diff --git a/core/storage/page_cache.rs b/core/storage/page_cache.rs index e08c979e8..ca2546522 100644 --- a/core/storage/page_cache.rs +++ b/core/storage/page_cache.rs @@ -1,8 +1,11 @@ +use std::sync::atomic::Ordering; use std::{cell::RefCell, ptr::NonNull}; use std::sync::Arc; use tracing::{debug, trace}; +use crate::turso_assert; + use super::pager::PageRef; /// FIXME: https://github.com/tursodatabase/turso/issues/1661 @@ -343,13 +346,43 @@ impl DumbLruPageCache { Ok(()) } + /// Removes all pages from the cache with pgno greater than len + pub fn truncate(&mut self, len: usize) -> Result<(), CacheError> { + let head_ptr = *self.head.borrow(); + let mut current = head_ptr; + while let Some(node) = current { + let node_ref = unsafe { node.as_ref() }; + + current = node_ref.next; + if node_ref.key.pgno <= len { + continue; + } + + self.map.borrow_mut().remove(&node_ref.key); + turso_assert!(!node_ref.page.is_dirty(), "page must be clean"); + turso_assert!(!node_ref.page.is_locked(), "page must be unlocked"); + turso_assert!(!node_ref.page.is_pinned(), "page must be unpinned"); + self.detach(node, true)?; + + unsafe { + let _ = Box::from_raw(node.as_ptr()); + } + } + Ok(()) + } + pub fn print(&self) { tracing::debug!("page_cache_len={}", self.map.borrow().len()); let head_ptr = *self.head.borrow(); let mut current = head_ptr; while let Some(node) = current { unsafe { - tracing::debug!("page={:?}", node.as_ref().key); + tracing::debug!( + "page={:?}, flags={}, pin_count={}", + node.as_ref().key, + node.as_ref().page.get().flags.load(Ordering::SeqCst), + node.as_ref().page.get().pin_count.load(Ordering::SeqCst), + ); let node_ref = node.as_ref(); current = node_ref.next; } @@ -1231,6 +1264,38 @@ mod tests { assert!(cache.insert(create_key(4), page_with_content(4)).is_ok()); } + #[test] + fn test_truncate_page_cache() { + let mut cache = DumbLruPageCache::new(10); + let _ = insert_page(&mut cache, 1); + let _ = insert_page(&mut cache, 4); + let _ = insert_page(&mut cache, 8); + let _ = insert_page(&mut cache, 10); + cache.truncate(4).unwrap(); + assert!(cache.contains_key(&PageCacheKey { pgno: 1 })); + assert!(cache.contains_key(&PageCacheKey { pgno: 4 })); + assert!(!cache.contains_key(&PageCacheKey { pgno: 8 })); + assert!(!cache.contains_key(&PageCacheKey { pgno: 10 })); + assert_eq!(cache.len(), 2); + assert_eq!(cache.capacity, 10); + cache.verify_list_integrity(); + assert!(cache.insert(create_key(8), page_with_content(8)).is_ok()); + } + + #[test] + fn test_truncate_page_cache_remove_all() { + let mut cache = DumbLruPageCache::new(10); + let _ = insert_page(&mut cache, 8); + let _ = insert_page(&mut cache, 10); + cache.truncate(4).unwrap(); + assert!(!cache.contains_key(&PageCacheKey { pgno: 8 })); + assert!(!cache.contains_key(&PageCacheKey { pgno: 10 })); + assert_eq!(cache.len(), 0); + assert_eq!(cache.capacity, 10); + cache.verify_list_integrity(); + assert!(cache.insert(create_key(8), page_with_content(8)).is_ok()); + } + #[test] #[ignore = "long running test, remove to verify"] fn test_clear_memory_stability() { diff --git a/core/storage/pager.rs b/core/storage/pager.rs index c1247449c..07602ef23 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1413,7 +1413,16 @@ impl Pager { page.get().id == header.page_number as usize, "page has unexpected id" ); - self.add_dirty(&page); + } + if header.page_number == 1 { + let db_size = self + .io + .block(|| self.with_header(|header| header.database_size))?; + tracing::debug!("truncate page_cache as first page was written: {}", db_size); + let mut page_cache = self.page_cache.write(); + page_cache.truncate(db_size.get() as usize).map_err(|e| { + LimboError::InternalError(format!("Failed to truncate page cache: {e:?}")) + })?; } if header.is_commit_frame() { for page_id in self.dirty_pages.borrow().iter() { diff --git a/tests/integration/functions/test_wal_api.rs b/tests/integration/functions/test_wal_api.rs index a93ac05c7..cf1548494 100644 --- a/tests/integration/functions/test_wal_api.rs +++ b/tests/integration/functions/test_wal_api.rs @@ -627,3 +627,50 @@ fn test_wal_checkpoint_no_work() { ); reader.execute("SELECT * FROM test").unwrap(); } + +#[test] +fn test_wal_revert_change_db_size() { + let db = TempDatabase::new_empty(false); + let writer = db.connect_limbo(); + + writer.execute("create table t(x, y)").unwrap(); + let watermark = writer.wal_state().unwrap().max_frame; + writer + .execute("insert into t values (1, randomblob(10 * 4096))") + .unwrap(); + writer + .execute("insert into t values (2, randomblob(20 * 4096))") + .unwrap(); + let mut changed = writer.wal_changed_pages_after(watermark).unwrap(); + changed.sort(); + + let mut frame = [0u8; 4096 + 24]; + + writer.wal_insert_begin().unwrap(); + let mut frames_count = writer.wal_state().unwrap().max_frame; + for page_no in changed { + let page = &mut frame[24..]; + if !writer + .try_wal_watermark_read_page(page_no, page, Some(watermark)) + .unwrap() + { + continue; + } + let info = WalFrameInfo { + page_no, + db_size: if page_no == 2 { 2 } else { 0 }, + }; + info.put_to_frame_header(&mut frame); + frames_count += 1; + writer.wal_insert_frame(frames_count, &frame).unwrap(); + } + writer.wal_insert_end().unwrap(); + + writer + .execute("insert into t values (3, randomblob(30 * 4096))") + .unwrap(); + assert_eq!( + limbo_exec_rows(&db, &writer, "SELECT x, length(y) FROM t"), + vec![vec![Value::Integer(3), Value::Integer(30 * 4096)]] + ); +}