diff --git a/core/mvcc/.github/workflows/smoke_test.yml b/core/mvcc/.github/workflows/smoke_test.yml index c776a1635..3a00d72b4 100644 --- a/core/mvcc/.github/workflows/smoke_test.yml +++ b/core/mvcc/.github/workflows/smoke_test.yml @@ -8,6 +8,7 @@ on: env: CARGO_TERM_COLOR: always + RUST_LOG: info,mvcc_rs=trace jobs: build: diff --git a/core/mvcc/mvcc-rs/Cargo.toml b/core/mvcc/mvcc-rs/Cargo.toml index 21c83167d..27f030a73 100644 --- a/core/mvcc/mvcc-rs/Cargo.toml +++ b/core/mvcc/mvcc-rs/Cargo.toml @@ -16,13 +16,12 @@ aws-config = "0.55.2" parking_lot = "0.12.1" futures = "0.3.28" crossbeam-skiplist = "0.1.1" +tracing-test = "0" [dev-dependencies] criterion = { version = "0.4", features = ["html_reports", "async", "async_futures"] } pprof = { version = "0.11.1", features = ["criterion", "flamegraph"] } -shuttle = "0.6.0" tracing-subscriber = "0" -tracing-test = "0" mvcc-rs = { path = "." } [[bench]] diff --git a/core/mvcc/mvcc-rs/src/database/mod.rs b/core/mvcc/mvcc-rs/src/database/mod.rs index 01e0c2dce..4941b5305 100644 --- a/core/mvcc/mvcc-rs/src/database/mod.rs +++ b/core/mvcc/mvcc-rs/src/database/mod.rs @@ -17,7 +17,7 @@ pub struct RowID { pub row_id: u64, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, PartialOrd, Serialize, Deserialize)] pub struct Row { pub id: RowID, @@ -25,7 +25,7 @@ pub struct Row { } /// A row version. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct RowVersion { begin: TxTimestampOrID, end: Option, @@ -56,7 +56,7 @@ impl LogRecord { /// phase of the transaction. During the active phase, new versions track the /// transaction ID in the `begin` and `end` fields. After a transaction commits, /// versions switch to tracking timestamps. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, PartialOrd, Serialize, Deserialize)] enum TxTimestampOrID { Timestamp(u64), TxID(TxID), @@ -274,6 +274,49 @@ impl Database { } } + // Extracts the begin timestamp from a transaction + fn get_begin_timestamp(&self, ts_or_id: &TxTimestampOrID) -> u64 { + match ts_or_id { + TxTimestampOrID::Timestamp(ts) => *ts, + TxTimestampOrID::TxID(tx_id) => { + self.txs + .get(tx_id) + .unwrap() + .value() + .read() + .unwrap() + .begin_ts + } + } + } + + /// Inserts a new row version into the database, while making sure that + /// the row version is inserted in the correct order. + fn insert_version(&self, id: RowID, row_version: RowVersion) { + let versions = self.rows.get_or_insert_with(id, || RwLock::new(Vec::new())); + let mut versions = versions.value().write().unwrap(); + self.insert_version_raw(&mut versions, row_version) + } + + /// Inserts a new row version into the internal data structure for versions, + /// while making sure that the row version is inserted in the correct order. + fn insert_version_raw(&self, versions: &mut Vec, row_version: RowVersion) { + // NOTICE: this is an insert a'la insertion sort, with pessimistic linear complexity. + // However, we expect the number of versions to be nearly sorted, so we deem it worthy + // to search linearly for the insertion point instead of paying the price of using + // another data structure, e.g. a BTreeSet. If it proves to be too quadratic empirically, + // we can either switch to a tree-like structure, or at least use partition_point() + // which performs a binary search for the insertion point. + let position = versions + .iter() + .rposition(|v| { + self.get_begin_timestamp(&v.begin) < self.get_begin_timestamp(&row_version.begin) + }) + .map(|p| p + 1) + .unwrap_or(0); + versions.insert(position, row_version); + } + /// Inserts a new row into the database. /// /// This function inserts a new `row` into the database within the context @@ -297,10 +340,9 @@ impl Database { end: None, row, }; - let versions = self.rows.get_or_insert_with(id, || RwLock::new(Vec::new())); - let mut versions = versions.value().write().unwrap(); - versions.push(row_version); tx.insert_to_write_set(id); + drop(tx); + self.insert_version(id, row_version); Ok(()) } @@ -364,7 +406,9 @@ impl Database { } if is_version_visible(&self.txs, &tx, rv) { rv.end = Some(TxTimestampOrID::TxID(tx.tx_id)); - drop(tx); // FIXME: maybe just grab the write lock above? Do we ever expect conflicts? + drop(row_versions); + drop(row_versions_opt); + drop(tx); let tx = self .txs .get(&tx_id) @@ -440,7 +484,7 @@ impl Database { let tx_id = self.get_tx_id(); let begin_ts = self.get_timestamp(); let tx = Transaction::new(tx_id, begin_ts); - tracing::trace!("BEGIN {tx}"); + tracing::trace!("BEGIN {tx}"); self.txs.insert(tx_id, RwLock::new(tx)); tx_id } @@ -456,6 +500,8 @@ impl Database { /// * `tx_id` - The ID of the transaction to commit. pub fn commit_tx(&self, tx_id: TxID) -> Result<()> { let end_ts = self.get_timestamp(); + // NOTICE: the first shadowed tx keeps the entry alive in the map + // for the duration of this whole function, which is important for correctness! let tx = self.txs.get(&tx_id).unwrap(); let tx = tx.value().write().unwrap(); match tx.state.load() { @@ -542,29 +588,38 @@ impl Database { */ tx.state.store(TransactionState::Committed(end_ts)); tracing::trace!("COMMIT {tx}"); + let tx_begin_ts = tx.begin_ts; + let write_set: Vec = tx.write_set.iter().map(|v| *v.value()).collect(); + drop(tx); // Postprocessing: inserting row versions and logging the transaction to persistent storage. // TODO: we should probably save to persistent storage first, and only then update the in-memory structures. let mut log_record: LogRecord = LogRecord::new(end_ts); - for id in &tx.write_set { - let id = id.value(); + for ref id in write_set { if let Some(row_versions) = self.rows.get(id) { let mut row_versions = row_versions.value().write().unwrap(); for row_version in row_versions.iter_mut() { if let TxTimestampOrID::TxID(id) = row_version.begin { if id == tx_id { - row_version.begin = TxTimestampOrID::Timestamp(tx.begin_ts); - log_record.row_versions.push(row_version.clone()); // FIXME: optimize cloning out + row_version.begin = TxTimestampOrID::Timestamp(tx_begin_ts); + self.insert_version_raw( + &mut log_record.row_versions, + row_version.clone(), + ); // FIXME: optimize cloning out } } if let Some(TxTimestampOrID::TxID(id)) = row_version.end { if id == tx_id { row_version.end = Some(TxTimestampOrID::Timestamp(end_ts)); - log_record.row_versions.push(row_version.clone()); // FIXME: optimize cloning out + self.insert_version_raw( + &mut log_record.row_versions, + row_version.clone(), + ); // FIXME: optimize cloning out } } } } } + tracing::trace!("UPDATED TX{tx_id}"); // We have now updated all the versions with a reference to the // transaction ID to a timestamp and can, therefore, remove the // transaction. Please note that when we move to lockless, the @@ -577,6 +632,7 @@ impl Database { if !log_record.row_versions.is_empty() { self.storage.log_tx(log_record)?; } + tracing::trace!("LOGGED {tx_id}"); Ok(()) } @@ -589,13 +645,14 @@ impl Database { /// /// * `tx_id` - The ID of the transaction to abort. pub fn rollback_tx(&self, tx_id: TxID) { - let tx = self.txs.get(&tx_id).unwrap(); - let tx = tx.value().write().unwrap(); + let tx_unlocked = self.txs.get(&tx_id).unwrap(); + let tx = tx_unlocked.value().write().unwrap(); assert!(tx.state == TransactionState::Active); tx.state.store(TransactionState::Aborted); tracing::trace!("ABORT {tx}"); - for id in &tx.write_set { - let id = id.value(); + let write_set: Vec = tx.write_set.iter().map(|v| *v.value()).collect(); + drop(tx); + for ref id in write_set { if let Some(row_versions) = self.rows.get(id) { let mut row_versions = row_versions.value().write().unwrap(); row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id)); @@ -604,6 +661,7 @@ impl Database { } } } + let tx = tx_unlocked.value().write().unwrap(); tx.state.store(TransactionState::Terminated); tracing::trace!("TERMINATE {tx}"); } @@ -664,11 +722,7 @@ impl Database { for record in tx_log { tracing::debug!("RECOVERING {:?}", record); for version in record.row_versions { - let row_versions = self - .rows - .get_or_insert_with(version.row.id, || RwLock::new(Vec::new())); - let mut row_versions = row_versions.value().write().unwrap(); - row_versions.push(version); + self.insert_version(version.row.id, version); } self.clock.reset(record.tx_timestamp); } diff --git a/core/mvcc/mvcc-rs/tests/concurrency_test.rs b/core/mvcc/mvcc-rs/tests/concurrency_test.rs index 12321aa10..fced575d7 100644 --- a/core/mvcc/mvcc-rs/tests/concurrency_test.rs +++ b/core/mvcc/mvcc-rs/tests/concurrency_test.rs @@ -1,65 +1,111 @@ use mvcc_rs::clock::LocalClock; use mvcc_rs::database::{Database, Row, RowID}; -use shuttle::sync::atomic::AtomicU64; -use shuttle::sync::Arc; -use shuttle::thread; +use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; +use std::sync::{Arc, Once}; + +static IDS: AtomicU64 = AtomicU64::new(1); + +static START: Once = Once::new(); #[test] fn test_non_overlapping_concurrent_inserts() { + START.call_once(|| { + tracing_subscriber::fmt::init(); + }); // Two threads insert to the database concurrently using non-overlapping // row IDs. let clock = LocalClock::default(); let storage = mvcc_rs::persistent_storage::Storage::new_noop(); let db = Arc::new(Database::new(clock, storage)); - let ids = Arc::new(AtomicU64::new(0)); - shuttle::check_random( - move || { - { - let db = db.clone(); - let ids = ids.clone(); - thread::spawn(move || { - let tx = db.begin_tx(); - let id = ids.fetch_add(1, Ordering::SeqCst); - let id = RowID { - table_id: 1, - row_id: id, - }; - let row = Row { - id, - data: "Hello".to_string(), - }; - db.insert(tx, row.clone()).unwrap(); - db.commit_tx(tx).unwrap(); - let tx = db.begin_tx(); - let committed_row = db.read(tx, id).unwrap(); - db.commit_tx(tx).unwrap(); - assert_eq!(committed_row, Some(row)); - }); + let iterations = 100000; + + let th1 = { + let db = db.clone(); + std::thread::spawn(move || { + for _ in 0..iterations { + let tx = db.begin_tx(); + let id = IDS.fetch_add(1, Ordering::SeqCst); + let id = RowID { + table_id: 1, + row_id: id, + }; + let row = Row { + id, + data: "Hello".to_string(), + }; + db.insert(tx, row.clone()).unwrap(); + db.commit_tx(tx).unwrap(); + let tx = db.begin_tx(); + let committed_row = db.read(tx, id).unwrap(); + db.commit_tx(tx).unwrap(); + assert_eq!(committed_row, Some(row)); } - { - let db = db.clone(); - let ids = ids.clone(); - thread::spawn(move || { - let tx = db.begin_tx(); - let id = ids.fetch_add(1, Ordering::SeqCst); - let id = RowID { - table_id: 1, - row_id: id, - }; - let row = Row { - id, - data: "World".to_string(), - }; - db.insert(tx, row.clone()).unwrap(); - db.commit_tx(tx).unwrap(); - let tx = db.begin_tx(); - let committed_row = db.read(tx, id).unwrap(); - db.commit_tx(tx).unwrap(); - assert_eq!(committed_row, Some(row)); - }); + }) + }; + let th2 = { + std::thread::spawn(move || { + for _ in 0..iterations { + let tx = db.begin_tx(); + let id = IDS.fetch_add(1, Ordering::SeqCst); + let id = RowID { + table_id: 1, + row_id: id, + }; + let row = Row { + id, + data: "World".to_string(), + }; + db.insert(tx, row.clone()).unwrap(); + db.commit_tx(tx).unwrap(); + let tx = db.begin_tx(); + let committed_row = db.read(tx, id).unwrap(); + db.commit_tx(tx).unwrap(); + assert_eq!(committed_row, Some(row)); } - }, - 100, - ); + }) + }; + th1.join().unwrap(); + th2.join().unwrap(); +} + +#[test] +fn test_overlapping_concurrent_inserts_read_your_writes() { + START.call_once(|| { + tracing_subscriber::fmt::init(); + }); // Two threads insert to the database concurrently using overlapping row IDs. + let clock = LocalClock::default(); + let storage = mvcc_rs::persistent_storage::Storage::new_noop(); + let db = Arc::new(Database::new(clock, storage)); + let iterations = 100000; + + let work = |prefix: &'static str| { + let db = db.clone(); + std::thread::spawn(move || { + for i in 0..iterations { + if i % 1000 == 0 { + tracing::debug!("{prefix}: {i}"); + } + let tx = db.begin_tx(); + let id = i % 16; + let id = RowID { + table_id: 1, + row_id: id, + }; + let row = Row { + id, + data: format!("{prefix} @{tx}"), + }; + db.insert(tx, row.clone()).unwrap(); + let committed_row = db.read(tx, id).unwrap(); + db.commit_tx(tx).unwrap(); + assert_eq!(committed_row, Some(row)); + } + }) + }; + + let threads = vec![work("A"), work("B"), work("C"), work("D")]; + for th in threads { + th.join().unwrap(); + } }