From 7b1e2093e897ec3c6c406b4607a0ac8fa0ce12a6 Mon Sep 17 00:00:00 2001 From: Pere Diaz Bou Date: Wed, 5 Mar 2025 17:29:22 +0100 Subject: [PATCH] add multi threaded test for simple writer/reader test --- tests/integration/wal/test_wal.rs | 84 +++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tests/integration/wal/test_wal.rs b/tests/integration/wal/test_wal.rs index a0173a2b4..f88786cf1 100644 --- a/tests/integration/wal/test_wal.rs +++ b/tests/integration/wal/test_wal.rs @@ -1,7 +1,9 @@ use crate::common::{do_flush, TempDatabase}; use limbo_core::{Connection, LimboError, Result, StepResult}; use std::cell::RefCell; +use std::ops::{Add, Deref}; use std::rc::Rc; +use std::sync::{Arc, Mutex}; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::EnvFilter; @@ -33,6 +35,88 @@ fn test_wal_checkpoint_result() -> Result<()> { Ok(()) } +#[test] +fn test_wal_1_writer_1_reader() -> Result<()> { + maybe_setup_tracing(); + let tmp_db = Arc::new(Mutex::new(TempDatabase::new("test_wal.db"))); + let db = tmp_db.lock().unwrap().limbo_database(); + + { + let conn = db.connect().unwrap(); + match conn.query("CREATE TABLE t (id)")? { + Some(ref mut rows) => loop { + match rows.step().unwrap() { + StepResult::Row => {} + StepResult::IO => { + tmp_db.lock().unwrap().io.run_once().unwrap(); + } + StepResult::Interrupt => break, + StepResult::Done => break, + StepResult::Busy => unreachable!(), + } + }, + None => todo!(), + } + do_flush(&conn, tmp_db.lock().unwrap().deref()).unwrap(); + } + let rows = Arc::new(std::sync::Mutex::new(0)); + let rows_ = rows.clone(); + const ROWS_WRITE: usize = 1000; + let tmp_db_w = db.clone(); + let writer_thread = std::thread::spawn(move || { + let conn = tmp_db_w.connect().unwrap(); + for i in 0..ROWS_WRITE { + println!("adding {}", i); + conn.execute(format!("INSERT INTO t values({})", i).as_str()) + .unwrap(); + let mut rows = rows_.lock().unwrap(); + *rows += 1; + } + }); + let rows_ = rows.clone(); + let reader_thread = std::thread::spawn(move || { + let conn = db.connect().unwrap(); + loop { + let rows = *rows_.lock().unwrap(); + let mut i = 0; + println!("reading {}", rows); + match conn.query("SELECT * FROM t") { + Ok(Some(ref mut rows)) => loop { + match rows.step().unwrap() { + StepResult::Row => { + let row = rows.row().unwrap(); + let first_value = row.get_value(0); + let id = match first_value { + limbo_core::OwnedValue::Integer(i) => *i as i32, + _ => unreachable!(), + }; + assert_eq!(id, i); + i += 1; + } + StepResult::IO => { + tmp_db.lock().unwrap().io.run_once().unwrap(); + } + StepResult::Interrupt => break, + StepResult::Done => break, + StepResult::Busy => unreachable!(), + } + }, + Ok(None) => {} + Err(err) => { + eprintln!("{}", err); + } + } + if rows == ROWS_WRITE { + break; + } + } + }); + + writer_thread.join().unwrap(); + reader_thread.join().unwrap(); + Ok(()) +} + fn maybe_setup_tracing() { let _ = tracing_subscriber::registry() .with(