diff --git a/core/lib.rs b/core/lib.rs index 6e2fb82a7..f5cb80b57 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -495,6 +495,9 @@ pub struct Connection { impl Connection { #[instrument(skip_all, level = Level::INFO)] pub fn prepare(self: &Arc, sql: impl AsRef) -> Result { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } if sql.as_ref().is_empty() { return Err(LimboError::InvalidArgument( "The supplied SQL string contains no statements".to_string(), @@ -536,6 +539,9 @@ impl Connection { #[instrument(skip_all, level = Level::INFO)] pub fn query(self: &Arc, sql: impl AsRef) -> Result> { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } let sql = sql.as_ref(); tracing::trace!("Querying: {}", sql); let mut parser = Parser::new(sql.as_bytes()); @@ -556,6 +562,9 @@ impl Connection { cmd: Cmd, input: &str, ) -> Result> { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } let syms = self.syms.borrow(); match cmd { Cmd::Stmt(ref stmt) | Cmd::Explain(ref stmt) => { @@ -605,6 +614,9 @@ impl Connection { /// TODO: make this api async #[instrument(skip_all, level = Level::INFO)] pub fn execute(self: &Arc, sql: impl AsRef) -> Result<()> { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } let sql = sql.as_ref(); let mut parser = Parser::new(sql.as_bytes()); while let Some(cmd) = parser.next()? { @@ -659,6 +671,9 @@ impl Connection { } fn run_once(&self) -> Result<()> { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } let res = self._db.io.run_once(); if res.is_err() { let state = self.transaction_state.get(); @@ -727,6 +742,9 @@ impl Connection { /// If the WAL size is over the checkpoint threshold, it will checkpoint the WAL to /// the database file and then fsync the database file. pub fn cacheflush(&self) -> Result { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } self.pager.cacheflush(self.wal_checkpoint_disabled.get()) } @@ -736,13 +754,18 @@ impl Connection { } pub fn checkpoint(&self) -> Result { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } self.pager .wal_checkpoint(self.wal_checkpoint_disabled.get()) } /// Close a connection and checkpoint. pub fn close(&self) -> Result<()> { - turso_assert!(!self.closed.get(), "Connection already closed"); + if self.closed.get() { + return Ok(()); + } self.closed.set(true); self.pager .checkpoint_shutdown(self.wal_checkpoint_disabled.get()) @@ -811,6 +834,9 @@ impl Connection { } pub fn parse_schema_rows(self: &Arc) -> Result<()> { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } let rows = self.query("SELECT * FROM sqlite_schema")?; let mut schema = self.schema.borrow_mut(); { @@ -829,6 +855,9 @@ impl Connection { // Clearly there is something to improve here, Vec> isn't a couple of tea /// Query the current rows/values of `pragma_name`. pub fn pragma_query(self: &Arc, pragma_name: &str) -> Result>> { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } let pragma = format!("PRAGMA {}", pragma_name); let mut stmt = self.prepare(pragma)?; let mut results = Vec::new(); @@ -857,6 +886,9 @@ impl Connection { pragma_name: &str, pragma_value: V, ) -> Result>> { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } let pragma = format!("PRAGMA {} = {}", pragma_name, pragma_value); let mut stmt = self.prepare(pragma)?; let mut results = Vec::new(); @@ -887,6 +919,9 @@ impl Connection { pragma_name: &str, pragma_value: V, ) -> Result>> { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } let pragma = format!("PRAGMA {}({})", pragma_name, pragma_value); let mut stmt = self.prepare(pragma)?; let mut results = Vec::new(); diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 8e3c4427d..7a45238cb 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -380,7 +380,7 @@ impl Program { pager: Rc, ) -> Result { loop { - if *self.connection.closed.borrow() { + if self.connection.closed.get() { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.transaction_state.get(); if let TransactionState::Write { schema_did_change } = state {