From 2b73260dd97d7463ea702d58f67b6c45ef4fbff2 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Wed, 22 Oct 2025 09:45:52 +0300 Subject: [PATCH] Handle cases where DB grows or shrinks due to savepoint rollback --- core/storage/pager.rs | 41 ++++-- core/vdbe/execute.rs | 9 +- core/vdbe/mod.rs | 20 +-- .../query_processing/test_write_path.rs | 128 +++++++++++++++++- 4 files changed, 172 insertions(+), 26 deletions(-) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index d2a813525..a6c3891dc 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -473,14 +473,19 @@ pub struct Savepoint { write_offset: AtomicU64, /// Bitmap of page numbers that are dirty in the savepoint. page_bitmap: RwLock, + /// Database size at the start of the savepoint. + /// If the database grows during the savepoint and a rollback to the savepoint is performed, + /// the pages exceeding the database size at the start of the savepoint will be ignored. + db_size: AtomicU32, } impl Savepoint { - pub fn new(subjournal_offset: u64) -> Self { + pub fn new(subjournal_offset: u64, db_size: u32) -> Self { Self { start_offset: AtomicU64::new(subjournal_offset), write_offset: AtomicU64::new(subjournal_offset), page_bitmap: RwLock::new(RoaringBitmap::new()), + db_size: AtomicU32::new(db_size), } } @@ -660,9 +665,9 @@ impl Pager { }) } - pub fn begin_statement(&self) -> Result<()> { + pub fn begin_statement(&self, db_size: u32) -> Result<()> { self.open_subjournal()?; - self.open_savepoint()?; + self.open_savepoint(db_size)?; Ok(()) } @@ -756,14 +761,14 @@ impl Pager { Ok(()) } - pub fn open_savepoint(&self) -> Result<()> { + pub fn open_savepoint(&self, db_size: u32) -> Result<()> { self.open_subjournal()?; let subjournal_offset = self.subjournal.read().as_ref().unwrap().size()?; // Currently as we only have anonymous savepoints opened at the start of a statement, // the subjournal offset should always be 0 as we should only have max 1 savepoint // opened at any given time. turso_assert!(subjournal_offset == 0, "subjournal offset should be 0"); - let savepoint = Savepoint::new(subjournal_offset); + let savepoint = Savepoint::new(subjournal_offset, db_size); let mut savepoints = self.savepoints.write(); turso_assert!( savepoints.is_empty(), @@ -824,6 +829,9 @@ impl Pager { let mut current_offset = journal_start_offset; let page_size = self.page_size.load(Ordering::SeqCst) as u64; let journal_end_offset = savepoint.write_offset.load(Ordering::SeqCst); + let db_size = savepoint.db_size.load(Ordering::SeqCst); + + let mut dirty_pages = self.dirty_pages.write(); while current_offset < journal_end_offset { // Read 4 bytes for page id @@ -833,12 +841,27 @@ impl Pager { let page_id = u32::from_be_bytes(page_id_buffer.as_slice()[0..4].try_into().unwrap()); current_offset += 4; - // Check if we've already rolled back this page - if rollback_bitset.contains(page_id) { - // Skip reading the page, just advance offset + // Check if we've already rolled back this page or if the page is beyond the database size at the start of the savepoint + let already_rolled_back = rollback_bitset.contains(page_id); + if already_rolled_back { current_offset += page_size; continue; } + let page_wont_exist_after_rollback = page_id > db_size; + if page_wont_exist_after_rollback { + dirty_pages.remove(&(page_id as usize)); + if let Some(page) = self + .page_cache + .write() + .get(&PageCacheKey::new(page_id as usize))? + { + page.clear_dirty(); + page.try_unpin(); + } + current_offset += page_size; + rollback_bitset.insert(page_id); + continue; + } // Read the page data let page_buffer = Arc::new(self.buffer_pool.allocate(page_size as usize)); @@ -870,6 +893,8 @@ impl Pager { "memory IO should complete immediately" ); + self.page_cache.write().truncate(db_size as usize)?; + Ok(()) } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index d90765c62..503c9c016 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -2266,6 +2266,7 @@ pub fn op_halt_if_null( pub enum OpTransactionState { Start, CheckSchemaCookie, + BeginStatement, } pub fn op_transaction( @@ -2471,9 +2472,15 @@ pub fn op_transaction_inner( } } + state.op_transaction_state = OpTransactionState::BeginStatement; + } + OpTransactionState::BeginStatement => { if program.needs_stmt_subtransactions && mv_store.is_none() { let write = matches!(tx_mode, TransactionMode::Write); - state.begin_statement(&program.connection, &pager, write)?; + let res = state.begin_statement(&program.connection, &pager, write)?; + if let IOResult::IO(io) = res { + return Ok(InsnFunctionStepResult::IO(io)); + } } state.pc += 1; diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 39109fd1f..68981dd70 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -29,6 +29,7 @@ use crate::{ error::LimboError, function::{AggFunc, FuncCtx}, mvcc::{database::CommitStateMachine, LocalClock}, + return_if_io, state_machine::StateMachine, storage::{pager::PagerCommitResult, sqlite3_ondisk::SmallVec}, translate::{collate::CollationSeq, plan::TableReferences}, @@ -179,18 +180,6 @@ pub enum StepResult { Busy, } -/// If there is I/O, the instruction is restarted. -/// Evaluate a Result>, if IO return Ok(StepResult::IO). -#[macro_export] -macro_rules! return_step_if_io { - ($expr:expr) => { - match $expr? { - IOResult::Ok(v) => v, - IOResult::IO => return Ok(StepResult::IO), - } - }; -} - struct RegexCache { like: HashMap, glob: HashMap, @@ -482,7 +471,7 @@ impl ProgramState { connection: &Connection, pager: &Arc, write: bool, - ) -> Result<()> { + ) -> Result> { // Store the deferred foreign key violations counter at the start of the statement. // This is used to ensure that if an interactive transaction had deferred FK violations and a statement subtransaction rolls back, // the deferred FK violations are not lost. @@ -494,9 +483,10 @@ impl ProgramState { self.fk_immediate_violations_during_stmt .store(0, Ordering::SeqCst); if write { - pager.begin_statement()?; + let db_size = return_if_io!(pager.with_header(|header| header.database_size.get())); + pager.begin_statement(db_size)?; } - Ok(()) + Ok(IOResult::Done(())) } /// End a statement subtransaction. diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index 85f666ee4..b9384cdc3 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -1,10 +1,11 @@ -use crate::common::{self, limbo_exec_rows, maybe_setup_tracing}; +use crate::common::{self, limbo_exec_rows, maybe_setup_tracing, rusqlite_integrity_check}; use crate::common::{compare_string, do_flush, TempDatabase}; use log::debug; use std::io::{Read, Seek, Write}; use std::sync::Arc; use turso_core::{ - CheckpointMode, Connection, Database, LimboError, Row, Statement, StepResult, Value, + CheckpointMode, Connection, Database, DatabaseOpts, LimboError, Row, Statement, StepResult, + Value, }; const WAL_HEADER_SIZE: usize = 32; @@ -508,6 +509,129 @@ fn test_update_regression() -> anyhow::Result<()> { Ok(()) } +#[test] +/// Test that a large insert statement containing a UNIQUE constraint violation +/// is properly rolled back so that the database size is also shrunk to the size +/// before that statement is executed. +fn test_rollback_on_unique_constraint_violation() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new_with_opts( + "big_statement_rollback.db", + DatabaseOpts::new().with_indexes(true), + ); + let conn = tmp_db.connect_limbo(); + + conn.execute("CREATE TABLE t(x UNIQUE)")?; + + conn.execute("BEGIN")?; + conn.execute("INSERT INTO t VALUES (10000)")?; + + // This should fail due to unique constraint violation + let result = conn.execute("INSERT INTO t SELECT value FROM generate_series(1,10000)"); + assert!(result.is_err(), "Expected unique constraint violation"); + + conn.execute("COMMIT")?; + + // Should have exactly 1 row (the first insert) + common::run_query_on_row(&tmp_db, &conn, "SELECT count(*) FROM t", |row| { + let count = row.get::(0).unwrap(); + assert_eq!(count, 1, "Expected 1 row after rollback"); + })?; + + // Check page count + common::run_query_on_row(&tmp_db, &conn, "PRAGMA page_count", |row| { + let page_count = row.get::(0).unwrap(); + assert_eq!(page_count, 3, "Expected 3 pages"); + })?; + + // Checkpoint the WAL + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")?; + + // Integrity check with rusqlite + rusqlite_integrity_check(tmp_db.path.as_path())?; + + // Size on disk should be 3 * 4096 + let db_size = std::fs::metadata(&tmp_db.path).unwrap().len(); + assert_eq!(db_size, 3 * 4096); + + Ok(()) +} + +#[test] +/// Test that a large delete statement containing a foreign key constraint violation +/// is properly rolled back. +fn test_rollback_on_foreign_key_constraint_violation() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new_with_opts( + "big_delete_rollback.db", + DatabaseOpts::new().with_indexes(true), + ); + let conn = tmp_db.connect_limbo(); + + // Enable foreign keys + conn.execute("PRAGMA foreign_keys = ON")?; + + // Create parent and child tables + conn.execute("CREATE TABLE parent(id INTEGER PRIMARY KEY)")?; + conn.execute( + "CREATE TABLE child(id INTEGER PRIMARY KEY, parent_id INTEGER REFERENCES parent(id))", + )?; + + // Insert 10000 parent rows + conn.execute("INSERT INTO parent SELECT value FROM generate_series(1,10000)")?; + + // Insert a child row that references the 10000th parent row + conn.execute("INSERT INTO child VALUES (1, 10000)")?; + + conn.execute("BEGIN")?; + + // Delete first parent row (should succeed) + conn.execute("DELETE FROM parent WHERE id = 1")?; + + // This should fail due to foreign key constraint violation (trying to delete parent row 10000 which has a child) + let result = conn.execute("DELETE FROM parent WHERE id >= 2"); + assert!(result.is_err(), "Expected foreign key constraint violation"); + + conn.execute("COMMIT")?; + + // Should have 9999 parent rows (10000 - 1 that was successfully deleted) + common::run_query_on_row(&tmp_db, &conn, "SELECT count(*) FROM parent", |row| { + let count = row.get::(0).unwrap(); + assert_eq!(count, 9999, "Expected 9999 parent rows after rollback"); + })?; + + // Verify rows 2-10000 are intact + common::run_query_on_row( + &tmp_db, + &conn, + "SELECT min(id), max(id) FROM parent", + |row| { + let min_id = row.get::(0).unwrap(); + let max_id = row.get::(1).unwrap(); + assert_eq!(min_id, 2, "Expected min id to be 2"); + assert_eq!(max_id, 10000, "Expected max id to be 10000"); + }, + )?; + + // Child row should still exist + common::run_query_on_row(&tmp_db, &conn, "SELECT count(*) FROM child", |row| { + let count = row.get::(0).unwrap(); + assert_eq!(count, 1, "Expected 1 child row"); + })?; + + // Checkpoint the WAL + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")?; + + // Integrity check with rusqlite + rusqlite_integrity_check(tmp_db.path.as_path())?; + + // Size on disk should be 21 * 4096 + let db_size = std::fs::metadata(&tmp_db.path).unwrap().len(); + assert_eq!(db_size, 21 * 4096); + + Ok(()) +} + #[test] fn test_multiple_statements() -> anyhow::Result<()> { let _ = env_logger::try_init();