From 2c0842ff52eca61d002d736eff78e0b3872264f4 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Wed, 27 Aug 2025 21:44:25 +0530 Subject: [PATCH] Set and propagate `IOContext` as required --- core/storage/btree.rs | 11 ++++++---- core/storage/database.rs | 14 +++++++++++++ core/storage/pager.rs | 37 +++++++++++++++------------------- core/storage/sqlite3_ondisk.rs | 16 ++++++++------- core/storage/wal.rs | 31 ++++++++++++++-------------- 5 files changed, 61 insertions(+), 48 deletions(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 49f4c06ef..ae922f757 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -7404,7 +7404,7 @@ mod tests { }, types::Text, vdbe::Register, - BufferPool, Completion, Connection, StepResult, WalFile, WalFileShared, + BufferPool, Completion, Connection, IOContext, StepResult, WalFile, WalFileShared, }; use std::{ cell::RefCell, @@ -8713,9 +8713,12 @@ mod tests { let c = Completion::new_write(move |_| { let _ = _buf.clone(); }); - let _c = pager - .db_file - .write_page(current_page as usize, buf.clone(), None, c)?; + let _c = pager.db_file.write_page( + current_page as usize, + buf.clone(), + &IOContext::default(), + c, + )?; pager.io.run_once()?; let (page, _c) = cursor.read_page(current_page as usize)?; diff --git a/core/storage/database.rs b/core/storage/database.rs index f6bbd8515..1cd1a2d42 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -4,12 +4,14 @@ use crate::{io::Completion, Buffer, CompletionError, Result}; use std::sync::Arc; use tracing::{instrument, Level}; +#[derive(Clone)] pub enum EncryptionOrChecksum { Encryption(EncryptionContext), Checksum, None, } +#[derive(Clone)] pub struct IOContext { encryption_or_checksum: EncryptionOrChecksum, } @@ -21,6 +23,18 @@ impl IOContext { _ => None, } } + + pub fn set_encryption(&mut self, encryption_ctx: EncryptionContext) { + self.encryption_or_checksum = EncryptionOrChecksum::Encryption(encryption_ctx); + } +} + +impl Default for IOContext { + fn default() -> Self { + Self { + encryption_or_checksum: EncryptionOrChecksum::None, + } + } } /// DatabaseStorage is an interface a database file that consists of pages. diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 1e6341fd5..003c13b61 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -11,7 +11,7 @@ use crate::storage::{ }; use crate::types::{IOCompletions, WalState}; use crate::util::IOExt as _; -use crate::{io_yield_many, io_yield_one}; +use crate::{io_yield_many, io_yield_one, IOContext}; use crate::{ return_if_io, turso_assert, types::WalFrameInfo, Completion, Connection, IOResult, LimboError, Result, TransactionState, @@ -503,7 +503,7 @@ pub struct Pager { header_ref_state: RefCell, #[cfg(not(feature = "omit_autovacuum"))] btree_create_vacuum_full_state: Cell, - pub(crate) encryption_ctx: RefCell>, + pub(crate) io_ctx: RefCell, } #[derive(Debug, Clone)] @@ -607,7 +607,7 @@ impl Pager { header_ref_state: RefCell::new(HeaderRefState::Start), #[cfg(not(feature = "omit_autovacuum"))] btree_create_vacuum_full_state: Cell::new(BtreeCreateVacuumFullState::Start), - encryption_ctx: RefCell::new(None), + io_ctx: RefCell::new(IOContext::default()), }) } @@ -1094,7 +1094,7 @@ impl Pager { ) -> Result<(PageRef, Completion)> { tracing::trace!("read_page_no_cache(page_idx = {})", page_idx); let page = Arc::new(Page::new(page_idx)); - + let io_ctx = &self.io_ctx.borrow(); let Some(wal) = self.wal.as_ref() else { turso_assert!( matches!(frame_watermark, Some(0) | None), @@ -1102,12 +1102,7 @@ impl Pager { ); page.set_locked(); - let c = self.begin_read_disk_page( - page_idx, - page.clone(), - allow_empty_read, - self.encryption_ctx.borrow().as_ref(), - )?; + let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read, io_ctx)?; return Ok((page, c)); }; @@ -1120,12 +1115,7 @@ impl Pager { return Ok((page, c)); } - let c = self.begin_read_disk_page( - page_idx, - page.clone(), - allow_empty_read, - self.encryption_ctx.borrow().as_ref(), - )?; + let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read, io_ctx)?; Ok((page, c)) } @@ -1149,7 +1139,7 @@ impl Pager { page_idx: usize, page: PageRef, allow_empty_read: bool, - encryption_key: Option<&EncryptionContext>, + io_ctx: &IOContext, ) -> Result { sqlite3_ondisk::begin_read_page( self.db_file.clone(), @@ -1157,7 +1147,7 @@ impl Pager { page, page_idx, allow_empty_read, - encryption_key, + io_ctx, ) } @@ -1802,7 +1792,8 @@ impl Pager { default_header.database_size = 1.into(); // if a key is set, then we will reserve space for encryption metadata - if let Some(ref ctx) = *self.encryption_ctx.borrow() { + let io_ctx = self.io_ctx.borrow(); + if let Some(ctx) = io_ctx.encryption_context() { default_header.reserved_space = ctx.required_reserved_bytes() } @@ -2190,9 +2181,13 @@ impl Pager { pub fn set_encryption_context(&self, cipher_mode: CipherMode, key: &EncryptionKey) { let encryption_ctx = EncryptionContext::new(cipher_mode, key).unwrap(); - self.encryption_ctx.replace(Some(encryption_ctx.clone())); + { + let mut io_ctx = self.io_ctx.borrow_mut(); + io_ctx.set_encryption(encryption_ctx); + } let Some(wal) = self.wal.as_ref() else { return }; - wal.borrow_mut().set_encryption_context(encryption_ctx) + wal.borrow_mut() + .set_io_context(self.io_ctx.borrow().clone()) } } diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 53ebf04a4..8ca674cef 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -59,11 +59,12 @@ use crate::storage::btree::offset::{ use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min}; use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; -use crate::storage::encryption::EncryptionContext; use crate::storage::pager::Pager; use crate::storage::wal::READMARK_NOT_USED; use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype}; -use crate::{bail_corrupt_error, turso_assert, CompletionError, File, Result, WalFileShared}; +use crate::{ + bail_corrupt_error, turso_assert, CompletionError, File, IOContext, Result, WalFileShared, +}; use std::cell::{Cell, UnsafeCell}; use std::collections::{BTreeMap, HashMap}; use std::mem::MaybeUninit; @@ -905,7 +906,7 @@ pub fn begin_read_page( page: PageRef, page_idx: usize, allow_empty_read: bool, - encryption_key: Option<&EncryptionContext>, + io_ctx: &IOContext, ) -> Result { tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx); let buf = buffer_pool.get_page(); @@ -928,7 +929,7 @@ pub fn begin_read_page( finish_read_page(page_idx, buf, page.clone()); }); let c = Completion::new_read(buf, complete); - db_file.read_page(page_idx, encryption_key, c) + db_file.read_page(page_idx, io_ctx, c) } #[instrument(skip_all, level = Level::INFO)] @@ -982,7 +983,8 @@ pub fn begin_write_btree_page(pager: &Pager, page: &PageRef) -> Result>, done_flag: Arc, - encryption_key: Option<&EncryptionContext>, ) -> Result> { if batch.is_empty() { done_flag.store(true, Ordering::Relaxed); @@ -1076,11 +1077,12 @@ pub fn write_pages_vectored( }); // Submit write operation for this run, decrementing the counter if we error + let io_ctx = &pager.io_ctx.borrow(); match pager.db_file.write_pages( start_id, page_sz, std::mem::replace(&mut run_bufs, Vec::with_capacity(EST_BUFF_CAPACITY)), - encryption_key, + io_ctx, c, ) { Ok(c) => { diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 15fd04afe..3f1d09571 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -17,7 +17,6 @@ use super::sqlite3_ondisk::{self, checksum_wal, WalHeader, WAL_MAGIC_BE, WAL_MAG use crate::fast_lock::SpinLock; use crate::io::{clock, File, IO}; use crate::result::LimboResult; -use crate::storage::encryption::EncryptionContext; use crate::storage::sqlite3_ondisk::{ begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame, write_pages_vectored, PageSize, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE, @@ -25,7 +24,7 @@ use crate::storage::sqlite3_ondisk::{ use crate::types::{IOCompletions, IOResult}; use crate::{ bail_corrupt_error, io_yield_many, turso_assert, Buffer, Completion, CompletionError, - LimboError, Result, + IOContext, LimboError, Result, }; #[derive(Debug, Clone, Default)] @@ -304,7 +303,7 @@ pub trait Wal: Debug { /// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no fn changed_pages_after(&self, frame_watermark: u64) -> Result>; - fn set_encryption_context(&mut self, ctx: EncryptionContext); + fn set_io_context(&mut self, ctx: IOContext); #[cfg(debug_assertions)] fn as_any(&self) -> &dyn std::any::Any; @@ -576,7 +575,7 @@ pub struct WalFile { /// Manages locks needed for checkpointing checkpoint_guard: Option, - encryption_ctx: RefCell>, + io_ctx: RefCell, } impl fmt::Debug for WalFile { @@ -1065,7 +1064,10 @@ impl Wal for WalFile { page.set_locked(); let frame = page.clone(); let page_idx = page.get().id; - let encryption_ctx = self.encryption_ctx.borrow().clone(); + let encryption_ctx = { + let io_ctx = self.io_ctx.borrow(); + io_ctx.encryption_context().cloned() + }; let seq = self.header.checkpoint_seq; let complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { let Ok((buf, bytes_read)) = res else { @@ -1244,7 +1246,8 @@ impl Wal for WalFile { let page_content = page.get_contents(); let page_buf = page_content.as_ptr(); - let encryption_ctx = self.encryption_ctx.borrow(); + let io_ctx = self.io_ctx.borrow(); + let encryption_ctx = io_ctx.encryption_context(); let encrypted_data = { if let Some(key) = encryption_ctx.as_ref() { Some(key.encrypt_page(page_buf, page_id)?) @@ -1442,7 +1445,8 @@ impl Wal for WalFile { let plain = page.get_contents().as_ptr(); let data_to_write: std::borrow::Cow<[u8]> = { - let ectx = self.encryption_ctx.borrow(); + let io_ctx = self.io_ctx.borrow(); + let ectx = io_ctx.encryption_context(); if let Some(ctx) = ectx.as_ref() { Cow::Owned(ctx.encrypt_page(plain, page_id as usize)?) } else { @@ -1511,8 +1515,8 @@ impl Wal for WalFile { self } - fn set_encryption_context(&mut self, ctx: EncryptionContext) { - self.encryption_ctx.replace(Some(ctx)); + fn set_io_context(&mut self, ctx: IOContext) { + self.io_ctx.replace(ctx); } } @@ -1550,7 +1554,7 @@ impl WalFile { prev_checkpoint: CheckpointResult::default(), checkpoint_guard: None, header: *header, - encryption_ctx: RefCell::new(None), + io_ctx: RefCell::new(IOContext::default()), } } @@ -1798,12 +1802,7 @@ impl WalFile { let batch_map = self.ongoing_checkpoint.pending_writes.take(); if !batch_map.is_empty() { let done_flag = self.ongoing_checkpoint.add_write(); - completions.extend(write_pages_vectored( - pager, - batch_map, - done_flag, - self.encryption_ctx.borrow().as_ref(), - )?); + completions.extend(write_pages_vectored(pager, batch_map, done_flag)?); } }