diff --git a/core/lib.rs b/core/lib.rs index 130503ede..3d169121a 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -41,6 +41,7 @@ mod numeric; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; +use crate::storage::sqlite3_ondisk::is_valid_page_size; use crate::storage::{header_accessor, wal::DummyWAL}; use crate::translate::optimizer::optimize_plan; use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME; @@ -77,7 +78,7 @@ use std::{ use storage::database::DatabaseFile; use storage::page_cache::DumbLruPageCache; pub use storage::pager::PagerCacheflushStatus; -use storage::pager::{DB_STATE_INITIALIZED, DB_STATE_UNITIALIZED}; +use storage::pager::{DB_STATE_INITIALIZED, DB_STATE_UNINITIALIZED}; pub use storage::{ buffer_pool::BufferPool, database::DatabaseStorage, @@ -117,7 +118,7 @@ pub struct Database { // create DB connections. _shared_page_cache: Arc>, maybe_shared_wal: RwLock>>>, - is_empty: Arc, + db_state: Arc, init_lock: Arc>, open_flags: OpenFlags, } @@ -178,7 +179,6 @@ impl Database { ) -> Result> { let wal_path = format!("{path}-wal"); let maybe_shared_wal = WalFileShared::open_shared_if_exists(&io, wal_path.as_str())?; - let db_size = db_file.size()?; let mv_store = if enable_mvcc { Some(Rc::new(MvStore::new( @@ -188,12 +188,10 @@ impl Database { } else { None }; - let wal_has_frames = maybe_shared_wal - .as_ref() - .is_some_and(|wal| unsafe { &*wal.get() }.max_frame.load(Ordering::SeqCst) > 0); - let is_empty = if db_size == 0 && !wal_has_frames { - DB_STATE_UNITIALIZED + let db_size = db_file.size()?; + let db_state = if db_size == 0 { + DB_STATE_UNINITIALIZED } else { DB_STATE_INITIALIZED }; @@ -209,13 +207,13 @@ impl Database { db_file, io: io.clone(), open_flags: flags, - is_empty: Arc::new(AtomicUsize::new(is_empty)), + db_state: Arc::new(AtomicUsize::new(db_state)), init_lock: Arc::new(Mutex::new(())), }; let db = Arc::new(db); // Check: https://github.com/tursodatabase/turso/pull/1761#discussion_r2154013123 - if is_empty == 2 { + if db_state == DB_STATE_INITIALIZED { // parse schema let conn = db.connect()?; let schema_version = get_schema_version(&conn)?; @@ -226,10 +224,9 @@ impl Database { .expect("lock on schema should succeed first try"); let syms = conn.syms.borrow(); + let pager = conn.pager.borrow().clone(); - if let Err(LimboError::ExtensionError(e)) = - schema.make_from_btree(None, conn.pager.clone(), &syms) - { + if let Err(LimboError::ExtensionError(e)) = schema.make_from_btree(None, pager, &syms) { // this means that a vtab exists and we no longer have the module loaded. we print // a warning to the user to load the module eprintln!("Warning: {e}"); @@ -239,90 +236,16 @@ impl Database { } pub fn connect(self: &Arc) -> Result> { - let buffer_pool = Arc::new(BufferPool::new(None)); + let pager = self.init_pager(None)?; - // Open existing WAL file if present - if let Some(shared_wal) = self.maybe_shared_wal.read().clone() { - // No pages in DB file or WAL -> empty database - let is_empty = self.is_empty.clone(); - let wal = Rc::new(RefCell::new(WalFile::new( - self.io.clone(), - shared_wal, - buffer_pool.clone(), - ))); - let pager = Rc::new(Pager::new( - self.db_file.clone(), - wal, - self.io.clone(), - Arc::new(RwLock::new(DumbLruPageCache::default())), - buffer_pool, - is_empty, - self.init_lock.clone(), - )?); - - let page_size = header_accessor::get_page_size(&pager) - .unwrap_or(storage::sqlite3_ondisk::DEFAULT_PAGE_SIZE) - as u32; - let default_cache_size = header_accessor::get_default_page_cache_size(&pager) - .unwrap_or(storage::sqlite3_ondisk::DEFAULT_CACHE_SIZE); - pager.buffer_pool.set_page_size(page_size as usize); - let conn = Arc::new(Connection { - _db: self.clone(), - pager: pager.clone(), - schema: RefCell::new(self.schema.read().clone()), - last_insert_rowid: Cell::new(0), - auto_commit: Cell::new(true), - mv_transactions: RefCell::new(Vec::new()), - transaction_state: Cell::new(TransactionState::None), - last_change: Cell::new(0), - syms: RefCell::new(SymbolTable::new()), - total_changes: Cell::new(0), - _shared_cache: false, - cache_size: Cell::new(default_cache_size), - readonly: Cell::new(false), - wal_checkpoint_disabled: Cell::new(false), - capture_data_changes: RefCell::new(CaptureDataChangesMode::Off), - closed: Cell::new(false), - }); - if let Err(e) = conn.register_builtins() { - return Err(LimboError::ExtensionError(e)); - } - return Ok(conn); - }; - - // No existing WAL; create one. - // TODO: currently Pager needs to be instantiated with some implementation of trait Wal, so here's a workaround. - let dummy_wal = Rc::new(RefCell::new(DummyWAL {})); - let is_empty = self.is_empty.clone(); - let mut pager = Pager::new( - self.db_file.clone(), - dummy_wal, - self.io.clone(), - Arc::new(RwLock::new(DumbLruPageCache::default())), - buffer_pool.clone(), - is_empty, - Arc::new(Mutex::new(())), - )?; let page_size = header_accessor::get_page_size(&pager) - .unwrap_or(storage::sqlite3_ondisk::DEFAULT_PAGE_SIZE) as u32; + .unwrap_or(storage::sqlite3_ondisk::DEFAULT_PAGE_SIZE); let default_cache_size = header_accessor::get_default_page_cache_size(&pager) .unwrap_or(storage::sqlite3_ondisk::DEFAULT_CACHE_SIZE); - let wal_path = format!("{}-wal", self.path); - let file = self.io.open_file(&wal_path, OpenFlags::Create, false)?; - let real_shared_wal = WalFileShared::new_shared(page_size, &self.io, file)?; - // Modify Database::maybe_shared_wal to point to the new WAL file so that other connections - // can open the existing WAL. - *self.maybe_shared_wal.write() = Some(real_shared_wal.clone()); - let wal = Rc::new(RefCell::new(WalFile::new( - self.io.clone(), - real_shared_wal, - buffer_pool, - ))); - pager.set_wal(wal); let conn = Arc::new(Connection { _db: self.clone(), - pager: Rc::new(pager), + pager: RefCell::new(Rc::new(pager)), schema: RefCell::new(self.schema.read().clone()), auto_commit: Cell::new(true), mv_transactions: RefCell::new(Vec::new()), @@ -333,6 +256,7 @@ impl Database { syms: RefCell::new(SymbolTable::new()), _shared_cache: false, cache_size: Cell::new(default_cache_size), + page_size: Cell::new(page_size), readonly: Cell::new(false), wal_checkpoint_disabled: Cell::new(false), capture_data_changes: RefCell::new(CaptureDataChangesMode::Off), @@ -345,6 +269,74 @@ impl Database { Ok(conn) } + fn init_pager(&self, page_size: Option) -> Result { + // Open existing WAL file if present + if let Some(shared_wal) = self.maybe_shared_wal.read().clone() { + let size = match page_size { + None => unsafe { (*shared_wal.get()).page_size() as usize }, + Some(size) => size, + }; + let buffer_pool = Arc::new(BufferPool::new(Some(size))); + + let db_state = self.db_state.clone(); + let wal = Rc::new(RefCell::new(WalFile::new( + self.io.clone(), + shared_wal, + buffer_pool.clone(), + ))); + let pager = Pager::new( + self.db_file.clone(), + wal, + self.io.clone(), + Arc::new(RwLock::new(DumbLruPageCache::default())), + buffer_pool.clone(), + db_state, + self.init_lock.clone(), + )?; + return Ok(pager); + } + + let buffer_pool = Arc::new(BufferPool::new(page_size)); + // No existing WAL; create one. + // TODO: currently Pager needs to be instantiated with some implementation of trait Wal, so here's a workaround. + let dummy_wal = Rc::new(RefCell::new(DummyWAL {})); + let db_state = self.db_state.clone(); + let mut pager = Pager::new( + self.db_file.clone(), + dummy_wal, + self.io.clone(), + Arc::new(RwLock::new(DumbLruPageCache::default())), + buffer_pool.clone(), + db_state, + Arc::new(Mutex::new(())), + )?; + + let size = match page_size { + Some(size) => size as u32, + None => { + let size = header_accessor::get_page_size(&pager) + .unwrap_or(storage::sqlite3_ondisk::DEFAULT_PAGE_SIZE); + buffer_pool.set_page_size(size as usize); + size + } + }; + + let wal_path = format!("{}-wal", self.path); + let file = self.io.open_file(&wal_path, OpenFlags::Create, false)?; + let real_shared_wal = WalFileShared::new_shared(size, &self.io, file)?; + // Modify Database::maybe_shared_wal to point to the new WAL file so that other connections + // can open the existing WAL. + *self.maybe_shared_wal.write() = Some(real_shared_wal.clone()); + let wal = Rc::new(RefCell::new(WalFile::new( + self.io.clone(), + real_shared_wal, + buffer_pool, + ))); + pager.set_wal(wal); + + Ok(pager) + } + /// Open a new database file with optionally specifying a VFS without an existing database /// connection and symbol table to register extensions. #[cfg(feature = "fs")] @@ -499,7 +491,7 @@ impl CaptureDataChangesMode { pub struct Connection { _db: Arc, - pager: Rc, + pager: RefCell>, schema: RefCell, /// Whether to automatically commit transaction auto_commit: Cell, @@ -511,6 +503,9 @@ pub struct Connection { syms: RefCell, _shared_cache: bool, cache_size: Cell, + /// page size used for an uninitialized database or the next vacuum command. + /// it's not always equal to the current page size of the database + page_size: Cell, readonly: Cell, wal_checkpoint_disabled: Cell, capture_data_changes: RefCell, @@ -540,22 +535,19 @@ impl Connection { .unwrap() .trim(); self.maybe_update_schema(); + let pager = self.pager.borrow().clone(); match cmd { Cmd::Stmt(stmt) => { let program = Rc::new(translate::translate( self.schema.borrow().deref(), stmt, - self.pager.clone(), + pager.clone(), self.clone(), &syms, QueryMode::Normal, input, )?); - Ok(Statement::new( - program, - self._db.mv_store.clone(), - self.pager.clone(), - )) + Ok(Statement::new(program, self._db.mv_store.clone(), pager)) } Cmd::Explain(_stmt) => todo!(), Cmd::ExplainQueryPlan(_stmt) => todo!(), @@ -591,22 +583,19 @@ impl Connection { return Err(LimboError::InternalError("Connection closed".to_string())); } let syms = self.syms.borrow(); + let pager = self.pager.borrow().clone(); match cmd { Cmd::Stmt(ref stmt) | Cmd::Explain(ref stmt) => { let program = translate::translate( self.schema.borrow().deref(), stmt.clone(), - self.pager.clone(), + pager.clone(), self.clone(), &syms, cmd.into(), input, )?; - let stmt = Statement::new( - program.into(), - self._db.mv_store.clone(), - self.pager.clone(), - ); + let stmt = Statement::new(program.into(), self._db.mv_store.clone(), pager); Ok(Some(stmt)) } Cmd::ExplainQueryPlan(stmt) => { @@ -646,6 +635,7 @@ impl Connection { let mut parser = Parser::new(sql.as_bytes()); while let Some(cmd) = parser.next()? { let syms = self.syms.borrow(); + let pager = self.pager.borrow().clone(); let byte_offset_end = parser.offset(); let input = str::from_utf8(&sql.as_bytes()[..byte_offset_end]) .unwrap() @@ -656,7 +646,7 @@ impl Connection { let program = translate::translate( self.schema.borrow().deref(), stmt, - self.pager.clone(), + pager, self.clone(), &syms, QueryMode::Explain, @@ -669,7 +659,7 @@ impl Connection { let program = translate::translate( self.schema.borrow().deref(), stmt, - self.pager.clone(), + pager.clone(), self.clone(), &syms, QueryMode::Normal, @@ -679,11 +669,8 @@ impl Connection { let mut state = vdbe::ProgramState::new(program.max_registers, program.cursor_ref.len()); loop { - let res = program.step( - &mut state, - self._db.mv_store.clone(), - self.pager.clone(), - )?; + let res = + program.step(&mut state, self._db.mv_store.clone(), pager.clone())?; if matches!(res, StepResult::Done) { break; } @@ -703,7 +690,7 @@ impl Connection { if res.is_err() { let state = self.transaction_state.get(); if let TransactionState::Write { schema_did_change } = state { - self.pager.rollback(schema_did_change, self)? + self.pager.borrow().rollback(schema_did_change, self)? } } res @@ -750,7 +737,7 @@ impl Connection { } pub fn wal_frame_count(&self) -> Result { - self.pager.wal_frame_count() + self.pager.borrow().wal_frame_count() } pub fn wal_get_frame( @@ -759,7 +746,9 @@ impl Connection { p_frame: *mut u8, frame_len: u32, ) -> Result> { - self.pager.wal_get_frame(frame_no, p_frame, frame_len) + self.pager + .borrow() + .wal_get_frame(frame_no, p_frame, frame_len) } /// Flush dirty pages to disk. @@ -770,11 +759,13 @@ impl Connection { if self.closed.get() { return Err(LimboError::InternalError("Connection closed".to_string())); } - self.pager.cacheflush(self.wal_checkpoint_disabled.get()) + self.pager + .borrow() + .cacheflush(self.wal_checkpoint_disabled.get()) } pub fn clear_page_cache(&self) -> Result<()> { - self.pager.clear_page_cache(); + self.pager.borrow().clear_page_cache(); Ok(()) } @@ -783,6 +774,7 @@ impl Connection { return Err(LimboError::InternalError("Connection closed".to_string())); } self.pager + .borrow() .wal_checkpoint(self.wal_checkpoint_disabled.get()) } @@ -793,6 +785,7 @@ impl Connection { } self.closed.set(true); self.pager + .borrow() .checkpoint_shutdown(self.wal_checkpoint_disabled.get()) } @@ -831,6 +824,33 @@ impl Connection { pub fn set_capture_data_changes(&self, opts: CaptureDataChangesMode) { self.capture_data_changes.replace(opts); } + pub fn get_page_size(&self) -> u32 { + self.page_size.get() + } + + /// Reset the page size for the current connection. + /// + /// Specifying a new page size does not change the page size immediately. + /// Instead, the new page size is remembered and is used to set the page size when the database + /// is first created, if it does not already exist when the page_size pragma is issued, + /// or at the next VACUUM command that is run on the same database connection while not in WAL mode. + pub fn reset_page_size(&self, size: u32) -> Result<()> { + if !is_valid_page_size(size) { + return Ok(()); + } + + self.page_size.set(size); + if self._db.db_state.load(Ordering::SeqCst) != DB_STATE_UNINITIALIZED { + return Ok(()); + } + + *self._db.maybe_shared_wal.write() = None; + let pager = self._db.init_pager(Some(size as usize))?; + self.pager.replace(Rc::new(pager)); + self.pager.borrow().set_initial_page_size(size); + + Ok(()) + } #[cfg(feature = "fs")] pub fn open_new(&self, path: &str, vfs: &str) -> Result<(Arc, Arc)> { diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 09abcc2dd..a26c84802 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -6593,7 +6593,7 @@ mod tests { &mut payload, &record, 4096, - conn.pager.clone(), + conn.pager.borrow().clone(), ); insert_into_cell(page, &payload, pos, 4096).unwrap(); payload @@ -6831,7 +6831,7 @@ mod tests { let io: Arc = Arc::new(MemoryIO::new()); let db = Database::open_file(io.clone(), "test.db", false, false).unwrap(); let conn = db.connect().unwrap(); - let pager = conn.pager.clone(); + let pager = conn.pager.borrow().clone(); // FIXME: handle page cache is full let _ = run_until_done(|| pager.allocate_page1(), &pager); @@ -7368,7 +7368,7 @@ mod tests { pager.allocate_page().unwrap(); } - header_accessor::set_page_size(&pager, page_size as u16).unwrap(); + header_accessor::set_page_size(&pager, page_size).unwrap(); pager } @@ -7717,7 +7717,7 @@ mod tests { &mut payload, &record, 4096, - conn.pager.clone(), + conn.pager.borrow().clone(), ); if (free as usize) < payload.len() + 2 { // do not try to insert overflow pages because they require balancing @@ -7790,7 +7790,7 @@ mod tests { &mut payload, &record, 4096, - conn.pager.clone(), + conn.pager.borrow().clone(), ); if (free as usize) < payload.len() - 2 { // do not try to insert overflow pages because they require balancing @@ -8154,7 +8154,7 @@ mod tests { &mut payload, &record, 4096, - conn.pager.clone(), + conn.pager.borrow().clone(), ); let page = page.get(); insert(0, page.get_contents()); @@ -8231,7 +8231,7 @@ mod tests { &mut payload, &record, 4096, - conn.pager.clone(), + conn.pager.borrow().clone(), ); insert_into_cell(page.get().get_contents(), &payload, 0, 4096).unwrap(); let free = compute_free_space(page.get().get_contents(), usable_space); diff --git a/core/storage/header_accessor.rs b/core/storage/header_accessor.rs index d1d7c726a..a2aba8013 100644 --- a/core/storage/header_accessor.rs +++ b/core/storage/header_accessor.rs @@ -1,3 +1,4 @@ +use crate::storage::sqlite3_ondisk::MAX_PAGE_SIZE; use crate::{ storage::{ self, @@ -35,7 +36,7 @@ const HEADER_OFFSET_VERSION_NUMBER: usize = 96; // Helper to get a read-only reference to the header page. fn get_header_page(pager: &Pager) -> Result> { - if pager.is_empty.load(Ordering::SeqCst) < 2 { + if pager.db_state.load(Ordering::SeqCst) < 2 { return Err(LimboError::InternalError( "Database is empty, header does not exist - page 1 should've been allocated before this".to_string(), )); @@ -49,7 +50,7 @@ fn get_header_page(pager: &Pager) -> Result> { // Helper to get a writable reference to the header page and mark it dirty. fn get_header_page_for_write(pager: &Pager) -> Result> { - if pager.is_empty.load(Ordering::SeqCst) < 2 { + if pager.db_state.load(Ordering::SeqCst) < 2 { // This should not be called on an empty DB for writing, as page 1 is allocated on first transaction. return Err(LimboError::InternalError( "Cannot write to header of an empty database - page 1 should've been allocated before this".to_string(), @@ -103,7 +104,7 @@ macro_rules! impl_header_field_accessor { // Async version #[allow(dead_code)] pub fn [](pager: &Pager) -> Result> { - if pager.is_empty.load(Ordering::SeqCst) < 2 { + if pager.db_state.load(Ordering::SeqCst) < 2 { return Err(LimboError::InternalError(format!("Database is empty, header does not exist - page 1 should've been allocated before this"))); } let page = match get_header_page(pager)? { @@ -158,7 +159,7 @@ macro_rules! impl_header_field_accessor { } // impl_header_field_accessor!(magic, [u8; 16], HEADER_OFFSET_MAGIC); -impl_header_field_accessor!(page_size, u16, HEADER_OFFSET_PAGE_SIZE); +impl_header_field_accessor!(page_size_u16, u16, HEADER_OFFSET_PAGE_SIZE); impl_header_field_accessor!(write_version, u8, HEADER_OFFSET_WRITE_VERSION); impl_header_field_accessor!(read_version, u8, HEADER_OFFSET_READ_VERSION); impl_header_field_accessor!(reserved_space, u8, HEADER_OFFSET_RESERVED_SPACE); @@ -193,3 +194,34 @@ impl_header_field_accessor!(application_id, u32, HEADER_OFFSET_APPLICATION_ID); //impl_header_field_accessor!(reserved_for_expansion, [u8; 20], HEADER_OFFSET_RESERVED_FOR_EXPANSION); impl_header_field_accessor!(version_valid_for, u32, HEADER_OFFSET_VERSION_VALID_FOR); impl_header_field_accessor!(version_number, u32, HEADER_OFFSET_VERSION_NUMBER); + +pub fn get_page_size(pager: &Pager) -> Result { + let size = get_page_size_u16(pager)?; + if size == 1 { + return Ok(MAX_PAGE_SIZE); + } + Ok(size as u32) +} + +#[allow(dead_code)] +pub fn set_page_size(pager: &Pager, value: u32) -> Result<()> { + let page_size = if value == MAX_PAGE_SIZE { + 1 + } else { + value as u16 + }; + set_page_size_u16(pager, page_size) +} + +#[allow(dead_code)] +pub fn get_page_size_async(pager: &Pager) -> Result> { + match get_page_size_u16_async(pager)? { + CursorResult::Ok(size) => { + if size == 1 { + return Ok(CursorResult::Ok(MAX_PAGE_SIZE)); + } + Ok(CursorResult::Ok(size as u32)) + } + CursorResult::IO => Ok(CursorResult::IO), + } +} diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 67499ee86..9d9a927d5 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -6,10 +6,10 @@ use crate::storage::header_accessor; use crate::storage::sqlite3_ondisk::{self, DatabaseHeader, PageContent, PageType}; use crate::storage::wal::{CheckpointResult, Wal, WalFsyncStatus}; use crate::types::CursorResult; +use crate::Completion; use crate::{Buffer, Connection, LimboError, Result}; -use crate::{Completion, WalFile}; use parking_lot::RwLock; -use std::cell::{OnceCell, RefCell, UnsafeCell}; +use std::cell::{Cell, OnceCell, RefCell, UnsafeCell}; use std::collections::HashSet; use std::rc::Rc; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -191,7 +191,7 @@ pub enum AutoVacuumMode { Incremental, } -pub const DB_STATE_UNITIALIZED: usize = 0; +pub const DB_STATE_UNINITIALIZED: usize = 0; pub const DB_STATE_INITIALIZING: usize = 1; pub const DB_STATE_INITIALIZED: usize = 2; /// The pager interface implements the persistence layer by providing access @@ -218,14 +218,14 @@ pub struct Pager { /// 0 -> Database is empty, /// 1 -> Database is being initialized, /// 2 -> Database is initialized and ready for use. - pub is_empty: Arc, + pub db_state: Arc, /// Mutex for synchronizing database initialization to prevent race conditions init_lock: Arc>, allocate_page1_state: RefCell, /// Cache page_size and reserved_space at Pager init and reuse for subsequent /// `usable_space` calls. TODO: Invalidate reserved_space when we add the functionality /// to change it. - page_size: OnceCell, + page_size: Cell>, reserved_space: OnceCell, } @@ -265,10 +265,10 @@ impl Pager { io: Arc, page_cache: Arc>, buffer_pool: Arc, - is_empty: Arc, + db_state: Arc, init_lock: Arc>, ) -> Result { - let allocate_page1_state = if is_empty.load(Ordering::SeqCst) < DB_STATE_INITIALIZED { + let allocate_page1_state = if db_state.load(Ordering::SeqCst) < DB_STATE_INITIALIZED { RefCell::new(AllocatePage1State::Start) } else { RefCell::new(AllocatePage1State::Done) @@ -288,15 +288,15 @@ impl Pager { checkpoint_inflight: Rc::new(RefCell::new(0)), buffer_pool, auto_vacuum_mode: RefCell::new(AutoVacuumMode::None), - is_empty, + db_state, init_lock, allocate_page1_state, - page_size: OnceCell::new(), + page_size: Cell::new(None), reserved_space: OnceCell::new(), }) } - pub fn set_wal(&mut self, wal: Rc>) { + pub fn set_wal(&mut self, wal: Rc>) { self.wal = wal; } @@ -586,7 +586,8 @@ impl Pager { pub fn usable_space(&self) -> usize { let page_size = *self .page_size - .get_or_init(|| header_accessor::get_page_size(self).unwrap_or_default()); + .get() + .get_or_insert_with(|| header_accessor::get_page_size(self).unwrap_or_default()); let reserved_space = *self .reserved_space @@ -595,6 +596,12 @@ impl Pager { (page_size as usize) - (reserved_space as usize) } + /// Set the initial page size for the database. Should only be called before the database is initialized + pub fn set_initial_page_size(&self, size: u32) { + assert_eq!(self.db_state.load(Ordering::SeqCst), DB_STATE_UNINITIALIZED); + self.page_size.replace(Some(size)); + } + #[inline(always)] #[instrument(skip_all, level = Level::INFO)] pub fn begin_read_tx(&self) -> Result> { @@ -608,10 +615,10 @@ impl Pager { #[instrument(skip_all, level = Level::INFO)] fn maybe_allocate_page1(&self) -> Result> { - if self.is_empty.load(Ordering::SeqCst) < DB_STATE_INITIALIZED { + if self.db_state.load(Ordering::SeqCst) < DB_STATE_INITIALIZED { if let Ok(_lock) = self.init_lock.try_lock() { match ( - self.is_empty.load(Ordering::SeqCst), + self.db_state.load(Ordering::SeqCst), self.allocating_page1(), ) { // In case of being empty or (allocating and this connection is performing allocation) then allocate the first page @@ -1054,9 +1061,12 @@ impl Pager { match state { AllocatePage1State::Start => { tracing::trace!("allocate_page1(Start)"); - self.is_empty.store(DB_STATE_INITIALIZING, Ordering::SeqCst); + self.db_state.store(DB_STATE_INITIALIZING, Ordering::SeqCst); let mut default_header = DatabaseHeader::default(); default_header.database_size += 1; + if let Some(size) = self.page_size.get() { + default_header.update_page_size(size); + } let page = allocate_page(1, &self.buffer_pool, 0); let contents = page.get_contents(); @@ -1100,7 +1110,7 @@ impl Pager { cache.insert(page_key, page1_ref.clone()).map_err(|e| { LimboError::InternalError(format!("Failed to insert page 1 into cache: {e:?}")) })?; - self.is_empty.store(DB_STATE_INITIALIZED, Ordering::SeqCst); + self.db_state.store(DB_STATE_INITIALIZED, Ordering::SeqCst); self.allocate_page1_state.replace(AllocatePage1State::Done); Ok(CursorResult::Ok(page1_ref.clone())) } diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index af5438859..abbdc637c 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -83,10 +83,10 @@ pub const MIN_PAGE_CACHE_SIZE: usize = 10; pub const MIN_PAGE_SIZE: u32 = 512; /// The maximum page size in bytes. -const MAX_PAGE_SIZE: u32 = 65536; +pub const MAX_PAGE_SIZE: u32 = 65536; /// The default page size in bytes. -pub const DEFAULT_PAGE_SIZE: u16 = 4096; +pub const DEFAULT_PAGE_SIZE: u32 = 4096; pub const DATABASE_HEADER_PAGE_ID: usize = 1; @@ -251,7 +251,7 @@ impl Default for DatabaseHeader { fn default() -> Self { Self { magic: *b"SQLite format 3\0", - page_size: DEFAULT_PAGE_SIZE, + page_size: DEFAULT_PAGE_SIZE as u16, write_version: 2, read_version: 2, reserved_space: 0, @@ -279,7 +279,7 @@ impl Default for DatabaseHeader { impl DatabaseHeader { pub fn update_page_size(&mut self, size: u32) { - if !(MIN_PAGE_SIZE..=MAX_PAGE_SIZE).contains(&size) || (size & (size - 1) != 0) { + if !is_valid_page_size(size) { return; } @@ -299,6 +299,10 @@ impl DatabaseHeader { } } +pub fn is_valid_page_size(size: u32) -> bool { + (MIN_PAGE_SIZE..=MAX_PAGE_SIZE).contains(&size) && (size & (size - 1)) == 0 +} + pub fn write_header_to_buf(buf: &mut [u8], header: &DatabaseHeader) { buf[0..16].copy_from_slice(&header.magic); buf[16..18].copy_from_slice(&header.page_size.to_be_bytes()); diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 45f554fc9..da99412ad 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -22,6 +22,7 @@ use strum::IntoEnumIterator; use super::integrity_check::translate_integrity_check; use crate::storage::header_accessor; use crate::storage::pager::Pager; +use crate::translate::emitter::TransactionMode; fn list_pragmas(program: &mut ProgramBuilder) { for x in PragmaName::iter() { @@ -29,7 +30,7 @@ fn list_pragmas(program: &mut ProgramBuilder) { program.emit_result_row(register, 1); } program.add_pragma_result_column("pragma_list".into()); - program.epilogue(crate::translate::emitter::TransactionMode::None); + program.epilogue(TransactionMode::None); } #[allow(clippy::too_many_arguments)] @@ -47,7 +48,6 @@ pub fn translate_pragma( approx_num_labels: 0, }; program.extend(&opts); - let mut write = false; if name.name.0.eq_ignore_ascii_case("pragma_list") { list_pragmas(&mut program); @@ -59,22 +59,16 @@ pub fn translate_pragma( Err(_) => bail_parse_error!("Not a valid pragma name"), }; - let mut program = match body { + let (mut program, mode) = match body { None => query_pragma(pragma, schema, None, pager, connection, program)?, Some(ast::PragmaBody::Equals(value) | ast::PragmaBody::Call(value)) => match pragma { PragmaName::TableInfo => { query_pragma(pragma, schema, Some(value), pager, connection, program)? } - _ => { - write = true; - update_pragma(pragma, schema, value, pager, connection, program)? - } + _ => update_pragma(pragma, schema, value, pager, connection, program)?, }, }; - program.epilogue(match write { - false => super::emitter::TransactionMode::Read, - true => super::emitter::TransactionMode::Write, - }); + program.epilogue(mode); Ok(program) } @@ -86,7 +80,7 @@ fn update_pragma( pager: Rc, connection: Arc, mut program: ProgramBuilder, -) -> crate::Result { +) -> crate::Result<(ProgramBuilder, TransactionMode)> { match pragma { PragmaName::CacheSize => { let cache_size = match parse_signed_number(&value)? { @@ -95,7 +89,7 @@ fn update_pragma( _ => bail_parse_error!("Invalid value for cache size pragma"), }; update_cache_size(cache_size, pager, connection)?; - Ok(program) + Ok((program, TransactionMode::None)) } PragmaName::JournalMode => query_pragma( PragmaName::JournalMode, @@ -105,7 +99,7 @@ fn update_pragma( connection, program, ), - PragmaName::LegacyFileFormat => Ok(program), + PragmaName::LegacyFileFormat => Ok((program, TransactionMode::None)), PragmaName::WalCheckpoint => query_pragma( PragmaName::WalCheckpoint, schema, @@ -136,7 +130,7 @@ fn update_pragma( value: version_value, p5: 1, }); - Ok(program) + Ok((program, TransactionMode::Write)) } PragmaName::SchemaVersion => { // TODO: Implement updating schema_version @@ -149,7 +143,13 @@ fn update_pragma( unreachable!(); } PragmaName::PageSize => { - bail_parse_error!("Updating database page size is not supported."); + let page_size = match parse_signed_number(&value)? { + Value::Integer(size) => size, + Value::Float(size) => size as i64, + _ => bail_parse_error!("Invalid value for page size pragma"), + }; + update_page_size(connection, page_size as u32)?; + Ok((program, TransactionMode::None)) } PragmaName::AutoVacuum => { let auto_vacuum_mode = match value { @@ -205,7 +205,7 @@ fn update_pragma( value: auto_vacuum_mode - 1, p5: 0, }); - Ok(program) + Ok((program, TransactionMode::None)) } PragmaName::IntegrityCheck => unreachable!("integrity_check cannot be set"), PragmaName::UnstableCaptureDataChangesConn => { @@ -230,7 +230,7 @@ fn update_pragma( )?; } connection.set_capture_data_changes(opts); - Ok(program) + Ok((program, TransactionMode::Write)) } } } @@ -242,20 +242,22 @@ fn query_pragma( pager: Rc, connection: Arc, mut program: ProgramBuilder, -) -> crate::Result { +) -> crate::Result<(ProgramBuilder, TransactionMode)> { let register = program.alloc_register(); match pragma { PragmaName::CacheSize => { program.emit_int(connection.get_cache_size() as i64, register); program.emit_result_row(register, 1); program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::None)) } PragmaName::JournalMode => { program.emit_string8("wal".into(), register); program.emit_result_row(register, 1); program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::None)) } - PragmaName::LegacyFileFormat => {} + PragmaName::LegacyFileFormat => Ok((program, TransactionMode::None)), PragmaName::WalCheckpoint => { // Checkpoint uses 3 registers: P1, P2, P3. Ref Insn::Checkpoint for more info. // Allocate two more here as one was allocated at the top. @@ -282,6 +284,7 @@ fn query_pragma( dest: register, }); program.emit_result_row(register, 3); + Ok((program, TransactionMode::None)) } PragmaName::PageCount => { program.emit_insn(Insn::PageCount { @@ -290,6 +293,7 @@ fn query_pragma( }); program.emit_result_row(register, 1); program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::Read)) } PragmaName::TableInfo => { let table = match value { @@ -335,6 +339,7 @@ fn query_pragma( for name in col_names { program.add_pragma_result_column(name.into()); } + Ok((program, TransactionMode::None)) } PragmaName::UserVersion => { program.emit_insn(Insn::ReadCookie { @@ -344,6 +349,7 @@ fn query_pragma( }); program.add_pragma_result_column(pragma.to_string()); program.emit_result_row(register, 1); + Ok((program, TransactionMode::Read)) } PragmaName::SchemaVersion => { program.emit_insn(Insn::ReadCookie { @@ -353,15 +359,16 @@ fn query_pragma( }); program.add_pragma_result_column(pragma.to_string()); program.emit_result_row(register, 1); + Ok((program, TransactionMode::Read)) } PragmaName::PageSize => { program.emit_int( - header_accessor::get_page_size(&pager) - .unwrap_or(storage::sqlite3_ondisk::DEFAULT_PAGE_SIZE) as i64, + header_accessor::get_page_size(&pager).unwrap_or(connection.get_page_size()) as i64, register, ); program.emit_result_row(register, 1); program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::None)) } PragmaName::AutoVacuum => { let auto_vacuum_mode = pager.get_auto_vacuum_mode(); @@ -378,9 +385,11 @@ fn query_pragma( value: auto_vacuum_mode_i64, }); program.emit_result_row(register, 1); + Ok((program, TransactionMode::None)) } PragmaName::IntegrityCheck => { translate_integrity_check(schema, &mut program)?; + Ok((program, TransactionMode::Read)) } PragmaName::UnstableCaptureDataChangesConn => { let pragma = pragma_for(pragma); @@ -395,10 +404,9 @@ fn query_pragma( program.emit_result_row(register, 2); program.add_pragma_result_column(pragma.columns[0].to_string()); program.add_pragma_result_column(pragma.columns[1].to_string()); + Ok((program, TransactionMode::Read)) } } - - Ok(program) } fn update_auto_vacuum_mode( @@ -528,3 +536,8 @@ fn turso_cdc_table_columns() -> Vec { }, ] } + +fn update_page_size(connection: Arc, page_size: u32) -> crate::Result<()> { + connection.reset_page_size(page_size)?; + Ok(()) +} diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index be9f666cb..d77869386 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -26,6 +26,7 @@ use crate::{ }, printf::exec_printf, }, + IO, }; use std::ops::DerefMut; use std::sync::atomic::AtomicUsize; @@ -55,9 +56,7 @@ use crate::{ vector::{vector32, vector64, vector_distance_cos, vector_distance_l2, vector_extract}, }; -use crate::{ - info, BufferPool, MvCursor, OpenFlags, RefValue, Row, StepResult, TransactionState, IO, -}; +use crate::{info, BufferPool, MvCursor, OpenFlags, RefValue, Row, StepResult, TransactionState}; use super::{ insn::{Cookie, RegisterOrLiteral}, @@ -5936,7 +5935,7 @@ pub fn op_open_ephemeral( OpOpenEphemeralState::Start => { tracing::trace!("Start"); let conn = program.connection.clone(); - let io = conn.pager.io.get_memory_io(); + let io = conn.pager.borrow().io.get_memory_io(); let file = io.open_file("", OpenFlags::Create, true)?; let db_file = Arc::new(FileMemoryStorage::new(file)); diff --git a/testing/pragma.test b/testing/pragma.test index ff42424bf..879c0516e 100755 --- a/testing/pragma.test +++ b/testing/pragma.test @@ -169,3 +169,37 @@ do_execsql_test pragma-function-sql-injection { SELECT * FROM pragma_table_info('sqlite_schema'';CREATE TABLE foo(c0);SELECT ''bar'); SELECT * FROM pragma_table_info('foo'); } {} + +do_execsql_test_on_specific_db ":memory:" pragma-page-size-default { + PRAGMA page_size +} {4096} + +do_execsql_test_on_specific_db ":memory:" pragma-page-size-set { + PRAGMA page_size=1024; + PRAGMA page_size +} {1024} + +# pragma page_size=xxx doesn't change the page size of an initialized database. +do_execsql_test_on_specific_db ":memory:" pragma-page-size-set-initialized-db { + CREATE TABLE "foo bar"(c0); + + PRAGMA page_size=1024; + PRAGMA page_size +} {4096} + +# pragma page_size=xxx changes the page size of an uninitialized database and persists the change. +set test_pragma_page_size_db "testing/testing_pragma_page_size.db" +catch {file delete -force $test_pragma_page_size_db} +catch {file delete -force "${test_pragma_page_size_db}-wal"} +# set user_version to trigger database initialization. +do_execsql_test_on_specific_db $test_pragma_page_size_db pragma-page-size-set-uninitialized-db-1 { + PRAGMA page_size=1024; + PRAGMA user_version=1; + PRAGMA page_size +} {1024} + +do_execsql_test_on_specific_db $test_pragma_page_size_db pragma-page-size-set-uninitialized-db-2 { + PRAGMA page_size +} {1024} +catch {file delete -force $test_pragma_page_size_db} +catch {file delete -force "${test_pragma_page_size_db}-wal"}