diff --git a/core/mvcc/.github/workflows/smoke_test.yml b/core/mvcc/.github/workflows/smoke_test.yml index c776a1635..3a00d72b4 100644 --- a/core/mvcc/.github/workflows/smoke_test.yml +++ b/core/mvcc/.github/workflows/smoke_test.yml @@ -8,6 +8,7 @@ on: env: CARGO_TERM_COLOR: always + RUST_LOG: info,mvcc_rs=trace jobs: build: diff --git a/core/mvcc/mvcc-rs/src/database/mod.rs b/core/mvcc/mvcc-rs/src/database/mod.rs index 29bfbdf73..890c98f62 100644 --- a/core/mvcc/mvcc-rs/src/database/mod.rs +++ b/core/mvcc/mvcc-rs/src/database/mod.rs @@ -297,10 +297,11 @@ impl Database { end: None, row, }; + tx.insert_to_write_set(id); + drop(tx); let versions = self.rows.get_or_insert_with(id, || RwLock::new(Vec::new())); let mut versions = versions.value().write().unwrap(); versions.push(row_version); - tx.insert_to_write_set(id); Ok(()) } @@ -364,7 +365,9 @@ impl Database { } if is_version_visible(&self.txs, &tx, rv) { rv.end = Some(TxTimestampOrID::TxID(tx.tx_id)); - drop(tx); // FIXME: maybe just grab the write lock above? Do we ever expect conflicts? + drop(row_versions); + drop(row_versions_opt); + drop(tx); let tx = self .txs .get(&tx_id) @@ -456,6 +459,8 @@ impl Database { /// * `tx_id` - The ID of the transaction to commit. pub fn commit_tx(&self, tx_id: TxID) -> Result<()> { let end_ts = self.get_timestamp(); + // NOTICE: the first shadowed tx keeps the entry alive in the map + // for the duration of this whole function, which is important for correctness! let tx = self.txs.get(&tx_id).unwrap(); let tx = tx.value().write().unwrap(); match tx.state.load() { @@ -542,17 +547,19 @@ impl Database { */ tx.state.store(TransactionState::Committed(end_ts)); tracing::trace!("COMMIT {tx}"); + let tx_begin_ts = tx.begin_ts; + let write_set: Vec = tx.write_set.iter().map(|v| *v.value()).collect(); + drop(tx); // Postprocessing: inserting row versions and logging the transaction to persistent storage. // TODO: we should probably save to persistent storage first, and only then update the in-memory structures. let mut log_record: LogRecord = LogRecord::new(end_ts); - for id in &tx.write_set { - let id = id.value(); + for ref id in write_set { if let Some(row_versions) = self.rows.get(id) { let mut row_versions = row_versions.value().write().unwrap(); for row_version in row_versions.iter_mut() { if let TxTimestampOrID::TxID(id) = row_version.begin { if id == tx_id { - row_version.begin = TxTimestampOrID::Timestamp(tx.begin_ts); + row_version.begin = TxTimestampOrID::Timestamp(tx_begin_ts); log_record.row_versions.push(row_version.clone()); // FIXME: optimize cloning out } } @@ -565,7 +572,7 @@ impl Database { } } } - tracing::trace!("UPDATED {tx}"); + tracing::trace!("UPDATED TX{tx_id}"); // We have now updated all the versions with a reference to the // transaction ID to a timestamp and can, therefore, remove the // transaction. Please note that when we move to lockless, the @@ -578,7 +585,7 @@ impl Database { if !log_record.row_versions.is_empty() { self.storage.log_tx(log_record)?; } - tracing::trace!("LOGGED {tx}"); + tracing::trace!("LOGGED {tx_id}"); Ok(()) } @@ -591,13 +598,14 @@ impl Database { /// /// * `tx_id` - The ID of the transaction to abort. pub fn rollback_tx(&self, tx_id: TxID) { - let tx = self.txs.get(&tx_id).unwrap(); - let tx = tx.value().write().unwrap(); + let tx_unlocked = self.txs.get(&tx_id).unwrap(); + let tx = tx_unlocked.value().write().unwrap(); assert!(tx.state == TransactionState::Active); tx.state.store(TransactionState::Aborted); tracing::trace!("ABORT {tx}"); - for id in &tx.write_set { - let id = id.value(); + let write_set: Vec = tx.write_set.iter().map(|v| *v.value()).collect(); + drop(tx); + for ref id in write_set { if let Some(row_versions) = self.rows.get(id) { let mut row_versions = row_versions.value().write().unwrap(); row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id)); @@ -606,6 +614,7 @@ impl Database { } } } + let tx = tx_unlocked.value().write().unwrap(); tx.state.store(TransactionState::Terminated); tracing::trace!("TERMINATE {tx}"); } diff --git a/core/mvcc/mvcc-rs/tests/concurrency_test.rs b/core/mvcc/mvcc-rs/tests/concurrency_test.rs index e284dd6da..3c8085ea0 100644 --- a/core/mvcc/mvcc-rs/tests/concurrency_test.rs +++ b/core/mvcc/mvcc-rs/tests/concurrency_test.rs @@ -2,14 +2,17 @@ use mvcc_rs::clock::LocalClock; use mvcc_rs::database::{Database, Row, RowID}; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; -use std::sync::Arc; +use std::sync::{Arc, Once}; static IDS: AtomicU64 = AtomicU64::new(1); -#[tracing_test::traced_test] +static START: Once = Once::new(); + #[test] fn test_non_overlapping_concurrent_inserts() { - tracing_subscriber::fmt::init(); + START.call_once(|| { + tracing_subscriber::fmt::init(); + }); // Two threads insert to the database concurrently using non-overlapping // row IDs. let clock = LocalClock::default(); @@ -68,8 +71,9 @@ fn test_non_overlapping_concurrent_inserts() { #[test] fn test_overlapping_concurrent_inserts_read_your_writes() { - tracing_subscriber::fmt::init(); - // Two threads insert to the database concurrently using overlapping row IDs. + START.call_once(|| { + tracing_subscriber::fmt::init(); + }); // Two threads insert to the database concurrently using overlapping row IDs. let clock = LocalClock::default(); let storage = mvcc_rs::persistent_storage::Storage::new_noop(); let db = Arc::new(Database::new(clock, storage));