From a93fcdcbcf3f5f7532c3d2bd8fffb3045a71bb05 Mon Sep 17 00:00:00 2001 From: Piotr Sarna Date: Mon, 12 Jun 2023 13:01:21 +0200 Subject: [PATCH] database: make transaction state atomic Without atomic access, we're subject to races when inspecting whether a transaction just changed its state, e.g. from Preparing to Committed. --- core/mvcc/mvcc-rs/src/database/mod.rs | 95 +++++++++++++++++++++---- core/mvcc/mvcc-rs/src/database/tests.rs | 1 + 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/core/mvcc/mvcc-rs/src/database/mod.rs b/core/mvcc/mvcc-rs/src/database/mod.rs index 3307b831b..01e0c2dce 100644 --- a/core/mvcc/mvcc-rs/src/database/mod.rs +++ b/core/mvcc/mvcc-rs/src/database/mod.rs @@ -66,7 +66,7 @@ enum TxTimestampOrID { #[derive(Debug, Serialize, Deserialize)] pub struct Transaction { /// The state of the transaction. - state: TransactionState, + state: AtomicTransactionState, /// The transaction ID. tx_id: u64, /// The transaction begin timestamp. @@ -126,7 +126,7 @@ mod skipset_rowid { impl Transaction { fn new(tx_id: u64, begin_ts: u64) -> Transaction { Transaction { - state: TransactionState::Active, + state: TransactionState::Active.into(), tx_id, begin_ts, write_set: SkipSet::new(), @@ -148,7 +148,7 @@ impl std::fmt::Display for Transaction { write!( f, "{{ state: {}, id: {}, begin_ts: {}, write_set: {:?}, read_set: {:?}", - self.state, + self.state.load(), self.tx_id, self.begin_ts, // FIXME: I'm sorry, we obviously shouldn't be cloning here. @@ -169,9 +169,66 @@ impl std::fmt::Display for Transaction { enum TransactionState { Active, Preparing, - Committed(u64), Aborted, Terminated, + Committed(u64), +} + +impl TransactionState { + pub fn encode(&self) -> u64 { + match self { + TransactionState::Active => 0, + TransactionState::Preparing => 1, + TransactionState::Aborted => 2, + TransactionState::Terminated => 3, + TransactionState::Committed(ts) => { + // We only support 2*62 - 1 timestamps, because the extra bit + // is used to encode the type. + assert!(ts & 0x8000_0000_0000_0000 == 0); + 0x8000_0000_0000_0000 | ts + } + } + } + + pub fn decode(v: u64) -> Self { + match v { + 0 => TransactionState::Active, + 1 => TransactionState::Preparing, + 2 => TransactionState::Aborted, + 3 => TransactionState::Terminated, + v if v & 0x8000_0000_0000_0000 != 0 => { + TransactionState::Committed(v & 0x7fff_ffff_ffff_ffff) + } + _ => panic!("Invalid transaction state"), + } + } +} + +// Transaction state encoded into a single 64-bit atomic. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct AtomicTransactionState { + pub(crate) state: AtomicU64, +} + +impl From for AtomicTransactionState { + fn from(state: TransactionState) -> Self { + Self { + state: AtomicU64::new(state.encode()), + } + } +} + +impl From for TransactionState { + fn from(state: AtomicTransactionState) -> Self { + let encoded = state.state.load(Ordering::Acquire); + TransactionState::decode(encoded) + } +} + +impl std::cmp::PartialEq for AtomicTransactionState { + fn eq(&self, other: &TransactionState) -> bool { + &self.load() == other + } } impl std::fmt::Display for TransactionState { @@ -186,6 +243,16 @@ impl std::fmt::Display for TransactionState { } } +impl AtomicTransactionState { + fn store(&self, state: TransactionState) { + self.state.store(state.encode(), Ordering::Release); + } + + fn load(&self) -> TransactionState { + TransactionState::decode(self.state.load(Ordering::Acquire)) + } +} + #[derive(Debug)] pub struct Database { rows: SkipMap>>, @@ -390,14 +457,14 @@ impl Database { pub fn commit_tx(&self, tx_id: TxID) -> Result<()> { let end_ts = self.get_timestamp(); let tx = self.txs.get(&tx_id).unwrap(); - let mut tx = tx.value().write().unwrap(); - match tx.state { + let tx = tx.value().write().unwrap(); + match tx.state.load() { TransactionState::Terminated => return Err(DatabaseError::TxTerminated), _ => { assert!(tx.state == TransactionState::Active); } } - tx.state = TransactionState::Preparing; + tx.state.store(TransactionState::Preparing); tracing::trace!("PREPARE {tx}"); /* TODO: The code we have here is sufficient for snapshot isolation. @@ -473,7 +540,7 @@ impl Database { only if TE commits. """ */ - tx.state = TransactionState::Committed(end_ts); + tx.state.store(TransactionState::Committed(end_ts)); tracing::trace!("COMMIT {tx}"); // Postprocessing: inserting row versions and logging the transaction to persistent storage. // TODO: we should probably save to persistent storage first, and only then update the in-memory structures. @@ -523,9 +590,9 @@ impl Database { /// * `tx_id` - The ID of the transaction to abort. pub fn rollback_tx(&self, tx_id: TxID) { let tx = self.txs.get(&tx_id).unwrap(); - let mut tx = tx.value().write().unwrap(); + let tx = tx.value().write().unwrap(); assert!(tx.state == TransactionState::Active); - tx.state = TransactionState::Aborted; + tx.state.store(TransactionState::Aborted); tracing::trace!("ABORT {tx}"); for id in &tx.write_set { let id = id.value(); @@ -537,7 +604,7 @@ impl Database { } } } - tx.state = TransactionState::Terminated; + tx.state.store(TransactionState::Terminated); tracing::trace!("TERMINATE {tx}"); } @@ -620,7 +687,7 @@ pub(crate) fn is_write_write_conflict( Some(TxTimestampOrID::TxID(rv_end)) => { let te = txs.get(&rv_end).unwrap(); let te = te.value().read().unwrap(); - match te.state { + match te.state.load() { TransactionState::Active => tx.tx_id != te.tx_id, TransactionState::Preparing => todo!(), TransactionState::Committed(_end_ts) => todo!(), @@ -651,7 +718,7 @@ fn is_begin_visible( TxTimestampOrID::TxID(rv_begin) => { let tb = txs.get(&rv_begin).unwrap(); let tb = tb.value().read().unwrap(); - let visible = match tb.state { + let visible = match tb.state.load() { TransactionState::Active => tx.tx_id == tb.tx_id && rv.end.is_none(), TransactionState::Preparing => false, // NOTICE: makes sense for snapshot isolation, not so much for serializable! TransactionState::Committed(committed_ts) => tx.begin_ts >= committed_ts, @@ -681,7 +748,7 @@ fn is_end_visible( Some(TxTimestampOrID::TxID(rv_end)) => { let te = txs.get(&rv_end).unwrap(); let te = te.value().read().unwrap(); - let visible = match te.state { + let visible = match te.state.load() { TransactionState::Active => tx.tx_id != te.tx_id, TransactionState::Preparing => false, // NOTICE: makes sense for snapshot isolation, not so much for serializable! TransactionState::Committed(committed_ts) => tx.begin_ts < committed_ts, diff --git a/core/mvcc/mvcc-rs/src/database/tests.rs b/core/mvcc/mvcc-rs/src/database/tests.rs index ada842218..e9023c76a 100644 --- a/core/mvcc/mvcc-rs/src/database/tests.rs +++ b/core/mvcc/mvcc-rs/src/database/tests.rs @@ -822,6 +822,7 @@ or not found | | the timestamp. */ fn new_tx(tx_id: TxID, begin_ts: u64, state: TransactionState) -> RwLock { + let state = state.into(); RwLock::new(Transaction { state, tx_id,