Merge 'Wrap more Connection fields with atomics' from Pekka Enberg

Closes #3307
This commit is contained in:
Pekka Enberg
2025-09-24 20:16:42 +03:00
committed by GitHub
6 changed files with 59 additions and 43 deletions

View File

@@ -1506,7 +1506,7 @@ impl DbspCompiler {
let io = Arc::new(MemoryIO::new());
let db = Database::open_file(io, ":memory:", false, false)?;
let internal_conn = db.connect()?;
internal_conn.query_only.set(true);
internal_conn.set_query_only(true);
internal_conn.auto_commit.store(false, Ordering::SeqCst);
// Create temporary symbol table

View File

@@ -62,7 +62,7 @@ impl ProjectOperator {
)?;
let internal_conn = db.connect()?;
// Set to read-only mode and disable auto-commit since we're only evaluating expressions
internal_conn.query_only.set(true);
internal_conn.set_query_only(true);
internal_conn.auto_commit.store(false, Ordering::SeqCst);
// Create ProjectColumn structs from compiled expressions

View File

@@ -511,17 +511,17 @@ impl Database {
page_size: AtomicU16::new(page_size.get_raw()),
wal_auto_checkpoint_disabled: AtomicBool::new(false),
capture_data_changes: RwLock::new(CaptureDataChangesMode::Off),
closed: Cell::new(false),
closed: AtomicBool::new(false),
attached_databases: RefCell::new(DatabaseCatalog::new()),
query_only: Cell::new(false),
query_only: AtomicBool::new(false),
mv_tx: Cell::new(None),
view_transaction_states: AllViewsTxState::new(),
metrics: RefCell::new(ConnectionMetrics::new()),
is_nested_stmt: Cell::new(false),
is_nested_stmt: AtomicBool::new(false),
encryption_key: RefCell::new(None),
encryption_cipher_mode: Cell::new(None),
sync_mode: Cell::new(SyncMode::Full),
data_sync_retry: Cell::new(false),
data_sync_retry: AtomicBool::new(false),
busy_timeout: Cell::new(Duration::new(0, 0)),
});
self.n_connections
@@ -1000,10 +1000,10 @@ pub struct Connection {
/// Client still can manually execute PRAGMA wal_checkpoint(...) commands
wal_auto_checkpoint_disabled: AtomicBool,
capture_data_changes: RwLock<CaptureDataChangesMode>,
closed: Cell<bool>,
closed: AtomicBool,
/// Attached databases
attached_databases: RefCell<DatabaseCatalog>,
query_only: Cell<bool>,
query_only: AtomicBool,
pub(crate) mv_tx: Cell<Option<(crate::mvcc::database::TxID, TransactionMode)>>,
/// Per-connection view transaction states for uncommitted changes. This represents
@@ -1013,11 +1013,11 @@ pub struct Connection {
pub metrics: RefCell<ConnectionMetrics>,
/// Whether the connection is executing a statement initiated by another statement.
/// Generally this is only true for ParseSchema.
is_nested_stmt: Cell<bool>,
is_nested_stmt: AtomicBool,
encryption_key: RefCell<Option<EncryptionKey>>,
encryption_cipher_mode: Cell<Option<CipherMode>>,
sync_mode: Cell<SyncMode>,
data_sync_retry: Cell<bool>,
data_sync_retry: AtomicBool,
/// User defined max accumulated Busy timeout duration
/// Default is 0 (no timeout)
busy_timeout: Cell<std::time::Duration>,
@@ -1025,7 +1025,7 @@ pub struct Connection {
impl Drop for Connection {
fn drop(&mut self) {
if !self.closed.get() {
if !self.is_closed() {
// if connection wasn't properly closed, decrement the connection counter
self.db
.n_connections
@@ -1037,7 +1037,7 @@ impl Drop for Connection {
impl Connection {
#[instrument(skip_all, level = Level::INFO)]
pub fn prepare(self: &Arc<Connection>, sql: impl AsRef<str>) -> Result<Statement> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
if sql.as_ref().is_empty() {
@@ -1197,7 +1197,7 @@ impl Connection {
#[instrument(skip_all, level = Level::INFO)]
pub fn prepare_execute_batch(self: &Arc<Connection>, sql: impl AsRef<str>) -> Result<()> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
if sql.as_ref().is_empty() {
@@ -1235,7 +1235,7 @@ impl Connection {
#[instrument(skip_all, level = Level::INFO)]
pub fn query(self: &Arc<Connection>, sql: impl AsRef<str>) -> Result<Option<Statement>> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
let sql = sql.as_ref();
@@ -1259,7 +1259,7 @@ impl Connection {
cmd: Cmd,
input: &str,
) -> Result<Option<Statement>> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
let syms = self.syms.read();
@@ -1287,7 +1287,7 @@ impl Connection {
/// TODO: make this api async
#[instrument(skip_all, level = Level::INFO)]
pub fn execute(self: &Arc<Connection>, sql: impl AsRef<str>) -> Result<()> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
let sql = sql.as_ref();
@@ -1603,7 +1603,7 @@ impl Connection {
/// Flush dirty pages to disk.
pub fn cacheflush(&self) -> Result<Vec<Completion>> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
self.pager.read().cacheflush()
@@ -1615,7 +1615,7 @@ impl Connection {
}
pub fn checkpoint(&self, mode: CheckpointMode) -> Result<CheckpointResult> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
self.pager.read().wal_checkpoint(mode)
@@ -1623,10 +1623,10 @@ impl Connection {
/// Close a connection and checkpoint.
pub fn close(&self) -> Result<()> {
if self.closed.get() {
if self.is_closed() {
return Ok(());
}
self.closed.set(true);
self.closed.store(true, Ordering::SeqCst);
match self.get_tx_state() {
TransactionState::None => {
@@ -1707,6 +1707,14 @@ impl Connection {
PageSize::new_from_header_u16(value).unwrap_or_default()
}
pub fn is_closed(&self) -> bool {
self.closed.load(Ordering::SeqCst)
}
pub fn is_query_only(&self) -> bool {
self.query_only.load(Ordering::SeqCst)
}
pub fn get_database_canonical_path(&self) -> String {
if self.db.path == ":memory:" {
// For in-memory databases, SQLite shows empty string
@@ -1791,7 +1799,7 @@ impl Connection {
}
pub fn parse_schema_rows(self: &Arc<Connection>) -> Result<()> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
let rows = self
@@ -1819,7 +1827,7 @@ impl Connection {
// Clearly there is something to improve here, Vec<Vec<Value>> isn't a couple of tea
/// Query the current rows/values of `pragma_name`.
pub fn pragma_query(self: &Arc<Connection>, pragma_name: &str) -> Result<Vec<Vec<Value>>> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
let pragma = format!("PRAGMA {pragma_name}");
@@ -1836,7 +1844,7 @@ impl Connection {
pragma_name: &str,
pragma_value: V,
) -> Result<Vec<Vec<Value>>> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
let pragma = format!("PRAGMA {pragma_name} = {pragma_value}");
@@ -1863,7 +1871,7 @@ impl Connection {
pragma_name: &str,
pragma_value: V,
) -> Result<Vec<Vec<Value>>> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
let pragma = format!("PRAGMA {pragma_name}({pragma_value})");
@@ -1923,7 +1931,7 @@ impl Connection {
/// Attach a database file with the given alias name
#[cfg(feature = "fs")]
pub(crate) fn attach_database(&self, path: &str, alias: &str) -> Result<()> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
@@ -1967,7 +1975,7 @@ impl Connection {
// Detach a database by alias name
fn detach_database(&self, alias: &str) -> Result<()> {
if self.closed.get() {
if self.is_closed() {
return Err(LimboError::InternalError("Connection closed".to_string()));
}
@@ -2116,11 +2124,11 @@ impl Connection {
}
pub fn get_query_only(&self) -> bool {
self.query_only.get()
self.is_query_only()
}
pub fn set_query_only(&self, value: bool) {
self.query_only.set(value);
self.query_only.store(value, Ordering::SeqCst);
}
pub fn get_sync_mode(&self) -> SyncMode {
@@ -2132,11 +2140,13 @@ impl Connection {
}
pub fn get_data_sync_retry(&self) -> bool {
self.data_sync_retry.get()
self.data_sync_retry
.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn set_data_sync_retry(&self, value: bool) {
self.data_sync_retry.set(value);
self.data_sync_retry
.store(value, std::sync::atomic::Ordering::SeqCst);
}
/// Creates a HashSet of modules that have been loaded
@@ -2481,7 +2491,12 @@ impl Statement {
pub fn run_once(&self) -> Result<()> {
let res = self.pager.io.step();
if self.program.connection.is_nested_stmt.get() {
if self
.program
.connection
.is_nested_stmt
.load(Ordering::SeqCst)
{
return res;
}
if res.is_err() {

View File

@@ -1122,7 +1122,7 @@ impl Pager {
rollback: bool,
connection: &Connection,
) -> Result<IOResult<PagerCommitResult>> {
if connection.is_nested_stmt.get() {
if connection.is_nested_stmt.load(Ordering::SeqCst) {
// Parent statement will handle the transaction rollback.
return Ok(IOResult::Done(PagerCommitResult::Rollback));
}

View File

@@ -2207,7 +2207,8 @@ pub fn op_transaction_inner(
// 1. We try to upgrade current version
let current_state = conn.get_tx_state();
let (new_transaction_state, updated) = if conn.is_nested_stmt.get() {
let (new_transaction_state, updated) = if conn.is_nested_stmt.load(Ordering::SeqCst)
{
(current_state, false)
} else {
match (current_state, write) {
@@ -2295,7 +2296,7 @@ pub fn op_transaction_inner(
}
if updated && matches!(current_state, TransactionState::None) {
turso_assert!(
!conn.is_nested_stmt.get(),
!conn.is_nested_stmt.load(Ordering::SeqCst),
"nested stmt should not begin a new read transaction"
);
pager.begin_read_tx()?;
@@ -2303,7 +2304,7 @@ pub fn op_transaction_inner(
if updated && matches!(new_transaction_state, TransactionState::Write { .. }) {
turso_assert!(
!conn.is_nested_stmt.get(),
!conn.is_nested_stmt.load(Ordering::SeqCst),
"nested stmt should not begin a new write transaction"
);
let begin_w_tx_res = pager.begin_write_tx();
@@ -6837,7 +6838,7 @@ pub fn op_parse_schema(
conn.with_schema_mut(|schema| {
// TODO: This function below is synchronous, make it async
let existing_views = schema.incremental_views.clone();
conn.is_nested_stmt.set(true);
conn.is_nested_stmt.store(true, Ordering::SeqCst);
parse_schema_rows(
stmt,
schema,
@@ -6853,7 +6854,7 @@ pub fn op_parse_schema(
conn.with_schema_mut(|schema| {
// TODO: This function below is synchronous, make it async
let existing_views = schema.incremental_views.clone();
conn.is_nested_stmt.set(true);
conn.is_nested_stmt.store(true, Ordering::SeqCst);
parse_schema_rows(
stmt,
schema,
@@ -6864,7 +6865,7 @@ pub fn op_parse_schema(
)
})
};
conn.is_nested_stmt.set(false);
conn.is_nested_stmt.store(false, Ordering::SeqCst);
conn.auto_commit
.store(previous_auto_commit, Ordering::SeqCst);
maybe_nested_stmt_err?;

View File

@@ -529,7 +529,7 @@ impl Program {
pager: Arc<Pager>,
) -> Result<StepResult> {
debug_assert!(state.column_count() == EXPLAIN_COLUMNS.len());
if self.connection.closed.get() {
if self.connection.is_closed() {
// Connection is closed for whatever reason, rollback the transaction.
let state = self.connection.get_tx_state();
if let TransactionState::Write { .. } = state {
@@ -584,7 +584,7 @@ impl Program {
) -> Result<StepResult> {
debug_assert!(state.column_count() == EXPLAIN_QUERY_PLAN_COLUMNS.len());
loop {
if self.connection.closed.get() {
if self.connection.is_closed() {
// Connection is closed for whatever reason, rollback the transaction.
let state = self.connection.get_tx_state();
if let TransactionState::Write { .. } = state {
@@ -632,7 +632,7 @@ impl Program {
) -> Result<StepResult> {
let enable_tracing = tracing::enabled!(tracing::Level::TRACE);
loop {
if self.connection.closed.get() {
if self.connection.is_closed() {
// Connection is closed for whatever reason, rollback the transaction.
let state = self.connection.get_tx_state();
if let TransactionState::Write { .. } = state {
@@ -825,7 +825,7 @@ impl Program {
return Ok(IOResult::Done(()));
}
if let Some(mv_store) = mv_store {
if self.connection.is_nested_stmt.get() {
if self.connection.is_nested_stmt.load(Ordering::SeqCst) {
// We don't want to commit on nested statements. Let parent handle it.
return Ok(IOResult::Done(()));
}
@@ -1069,7 +1069,7 @@ pub fn handle_program_error(
err: &LimboError,
mv_store: Option<&Arc<MvStore>>,
) -> Result<()> {
if connection.is_nested_stmt.get() {
if connection.is_nested_stmt.load(Ordering::SeqCst) {
// Errors from nested statements are handled by the parent statement.
return Ok(());
}