diff --git a/core/mvcc/database/src/database.rs b/core/mvcc/database/src/database.rs index a1dd27265..51176c617 100644 --- a/core/mvcc/database/src/database.rs +++ b/core/mvcc/database/src/database.rs @@ -119,6 +119,25 @@ pub struct DatabaseInner { clock: Clock, } +impl DatabaseInner { + fn rollback_tx(&self, tx_id: TxID) { + let mut txs = self.txs.borrow_mut(); + let mut tx = txs.get_mut(&tx_id).unwrap(); + assert!(tx.state == TransactionState::Active); + tx.state = TransactionState::Aborted; + let mut rows = self.rows.borrow_mut(); + for id in &tx.write_set { + if let Some(row_versions) = rows.get_mut(id) { + row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id)); + if row_versions.is_empty() { + rows.remove(id); + } + } + } + tx.state = TransactionState::Terminated; + } +} + impl Database { /// Creates a new database. pub fn new(clock: Clock) -> Self { @@ -206,26 +225,29 @@ impl Database { let inner = self.inner.lock().unwrap(); let mut rows = inner.rows.borrow_mut(); let mut txs = inner.txs.borrow_mut(); - match rows.get_mut(&id) { - Some(row_versions) => match row_versions.last_mut() { - Some(v) => { - let tx = txs.get(&tx_id).ok_or(DatabaseError::NoSuchTransactionID(tx_id))?; - assert!(tx.state == TransactionState::Active); - if is_version_visible(&txs, tx, v) { - v.end = Some(TxTimestampOrID::TxID(tx.tx_id)); - } else { - return Ok(false); - } + if let Some(row_versions) = rows.get_mut(&id) { + for rv in row_versions.iter_mut().rev() { + let tx = txs + .get(&tx_id) + .ok_or(DatabaseError::NoSuchTransactionID(tx_id))?; + assert!(tx.state == TransactionState::Active); + if is_write_write_conflict(&txs, tx, rv) { + drop(txs); + drop(rows); + inner.rollback_tx(tx_id); + return Err(DatabaseError::WriteWriteConflict); } - None => unreachable!("no versions for row {}", id), - }, - None => return Ok(false), + if is_version_visible(&txs, tx, rv) { + rv.end = Some(TxTimestampOrID::TxID(tx.tx_id)); + let tx = txs + .get_mut(&tx_id) + .ok_or(DatabaseError::NoSuchTransactionID(tx_id))?; + tx.insert_to_write_set(id); + return Ok(true); + } + } } - let tx = txs - .get_mut(&tx_id) - .ok_or(DatabaseError::NoSuchTransactionID(tx_id))?; - tx.insert_to_write_set(id); - Ok(true) + Ok(false) } /// Retrieves a row from the table with the given `id`. @@ -283,12 +305,17 @@ impl Database { /// # Arguments /// /// * `tx_id` - The ID of the transaction to commit. - pub fn commit_tx(&self, tx_id: TxID) { + pub fn commit_tx(&self, tx_id: TxID) -> Result<()> { let mut inner = self.inner.lock().unwrap(); let end_ts = get_timestamp(&mut inner); let mut txs = inner.txs.borrow_mut(); let mut tx = txs.get_mut(&tx_id).unwrap(); - assert!(tx.state == TransactionState::Active); + match tx.state { + TransactionState::Terminated => return Err(DatabaseError::TxTerminated), + _ => { + assert!(tx.state == TransactionState::Active); + } + } let mut rows = inner.rows.borrow_mut(); tx.state = TransactionState::Preparing; for id in &tx.write_set { @@ -308,6 +335,7 @@ impl Database { } } tx.state = TransactionState::Committed; + Ok(()) } /// Rolls back a transaction with the specified ID. @@ -320,20 +348,28 @@ impl Database { /// * `tx_id` - The ID of the transaction to abort. pub fn rollback_tx(&self, tx_id: TxID) { let inner = self.inner.lock().unwrap(); - let mut txs = inner.txs.borrow_mut(); - let mut tx = txs.get_mut(&tx_id).unwrap(); - assert!(tx.state == TransactionState::Active); - tx.state = TransactionState::Aborted; - let mut rows = inner.rows.borrow_mut(); - for id in &tx.write_set { - if let Some(row_versions) = rows.get_mut(id) { - row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id)); - if row_versions.is_empty() { - rows.remove(id); - } + inner.rollback_tx(tx_id); + } +} + +fn is_write_write_conflict( + txs: &HashMap, + tx: &Transaction, + rv: &RowVersion, +) -> bool { + match rv.end { + Some(TxTimestampOrID::TxID(rv_end)) => { + let te = txs.get(&rv_end).unwrap(); + match te.state { + TransactionState::Active => tx.tx_id != te.tx_id, + TransactionState::Preparing => todo!(), + TransactionState::Committed => todo!(), + TransactionState::Aborted => todo!(), + TransactionState::Terminated => todo!(), } } - tx.state = TransactionState::Terminated; + Some(TxTimestampOrID::Timestamp(_)) => false, + None => false, } } @@ -399,7 +435,7 @@ mod tests { db.insert(tx1, tx1_row.clone()).unwrap(); let row = db.read(tx1, 1).unwrap().unwrap(); assert_eq!(tx1_row, row); - db.commit_tx(tx1); + db.commit_tx(tx1).unwrap(); let tx2 = db.begin_tx(); let row = db.read(tx2, 1).unwrap().unwrap(); @@ -431,7 +467,7 @@ mod tests { db.delete(tx1, 1).unwrap(); let row = db.read(tx1, 1).unwrap(); assert!(row.is_none()); - db.commit_tx(tx1); + db.commit_tx(tx1).unwrap(); let tx2 = db.begin_tx(); let row = db.read(tx2, 1).unwrap(); @@ -465,11 +501,11 @@ mod tests { db.update(tx1, tx1_updated_row.clone()).unwrap(); let row = db.read(tx1, 1).unwrap().unwrap(); assert_eq!(tx1_updated_row, row); - db.commit_tx(tx1); + db.commit_tx(tx1).unwrap(); let tx2 = db.begin_tx(); let row = db.read(tx2, 1).unwrap().unwrap(); - db.commit_tx(tx2); + db.commit_tx(tx2).unwrap(); assert_eq!(tx1_updated_row, row); } @@ -557,7 +593,7 @@ mod tests { data: "Hello".to_string(), }; db.insert(tx1, tx1_row.clone()).unwrap(); - db.commit_tx(tx1); + db.commit_tx(tx1).unwrap(); // T2 deletes row with ID 1, but does not commit. let tx2 = db.begin_tx(); @@ -583,7 +619,7 @@ mod tests { db.insert(tx1, tx1_row.clone()).unwrap(); let row = db.read(tx1.clone(), 1).unwrap().unwrap(); assert_eq!(tx1_row, row); - db.commit_tx(tx1); + db.commit_tx(tx1).unwrap(); // T2 reads the row with ID 1 within an active transaction. let tx2 = db.begin_tx(); @@ -597,14 +633,13 @@ mod tests { data: "World".to_string(), }; db.update(tx3, tx3_row.clone()).unwrap(); - db.commit_tx(tx3); + db.commit_tx(tx3).unwrap(); // T2 still reads the same version of the row as before. let row = db.read(tx2, 1).unwrap().unwrap(); assert_eq!(tx1_row, row); } - #[ignore] #[test] fn test_lost_update() { let clock = LocalClock::default(); @@ -619,7 +654,7 @@ mod tests { db.insert(tx1, tx1_row.clone()).unwrap(); let row = db.read(tx1.clone(), 1).unwrap().unwrap(); assert_eq!(tx1_row, row); - db.commit_tx(tx1); + db.commit_tx(tx1).unwrap(); // T2 attempts to update row ID 1 within an active transaction. let tx2 = db.begin_tx(); @@ -627,7 +662,7 @@ mod tests { id: 1, data: "World".to_string(), }; - db.update(tx2, tx2_row.clone()).unwrap(); + assert!(db.update(tx2, tx2_row.clone()).unwrap()); // T3 also attempts to update row ID 1 within an active transaction. let tx3 = db.begin_tx(); @@ -635,10 +670,13 @@ mod tests { id: 1, data: "Hello, world!".to_string(), }; - db.update(tx3, tx3_row.clone()).unwrap(); + assert_eq!( + Err(DatabaseError::WriteWriteConflict), + db.update(tx3, tx3_row.clone()) + ); - db.commit_tx(tx2); - db.commit_tx(tx3); // TODO: this should fail + db.commit_tx(tx2).unwrap(); + assert_eq!(Err(DatabaseError::TxTerminated), db.commit_tx(tx3)); let tx4 = db.begin_tx(); let row = db.read(tx4, 1).unwrap().unwrap(); diff --git a/core/mvcc/database/src/errors.rs b/core/mvcc/database/src/errors.rs index 95901137b..7bd5bab57 100644 --- a/core/mvcc/database/src/errors.rs +++ b/core/mvcc/database/src/errors.rs @@ -1,7 +1,11 @@ use thiserror::Error; -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum DatabaseError { #[error("no such transaction ID: `{0}`")] NoSuchTransactionID(u64), + #[error("transaction aborted because of a write-write conflict")] + WriteWriteConflict, + #[error("transaction is terminated")] + TxTerminated, } diff --git a/core/mvcc/database/tests/concurrency_test.rs b/core/mvcc/database/tests/concurrency_test.rs index a39a87dec..b18e9a9bd 100644 --- a/core/mvcc/database/tests/concurrency_test.rs +++ b/core/mvcc/database/tests/concurrency_test.rs @@ -24,10 +24,10 @@ fn test_non_overlapping_concurrent_inserts() { data: "Hello".to_string(), }; db.insert(tx, row.clone()).unwrap(); - db.commit_tx(tx); + db.commit_tx(tx).unwrap(); let tx = db.begin_tx(); let committed_row = db.read(tx, id).unwrap(); - db.commit_tx(tx); + db.commit_tx(tx).unwrap(); assert_eq!(committed_row, Some(row)); }); } @@ -42,10 +42,10 @@ fn test_non_overlapping_concurrent_inserts() { data: "World".to_string(), }; db.insert(tx, row.clone()).unwrap(); - db.commit_tx(tx); + db.commit_tx(tx).unwrap(); let tx = db.begin_tx(); let committed_row = db.read(tx, id).unwrap(); - db.commit_tx(tx); + db.commit_tx(tx).unwrap(); assert_eq!(committed_row, Some(row)); }); }