diff --git a/Cargo.lock b/Cargo.lock index 83db58c03..ee0f3efd4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1617,7 +1617,7 @@ dependencies = [ [[package]] name = "limbo_completion" -version = "0.0.15" +version = "0.0.16" dependencies = [ "limbo_ext", "mimalloc", diff --git a/bindings/go/rs_src/lib.rs b/bindings/go/rs_src/lib.rs index a6068732a..b2240a16b 100644 --- a/bindings/go/rs_src/lib.rs +++ b/bindings/go/rs_src/lib.rs @@ -22,15 +22,13 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { let path = unsafe { std::ffi::CStr::from_ptr(path) }; let path = path.to_str().unwrap(); let io: Arc = match path { - p if p.contains(":memory:") => { - Arc::new(limbo_core::MemoryIO::new().expect("Failed to create IO")) - } + p if p.contains(":memory:") => Arc::new(limbo_core::MemoryIO::new()), _ => Arc::new(limbo_core::PlatformIO::new().expect("Failed to create IO")), }; let db = Database::open_file(io.clone(), path); match db { Ok(db) => { - let conn = db.connect(); + let conn = db.connect().unwrap(); LimboConn::new(conn, io).to_ptr() } Err(e) => { diff --git a/bindings/java/rs_src/limbo_db.rs b/bindings/java/rs_src/limbo_db.rs index a7f6abec5..ef33bf6e2 100644 --- a/bindings/java/rs_src/limbo_db.rs +++ b/bindings/java/rs_src/limbo_db.rs @@ -92,7 +92,7 @@ pub extern "system" fn Java_tech_turso_core_LimboDB_connect0<'local>( } }; - let conn = LimboConnection::new(db.db.connect(), db.io.clone()); + let conn = LimboConnection::new(db.db.connect().unwrap(), db.io.clone()); conn.to_ptr() } diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 31d1fed3b..fa9045ef7 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -284,15 +284,15 @@ pub fn connect(path: &str) -> Result { match path { ":memory:" => { - let io: Arc = Arc::new(limbo_core::MemoryIO::new()?); + let io: Arc = Arc::new(limbo_core::MemoryIO::new()); let db = open_or(io.clone(), path)?; - let conn: Rc = db.connect(); + let conn: Rc = db.connect().unwrap(); Ok(Connection { conn, io }) } path => { let io: Arc = Arc::new(limbo_core::PlatformIO::new()?); let db = open_or(io.clone(), path)?; - let conn: Rc = db.connect(); + let conn: Rc = db.connect().unwrap(); Ok(Connection { conn, io }) } } diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index b1c8d6439..d99e7bbb6 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -41,7 +41,7 @@ impl Builder { pub async fn build(self) -> Result { match self.path.as_str() { ":memory:" => { - let io: Arc = Arc::new(limbo_core::MemoryIO::new()?); + let io: Arc = Arc::new(limbo_core::MemoryIO::new()); let db = limbo_core::Database::open_file(io, self.path.as_str())?; Ok(Database { inner: db }) } @@ -63,7 +63,7 @@ unsafe impl Sync for Database {} impl Database { pub fn connect(&self) -> Result { - let conn = self.inner.connect(); + let conn = self.inner.connect().unwrap(); #[allow(clippy::arc_with_non_send_sync)] let connection = Connection { inner: Arc::new(Mutex::new(conn)), diff --git a/bindings/wasm/lib.rs b/bindings/wasm/lib.rs index 34213b171..a9ee77f8e 100644 --- a/bindings/wasm/lib.rs +++ b/bindings/wasm/lib.rs @@ -1,6 +1,6 @@ use js_sys::{Array, Object}; use limbo_core::{ - maybe_init_database_file, BufferPool, OpenFlags, Pager, Result, WalFile, WalFileShared, + maybe_init_database_file, OpenFlags, Pager, Result, WalFileShared, }; use std::cell::RefCell; use std::rc::Rc; @@ -23,26 +23,19 @@ impl Database { .open_file(path, limbo_core::OpenFlags::Create, false) .unwrap(); maybe_init_database_file(&file, &io).unwrap(); - let page_io = Rc::new(DatabaseStorage::new(file)); + let page_io = Arc::new(DatabaseStorage::new(file)); let db_header = Pager::begin_open(page_io.clone()).unwrap(); // ensure db header is there io.run_once().unwrap(); - let page_size = db_header.borrow().page_size; + let page_size = db_header.lock().unwrap().page_size; let wal_path = format!("{}-wal", path); let wal_shared = WalFileShared::open_shared(&io, wal_path.as_str(), page_size).unwrap(); - let buffer_pool = Rc::new(BufferPool::new(page_size as usize)); - let wal = Rc::new(RefCell::new(WalFile::new( - io.clone(), - db_header.borrow().page_size as usize, - wal_shared.clone(), - buffer_pool.clone(), - ))); - let db = limbo_core::Database::open(io, page_io, wal, wal_shared, buffer_pool).unwrap(); - let conn = db.connect(); + let db = limbo_core::Database::open(io, page_io, wal_shared).unwrap(); + let conn = db.connect().unwrap(); Database { db, conn } } @@ -209,6 +202,9 @@ pub struct File { fd: i32, } +unsafe impl Send for File {} +unsafe impl Sync for File {} + #[allow(dead_code)] impl File { fn new(vfs: VFS, fd: i32) -> Self { @@ -245,7 +241,7 @@ impl limbo_core::File for File { fn pwrite( &self, pos: usize, - buffer: Rc>, + buffer: Arc>, c: limbo_core::Completion, ) -> Result<()> { let w = match &c { @@ -273,6 +269,8 @@ impl limbo_core::File for File { pub struct PlatformIO { vfs: VFS, } +unsafe impl Send for PlatformIO {} +unsafe impl Sync for PlatformIO {} impl limbo_core::IO for PlatformIO { fn open_file( @@ -280,9 +278,9 @@ impl limbo_core::IO for PlatformIO { path: &str, _flags: OpenFlags, _direct: bool, - ) -> Result> { + ) -> Result> { let fd = self.vfs.open(path, "a+"); - Ok(Rc::new(File { + Ok(Arc::new(File { vfs: VFS::new(), fd, })) @@ -320,15 +318,18 @@ extern "C" { } pub struct DatabaseStorage { - file: Rc, + file: Arc, } +unsafe impl Send for DatabaseStorage {} +unsafe impl Sync for DatabaseStorage {} impl DatabaseStorage { - pub fn new(file: Rc) -> Self { + pub fn new(file: Arc) -> Self { Self { file } } } + impl limbo_core::DatabaseStorage for DatabaseStorage { fn read_page(&self, page_idx: usize, c: limbo_core::Completion) -> Result<()> { let r = match c { @@ -348,7 +349,7 @@ impl limbo_core::DatabaseStorage for DatabaseStorage { fn write_page( &self, page_idx: usize, - buffer: Rc>, + buffer: Arc>, c: limbo_core::Completion, ) -> Result<()> { let size = buffer.borrow().len(); diff --git a/cli/app.rs b/cli/app.rs index bbc49ac46..4d1470adf 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -211,7 +211,7 @@ impl<'a> Limbo<'a> { } }; let db = Database::open_file(io.clone(), &db_file)?; - let conn = db.connect(); + let conn = db.connect().unwrap(); let h = LimboHelper::new(conn.clone(), io.clone()); rl.set_helper(Some(h)); let interrupt_count = Arc::new(AtomicUsize::new(0)); @@ -413,7 +413,7 @@ impl<'a> Limbo<'a> { }; self.io = Arc::clone(&io); let db = Database::open_file(self.io.clone(), path)?; - self.conn = db.connect(); + self.conn = db.connect().unwrap(); self.opts.db_file = path.to_string(); Ok(()) } diff --git a/cli/input.rs b/cli/input.rs index 2ace039de..f1c1fd3f6 100644 --- a/cli/input.rs +++ b/cli/input.rs @@ -120,7 +120,7 @@ pub fn get_writer(output: &str) -> Box { pub fn get_io(db_location: DbLocation, io_choice: Io) -> anyhow::Result> { Ok(match db_location { - DbLocation::Memory => Arc::new(limbo_core::MemoryIO::new()?), + DbLocation::Memory => Arc::new(limbo_core::MemoryIO::new()), DbLocation::Path => { match io_choice { Io::Syscall => { diff --git a/core/benches/benchmark.rs b/core/benches/benchmark.rs index d2aef982b..57deca54e 100644 --- a/core/benches/benchmark.rs +++ b/core/benches/benchmark.rs @@ -19,7 +19,7 @@ fn bench(criterion: &mut Criterion) { #[allow(clippy::arc_with_non_send_sync)] let io = Arc::new(PlatformIO::new().unwrap()); let db = Database::open_file(io.clone(), "../testing/testing.db").unwrap(); - let limbo_conn = db.connect(); + let limbo_conn = db.connect().unwrap(); let queries = [ "SELECT 1", diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 961cf5a85..6122aa9b6 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -76,7 +76,7 @@ unsafe extern "C" fn register_module( impl Database { fn register_scalar_function_impl(&self, name: &str, func: ScalarFunction) -> ResultCode { - self.syms.borrow_mut().functions.insert( + self.syms.lock().unwrap().functions.insert( name.to_string(), Rc::new(ExternalFunc::new_scalar(name.to_string(), func)), ); @@ -89,7 +89,7 @@ impl Database { args: i32, func: ExternAggFunc, ) -> ResultCode { - self.syms.borrow_mut().functions.insert( + self.syms.lock().unwrap().functions.insert( name.to_string(), Rc::new(ExternalFunc::new_aggregate(name.to_string(), args, func)), ); @@ -108,7 +108,8 @@ impl Database { implementation: module, }; self.syms - .borrow_mut() + .lock() + .unwrap() .vtab_modules .insert(name.to_string(), vmodule.into()); ResultCode::OK diff --git a/core/io/memory.rs b/core/io/memory.rs index ca15f1d3b..e5f4d03ea 100644 --- a/core/io/memory.rs +++ b/core/io/memory.rs @@ -4,15 +4,12 @@ use crate::Result; use std::{ cell::{Cell, RefCell, UnsafeCell}, collections::BTreeMap, - rc::Rc, sync::Arc, }; use tracing::debug; -pub struct MemoryIO { - pages: UnsafeCell>, - size: Cell, -} +pub struct MemoryIO {} +unsafe impl Send for MemoryIO {} // TODO: page size flag const PAGE_SIZE: usize = 4096; @@ -20,33 +17,23 @@ type MemPage = Box<[u8; PAGE_SIZE]>; impl MemoryIO { #[allow(clippy::arc_with_non_send_sync)] - pub fn new() -> Result> { + pub fn new() -> Self { debug!("Using IO backend 'memory'"); - Ok(Arc::new(Self { - pages: BTreeMap::new().into(), - size: 0.into(), - })) - } - - #[allow(clippy::mut_from_ref)] - fn get_or_allocate_page(&self, page_no: usize) -> &mut MemPage { - unsafe { - let pages = &mut *self.pages.get(); - pages - .entry(page_no) - .or_insert_with(|| Box::new([0; PAGE_SIZE])) - } - } - - fn get_page(&self, page_no: usize) -> Option<&MemPage> { - unsafe { (*self.pages.get()).get(&page_no) } + Self {} } } -impl IO for Arc { - fn open_file(&self, _path: &str, _flags: OpenFlags, _direct: bool) -> Result> { - Ok(Rc::new(MemoryFile { - io: Arc::clone(self), +impl Default for MemoryIO { + fn default() -> Self { + Self::new() + } +} + +impl IO for MemoryIO { + fn open_file(&self, _path: &str, _flags: OpenFlags, _direct: bool) -> Result> { + Ok(Arc::new(MemoryFile { + pages: BTreeMap::new().into(), + size: 0.into(), })) } @@ -67,8 +54,12 @@ impl IO for Arc { } pub struct MemoryFile { - io: Arc, + pages: UnsafeCell>, + size: Cell, } +unsafe impl Send for MemoryFile {} +unsafe impl Sync for MemoryFile {} + impl File for MemoryFile { fn lock_file(&self, _exclusive: bool) -> Result<()> { @@ -86,7 +77,7 @@ impl File for MemoryFile { return Ok(()); } - let file_size = self.io.size.get(); + let file_size = self.size.get(); if pos >= file_size { c.complete(0); return Ok(()); @@ -103,7 +94,7 @@ impl File for MemoryFile { let page_no = offset / PAGE_SIZE; let page_offset = offset % PAGE_SIZE; let bytes_to_read = remaining.min(PAGE_SIZE - page_offset); - if let Some(page) = self.io.get_page(page_no) { + if let Some(page) = self.get_page(page_no) { read_buf.as_mut_slice()[buf_offset..buf_offset + bytes_to_read] .copy_from_slice(&page[page_offset..page_offset + bytes_to_read]); } else { @@ -119,7 +110,7 @@ impl File for MemoryFile { Ok(()) } - fn pwrite(&self, pos: usize, buffer: Rc>, c: Completion) -> Result<()> { + fn pwrite(&self, pos: usize, buffer: Arc>, c: Completion) -> Result<()> { let buf = buffer.borrow(); let buf_len = buf.len(); if buf_len == 0 { @@ -138,7 +129,7 @@ impl File for MemoryFile { let bytes_to_write = remaining.min(PAGE_SIZE - page_offset); { - let page = self.io.get_or_allocate_page(page_no); + let page = self.get_or_allocate_page(page_no); page[page_offset..page_offset + bytes_to_write] .copy_from_slice(&data[buf_offset..buf_offset + bytes_to_write]); } @@ -148,9 +139,8 @@ impl File for MemoryFile { remaining -= bytes_to_write; } - self.io - .size - .set(core::cmp::max(pos + buf_len, self.io.size.get())); + self.size + .set(core::cmp::max(pos + buf_len, self.size.get())); c.complete(buf_len as i32); Ok(()) @@ -163,7 +153,7 @@ impl File for MemoryFile { } fn size(&self) -> Result { - Ok(self.io.size.get() as u64) + Ok(self.size.get() as u64) } } @@ -172,3 +162,19 @@ impl Drop for MemoryFile { // no-op } } + +impl MemoryFile { + #[allow(clippy::mut_from_ref)] + fn get_or_allocate_page(&self, page_no: usize) -> &mut MemPage { + unsafe { + let pages = &mut *self.pages.get(); + pages + .entry(page_no) + .or_insert_with(|| Box::new([0; PAGE_SIZE])) + } + } + + fn get_page(&self, page_no: usize) -> Option<&MemPage> { + unsafe { (*self.pages.get()).get(&page_no) } + } +} diff --git a/core/io/mod.rs b/core/io/mod.rs index 6c1eddb04..519109565 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -1,6 +1,7 @@ use crate::Result; use cfg_block::cfg_block; use std::fmt; +use std::sync::Arc; use std::{ cell::{Ref, RefCell, RefMut}, fmt::Debug, @@ -9,11 +10,11 @@ use std::{ rc::Rc, }; -pub trait File { +pub trait File: Send + Sync { fn lock_file(&self, exclusive: bool) -> Result<()>; fn unlock_file(&self) -> Result<()>; fn pread(&self, pos: usize, c: Completion) -> Result<()>; - fn pwrite(&self, pos: usize, buffer: Rc>, c: Completion) -> Result<()>; + fn pwrite(&self, pos: usize, buffer: Arc>, c: Completion) -> Result<()>; fn sync(&self, c: Completion) -> Result<()>; fn size(&self) -> Result; } @@ -23,8 +24,8 @@ pub enum OpenFlags { Create, } -pub trait IO { - fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result>; +pub trait IO: Send + Sync { + fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result>; fn run_once(&self) -> Result<()>; @@ -33,7 +34,7 @@ pub trait IO { fn get_current_time(&self) -> String; } -pub type Complete = dyn Fn(Rc>); +pub type Complete = dyn Fn(Arc>); pub type WriteComplete = dyn Fn(i32); pub type SyncComplete = dyn Fn(i32); @@ -44,7 +45,7 @@ pub enum Completion { } pub struct ReadCompletion { - pub buf: Rc>, + pub buf: Arc>, pub complete: Box, } @@ -76,7 +77,7 @@ pub struct SyncCompletion { } impl ReadCompletion { - pub fn new(buf: Rc>, complete: Box) -> Self { + pub fn new(buf: Arc>, complete: Box) -> Self { Self { buf, complete } } diff --git a/core/io/unix.rs b/core/io/unix.rs index f1323f2bd..8dfe44736 100644 --- a/core/io/unix.rs +++ b/core/io/unix.rs @@ -9,15 +9,20 @@ use rustix::{ fs::{self, FlockOperation, OFlags, OpenOptionsExt}, io::Errno, }; -use std::io::{ErrorKind, Read, Seek, Write}; -use std::rc::Rc; -use tracing::{debug, trace}; use std::{ cell::{RefCell, UnsafeCell}, mem::MaybeUninit, }; +use std::{ + io::{ErrorKind, Read, Seek, Write}, + sync::Arc, +}; +use tracing::{debug, trace}; struct OwnedCallbacks(UnsafeCell); +// We assume we locking on IO level is done by user. +unsafe impl Send for OwnedCallbacks {} +unsafe impl Sync for OwnedCallbacks {} struct BorrowedCallbacks<'io>(UnsafeCell<&'io mut Callbacks>); impl OwnedCallbacks { @@ -163,6 +168,9 @@ pub struct UnixIO { callbacks: OwnedCallbacks, } +unsafe impl Send for UnixIO {} +unsafe impl Sync for UnixIO {} + impl UnixIO { #[cfg(feature = "fs")] pub fn new() -> Result { @@ -176,7 +184,7 @@ impl UnixIO { } impl IO for UnixIO { - fn open_file(&self, path: &str, flags: OpenFlags, _direct: bool) -> Result> { + fn open_file(&self, path: &str, flags: OpenFlags, _direct: bool) -> Result> { trace!("open_file(path = {})", path); let file = std::fs::File::options() .read(true) @@ -185,8 +193,8 @@ impl IO for UnixIO { .create(matches!(flags, OpenFlags::Create)) .open(path)?; - let unix_file = Rc::new(UnixFile { - file: Rc::new(RefCell::new(file)), + let unix_file = Arc::new(UnixFile { + file: Arc::new(RefCell::new(file)), poller: BorrowedPollHandler(self.poller.as_mut().into()), callbacks: BorrowedCallbacks(self.callbacks.as_mut().into()), }); @@ -245,20 +253,22 @@ impl IO for UnixIO { } enum CompletionCallback { - Read(Rc>, Completion, usize), + Read(Arc>, Completion, usize), Write( - Rc>, + Arc>, Completion, - Rc>, + Arc>, usize, ), } pub struct UnixFile<'io> { - file: Rc>, + file: Arc>, poller: BorrowedPollHandler<'io>, callbacks: BorrowedCallbacks<'io>, } +unsafe impl Send for UnixFile<'_> {} +unsafe impl Sync for UnixFile<'_> {} impl File for UnixFile<'_> { fn lock_file(&self, exclusive: bool) -> Result<()> { @@ -332,7 +342,7 @@ impl File for UnixFile<'_> { } } - fn pwrite(&self, pos: usize, buffer: Rc>, c: Completion) -> Result<()> { + fn pwrite(&self, pos: usize, buffer: Arc>, c: Completion) -> Result<()> { let file = self.file.borrow(); let result = { let buf = buffer.borrow(); diff --git a/core/lib.rs b/core/lib.rs index e08f65f7c..9561b589d 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -36,7 +36,8 @@ use std::borrow::Cow; use std::cell::Cell; use std::collections::HashMap; use std::num::NonZero; -use std::sync::{Arc, OnceLock}; +use std::ops::Deref; +use std::sync::{Arc, Mutex, OnceLock}; use std::{cell::RefCell, rc::Rc}; use storage::btree::btree_init_page; #[cfg(feature = "fs")] @@ -83,16 +84,21 @@ enum TransactionState { } pub struct Database { - pager: Rc, - schema: Rc>, - header: Rc>, - syms: Rc>, + schema: Arc>, + header: Arc>, + syms: Arc>, + page_io: Arc, + io: Arc, + page_size: u16, // Shared structures of a Database are the parts that are common to multiple threads that might // create DB connections. - _shared_page_cache: Arc>, - _shared_wal: Arc>, + shared_page_cache: Arc>, + shared_wal: Arc>, } +unsafe impl Send for Database {} +unsafe impl Sync for Database {} + impl Database { #[cfg(feature = "fs")] pub fn open_file(io: Arc, path: &str) -> Result> { @@ -100,81 +106,73 @@ impl Database { let file = io.open_file(path, OpenFlags::Create, true)?; maybe_init_database_file(&file, &io)?; - let page_io = Rc::new(FileStorage::new(file)); + let page_io = Arc::new(FileStorage::new(file)); let wal_path = format!("{}-wal", path); let db_header = Pager::begin_open(page_io.clone())?; io.run_once()?; - let page_size = db_header.borrow().page_size; + let page_size = db_header.lock().unwrap().page_size; let wal_shared = WalFileShared::open_shared(&io, wal_path.as_str(), page_size)?; - let buffer_pool = Rc::new(BufferPool::new(page_size as usize)); - let wal = Rc::new(RefCell::new(WalFile::new( - io.clone(), - db_header.borrow().page_size as usize, - wal_shared.clone(), - buffer_pool.clone(), - ))); - Self::open(io, page_io, wal, wal_shared, buffer_pool) + Self::open(io, page_io, wal_shared) } #[allow(clippy::arc_with_non_send_sync)] pub fn open( io: Arc, - page_io: Rc, - wal: Rc>, + page_io: Arc, shared_wal: Arc>, - buffer_pool: Rc, ) -> Result> { let db_header = Pager::begin_open(page_io.clone())?; io.run_once()?; DATABASE_VERSION.get_or_init(|| { - let version = db_header.borrow().version_number; + let version = db_header.lock().unwrap().version_number; version.to_string() }); - let _shared_page_cache = Arc::new(RwLock::new(DumbLruPageCache::new(10))); - let pager = Rc::new(Pager::finish_open( - db_header.clone(), - page_io, - wal, - io.clone(), - _shared_page_cache.clone(), - buffer_pool, - )?); + let shared_page_cache = Arc::new(RwLock::new(DumbLruPageCache::new(10))); + let page_size = db_header.lock().unwrap().page_size; let header = db_header; - let schema = Rc::new(RefCell::new(Schema::new())); - let syms = Rc::new(RefCell::new(SymbolTable::new())); + let schema = Arc::new(Mutex::new(Schema::new())); + let syms = Arc::new(Mutex::new(SymbolTable::new())); let db = Database { - pager: pager.clone(), schema: schema.clone(), header: header.clone(), - _shared_page_cache: _shared_page_cache.clone(), - _shared_wal: shared_wal.clone(), + shared_page_cache: shared_page_cache.clone(), + shared_wal: shared_wal.clone(), syms: syms.clone(), + page_io, + io: io.clone(), + page_size }; if let Err(e) = db.register_builtins() { return Err(LimboError::ExtensionError(e)); } let db = Arc::new(db); - let conn = Rc::new(Connection { - db: db.clone(), - pager, - schema: schema.clone(), - header, - auto_commit: RefCell::new(true), - transaction_state: RefCell::new(TransactionState::None), - last_insert_rowid: Cell::new(0), - last_change: Cell::new(0), - total_changes: Cell::new(0), - }); + let conn = db.connect()?; let rows = conn.query("SELECT * FROM sqlite_schema")?; - let mut schema = schema.borrow_mut(); - parse_schema_rows(rows, &mut schema, io, &syms.borrow())?; + let mut schema = schema.lock().unwrap(); + let syms = syms.lock().unwrap(); + parse_schema_rows(rows, &mut schema, io, syms.deref())?; Ok(db) } - pub fn connect(self: &Arc) -> Rc { - Rc::new(Connection { + pub fn connect(self: &Arc) -> Result> { + let buffer_pool = Rc::new(BufferPool::new(self.page_size as usize)); + let wal = Rc::new(RefCell::new(WalFile::new( + self.io.clone(), + self.page_size as usize, + self.shared_wal.clone(), + buffer_pool.clone(), + ))); + let pager = Rc::new(Pager::finish_open( + self.header.clone(), + self.page_io.clone(), + wal, + self.io.clone(), + self.shared_page_cache.clone(), + buffer_pool, + )?); + Ok(Rc::new(Connection { db: self.clone(), - pager: self.pager.clone(), + pager: pager.clone(), schema: self.schema.clone(), header: self.header.clone(), last_insert_rowid: Cell::new(0), @@ -182,7 +180,7 @@ impl Database { transaction_state: RefCell::new(TransactionState::None), last_change: Cell::new(0), total_changes: Cell::new(0), - }) + })) } #[cfg(not(target_family = "wasm"))] @@ -197,7 +195,7 @@ impl Database { let api_ptr: *const ExtensionApi = Box::into_raw(api); let result_code = unsafe { entry(api_ptr) }; if result_code.is_ok() { - self.syms.borrow_mut().extensions.push((lib, api_ptr)); + self.syms.lock().unwrap().extensions.push((lib, api_ptr)); Ok(()) } else { if !api_ptr.is_null() { @@ -210,7 +208,7 @@ impl Database { } } -pub fn maybe_init_database_file(file: &Rc, io: &Arc) -> Result<()> { +pub fn maybe_init_database_file(file: &Arc, io: &Arc) -> Result<()> { if file.size()? == 0 { // init db let db_header = DatabaseHeader::default(); @@ -261,8 +259,8 @@ pub fn maybe_init_database_file(file: &Rc, io: &Arc) -> Result pub struct Connection { db: Arc, pager: Rc, - schema: Rc>, - header: Rc>, + schema: Arc>, + header: Arc>, auto_commit: RefCell, transaction_state: RefCell, last_insert_rowid: Cell, @@ -276,12 +274,12 @@ impl Connection { tracing::trace!("Preparing: {}", sql); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; - let syms = self.db.syms.borrow(); + let syms = self.db.syms.lock().unwrap(); if let Some(cmd) = cmd { match cmd { Cmd::Stmt(stmt) => { let program = Rc::new(translate::translate( - &self.schema.borrow(), + &self.schema.lock().unwrap(), stmt, self.header.clone(), self.pager.clone(), @@ -312,11 +310,11 @@ impl Connection { pub(crate) fn run_cmd(self: &Rc, cmd: Cmd) -> Result> { let db = self.db.clone(); - let syms = db.syms.borrow(); + let syms = db.syms.lock().unwrap(); match cmd { Cmd::Stmt(stmt) => { let program = Rc::new(translate::translate( - &self.schema.borrow(), + &self.schema.lock().unwrap(), stmt, self.header.clone(), self.pager.clone(), @@ -329,7 +327,7 @@ impl Connection { } Cmd::Explain(stmt) => { let program = translate::translate( - &self.schema.borrow(), + &self.schema.lock().unwrap(), stmt, self.header.clone(), self.pager.clone(), @@ -344,8 +342,8 @@ impl Connection { match stmt { ast::Stmt::Select(select) => { let mut plan = - prepare_select_plan(&self.schema.borrow(), *select, &syms, None)?; - optimize_plan(&mut plan, &self.schema.borrow())?; + prepare_select_plan(&self.schema.lock().unwrap(), *select, &syms, None)?; + optimize_plan(&mut plan, &self.schema.lock().unwrap())?; println!("{}", plan); } _ => todo!(), @@ -363,12 +361,12 @@ impl Connection { let sql = sql.as_ref(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; - let syms = self.db.syms.borrow(); + let syms = self.db.syms.lock().unwrap(); if let Some(cmd) = cmd { match cmd { Cmd::Explain(stmt) => { let program = translate::translate( - &self.schema.borrow(), + &self.schema.lock().unwrap(), stmt, self.header.clone(), self.pager.clone(), @@ -381,7 +379,7 @@ impl Connection { Cmd::ExplainQueryPlan(_stmt) => todo!(), Cmd::Stmt(stmt) => { let program = translate::translate( - &self.schema.borrow(), + &self.schema.lock().unwrap(), stmt, self.header.clone(), self.pager.clone(), diff --git a/core/schema.rs b/core/schema.rs index 22130a825..bff535250 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -1,6 +1,7 @@ use crate::VirtualTable; use crate::{util::normalize_ident, Result}; use core::fmt; +use std::sync::Arc; use fallible_iterator::FallibleIterator; use limbo_sqlite3_parser::ast::{Expr, Literal, TableOptions}; use limbo_sqlite3_parser::{ @@ -12,18 +13,18 @@ use std::rc::Rc; use tracing::trace; pub struct Schema { - pub tables: HashMap>, + pub tables: HashMap>, // table_name to list of indexes for the table - pub indexes: HashMap>>, + pub indexes: HashMap>>, } impl Schema { pub fn new() -> Self { - let mut tables: HashMap> = HashMap::new(); - let indexes: HashMap>> = HashMap::new(); + let mut tables: HashMap> = HashMap::new(); + let indexes: HashMap>> = HashMap::new(); tables.insert( "sqlite_schema".to_string(), - Rc::new(Table::BTree(sqlite_schema_table().into())), + Arc::new(Table::BTree(sqlite_schema_table().into())), ); Self { tables, indexes } } @@ -38,7 +39,7 @@ impl Schema { self.tables.insert(name, Table::Virtual(table).into()); } - pub fn get_table(&self, name: &str) -> Option> { + pub fn get_table(&self, name: &str) -> Option> { let name = normalize_ident(name); self.tables.get(&name).cloned() } @@ -52,7 +53,7 @@ impl Schema { } } - pub fn add_index(&mut self, index: Rc) { + pub fn add_index(&mut self, index: Arc) { let table_name = normalize_ident(&index.table_name); self.indexes .entry(table_name) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 0da906386..d6b85d92f 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -1871,7 +1871,7 @@ impl BTreeCursor { let Some(first_page) = first_overflow_page else { return Ok(CursorResult::Ok(())); }; - let page_count = self.pager.db_header.borrow().database_size as usize; + let page_count = self.pager.db_header.lock().unwrap().database_size as usize; let mut pages_left = n_overflow; let mut current_page = first_page; // Clear overflow pages @@ -2752,12 +2752,14 @@ mod tests { use crate::storage::sqlite3_ondisk; use crate::storage::sqlite3_ondisk::DatabaseHeader; use crate::types::Text; + use crate::Connection; use crate::{BufferPool, DatabaseStorage, WalFile, WalFileShared, WriteCompletion}; use std::cell::RefCell; use std::ops::Deref; use std::panic; use std::rc::Rc; use std::sync::Arc; + use std::sync::Mutex; use rand::{thread_rng, Rng}; use tempfile::TempDir; @@ -2785,7 +2787,7 @@ mod tests { let drop_fn = Rc::new(|_| {}); let inner = PageContent { offset: 0, - buffer: Rc::new(RefCell::new(Buffer::new( + buffer: Arc::new(RefCell::new(Buffer::new( BufferData::new(vec![0; 4096]), drop_fn, ))), @@ -2831,7 +2833,7 @@ mod tests { pos: usize, page: &mut PageContent, record: Record, - db: &Arc, + conn: &Rc, ) -> Vec { let mut payload: Vec = Vec::new(); fill_cell_payload( @@ -2840,7 +2842,7 @@ mod tests { &mut payload, &record, 4096, - db.pager.clone(), + conn.pager.clone(), ); insert_into_cell(page, &payload, pos, 4096).unwrap(); payload @@ -2849,11 +2851,12 @@ mod tests { #[test] fn test_insert_cell() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); let header_size = 8; let record = Record::new([OwnedValue::Integer(1)].to_vec()); - let payload = add_record(1, 0, page, record, &db); + let payload = add_record(1, 0, page, record, &conn); assert_eq!(page.cell_count(), 1); let free = compute_free_space(page, 4096); assert_eq!(free, 4096 - payload.len() as u16 - 2 - header_size); @@ -2870,6 +2873,7 @@ mod tests { #[test] fn test_drop_1() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); @@ -2880,7 +2884,7 @@ mod tests { let usable_space = 4096; for i in 0..3 { let record = Record::new([OwnedValue::Integer(i as i64)].to_vec()); - let payload = add_record(i, i, page, record, &db); + let payload = add_record(i, i, page, record, &conn); assert_eq!(page.cell_count(), i + 1); let free = compute_free_space(page, usable_space); total_size += payload.len() as u16 + 2; @@ -3027,9 +3031,9 @@ mod tests { let page_size = db_header.page_size as usize; #[allow(clippy::arc_with_non_send_sync)] - let io: Arc = Arc::new(MemoryIO::new().unwrap()); + let io: Arc = Arc::new(MemoryIO::new()); let io_file = io.open_file("test.db", OpenFlags::Create, false).unwrap(); - let page_io = Rc::new(FileStorage::new(io_file)); + let page_io = Arc::new(FileStorage::new(io_file)); let buffer_pool = Rc::new(BufferPool::new(db_header.page_size as usize)); let wal_shared = WalFileShared::open_shared(&io, "test.wal", db_header.page_size).unwrap(); @@ -3038,7 +3042,7 @@ mod tests { let page_cache = Arc::new(parking_lot::RwLock::new(DumbLruPageCache::new(10))); let pager = { - let db_header = Rc::new(RefCell::new(db_header.clone())); + let db_header = Arc::new(Mutex::new(db_header.clone())); Pager::finish_open(db_header, page_io, wal, io, page_cache, buffer_pool).unwrap() }; let pager = Rc::new(pager); @@ -3208,7 +3212,7 @@ mod tests { let total_cells = 10; for i in 0..total_cells { let record = Record::new([OwnedValue::Integer(i as i64)].to_vec()); - let payload = add_record(i, i, page, record, &db); + let payload = add_record(i, i, page, record, &conn); assert_eq!(page.cell_count(), i + 1); let free = compute_free_space(page, usable_space); total_size += payload.len() as u16 + 2; @@ -3270,12 +3274,12 @@ mod tests { } #[allow(clippy::arc_with_non_send_sync)] - fn setup_test_env(database_size: u32) -> (Rc, Rc>) { + fn setup_test_env(database_size: u32) -> (Rc, Arc>) { let page_size = 512; let mut db_header = DatabaseHeader::default(); db_header.page_size = page_size; db_header.database_size = database_size; - let db_header = Rc::new(RefCell::new(db_header)); + let db_header = Arc::new(Mutex::new(db_header)); let buffer_pool = Rc::new(BufferPool::new(10)); @@ -3285,17 +3289,17 @@ mod tests { buffer_pool.put(Pin::new(vec)); } - let io: Arc = Arc::new(MemoryIO::new().unwrap()); - let page_io = Rc::new(FileStorage::new( + let io: Arc = Arc::new(MemoryIO::new()); + let page_io = Arc::new(FileStorage::new( io.open_file("test.db", OpenFlags::Create, false).unwrap(), )); let drop_fn = Rc::new(|_buf| {}); - let buf = Rc::new(RefCell::new(Buffer::allocate(page_size as usize, drop_fn))); + let buf = Arc::new(RefCell::new(Buffer::allocate(page_size as usize, drop_fn))); { let mut buf_mut = buf.borrow_mut(); let buf_slice = buf_mut.as_mut_slice(); - sqlite3_ondisk::write_header_to_buf(buf_slice, &db_header.borrow()); + sqlite3_ondisk::write_header_to_buf(buf_slice, &db_header.lock().unwrap()); } let write_complete = Box::new(|_| {}); @@ -3343,8 +3347,8 @@ mod tests { let mut current_page = 2u32; while current_page <= 4 { let drop_fn = Rc::new(|_buf| {}); - let buf = Rc::new(RefCell::new(Buffer::allocate( - db_header.borrow().page_size as usize, + let buf = Arc::new(RefCell::new(Buffer::allocate( + db_header.lock().unwrap().page_size as usize, drop_fn, ))); let write_complete = Box::new(|_| {}); @@ -3384,20 +3388,20 @@ mod tests { first_overflow_page: Some(2), // Point to first overflow page }); - let initial_freelist_pages = db_header.borrow().freelist_pages; + let initial_freelist_pages = db_header.lock().unwrap().freelist_pages; // Clear overflow pages let clear_result = cursor.clear_overflow_pages(&leaf_cell)?; match clear_result { CursorResult::Ok(_) => { // Verify proper number of pages were added to freelist assert_eq!( - db_header.borrow().freelist_pages, + db_header.lock().unwrap().freelist_pages, initial_freelist_pages + 3, "Expected 3 pages to be added to freelist" ); // If this is first trunk page - let trunk_page_id = db_header.borrow().freelist_trunk_page; + let trunk_page_id = db_header.lock().unwrap().freelist_trunk_page; if trunk_page_id > 0 { // Verify trunk page structure let trunk_page = cursor.pager.read_page(trunk_page_id as usize)?; @@ -3439,7 +3443,7 @@ mod tests { first_overflow_page: None, }); - let initial_freelist_pages = db_header.borrow().freelist_pages; + let initial_freelist_pages = db_header.lock().unwrap().freelist_pages; // Try to clear non-existent overflow pages let clear_result = cursor.clear_overflow_pages(&leaf_cell)?; @@ -3447,14 +3451,14 @@ mod tests { CursorResult::Ok(_) => { // Verify freelist was not modified assert_eq!( - db_header.borrow().freelist_pages, + db_header.lock().unwrap().freelist_pages, initial_freelist_pages, "Freelist should not change when no overflow pages exist" ); // Verify trunk page wasn't created assert_eq!( - db_header.borrow().freelist_trunk_page, + db_header.lock().unwrap().freelist_trunk_page, 0, "No trunk page should be created when no overflow pages exist" ); @@ -3469,6 +3473,7 @@ mod tests { #[test] pub fn test_defragment() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); @@ -3479,7 +3484,7 @@ mod tests { let usable_space = 4096; for i in 0..3 { let record = Record::new([OwnedValue::Integer(i as i64)].to_vec()); - let payload = add_record(i, i, page, record, &db); + let payload = add_record(i, i, page, record, &conn); assert_eq!(page.cell_count(), i + 1); let free = compute_free_space(page, usable_space); total_size += payload.len() as u16 + 2; @@ -3507,6 +3512,7 @@ mod tests { #[test] pub fn test_drop_odd_with_defragment() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); @@ -3518,7 +3524,7 @@ mod tests { let total_cells = 10; for i in 0..total_cells { let record = Record::new([OwnedValue::Integer(i as i64)].to_vec()); - let payload = add_record(i, i, page, record, &db); + let payload = add_record(i, i, page, record, &conn); assert_eq!(page.cell_count(), i + 1); let free = compute_free_space(page, usable_space); total_size += payload.len() as u16 + 2; @@ -3551,6 +3557,7 @@ mod tests { #[test] pub fn test_fuzz_drop_defragment_insert() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); @@ -3577,7 +3584,7 @@ mod tests { &mut payload, &record, 4096, - db.pager.clone(), + conn.pager.clone(), ); if (free as usize) < payload.len() - 2 { // do not try to insert overflow pages because they require balancing @@ -3616,13 +3623,14 @@ mod tests { #[test] pub fn test_free_space() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); let header_size = 8; let usable_space = 4096; let record = Record::new([OwnedValue::Integer(0)].to_vec()); - let payload = add_record(0, 0, page, record, &db); + let payload = add_record(0, 0, page, record, &conn); let free = compute_free_space(page, usable_space); assert_eq!(free, 4096 - payload.len() as u16 - 2 - header_size); } @@ -3630,13 +3638,14 @@ mod tests { #[test] pub fn test_defragment_1() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); let usable_space = 4096; let record = Record::new([OwnedValue::Integer(0 as i64)].to_vec()); - let payload = add_record(0, 0, page, record, &db); + let payload = add_record(0, 0, page, record, &conn); assert_eq!(page.cell_count(), 1); defragment_page(page, usable_space); @@ -3654,6 +3663,7 @@ mod tests { #[test] pub fn test_insert_drop_insert() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); @@ -3666,14 +3676,14 @@ mod tests { ] .to_vec(), ); - let _ = add_record(0, 0, page, record, &db); + let _ = add_record(0, 0, page, record, &conn); assert_eq!(page.cell_count(), 1); drop_cell(page, 0, usable_space).unwrap(); assert_eq!(page.cell_count(), 0); let record = Record::new([OwnedValue::Integer(0 as i64)].to_vec()); - let payload = add_record(0, 0, page, record, &db); + let payload = add_record(0, 0, page, record, &conn); assert_eq!(page.cell_count(), 1); let (start, len) = page.cell_get_raw_region( @@ -3689,6 +3699,7 @@ mod tests { #[test] pub fn test_insert_drop_insert_multiple() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); @@ -3701,7 +3712,7 @@ mod tests { ] .to_vec(), ); - let _ = add_record(0, 0, page, record, &db); + let _ = add_record(0, 0, page, record, &conn); for _ in 0..100 { assert_eq!(page.cell_count(), 1); @@ -3709,7 +3720,7 @@ mod tests { assert_eq!(page.cell_count(), 0); let record = Record::new([OwnedValue::Integer(0)].to_vec()); - let payload = add_record(0, 0, page, record, &db); + let payload = add_record(0, 0, page, record, &conn); assert_eq!(page.cell_count(), 1); let (start, len) = page.cell_get_raw_region( @@ -3726,17 +3737,18 @@ mod tests { #[test] pub fn test_drop_a_few_insert() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); let usable_space = 4096; let record = Record::new([OwnedValue::Integer(0)].to_vec()); - let payload = add_record(0, 0, page, record, &db); + let payload = add_record(0, 0, page, record, &conn); let record = Record::new([OwnedValue::Integer(1)].to_vec()); - let _ = add_record(1, 1, page, record, &db); + let _ = add_record(1, 1, page, record, &conn); let record = Record::new([OwnedValue::Integer(2)].to_vec()); - let _ = add_record(2, 2, page, record, &db); + let _ = add_record(2, 2, page, record, &conn); drop_cell(page, 1, usable_space).unwrap(); drop_cell(page, 1, usable_space).unwrap(); @@ -3747,38 +3759,40 @@ mod tests { #[test] pub fn test_fuzz_victim_1() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get_contents(); let usable_space = 4096; let record = Record::new([OwnedValue::Integer(0)].to_vec()); - let _ = add_record(0, 0, page, record, &db); + let _ = add_record(0, 0, page, record, &conn); let record = Record::new([OwnedValue::Integer(0)].to_vec()); - let _ = add_record(0, 0, page, record, &db); + let _ = add_record(0, 0, page, record, &conn); drop_cell(page, 0, usable_space).unwrap(); defragment_page(page, usable_space); let record = Record::new([OwnedValue::Integer(0)].to_vec()); - let _ = add_record(0, 1, page, record, &db); + let _ = add_record(0, 1, page, record, &conn); drop_cell(page, 0, usable_space).unwrap(); let record = Record::new([OwnedValue::Integer(0)].to_vec()); - let _ = add_record(0, 1, page, record, &db); + let _ = add_record(0, 1, page, record, &conn); } #[test] pub fn test_fuzz_victim_2() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let usable_space = 4096; let insert = |pos, page| { let record = Record::new([OwnedValue::Integer(0)].to_vec()); - let _ = add_record(0, pos, page, record, &db); + let _ = add_record(0, pos, page, record, &conn); }; let drop = |pos, page| { drop_cell(page, pos, usable_space).unwrap(); @@ -3811,12 +3825,13 @@ mod tests { #[test] pub fn test_fuzz_victim_3() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let usable_space = 4096; let insert = |pos, page| { let record = Record::new([OwnedValue::Integer(0)].to_vec()); - let _ = add_record(0, pos, page, record, &db); + let _ = add_record(0, pos, page, record, &conn); }; let drop = |pos, page| { drop_cell(page, pos, usable_space).unwrap(); @@ -3889,6 +3904,7 @@ mod tests { #[test] pub fn test_big_payload_compute_free() { let db = get_database(); + let conn = db.connect().unwrap(); let page = get_page(2); let usable_space = 4096; @@ -3900,7 +3916,7 @@ mod tests { &mut payload, &record, 4096, - db.pager.clone(), + conn.pager.clone(), ); insert_into_cell(page.get_contents(), &payload, 0, 4096).unwrap(); let free = compute_free_space(page.get_contents(), usable_space); diff --git a/core/storage/database.rs b/core/storage/database.rs index 97bb85721..87be27d72 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -1,25 +1,33 @@ #[cfg(feature = "fs")] use crate::error::LimboError; use crate::{io::Completion, Buffer, Result}; -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, sync::Arc}; /// DatabaseStorage is an interface a database file that consists of pages. /// /// The purpose of this trait is to abstract the upper layers of Limbo from /// the storage medium. A database can either be a file on disk, like in SQLite, /// or something like a remote page server service. -pub trait DatabaseStorage { +pub trait DatabaseStorage: Send + Sync { fn read_page(&self, page_idx: usize, c: Completion) -> Result<()>; - fn write_page(&self, page_idx: usize, buffer: Rc>, c: Completion) - -> Result<()>; + fn write_page( + &self, + page_idx: usize, + buffer: Arc>, + c: Completion, + ) -> Result<()>; fn sync(&self, c: Completion) -> Result<()>; } + #[cfg(feature = "fs")] pub struct FileStorage { - file: Rc, + file: Arc, } +unsafe impl Send for FileStorage {} +unsafe impl Sync for FileStorage {} + #[cfg(feature = "fs")] impl DatabaseStorage for FileStorage { fn read_page(&self, page_idx: usize, c: Completion) -> Result<()> { @@ -40,7 +48,7 @@ impl DatabaseStorage for FileStorage { fn write_page( &self, page_idx: usize, - buffer: Rc>, + buffer: Arc>, c: Completion, ) -> Result<()> { let buffer_size = buffer.borrow().len(); @@ -59,7 +67,7 @@ impl DatabaseStorage for FileStorage { #[cfg(feature = "fs")] impl FileStorage { - pub fn new(file: Rc) -> Self { + pub fn new(file: Arc) -> Self { Self { file } } } diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 526374b96..54ef933de 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -9,7 +9,7 @@ use std::cell::{RefCell, UnsafeCell}; use std::collections::HashSet; use std::rc::Rc; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use tracing::trace; use super::page_cache::{DumbLruPageCache, PageCacheKey}; @@ -152,7 +152,7 @@ struct FlushInfo { /// transaction management. pub struct Pager { /// Source of the database pages. - pub page_io: Rc, + pub page_io: Arc, /// The write-ahead log (WAL) for the database. wal: Rc>, /// A page cache for the database. @@ -162,7 +162,7 @@ pub struct Pager { /// I/O interface for input/output operations. pub io: Arc, dirty_pages: Rc>>, - pub db_header: Rc>, + pub db_header: Arc>, flush_info: RefCell, checkpoint_state: RefCell, @@ -172,14 +172,14 @@ pub struct Pager { impl Pager { /// Begins opening a database by reading the database header. - pub fn begin_open(page_io: Rc) -> Result>> { + pub fn begin_open(page_io: Arc) -> Result>> { sqlite3_ondisk::begin_read_database_header(page_io) } /// Completes opening a database by initializing the Pager with the database header. pub fn finish_open( - db_header_ref: Rc>, - page_io: Rc, + db_header_ref: Arc>, + page_io: Arc, wal: Rc>, io: Arc, page_cache: Arc>, @@ -230,7 +230,7 @@ impl Pager { /// The usable size of a page might be an odd number. However, the usable size is not allowed to be less than 480. /// In other words, if the page size is 512, then the reserved space size cannot exceed 32. pub fn usable_space(&self) -> usize { - let db_header = self.db_header.borrow(); + let db_header = self.db_header.lock().unwrap(); (db_header.page_size - db_header.reserved_space as u16) as usize } @@ -349,7 +349,7 @@ impl Pager { let state = self.flush_info.borrow().state.clone(); match state { FlushState::Start => { - let db_size = self.db_header.borrow().database_size; + let db_size = self.db_header.lock().unwrap().database_size; for page_id in self.dirty_pages.borrow().iter() { let mut cache = self.page_cache.write(); let page_key = @@ -496,7 +496,7 @@ impl Pager { const TRUNK_PAGE_NEXT_PAGE_OFFSET: usize = 0; // Offset to next trunk page pointer const TRUNK_PAGE_LEAF_COUNT_OFFSET: usize = 4; // Offset to leaf count - if page_id < 2 || page_id > self.db_header.borrow().database_size as usize { + if page_id < 2 || page_id > self.db_header.lock().unwrap().database_size as usize { return Err(LimboError::Corrupt(format!( "Invalid page number {} for free operation", page_id @@ -511,9 +511,9 @@ impl Pager { None => self.read_page(page_id)?, }; - self.db_header.borrow_mut().freelist_pages += 1; + self.db_header.lock().unwrap().freelist_pages += 1; - let trunk_page_id = self.db_header.borrow().freelist_trunk_page; + let trunk_page_id = self.db_header.lock().unwrap().freelist_trunk_page; if trunk_page_id != 0 { // Add as leaf to current trunk @@ -551,7 +551,7 @@ impl Pager { // Zero leaf count contents.write_u32(TRUNK_PAGE_LEAF_COUNT_OFFSET, 0); // Update page 1 to point to new trunk - self.db_header.borrow_mut().freelist_trunk_page = page_id as u32; + self.db_header.lock().unwrap().freelist_trunk_page = page_id as u32; // Clear flags page.clear_uptodate(); page.clear_loaded(); @@ -565,7 +565,7 @@ impl Pager { #[allow(clippy::readonly_write_lock)] pub fn allocate_page(&self) -> Result { let header = &self.db_header; - let mut header = RefCell::borrow_mut(header); + let mut header = header.lock().unwrap(); header.database_size += 1; { // update database size @@ -607,7 +607,7 @@ impl Pager { } pub fn usable_size(&self) -> usize { - let db_header = self.db_header.borrow(); + let db_header = self.db_header.lock().unwrap(); (db_header.page_size - db_header.reserved_space as u16) as usize } } @@ -620,7 +620,7 @@ pub fn allocate_page(page_id: usize, buffer_pool: &Rc, offset: usize let drop_fn = Rc::new(move |buf| { bp.put(buf); }); - let buffer = Rc::new(RefCell::new(Buffer::new(buffer, drop_fn))); + let buffer = Arc::new(RefCell::new(Buffer::new(buffer, drop_fn))); page.set_loaded(); page.get().contents = Some(PageContent { offset, diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 0e49d30ae..3adafadac 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -52,7 +52,7 @@ use parking_lot::RwLock; use std::cell::RefCell; use std::pin::Pin; use std::rc::Rc; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use tracing::trace; use super::pager::PageRef; @@ -243,13 +243,13 @@ impl Default for DatabaseHeader { } pub fn begin_read_database_header( - page_io: Rc, -) -> Result>> { + page_io: Arc, +) -> Result>> { let drop_fn = Rc::new(|_buf| {}); - let buf = Rc::new(RefCell::new(Buffer::allocate(512, drop_fn))); - let result = Rc::new(RefCell::new(DatabaseHeader::default())); + let buf = Arc::new(RefCell::new(Buffer::allocate(512, drop_fn))); + let result = Arc::new(Mutex::new(DatabaseHeader::default())); let header = result.clone(); - let complete = Box::new(move |buf: Rc>| { + let complete = Box::new(move |buf: Arc>| { let header = header.clone(); finish_read_database_header(buf, header).unwrap(); }); @@ -259,12 +259,12 @@ pub fn begin_read_database_header( } fn finish_read_database_header( - buf: Rc>, - header: Rc>, + buf: Arc>, + header: Arc>, ) -> Result<()> { let buf = buf.borrow(); let buf = buf.as_slice(); - let mut header = RefCell::borrow_mut(&header); + let mut header = header.lock().unwrap(); header.magic.copy_from_slice(&buf[0..16]); header.page_size = u16::from_be_bytes([buf[16], buf[17]]); header.write_version = buf[18]; @@ -299,10 +299,10 @@ pub fn begin_write_database_header(header: &DatabaseHeader, pager: &Pager) -> Re let page_source = pager.page_io.clone(); let drop_fn = Rc::new(|_buf| {}); - let buffer_to_copy = Rc::new(RefCell::new(Buffer::allocate(512, drop_fn))); + let buffer_to_copy = Arc::new(RefCell::new(Buffer::allocate(512, drop_fn))); let buffer_to_copy_in_cb = buffer_to_copy.clone(); - let read_complete = Box::new(move |buffer: Rc>| { + let read_complete = Box::new(move |buffer: Arc>| { let buffer = buffer.borrow().clone(); let buffer = Rc::new(RefCell::new(buffer)); let mut buf_mut = buffer.borrow_mut(); @@ -312,7 +312,7 @@ pub fn begin_write_database_header(header: &DatabaseHeader, pager: &Pager) -> Re }); let drop_fn = Rc::new(|_buf| {}); - let buf = Rc::new(RefCell::new(Buffer::allocate(512, drop_fn))); + let buf = Arc::new(RefCell::new(Buffer::allocate(512, drop_fn))); let c = Completion::Read(ReadCompletion::new(buf, read_complete)); page_source.read_page(1, c)?; // run get header block @@ -393,7 +393,7 @@ pub struct OverflowCell { #[derive(Debug)] pub struct PageContent { pub offset: usize, - pub buffer: Rc>, + pub buffer: Arc>, pub overflow_cells: Vec, } @@ -401,7 +401,7 @@ impl Clone for PageContent { fn clone(&self) -> Self { Self { offset: self.offset, - buffer: Rc::new(RefCell::new((*self.buffer.borrow()).clone())), + buffer: Arc::new(RefCell::new((*self.buffer.borrow()).clone())), overflow_cells: self.overflow_cells.clone(), } } @@ -673,7 +673,7 @@ impl PageContent { } pub fn begin_read_page( - page_io: Rc, + page_io: Arc, buffer_pool: Rc, page: PageRef, page_idx: usize, @@ -684,8 +684,8 @@ pub fn begin_read_page( let buffer_pool = buffer_pool.clone(); buffer_pool.put(buf); }); - let buf = Rc::new(RefCell::new(Buffer::new(buf, drop_fn))); - let complete = Box::new(move |buf: Rc>| { + let buf = Arc::new(RefCell::new(Buffer::new(buf, drop_fn))); + let complete = Box::new(move |buf: Arc>| { let page = page.clone(); if finish_read_page(page_idx, buf, page.clone()).is_err() { page.set_error(); @@ -696,7 +696,11 @@ pub fn begin_read_page( Ok(()) } -fn finish_read_page(page_idx: usize, buffer_ref: Rc>, page: PageRef) -> Result<()> { +fn finish_read_page( + page_idx: usize, + buffer_ref: Arc>, + page: PageRef, +) -> Result<()> { trace!("finish_read_btree_page(page_idx = {})", page_idx); let pos = if page_idx == 1 { DATABASE_HEADER_SIZE @@ -754,7 +758,7 @@ pub fn begin_write_btree_page( Ok(()) } -pub fn begin_sync(page_io: Rc, syncing: Rc>) -> Result<()> { +pub fn begin_sync(page_io: Arc, syncing: Rc>) -> Result<()> { assert!(!*syncing.borrow()); *syncing.borrow_mut() = true; let completion = Completion::Sync(SyncCompletion { @@ -1248,12 +1252,12 @@ pub fn write_varint_to_vec(value: u64, payload: &mut Vec) { payload.extend_from_slice(&varint[0..n]); } -pub fn begin_read_wal_header(io: &Rc) -> Result>> { +pub fn begin_read_wal_header(io: &Arc) -> Result>> { let drop_fn = Rc::new(|_buf| {}); - let buf = Rc::new(RefCell::new(Buffer::allocate(512, drop_fn))); + let buf = Arc::new(RefCell::new(Buffer::allocate(512, drop_fn))); let result = Arc::new(RwLock::new(WalHeader::default())); let header = result.clone(); - let complete = Box::new(move |buf: Rc>| { + let complete = Box::new(move |buf: Arc>| { let header = header.clone(); finish_read_wal_header(buf, header).unwrap(); }); @@ -1262,7 +1266,7 @@ pub fn begin_read_wal_header(io: &Rc) -> Result> Ok(result) } -fn finish_read_wal_header(buf: Rc>, header: Arc>) -> Result<()> { +fn finish_read_wal_header(buf: Arc>, header: Arc>) -> Result<()> { let buf = buf.borrow(); let buf = buf.as_slice(); let mut header = header.write(); @@ -1278,7 +1282,7 @@ fn finish_read_wal_header(buf: Rc>, header: Arc, + io: &Arc, offset: usize, buffer_pool: Rc, page: PageRef, @@ -1293,9 +1297,9 @@ pub fn begin_read_wal_frame( let buffer_pool = buffer_pool.clone(); buffer_pool.put(buf); }); - let buf = Rc::new(RefCell::new(Buffer::new(buf, drop_fn))); + let buf = Arc::new(RefCell::new(Buffer::new(buf, drop_fn))); let frame = page.clone(); - let complete = Box::new(move |buf: Rc>| { + let complete = Box::new(move |buf: Arc>| { let frame = frame.clone(); finish_read_page(2, buf, frame).unwrap(); }); @@ -1305,7 +1309,7 @@ pub fn begin_read_wal_frame( } pub fn begin_write_wal_frame( - io: &Rc, + io: &Arc, offset: usize, page: &PageRef, db_size: u32, @@ -1357,7 +1361,7 @@ pub fn begin_write_wal_frame( buf[20..24].copy_from_slice(&header.checksum_2.to_be_bytes()); buf[WAL_FRAME_HEADER_SIZE..].copy_from_slice(contents.as_ptr()); - (Rc::new(RefCell::new(buffer)), checksums) + (Arc::new(RefCell::new(buffer)), checksums) }; *write_counter.borrow_mut() += 1; @@ -1379,7 +1383,7 @@ pub fn begin_write_wal_frame( Ok(checksums) } -pub fn begin_write_wal_header(io: &Rc, header: &WalHeader) -> Result<()> { +pub fn begin_write_wal_header(io: &Arc, header: &WalHeader) -> Result<()> { let buffer = { let drop_fn = Rc::new(|_buf| {}); @@ -1395,7 +1399,7 @@ pub fn begin_write_wal_header(io: &Rc, header: &WalHeader) -> Result<( buf[24..28].copy_from_slice(&header.checksum_1.to_be_bytes()); buf[28..32].copy_from_slice(&header.checksum_2.to_be_bytes()); - Rc::new(RefCell::new(buffer)) + Arc::new(RefCell::new(buffer)) }; let write_complete = { diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 30941e9b9..9df9121f0 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -283,7 +283,7 @@ pub struct WalFileShared { // Another memory inefficient array made to just keep track of pages that are in frame_cache. pages_in_frames: Vec, last_checksum: (u32, u32), // Check of last frame in WAL, this is a cumulative checksum over all frames in the WAL - file: Rc, + file: Arc, /// read_locks is a list of read locks that can coexist with the max_frame number stored in /// value. There is a limited amount because and unbounded amount of connections could be /// fatal. Therefore, for now we copy how SQLite behaves with limited amounts of read max @@ -675,7 +675,7 @@ impl WalFile { }); checkpoint_page.get().contents = Some(PageContent { offset: 0, - buffer: Rc::new(RefCell::new(Buffer::new(buffer, drop_fn))), + buffer: Arc::new(RefCell::new(Buffer::new(buffer, drop_fn))), overflow_cells: Vec::new(), }); } diff --git a/core/translate/mod.rs b/core/translate/mod.rs index fe49d05ce..b04991593 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -35,16 +35,16 @@ use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; use limbo_sqlite3_parser::ast::{self, fmt::ToTokens, CreateVirtualTable, Delete, Insert}; use select::translate_select; -use std::cell::RefCell; use std::fmt::Display; use std::rc::{Rc, Weak}; +use std::sync::{Arc, Mutex}; use transaction::{translate_tx_begin, translate_tx_commit}; /// Translate SQL statement into bytecode program. pub fn translate( schema: &Schema, stmt: ast::Stmt, - database_header: Rc>, + database_header: Arc>, pager: Rc, connection: Weak, syms: &SymbolTable, diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 89a8b3a38..925ad8d2b 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, rc::Rc}; +use std::{collections::HashMap, sync::Arc}; use limbo_sqlite3_parser::ast; @@ -76,7 +76,7 @@ fn optimize_subqueries(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { fn query_is_already_ordered_by( table_references: &[TableReference], key: &mut ast::Expr, - available_indexes: &HashMap>>, + available_indexes: &HashMap>>, ) -> Result { let first_table = table_references.first(); if first_table.is_none() { @@ -91,7 +91,7 @@ fn query_is_already_ordered_by( Search::IndexSearch { index, .. } => { let index_rc = key.check_index_scan(0, &table_reference, available_indexes)?; let index_is_the_same = - index_rc.map(|irc| Rc::ptr_eq(index, &irc)).unwrap_or(false); + index_rc.map(|irc| Arc::ptr_eq(index, &irc)).unwrap_or(false); Ok(index_is_the_same) } }, @@ -138,7 +138,7 @@ fn eliminate_unnecessary_orderby(plan: &mut SelectPlan, schema: &Schema) -> Resu */ fn use_indexes( table_references: &mut [TableReference], - available_indexes: &HashMap>>, + available_indexes: &HashMap>>, where_clause: &mut Vec, ) -> Result<()> { if where_clause.is_empty() { @@ -276,8 +276,8 @@ pub trait Optimizable { &mut self, table_index: usize, table_reference: &TableReference, - available_indexes: &HashMap>>, - ) -> Result>>; + available_indexes: &HashMap>>, + ) -> Result>>; } impl Optimizable for ast::Expr { @@ -295,8 +295,8 @@ impl Optimizable for ast::Expr { &mut self, table_index: usize, table_reference: &TableReference, - available_indexes: &HashMap>>, - ) -> Result>> { + available_indexes: &HashMap>>, + ) -> Result>> { match self { Self::Column { table, column, .. } => { if *table != table_index { @@ -497,7 +497,7 @@ pub fn try_extract_index_search_expression( cond: &mut WhereTerm, table_index: usize, table_reference: &TableReference, - available_indexes: &HashMap>>, + available_indexes: &HashMap>>, ) -> Result> { if !cond.should_eval_at_loop(table_index) { return Ok(None); diff --git a/core/translate/plan.rs b/core/translate/plan.rs index f40d5bb38..9fa78a608 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -3,7 +3,7 @@ use limbo_sqlite3_parser::ast; use std::{ cmp::Ordering, fmt::{Display, Formatter}, - rc::Rc, + rc::Rc, sync::Arc, }; use crate::{ @@ -325,7 +325,7 @@ pub enum Search { }, /// A secondary index search. Uses bytecode instructions like SeekGE, SeekGT etc. IndexSearch { - index: Rc, + index: Arc, cmp_op: ast::Operator, cmp_expr: WhereTerm, }, diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index b33f38011..30c9a6674 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -3,8 +3,8 @@ use limbo_sqlite3_parser::ast; use limbo_sqlite3_parser::ast::PragmaName; -use std::cell::RefCell; use std::rc::Rc; +use std::sync::{Arc, Mutex}; use crate::schema::Schema; use crate::storage::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; @@ -38,7 +38,7 @@ pub fn translate_pragma( schema: &Schema, name: &ast::QualifiedName, body: Option, - database_header: Rc>, + database_header: Arc>, pager: Rc, ) -> crate::Result { let mut program = ProgramBuilder::new(ProgramBuilderOpts { @@ -115,7 +115,7 @@ fn update_pragma( pragma: PragmaName, schema: &Schema, value: ast::Expr, - header: Rc>, + header: Arc>, pager: Rc, program: &mut ProgramBuilder, ) -> crate::Result<()> { @@ -166,14 +166,14 @@ fn query_pragma( pragma: PragmaName, schema: &Schema, value: Option, - database_header: Rc>, + database_header: Arc>, program: &mut ProgramBuilder, ) -> crate::Result<()> { let register = program.alloc_register(); match pragma { PragmaName::CacheSize => { program.emit_int( - database_header.borrow().default_page_cache_size.into(), + database_header.lock().unwrap().default_page_cache_size.into(), register, ); program.emit_result_row(register, 1); @@ -261,7 +261,7 @@ fn query_pragma( Ok(()) } -fn update_cache_size(value: i64, header: Rc>, pager: Rc) { +fn update_cache_size(value: i64, header: Arc>, pager: Rc) { let mut cache_size_unformatted: i64 = value; let mut cache_size = if cache_size_unformatted < 0 { let kb = cache_size_unformatted.abs() * 1024; @@ -277,12 +277,12 @@ fn update_cache_size(value: i64, header: Rc>, pager: Rc< } // update in-memory header - header.borrow_mut().default_page_cache_size = cache_size_unformatted + header.lock().unwrap().default_page_cache_size = cache_size_unformatted .try_into() .unwrap_or_else(|_| panic!("invalid value, too big for a i32 {}", value)); // update in disk - let header_copy = header.borrow().clone(); + let header_copy = header.lock().unwrap().clone(); pager.write_database_header(&header_copy); // update cache size diff --git a/core/util.rs b/core/util.rs index 688ba4264..b16c29de3 100644 --- a/core/util.rs +++ b/core/util.rs @@ -70,7 +70,7 @@ pub fn parse_schema_rows( match row.get::<&str>(4) { Ok(sql) => { let index = schema::Index::from_sql(sql, root_page as usize)?; - schema.add_index(Rc::new(index)); + schema.add_index(Arc::new(index)); } _ => { // Automatic index on primary key, e.g. @@ -105,7 +105,7 @@ pub fn parse_schema_rows( let table = schema.get_btree_table(&table_name).unwrap(); let index = schema::Index::automatic_from_primary_key(&table, &index_name, root_page as usize)?; - schema.add_index(Rc::new(index)); + schema.add_index(Arc::new(index)); } } Ok(()) diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 470997158..1e038ceee 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -1,7 +1,7 @@ use std::{ - cell::{Cell, RefCell}, + cell::Cell, collections::HashMap, - rc::{Rc, Weak}, + rc::{Rc, Weak}, sync::{Arc, Mutex}, }; use crate::{ @@ -38,7 +38,7 @@ pub struct ProgramBuilder { #[derive(Debug, Clone)] pub enum CursorType { BTreeTable(Rc), - BTreeIndex(Rc), + BTreeIndex(Arc), Pseudo(Rc), Sorter, VirtualTable(Rc), @@ -437,7 +437,7 @@ impl ProgramBuilder { pub fn build( mut self, - database_header: Rc>, + database_header: Arc>, connection: Weak, change_cnt_on: bool, ) -> Program { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 8b59879b4..c8c64c146 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -72,6 +72,7 @@ use std::collections::HashMap; use std::ffi::c_void; use std::num::NonZero; use std::rc::{Rc, Weak}; +use std::sync::{Arc, Mutex}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] /// Represents a target for a jump instruction. @@ -318,7 +319,7 @@ pub struct Program { pub max_registers: usize, pub insns: Vec, pub cursor_ref: Vec<(Option, CursorType)>, - pub database_header: Rc>, + pub database_header: Arc>, pub comments: Option>, pub parameters: crate::parameters::Parameters, pub connection: Weak, @@ -813,15 +814,15 @@ impl Program { Some(&table_name), &module_name, args, - &conn.db.syms.borrow(), + &conn.db.syms.lock().unwrap(), limbo_ext::VTabKind::VirtualTable, None, )?; { conn.db .syms - .as_ref() - .borrow_mut() + .lock() + .unwrap() .vtabs .insert(table_name, table.clone()); } @@ -2982,7 +2983,7 @@ impl Program { } // SQLite returns "0" on an empty database, and 2 on the first insertion, // so we'll mimic that behavior. - let mut pages = pager.db_header.borrow().database_size.into(); + let mut pages = pager.db_header.lock().unwrap().database_size.into(); if pages == 1 { pages = 0; } @@ -2999,13 +3000,13 @@ impl Program { "SELECT * FROM sqlite_schema WHERE {}", where_clause ))?; - let mut schema = RefCell::borrow_mut(&conn.schema); + let mut schema = conn.schema.lock().unwrap(); // TODO: This function below is synchronous, make it async parse_schema_rows( Some(stmt), &mut schema, conn.pager.io.clone(), - &conn.db.syms.borrow(), + &conn.db.syms.lock().unwrap() )?; state.pc += 1; } @@ -3015,7 +3016,7 @@ impl Program { todo!("temp databases not implemented yet"); } let cookie_value = match cookie { - Cookie::UserVersion => pager.db_header.borrow().user_version.into(), + Cookie::UserVersion => pager.db_header.lock().unwrap().user_version.into(), cookie => todo!("{cookie:?} is not yet implement for ReadCookie"), }; state.registers[*dest] = OwnedValue::Integer(cookie_value); diff --git a/simulator/runner/differential.rs b/simulator/runner/differential.rs index bfadc5687..a1737b70e 100644 --- a/simulator/runner/differential.rs +++ b/simulator/runner/differential.rs @@ -220,7 +220,7 @@ fn execute_plan( if let SimConnection::Disconnected = connection { log::info!("connecting {}", connection_index); - env.connections[connection_index] = SimConnection::Connected(env.db.connect()); + env.connections[connection_index] = SimConnection::Connected(env.db.connect().unwrap()); } else { let limbo_result = execute_interaction(env, connection_index, interaction, &mut state.stack); diff --git a/simulator/runner/execution.rs b/simulator/runner/execution.rs index 822660260..f184e6b60 100644 --- a/simulator/runner/execution.rs +++ b/simulator/runner/execution.rs @@ -117,7 +117,7 @@ fn execute_plan( if let SimConnection::Disconnected = connection { log::info!("connecting {}", connection_index); - env.connections[connection_index] = SimConnection::Connected(env.db.connect()); + env.connections[connection_index] = SimConnection::Connected(env.db.connect().unwrap()); } else { match execute_interaction(env, connection_index, interaction, &mut state.stack) { Ok(next_execution) => { diff --git a/simulator/runner/file.rs b/simulator/runner/file.rs index 6d73af505..3e66a02fa 100644 --- a/simulator/runner/file.rs +++ b/simulator/runner/file.rs @@ -1,8 +1,8 @@ -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, sync::Arc}; use limbo_core::{File, Result}; pub(crate) struct SimulatorFile { - pub(crate) inner: Rc, + pub(crate) inner: Arc, pub(crate) fault: RefCell, /// Number of `pread` function calls (both success and failures). @@ -23,6 +23,9 @@ pub(crate) struct SimulatorFile { pub(crate) page_size: usize, } +unsafe impl Send for SimulatorFile {} +unsafe impl Sync for SimulatorFile {} + impl SimulatorFile { pub(crate) fn inject_fault(&self, fault: bool) { self.fault.replace(fault); @@ -88,7 +91,7 @@ impl File for SimulatorFile { fn pwrite( &self, pos: usize, - buffer: Rc>, + buffer: Arc>, c: limbo_core::Completion, ) -> Result<()> { *self.nr_pwrite_calls.borrow_mut() += 1; diff --git a/simulator/runner/io.rs b/simulator/runner/io.rs index 1034065ac..48340d170 100644 --- a/simulator/runner/io.rs +++ b/simulator/runner/io.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, sync::Arc}; use limbo_core::{OpenFlags, PlatformIO, Result, IO}; use rand::{RngCore, SeedableRng}; @@ -9,12 +9,15 @@ use crate::runner::file::SimulatorFile; pub(crate) struct SimulatorIO { pub(crate) inner: Box, pub(crate) fault: RefCell, - pub(crate) files: RefCell>>, + pub(crate) files: RefCell>>, pub(crate) rng: RefCell, pub(crate) nr_run_once_faults: RefCell, pub(crate) page_size: usize, } +unsafe impl Send for SimulatorIO {} +unsafe impl Sync for SimulatorIO {} + impl SimulatorIO { pub(crate) fn new(seed: u64, page_size: usize) -> Result { let inner = Box::new(PlatformIO::new()?); @@ -55,9 +58,9 @@ impl IO for SimulatorIO { path: &str, flags: OpenFlags, _direct: bool, - ) -> Result> { + ) -> Result> { let inner = self.inner.open_file(path, flags, false)?; - let file = Rc::new(SimulatorFile { + let file = Arc::new(SimulatorFile { inner, fault: RefCell::new(false), nr_pread_faults: RefCell::new(0), diff --git a/simulator/runner/watch.rs b/simulator/runner/watch.rs index 75ecb1801..fb08b705e 100644 --- a/simulator/runner/watch.rs +++ b/simulator/runner/watch.rs @@ -98,7 +98,7 @@ fn execute_plan( if let SimConnection::Disconnected = connection { log::info!("connecting {}", connection_index); - env.connections[connection_index] = SimConnection::Connected(env.db.connect()); + env.connections[connection_index] = SimConnection::Connected(env.db.connect().unwrap()); } else { match execute_interaction(env, connection_index, interaction, &mut state.stack) { Ok(next_execution) => { diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index ea11d3472..43d1ff0fc 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -110,10 +110,7 @@ pub unsafe extern "C" fn sqlite3_open( Err(_) => return SQLITE_MISUSE, }; let io: Arc = match filename { - ":memory:" => match limbo_core::MemoryIO::new() { - Ok(io) => Arc::new(io), - Err(_) => return SQLITE_MISUSE, - }, + ":memory:" => Arc::new(limbo_core::MemoryIO::new()), _ => match limbo_core::PlatformIO::new() { Ok(io) => Arc::new(io), Err(_) => return SQLITE_MISUSE, @@ -121,7 +118,7 @@ pub unsafe extern "C" fn sqlite3_open( }; match limbo_core::Database::open_file(io, filename) { Ok(db) => { - let conn = db.connect(); + let conn = db.connect().unwrap(); *db_out = Box::leak(Box::new(sqlite3::new(db, conn))); SQLITE_OK } diff --git a/tests/integration/common.rs b/tests/integration/common.rs index 06501769a..1831abbd0 100644 --- a/tests/integration/common.rs +++ b/tests/integration/common.rs @@ -8,8 +8,9 @@ use tempfile::TempDir; #[allow(dead_code)] pub struct TempDatabase { pub path: PathBuf, - pub io: Arc, + pub io: Arc, } +unsafe impl Send for TempDatabase {} #[allow(dead_code, clippy::arc_with_non_send_sync)] impl TempDatabase { @@ -20,7 +21,7 @@ impl TempDatabase { pub fn new(db_name: &str) -> Self { let mut path = TempDir::new().unwrap().into_path(); path.push(db_name); - let io: Arc = Arc::new(limbo_core::PlatformIO::new().unwrap()); + let io: Arc = Arc::new(limbo_core::PlatformIO::new().unwrap()); Self { path, io } } @@ -43,10 +44,15 @@ impl TempDatabase { log::debug!("conneting to limbo"); let db = Database::open_file(self.io.clone(), self.path.to_str().unwrap()).unwrap(); - let conn = db.connect(); + let conn = db.connect().unwrap(); log::debug!("connected to limbo"); conn } + + pub fn limbo_database(&self) -> Arc { + log::debug!("conneting to limbo"); + Database::open_file(self.io.clone(), self.path.to_str().unwrap()).unwrap() + } } pub(crate) fn do_flush(conn: &Rc, tmp_db: &TempDatabase) -> anyhow::Result<()> {