diff --git a/core/lib.rs b/core/lib.rs index 2c38bfe41..6945e6a63 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -324,7 +324,8 @@ impl Database { fn init_pager(&self, page_size: Option) -> Result { // Open existing WAL file if present - if let Some(shared_wal) = self.maybe_shared_wal.read().clone() { + let mut maybe_shared_wal = self.maybe_shared_wal.write(); + if let Some(shared_wal) = maybe_shared_wal.clone() { let size = match page_size { None => unsafe { (*shared_wal.get()).page_size() as usize }, Some(size) => size, @@ -379,7 +380,7 @@ impl Database { let real_shared_wal = WalFileShared::new_shared(size, &self.io, file)?; // Modify Database::maybe_shared_wal to point to the new WAL file so that other connections // can open the existing WAL. - *self.maybe_shared_wal.write() = Some(real_shared_wal.clone()); + *maybe_shared_wal = Some(real_shared_wal.clone()); let wal = Rc::new(RefCell::new(WalFile::new( self.io.clone(), real_shared_wal, diff --git a/tests/integration/query_processing/test_multi_thread.rs b/tests/integration/query_processing/test_multi_thread.rs index 1d96e0422..4682659c2 100644 --- a/tests/integration/query_processing/test_multi_thread.rs +++ b/tests/integration/query_processing/test_multi_thread.rs @@ -1,4 +1,8 @@ -use crate::common::TempDatabase; +use std::sync::{atomic::AtomicUsize, Arc}; + +use turso_core::StepResult; + +use crate::common::{maybe_setup_tracing, TempDatabase}; #[test] fn test_schema_change() { @@ -25,3 +29,157 @@ fn test_schema_change() { }; println!("{:?} {:?}", row.get_value(0), row.get_value(1)); } + +#[test] +#[ignore] +fn test_create_multiple_connections() -> anyhow::Result<()> { + maybe_setup_tracing(); + let tries = 1; + for _ in 0..tries { + let tmp_db = Arc::new(TempDatabase::new_empty(false)); + { + let conn = tmp_db.connect_limbo(); + conn.execute("CREATE TABLE t(x)").unwrap(); + } + + let mut threads = Vec::new(); + for i in 0..10 { + let tmp_db_ = tmp_db.clone(); + threads.push(std::thread::spawn(move || { + let conn = tmp_db_.connect_limbo(); + 'outer: loop { + let mut stmt = conn + .prepare(format!("INSERT INTO t VALUES ({i})").as_str()) + .unwrap(); + tracing::info!("inserting row {}", i); + loop { + match stmt.step().unwrap() { + StepResult::Row => { + panic!("unexpected row result"); + } + StepResult::IO => { + stmt.run_once().unwrap(); + } + StepResult::Done => { + tracing::info!("inserted row {}", i); + break 'outer; + } + StepResult::Interrupt => { + panic!("unexpected step result"); + } + StepResult::Busy => { + // repeat until we can insert it + tracing::info!("busy {}, repeating", i); + break; + } + } + } + } + })); + } + for thread in threads { + thread.join().unwrap(); + } + + let conn = tmp_db.connect_limbo(); + let mut stmt = conn.prepare("SELECT * FROM t").unwrap(); + let mut rows = Vec::new(); + loop { + match stmt.step().unwrap() { + StepResult::Row => { + let row = stmt.row().unwrap(); + rows.push(row.get::(0).unwrap()); + } + StepResult::IO => { + stmt.run_once().unwrap(); + } + StepResult::Done => { + break; + } + StepResult::Interrupt => { + panic!("unexpected step result"); + } + StepResult::Busy => { + panic!("unexpected busy result on select"); + } + } + } + rows.sort(); + assert_eq!(rows, (0..10).collect::>()); + } + Ok(()) +} + +#[test] +#[ignore] +fn test_reader_writer() -> anyhow::Result<()> { + let tries = 10; + for _ in 0..tries { + let tmp_db = Arc::new(TempDatabase::new_empty(false)); + { + let conn = tmp_db.connect_limbo(); + conn.execute("CREATE TABLE t(x)").unwrap(); + } + + let mut threads = Vec::new(); + let number_of_writers = 100; + let current_written_rows = Arc::new(AtomicUsize::new(0)); + { + let tmp_db = tmp_db.clone(); + let current_written_rows = current_written_rows.clone(); + threads.push(std::thread::spawn(move || { + let conn = tmp_db.connect_limbo(); + for i in 0..number_of_writers { + conn.execute(format!("INSERT INTO t VALUES ({i})").as_str()) + .unwrap(); + current_written_rows.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + })); + } + { + let current_written_rows = current_written_rows.clone(); + threads.push(std::thread::spawn(move || { + let conn = tmp_db.connect_limbo(); + loop { + let current_written_rows = + current_written_rows.load(std::sync::atomic::Ordering::Relaxed); + if current_written_rows == number_of_writers { + break; + } + let mut stmt = conn.prepare("SELECT * FROM t").unwrap(); + let mut rows = Vec::new(); + loop { + match stmt.step().unwrap() { + StepResult::Row => { + let row = stmt.row().unwrap(); + let x = row.get::(0).unwrap(); + rows.push(x); + } + StepResult::IO => { + stmt.run_once().unwrap(); + } + StepResult::Done => { + rows.sort(); + for i in 0..current_written_rows { + let i = i as i64; + assert!( + rows.contains(&i), + "row {i} not found in {rows:?}. current_written_rows: {current_written_rows}", + ); + } + break; + } + StepResult::Interrupt | StepResult::Busy => { + panic!("unexpected step result"); + } + } + } + } + })); + } + for thread in threads { + thread.join().unwrap(); + } + } + Ok(()) +} diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index f7605323e..b5c4e0b51 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -765,11 +765,11 @@ fn test_read_wal_dumb_no_frames() -> anyhow::Result<()> { Ok(()) } -fn run_query(tmp_db: &TempDatabase, conn: &Arc, query: &str) -> anyhow::Result<()> { +pub fn run_query(tmp_db: &TempDatabase, conn: &Arc, query: &str) -> anyhow::Result<()> { run_query_core(tmp_db, conn, query, None::) } -fn run_query_on_row( +pub fn run_query_on_row( tmp_db: &TempDatabase, conn: &Arc, query: &str, @@ -778,7 +778,7 @@ fn run_query_on_row( run_query_core(tmp_db, conn, query, Some(on_row)) } -fn run_query_core( +pub fn run_query_core( _tmp_db: &TempDatabase, conn: &Arc, query: &str,