Merge pull request #54 from penberg/concurrency_no_shuttle

concurrency test: port to OS threads
This commit is contained in:
Pekka Enberg
2023-06-13 14:49:48 +03:00
committed by GitHub
4 changed files with 175 additions and 75 deletions

View File

@@ -8,6 +8,7 @@ on:
env:
CARGO_TERM_COLOR: always
RUST_LOG: info,mvcc_rs=trace
jobs:
build:

View File

@@ -16,13 +16,12 @@ aws-config = "0.55.2"
parking_lot = "0.12.1"
futures = "0.3.28"
crossbeam-skiplist = "0.1.1"
tracing-test = "0"
[dev-dependencies]
criterion = { version = "0.4", features = ["html_reports", "async", "async_futures"] }
pprof = { version = "0.11.1", features = ["criterion", "flamegraph"] }
shuttle = "0.6.0"
tracing-subscriber = "0"
tracing-test = "0"
mvcc-rs = { path = "." }
[[bench]]

View File

@@ -17,7 +17,7 @@ pub struct RowID {
pub row_id: u64,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, PartialOrd, Serialize, Deserialize)]
pub struct Row {
pub id: RowID,
@@ -25,7 +25,7 @@ pub struct Row {
}
/// A row version.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct RowVersion {
begin: TxTimestampOrID,
end: Option<TxTimestampOrID>,
@@ -56,7 +56,7 @@ impl LogRecord {
/// phase of the transaction. During the active phase, new versions track the
/// transaction ID in the `begin` and `end` fields. After a transaction commits,
/// versions switch to tracking timestamps.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, PartialOrd, Serialize, Deserialize)]
enum TxTimestampOrID {
Timestamp(u64),
TxID(TxID),
@@ -274,6 +274,49 @@ impl<Clock: LogicalClock> Database<Clock> {
}
}
// Extracts the begin timestamp from a transaction
fn get_begin_timestamp(&self, ts_or_id: &TxTimestampOrID) -> u64 {
match ts_or_id {
TxTimestampOrID::Timestamp(ts) => *ts,
TxTimestampOrID::TxID(tx_id) => {
self.txs
.get(tx_id)
.unwrap()
.value()
.read()
.unwrap()
.begin_ts
}
}
}
/// Inserts a new row version into the database, while making sure that
/// the row version is inserted in the correct order.
fn insert_version(&self, id: RowID, row_version: RowVersion) {
let versions = self.rows.get_or_insert_with(id, || RwLock::new(Vec::new()));
let mut versions = versions.value().write().unwrap();
self.insert_version_raw(&mut versions, row_version)
}
/// Inserts a new row version into the internal data structure for versions,
/// while making sure that the row version is inserted in the correct order.
fn insert_version_raw(&self, versions: &mut Vec<RowVersion>, row_version: RowVersion) {
// NOTICE: this is an insert a'la insertion sort, with pessimistic linear complexity.
// However, we expect the number of versions to be nearly sorted, so we deem it worthy
// to search linearly for the insertion point instead of paying the price of using
// another data structure, e.g. a BTreeSet. If it proves to be too quadratic empirically,
// we can either switch to a tree-like structure, or at least use partition_point()
// which performs a binary search for the insertion point.
let position = versions
.iter()
.rposition(|v| {
self.get_begin_timestamp(&v.begin) < self.get_begin_timestamp(&row_version.begin)
})
.map(|p| p + 1)
.unwrap_or(0);
versions.insert(position, row_version);
}
/// Inserts a new row into the database.
///
/// This function inserts a new `row` into the database within the context
@@ -297,10 +340,9 @@ impl<Clock: LogicalClock> Database<Clock> {
end: None,
row,
};
let versions = self.rows.get_or_insert_with(id, || RwLock::new(Vec::new()));
let mut versions = versions.value().write().unwrap();
versions.push(row_version);
tx.insert_to_write_set(id);
drop(tx);
self.insert_version(id, row_version);
Ok(())
}
@@ -364,7 +406,9 @@ impl<Clock: LogicalClock> Database<Clock> {
}
if is_version_visible(&self.txs, &tx, rv) {
rv.end = Some(TxTimestampOrID::TxID(tx.tx_id));
drop(tx); // FIXME: maybe just grab the write lock above? Do we ever expect conflicts?
drop(row_versions);
drop(row_versions_opt);
drop(tx);
let tx = self
.txs
.get(&tx_id)
@@ -440,7 +484,7 @@ impl<Clock: LogicalClock> Database<Clock> {
let tx_id = self.get_tx_id();
let begin_ts = self.get_timestamp();
let tx = Transaction::new(tx_id, begin_ts);
tracing::trace!("BEGIN {tx}");
tracing::trace!("BEGIN {tx}");
self.txs.insert(tx_id, RwLock::new(tx));
tx_id
}
@@ -456,6 +500,8 @@ impl<Clock: LogicalClock> Database<Clock> {
/// * `tx_id` - The ID of the transaction to commit.
pub fn commit_tx(&self, tx_id: TxID) -> Result<()> {
let end_ts = self.get_timestamp();
// NOTICE: the first shadowed tx keeps the entry alive in the map
// for the duration of this whole function, which is important for correctness!
let tx = self.txs.get(&tx_id).unwrap();
let tx = tx.value().write().unwrap();
match tx.state.load() {
@@ -542,29 +588,38 @@ impl<Clock: LogicalClock> Database<Clock> {
*/
tx.state.store(TransactionState::Committed(end_ts));
tracing::trace!("COMMIT {tx}");
let tx_begin_ts = tx.begin_ts;
let write_set: Vec<RowID> = tx.write_set.iter().map(|v| *v.value()).collect();
drop(tx);
// Postprocessing: inserting row versions and logging the transaction to persistent storage.
// TODO: we should probably save to persistent storage first, and only then update the in-memory structures.
let mut log_record: LogRecord = LogRecord::new(end_ts);
for id in &tx.write_set {
let id = id.value();
for ref id in write_set {
if let Some(row_versions) = self.rows.get(id) {
let mut row_versions = row_versions.value().write().unwrap();
for row_version in row_versions.iter_mut() {
if let TxTimestampOrID::TxID(id) = row_version.begin {
if id == tx_id {
row_version.begin = TxTimestampOrID::Timestamp(tx.begin_ts);
log_record.row_versions.push(row_version.clone()); // FIXME: optimize cloning out
row_version.begin = TxTimestampOrID::Timestamp(tx_begin_ts);
self.insert_version_raw(
&mut log_record.row_versions,
row_version.clone(),
); // FIXME: optimize cloning out
}
}
if let Some(TxTimestampOrID::TxID(id)) = row_version.end {
if id == tx_id {
row_version.end = Some(TxTimestampOrID::Timestamp(end_ts));
log_record.row_versions.push(row_version.clone()); // FIXME: optimize cloning out
self.insert_version_raw(
&mut log_record.row_versions,
row_version.clone(),
); // FIXME: optimize cloning out
}
}
}
}
}
tracing::trace!("UPDATED TX{tx_id}");
// We have now updated all the versions with a reference to the
// transaction ID to a timestamp and can, therefore, remove the
// transaction. Please note that when we move to lockless, the
@@ -577,6 +632,7 @@ impl<Clock: LogicalClock> Database<Clock> {
if !log_record.row_versions.is_empty() {
self.storage.log_tx(log_record)?;
}
tracing::trace!("LOGGED {tx_id}");
Ok(())
}
@@ -589,13 +645,14 @@ impl<Clock: LogicalClock> Database<Clock> {
///
/// * `tx_id` - The ID of the transaction to abort.
pub fn rollback_tx(&self, tx_id: TxID) {
let tx = self.txs.get(&tx_id).unwrap();
let tx = tx.value().write().unwrap();
let tx_unlocked = self.txs.get(&tx_id).unwrap();
let tx = tx_unlocked.value().write().unwrap();
assert!(tx.state == TransactionState::Active);
tx.state.store(TransactionState::Aborted);
tracing::trace!("ABORT {tx}");
for id in &tx.write_set {
let id = id.value();
let write_set: Vec<RowID> = tx.write_set.iter().map(|v| *v.value()).collect();
drop(tx);
for ref id in write_set {
if let Some(row_versions) = self.rows.get(id) {
let mut row_versions = row_versions.value().write().unwrap();
row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id));
@@ -604,6 +661,7 @@ impl<Clock: LogicalClock> Database<Clock> {
}
}
}
let tx = tx_unlocked.value().write().unwrap();
tx.state.store(TransactionState::Terminated);
tracing::trace!("TERMINATE {tx}");
}
@@ -664,11 +722,7 @@ impl<Clock: LogicalClock> Database<Clock> {
for record in tx_log {
tracing::debug!("RECOVERING {:?}", record);
for version in record.row_versions {
let row_versions = self
.rows
.get_or_insert_with(version.row.id, || RwLock::new(Vec::new()));
let mut row_versions = row_versions.value().write().unwrap();
row_versions.push(version);
self.insert_version(version.row.id, version);
}
self.clock.reset(record.tx_timestamp);
}

View File

@@ -1,65 +1,111 @@
use mvcc_rs::clock::LocalClock;
use mvcc_rs::database::{Database, Row, RowID};
use shuttle::sync::atomic::AtomicU64;
use shuttle::sync::Arc;
use shuttle::thread;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::sync::{Arc, Once};
static IDS: AtomicU64 = AtomicU64::new(1);
static START: Once = Once::new();
#[test]
fn test_non_overlapping_concurrent_inserts() {
START.call_once(|| {
tracing_subscriber::fmt::init();
});
// Two threads insert to the database concurrently using non-overlapping
// row IDs.
let clock = LocalClock::default();
let storage = mvcc_rs::persistent_storage::Storage::new_noop();
let db = Arc::new(Database::new(clock, storage));
let ids = Arc::new(AtomicU64::new(0));
shuttle::check_random(
move || {
{
let db = db.clone();
let ids = ids.clone();
thread::spawn(move || {
let tx = db.begin_tx();
let id = ids.fetch_add(1, Ordering::SeqCst);
let id = RowID {
table_id: 1,
row_id: id,
};
let row = Row {
id,
data: "Hello".to_string(),
};
db.insert(tx, row.clone()).unwrap();
db.commit_tx(tx).unwrap();
let tx = db.begin_tx();
let committed_row = db.read(tx, id).unwrap();
db.commit_tx(tx).unwrap();
assert_eq!(committed_row, Some(row));
});
let iterations = 100000;
let th1 = {
let db = db.clone();
std::thread::spawn(move || {
for _ in 0..iterations {
let tx = db.begin_tx();
let id = IDS.fetch_add(1, Ordering::SeqCst);
let id = RowID {
table_id: 1,
row_id: id,
};
let row = Row {
id,
data: "Hello".to_string(),
};
db.insert(tx, row.clone()).unwrap();
db.commit_tx(tx).unwrap();
let tx = db.begin_tx();
let committed_row = db.read(tx, id).unwrap();
db.commit_tx(tx).unwrap();
assert_eq!(committed_row, Some(row));
}
{
let db = db.clone();
let ids = ids.clone();
thread::spawn(move || {
let tx = db.begin_tx();
let id = ids.fetch_add(1, Ordering::SeqCst);
let id = RowID {
table_id: 1,
row_id: id,
};
let row = Row {
id,
data: "World".to_string(),
};
db.insert(tx, row.clone()).unwrap();
db.commit_tx(tx).unwrap();
let tx = db.begin_tx();
let committed_row = db.read(tx, id).unwrap();
db.commit_tx(tx).unwrap();
assert_eq!(committed_row, Some(row));
});
})
};
let th2 = {
std::thread::spawn(move || {
for _ in 0..iterations {
let tx = db.begin_tx();
let id = IDS.fetch_add(1, Ordering::SeqCst);
let id = RowID {
table_id: 1,
row_id: id,
};
let row = Row {
id,
data: "World".to_string(),
};
db.insert(tx, row.clone()).unwrap();
db.commit_tx(tx).unwrap();
let tx = db.begin_tx();
let committed_row = db.read(tx, id).unwrap();
db.commit_tx(tx).unwrap();
assert_eq!(committed_row, Some(row));
}
},
100,
);
})
};
th1.join().unwrap();
th2.join().unwrap();
}
#[test]
fn test_overlapping_concurrent_inserts_read_your_writes() {
START.call_once(|| {
tracing_subscriber::fmt::init();
}); // Two threads insert to the database concurrently using overlapping row IDs.
let clock = LocalClock::default();
let storage = mvcc_rs::persistent_storage::Storage::new_noop();
let db = Arc::new(Database::new(clock, storage));
let iterations = 100000;
let work = |prefix: &'static str| {
let db = db.clone();
std::thread::spawn(move || {
for i in 0..iterations {
if i % 1000 == 0 {
tracing::debug!("{prefix}: {i}");
}
let tx = db.begin_tx();
let id = i % 16;
let id = RowID {
table_id: 1,
row_id: id,
};
let row = Row {
id,
data: format!("{prefix} @{tx}"),
};
db.insert(tx, row.clone()).unwrap();
let committed_row = db.read(tx, id).unwrap();
db.commit_tx(tx).unwrap();
assert_eq!(committed_row, Some(row));
}
})
};
let threads = vec![work("A"), work("B"), work("C"), work("D")];
for th in threads {
th.join().unwrap();
}
}