diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index ffa1a4445..77f69cf94 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -5,7 +5,7 @@ use limbo_core::Value; use std::ffi::{self, CStr, CString}; use tracing::trace; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; macro_rules! stub { () => { @@ -35,6 +35,10 @@ pub const SQLITE_CHECKPOINT_RESTART: ffi::c_int = 2; pub const SQLITE_CHECKPOINT_TRUNCATE: ffi::c_int = 3; pub struct sqlite3 { + pub(crate) inner: Arc>, +} + +struct sqlite3Inner { pub(crate) io: Arc, pub(crate) _db: Arc, pub(crate) conn: Arc, @@ -51,7 +55,7 @@ impl sqlite3 { db: Arc, conn: Arc, ) -> Self { - Self { + let inner = sqlite3Inner { io, _db: db, conn, @@ -60,7 +64,9 @@ impl sqlite3 { malloc_failed: false, e_open_state: SQLITE_STATE_OPEN, p_err: std::ptr::null_mut(), - } + }; + let inner = Arc::new(Mutex::new(inner)); + Self { inner } } } @@ -202,16 +208,17 @@ pub unsafe extern "C" fn sqlite3_context_db_handle(_context: *mut ffi::c_void) - #[no_mangle] pub unsafe extern "C" fn sqlite3_prepare_v2( - db: *mut sqlite3, + raw_db: *mut sqlite3, sql: *const ffi::c_char, _len: ffi::c_int, out_stmt: *mut *mut sqlite3_stmt, _tail: *mut *const ffi::c_char, ) -> ffi::c_int { - if db.is_null() || sql.is_null() || out_stmt.is_null() { + if raw_db.is_null() || sql.is_null() || out_stmt.is_null() { return SQLITE_MISUSE; } - let db: &mut sqlite3 = &mut *db; + let db: &mut sqlite3 = &mut *raw_db; + let db = db.inner.lock().unwrap(); let sql = CStr::from_ptr(sql); let sql = match sql.to_str() { Ok(s) => s, @@ -221,7 +228,7 @@ pub unsafe extern "C" fn sqlite3_prepare_v2( Ok(stmt) => stmt, Err(_) => return SQLITE_ERROR, }; - *out_stmt = Box::leak(Box::new(sqlite3_stmt::new(db, stmt))); + *out_stmt = Box::leak(Box::new(sqlite3_stmt::new(raw_db, stmt))); SQLITE_OK } @@ -239,6 +246,7 @@ pub unsafe extern "C" fn sqlite3_step(stmt: *mut sqlite3_stmt) -> ffi::c_int { let stmt = &mut *stmt; let db = &mut *stmt.db; loop { + let db = db.inner.lock().unwrap(); if let Ok(result) = stmt.stmt.step() { match result { limbo_core::StepResult::IO => { @@ -278,6 +286,7 @@ pub unsafe extern "C" fn sqlite3_exec( return SQLITE_MISUSE; } let db: &mut sqlite3 = &mut *db; + let db = db.inner.lock().unwrap(); let sql = CStr::from_ptr(sql); let sql = match sql.to_str() { Ok(s) => s, @@ -395,7 +404,8 @@ pub unsafe extern "C" fn sqlite3_errcode(db: *mut sqlite3) -> ffi::c_int { return SQLITE_MISUSE; } let db: &mut sqlite3 = &mut *db; - if !sqlite3_safety_check_sick_or_ok(db) { + let db = db.inner.lock().unwrap(); + if !sqlite3_safety_check_sick_or_ok(&db) { return SQLITE_MISUSE; } if db.malloc_failed { @@ -949,7 +959,8 @@ pub unsafe extern "C" fn sqlite3_errmsg(db: *mut sqlite3) -> *const ffi::c_char return sqlite3_errstr(SQLITE_NOMEM); } let db: &mut sqlite3 = &mut *db; - if !sqlite3_safety_check_sick_or_ok(db) { + let db = db.inner.lock().unwrap(); + if !sqlite3_safety_check_sick_or_ok(&db) { return sqlite3_errstr(SQLITE_MISUSE); } if db.malloc_failed { @@ -977,7 +988,8 @@ pub unsafe extern "C" fn sqlite3_extended_errcode(db: *mut sqlite3) -> ffi::c_in return SQLITE_MISUSE; } let db: &mut sqlite3 = &mut *db; - if !sqlite3_safety_check_sick_or_ok(db) { + let db = db.inner.lock().unwrap(); + if !sqlite3_safety_check_sick_or_ok(&db) { return SQLITE_MISUSE; } if db.malloc_failed { @@ -1091,6 +1103,7 @@ pub unsafe extern "C" fn sqlite3_wal_checkpoint_v2( return SQLITE_MISUSE; } let db: &mut sqlite3 = &mut *db; + let db = db.inner.lock().unwrap(); // TODO: Checkpointing modes and reporting back log size and checkpoint count to caller. if db.conn.checkpoint().is_err() { return SQLITE_ERROR; @@ -1125,6 +1138,7 @@ pub unsafe extern "C" fn libsql_wal_frame_count( return SQLITE_MISUSE; } let db: &mut sqlite3 = &mut *db; + let db = db.inner.lock().unwrap(); let frame_count = match db.conn.wal_frame_count() { Ok(count) => count as u32, Err(_) => return SQLITE_ERROR, @@ -1163,6 +1177,7 @@ pub unsafe extern "C" fn libsql_wal_get_frame( return SQLITE_MISUSE; } let db: &mut sqlite3 = &mut *db; + let db = db.inner.lock().unwrap(); match db.conn.wal_get_frame(frame_no, p_frame, frame_len) { Ok(c) => match db.io.wait_for_completion(c) { Ok(_) => SQLITE_OK, @@ -1172,7 +1187,7 @@ pub unsafe extern "C" fn libsql_wal_get_frame( } } -fn sqlite3_safety_check_sick_or_ok(db: &sqlite3) -> bool { +fn sqlite3_safety_check_sick_or_ok(db: &sqlite3Inner) -> bool { match db.e_open_state { SQLITE_STATE_SICK | SQLITE_STATE_OPEN | SQLITE_STATE_BUSY => true, _ => {