From 10d02525d6d700141a2619880cea99528654a09e Mon Sep 17 00:00:00 2001 From: Pere Diaz Bou Date: Wed, 18 Jun 2025 16:52:42 +0200 Subject: [PATCH] introduce concurrent write test The idea is quite simple: write with 4 concurrent writers and once all are finsihed, check the count of rows written is correct. --- core/storage/wal.rs | 5 +- tests/integration/common.rs | 54 ++++--- .../query_processing/test_write_path.rs | 132 +++++++++++++++++- 3 files changed, 166 insertions(+), 25 deletions(-) diff --git a/core/storage/wal.rs b/core/storage/wal.rs index f73838a27..4532fc9c7 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -167,7 +167,7 @@ impl LimboRwLock { /// Unlock the current held lock. pub fn unlock(&mut self) { let lock = self.lock.load(Ordering::SeqCst); - tracing::trace!("unlock(lock={})", lock); + tracing::trace!("unlock(value={})", lock); match lock { NO_LOCK => {} SHARED_LOCK => { @@ -511,6 +511,7 @@ impl Wal for WalFile { let shared = self.get_shared(); { let lock = &mut shared.read_locks[max_read_mark_index as usize]; + tracing::trace!("begin_read_tx_read_lock(lock={})", max_read_mark_index); let busy = !lock.read(); if busy { return Ok(LimboResult::Busy); @@ -532,7 +533,7 @@ impl Wal for WalFile { /// End a read transaction. #[inline(always)] fn end_read_tx(&self) -> Result { - tracing::debug!("end_read_tx"); + tracing::debug!("end_read_tx(lock={})", self.max_frame_read_lock_index); let read_lock = &mut self.get_shared().read_locks[self.max_frame_read_lock_index]; read_lock.unlock(); Ok(LimboResult::Ok) diff --git a/tests/integration/common.rs b/tests/integration/common.rs index 4814c5777..25356deed 100644 --- a/tests/integration/common.rs +++ b/tests/integration/common.rs @@ -12,6 +12,7 @@ use tracing_subscriber::EnvFilter; pub struct TempDatabase { pub path: PathBuf, pub io: Arc, + pub db: Arc, } unsafe impl Send for TempDatabase {} @@ -25,14 +26,29 @@ impl TempDatabase { let mut path = TempDir::new().unwrap().keep(); path.push(db_name); let io: Arc = Arc::new(limbo_core::PlatformIO::new().unwrap()); - Self { path, io } + let db = Database::open_file_with_flags( + io.clone(), + path.to_str().unwrap(), + limbo_core::OpenFlags::default(), + false, + ) + .unwrap(); + Self { path, io, db } } pub fn new_with_existent(db_path: &Path) -> Self { + Self::new_with_existent_with_flags(db_path, limbo_core::OpenFlags::default()) + } + + pub fn new_with_existent_with_flags(db_path: &Path, flags: limbo_core::OpenFlags) -> Self { let io: Arc = Arc::new(limbo_core::PlatformIO::new().unwrap()); + let db = + Database::open_file_with_flags(io.clone(), db_path.to_str().unwrap(), flags, false) + .unwrap(); Self { path: db_path.to_path_buf(), io, + db, } } @@ -50,28 +66,21 @@ impl TempDatabase { connection.execute(table_sql, ()).unwrap(); } let io: Arc = Arc::new(limbo_core::PlatformIO::new().unwrap()); - - Self { path, io } - } - - pub fn connect_limbo(&self) -> Arc { - Self::connect_limbo_with_flags(&self, limbo_core::OpenFlags::default()) - } - - pub fn connect_limbo_with_flags( - &self, - flags: limbo_core::OpenFlags, - ) -> Arc { - log::debug!("conneting to limbo"); let db = Database::open_file_with_flags( - self.io.clone(), - self.path.to_str().unwrap(), - flags, + io.clone(), + path.to_str().unwrap(), + limbo_core::OpenFlags::default(), false, ) .unwrap(); - let conn = db.connect().unwrap(); + Self { path, io, db } + } + + pub fn connect_limbo(&self) -> Arc { + log::debug!("conneting to limbo"); + + let conn = self.db.connect().unwrap(); log::debug!("connected to limbo"); conn } @@ -260,7 +269,10 @@ mod tests { #[test] fn test_limbo_open_read_only() -> anyhow::Result<()> { let path = TempDir::new().unwrap().keep().join("temp_read_only"); - let db = TempDatabase::new_with_existent(&path); + let db = TempDatabase::new_with_existent_with_flags( + &path, + limbo_core::OpenFlags::default() | limbo_core::OpenFlags::ReadOnly, + ); { let conn = db.connect_limbo(); let ret = limbo_exec_rows(&db, &conn, "CREATE table t(a)"); @@ -270,9 +282,7 @@ mod tests { } { - let conn = db.connect_limbo_with_flags( - limbo_core::OpenFlags::default() | limbo_core::OpenFlags::ReadOnly, - ); + let conn = db.connect_limbo(); let ret = limbo_exec_rows(&db, &conn, "SELECT * from t"); assert_eq!(ret, vec![vec![Value::Integer(1)]]); diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index 0e0597167..da89b2077 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -1,9 +1,21 @@ use crate::common::{self, maybe_setup_tracing}; use crate::common::{compare_string, do_flush, TempDatabase}; -use limbo_core::{Connection, Row, StepResult, Value}; +use limbo_core::{Connection, Row, Statement, StepResult, Value}; use log::debug; use std::sync::Arc; +#[macro_export] +macro_rules! change_state { + ($current:expr, $pattern:pat => $selector:expr) => { + let state = match std::mem::replace($current, unsafe { std::mem::zeroed() }) { + $pattern => $selector, + _ => panic!("unexpected state"), + }; + #[allow(clippy::forget_non_drop)] + std::mem::forget(std::mem::replace($current, state)); + }; +} + #[test] #[ignore] fn test_simple_overflow_page() -> anyhow::Result<()> { @@ -448,6 +460,124 @@ fn test_delete_with_index() -> anyhow::Result<()> { Ok(()) } +enum ConnectionState { + PrepareQuery { query_idx: usize }, + ExecuteQuery { query_idx: usize, stmt: Statement }, + Done, +} + +struct ConnectionPlan { + queries: Vec, + conn: Arc, + state: ConnectionState, +} + +impl ConnectionPlan { + pub fn step(&mut self) -> anyhow::Result { + loop { + match &mut self.state { + ConnectionState::PrepareQuery { query_idx } => { + if *query_idx >= self.queries.len() { + self.state = ConnectionState::Done; + return Ok(true); + } + let query = &self.queries[*query_idx]; + tracing::info!("preparing {}", query); + let stmt = self.conn.query(query)?.unwrap(); + self.state = ConnectionState::ExecuteQuery { + query_idx: *query_idx, + stmt, + }; + } + ConnectionState::ExecuteQuery { stmt, query_idx } => loop { + let query = &self.queries[*query_idx]; + tracing::info!("stepping {}", query); + let current_query_idx = *query_idx; + let step_result = stmt.step()?; + match step_result { + StepResult::IO => { + return Ok(false); + } + StepResult::Done => { + change_state!(&mut self.state, ConnectionState::ExecuteQuery { .. } => ConnectionState::PrepareQuery { query_idx: current_query_idx + 1 }); + return Ok(false); + } + StepResult::Row => {} + StepResult::Busy => { + return Ok(false); + } + _ => unreachable!(), + } + }, + ConnectionState::Done => { + return Ok(true); + } + } + } + } + + pub fn is_finished(&self) -> bool { + matches!(self.state, ConnectionState::Done) + } +} + +#[test] +fn test_write_concurrent_connections() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + + maybe_setup_tracing(); + + let tmp_db = TempDatabase::new_with_rusqlite("CREATE TABLE t(x)"); + let num_connections = 4; + let num_inserts_per_connection = 100; + let mut connections = vec![]; + for connection_idx in 0..num_connections { + let conn = tmp_db.connect_limbo(); + let mut queries = Vec::with_capacity(num_inserts_per_connection); + for query_idx in 0..num_inserts_per_connection { + queries.push(format!( + "INSERT INTO t VALUES({})", + (connection_idx * num_inserts_per_connection) + query_idx + )); + } + connections.push(ConnectionPlan { + queries, + conn, + state: ConnectionState::PrepareQuery { query_idx: 0 }, + }); + } + + let mut connections_finished = 0; + while connections_finished != num_connections { + for conn in &mut connections { + if conn.is_finished() { + continue; + } + let finished = conn.step()?; + if finished { + connections_finished += 1; + } + } + } + + let conn = tmp_db.connect_limbo(); + // run_query_on_row(&tmp_db, &conn, "SELECT * from t", |row: &Row| { + // for value in row.get_values() { + // tracing::info!("got value {:?}", value); + // } + // })?; + run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t", |row: &Row| { + let count = row.get::(0).unwrap(); + assert_eq!( + count, + (num_connections * num_inserts_per_connection) as i64, + "received wrong number of rows" + ); + })?; + + Ok(()) +} + fn run_query(tmp_db: &TempDatabase, conn: &Arc, query: &str) -> anyhow::Result<()> { run_query_core(tmp_db, conn, query, None::) }