Merge 'Safe AtomicUsize wrapper for db_state: add DbState and AtomicDbState' from Levy A.

Makes invalid states impossible, removes magic numbers. Functionally
equivalent.
> [!NOTE]
> ~`DbState` is implemented as a transparent wrapper over `usize` to
avoid undefined behavior with `mem::transmute`~
> Switched to a regular enum, by @Shourya742's suggestion.

Reviewed-by: bit-aloo (@Shourya742)
Reviewed-by: Preston Thorpe (@PThorpe92)

Closes #2174
This commit is contained in:
PThorpe92
2025-07-22 19:21:10 -04:00
5 changed files with 89 additions and 42 deletions

View File

@@ -61,7 +61,6 @@ pub use io::{
};
use parking_lot::RwLock;
use schema::Schema;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{
borrow::Cow,
cell::{Cell, RefCell, UnsafeCell},
@@ -76,7 +75,7 @@ use std::{
#[cfg(feature = "fs")]
use storage::database::DatabaseFile;
use storage::page_cache::DumbLruPageCache;
use storage::pager::{DB_STATE_INITIALIZED, DB_STATE_INITIALIZING, DB_STATE_UNINITIALIZED};
use storage::pager::{AtomicDbState, DbState};
pub use storage::{
buffer_pool::BufferPool,
database::DatabaseStorage,
@@ -117,7 +116,7 @@ pub struct Database {
// create DB connections.
_shared_page_cache: Arc<RwLock<DumbLruPageCache>>,
maybe_shared_wal: RwLock<Option<Arc<UnsafeCell<WalFileShared>>>>,
db_state: Arc<AtomicUsize>,
db_state: Arc<AtomicDbState>,
init_lock: Arc<Mutex<()>>,
open_flags: OpenFlags,
builtin_syms: RefCell<SymbolTable>,
@@ -134,11 +133,10 @@ impl fmt::Debug for Database {
.field("open_flags", &self.open_flags);
// Database state information
let db_state_value = match self.db_state.load(std::sync::atomic::Ordering::Relaxed) {
DB_STATE_UNINITIALIZED => "uninitialized".to_string(),
DB_STATE_INITIALIZING => "initializing".to_string(),
DB_STATE_INITIALIZED => "initialized".to_string(),
x => format!("invalid ({x})"),
let db_state_value = match self.db_state.get() {
DbState::Uninitialized => "uninitialized".to_string(),
DbState::Initializing => "initializing".to_string(),
DbState::Initialized => "initialized".to_string(),
};
debug_struct.field("db_state", &db_state_value);
@@ -239,9 +237,9 @@ impl Database {
let db_size = db_file.size()?;
let db_state = if db_size == 0 {
DB_STATE_UNINITIALIZED
DbState::Uninitialized
} else {
DB_STATE_INITIALIZED
DbState::Initialized
};
let shared_page_cache = Arc::new(RwLock::new(DumbLruPageCache::default()));
@@ -256,14 +254,14 @@ impl Database {
builtin_syms: syms.into(),
io: io.clone(),
open_flags: flags,
db_state: Arc::new(AtomicUsize::new(db_state)),
db_state: Arc::new(AtomicDbState::new(db_state)),
init_lock: Arc::new(Mutex::new(())),
});
db.register_global_builtin_extensions()
.expect("unable to register global extensions");
// Check: https://github.com/tursodatabase/turso/pull/1761#discussion_r2154013123
if db_state == DB_STATE_INITIALIZED {
if db_state.is_initialized() {
// parse schema
let conn = db.connect()?;
@@ -918,7 +916,7 @@ impl Connection {
}
self.page_size.set(size);
if self._db.db_state.load(Ordering::SeqCst) != DB_STATE_UNINITIALIZED {
if self._db.db_state.get() != DbState::Uninitialized {
return Ok(());
}

View File

@@ -6648,7 +6648,11 @@ mod tests {
use crate::{
io::{Buffer, MemoryIO, OpenFlags, IO},
schema::IndexColumn,
storage::{database::DatabaseFile, page_cache::DumbLruPageCache},
storage::{
database::DatabaseFile,
page_cache::DumbLruPageCache,
pager::{AtomicDbState, DbState},
},
types::Text,
util::IOExt as _,
vdbe::Register,
@@ -6660,7 +6664,7 @@ mod tests {
mem::transmute,
ops::Deref,
rc::Rc,
sync::{atomic::AtomicUsize, Arc, Mutex},
sync::{Arc, Mutex},
};
use tempfile::TempDir;
@@ -7531,7 +7535,7 @@ mod tests {
io,
Arc::new(parking_lot::RwLock::new(DumbLruPageCache::new(10))),
buffer_pool,
Arc::new(AtomicUsize::new(0)),
Arc::new(AtomicDbState::new(DbState::Uninitialized)),
Arc::new(Mutex::new(())),
)
.unwrap(),

View File

@@ -8,7 +8,6 @@ use crate::{
types::IOResult,
LimboError, Result,
};
use std::sync::atomic::Ordering;
// const HEADER_OFFSET_MAGIC: usize = 0;
const HEADER_OFFSET_PAGE_SIZE: usize = 16;
@@ -36,7 +35,7 @@ const HEADER_OFFSET_VERSION_NUMBER: usize = 96;
// Helper to get a read-only reference to the header page.
fn get_header_page(pager: &Pager) -> Result<IOResult<PageRef>> {
if pager.db_state.load(Ordering::SeqCst) < 2 {
if !pager.db_state.is_initialized() {
return Err(LimboError::InternalError(
"Database is empty, header does not exist - page 1 should've been allocated before this".to_string(),
));
@@ -50,7 +49,7 @@ fn get_header_page(pager: &Pager) -> Result<IOResult<PageRef>> {
// Helper to get a writable reference to the header page and mark it dirty.
fn get_header_page_for_write(pager: &Pager) -> Result<IOResult<PageRef>> {
if pager.db_state.load(Ordering::SeqCst) < 2 {
if !pager.db_state.is_initialized() {
// This should not be called on an empty DB for writing, as page 1 is allocated on first transaction.
return Err(LimboError::InternalError(
"Cannot write to header of an empty database - page 1 should've been allocated before this".to_string(),
@@ -104,7 +103,7 @@ macro_rules! impl_header_field_accessor {
// Async version
#[allow(dead_code)]
pub fn [<get_ $field_name _async>](pager: &Pager) -> Result<IOResult<$type>> {
if pager.db_state.load(Ordering::SeqCst) < 2 {
if !pager.db_state.is_initialized() {
return Err(LimboError::InternalError(format!("Database is empty, header does not exist - page 1 should've been allocated before this")));
}
let page = match get_header_page(pager)? {

View File

@@ -245,9 +245,57 @@ pub enum AutoVacuumMode {
Incremental,
}
pub const DB_STATE_UNINITIALIZED: usize = 0;
pub const DB_STATE_INITIALIZING: usize = 1;
pub const DB_STATE_INITIALIZED: usize = 2;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(usize)]
pub enum DbState {
Uninitialized = Self::UNINITIALIZED,
Initializing = Self::INITIALIZING,
Initialized = Self::INITIALIZED,
}
impl DbState {
pub(self) const UNINITIALIZED: usize = 0;
pub(self) const INITIALIZING: usize = 1;
pub(self) const INITIALIZED: usize = 2;
#[inline]
pub fn is_initialized(&self) -> bool {
matches!(self, DbState::Initialized)
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct AtomicDbState(AtomicUsize);
impl AtomicDbState {
#[inline]
pub const fn new(state: DbState) -> Self {
Self(AtomicUsize::new(state as usize))
}
#[inline]
pub fn set(&self, state: DbState) {
self.0.store(state as usize, Ordering::SeqCst);
}
#[inline]
pub fn get(&self) -> DbState {
let v = self.0.load(Ordering::SeqCst);
match v {
DbState::UNINITIALIZED => DbState::Uninitialized,
DbState::INITIALIZING => DbState::Initializing,
DbState::INITIALIZED => DbState::Initialized,
_ => unreachable!(),
}
}
#[inline]
pub fn is_initialized(&self) -> bool {
self.get().is_initialized()
}
}
/// The pager interface implements the persistence layer by providing access
/// to pages of the database file, including caching, concurrency control, and
/// transaction management.
@@ -273,7 +321,7 @@ pub struct Pager {
/// 0 -> Database is empty,
/// 1 -> Database is being initialized,
/// 2 -> Database is initialized and ready for use.
pub db_state: Arc<AtomicUsize>,
pub db_state: Arc<AtomicDbState>,
/// Mutex for synchronizing database initialization to prevent race conditions
init_lock: Arc<Mutex<()>>,
allocate_page1_state: RefCell<AllocatePage1State>,
@@ -325,10 +373,10 @@ impl Pager {
io: Arc<dyn crate::io::IO>,
page_cache: Arc<RwLock<DumbLruPageCache>>,
buffer_pool: Arc<BufferPool>,
db_state: Arc<AtomicUsize>,
db_state: Arc<AtomicDbState>,
init_lock: Arc<Mutex<()>>,
) -> Result<Self> {
let allocate_page1_state = if db_state.load(Ordering::SeqCst) < DB_STATE_INITIALIZED {
let allocate_page1_state = if !db_state.is_initialized() {
RefCell::new(AllocatePage1State::Start)
} else {
RefCell::new(AllocatePage1State::Done)
@@ -665,7 +713,7 @@ impl Pager {
/// Set the initial page size for the database. Should only be called before the database is initialized
pub fn set_initial_page_size(&self, size: u32) {
assert_eq!(self.db_state.load(Ordering::SeqCst), DB_STATE_UNINITIALIZED);
assert_eq!(self.db_state.get(), DbState::Uninitialized);
self.page_size.replace(Some(size));
}
@@ -682,17 +730,16 @@ impl Pager {
#[instrument(skip_all, level = Level::DEBUG)]
pub fn maybe_allocate_page1(&self) -> Result<IOResult<()>> {
if self.db_state.load(Ordering::SeqCst) < DB_STATE_INITIALIZED {
if !self.db_state.is_initialized() {
if let Ok(_lock) = self.init_lock.try_lock() {
match (
self.db_state.load(Ordering::SeqCst),
self.allocating_page1(),
) {
match (self.db_state.get(), self.allocating_page1()) {
// In case of being empty or (allocating and this connection is performing allocation) then allocate the first page
(0, false) | (1, true) => match self.allocate_page1()? {
IOResult::Done(_) => Ok(IOResult::Done(())),
IOResult::IO => Ok(IOResult::IO),
},
(DbState::Uninitialized, false) | (DbState::Initializing, true) => {
match self.allocate_page1()? {
IOResult::Done(_) => Ok(IOResult::Done(())),
IOResult::IO => Ok(IOResult::IO),
}
}
_ => Ok(IOResult::IO),
}
} else {
@@ -1183,7 +1230,7 @@ impl Pager {
match state {
AllocatePage1State::Start => {
tracing::trace!("allocate_page1(Start)");
self.db_state.store(DB_STATE_INITIALIZING, Ordering::SeqCst);
self.db_state.set(DbState::Initializing);
let mut default_header = DatabaseHeader::default();
default_header.database_size += 1;
if let Some(size) = self.page_size.get() {
@@ -1232,7 +1279,7 @@ impl Pager {
cache.insert(page_key, page1_ref.clone()).map_err(|e| {
LimboError::InternalError(format!("Failed to insert page 1 into cache: {e:?}"))
})?;
self.db_state.store(DB_STATE_INITIALIZED, Ordering::SeqCst);
self.db_state.set(DbState::Initialized);
self.allocate_page1_state.replace(AllocatePage1State::Done);
Ok(IOResult::Done(page1_ref.clone()))
}
@@ -1671,7 +1718,7 @@ mod ptrmap_tests {
io,
page_cache,
buffer_pool,
Arc::new(AtomicUsize::new(0)),
Arc::new(AtomicDbState::new(DbState::Uninitialized)),
Arc::new(Mutex::new(())),
)
.unwrap();

View File

@@ -4,7 +4,7 @@ use crate::numeric::{NullableInteger, Numeric};
use crate::storage::btree::{integrity_check, IntegrityCheckError, IntegrityCheckState};
use crate::storage::database::FileMemoryStorage;
use crate::storage::page_cache::DumbLruPageCache;
use crate::storage::pager::CreateBTreeFlags;
use crate::storage::pager::{AtomicDbState, CreateBTreeFlags, DbState};
use crate::storage::sqlite3_ondisk::read_varint;
use crate::storage::wal::DummyWAL;
use crate::storage::{self, header_accessor};
@@ -30,7 +30,6 @@ use crate::{
IO,
};
use std::ops::DerefMut;
use std::sync::atomic::AtomicUsize;
use std::{
borrow::BorrowMut,
rc::Rc,
@@ -6112,7 +6111,7 @@ pub fn op_open_ephemeral(
io,
page_cache,
buffer_pool.clone(),
Arc::new(AtomicUsize::new(0)),
Arc::new(AtomicDbState::new(DbState::Uninitialized)),
Arc::new(Mutex::new(())),
)?);