Fix lost update anomaly

Fixes #5
This commit is contained in:
Pekka Enberg
2023-04-13 09:36:10 +03:00
parent eb250e1e83
commit d7ecfc054c
3 changed files with 92 additions and 50 deletions

View File

@@ -119,6 +119,25 @@ pub struct DatabaseInner<Clock: LogicalClock> {
clock: Clock,
}
impl<Clock: LogicalClock> DatabaseInner<Clock> {
fn rollback_tx(&self, tx_id: TxID) {
let mut txs = self.txs.borrow_mut();
let mut tx = txs.get_mut(&tx_id).unwrap();
assert!(tx.state == TransactionState::Active);
tx.state = TransactionState::Aborted;
let mut rows = self.rows.borrow_mut();
for id in &tx.write_set {
if let Some(row_versions) = rows.get_mut(id) {
row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id));
if row_versions.is_empty() {
rows.remove(id);
}
}
}
tx.state = TransactionState::Terminated;
}
}
impl<Clock: LogicalClock> Database<Clock> {
/// Creates a new database.
pub fn new(clock: Clock) -> Self {
@@ -206,26 +225,29 @@ impl<Clock: LogicalClock> Database<Clock> {
let inner = self.inner.lock().unwrap();
let mut rows = inner.rows.borrow_mut();
let mut txs = inner.txs.borrow_mut();
match rows.get_mut(&id) {
Some(row_versions) => match row_versions.last_mut() {
Some(v) => {
let tx = txs.get(&tx_id).ok_or(DatabaseError::NoSuchTransactionID(tx_id))?;
assert!(tx.state == TransactionState::Active);
if is_version_visible(&txs, tx, v) {
v.end = Some(TxTimestampOrID::TxID(tx.tx_id));
} else {
return Ok(false);
}
if let Some(row_versions) = rows.get_mut(&id) {
for rv in row_versions.iter_mut().rev() {
let tx = txs
.get(&tx_id)
.ok_or(DatabaseError::NoSuchTransactionID(tx_id))?;
assert!(tx.state == TransactionState::Active);
if is_write_write_conflict(&txs, tx, rv) {
drop(txs);
drop(rows);
inner.rollback_tx(tx_id);
return Err(DatabaseError::WriteWriteConflict);
}
None => unreachable!("no versions for row {}", id),
},
None => return Ok(false),
if is_version_visible(&txs, tx, rv) {
rv.end = Some(TxTimestampOrID::TxID(tx.tx_id));
let tx = txs
.get_mut(&tx_id)
.ok_or(DatabaseError::NoSuchTransactionID(tx_id))?;
tx.insert_to_write_set(id);
return Ok(true);
}
}
}
let tx = txs
.get_mut(&tx_id)
.ok_or(DatabaseError::NoSuchTransactionID(tx_id))?;
tx.insert_to_write_set(id);
Ok(true)
Ok(false)
}
/// Retrieves a row from the table with the given `id`.
@@ -283,12 +305,17 @@ impl<Clock: LogicalClock> Database<Clock> {
/// # Arguments
///
/// * `tx_id` - The ID of the transaction to commit.
pub fn commit_tx(&self, tx_id: TxID) {
pub fn commit_tx(&self, tx_id: TxID) -> Result<()> {
let mut inner = self.inner.lock().unwrap();
let end_ts = get_timestamp(&mut inner);
let mut txs = inner.txs.borrow_mut();
let mut tx = txs.get_mut(&tx_id).unwrap();
assert!(tx.state == TransactionState::Active);
match tx.state {
TransactionState::Terminated => return Err(DatabaseError::TxTerminated),
_ => {
assert!(tx.state == TransactionState::Active);
}
}
let mut rows = inner.rows.borrow_mut();
tx.state = TransactionState::Preparing;
for id in &tx.write_set {
@@ -308,6 +335,7 @@ impl<Clock: LogicalClock> Database<Clock> {
}
}
tx.state = TransactionState::Committed;
Ok(())
}
/// Rolls back a transaction with the specified ID.
@@ -320,20 +348,28 @@ impl<Clock: LogicalClock> Database<Clock> {
/// * `tx_id` - The ID of the transaction to abort.
pub fn rollback_tx(&self, tx_id: TxID) {
let inner = self.inner.lock().unwrap();
let mut txs = inner.txs.borrow_mut();
let mut tx = txs.get_mut(&tx_id).unwrap();
assert!(tx.state == TransactionState::Active);
tx.state = TransactionState::Aborted;
let mut rows = inner.rows.borrow_mut();
for id in &tx.write_set {
if let Some(row_versions) = rows.get_mut(id) {
row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id));
if row_versions.is_empty() {
rows.remove(id);
}
inner.rollback_tx(tx_id);
}
}
fn is_write_write_conflict(
txs: &HashMap<TxID, Transaction>,
tx: &Transaction,
rv: &RowVersion,
) -> bool {
match rv.end {
Some(TxTimestampOrID::TxID(rv_end)) => {
let te = txs.get(&rv_end).unwrap();
match te.state {
TransactionState::Active => tx.tx_id != te.tx_id,
TransactionState::Preparing => todo!(),
TransactionState::Committed => todo!(),
TransactionState::Aborted => todo!(),
TransactionState::Terminated => todo!(),
}
}
tx.state = TransactionState::Terminated;
Some(TxTimestampOrID::Timestamp(_)) => false,
None => false,
}
}
@@ -399,7 +435,7 @@ mod tests {
db.insert(tx1, tx1_row.clone()).unwrap();
let row = db.read(tx1, 1).unwrap().unwrap();
assert_eq!(tx1_row, row);
db.commit_tx(tx1);
db.commit_tx(tx1).unwrap();
let tx2 = db.begin_tx();
let row = db.read(tx2, 1).unwrap().unwrap();
@@ -431,7 +467,7 @@ mod tests {
db.delete(tx1, 1).unwrap();
let row = db.read(tx1, 1).unwrap();
assert!(row.is_none());
db.commit_tx(tx1);
db.commit_tx(tx1).unwrap();
let tx2 = db.begin_tx();
let row = db.read(tx2, 1).unwrap();
@@ -465,11 +501,11 @@ mod tests {
db.update(tx1, tx1_updated_row.clone()).unwrap();
let row = db.read(tx1, 1).unwrap().unwrap();
assert_eq!(tx1_updated_row, row);
db.commit_tx(tx1);
db.commit_tx(tx1).unwrap();
let tx2 = db.begin_tx();
let row = db.read(tx2, 1).unwrap().unwrap();
db.commit_tx(tx2);
db.commit_tx(tx2).unwrap();
assert_eq!(tx1_updated_row, row);
}
@@ -557,7 +593,7 @@ mod tests {
data: "Hello".to_string(),
};
db.insert(tx1, tx1_row.clone()).unwrap();
db.commit_tx(tx1);
db.commit_tx(tx1).unwrap();
// T2 deletes row with ID 1, but does not commit.
let tx2 = db.begin_tx();
@@ -583,7 +619,7 @@ mod tests {
db.insert(tx1, tx1_row.clone()).unwrap();
let row = db.read(tx1.clone(), 1).unwrap().unwrap();
assert_eq!(tx1_row, row);
db.commit_tx(tx1);
db.commit_tx(tx1).unwrap();
// T2 reads the row with ID 1 within an active transaction.
let tx2 = db.begin_tx();
@@ -597,14 +633,13 @@ mod tests {
data: "World".to_string(),
};
db.update(tx3, tx3_row.clone()).unwrap();
db.commit_tx(tx3);
db.commit_tx(tx3).unwrap();
// T2 still reads the same version of the row as before.
let row = db.read(tx2, 1).unwrap().unwrap();
assert_eq!(tx1_row, row);
}
#[ignore]
#[test]
fn test_lost_update() {
let clock = LocalClock::default();
@@ -619,7 +654,7 @@ mod tests {
db.insert(tx1, tx1_row.clone()).unwrap();
let row = db.read(tx1.clone(), 1).unwrap().unwrap();
assert_eq!(tx1_row, row);
db.commit_tx(tx1);
db.commit_tx(tx1).unwrap();
// T2 attempts to update row ID 1 within an active transaction.
let tx2 = db.begin_tx();
@@ -627,7 +662,7 @@ mod tests {
id: 1,
data: "World".to_string(),
};
db.update(tx2, tx2_row.clone()).unwrap();
assert!(db.update(tx2, tx2_row.clone()).unwrap());
// T3 also attempts to update row ID 1 within an active transaction.
let tx3 = db.begin_tx();
@@ -635,10 +670,13 @@ mod tests {
id: 1,
data: "Hello, world!".to_string(),
};
db.update(tx3, tx3_row.clone()).unwrap();
assert_eq!(
Err(DatabaseError::WriteWriteConflict),
db.update(tx3, tx3_row.clone())
);
db.commit_tx(tx2);
db.commit_tx(tx3); // TODO: this should fail
db.commit_tx(tx2).unwrap();
assert_eq!(Err(DatabaseError::TxTerminated), db.commit_tx(tx3));
let tx4 = db.begin_tx();
let row = db.read(tx4, 1).unwrap().unwrap();

View File

@@ -1,7 +1,11 @@
use thiserror::Error;
#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
pub enum DatabaseError {
#[error("no such transaction ID: `{0}`")]
NoSuchTransactionID(u64),
#[error("transaction aborted because of a write-write conflict")]
WriteWriteConflict,
#[error("transaction is terminated")]
TxTerminated,
}

View File

@@ -24,10 +24,10 @@ fn test_non_overlapping_concurrent_inserts() {
data: "Hello".to_string(),
};
db.insert(tx, row.clone()).unwrap();
db.commit_tx(tx);
db.commit_tx(tx).unwrap();
let tx = db.begin_tx();
let committed_row = db.read(tx, id).unwrap();
db.commit_tx(tx);
db.commit_tx(tx).unwrap();
assert_eq!(committed_row, Some(row));
});
}
@@ -42,10 +42,10 @@ fn test_non_overlapping_concurrent_inserts() {
data: "World".to_string(),
};
db.insert(tx, row.clone()).unwrap();
db.commit_tx(tx);
db.commit_tx(tx).unwrap();
let tx = db.begin_tx();
let committed_row = db.read(tx, id).unwrap();
db.commit_tx(tx);
db.commit_tx(tx).unwrap();
assert_eq!(committed_row, Some(row));
});
}