diff --git a/bindings/wasm/lib.rs b/bindings/wasm/lib.rs index 88827e79e..3e53b8c04 100644 --- a/bindings/wasm/lib.rs +++ b/bindings/wasm/lib.rs @@ -27,7 +27,7 @@ impl Database { // ensure db header is there io.run_once().unwrap(); - let page_size = db_header.lock().unwrap().page_size; + let page_size = db_header.lock().get_mut().page_size; let wal_path = format!("{}-wal", path); let wal_shared = WalFileShared::open_shared(&io, wal_path.as_str(), page_size).unwrap(); diff --git a/core/fast_lock.rs b/core/fast_lock.rs new file mode 100644 index 000000000..9d2c1dc2b --- /dev/null +++ b/core/fast_lock.rs @@ -0,0 +1,79 @@ +use std::{ + cell::UnsafeCell, + sync::atomic::{AtomicBool, Ordering}, +}; + +#[derive(Debug)] +pub struct FastLock { + lock: AtomicBool, + value: UnsafeCell, +} + +pub struct FastLockGuard<'a, T> { + lock: &'a FastLock, +} + +impl<'a, T> FastLockGuard<'a, T> { + pub fn get_mut(&self) -> &mut T { + self.lock.get_mut() + } +} + +impl<'a, T> Drop for FastLockGuard<'a, T> { + fn drop(&mut self) { + self.lock.unlock(); + } +} + +unsafe impl Send for FastLock {} +unsafe impl Sync for FastLock {} + +impl FastLock { + pub fn new(value: T) -> Self { + Self { + lock: AtomicBool::new(false), + value: UnsafeCell::new(value), + } + } + + pub fn lock(&self) -> FastLockGuard { + while self.lock.compare_and_swap(false, true, Ordering::Acquire) { + std::thread::yield_now(); + } + FastLockGuard { lock: self } + } + + pub fn unlock(&self) { + assert!(self.lock.compare_and_swap(true, false, Ordering::Acquire)); + } + + pub fn get_mut(&self) -> &mut T { + unsafe { self.value.get().as_mut().unwrap() } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::FastLock; + + #[test] + fn test_fast_lock_multiple_thread_sum() { + let lock = Arc::new(FastLock::new(0)); + let mut threads = vec![]; + const NTHREADS: usize = 1000; + for _ in 0..NTHREADS { + let lock = lock.clone(); + threads.push(std::thread::spawn(move || { + lock.lock(); + let value = lock.get_mut(); + *value += 1; + })); + } + for thread in threads { + thread.join().unwrap(); + } + assert_eq!(*lock.get_mut(), NTHREADS); + } +} diff --git a/core/lib.rs b/core/lib.rs index eb2036d9e..3f9df8378 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1,5 +1,6 @@ mod error; mod ext; +mod fast_lock; mod function; mod functions; mod info; @@ -25,6 +26,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; use ext::list_vfs_modules; use fallible_iterator::FallibleIterator; +use fast_lock::FastLock; use limbo_ext::{ResultCode, VTabKind, VTabModuleImpl}; use limbo_sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser}; use parking_lot::RwLock; @@ -88,7 +90,7 @@ pub struct Database { mv_store: Option>, schema: Arc>, // TODO: make header work without lock - header: Arc>, + header: Arc>, page_io: Arc, io: Arc, page_size: u16, @@ -112,7 +114,7 @@ impl Database { let wal_path = format!("{}-wal", path); let db_header = Pager::begin_open(page_io.clone())?; io.run_once()?; - let page_size = db_header.lock().unwrap().page_size; + let page_size = db_header.lock().get_mut().page_size; let wal_shared = WalFileShared::open_shared(&io, wal_path.as_str(), page_size)?; Self::open(io, page_io, wal_shared, enable_mvcc) } @@ -127,7 +129,7 @@ impl Database { let db_header = Pager::begin_open(page_io.clone())?; io.run_once()?; DATABASE_VERSION.get_or_init(|| { - let version = db_header.lock().unwrap().version_number; + let version = db_header.lock().get_mut().version_number; version.to_string() }); let mv_store = if enable_mvcc { @@ -139,7 +141,7 @@ impl Database { None }; let shared_page_cache = Arc::new(RwLock::new(DumbLruPageCache::new(10))); - let page_size = db_header.lock().unwrap().page_size; + let page_size = db_header.lock().get_mut().page_size; let header = db_header; let schema = Arc::new(RwLock::new(Schema::new())); let db = Database { @@ -280,7 +282,7 @@ pub struct Connection { _db: Arc, pager: Rc, schema: Arc>, - header: Arc>, + header: Arc>, auto_commit: RefCell, mv_transactions: RefCell>, transaction_state: RefCell, diff --git a/core/storage/btree.rs b/core/storage/btree.rs index bdbaf7ce9..ee8351bd3 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -1961,7 +1961,7 @@ impl BTreeCursor { OverflowState::ProcessPage { next_page } => { if next_page < 2 || next_page as usize - > self.pager.db_header.lock().unwrap().database_size as usize + > self.pager.db_header.lock().get_mut().database_size as usize { self.overflow_state = None; return Err(LimboError::Corrupt("Invalid overflow page number".into())); @@ -3037,6 +3037,7 @@ mod tests { use test_log::test; use super::*; + use crate::fast_lock::FastLock; use crate::io::{Buffer, Completion, MemoryIO, OpenFlags, IO}; use crate::storage::database::FileStorage; use crate::storage::page_cache::DumbLruPageCache; @@ -3332,7 +3333,7 @@ mod tests { let page_cache = Arc::new(parking_lot::RwLock::new(DumbLruPageCache::new(10))); let pager = { - let db_header = Arc::new(Mutex::new(db_header.clone())); + let db_header = Arc::new(FastLock::new(db_header.clone())); Pager::finish_open(db_header, page_io, wal, io, page_cache, buffer_pool).unwrap() }; let pager = Rc::new(pager); @@ -3564,12 +3565,12 @@ mod tests { } #[allow(clippy::arc_with_non_send_sync)] - fn setup_test_env(database_size: u32) -> (Rc, Arc>) { + fn setup_test_env(database_size: u32) -> (Rc, Arc>) { let page_size = 512; let mut db_header = DatabaseHeader::default(); db_header.page_size = page_size; db_header.database_size = database_size; - let db_header = Arc::new(Mutex::new(db_header)); + let db_header = Arc::new(FastLock::new(db_header)); let buffer_pool = Rc::new(BufferPool::new(10)); @@ -3589,7 +3590,7 @@ mod tests { { let mut buf_mut = buf.borrow_mut(); let buf_slice = buf_mut.as_mut_slice(); - sqlite3_ondisk::write_header_to_buf(buf_slice, &db_header.lock().unwrap()); + sqlite3_ondisk::write_header_to_buf(buf_slice, &db_header.lock().get_mut()); } let write_complete = Box::new(|_| {}); @@ -3639,7 +3640,7 @@ mod tests { let drop_fn = Rc::new(|_buf| {}); #[allow(clippy::arc_with_non_send_sync)] let buf = Arc::new(RefCell::new(Buffer::allocate( - db_header.lock().unwrap().page_size as usize, + db_header.lock().get_mut().page_size as usize, drop_fn, ))); let write_complete = Box::new(|_| {}); @@ -3679,20 +3680,20 @@ mod tests { first_overflow_page: Some(2), // Point to first overflow page }); - let initial_freelist_pages = db_header.lock().unwrap().freelist_pages; + let initial_freelist_pages = db_header.lock().get_mut().freelist_pages; // Clear overflow pages let clear_result = cursor.clear_overflow_pages(&leaf_cell)?; match clear_result { CursorResult::Ok(_) => { // Verify proper number of pages were added to freelist assert_eq!( - db_header.lock().unwrap().freelist_pages, + db_header.lock().get_mut().freelist_pages, initial_freelist_pages + 3, "Expected 3 pages to be added to freelist" ); // If this is first trunk page - let trunk_page_id = db_header.lock().unwrap().freelist_trunk_page; + let trunk_page_id = db_header.lock().get_mut().freelist_trunk_page; if trunk_page_id > 0 { // Verify trunk page structure let trunk_page = cursor.pager.read_page(trunk_page_id as usize)?; @@ -3734,7 +3735,7 @@ mod tests { first_overflow_page: None, }); - let initial_freelist_pages = db_header.lock().unwrap().freelist_pages; + let initial_freelist_pages = db_header.lock().get_mut().freelist_pages; // Try to clear non-existent overflow pages let clear_result = cursor.clear_overflow_pages(&leaf_cell)?; @@ -3742,14 +3743,14 @@ mod tests { CursorResult::Ok(_) => { // Verify freelist was not modified assert_eq!( - db_header.lock().unwrap().freelist_pages, + db_header.lock().get_mut().freelist_pages, initial_freelist_pages, "Freelist should not change when no overflow pages exist" ); // Verify trunk page wasn't created assert_eq!( - db_header.lock().unwrap().freelist_trunk_page, + db_header.lock().get_mut().freelist_trunk_page, 0, "No trunk page should be created when no overflow pages exist" ); @@ -3768,7 +3769,7 @@ mod tests { let (pager, db_header) = setup_test_env(initial_size); let mut cursor = BTreeCursor::new(None, pager.clone(), 2); assert_eq!( - db_header.lock().unwrap().database_size, + db_header.lock().get_mut().database_size, initial_size, "Database should initially have 3 pages" ); @@ -3828,18 +3829,18 @@ mod tests { // Verify structure before destruction assert_eq!( - db_header.lock().unwrap().database_size, + db_header.lock().get_mut().database_size, 5, // We should have pages 0-4 "Database should have 4 pages total" ); // Track freelist state before destruction - let initial_free_pages = db_header.lock().unwrap().freelist_pages; + let initial_free_pages = db_header.lock().get_mut().freelist_pages; assert_eq!(initial_free_pages, 0, "should start with no free pages"); run_until_done(|| cursor.btree_destroy(), pager.deref())?; - let pages_freed = db_header.lock().unwrap().freelist_pages - initial_free_pages; + let pages_freed = db_header.lock().get_mut().freelist_pages - initial_free_pages; assert_eq!(pages_freed, 3, "should free 3 pages (root + 2 leaves)"); Ok(()) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 5c43eee1a..0db1f1b42 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1,3 +1,4 @@ +use crate::fast_lock::FastLock; use crate::result::LimboResult; use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; @@ -162,7 +163,7 @@ pub struct Pager { /// I/O interface for input/output operations. pub io: Arc, dirty_pages: Rc>>, - pub db_header: Arc>, + pub db_header: Arc>, flush_info: RefCell, checkpoint_state: RefCell, @@ -172,13 +173,13 @@ pub struct Pager { impl Pager { /// Begins opening a database by reading the database header. - pub fn begin_open(page_io: Arc) -> Result>> { + pub fn begin_open(page_io: Arc) -> Result>> { sqlite3_ondisk::begin_read_database_header(page_io) } /// Completes opening a database by initializing the Pager with the database header. pub fn finish_open( - db_header_ref: Arc>, + db_header_ref: Arc>, page_io: Arc, wal: Rc>, io: Arc, @@ -230,8 +231,8 @@ impl Pager { /// The usable size of a page might be an odd number. However, the usable size is not allowed to be less than 480. /// In other words, if the page size is 512, then the reserved space size cannot exceed 32. pub fn usable_space(&self) -> usize { - let db_header = self.db_header.lock().unwrap(); - (db_header.page_size - db_header.reserved_space as u16) as usize + let db_header = self.db_header.lock(); + (db_header.get_mut().page_size - db_header.get_mut().reserved_space as u16) as usize } pub fn begin_read_tx(&self) -> Result { @@ -351,7 +352,7 @@ impl Pager { trace!("cacheflush {:?}", state); match state { FlushState::Start => { - let db_size = self.db_header.lock().unwrap().database_size; + let db_size = self.db_header.lock().get_mut().database_size; for page_id in self.dirty_pages.borrow().iter() { let mut cache = self.page_cache.write(); let page_key = @@ -502,7 +503,7 @@ impl Pager { const TRUNK_PAGE_NEXT_PAGE_OFFSET: usize = 0; // Offset to next trunk page pointer const TRUNK_PAGE_LEAF_COUNT_OFFSET: usize = 4; // Offset to leaf count - if page_id < 2 || page_id > self.db_header.lock().unwrap().database_size as usize { + if page_id < 2 || page_id > self.db_header.lock().get_mut().database_size as usize { return Err(LimboError::Corrupt(format!( "Invalid page number {} for free operation", page_id @@ -517,9 +518,9 @@ impl Pager { None => self.read_page(page_id)?, }; - self.db_header.lock().unwrap().freelist_pages += 1; + self.db_header.lock().get_mut().freelist_pages += 1; - let trunk_page_id = self.db_header.lock().unwrap().freelist_trunk_page; + let trunk_page_id = self.db_header.lock().get_mut().freelist_trunk_page; if trunk_page_id != 0 { // Add as leaf to current trunk @@ -557,7 +558,7 @@ impl Pager { // Zero leaf count contents.write_u32(TRUNK_PAGE_LEAF_COUNT_OFFSET, 0); // Update page 1 to point to new trunk - self.db_header.lock().unwrap().freelist_trunk_page = page_id as u32; + self.db_header.lock().get_mut().freelist_trunk_page = page_id as u32; // Clear flags page.clear_uptodate(); page.clear_loaded(); @@ -571,8 +572,8 @@ impl Pager { #[allow(clippy::readonly_write_lock)] pub fn allocate_page(&self) -> Result { let header = &self.db_header; - let mut header = header.lock().unwrap(); - header.database_size += 1; + let header = header.lock(); + header.get_mut().database_size += 1; { // update database size // read sync for now @@ -586,12 +587,16 @@ impl Pager { self.add_dirty(1); let contents = first_page_ref.get().contents.as_ref().unwrap(); - contents.write_database_header(&header); + contents.write_database_header(&header.get_mut()); break; } } - let page = allocate_page(header.database_size as usize, &self.buffer_pool, 0); + let page = allocate_page( + header.get_mut().database_size as usize, + &self.buffer_pool, + 0, + ); { // setup page and add to cache page.set_dirty(); @@ -613,8 +618,8 @@ impl Pager { } pub fn usable_size(&self) -> usize { - let db_header = self.db_header.lock().unwrap(); - (db_header.page_size - db_header.reserved_space as u16) as usize + let db_header = self.db_header.lock(); + (db_header.get_mut().page_size - db_header.get_mut().reserved_space as u16) as usize } } diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 7ec197517..2901e5a2e 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -42,6 +42,7 @@ //! https://www.sqlite.org/fileformat.html use crate::error::LimboError; +use crate::fast_lock::FastLock; use crate::io::{Buffer, Completion, ReadCompletion, SyncCompletion, WriteCompletion}; use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; @@ -244,11 +245,11 @@ impl Default for DatabaseHeader { pub fn begin_read_database_header( page_io: Arc, -) -> Result>> { +) -> Result>> { let drop_fn = Rc::new(|_buf| {}); #[allow(clippy::arc_with_non_send_sync)] let buf = Arc::new(RefCell::new(Buffer::allocate(512, drop_fn))); - let result = Arc::new(Mutex::new(DatabaseHeader::default())); + let result = Arc::new(FastLock::new(DatabaseHeader::default())); let header = result.clone(); let complete = Box::new(move |buf: Arc>| { let header = header.clone(); @@ -261,11 +262,12 @@ pub fn begin_read_database_header( fn finish_read_database_header( buf: Arc>, - header: Arc>, + header: Arc>, ) -> Result<()> { let buf = buf.borrow(); let buf = buf.as_slice(); - let mut header = header.lock().unwrap(); + let header = header.lock(); + let header = header.get_mut(); header.magic.copy_from_slice(&buf[0..16]); header.page_size = u16::from_be_bytes([buf[16], buf[17]]); header.write_version = buf[18]; diff --git a/core/translate/mod.rs b/core/translate/mod.rs index ea95df1d2..1639d61f2 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -24,6 +24,7 @@ pub(crate) mod select; pub(crate) mod subquery; pub(crate) mod transaction; +use crate::fast_lock::FastLock; use crate::schema::Schema; use crate::storage::pager::Pager; use crate::storage::sqlite3_ondisk::DatabaseHeader; @@ -45,7 +46,7 @@ use transaction::{translate_tx_begin, translate_tx_commit}; pub fn translate( schema: &Schema, stmt: ast::Stmt, - database_header: Arc>, + database_header: Arc>, pager: Rc, connection: Weak, syms: &SymbolTable, diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 6d9e44e8f..8d9be8502 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -6,6 +6,7 @@ use limbo_sqlite3_parser::ast::PragmaName; use std::rc::Rc; use std::sync::{Arc, Mutex}; +use crate::fast_lock::FastLock; use crate::schema::Schema; use crate::storage::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; use crate::storage::wal::CheckpointMode; @@ -38,7 +39,7 @@ pub fn translate_pragma( schema: &Schema, name: &ast::QualifiedName, body: Option, - database_header: Arc>, + database_header: Arc>, pager: Rc, ) -> crate::Result { let mut program = ProgramBuilder::new(ProgramBuilderOpts { @@ -115,7 +116,7 @@ fn update_pragma( pragma: PragmaName, schema: &Schema, value: ast::Expr, - header: Arc>, + header: Arc>, pager: Rc, program: &mut ProgramBuilder, ) -> crate::Result<()> { @@ -166,7 +167,7 @@ fn query_pragma( pragma: PragmaName, schema: &Schema, value: Option, - database_header: Arc>, + database_header: Arc>, program: &mut ProgramBuilder, ) -> crate::Result<()> { let register = program.alloc_register(); @@ -175,7 +176,7 @@ fn query_pragma( program.emit_int( database_header .lock() - .unwrap() + .get_mut() .default_page_cache_size .into(), register, @@ -265,7 +266,7 @@ fn query_pragma( Ok(()) } -fn update_cache_size(value: i64, header: Arc>, pager: Rc) { +fn update_cache_size(value: i64, header: Arc>, pager: Rc) { let mut cache_size_unformatted: i64 = value; let mut cache_size = if cache_size_unformatted < 0 { let kb = cache_size_unformatted.abs() * 1024; @@ -281,12 +282,12 @@ fn update_cache_size(value: i64, header: Arc>, pager: Rc

>, + database_header: Arc>, connection: Weak, change_cnt_on: bool, ) -> Program { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 7621146d6..9d44377d4 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -25,6 +25,7 @@ pub mod sorter; use crate::error::{LimboError, SQLITE_CONSTRAINT_PRIMARYKEY}; use crate::ext::ExtValue; +use crate::fast_lock::FastLock; use crate::function::{AggFunc, ExtFunc, FuncCtx, MathFunc, MathFuncArity, ScalarFunc, VectorFunc}; use crate::functions::datetime::{ exec_date, exec_datetime_full, exec_julianday, exec_strftime, exec_time, exec_unixepoch, @@ -332,7 +333,7 @@ pub struct Program { pub max_registers: usize, pub insns: Vec, pub cursor_ref: Vec<(Option, CursorType)>, - pub database_header: Arc>, + pub database_header: Arc>, pub comments: Option>, pub parameters: crate::parameters::Parameters, pub connection: Weak, @@ -3073,7 +3074,7 @@ impl Program { } // SQLite returns "0" on an empty database, and 2 on the first insertion, // so we'll mimic that behavior. - let mut pages = pager.db_header.lock().unwrap().database_size.into(); + let mut pages = pager.db_header.lock().get_mut().database_size.into(); if pages == 1 { pages = 0; } @@ -3107,7 +3108,7 @@ impl Program { todo!("temp databases not implemented yet"); } let cookie_value = match cookie { - Cookie::UserVersion => pager.db_header.lock().unwrap().user_version.into(), + Cookie::UserVersion => pager.db_header.lock().get_mut().user_version.into(), cookie => todo!("{cookie:?} is not yet implement for ReadCookie"), }; state.registers[*dest] = OwnedValue::Integer(cookie_value);