diff --git a/core/benches/mvcc_benchmark.rs b/core/benches/mvcc_benchmark.rs index 0ebd33fa5..7d316707d 100644 --- a/core/benches/mvcc_benchmark.rs +++ b/core/benches/mvcc_benchmark.rs @@ -36,8 +36,7 @@ fn bench(c: &mut Criterion) { let conn = db.conn.clone(); let tx_id = db.mvcc_store.begin_tx(conn.get_pager().clone()).unwrap(); db.mvcc_store - .rollback_tx(tx_id, conn.get_pager().clone(), &conn) - .unwrap(); + .rollback_tx(tx_id, conn.get_pager().clone(), &conn); }) }); diff --git a/core/lib.rs b/core/lib.rs index 29450c471..f45715dc2 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -498,7 +498,7 @@ impl Database { let result = schema .make_from_btree(None, pager.clone(), &syms) .or_else(|e| { - pager.end_read_tx()?; + pager.end_read_tx(); Err(e) }); if let Err(LimboError::ExtensionError(e)) = result { @@ -1195,11 +1195,11 @@ impl Connection { 0 } Err(err) => { - pager.end_read_tx().expect("read txn must be finished"); + pager.end_read_tx(); return Err(err); } }; - pager.end_read_tx().expect("read txn must be finished"); + pager.end_read_tx(); let db_schema_version = self.db.schema.lock().unwrap().schema_version; tracing::debug!( @@ -1236,7 +1236,7 @@ impl Connection { // close opened transaction if it was kept open // (in most cases, it will be automatically closed if stmt was executed properly) if previous == TransactionState::Read { - pager.end_read_tx().expect("read txn must be finished"); + pager.end_read_tx(); } reparse_result?; @@ -1654,7 +1654,7 @@ impl Connection { let pager = self.pager.read(); pager.begin_read_tx()?; pager.io.block(|| pager.begin_write_tx()).inspect_err(|_| { - pager.end_read_tx().expect("read txn must be closed"); + pager.end_read_tx(); })?; // start write transaction and disable auto-commit mode as SQL can be executed within WAL session (at caller own risk) @@ -1702,13 +1702,11 @@ impl Connection { wal.end_read_tx(); } - let rollback_err = if !force_commit { + if !force_commit { // remove all non-commited changes in case if WAL session left some suffix without commit frame - pager.rollback(false, self, true).err() - } else { - None - }; - if let Some(err) = commit_err.or(rollback_err) { + pager.rollback(false, self, true); + } + if let Some(err) = commit_err { return Err(err); } } @@ -1752,12 +1750,7 @@ impl Connection { _ => { if !self.mvcc_enabled() { let pager = self.pager.read(); - pager.io.block(|| { - pager.end_tx( - true, // rollback = true for close - self, - ) - })?; + pager.rollback_tx(self); } self.set_tx_state(TransactionState::None); } @@ -2632,12 +2625,8 @@ impl Statement { } let state = self.program.connection.get_tx_state(); if let TransactionState::Write { .. } = state { - let end_tx_res = self.pager.end_tx(true, &self.program.connection)?; + self.pager.rollback_tx(&self.program.connection); self.program.connection.set_tx_state(TransactionState::None); - assert!( - matches!(end_tx_res, IOResult::Done(_)), - "end_tx should not return IO as it should just end txn without flushing anything. Got {end_tx_res:?}" - ); } } res diff --git a/core/mvcc/database/checkpoint_state_machine.rs b/core/mvcc/database/checkpoint_state_machine.rs index a207d5ad2..fee93c2d8 100644 --- a/core/mvcc/database/checkpoint_state_machine.rs +++ b/core/mvcc/database/checkpoint_state_machine.rs @@ -548,7 +548,7 @@ impl CheckpointStateMachine { CheckpointState::CommitPagerTxn => { tracing::debug!("Committing pager transaction"); - let result = self.pager.end_tx(false, &self.connection)?; + let result = self.pager.commit_tx(&self.connection)?; match result { IOResult::Done(_) => { self.state = CheckpointState::TruncateLogicalLog; @@ -642,16 +642,12 @@ impl StateTransition for CheckpointStateMachine { Err(err) => { tracing::info!("Error in checkpoint state machine: {err}"); if self.lock_states.pager_write_tx { - let rollback = true; - self.pager - .io - .block(|| self.pager.end_tx(rollback, self.connection.as_ref())) - .expect("failed to end pager write tx"); + self.pager.rollback_tx(self.connection.as_ref()); if self.update_transaction_state { *self.connection.transaction_state.write() = TransactionState::None; } } else if self.lock_states.pager_read_tx { - self.pager.end_read_tx().unwrap(); + self.pager.end_read_tx(); if self.update_transaction_state { *self.connection.transaction_state.write() = TransactionState::None; } diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index 45b7cb7e8..a03fba7ba 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -1566,12 +1566,7 @@ impl MvStore { /// # Arguments /// /// * `tx_id` - The ID of the transaction to abort. - pub fn rollback_tx( - &self, - tx_id: TxID, - _pager: Arc, - connection: &Connection, - ) -> Result<()> { + pub fn rollback_tx(&self, tx_id: TxID, _pager: Arc, connection: &Connection) { let tx_unlocked = self.txs.get(&tx_id).unwrap(); let tx = tx_unlocked.value(); *connection.mv_tx.write() = None; @@ -1615,8 +1610,6 @@ impl MvStore { // FIXME: verify that we can already remove the transaction here! // Maybe it's fine for snapshot isolation, but too early for serializable? self.remove_tx(tx_id); - - Ok(()) } /// Returns true if the given transaction is the exclusive transaction. diff --git a/core/mvcc/database/tests.rs b/core/mvcc/database/tests.rs index 35a45e728..e559bda52 100644 --- a/core/mvcc/database/tests.rs +++ b/core/mvcc/database/tests.rs @@ -347,8 +347,7 @@ fn test_rollback() { .unwrap(); assert_eq!(row3, row4); db.mvcc_store - .rollback_tx(tx1, db.conn.pager.read().clone(), &db.conn) - .unwrap(); + .rollback_tx(tx1, db.conn.pager.read().clone(), &db.conn); let tx2 = db .mvcc_store .begin_tx(db.conn.pager.read().clone()) @@ -592,8 +591,7 @@ fn test_lost_update() { )); // hack: in the actual tursodb database we rollback the mvcc tx ourselves, so manually roll it back here db.mvcc_store - .rollback_tx(tx3, conn3.pager.read().clone(), &conn3) - .unwrap(); + .rollback_tx(tx3, conn3.pager.read().clone(), &conn3); commit_tx(db.mvcc_store.clone(), &conn2, tx2).unwrap(); assert!(matches!( diff --git a/core/schema.rs b/core/schema.rs index 5188f962f..9208c5d1d 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -472,7 +472,7 @@ impl Schema { pager.io.block(|| cursor.next())?; } - pager.end_read_tx()?; + pager.end_read_tx(); self.populate_indices(from_sql_indexes, automatic_indices)?; diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 0dfaff8c1..9ab4c689e 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -8183,7 +8183,7 @@ mod tests { // force allocate page1 with a transaction pager.begin_read_tx().unwrap(); run_until_done(|| pager.begin_write_tx(), &pager).unwrap(); - run_until_done(|| pager.end_tx(false, &conn), &pager).unwrap(); + run_until_done(|| pager.commit_tx(&conn), &pager).unwrap(); let page2 = run_until_done(|| pager.allocate_page(), &pager).unwrap(); btree_init_page(&page2, PageType::TableLeaf, 0, pager.usable_space()); @@ -8495,7 +8495,7 @@ mod tests { pager.deref(), ) .unwrap(); - pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); + pager.io.block(|| pager.commit_tx(&conn)).unwrap(); pager.begin_read_tx().unwrap(); // FIXME: add sorted vector instead, should be okay for small amounts of keys for now :P, too lazy to fix right now let _c = cursor.move_to_root().unwrap(); @@ -8524,7 +8524,7 @@ mod tests { println!("btree after:\n{btree_after}"); panic!("invalid btree"); } - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } pager.begin_read_tx().unwrap(); tracing::info!( @@ -8546,7 +8546,7 @@ mod tests { "key {key} is not found, got {cursor_rowid}" ); } - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } } @@ -8641,7 +8641,7 @@ mod tests { if let Some(c) = c { pager.io.wait_for_completion(c).unwrap(); } - pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); + pager.io.block(|| pager.commit_tx(&conn)).unwrap(); } // Check that all keys can be found by seeking @@ -8702,7 +8702,7 @@ mod tests { } prev = Some(cur); } - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } } @@ -8848,7 +8848,7 @@ mod tests { if let Some(c) = c { pager.io.wait_for_completion(c).unwrap(); } - pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); + pager.io.block(|| pager.commit_tx(&conn)).unwrap(); } // Final validation @@ -8856,7 +8856,7 @@ mod tests { sorted_keys.sort(); validate_expected_keys(&pager, &mut cursor, &sorted_keys, seed); - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } } @@ -8939,7 +8939,7 @@ mod tests { "key {key:?} is not found, seed: {seed}" ); } - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } #[test] diff --git a/core/storage/pager.rs b/core/storage/pager.rs index f2bde3d04..32f9e834e 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1161,33 +1161,20 @@ impl Pager { } #[instrument(skip_all, level = Level::DEBUG)] - pub fn end_tx( - &self, - rollback: bool, - connection: &Connection, - ) -> Result> { + pub fn commit_tx(&self, connection: &Connection) -> Result> { if connection.is_nested_stmt.load(Ordering::SeqCst) { // Parent statement will handle the transaction rollback. return Ok(IOResult::Done(PagerCommitResult::Rollback)); } - tracing::trace!("end_tx(rollback={})", rollback); let Some(wal) = self.wal.as_ref() else { // TODO: Unsure what the semantics of "end_tx" is for in-memory databases, ephemeral tables and ephemeral indexes. return Ok(IOResult::Done(PagerCommitResult::Rollback)); }; - let (is_write, schema_did_change) = match connection.get_tx_state() { + let (_, schema_did_change) = match connection.get_tx_state() { TransactionState::Write { schema_did_change } => (true, schema_did_change), _ => (false, false), }; - tracing::trace!("end_tx(schema_did_change={})", schema_did_change); - if rollback { - if is_write { - wal.borrow().end_write_tx(); - } - wal.borrow().end_read_tx(); - self.rollback(schema_did_change, connection, is_write)?; - return Ok(IOResult::Done(PagerCommitResult::Rollback)); - } + tracing::trace!("commit_tx(schema_did_change={})", schema_did_change); let commit_status = return_if_io!(self.commit_dirty_pages( connection.is_wal_auto_checkpoint_disabled(), connection.get_sync_mode(), @@ -1204,12 +1191,33 @@ impl Pager { } #[instrument(skip_all, level = Level::DEBUG)] - pub fn end_read_tx(&self) -> Result<()> { + pub fn rollback_tx(&self, connection: &Connection) { + if connection.is_nested_stmt.load(Ordering::SeqCst) { + // Parent statement will handle the transaction rollback. + return; + } let Some(wal) = self.wal.as_ref() else { - return Ok(()); + // TODO: Unsure what the semantics of "end_tx" is for in-memory databases, ephemeral tables and ephemeral indexes. + return; + }; + let (is_write, schema_did_change) = match connection.get_tx_state() { + TransactionState::Write { schema_did_change } => (true, schema_did_change), + _ => (false, false), + }; + tracing::trace!("rollback_tx(schema_did_change={})", schema_did_change); + if is_write { + wal.borrow().end_write_tx(); + } + wal.borrow().end_read_tx(); + self.rollback(schema_did_change, connection, is_write); + } + + #[instrument(skip_all, level = Level::DEBUG)] + pub fn end_read_tx(&self) { + let Some(wal) = self.wal.as_ref() else { + return; }; wal.borrow().end_read_tx(); - Ok(()) } /// Reads a page from disk (either WAL or DB file) bypassing page-cache @@ -2393,12 +2401,7 @@ impl Pager { } #[instrument(skip_all, level = Level::DEBUG)] - pub fn rollback( - &self, - schema_did_change: bool, - connection: &Connection, - is_write: bool, - ) -> Result<(), LimboError> { + pub fn rollback(&self, schema_did_change: bool, connection: &Connection, is_write: bool) { tracing::debug!(schema_did_change); self.clear_page_cache(); if is_write { @@ -2415,11 +2418,9 @@ impl Pager { } if is_write { if let Some(wal) = self.wal.as_ref() { - wal.borrow_mut().rollback()?; + wal.borrow_mut().rollback(); } } - - Ok(()) } fn reset_internal_states(&self) { @@ -2764,7 +2765,7 @@ mod ptrmap_tests { use super::*; use crate::io::{MemoryIO, OpenFlags, IO}; use crate::storage::buffer_pool::BufferPool; - use crate::storage::database::{DatabaseFile, DatabaseStorage}; + use crate::storage::database::DatabaseFile; use crate::storage::page_cache::PageCache; use crate::storage::pager::Pager; use crate::storage::sqlite3_ondisk::PageSize; diff --git a/core/storage/wal.rs b/core/storage/wal.rs index c590219bd..1f55e1cc0 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -302,7 +302,7 @@ pub trait Wal: Debug { fn get_checkpoint_seq(&self) -> u32; fn get_max_frame(&self) -> u64; fn get_min_frame(&self) -> u64; - fn rollback(&mut self) -> Result<()>; + fn rollback(&mut self); /// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no fn changed_pages_after(&self, frame_watermark: u64) -> Result>; @@ -1351,8 +1351,8 @@ impl Wal for WalFile { self.min_frame.load(Ordering::Acquire) } - #[instrument(err, skip_all, level = Level::DEBUG)] - fn rollback(&mut self) -> Result<()> { + #[instrument(skip_all, level = Level::DEBUG)] + fn rollback(&mut self) { let (max_frame, last_checksum) = { let shared = self.get_shared(); let max_frame = shared.max_frame.load(Ordering::Acquire); @@ -1369,7 +1369,6 @@ impl Wal for WalFile { self.last_checksum = last_checksum; self.max_frame.store(max_frame, Ordering::Release); self.reset_internal_states(); - Ok(()) } #[instrument(skip_all, level = Level::DEBUG)] @@ -2825,7 +2824,7 @@ pub mod test { } } drop(w); - conn2.pager.write().end_read_tx().unwrap(); + conn2.pager.write().end_read_tx(); conn1 .execute("create table test(id integer primary key, value text)") diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 87fbaec0a..ec5d6f3f9 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -2372,7 +2372,7 @@ pub fn op_transaction_inner( // That is, if the transaction had not started, end the read transaction so that next time we // start a new one. if matches!(current_state, TransactionState::None) { - pager.end_read_tx()?; + pager.end_read_tx(); conn.set_tx_state(TransactionState::None); } assert_eq!(conn.get_tx_state(), current_state); @@ -2456,10 +2456,10 @@ pub fn op_auto_commit( // TODO(pere): add rollback I/O logic once we implement rollback journal if let Some(mv_store) = mv_store { if let Some(tx_id) = conn.get_mv_tx_id() { - mv_store.rollback_tx(tx_id, pager.clone(), &conn)?; + mv_store.rollback_tx(tx_id, pager.clone(), &conn); } } else { - return_if_io!(pager.end_tx(true, &conn)); + pager.rollback_tx(&conn); } conn.set_tx_state(TransactionState::None); conn.auto_commit.store(true, Ordering::SeqCst); diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index fa1a88df8..ad3c7ad3b 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -30,7 +30,7 @@ use crate::{ function::{AggFunc, FuncCtx}, mvcc::{database::CommitStateMachine, LocalClock}, state_machine::StateMachine, - storage::sqlite3_ondisk::SmallVec, + storage::{pager::PagerCommitResult, sqlite3_ondisk::SmallVec}, translate::{collate::CollationSeq, plan::TableReferences}, types::{IOCompletions, IOResult, RawSlice, TextRef}, vdbe::{ @@ -41,7 +41,7 @@ use crate::{ }, metrics::StatementMetrics, }, - IOExt, RefValue, + RefValue, }; use crate::{ @@ -533,7 +533,7 @@ impl Program { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.get_tx_state(); if let TransactionState::Write { .. } = state { - pager.io.block(|| pager.end_tx(true, &self.connection))?; + pager.rollback_tx(&self.connection); } return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -588,7 +588,7 @@ impl Program { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.get_tx_state(); if let TransactionState::Write { .. } = state { - pager.io.block(|| pager.end_tx(true, &self.connection))?; + pager.rollback_tx(&self.connection); } return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -636,7 +636,7 @@ impl Program { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.get_tx_state(); if let TransactionState::Write { .. } = state { - pager.io.block(|| pager.end_tx(true, &self.connection))?; + pager.rollback_tx(&self.connection); } return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -888,7 +888,7 @@ impl Program { ), TransactionState::Read => { connection.set_tx_state(TransactionState::None); - pager.end_read_tx()?; + pager.end_read_tx(); Ok(IOResult::Done(())) } TransactionState::None => Ok(IOResult::Done(())), @@ -914,7 +914,12 @@ impl Program { connection: &Connection, rollback: bool, ) -> Result> { - let cacheflush_status = pager.end_tx(rollback, connection)?; + let cacheflush_status = if !rollback { + pager.commit_tx(connection)? + } else { + pager.rollback_tx(connection); + IOResult::Done(PagerCommitResult::Rollback) + }; match cacheflush_status { IOResult::Done(_) => { if self.change_cnt_on {