diff --git a/core/lib.rs b/core/lib.rs index 737fecb98..e95184982 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1169,7 +1169,12 @@ impl Connection { let pager = self.pager.borrow(); // remove all non-commited changes in case if WAL session left some suffix without commit frame - pager.rollback(false, self).expect("rollback must succeed"); + while let IOResult::IO = pager + .end_tx(false, false, self, false) + .expect("rollback must succeed") + { + self.run_once()?; + } let wal = pager.wal.borrow_mut(); wal.end_write_tx(); @@ -1726,13 +1731,6 @@ impl Statement { if res.is_err() { let state = self.program.connection.transaction_state.get(); if let TransactionState::Write { schema_did_change } = state { - if let Err(e) = self - .pager - .rollback(schema_did_change, &self.program.connection) - { - // Let's panic for now as we don't want to leave state in a bad state. - panic!("rollback failed: {e:?}"); - } let end_tx_res = self.pager .end_tx(true, schema_did_change, &self.program.connection, true)?; diff --git a/core/storage/pager.rs b/core/storage/pager.rs index fbdabece3..24ffb9ff5 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -818,6 +818,7 @@ impl Pager { ) -> Result> { tracing::trace!("end_tx(rollback={})", rollback); if rollback { + self.rollback(schema_did_change, connection)?; self.wal.borrow().end_write_tx(); self.wal.borrow().end_read_tx(); return Ok(IOResult::Done(PagerCommitResult::Rollback)); @@ -1797,11 +1798,7 @@ impl Pager { } #[instrument(skip_all, level = Level::DEBUG)] - pub fn rollback( - &self, - schema_did_change: bool, - connection: &Connection, - ) -> Result<(), LimboError> { + fn rollback(&self, schema_did_change: bool, connection: &Connection) -> Result<(), LimboError> { tracing::debug!(schema_did_change); self.dirty_pages.borrow_mut().clear(); let mut cache = self.page_cache.write(); diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 8f211c482..f25946835 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -2085,7 +2085,7 @@ pub fn op_auto_commit( if *auto_commit != conn.auto_commit.get() { if *rollback { // TODO(pere): add rollback I/O logic once we implement rollback journal - pager.rollback(schema_did_change, &conn)?; + return_if_io!(pager.end_tx(true, schema_did_change, &conn, false)); conn.auto_commit.replace(true); } else { conn.auto_commit.replace(*auto_commit); diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 1c4addfe5..6113c8e3c 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -398,7 +398,10 @@ impl Program { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.transaction_state.get(); if let TransactionState::Write { schema_did_change } = state { - pager.rollback(schema_did_change, &self.connection)? + match pager.end_tx(true, schema_did_change, &self.connection, false)? { + IOResult::IO => return Ok(StepResult::IO), + IOResult::Done(_) => {} + } } return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -510,13 +513,10 @@ impl Program { connection.wal_checkpoint_disabled.get(), )?; match cacheflush_status { - IOResult::Done(status) => { + IOResult::Done(_) => { if self.change_cnt_on { self.connection.set_changes(self.n_change.get()); } - if matches!(status, pager::PagerCommitResult::Rollback) { - pager.rollback(schema_did_change, connection)?; - } connection.transaction_state.replace(TransactionState::None); *commit_state = CommitState::Ready; } @@ -761,11 +761,15 @@ pub fn handle_program_error( _ => { let state = connection.transaction_state.get(); if let TransactionState::Write { schema_did_change } = state { - if let Err(e) = pager.rollback(schema_did_change, connection) { - tracing::error!("rollback failed: {e}"); - } - if let Err(e) = pager.end_tx(false, schema_did_change, connection, false) { - tracing::error!("end_tx failed: {e}"); + loop { + match pager.end_tx(true, schema_did_change, connection, false) { + Ok(IOResult::IO) => connection.run_once()?, + Ok(IOResult::Done(_)) => break, + Err(e) => { + tracing::error!("end_tx failed: {e}"); + break; + } + } } } else if let Err(e) = pager.end_read_tx() { tracing::error!("end_read_tx failed: {e}");