diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 94a6556c9..d8d597cc6 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -48,6 +48,8 @@ pub use params::IntoParams; use std::fmt::Debug; use std::future::Future; use std::num::NonZero; +use std::sync::atomic::AtomicU8; +use std::sync::atomic::Ordering; use std::sync::{Arc, Mutex}; use std::task::Poll; pub use turso_core::EncryptionOpts; @@ -55,6 +57,7 @@ use turso_core::OpenFlags; // Re-exports rows pub use crate::rows::{Row, Rows}; +use crate::transaction::DropBehavior; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -218,10 +221,39 @@ impl Database { } } +/// Atomic wrapper for [DropBehavior] +struct AtomicDropBehavior { + inner: AtomicU8, +} + +impl AtomicDropBehavior { + fn new(behavior: DropBehavior) -> Self { + Self { + inner: AtomicU8::new(behavior.into()), + } + } + + fn load(&self, ordering: Ordering) -> DropBehavior { + self.inner.load(ordering).into() + } + + fn store(&self, behavior: DropBehavior, ordering: Ordering) { + self.inner.store(behavior.into(), ordering); + } +} + /// A database connection. pub struct Connection { inner: Arc>>, transaction_behavior: TransactionBehavior, + /// If there is a dangling transaction after it was dropped without being finished, + /// [Connection::dangling_tx] will be set to the [DropBehavior] of the dangling transaction, + /// and the corresponding action will be taken when a new transaction is requested + /// or the connection queries/executes. + /// We cannot do this eagerly on Drop because drop is not async. + /// + /// By default, the value is [DropBehavior::Ignore] which effectively does nothing. + dangling_tx: AtomicDropBehavior, } impl Clone for Connection { @@ -229,6 +261,7 @@ impl Clone for Connection { Self { inner: Arc::clone(&self.inner), transaction_behavior: self.transaction_behavior, + dangling_tx: AtomicDropBehavior::new(self.dangling_tx.load(Ordering::SeqCst)), } } } @@ -242,17 +275,43 @@ impl Connection { let connection = Connection { inner: Arc::new(Mutex::new(conn)), transaction_behavior: TransactionBehavior::Deferred, + dangling_tx: AtomicDropBehavior::new(DropBehavior::Ignore), }; connection } + + async fn maybe_handle_dangling_tx(&self) -> Result<()> { + match self.dangling_tx.load(Ordering::SeqCst) { + DropBehavior::Rollback => { + let mut stmt = self.prepare("ROLLBACK").await?; + stmt.execute(()).await?; + self.dangling_tx + .store(DropBehavior::Ignore, Ordering::SeqCst); + } + DropBehavior::Commit => { + let mut stmt = self.prepare("COMMIT").await?; + stmt.execute(()).await?; + self.dangling_tx + .store(DropBehavior::Ignore, Ordering::SeqCst); + } + DropBehavior::Ignore => {} + DropBehavior::Panic => { + panic!("Transaction dropped unexpectedly."); + } + } + Ok(()) + } + /// Query the database with SQL. pub async fn query(&self, sql: &str, params: impl IntoParams) -> Result { + self.maybe_handle_dangling_tx().await?; let mut stmt = self.prepare(sql).await?; stmt.query(params).await } /// Execute SQL statement on the database. pub async fn execute(&self, sql: &str, params: impl IntoParams) -> Result { + self.maybe_handle_dangling_tx().await?; let mut stmt = self.prepare(sql).await?; stmt.execute(params).await } @@ -337,6 +396,7 @@ impl Connection { /// Execute a batch of SQL statements on the database. pub async fn execute_batch(&self, sql: &str) -> Result<()> { + self.maybe_handle_dangling_tx().await?; self.prepare_execute_batch(sql).await?; Ok(()) } @@ -358,6 +418,7 @@ impl Connection { } async fn prepare_execute_batch(&self, sql: impl AsRef) -> Result<()> { + self.maybe_handle_dangling_tx().await?; let conn = self .inner .lock() diff --git a/bindings/rust/src/transaction.rs b/bindings/rust/src/transaction.rs index 6da5c133d..dd77ed0d8 100644 --- a/bindings/rust/src/transaction.rs +++ b/bindings/rust/src/transaction.rs @@ -1,4 +1,4 @@ -use std::ops::Deref; +use std::{ops::Deref, sync::atomic::Ordering}; use crate::{Connection, Result}; @@ -36,13 +36,36 @@ pub enum DropBehavior { Panic, } +impl From for u8 { + fn from(behavior: DropBehavior) -> Self { + match behavior { + DropBehavior::Rollback => 0, + DropBehavior::Commit => 1, + DropBehavior::Ignore => 2, + DropBehavior::Panic => 3, + } + } +} + +impl From for DropBehavior { + fn from(value: u8) -> Self { + match value { + 0 => DropBehavior::Rollback, + 1 => DropBehavior::Commit, + 2 => DropBehavior::Ignore, + 3 => DropBehavior::Panic, + _ => panic!("Invalid drop behavior: {value}"), + } + } +} + /// Represents a transaction on a database connection. /// /// ## Note /// /// Transactions will roll back by default. Use `commit` method to explicitly /// commit the transaction, or use `set_drop_behavior` to change what happens -/// when the transaction is dropped. +/// on the next access to the connection after the transaction is dropped. /// /// ## Example /// @@ -63,7 +86,7 @@ pub enum DropBehavior { pub struct Transaction<'conn> { conn: &'conn Connection, drop_behavior: DropBehavior, - must_finish: bool, + in_progress: bool, } impl Transaction<'_> { @@ -99,7 +122,7 @@ impl Transaction<'_> { conn.execute(query, ()).await.map(move |_| Transaction { conn, drop_behavior: DropBehavior::Rollback, - must_finish: true, + in_progress: true, }) } @@ -126,8 +149,8 @@ impl Transaction<'_> { #[inline] async fn _commit(&mut self) -> Result<()> { - self.must_finish = false; self.conn.execute("COMMIT", ()).await?; + self.in_progress = false; Ok(()) } @@ -139,8 +162,8 @@ impl Transaction<'_> { #[inline] async fn _rollback(&mut self) -> Result<()> { - self.must_finish = false; self.conn.execute("ROLLBACK", ()).await?; + self.in_progress = false; Ok(()) } @@ -186,8 +209,14 @@ impl Deref for Transaction<'_> { impl Drop for Transaction<'_> { #[inline] fn drop(&mut self) { - if self.must_finish { - panic!("Transaction dropped without finish()") + if self.in_progress { + self.conn + .dangling_tx + .store(self.drop_behavior(), Ordering::SeqCst); + } else { + self.conn + .dangling_tx + .store(DropBehavior::Ignore, Ordering::SeqCst); } } } @@ -195,7 +224,8 @@ impl Drop for Transaction<'_> { impl Connection { /// Begin a new transaction with the default behavior (DEFERRED). /// - /// The transaction defaults to rolling back when it is dropped. If you + /// The transaction defaults to rolling back on the next access to the connection + /// if it is not finished when the transaction is dropped. If you /// want the transaction to commit, you must call /// [`commit`](Transaction::commit) or /// [`set_drop_behavior(DropBehavior::Commit)`](Transaction::set_drop_behavior). @@ -221,7 +251,8 @@ impl Connection { /// Will return `Err` if the call fails. #[inline] pub async fn transaction(&mut self) -> Result> { - Transaction::new(self, self.transaction_behavior).await + self.transaction_with_behavior(self.transaction_behavior) + .await } /// Begin a new transaction with a specified behavior. @@ -236,6 +267,7 @@ impl Connection { &mut self, behavior: TransactionBehavior, ) -> Result> { + self.maybe_handle_dangling_tx().await?; Transaction::new(self, behavior).await } @@ -318,13 +350,66 @@ mod test { } #[tokio::test] - #[should_panic(expected = "Transaction dropped without finish()")] - async fn test_drop_panic() { + async fn test_drop_rollback_on_new_transaction() { let mut conn = checked_memory_handle().await.unwrap(); { let tx = conn.transaction().await.unwrap(); tx.execute("INSERT INTO foo VALUES(?)", &[1]).await.unwrap(); + // Drop without finish - should be rolled back when next transaction starts } + + // Start a new transaction - this should rollback the dangling one + let tx = conn.transaction().await.unwrap(); + tx.execute("INSERT INTO foo VALUES(?)", &[2]).await.unwrap(); + let result = tx + .prepare("SELECT SUM(x) FROM foo") + .await + .unwrap() + .query_row(()) + .await + .unwrap(); + + // The insert from the dropped transaction should have been rolled back + assert_eq!(2, result.get::(0).unwrap()); + tx.finish().await.unwrap(); + } + + #[tokio::test] + async fn test_drop_rollback_on_query() { + let mut conn = checked_memory_handle().await.unwrap(); + { + let tx = conn.transaction().await.unwrap(); + tx.execute("INSERT INTO foo VALUES(?)", &[1]).await.unwrap(); + // Drop without finish - should be rolled back when conn.query is called + } + + // Using conn.query should rollback the dangling transaction + let mut rows = conn.query("SELECT count(*) FROM foo", ()).await.unwrap(); + let result = rows.next().await.unwrap().unwrap(); + + // The insert from the dropped transaction should have been rolled back + assert_eq!(0, result.get::(0).unwrap()); + } + + #[tokio::test] + async fn test_drop_rollback_on_execute() { + let mut conn = checked_memory_handle().await.unwrap(); + { + let tx = conn.transaction().await.unwrap(); + tx.execute("INSERT INTO foo VALUES(?)", &[1]).await.unwrap(); + // Drop without finish - should be rolled back when conn.execute is called + } + + // Using conn.execute should rollback the dangling transaction + conn.execute("INSERT INTO foo VALUES(?)", &[2]) + .await + .unwrap(); + + let mut rows = conn.query("SELECT count(*) FROM foo", ()).await.unwrap(); + let result = rows.next().await.unwrap().unwrap(); + + // The insert from the dropped transaction should have been rolled back + assert_eq!(1, result.get::(0).unwrap()); } #[tokio::test] @@ -334,14 +419,12 @@ mod test { { let tx = conn.transaction().await?; tx.execute("INSERT INTO foo VALUES(?)", &[1]).await?; - tx.finish().await?; // default: rollback } { let mut tx = conn.transaction().await?; tx.execute("INSERT INTO foo VALUES(?)", &[2]).await?; tx.set_drop_behavior(DropBehavior::Commit); - tx.finish().await?; } { let tx = conn.transaction().await?; @@ -352,7 +435,6 @@ mod test { .await?; assert_eq!(2, result.get::(0)?); - tx.finish().await?; } Ok(()) }