diff --git a/core/lib.rs b/core/lib.rs index 87daf3dc6..1ec36633b 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -61,7 +61,6 @@ pub use io::{ }; use parking_lot::RwLock; use schema::Schema; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::{ borrow::Cow, cell::{Cell, RefCell, UnsafeCell}, @@ -76,7 +75,7 @@ use std::{ #[cfg(feature = "fs")] use storage::database::DatabaseFile; use storage::page_cache::DumbLruPageCache; -use storage::pager::{DB_STATE_INITIALIZED, DB_STATE_INITIALIZING, DB_STATE_UNINITIALIZED}; +use storage::pager::{AtomicDbState, DbState}; pub use storage::{ buffer_pool::BufferPool, database::DatabaseStorage, @@ -117,7 +116,7 @@ pub struct Database { // create DB connections. _shared_page_cache: Arc>, maybe_shared_wal: RwLock>>>, - db_state: Arc, + db_state: Arc, init_lock: Arc>, open_flags: OpenFlags, builtin_syms: RefCell, @@ -134,11 +133,10 @@ impl fmt::Debug for Database { .field("open_flags", &self.open_flags); // Database state information - let db_state_value = match self.db_state.load(std::sync::atomic::Ordering::Relaxed) { - DB_STATE_UNINITIALIZED => "uninitialized".to_string(), - DB_STATE_INITIALIZING => "initializing".to_string(), - DB_STATE_INITIALIZED => "initialized".to_string(), - x => format!("invalid ({x})"), + let db_state_value = match self.db_state.get() { + DbState::Uninitialized => "uninitialized".to_string(), + DbState::Initializing => "initializing".to_string(), + DbState::Initialized => "initialized".to_string(), }; debug_struct.field("db_state", &db_state_value); @@ -239,9 +237,9 @@ impl Database { let db_size = db_file.size()?; let db_state = if db_size == 0 { - DB_STATE_UNINITIALIZED + DbState::Uninitialized } else { - DB_STATE_INITIALIZED + DbState::Initialized }; let shared_page_cache = Arc::new(RwLock::new(DumbLruPageCache::default())); @@ -256,14 +254,14 @@ impl Database { builtin_syms: syms.into(), io: io.clone(), open_flags: flags, - db_state: Arc::new(AtomicUsize::new(db_state)), + db_state: Arc::new(AtomicDbState::new(db_state)), init_lock: Arc::new(Mutex::new(())), }); db.register_global_builtin_extensions() .expect("unable to register global extensions"); // Check: https://github.com/tursodatabase/turso/pull/1761#discussion_r2154013123 - if db_state == DB_STATE_INITIALIZED { + if db_state.is_initialized() { // parse schema let conn = db.connect()?; @@ -918,7 +916,7 @@ impl Connection { } self.page_size.set(size); - if self._db.db_state.load(Ordering::SeqCst) != DB_STATE_UNINITIALIZED { + if self._db.db_state.get() != DbState::Uninitialized { return Ok(()); } diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 7b4008263..3139ac492 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -6648,7 +6648,11 @@ mod tests { use crate::{ io::{Buffer, MemoryIO, OpenFlags, IO}, schema::IndexColumn, - storage::{database::DatabaseFile, page_cache::DumbLruPageCache}, + storage::{ + database::DatabaseFile, + page_cache::DumbLruPageCache, + pager::{AtomicDbState, DbState}, + }, types::Text, util::IOExt as _, vdbe::Register, @@ -6660,7 +6664,7 @@ mod tests { mem::transmute, ops::Deref, rc::Rc, - sync::{atomic::AtomicUsize, Arc, Mutex}, + sync::{Arc, Mutex}, }; use tempfile::TempDir; @@ -7531,7 +7535,7 @@ mod tests { io, Arc::new(parking_lot::RwLock::new(DumbLruPageCache::new(10))), buffer_pool, - Arc::new(AtomicUsize::new(0)), + Arc::new(AtomicDbState::new(DbState::Uninitialized)), Arc::new(Mutex::new(())), ) .unwrap(), diff --git a/core/storage/header_accessor.rs b/core/storage/header_accessor.rs index b9ac8902b..33edbfe37 100644 --- a/core/storage/header_accessor.rs +++ b/core/storage/header_accessor.rs @@ -8,7 +8,6 @@ use crate::{ types::IOResult, LimboError, Result, }; -use std::sync::atomic::Ordering; // const HEADER_OFFSET_MAGIC: usize = 0; const HEADER_OFFSET_PAGE_SIZE: usize = 16; @@ -36,7 +35,7 @@ const HEADER_OFFSET_VERSION_NUMBER: usize = 96; // Helper to get a read-only reference to the header page. fn get_header_page(pager: &Pager) -> Result> { - if pager.db_state.load(Ordering::SeqCst) < 2 { + if !pager.db_state.is_initialized() { return Err(LimboError::InternalError( "Database is empty, header does not exist - page 1 should've been allocated before this".to_string(), )); @@ -50,7 +49,7 @@ fn get_header_page(pager: &Pager) -> Result> { // Helper to get a writable reference to the header page and mark it dirty. fn get_header_page_for_write(pager: &Pager) -> Result> { - if pager.db_state.load(Ordering::SeqCst) < 2 { + if !pager.db_state.is_initialized() { // This should not be called on an empty DB for writing, as page 1 is allocated on first transaction. return Err(LimboError::InternalError( "Cannot write to header of an empty database - page 1 should've been allocated before this".to_string(), @@ -104,7 +103,7 @@ macro_rules! impl_header_field_accessor { // Async version #[allow(dead_code)] pub fn [](pager: &Pager) -> Result> { - if pager.db_state.load(Ordering::SeqCst) < 2 { + if !pager.db_state.is_initialized() { return Err(LimboError::InternalError(format!("Database is empty, header does not exist - page 1 should've been allocated before this"))); } let page = match get_header_page(pager)? { diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 9a760f2ff..95f7efb98 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -245,9 +245,57 @@ pub enum AutoVacuumMode { Incremental, } -pub const DB_STATE_UNINITIALIZED: usize = 0; -pub const DB_STATE_INITIALIZING: usize = 1; -pub const DB_STATE_INITIALIZED: usize = 2; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(usize)] +pub enum DbState { + Uninitialized = Self::UNINITIALIZED, + Initializing = Self::INITIALIZING, + Initialized = Self::INITIALIZED, +} + +impl DbState { + pub(self) const UNINITIALIZED: usize = 0; + pub(self) const INITIALIZING: usize = 1; + pub(self) const INITIALIZED: usize = 2; + + #[inline] + pub fn is_initialized(&self) -> bool { + matches!(self, DbState::Initialized) + } +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct AtomicDbState(AtomicUsize); + +impl AtomicDbState { + #[inline] + pub const fn new(state: DbState) -> Self { + Self(AtomicUsize::new(state as usize)) + } + + #[inline] + pub fn set(&self, state: DbState) { + self.0.store(state as usize, Ordering::SeqCst); + } + + #[inline] + pub fn get(&self) -> DbState { + let v = self.0.load(Ordering::SeqCst); + match v { + DbState::UNINITIALIZED => DbState::Uninitialized, + DbState::INITIALIZING => DbState::Initializing, + DbState::INITIALIZED => DbState::Initialized, + _ => unreachable!(), + } + } + + #[inline] + pub fn is_initialized(&self) -> bool { + self.get().is_initialized() + } +} + /// The pager interface implements the persistence layer by providing access /// to pages of the database file, including caching, concurrency control, and /// transaction management. @@ -273,7 +321,7 @@ pub struct Pager { /// 0 -> Database is empty, /// 1 -> Database is being initialized, /// 2 -> Database is initialized and ready for use. - pub db_state: Arc, + pub db_state: Arc, /// Mutex for synchronizing database initialization to prevent race conditions init_lock: Arc>, allocate_page1_state: RefCell, @@ -325,10 +373,10 @@ impl Pager { io: Arc, page_cache: Arc>, buffer_pool: Arc, - db_state: Arc, + db_state: Arc, init_lock: Arc>, ) -> Result { - let allocate_page1_state = if db_state.load(Ordering::SeqCst) < DB_STATE_INITIALIZED { + let allocate_page1_state = if !db_state.is_initialized() { RefCell::new(AllocatePage1State::Start) } else { RefCell::new(AllocatePage1State::Done) @@ -665,7 +713,7 @@ impl Pager { /// Set the initial page size for the database. Should only be called before the database is initialized pub fn set_initial_page_size(&self, size: u32) { - assert_eq!(self.db_state.load(Ordering::SeqCst), DB_STATE_UNINITIALIZED); + assert_eq!(self.db_state.get(), DbState::Uninitialized); self.page_size.replace(Some(size)); } @@ -682,17 +730,16 @@ impl Pager { #[instrument(skip_all, level = Level::DEBUG)] pub fn maybe_allocate_page1(&self) -> Result> { - if self.db_state.load(Ordering::SeqCst) < DB_STATE_INITIALIZED { + if !self.db_state.is_initialized() { if let Ok(_lock) = self.init_lock.try_lock() { - match ( - self.db_state.load(Ordering::SeqCst), - self.allocating_page1(), - ) { + match (self.db_state.get(), self.allocating_page1()) { // In case of being empty or (allocating and this connection is performing allocation) then allocate the first page - (0, false) | (1, true) => match self.allocate_page1()? { - IOResult::Done(_) => Ok(IOResult::Done(())), - IOResult::IO => Ok(IOResult::IO), - }, + (DbState::Uninitialized, false) | (DbState::Initializing, true) => { + match self.allocate_page1()? { + IOResult::Done(_) => Ok(IOResult::Done(())), + IOResult::IO => Ok(IOResult::IO), + } + } _ => Ok(IOResult::IO), } } else { @@ -1183,7 +1230,7 @@ impl Pager { match state { AllocatePage1State::Start => { tracing::trace!("allocate_page1(Start)"); - self.db_state.store(DB_STATE_INITIALIZING, Ordering::SeqCst); + self.db_state.set(DbState::Initializing); let mut default_header = DatabaseHeader::default(); default_header.database_size += 1; if let Some(size) = self.page_size.get() { @@ -1232,7 +1279,7 @@ impl Pager { cache.insert(page_key, page1_ref.clone()).map_err(|e| { LimboError::InternalError(format!("Failed to insert page 1 into cache: {e:?}")) })?; - self.db_state.store(DB_STATE_INITIALIZED, Ordering::SeqCst); + self.db_state.set(DbState::Initialized); self.allocate_page1_state.replace(AllocatePage1State::Done); Ok(IOResult::Done(page1_ref.clone())) } @@ -1671,7 +1718,7 @@ mod ptrmap_tests { io, page_cache, buffer_pool, - Arc::new(AtomicUsize::new(0)), + Arc::new(AtomicDbState::new(DbState::Uninitialized)), Arc::new(Mutex::new(())), ) .unwrap(); diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index af396384b..ae0d341cd 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -4,7 +4,7 @@ use crate::numeric::{NullableInteger, Numeric}; use crate::storage::btree::{integrity_check, IntegrityCheckError, IntegrityCheckState}; use crate::storage::database::FileMemoryStorage; use crate::storage::page_cache::DumbLruPageCache; -use crate::storage::pager::CreateBTreeFlags; +use crate::storage::pager::{AtomicDbState, CreateBTreeFlags, DbState}; use crate::storage::sqlite3_ondisk::read_varint; use crate::storage::wal::DummyWAL; use crate::storage::{self, header_accessor}; @@ -30,7 +30,6 @@ use crate::{ IO, }; use std::ops::DerefMut; -use std::sync::atomic::AtomicUsize; use std::{ borrow::BorrowMut, rc::Rc, @@ -6112,7 +6111,7 @@ pub fn op_open_ephemeral( io, page_cache, buffer_pool.clone(), - Arc::new(AtomicUsize::new(0)), + Arc::new(AtomicDbState::new(DbState::Uninitialized)), Arc::new(Mutex::new(())), )?);