diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index f375d2f95..529cc5b23 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -1506,7 +1506,7 @@ impl DbspCompiler { let io = Arc::new(MemoryIO::new()); let db = Database::open_file(io, ":memory:", false, false)?; let internal_conn = db.connect()?; - internal_conn.query_only.set(true); + internal_conn.set_query_only(true); internal_conn.auto_commit.store(false, Ordering::SeqCst); // Create temporary symbol table diff --git a/core/incremental/project_operator.rs b/core/incremental/project_operator.rs index 435f6d90a..b82a1a138 100644 --- a/core/incremental/project_operator.rs +++ b/core/incremental/project_operator.rs @@ -62,7 +62,7 @@ impl ProjectOperator { )?; let internal_conn = db.connect()?; // Set to read-only mode and disable auto-commit since we're only evaluating expressions - internal_conn.query_only.set(true); + internal_conn.set_query_only(true); internal_conn.auto_commit.store(false, Ordering::SeqCst); // Create ProjectColumn structs from compiled expressions diff --git a/core/lib.rs b/core/lib.rs index ab894bd4a..b99ea4036 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -511,17 +511,17 @@ impl Database { page_size: AtomicU16::new(page_size.get_raw()), wal_auto_checkpoint_disabled: AtomicBool::new(false), capture_data_changes: RwLock::new(CaptureDataChangesMode::Off), - closed: Cell::new(false), + closed: AtomicBool::new(false), attached_databases: RefCell::new(DatabaseCatalog::new()), - query_only: Cell::new(false), + query_only: AtomicBool::new(false), mv_tx: Cell::new(None), view_transaction_states: AllViewsTxState::new(), metrics: RefCell::new(ConnectionMetrics::new()), - is_nested_stmt: Cell::new(false), + is_nested_stmt: AtomicBool::new(false), encryption_key: RefCell::new(None), encryption_cipher_mode: Cell::new(None), sync_mode: Cell::new(SyncMode::Full), - data_sync_retry: Cell::new(false), + data_sync_retry: AtomicBool::new(false), busy_timeout: Cell::new(Duration::new(0, 0)), }); self.n_connections @@ -1000,10 +1000,10 @@ pub struct Connection { /// Client still can manually execute PRAGMA wal_checkpoint(...) commands wal_auto_checkpoint_disabled: AtomicBool, capture_data_changes: RwLock, - closed: Cell, + closed: AtomicBool, /// Attached databases attached_databases: RefCell, - query_only: Cell, + query_only: AtomicBool, pub(crate) mv_tx: Cell>, /// Per-connection view transaction states for uncommitted changes. This represents @@ -1013,11 +1013,11 @@ pub struct Connection { pub metrics: RefCell, /// Whether the connection is executing a statement initiated by another statement. /// Generally this is only true for ParseSchema. - is_nested_stmt: Cell, + is_nested_stmt: AtomicBool, encryption_key: RefCell>, encryption_cipher_mode: Cell>, sync_mode: Cell, - data_sync_retry: Cell, + data_sync_retry: AtomicBool, /// User defined max accumulated Busy timeout duration /// Default is 0 (no timeout) busy_timeout: Cell, @@ -1025,7 +1025,7 @@ pub struct Connection { impl Drop for Connection { fn drop(&mut self) { - if !self.closed.get() { + if !self.is_closed() { // if connection wasn't properly closed, decrement the connection counter self.db .n_connections @@ -1037,7 +1037,7 @@ impl Drop for Connection { impl Connection { #[instrument(skip_all, level = Level::INFO)] pub fn prepare(self: &Arc, sql: impl AsRef) -> Result { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } if sql.as_ref().is_empty() { @@ -1197,7 +1197,7 @@ impl Connection { #[instrument(skip_all, level = Level::INFO)] pub fn prepare_execute_batch(self: &Arc, sql: impl AsRef) -> Result<()> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } if sql.as_ref().is_empty() { @@ -1235,7 +1235,7 @@ impl Connection { #[instrument(skip_all, level = Level::INFO)] pub fn query(self: &Arc, sql: impl AsRef) -> Result> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } let sql = sql.as_ref(); @@ -1259,7 +1259,7 @@ impl Connection { cmd: Cmd, input: &str, ) -> Result> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } let syms = self.syms.read(); @@ -1287,7 +1287,7 @@ impl Connection { /// TODO: make this api async #[instrument(skip_all, level = Level::INFO)] pub fn execute(self: &Arc, sql: impl AsRef) -> Result<()> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } let sql = sql.as_ref(); @@ -1603,7 +1603,7 @@ impl Connection { /// Flush dirty pages to disk. pub fn cacheflush(&self) -> Result> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } self.pager.read().cacheflush() @@ -1615,7 +1615,7 @@ impl Connection { } pub fn checkpoint(&self, mode: CheckpointMode) -> Result { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } self.pager.read().wal_checkpoint(mode) @@ -1623,10 +1623,10 @@ impl Connection { /// Close a connection and checkpoint. pub fn close(&self) -> Result<()> { - if self.closed.get() { + if self.is_closed() { return Ok(()); } - self.closed.set(true); + self.closed.store(true, Ordering::SeqCst); match self.get_tx_state() { TransactionState::None => { @@ -1707,6 +1707,14 @@ impl Connection { PageSize::new_from_header_u16(value).unwrap_or_default() } + pub fn is_closed(&self) -> bool { + self.closed.load(Ordering::SeqCst) + } + + pub fn is_query_only(&self) -> bool { + self.query_only.load(Ordering::SeqCst) + } + pub fn get_database_canonical_path(&self) -> String { if self.db.path == ":memory:" { // For in-memory databases, SQLite shows empty string @@ -1791,7 +1799,7 @@ impl Connection { } pub fn parse_schema_rows(self: &Arc) -> Result<()> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } let rows = self @@ -1819,7 +1827,7 @@ impl Connection { // Clearly there is something to improve here, Vec> isn't a couple of tea /// Query the current rows/values of `pragma_name`. pub fn pragma_query(self: &Arc, pragma_name: &str) -> Result>> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } let pragma = format!("PRAGMA {pragma_name}"); @@ -1836,7 +1844,7 @@ impl Connection { pragma_name: &str, pragma_value: V, ) -> Result>> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } let pragma = format!("PRAGMA {pragma_name} = {pragma_value}"); @@ -1863,7 +1871,7 @@ impl Connection { pragma_name: &str, pragma_value: V, ) -> Result>> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } let pragma = format!("PRAGMA {pragma_name}({pragma_value})"); @@ -1923,7 +1931,7 @@ impl Connection { /// Attach a database file with the given alias name #[cfg(feature = "fs")] pub(crate) fn attach_database(&self, path: &str, alias: &str) -> Result<()> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -1967,7 +1975,7 @@ impl Connection { // Detach a database by alias name fn detach_database(&self, alias: &str) -> Result<()> { - if self.closed.get() { + if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -2116,11 +2124,11 @@ impl Connection { } pub fn get_query_only(&self) -> bool { - self.query_only.get() + self.is_query_only() } pub fn set_query_only(&self, value: bool) { - self.query_only.set(value); + self.query_only.store(value, Ordering::SeqCst); } pub fn get_sync_mode(&self) -> SyncMode { @@ -2132,11 +2140,13 @@ impl Connection { } pub fn get_data_sync_retry(&self) -> bool { - self.data_sync_retry.get() + self.data_sync_retry + .load(std::sync::atomic::Ordering::SeqCst) } pub fn set_data_sync_retry(&self, value: bool) { - self.data_sync_retry.set(value); + self.data_sync_retry + .store(value, std::sync::atomic::Ordering::SeqCst); } /// Creates a HashSet of modules that have been loaded @@ -2481,7 +2491,12 @@ impl Statement { pub fn run_once(&self) -> Result<()> { let res = self.pager.io.step(); - if self.program.connection.is_nested_stmt.get() { + if self + .program + .connection + .is_nested_stmt + .load(Ordering::SeqCst) + { return res; } if res.is_err() { diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 59578a95d..c4c4df620 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1122,7 +1122,7 @@ impl Pager { rollback: bool, connection: &Connection, ) -> Result> { - if connection.is_nested_stmt.get() { + if connection.is_nested_stmt.load(Ordering::SeqCst) { // Parent statement will handle the transaction rollback. return Ok(IOResult::Done(PagerCommitResult::Rollback)); } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 5068a819e..176715ace 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -2207,7 +2207,8 @@ pub fn op_transaction_inner( // 1. We try to upgrade current version let current_state = conn.get_tx_state(); - let (new_transaction_state, updated) = if conn.is_nested_stmt.get() { + let (new_transaction_state, updated) = if conn.is_nested_stmt.load(Ordering::SeqCst) + { (current_state, false) } else { match (current_state, write) { @@ -2295,7 +2296,7 @@ pub fn op_transaction_inner( } if updated && matches!(current_state, TransactionState::None) { turso_assert!( - !conn.is_nested_stmt.get(), + !conn.is_nested_stmt.load(Ordering::SeqCst), "nested stmt should not begin a new read transaction" ); pager.begin_read_tx()?; @@ -2303,7 +2304,7 @@ pub fn op_transaction_inner( if updated && matches!(new_transaction_state, TransactionState::Write { .. }) { turso_assert!( - !conn.is_nested_stmt.get(), + !conn.is_nested_stmt.load(Ordering::SeqCst), "nested stmt should not begin a new write transaction" ); let begin_w_tx_res = pager.begin_write_tx(); @@ -6837,7 +6838,7 @@ pub fn op_parse_schema( conn.with_schema_mut(|schema| { // TODO: This function below is synchronous, make it async let existing_views = schema.incremental_views.clone(); - conn.is_nested_stmt.set(true); + conn.is_nested_stmt.store(true, Ordering::SeqCst); parse_schema_rows( stmt, schema, @@ -6853,7 +6854,7 @@ pub fn op_parse_schema( conn.with_schema_mut(|schema| { // TODO: This function below is synchronous, make it async let existing_views = schema.incremental_views.clone(); - conn.is_nested_stmt.set(true); + conn.is_nested_stmt.store(true, Ordering::SeqCst); parse_schema_rows( stmt, schema, @@ -6864,7 +6865,7 @@ pub fn op_parse_schema( ) }) }; - conn.is_nested_stmt.set(false); + conn.is_nested_stmt.store(false, Ordering::SeqCst); conn.auto_commit .store(previous_auto_commit, Ordering::SeqCst); maybe_nested_stmt_err?; diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index f80bbd6ab..f70028f47 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -529,7 +529,7 @@ impl Program { pager: Arc, ) -> Result { debug_assert!(state.column_count() == EXPLAIN_COLUMNS.len()); - if self.connection.closed.get() { + if self.connection.is_closed() { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.get_tx_state(); if let TransactionState::Write { .. } = state { @@ -584,7 +584,7 @@ impl Program { ) -> Result { debug_assert!(state.column_count() == EXPLAIN_QUERY_PLAN_COLUMNS.len()); loop { - if self.connection.closed.get() { + if self.connection.is_closed() { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.get_tx_state(); if let TransactionState::Write { .. } = state { @@ -632,7 +632,7 @@ impl Program { ) -> Result { let enable_tracing = tracing::enabled!(tracing::Level::TRACE); loop { - if self.connection.closed.get() { + if self.connection.is_closed() { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.get_tx_state(); if let TransactionState::Write { .. } = state { @@ -825,7 +825,7 @@ impl Program { return Ok(IOResult::Done(())); } if let Some(mv_store) = mv_store { - if self.connection.is_nested_stmt.get() { + if self.connection.is_nested_stmt.load(Ordering::SeqCst) { // We don't want to commit on nested statements. Let parent handle it. return Ok(IOResult::Done(())); } @@ -1069,7 +1069,7 @@ pub fn handle_program_error( err: &LimboError, mv_store: Option<&Arc>, ) -> Result<()> { - if connection.is_nested_stmt.get() { + if connection.is_nested_stmt.load(Ordering::SeqCst) { // Errors from nested statements are handled by the parent statement. return Ok(()); }