From 3090545167ee2ed60cdc97a764e629c0ebb541c5 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Thu, 21 Aug 2025 22:23:08 +0530 Subject: [PATCH] use encryption ctx instead of encryption key --- bindings/javascript/src/lib.rs | 6 +++--- core/lib.rs | 8 ++++---- core/storage/database.rs | 36 +++++++++++++++++++--------------- core/storage/pager.rs | 23 ++++++++++++---------- core/storage/sqlite3_ondisk.rs | 6 +++--- core/storage/wal.rs | 28 +++++++++++++------------- core/translate/pragma.rs | 2 +- 7 files changed, 58 insertions(+), 51 deletions(-) diff --git a/bindings/javascript/src/lib.rs b/bindings/javascript/src/lib.rs index 1c0218eeb..7cd1a6965 100644 --- a/bindings/javascript/src/lib.rs +++ b/bindings/javascript/src/lib.rs @@ -561,7 +561,7 @@ impl turso_core::DatabaseStorage for DatabaseFile { fn read_page( &self, page_idx: usize, - _key: Option<&turso_core::EncryptionKey>, + _encryption_ctx: Option<&turso_core::PerConnEncryptionContext>, c: turso_core::Completion, ) -> turso_core::Result { let r = c.as_read(); @@ -578,7 +578,7 @@ impl turso_core::DatabaseStorage for DatabaseFile { &self, page_idx: usize, buffer: Arc, - _key: Option<&turso_core::EncryptionKey>, + _encryption_ctx: Option<&turso_core::PerConnEncryptionContext>, c: turso_core::Completion, ) -> turso_core::Result { let size = buffer.len(); @@ -591,7 +591,7 @@ impl turso_core::DatabaseStorage for DatabaseFile { first_page_idx: usize, page_size: usize, buffers: Vec>, - _key: Option<&turso_core::EncryptionKey>, + _encryption_ctx: Option<&turso_core::PerConnEncryptionContext>, c: turso_core::Completion, ) -> turso_core::Result { let pos = first_page_idx.saturating_sub(1) * page_size; diff --git a/core/lib.rs b/core/lib.rs index 7162152ed..89ef7f4f2 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -76,7 +76,7 @@ use std::{ }; #[cfg(feature = "fs")] use storage::database::DatabaseFile; -pub use storage::encryption::EncryptionKey; +pub use storage::encryption::{EncryptionKey, PerConnEncryptionContext}; use storage::page_cache::DumbLruPageCache; use storage::pager::{AtomicDbState, DbState}; use storage::sqlite3_ondisk::PageSize; @@ -1904,11 +1904,11 @@ impl Connection { self.syms.borrow().vtab_modules.keys().cloned().collect() } - pub fn set_encryption_key(&self, key: Option) { + pub fn set_encryption_key(&self, key: EncryptionKey) { tracing::trace!("setting encryption key for connection"); - *self.encryption_key.borrow_mut() = key.clone(); + *self.encryption_key.borrow_mut() = Some(key.clone()); let pager = self.pager.borrow(); - pager.set_encryption_key(key); + pager.set_encryption_context(&key); } } diff --git a/core/storage/database.rs b/core/storage/database.rs index 980dce8e4..e2b27b302 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -1,5 +1,5 @@ use crate::error::LimboError; -use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey}; +use crate::storage::encryption::PerConnEncryptionContext; use crate::{io::Completion, Buffer, CompletionError, Result}; use std::sync::Arc; use tracing::{instrument, Level}; @@ -15,14 +15,14 @@ pub trait DatabaseStorage: Send + Sync { fn read_page( &self, page_idx: usize, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result; fn write_page( &self, page_idx: usize, buffer: Arc, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result; fn write_pages( @@ -30,7 +30,7 @@ pub trait DatabaseStorage: Send + Sync { first_page_idx: usize, page_size: usize, buffers: Vec>, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result; fn sync(&self, c: Completion) -> Result; @@ -59,7 +59,7 @@ impl DatabaseStorage for DatabaseFile { fn read_page( &self, page_idx: usize, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result { let r = c.as_read(); @@ -70,8 +70,8 @@ impl DatabaseStorage for DatabaseFile { } let pos = (page_idx - 1) * size; - if let Some(key) = encryption_key { - let key_clone = key.clone(); + if let Some(ctx) = encryption_ctx { + let encryption_ctx = ctx.clone(); let read_buffer = r.buf_arc(); let original_c = c.clone(); @@ -81,7 +81,7 @@ impl DatabaseStorage for DatabaseFile { return; }; if bytes_read > 0 { - match decrypt_page(buf.as_slice(), page_idx, &key_clone) { + match encryption_ctx.decrypt_page(buf.as_slice(), page_idx) { Ok(decrypted_data) => { let original_buf = original_c.as_read().buf(); original_buf.as_mut_slice().copy_from_slice(&decrypted_data); @@ -111,7 +111,7 @@ impl DatabaseStorage for DatabaseFile { &self, page_idx: usize, buffer: Arc, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result { let buffer_size = buffer.len(); @@ -121,8 +121,8 @@ impl DatabaseStorage for DatabaseFile { assert_eq!(buffer_size & (buffer_size - 1), 0); let pos = (page_idx - 1) * buffer_size; let buffer = { - if let Some(key) = encryption_key { - encrypt_buffer(page_idx, buffer, key) + if let Some(ctx) = encryption_ctx { + encrypt_buffer(page_idx, buffer, ctx) } else { buffer } @@ -135,7 +135,7 @@ impl DatabaseStorage for DatabaseFile { first_page_idx: usize, page_size: usize, buffers: Vec>, - encryption_key: Option<&EncryptionKey>, + encryption_key: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result { assert!(first_page_idx > 0); @@ -145,11 +145,11 @@ impl DatabaseStorage for DatabaseFile { let pos = (first_page_idx - 1) * page_size; let buffers = { - if let Some(key) = encryption_key { + if let Some(ctx) = encryption_key { buffers .into_iter() .enumerate() - .map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, key)) + .map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx)) .collect::>() } else { buffers @@ -184,7 +184,11 @@ impl DatabaseFile { } } -fn encrypt_buffer(page_idx: usize, buffer: Arc, key: &EncryptionKey) -> Arc { - let encrypted_data = encrypt_page(buffer.as_slice(), page_idx, key).unwrap(); +fn encrypt_buffer( + page_idx: usize, + buffer: Arc, + ctx: &PerConnEncryptionContext, +) -> Arc { + let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap(); Arc::new(Buffer::new(encrypted_data.to_vec())) } diff --git a/core/storage/pager.rs b/core/storage/pager.rs index c1247449c..e18f95a94 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -28,7 +28,9 @@ use super::btree::{btree_init_page, BTreePage}; use super::page_cache::{CacheError, CacheResizeResult, DumbLruPageCache, PageCacheKey}; use super::sqlite3_ondisk::begin_write_btree_page; use super::wal::CheckpointMode; -use crate::storage::encryption::{EncryptionKey, ENCRYPTION_METADATA_SIZE}; +use crate::storage::encryption::{ + EncryptionKey, PerConnEncryptionContext, ENCRYPTION_METADATA_SIZE, +}; /// SQLite's default maximum page count const DEFAULT_MAX_PAGE_COUNT: u32 = 0xfffffffe; @@ -483,7 +485,7 @@ pub struct Pager { header_ref_state: RefCell, #[cfg(not(feature = "omit_autovacuum"))] btree_create_vacuum_full_state: Cell, - pub(crate) encryption_key: RefCell>, + pub(crate) encryption_ctx: RefCell>, } #[derive(Debug, Clone)] @@ -585,7 +587,7 @@ impl Pager { header_ref_state: RefCell::new(HeaderRefState::Start), #[cfg(not(feature = "omit_autovacuum"))] btree_create_vacuum_full_state: Cell::new(BtreeCreateVacuumFullState::Start), - encryption_key: RefCell::new(None), + encryption_ctx: RefCell::new(None), }) } @@ -1072,7 +1074,7 @@ impl Pager { page_idx, page.clone(), allow_empty_read, - self.encryption_key.borrow().as_ref(), + self.encryption_ctx.borrow().as_ref(), )?; return Ok((page, c)); }; @@ -1090,7 +1092,7 @@ impl Pager { page_idx, page.clone(), allow_empty_read, - self.encryption_key.borrow().as_ref(), + self.encryption_ctx.borrow().as_ref(), )?; Ok((page, c)) } @@ -1116,7 +1118,7 @@ impl Pager { page_idx: usize, page: PageRef, allow_empty_read: bool, - encryption_key: Option<&EncryptionKey>, + encryption_key: Option<&PerConnEncryptionContext>, ) -> Result { sqlite3_ondisk::begin_read_page( self.db_file.clone(), @@ -1694,7 +1696,7 @@ impl Pager { default_header.database_size = 1.into(); // if a key is set, then we will reserve space for encryption metadata - if self.encryption_key.borrow().is_some() { + if self.encryption_ctx.borrow().is_some() { default_header.reserved_space = ENCRYPTION_METADATA_SIZE as u8; } @@ -2076,10 +2078,11 @@ impl Pager { Ok(IOResult::Done(f(header))) } - pub fn set_encryption_key(&self, key: Option) { - self.encryption_key.replace(key.clone()); + pub fn set_encryption_context(&self, key: &EncryptionKey) { + let encryption_ctx = PerConnEncryptionContext::new(key).unwrap(); + self.encryption_ctx.replace(Some(encryption_ctx.clone())); let Some(wal) = self.wal.as_ref() else { return }; - wal.borrow_mut().set_encryption_key(key) + wal.borrow_mut().set_encryption_context(encryption_ctx) } } diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 7bc1202c8..d1a37615d 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -59,7 +59,7 @@ use crate::storage::btree::offset::{ use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min}; use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; -use crate::storage::encryption::EncryptionKey; +use crate::storage::encryption::PerConnEncryptionContext; use crate::storage::pager::Pager; use crate::storage::wal::READMARK_NOT_USED; use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype}; @@ -870,7 +870,7 @@ pub fn begin_read_page( page: PageRef, page_idx: usize, allow_empty_read: bool, - encryption_key: Option<&EncryptionKey>, + encryption_key: Option<&PerConnEncryptionContext>, ) -> Result { tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx); let buf = buffer_pool.get_page(); @@ -965,7 +965,7 @@ pub fn write_pages_vectored( pager: &Pager, batch: BTreeMap>, done_flag: Arc, - encryption_key: Option<&EncryptionKey>, + encryption_key: Option<&PerConnEncryptionContext>, ) -> Result> { if batch.is_empty() { done_flag.store(true, Ordering::Relaxed); diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 35882bf7d..0a903f4b3 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -17,7 +17,7 @@ use super::sqlite3_ondisk::{self, checksum_wal, WalHeader, WAL_MAGIC_BE, WAL_MAG use crate::fast_lock::SpinLock; use crate::io::{clock, File, IO}; use crate::result::LimboResult; -use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey}; +use crate::storage::encryption::PerConnEncryptionContext; use crate::storage::sqlite3_ondisk::{ begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame, write_pages_vectored, PageSize, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE, @@ -297,7 +297,7 @@ pub trait Wal: Debug { /// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no fn changed_pages_after(&self, frame_watermark: u64) -> Result>; - fn set_encryption_key(&mut self, key: Option); + fn set_encryption_context(&mut self, ctx: PerConnEncryptionContext); #[cfg(debug_assertions)] fn as_any(&self) -> &dyn std::any::Any; @@ -568,7 +568,7 @@ pub struct WalFile { /// Manages locks needed for checkpointing checkpoint_guard: Option, - encryption_key: RefCell>, + encryption_ctx: RefCell>, } impl fmt::Debug for WalFile { @@ -1034,7 +1034,7 @@ impl Wal for WalFile { page.set_locked(); let frame = page.clone(); let page_idx = page.get().id; - let key = self.encryption_key.borrow().clone(); + let encryption_ctx = self.encryption_ctx.borrow().clone(); let seq = self.header.checkpoint_seq; let complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { let Ok((buf, bytes_read)) = res else { @@ -1047,8 +1047,8 @@ impl Wal for WalFile { "read({bytes_read}) less than expected({buf_len}): frame_id={frame_id}" ); let cloned = frame.clone(); - if let Some(key) = key.clone() { - match decrypt_page(buf.as_slice(), page_idx, &key) { + if let Some(ctx) = encryption_ctx.clone() { + match ctx.decrypt_page(buf.as_slice(), page_idx) { Ok(decrypted_data) => { buf.as_mut_slice().copy_from_slice(&decrypted_data); } @@ -1213,15 +1213,15 @@ impl Wal for WalFile { let page_content = page.get_contents(); let page_buf = page_content.as_ptr(); - let key = self.encryption_key.borrow(); + let encryption_ctx = self.encryption_ctx.borrow(); let encrypted_data = { - if let Some(key) = key.as_ref() { - Some(encrypt_page(page_buf, page_id, key)?) + if let Some(key) = encryption_ctx.as_ref() { + Some(key.encrypt_page(page_buf, page_id)?) } else { None } }; - let data_to_write = if key.as_ref().is_some() { + let data_to_write = if encryption_ctx.as_ref().is_some() { encrypted_data.as_ref().unwrap().as_slice() } else { page_buf @@ -1374,8 +1374,8 @@ impl Wal for WalFile { self } - fn set_encryption_key(&mut self, key: Option) { - self.encryption_key.replace(key); + fn set_encryption_context(&mut self, ctx: PerConnEncryptionContext) { + self.encryption_ctx.replace(Some(ctx)); } } @@ -1413,7 +1413,7 @@ impl WalFile { prev_checkpoint: CheckpointResult::default(), checkpoint_guard: None, header: *header, - encryption_key: RefCell::new(None), + encryption_ctx: RefCell::new(None), } } @@ -1665,7 +1665,7 @@ impl WalFile { pager, batch_map, done_flag, - self.encryption_key.borrow().as_ref(), + self.encryption_ctx.borrow().as_ref(), )?); } } diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 1cb9a9d04..c7528027a 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -312,7 +312,7 @@ fn update_pragma( PragmaName::EncryptionKey => { let value = parse_string(&value)?; let key = EncryptionKey::from_string(&value); - connection.set_encryption_key(Some(key)); + connection.set_encryption_key(key); Ok((program, TransactionMode::None)) } }