diff --git a/core/lib.rs b/core/lib.rs index ada35e497..c52bdb2d7 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1360,10 +1360,12 @@ impl Connection { self.closed.set(true); match self.transaction_state.get() { - TransactionState::Write { schema_did_change } => { + TransactionState::None => { + // No active transaction + } + _ => { while let IOResult::IO = self.pager.borrow().end_tx( true, // rollback = true for close - schema_did_change, self, self.wal_checkpoint_disabled.get(), )? { @@ -1371,13 +1373,6 @@ impl Connection { } self.transaction_state.set(TransactionState::None); } - TransactionState::PendingUpgrade | TransactionState::Read => { - self.pager.borrow().end_read_tx()?; - self.transaction_state.set(TransactionState::None); - } - TransactionState::None => { - // No active transaction - } } self.pager @@ -1936,7 +1931,9 @@ impl Statement { res } + #[instrument(skip_all, level = Level::DEBUG)] fn reprepare(&mut self) -> Result<()> { + tracing::trace!("repreparing statement"); let conn = self.program.connection.clone(); *conn.schema.borrow_mut() = conn._db.clone_schema()?; self.program = { @@ -1972,10 +1969,8 @@ impl Statement { let res = self.pager.io.run_once(); if res.is_err() { let state = self.program.connection.transaction_state.get(); - if let TransactionState::Write { schema_did_change } = state { - let end_tx_res = - self.pager - .end_tx(true, schema_did_change, &self.program.connection, true)?; + if let TransactionState::Write { .. } = state { + let end_tx_res = self.pager.end_tx(true, &self.program.connection, true)?; self.program .connection .transaction_state diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index 4466d51d2..5605c6d1c 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -514,7 +514,6 @@ impl StateTransition for CommitStateMachine { .pager .end_tx( false, // rollback = false since we're committing - false, // schema_did_change = false for now (could be improved) &self.connection, self.connection.wal_checkpoint_disabled.get(), ) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 33788b211..c90b45968 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -7824,7 +7824,7 @@ mod tests { ) .unwrap(); loop { - match pager.end_tx(false, false, &conn, false).unwrap() { + match pager.end_tx(false, &conn, false).unwrap() { IOResult::Done(_) => break, IOResult::IO => { pager.io.run_once().unwrap(); @@ -7982,7 +7982,7 @@ mod tests { .unwrap(); let _c = cursor.move_to_root().unwrap(); loop { - match pager.end_tx(false, false, &conn, false).unwrap() { + match pager.end_tx(false, &conn, false).unwrap() { IOResult::Done(_) => break, IOResult::IO => { pager.io.run_once().unwrap(); @@ -8200,7 +8200,7 @@ mod tests { let _c = cursor.move_to_root().unwrap(); loop { - match pager.end_tx(false, false, &conn, false).unwrap() { + match pager.end_tx(false, &conn, false).unwrap() { IOResult::Done(_) => break, IOResult::IO => { pager.io.run_once().unwrap(); diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 974df2aac..575626be5 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1009,7 +1009,6 @@ impl Pager { pub fn end_tx( &self, rollback: bool, - schema_did_change: bool, connection: &Connection, wal_checkpoint_disabled: bool, ) -> Result> { @@ -1018,11 +1017,11 @@ impl Pager { // TODO: Unsure what the semantics of "end_tx" is for in-memory databases, ephemeral tables and ephemeral indexes. return Ok(IOResult::Done(PagerCommitResult::Rollback)); }; + let (is_write, schema_did_change) = match connection.transaction_state.get() { + TransactionState::Write { schema_did_change } => (true, schema_did_change), + _ => (false, false), + }; if rollback { - let is_write = matches!( - connection.transaction_state.get(), - TransactionState::Write { .. } - ); if is_write { wal.borrow().end_write_tx(); } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 9d9c8caab..cf7aad99b 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -2030,6 +2030,7 @@ pub fn op_transaction( match pager.begin_write_tx()? { IOResult::Done(r) => { if let LimboResult::Busy = r { + tracing::error!("connection is busy"); pager.end_read_tx()?; conn.transaction_state.replace(TransactionState::None); conn.auto_commit.replace(true); @@ -2048,6 +2049,11 @@ pub fn op_transaction( } } + // Transaction state should be updated before checking for Schema cookie so that the tx is ended properly on error + if updated { + conn.transaction_state.replace(new_transaction_state); + } + // Check whether schema has changed if we are actually going to access the database. if !matches!(new_transaction_state, TransactionState::None) { let res = pager @@ -2066,10 +2072,6 @@ pub fn op_transaction( } } } - - if updated { - conn.transaction_state.replace(new_transaction_state); - } } state.pc += 1; Ok(InsnFunctionStepResult::Step) @@ -2095,17 +2097,11 @@ pub fn op_auto_commit( .commit_txn(pager.clone(), state, mv_store, *rollback) .map(Into::into); } - let schema_did_change = - if let TransactionState::Write { schema_did_change } = conn.transaction_state.get() { - schema_did_change - } else { - false - }; if *auto_commit != conn.auto_commit.get() { if *rollback { // TODO(pere): add rollback I/O logic once we implement rollback journal - return_if_io!(pager.end_tx(true, schema_did_change, &conn, false)); + return_if_io!(pager.end_tx(true, &conn, false)); conn.transaction_state.replace(TransactionState::None); conn.auto_commit.replace(true); } else { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index f0adcebe8..9760c7d60 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -420,8 +420,8 @@ impl Program { if self.connection.closed.get() { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.transaction_state.get(); - if let TransactionState::Write { schema_did_change } = state { - match pager.end_tx(true, schema_did_change, &self.connection, false)? { + if let TransactionState::Write { .. } = state { + match pager.end_tx(true, &self.connection, false)? { IOResult::IO => return Ok(StepResult::IO), IOResult::Done(_) => {} } @@ -510,9 +510,7 @@ impl Program { program_state.commit_state ); if program_state.commit_state == CommitState::Committing { - let TransactionState::Write { schema_did_change } = - connection.transaction_state.get() - else { + let TransactionState::Write { .. } = connection.transaction_state.get() else { unreachable!("invalid state for write commit step") }; self.step_end_write_txn( @@ -520,18 +518,16 @@ impl Program { &mut program_state.commit_state, &connection, rollback, - schema_did_change, ) } else if auto_commit { let current_state = connection.transaction_state.get(); tracing::trace!("Auto-commit state: {:?}", current_state); match current_state { - TransactionState::Write { schema_did_change } => self.step_end_write_txn( + TransactionState::Write { .. } => self.step_end_write_txn( &pager, &mut program_state.commit_state, &connection, rollback, - schema_did_change, ), TransactionState::Read => { connection.transaction_state.replace(TransactionState::None); @@ -559,11 +555,9 @@ impl Program { commit_state: &mut CommitState, connection: &Connection, rollback: bool, - schema_did_change: bool, ) -> Result { let cacheflush_status = pager.end_tx( rollback, - schema_did_change, connection, connection.wal_checkpoint_disabled.get(), )?; @@ -817,20 +811,15 @@ pub fn handle_program_error( // Table locked errors, e.g. trying to checkpoint in an interactive transaction, do not cause a rollback. LimboError::TableLocked => {} _ => { - let state = connection.transaction_state.get(); - if let TransactionState::Write { schema_did_change } = state { - loop { - match pager.end_tx(true, schema_did_change, connection, false) { - Ok(IOResult::IO) => connection.run_once()?, - Ok(IOResult::Done(_)) => break, - Err(e) => { - tracing::error!("end_tx failed: {e}"); - break; - } + loop { + match pager.end_tx(true, connection, false) { + Ok(IOResult::IO) => connection.run_once()?, + Ok(IOResult::Done(_)) => break, + Err(e) => { + tracing::error!("end_tx failed: {e}"); + break; } } - } else if let Err(e) = pager.end_read_tx() { - tracing::error!("end_read_tx failed: {e}"); } connection.transaction_state.replace(TransactionState::None); } diff --git a/tests/integration/common.rs b/tests/integration/common.rs index 058d44787..83e38429a 100644 --- a/tests/integration/common.rs +++ b/tests/integration/common.rs @@ -153,7 +153,7 @@ pub fn maybe_setup_tracing() { let _ = tracing_subscriber::registry() .with( tracing_subscriber::fmt::layer() - .with_ansi(false) + .with_ansi(true) .with_line_number(true) .with_thread_ids(true), ) diff --git a/tests/integration/query_processing/test_multi_thread.rs b/tests/integration/query_processing/test_multi_thread.rs index c509fe0da..4450e5d2a 100644 --- a/tests/integration/query_processing/test_multi_thread.rs +++ b/tests/integration/query_processing/test_multi_thread.rs @@ -196,3 +196,41 @@ fn test_reader_writer() -> anyhow::Result<()> { } Ok(()) } + +#[test] +fn test_schema_reprepare_write() { + maybe_setup_tracing(); + let tmp_db = TempDatabase::new_empty(false); + let conn1 = tmp_db.connect_limbo(); + conn1.execute("CREATE TABLE t(x, y, z)").unwrap(); + let conn2 = tmp_db.connect_limbo(); + let mut stmt = conn2.prepare("INSERT INTO t(y, z) VALUES (1, 2)").unwrap(); + let mut stmt2 = conn2.prepare("INSERT INTO t(y, z) VALUES (3, 4)").unwrap(); + conn1.execute("ALTER TABLE t DROP COLUMN x").unwrap(); + + tracing::info!("Executing Stmt 1"); + loop { + match stmt.step().unwrap() { + turso_core::StepResult::Done => { + break; + } + turso_core::StepResult::IO => { + stmt.run_once().unwrap(); + } + step => panic!("unexpected step result {step:?}"), + } + } + + tracing::info!("Executing Stmt 2"); + loop { + match stmt2.step().unwrap() { + turso_core::StepResult::Done => { + break; + } + turso_core::StepResult::IO => { + stmt2.run_once().unwrap(); + } + step => panic!("unexpected step result {step:?}"), + } + } +}