diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 325f3ce2a..06dddfa9a 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -750,12 +750,14 @@ impl Pager { match state { FlushState::Start => { let db_size = header_accessor::get_database_size(self)?; - for page_id in self.dirty_pages.borrow().iter() { + for (dirty_page_idx, page_id) in self.dirty_pages.borrow().iter().enumerate() { + let is_last_frame = dirty_page_idx == self.dirty_pages.borrow().len() - 1; let mut cache = self.page_cache.write(); let page_key = PageCacheKey::new(*page_id); let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it."); let page_type = page.get().contents.as_ref().unwrap().maybe_page_type(); trace!("cacheflush(page={}, page_type={:?}", page_id, page_type); + let db_size = if is_last_frame { db_size } else { 0 }; self.wal.borrow_mut().append_frame( page.clone(), db_size, diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 4fbd1d8ac..433fd1151 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -1385,7 +1385,7 @@ pub fn read_entire_wal_dumb(file: &Arc) -> Result) -> Result 0; + if is_commit_record { + wfs_data.max_frame.store(frame_idx, Ordering::SeqCst); + wfs_data.last_checksum = cumulative_checksum; + } + frame_idx += 1; current_offset += WAL_FRAME_HEADER_SIZE + page_size; } - wfs_data - .max_frame - .store(frame_idx.saturating_sub(1), Ordering::SeqCst); - wfs_data.last_checksum = cumulative_checksum; wfs_data.loaded.store(true, Ordering::SeqCst); }); let c = Completion::new(CompletionType::Read(ReadCompletion::new( diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 4f4502388..2161b2617 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -223,6 +223,10 @@ pub trait Wal { ) -> Result>; /// Write a frame to the WAL. + /// db_size is the database size in pages after the transaction finishes. + /// db_size > 0 -> last frame written in transaction + /// db_size == 0 -> non-last frame written in transaction + /// write_counter is the counter we use to track when the I/O operation starts and completes fn append_frame( &mut self, page: PageRef, diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index c25d9034b..6099eac05 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -1,8 +1,13 @@ use crate::common::{self, maybe_setup_tracing}; use crate::common::{compare_string, do_flush, TempDatabase}; use log::debug; +use std::io::Write; +use std::os::unix::fs::FileExt; use std::sync::Arc; -use turso_core::{Connection, Row, Statement, StepResult, Value}; +use turso_core::{Connection, Database, Row, Statement, StepResult, Value}; + +const WAL_HEADER_SIZE: usize = 32; +const WAL_FRAME_HEADER_SIZE: usize = 24; #[macro_export] macro_rules! change_state { @@ -638,6 +643,96 @@ fn test_write_concurrent_connections() -> anyhow::Result<()> { Ok(()) } +#[test] +fn test_wal_bad_frame() -> anyhow::Result<()> { + maybe_setup_tracing(); + let _ = env_logger::try_init(); + let db_path = { + let tmp_db = TempDatabase::new_with_rusqlite("CREATE TABLE t1(x)", false); + let db_path = tmp_db.path.clone(); + let conn = tmp_db.connect_limbo(); + conn.execute("BEGIN")?; + conn.execute("CREATE TABLE t2(x)")?; + conn.execute("CREATE TABLE t3(x)")?; + conn.execute("INSERT INTO t2(x) VALUES (1)")?; + conn.execute("INSERT INTO t3(x) VALUES (1)")?; + conn.execute("COMMIT")?; + run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t2", |row| { + let x = row.get::(0).unwrap(); + assert_eq!(x, 1); + }) + .unwrap(); + run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t3", |row| { + let x = row.get::(0).unwrap(); + assert_eq!(x, 1); + }) + .unwrap(); + // Now let's modify last frame record + let path = tmp_db.path.clone(); + let path = path.with_extension("db-wal"); + let mut file = std::fs::OpenOptions::new() + .read(true) + .write(true) + .open(&path) + .unwrap(); + let offset = WAL_HEADER_SIZE + (WAL_FRAME_HEADER_SIZE + 4096) * 2; + let mut buf = [0u8; WAL_FRAME_HEADER_SIZE]; + file.read_at(&mut buf, offset as u64).unwrap(); + dbg!(&buf); + let db_size = u32::from_be_bytes(buf[4..8].try_into().unwrap()); + dbg!(offset); + assert_eq!(db_size, 4); + // let's overwrite size_after to be 0 so that we think transaction never finished + buf[4..8].copy_from_slice(&[0, 0, 0, 0]); + file.write_at(&buf, offset as u64).unwrap(); + file.flush().unwrap(); + + db_path + }; + { + let result = std::panic::catch_unwind(|| { + let io: Arc = Arc::new(limbo_core::PlatformIO::new().unwrap()); + let db = Database::open_file_with_flags( + io.clone(), + db_path.to_str().unwrap(), + limbo_core::OpenFlags::default(), + false, + false, + ) + .unwrap(); + let tmp_db = TempDatabase { + path: db_path, + io, + db, + }; + let conn = tmp_db.connect_limbo(); + run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t2", |row| { + let x = row.get::(0).unwrap(); + assert_eq!(x, 0); + }) + }); + + match result { + Err(panic_info) => { + let panic_msg = panic_info + .downcast_ref::() + .map(|s| s.as_str()) + .or_else(|| panic_info.downcast_ref::<&str>().copied()) + .unwrap_or("Unknown panic message"); + + assert!( + panic_msg.contains("WAL frame checksum mismatch."), + "Expected panic message not found. Got: {}", + panic_msg + ); + } + Ok(_) => panic!("Expected query to panic, but it succeeded"), + } + } + + Ok(()) +} + fn run_query(tmp_db: &TempDatabase, conn: &Arc, query: &str) -> anyhow::Result<()> { run_query_core(tmp_db, conn, query, None::) }