make rollback non-failing method

This commit is contained in:
Nikita Sivukhin
2025-10-06 13:21:45 +04:00
parent 38d2630969
commit 8dae601fac
11 changed files with 77 additions and 97 deletions

View File

@@ -36,8 +36,7 @@ fn bench(c: &mut Criterion) {
let conn = db.conn.clone();
let tx_id = db.mvcc_store.begin_tx(conn.get_pager().clone()).unwrap();
db.mvcc_store
.rollback_tx(tx_id, conn.get_pager().clone(), &conn)
.unwrap();
.rollback_tx(tx_id, conn.get_pager().clone(), &conn);
})
});

View File

@@ -498,7 +498,7 @@ impl Database {
let result = schema
.make_from_btree(None, pager.clone(), &syms)
.or_else(|e| {
pager.end_read_tx()?;
pager.end_read_tx();
Err(e)
});
if let Err(LimboError::ExtensionError(e)) = result {
@@ -1195,11 +1195,11 @@ impl Connection {
0
}
Err(err) => {
pager.end_read_tx().expect("read txn must be finished");
pager.end_read_tx();
return Err(err);
}
};
pager.end_read_tx().expect("read txn must be finished");
pager.end_read_tx();
let db_schema_version = self.db.schema.lock().unwrap().schema_version;
tracing::debug!(
@@ -1236,7 +1236,7 @@ impl Connection {
// close opened transaction if it was kept open
// (in most cases, it will be automatically closed if stmt was executed properly)
if previous == TransactionState::Read {
pager.end_read_tx().expect("read txn must be finished");
pager.end_read_tx();
}
reparse_result?;
@@ -1654,7 +1654,7 @@ impl Connection {
let pager = self.pager.read();
pager.begin_read_tx()?;
pager.io.block(|| pager.begin_write_tx()).inspect_err(|_| {
pager.end_read_tx().expect("read txn must be closed");
pager.end_read_tx();
})?;
// start write transaction and disable auto-commit mode as SQL can be executed within WAL session (at caller own risk)
@@ -1702,13 +1702,11 @@ impl Connection {
wal.end_read_tx();
}
let rollback_err = if !force_commit {
if !force_commit {
// remove all non-commited changes in case if WAL session left some suffix without commit frame
pager.rollback(false, self, true).err()
} else {
None
};
if let Some(err) = commit_err.or(rollback_err) {
pager.rollback(false, self, true);
}
if let Some(err) = commit_err {
return Err(err);
}
}
@@ -1752,12 +1750,7 @@ impl Connection {
_ => {
if !self.mvcc_enabled() {
let pager = self.pager.read();
pager.io.block(|| {
pager.end_tx(
true, // rollback = true for close
self,
)
})?;
pager.rollback_tx(self);
}
self.set_tx_state(TransactionState::None);
}
@@ -2632,12 +2625,8 @@ impl Statement {
}
let state = self.program.connection.get_tx_state();
if let TransactionState::Write { .. } = state {
let end_tx_res = self.pager.end_tx(true, &self.program.connection)?;
self.pager.rollback_tx(&self.program.connection);
self.program.connection.set_tx_state(TransactionState::None);
assert!(
matches!(end_tx_res, IOResult::Done(_)),
"end_tx should not return IO as it should just end txn without flushing anything. Got {end_tx_res:?}"
);
}
}
res

View File

@@ -548,7 +548,7 @@ impl<Clock: LogicalClock> CheckpointStateMachine<Clock> {
CheckpointState::CommitPagerTxn => {
tracing::debug!("Committing pager transaction");
let result = self.pager.end_tx(false, &self.connection)?;
let result = self.pager.commit_tx(&self.connection)?;
match result {
IOResult::Done(_) => {
self.state = CheckpointState::TruncateLogicalLog;
@@ -642,16 +642,12 @@ impl<Clock: LogicalClock> StateTransition for CheckpointStateMachine<Clock> {
Err(err) => {
tracing::info!("Error in checkpoint state machine: {err}");
if self.lock_states.pager_write_tx {
let rollback = true;
self.pager
.io
.block(|| self.pager.end_tx(rollback, self.connection.as_ref()))
.expect("failed to end pager write tx");
self.pager.rollback_tx(self.connection.as_ref());
if self.update_transaction_state {
*self.connection.transaction_state.write() = TransactionState::None;
}
} else if self.lock_states.pager_read_tx {
self.pager.end_read_tx().unwrap();
self.pager.end_read_tx();
if self.update_transaction_state {
*self.connection.transaction_state.write() = TransactionState::None;
}

View File

@@ -1566,12 +1566,7 @@ impl<Clock: LogicalClock> MvStore<Clock> {
/// # Arguments
///
/// * `tx_id` - The ID of the transaction to abort.
pub fn rollback_tx(
&self,
tx_id: TxID,
_pager: Arc<Pager>,
connection: &Connection,
) -> Result<()> {
pub fn rollback_tx(&self, tx_id: TxID, _pager: Arc<Pager>, connection: &Connection) {
let tx_unlocked = self.txs.get(&tx_id).unwrap();
let tx = tx_unlocked.value();
*connection.mv_tx.write() = None;
@@ -1615,8 +1610,6 @@ impl<Clock: LogicalClock> MvStore<Clock> {
// FIXME: verify that we can already remove the transaction here!
// Maybe it's fine for snapshot isolation, but too early for serializable?
self.remove_tx(tx_id);
Ok(())
}
/// Returns true if the given transaction is the exclusive transaction.

View File

@@ -347,8 +347,7 @@ fn test_rollback() {
.unwrap();
assert_eq!(row3, row4);
db.mvcc_store
.rollback_tx(tx1, db.conn.pager.read().clone(), &db.conn)
.unwrap();
.rollback_tx(tx1, db.conn.pager.read().clone(), &db.conn);
let tx2 = db
.mvcc_store
.begin_tx(db.conn.pager.read().clone())
@@ -592,8 +591,7 @@ fn test_lost_update() {
));
// hack: in the actual tursodb database we rollback the mvcc tx ourselves, so manually roll it back here
db.mvcc_store
.rollback_tx(tx3, conn3.pager.read().clone(), &conn3)
.unwrap();
.rollback_tx(tx3, conn3.pager.read().clone(), &conn3);
commit_tx(db.mvcc_store.clone(), &conn2, tx2).unwrap();
assert!(matches!(

View File

@@ -472,7 +472,7 @@ impl Schema {
pager.io.block(|| cursor.next())?;
}
pager.end_read_tx()?;
pager.end_read_tx();
self.populate_indices(from_sql_indexes, automatic_indices)?;

View File

@@ -8183,7 +8183,7 @@ mod tests {
// force allocate page1 with a transaction
pager.begin_read_tx().unwrap();
run_until_done(|| pager.begin_write_tx(), &pager).unwrap();
run_until_done(|| pager.end_tx(false, &conn), &pager).unwrap();
run_until_done(|| pager.commit_tx(&conn), &pager).unwrap();
let page2 = run_until_done(|| pager.allocate_page(), &pager).unwrap();
btree_init_page(&page2, PageType::TableLeaf, 0, pager.usable_space());
@@ -8495,7 +8495,7 @@ mod tests {
pager.deref(),
)
.unwrap();
pager.io.block(|| pager.end_tx(false, &conn)).unwrap();
pager.io.block(|| pager.commit_tx(&conn)).unwrap();
pager.begin_read_tx().unwrap();
// FIXME: add sorted vector instead, should be okay for small amounts of keys for now :P, too lazy to fix right now
let _c = cursor.move_to_root().unwrap();
@@ -8524,7 +8524,7 @@ mod tests {
println!("btree after:\n{btree_after}");
panic!("invalid btree");
}
pager.end_read_tx().unwrap();
pager.end_read_tx();
}
pager.begin_read_tx().unwrap();
tracing::info!(
@@ -8546,7 +8546,7 @@ mod tests {
"key {key} is not found, got {cursor_rowid}"
);
}
pager.end_read_tx().unwrap();
pager.end_read_tx();
}
}
@@ -8641,7 +8641,7 @@ mod tests {
if let Some(c) = c {
pager.io.wait_for_completion(c).unwrap();
}
pager.io.block(|| pager.end_tx(false, &conn)).unwrap();
pager.io.block(|| pager.commit_tx(&conn)).unwrap();
}
// Check that all keys can be found by seeking
@@ -8702,7 +8702,7 @@ mod tests {
}
prev = Some(cur);
}
pager.end_read_tx().unwrap();
pager.end_read_tx();
}
}
@@ -8848,7 +8848,7 @@ mod tests {
if let Some(c) = c {
pager.io.wait_for_completion(c).unwrap();
}
pager.io.block(|| pager.end_tx(false, &conn)).unwrap();
pager.io.block(|| pager.commit_tx(&conn)).unwrap();
}
// Final validation
@@ -8856,7 +8856,7 @@ mod tests {
sorted_keys.sort();
validate_expected_keys(&pager, &mut cursor, &sorted_keys, seed);
pager.end_read_tx().unwrap();
pager.end_read_tx();
}
}
@@ -8939,7 +8939,7 @@ mod tests {
"key {key:?} is not found, seed: {seed}"
);
}
pager.end_read_tx().unwrap();
pager.end_read_tx();
}
#[test]

View File

@@ -1161,33 +1161,20 @@ impl Pager {
}
#[instrument(skip_all, level = Level::DEBUG)]
pub fn end_tx(
&self,
rollback: bool,
connection: &Connection,
) -> Result<IOResult<PagerCommitResult>> {
pub fn commit_tx(&self, connection: &Connection) -> Result<IOResult<PagerCommitResult>> {
if connection.is_nested_stmt.load(Ordering::SeqCst) {
// Parent statement will handle the transaction rollback.
return Ok(IOResult::Done(PagerCommitResult::Rollback));
}
tracing::trace!("end_tx(rollback={})", rollback);
let Some(wal) = self.wal.as_ref() else {
// TODO: Unsure what the semantics of "end_tx" is for in-memory databases, ephemeral tables and ephemeral indexes.
return Ok(IOResult::Done(PagerCommitResult::Rollback));
};
let (is_write, schema_did_change) = match connection.get_tx_state() {
let (_, schema_did_change) = match connection.get_tx_state() {
TransactionState::Write { schema_did_change } => (true, schema_did_change),
_ => (false, false),
};
tracing::trace!("end_tx(schema_did_change={})", schema_did_change);
if rollback {
if is_write {
wal.borrow().end_write_tx();
}
wal.borrow().end_read_tx();
self.rollback(schema_did_change, connection, is_write)?;
return Ok(IOResult::Done(PagerCommitResult::Rollback));
}
tracing::trace!("commit_tx(schema_did_change={})", schema_did_change);
let commit_status = return_if_io!(self.commit_dirty_pages(
connection.is_wal_auto_checkpoint_disabled(),
connection.get_sync_mode(),
@@ -1204,12 +1191,33 @@ impl Pager {
}
#[instrument(skip_all, level = Level::DEBUG)]
pub fn end_read_tx(&self) -> Result<()> {
pub fn rollback_tx(&self, connection: &Connection) {
if connection.is_nested_stmt.load(Ordering::SeqCst) {
// Parent statement will handle the transaction rollback.
return;
}
let Some(wal) = self.wal.as_ref() else {
return Ok(());
// TODO: Unsure what the semantics of "end_tx" is for in-memory databases, ephemeral tables and ephemeral indexes.
return;
};
let (is_write, schema_did_change) = match connection.get_tx_state() {
TransactionState::Write { schema_did_change } => (true, schema_did_change),
_ => (false, false),
};
tracing::trace!("rollback_tx(schema_did_change={})", schema_did_change);
if is_write {
wal.borrow().end_write_tx();
}
wal.borrow().end_read_tx();
self.rollback(schema_did_change, connection, is_write);
}
#[instrument(skip_all, level = Level::DEBUG)]
pub fn end_read_tx(&self) {
let Some(wal) = self.wal.as_ref() else {
return;
};
wal.borrow().end_read_tx();
Ok(())
}
/// Reads a page from disk (either WAL or DB file) bypassing page-cache
@@ -2393,12 +2401,7 @@ impl Pager {
}
#[instrument(skip_all, level = Level::DEBUG)]
pub fn rollback(
&self,
schema_did_change: bool,
connection: &Connection,
is_write: bool,
) -> Result<(), LimboError> {
pub fn rollback(&self, schema_did_change: bool, connection: &Connection, is_write: bool) {
tracing::debug!(schema_did_change);
self.clear_page_cache();
if is_write {
@@ -2415,11 +2418,9 @@ impl Pager {
}
if is_write {
if let Some(wal) = self.wal.as_ref() {
wal.borrow_mut().rollback()?;
wal.borrow_mut().rollback();
}
}
Ok(())
}
fn reset_internal_states(&self) {
@@ -2764,7 +2765,7 @@ mod ptrmap_tests {
use super::*;
use crate::io::{MemoryIO, OpenFlags, IO};
use crate::storage::buffer_pool::BufferPool;
use crate::storage::database::{DatabaseFile, DatabaseStorage};
use crate::storage::database::DatabaseFile;
use crate::storage::page_cache::PageCache;
use crate::storage::pager::Pager;
use crate::storage::sqlite3_ondisk::PageSize;

View File

@@ -302,7 +302,7 @@ pub trait Wal: Debug {
fn get_checkpoint_seq(&self) -> u32;
fn get_max_frame(&self) -> u64;
fn get_min_frame(&self) -> u64;
fn rollback(&mut self) -> Result<()>;
fn rollback(&mut self);
/// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no
fn changed_pages_after(&self, frame_watermark: u64) -> Result<Vec<u32>>;
@@ -1351,8 +1351,8 @@ impl Wal for WalFile {
self.min_frame.load(Ordering::Acquire)
}
#[instrument(err, skip_all, level = Level::DEBUG)]
fn rollback(&mut self) -> Result<()> {
#[instrument(skip_all, level = Level::DEBUG)]
fn rollback(&mut self) {
let (max_frame, last_checksum) = {
let shared = self.get_shared();
let max_frame = shared.max_frame.load(Ordering::Acquire);
@@ -1369,7 +1369,6 @@ impl Wal for WalFile {
self.last_checksum = last_checksum;
self.max_frame.store(max_frame, Ordering::Release);
self.reset_internal_states();
Ok(())
}
#[instrument(skip_all, level = Level::DEBUG)]
@@ -2825,7 +2824,7 @@ pub mod test {
}
}
drop(w);
conn2.pager.write().end_read_tx().unwrap();
conn2.pager.write().end_read_tx();
conn1
.execute("create table test(id integer primary key, value text)")

View File

@@ -2372,7 +2372,7 @@ pub fn op_transaction_inner(
// That is, if the transaction had not started, end the read transaction so that next time we
// start a new one.
if matches!(current_state, TransactionState::None) {
pager.end_read_tx()?;
pager.end_read_tx();
conn.set_tx_state(TransactionState::None);
}
assert_eq!(conn.get_tx_state(), current_state);
@@ -2456,10 +2456,10 @@ pub fn op_auto_commit(
// TODO(pere): add rollback I/O logic once we implement rollback journal
if let Some(mv_store) = mv_store {
if let Some(tx_id) = conn.get_mv_tx_id() {
mv_store.rollback_tx(tx_id, pager.clone(), &conn)?;
mv_store.rollback_tx(tx_id, pager.clone(), &conn);
}
} else {
return_if_io!(pager.end_tx(true, &conn));
pager.rollback_tx(&conn);
}
conn.set_tx_state(TransactionState::None);
conn.auto_commit.store(true, Ordering::SeqCst);

View File

@@ -30,7 +30,7 @@ use crate::{
function::{AggFunc, FuncCtx},
mvcc::{database::CommitStateMachine, LocalClock},
state_machine::StateMachine,
storage::sqlite3_ondisk::SmallVec,
storage::{pager::PagerCommitResult, sqlite3_ondisk::SmallVec},
translate::{collate::CollationSeq, plan::TableReferences},
types::{IOCompletions, IOResult, RawSlice, TextRef},
vdbe::{
@@ -41,7 +41,7 @@ use crate::{
},
metrics::StatementMetrics,
},
IOExt, RefValue,
RefValue,
};
use crate::{
@@ -533,7 +533,7 @@ impl Program {
// Connection is closed for whatever reason, rollback the transaction.
let state = self.connection.get_tx_state();
if let TransactionState::Write { .. } = state {
pager.io.block(|| pager.end_tx(true, &self.connection))?;
pager.rollback_tx(&self.connection);
}
return Err(LimboError::InternalError("Connection closed".to_string()));
}
@@ -588,7 +588,7 @@ impl Program {
// Connection is closed for whatever reason, rollback the transaction.
let state = self.connection.get_tx_state();
if let TransactionState::Write { .. } = state {
pager.io.block(|| pager.end_tx(true, &self.connection))?;
pager.rollback_tx(&self.connection);
}
return Err(LimboError::InternalError("Connection closed".to_string()));
}
@@ -636,7 +636,7 @@ impl Program {
// Connection is closed for whatever reason, rollback the transaction.
let state = self.connection.get_tx_state();
if let TransactionState::Write { .. } = state {
pager.io.block(|| pager.end_tx(true, &self.connection))?;
pager.rollback_tx(&self.connection);
}
return Err(LimboError::InternalError("Connection closed".to_string()));
}
@@ -888,7 +888,7 @@ impl Program {
),
TransactionState::Read => {
connection.set_tx_state(TransactionState::None);
pager.end_read_tx()?;
pager.end_read_tx();
Ok(IOResult::Done(()))
}
TransactionState::None => Ok(IOResult::Done(())),
@@ -914,7 +914,12 @@ impl Program {
connection: &Connection,
rollback: bool,
) -> Result<IOResult<()>> {
let cacheflush_status = pager.end_tx(rollback, connection)?;
let cacheflush_status = if !rollback {
pager.commit_tx(connection)?
} else {
pager.rollback_tx(connection);
IOResult::Done(PagerCommitResult::Rollback)
};
match cacheflush_status {
IOResult::Done(_) => {
if self.change_cnt_on {