Merge ' core/lib: init_pager lock shared wal until filled ' from Pere Diaz Bou

maybe_shared_wal's lock is held for a limited time increasing the chance
of initializing the shared wal twice.

Reviewed-by: Jussi Saurio <jussi.saurio@gmail.com>

Closes #2201
This commit is contained in:
Pere Diaz Bou
2025-07-21 13:06:41 +02:00
3 changed files with 165 additions and 6 deletions

View File

@@ -324,7 +324,8 @@ impl Database {
fn init_pager(&self, page_size: Option<usize>) -> Result<Pager> {
// Open existing WAL file if present
if let Some(shared_wal) = self.maybe_shared_wal.read().clone() {
let mut maybe_shared_wal = self.maybe_shared_wal.write();
if let Some(shared_wal) = maybe_shared_wal.clone() {
let size = match page_size {
None => unsafe { (*shared_wal.get()).page_size() as usize },
Some(size) => size,
@@ -379,7 +380,7 @@ impl Database {
let real_shared_wal = WalFileShared::new_shared(size, &self.io, file)?;
// Modify Database::maybe_shared_wal to point to the new WAL file so that other connections
// can open the existing WAL.
*self.maybe_shared_wal.write() = Some(real_shared_wal.clone());
*maybe_shared_wal = Some(real_shared_wal.clone());
let wal = Rc::new(RefCell::new(WalFile::new(
self.io.clone(),
real_shared_wal,

View File

@@ -1,4 +1,8 @@
use crate::common::TempDatabase;
use std::sync::{atomic::AtomicUsize, Arc};
use turso_core::StepResult;
use crate::common::{maybe_setup_tracing, TempDatabase};
#[test]
fn test_schema_change() {
@@ -25,3 +29,157 @@ fn test_schema_change() {
};
println!("{:?} {:?}", row.get_value(0), row.get_value(1));
}
#[test]
#[ignore]
fn test_create_multiple_connections() -> anyhow::Result<()> {
maybe_setup_tracing();
let tries = 1;
for _ in 0..tries {
let tmp_db = Arc::new(TempDatabase::new_empty(false));
{
let conn = tmp_db.connect_limbo();
conn.execute("CREATE TABLE t(x)").unwrap();
}
let mut threads = Vec::new();
for i in 0..10 {
let tmp_db_ = tmp_db.clone();
threads.push(std::thread::spawn(move || {
let conn = tmp_db_.connect_limbo();
'outer: loop {
let mut stmt = conn
.prepare(format!("INSERT INTO t VALUES ({i})").as_str())
.unwrap();
tracing::info!("inserting row {}", i);
loop {
match stmt.step().unwrap() {
StepResult::Row => {
panic!("unexpected row result");
}
StepResult::IO => {
stmt.run_once().unwrap();
}
StepResult::Done => {
tracing::info!("inserted row {}", i);
break 'outer;
}
StepResult::Interrupt => {
panic!("unexpected step result");
}
StepResult::Busy => {
// repeat until we can insert it
tracing::info!("busy {}, repeating", i);
break;
}
}
}
}
}));
}
for thread in threads {
thread.join().unwrap();
}
let conn = tmp_db.connect_limbo();
let mut stmt = conn.prepare("SELECT * FROM t").unwrap();
let mut rows = Vec::new();
loop {
match stmt.step().unwrap() {
StepResult::Row => {
let row = stmt.row().unwrap();
rows.push(row.get::<i64>(0).unwrap());
}
StepResult::IO => {
stmt.run_once().unwrap();
}
StepResult::Done => {
break;
}
StepResult::Interrupt => {
panic!("unexpected step result");
}
StepResult::Busy => {
panic!("unexpected busy result on select");
}
}
}
rows.sort();
assert_eq!(rows, (0..10).collect::<Vec<_>>());
}
Ok(())
}
#[test]
#[ignore]
fn test_reader_writer() -> anyhow::Result<()> {
let tries = 10;
for _ in 0..tries {
let tmp_db = Arc::new(TempDatabase::new_empty(false));
{
let conn = tmp_db.connect_limbo();
conn.execute("CREATE TABLE t(x)").unwrap();
}
let mut threads = Vec::new();
let number_of_writers = 100;
let current_written_rows = Arc::new(AtomicUsize::new(0));
{
let tmp_db = tmp_db.clone();
let current_written_rows = current_written_rows.clone();
threads.push(std::thread::spawn(move || {
let conn = tmp_db.connect_limbo();
for i in 0..number_of_writers {
conn.execute(format!("INSERT INTO t VALUES ({i})").as_str())
.unwrap();
current_written_rows.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}));
}
{
let current_written_rows = current_written_rows.clone();
threads.push(std::thread::spawn(move || {
let conn = tmp_db.connect_limbo();
loop {
let current_written_rows =
current_written_rows.load(std::sync::atomic::Ordering::Relaxed);
if current_written_rows == number_of_writers {
break;
}
let mut stmt = conn.prepare("SELECT * FROM t").unwrap();
let mut rows = Vec::new();
loop {
match stmt.step().unwrap() {
StepResult::Row => {
let row = stmt.row().unwrap();
let x = row.get::<i64>(0).unwrap();
rows.push(x);
}
StepResult::IO => {
stmt.run_once().unwrap();
}
StepResult::Done => {
rows.sort();
for i in 0..current_written_rows {
let i = i as i64;
assert!(
rows.contains(&i),
"row {i} not found in {rows:?}. current_written_rows: {current_written_rows}",
);
}
break;
}
StepResult::Interrupt | StepResult::Busy => {
panic!("unexpected step result");
}
}
}
}
}));
}
for thread in threads {
thread.join().unwrap();
}
}
Ok(())
}

View File

@@ -765,11 +765,11 @@ fn test_read_wal_dumb_no_frames() -> anyhow::Result<()> {
Ok(())
}
fn run_query(tmp_db: &TempDatabase, conn: &Arc<Connection>, query: &str) -> anyhow::Result<()> {
pub fn run_query(tmp_db: &TempDatabase, conn: &Arc<Connection>, query: &str) -> anyhow::Result<()> {
run_query_core(tmp_db, conn, query, None::<fn(&Row)>)
}
fn run_query_on_row(
pub fn run_query_on_row(
tmp_db: &TempDatabase,
conn: &Arc<Connection>,
query: &str,
@@ -778,7 +778,7 @@ fn run_query_on_row(
run_query_core(tmp_db, conn, query, Some(on_row))
}
fn run_query_core(
pub fn run_query_core(
_tmp_db: &TempDatabase,
conn: &Arc<Connection>,
query: &str,