From cc8c7639424de7732515aed91a3f5fc4214226ca Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Thu, 21 Aug 2025 22:22:43 +0530 Subject: [PATCH 1/2] refactor encryption module and make it configurable --- core/storage/encryption.rs | 279 ++++++++++++++++++++++++------------- 1 file changed, 184 insertions(+), 95 deletions(-) diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index 97c2b3574..8c8f7bd44 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -1,7 +1,5 @@ #![allow(unused_variables, dead_code)] -#[cfg(not(feature = "encryption"))] -use crate::LimboError; -use crate::Result; +use crate::{LimboError, Result}; use aes_gcm::{ aead::{Aead, AeadCore, KeyInit, OsRng}, Aes256Gcm, Key, Nonce, @@ -11,6 +9,7 @@ use std::ops::Deref; pub const ENCRYPTION_METADATA_SIZE: usize = 28; pub const ENCRYPTED_PAGE_SIZE: usize = 4096; pub const ENCRYPTION_NONCE_SIZE: usize = 12; +pub const ENCRYPTION_TAG_SIZE: usize = 16; #[repr(transparent)] #[derive(Clone)] @@ -71,106 +70,195 @@ impl Drop for EncryptionKey { } } -#[cfg(not(feature = "encryption"))] -pub fn encrypt_page(page: &[u8], page_id: usize, key: &EncryptionKey) -> Result> { - Err(LimboError::InvalidArgument( - "encryption is not enabled, cannot encrypt page. enable via passing `--features encryption`".into(), - )) +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum CipherMode { + Aes256Gcm, } -#[cfg(feature = "encryption")] -pub fn encrypt_page(page: &[u8], page_id: usize, key: &EncryptionKey) -> Result> { - if page_id == 1 { - tracing::debug!("skipping encryption for page 1 (database header)"); - return Ok(page.to_vec()); +impl CipherMode { + /// Every cipher requires a specific key size. For 256-bit algorithms, this is 32 bytes. + /// For 128-bit algorithms, it would be 16 bytes, etc. + pub fn required_key_size(&self) -> usize { + match self { + CipherMode::Aes256Gcm => 32, + } } - tracing::debug!("encrypting page {}", page_id); - assert_eq!( - page.len(), - ENCRYPTED_PAGE_SIZE, - "Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" - ); - let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE..]; - let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0); - assert!( - reserved_bytes_zeroed, - "last reserved bytes must be empty/zero, but found non-zero bytes" - ); - let payload = &page[..ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE]; - let (encrypted, nonce) = encrypt(payload, key)?; - assert_eq!( - encrypted.len(), - ENCRYPTED_PAGE_SIZE - nonce.len(), - "Encrypted page must be exactly {} bytes", - ENCRYPTED_PAGE_SIZE - nonce.len() - ); - let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); - result.extend_from_slice(&encrypted); - result.extend_from_slice(&nonce); - assert_eq!( - result.len(), - ENCRYPTED_PAGE_SIZE, - "Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes" - ); - Ok(result) -} -#[cfg(not(feature = "encryption"))] -pub fn decrypt_page(encrypted_page: &[u8], page_id: usize, key: &EncryptionKey) -> Result> { - Err(LimboError::InvalidArgument( - "encryption is not enabled, cannot decrypt page. enable via passing `--features encryption`".into(), - )) -} - -#[cfg(feature = "encryption")] -pub fn decrypt_page(encrypted_page: &[u8], page_id: usize, key: &EncryptionKey) -> Result> { - if page_id == 1 { - tracing::debug!("skipping decryption for page 1 (database header)"); - return Ok(encrypted_page.to_vec()); + /// Returns the nonce size for this cipher mode. Though most AEAD ciphers use 12-byte nonces. + pub fn nonce_size(&self) -> usize { + match self { + CipherMode::Aes256Gcm => ENCRYPTION_NONCE_SIZE, + } } - tracing::debug!("decrypting page {}", page_id); - assert_eq!( - encrypted_page.len(), - ENCRYPTED_PAGE_SIZE, - "Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" - ); - let nonce_start = encrypted_page.len() - ENCRYPTION_NONCE_SIZE; - let payload = &encrypted_page[..nonce_start]; - let nonce = &encrypted_page[nonce_start..]; - - let decrypted_data = decrypt(payload, nonce, key)?; - assert_eq!( - decrypted_data.len(), - ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE, - "Decrypted page data must be exactly {} bytes", - ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE - ); - let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); - result.extend_from_slice(&decrypted_data); - result.resize(ENCRYPTED_PAGE_SIZE, 0); - assert_eq!( - result.len(), - ENCRYPTED_PAGE_SIZE, - "Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" - ); - Ok(result) + /// Returns the authentication tag size for this cipher mode. All common AEAD ciphers use 16-byte tags. + pub fn tag_size(&self) -> usize { + match self { + CipherMode::Aes256Gcm => ENCRYPTION_TAG_SIZE, + } + } } -fn encrypt(plaintext: &[u8], key: &EncryptionKey) -> Result<(Vec, Vec)> { - let key: &Key = key.as_ref().into(); - let cipher = Aes256Gcm::new(key); - let nonce = Aes256Gcm::generate_nonce(&mut OsRng); - let ciphertext = cipher.encrypt(&nonce, plaintext).unwrap(); - Ok((ciphertext, nonce.to_vec())) +#[derive(Clone)] +pub enum Cipher { + Aes256Gcm(Box), } -fn decrypt(ciphertext: &[u8], nonce: &[u8], key: &EncryptionKey) -> Result> { - let key: &Key = key.as_ref().into(); - let cipher = Aes256Gcm::new(key); - let nonce = Nonce::from_slice(nonce); - let plaintext = cipher.decrypt(nonce, ciphertext).unwrap(); - Ok(plaintext) +impl std::fmt::Debug for Cipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Cipher::Aes256Gcm(_) => write!(f, "Cipher::Aes256Gcm"), + } + } +} + +#[derive(Clone)] +pub struct PerConnEncryptionContext { + cipher_mode: CipherMode, + cipher: Cipher, +} + +impl PerConnEncryptionContext { + pub fn new(key: &EncryptionKey) -> Result { + let cipher_mode = CipherMode::Aes256Gcm; + let required_size = cipher_mode.required_key_size(); + if key.as_slice().len() != required_size { + return Err(crate::LimboError::InvalidArgument(format!( + "Invalid key size for {:?}: expected {} bytes, got {}", + cipher_mode, + required_size, + key.as_slice().len() + ))); + } + + let cipher = match cipher_mode { + CipherMode::Aes256Gcm => { + let cipher_key: &Key = key.as_ref().into(); + Cipher::Aes256Gcm(Box::new(Aes256Gcm::new(cipher_key))) + } + }; + Ok(Self { + cipher_mode, + cipher, + }) + } + + pub fn cipher_mode(&self) -> CipherMode { + self.cipher_mode + } + + #[cfg(feature = "encryption")] + pub fn encrypt_page(&self, page: &[u8], page_id: usize) -> Result> { + if page_id == 1 { + tracing::debug!("skipping encryption for page 1 (database header)"); + return Ok(page.to_vec()); + } + tracing::debug!("encrypting page {}", page_id); + assert_eq!( + page.len(), + ENCRYPTED_PAGE_SIZE, + "Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" + ); + let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE..]; + let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0); + assert!( + reserved_bytes_zeroed, + "last reserved bytes must be empty/zero, but found non-zero bytes" + ); + let payload = &page[..ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE]; + let (encrypted, nonce) = self.encrypt_raw(payload)?; + + assert_eq!( + encrypted.len(), + ENCRYPTED_PAGE_SIZE - nonce.len(), + "Encrypted page must be exactly {} bytes", + ENCRYPTED_PAGE_SIZE - nonce.len() + ); + let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); + result.extend_from_slice(&encrypted); + result.extend_from_slice(&nonce); + assert_eq!( + result.len(), + ENCRYPTED_PAGE_SIZE, + "Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes" + ); + Ok(result) + } + + #[cfg(feature = "encryption")] + pub fn decrypt_page(&self, encrypted_page: &[u8], page_id: usize) -> Result> { + if page_id == 1 { + tracing::debug!("skipping decryption for page 1 (database header)"); + return Ok(encrypted_page.to_vec()); + } + tracing::debug!("decrypting page {}", page_id); + assert_eq!( + encrypted_page.len(), + ENCRYPTED_PAGE_SIZE, + "Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" + ); + + let nonce_start = encrypted_page.len() - ENCRYPTION_NONCE_SIZE; + let payload = &encrypted_page[..nonce_start]; + let nonce = &encrypted_page[nonce_start..]; + + let decrypted_data = self.decrypt_raw(payload, nonce)?; + + assert_eq!( + decrypted_data.len(), + ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE, + "Decrypted page data must be exactly {} bytes", + ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE + ); + let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); + result.extend_from_slice(&decrypted_data); + result.resize(ENCRYPTED_PAGE_SIZE, 0); + assert_eq!( + result.len(), + ENCRYPTED_PAGE_SIZE, + "Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" + ); + Ok(result) + } + + /// encrypts raw data using the configured cipher, returns ciphertext and nonce + fn encrypt_raw(&self, plaintext: &[u8]) -> Result<(Vec, Vec)> { + match &self.cipher { + Cipher::Aes256Gcm(cipher) => { + let nonce = Aes256Gcm::generate_nonce(&mut OsRng); + let ciphertext = cipher + .encrypt(&nonce, plaintext) + .map_err(|e| LimboError::InternalError(format!("Encryption failed: {e:?}")))?; + Ok((ciphertext, nonce.to_vec())) + } + } + } + + fn decrypt_raw(&self, ciphertext: &[u8], nonce: &[u8]) -> Result> { + match &self.cipher { + Cipher::Aes256Gcm(cipher) => { + let nonce = Nonce::from_slice(nonce); + let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|e| { + crate::LimboError::InternalError(format!("Decryption failed: {e:?}")) + })?; + Ok(plaintext) + } + } + } + + #[cfg(not(feature = "encryption"))] + pub fn encrypt_page(&self, _page: &[u8], _page_id: usize) -> Result> { + Err(LimboError::InvalidArgument( + "encryption is not enabled, cannot encrypt page. enable via passing `--features encryption`".into(), + )) + } + + #[cfg(not(feature = "encryption"))] + pub fn decrypt_page(&self, _encrypted_page: &[u8], _page_id: usize) -> Result> { + Err(LimboError::InvalidArgument( + "encryption is not enabled, cannot decrypt page. enable via passing `--features encryption`".into(), + )) + } } #[cfg(test)] @@ -193,14 +281,15 @@ mod tests { }; let key = EncryptionKey::from_string("alice and bob use encryption on database"); + let ctx = PerConnEncryptionContext::new(&key).unwrap(); let page_id = 42; - let encrypted = encrypt_page(&page_data, page_id, &key).unwrap(); + let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap(); assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE); assert_ne!(&encrypted[..data_size], &page_data[..data_size]); assert_ne!(&encrypted[..], &page_data[..]); - let decrypted = decrypt_page(&encrypted, page_id, &key).unwrap(); + let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap(); assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE); assert_eq!(decrypted, page_data); } From 3090545167ee2ed60cdc97a764e629c0ebb541c5 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Thu, 21 Aug 2025 22:23:08 +0530 Subject: [PATCH 2/2] use encryption ctx instead of encryption key --- bindings/javascript/src/lib.rs | 6 +++--- core/lib.rs | 8 ++++---- core/storage/database.rs | 36 +++++++++++++++++++--------------- core/storage/pager.rs | 23 ++++++++++++---------- core/storage/sqlite3_ondisk.rs | 6 +++--- core/storage/wal.rs | 28 +++++++++++++------------- core/translate/pragma.rs | 2 +- 7 files changed, 58 insertions(+), 51 deletions(-) diff --git a/bindings/javascript/src/lib.rs b/bindings/javascript/src/lib.rs index 1c0218eeb..7cd1a6965 100644 --- a/bindings/javascript/src/lib.rs +++ b/bindings/javascript/src/lib.rs @@ -561,7 +561,7 @@ impl turso_core::DatabaseStorage for DatabaseFile { fn read_page( &self, page_idx: usize, - _key: Option<&turso_core::EncryptionKey>, + _encryption_ctx: Option<&turso_core::PerConnEncryptionContext>, c: turso_core::Completion, ) -> turso_core::Result { let r = c.as_read(); @@ -578,7 +578,7 @@ impl turso_core::DatabaseStorage for DatabaseFile { &self, page_idx: usize, buffer: Arc, - _key: Option<&turso_core::EncryptionKey>, + _encryption_ctx: Option<&turso_core::PerConnEncryptionContext>, c: turso_core::Completion, ) -> turso_core::Result { let size = buffer.len(); @@ -591,7 +591,7 @@ impl turso_core::DatabaseStorage for DatabaseFile { first_page_idx: usize, page_size: usize, buffers: Vec>, - _key: Option<&turso_core::EncryptionKey>, + _encryption_ctx: Option<&turso_core::PerConnEncryptionContext>, c: turso_core::Completion, ) -> turso_core::Result { let pos = first_page_idx.saturating_sub(1) * page_size; diff --git a/core/lib.rs b/core/lib.rs index 7162152ed..89ef7f4f2 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -76,7 +76,7 @@ use std::{ }; #[cfg(feature = "fs")] use storage::database::DatabaseFile; -pub use storage::encryption::EncryptionKey; +pub use storage::encryption::{EncryptionKey, PerConnEncryptionContext}; use storage::page_cache::DumbLruPageCache; use storage::pager::{AtomicDbState, DbState}; use storage::sqlite3_ondisk::PageSize; @@ -1904,11 +1904,11 @@ impl Connection { self.syms.borrow().vtab_modules.keys().cloned().collect() } - pub fn set_encryption_key(&self, key: Option) { + pub fn set_encryption_key(&self, key: EncryptionKey) { tracing::trace!("setting encryption key for connection"); - *self.encryption_key.borrow_mut() = key.clone(); + *self.encryption_key.borrow_mut() = Some(key.clone()); let pager = self.pager.borrow(); - pager.set_encryption_key(key); + pager.set_encryption_context(&key); } } diff --git a/core/storage/database.rs b/core/storage/database.rs index 980dce8e4..e2b27b302 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -1,5 +1,5 @@ use crate::error::LimboError; -use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey}; +use crate::storage::encryption::PerConnEncryptionContext; use crate::{io::Completion, Buffer, CompletionError, Result}; use std::sync::Arc; use tracing::{instrument, Level}; @@ -15,14 +15,14 @@ pub trait DatabaseStorage: Send + Sync { fn read_page( &self, page_idx: usize, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result; fn write_page( &self, page_idx: usize, buffer: Arc, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result; fn write_pages( @@ -30,7 +30,7 @@ pub trait DatabaseStorage: Send + Sync { first_page_idx: usize, page_size: usize, buffers: Vec>, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result; fn sync(&self, c: Completion) -> Result; @@ -59,7 +59,7 @@ impl DatabaseStorage for DatabaseFile { fn read_page( &self, page_idx: usize, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result { let r = c.as_read(); @@ -70,8 +70,8 @@ impl DatabaseStorage for DatabaseFile { } let pos = (page_idx - 1) * size; - if let Some(key) = encryption_key { - let key_clone = key.clone(); + if let Some(ctx) = encryption_ctx { + let encryption_ctx = ctx.clone(); let read_buffer = r.buf_arc(); let original_c = c.clone(); @@ -81,7 +81,7 @@ impl DatabaseStorage for DatabaseFile { return; }; if bytes_read > 0 { - match decrypt_page(buf.as_slice(), page_idx, &key_clone) { + match encryption_ctx.decrypt_page(buf.as_slice(), page_idx) { Ok(decrypted_data) => { let original_buf = original_c.as_read().buf(); original_buf.as_mut_slice().copy_from_slice(&decrypted_data); @@ -111,7 +111,7 @@ impl DatabaseStorage for DatabaseFile { &self, page_idx: usize, buffer: Arc, - encryption_key: Option<&EncryptionKey>, + encryption_ctx: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result { let buffer_size = buffer.len(); @@ -121,8 +121,8 @@ impl DatabaseStorage for DatabaseFile { assert_eq!(buffer_size & (buffer_size - 1), 0); let pos = (page_idx - 1) * buffer_size; let buffer = { - if let Some(key) = encryption_key { - encrypt_buffer(page_idx, buffer, key) + if let Some(ctx) = encryption_ctx { + encrypt_buffer(page_idx, buffer, ctx) } else { buffer } @@ -135,7 +135,7 @@ impl DatabaseStorage for DatabaseFile { first_page_idx: usize, page_size: usize, buffers: Vec>, - encryption_key: Option<&EncryptionKey>, + encryption_key: Option<&PerConnEncryptionContext>, c: Completion, ) -> Result { assert!(first_page_idx > 0); @@ -145,11 +145,11 @@ impl DatabaseStorage for DatabaseFile { let pos = (first_page_idx - 1) * page_size; let buffers = { - if let Some(key) = encryption_key { + if let Some(ctx) = encryption_key { buffers .into_iter() .enumerate() - .map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, key)) + .map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx)) .collect::>() } else { buffers @@ -184,7 +184,11 @@ impl DatabaseFile { } } -fn encrypt_buffer(page_idx: usize, buffer: Arc, key: &EncryptionKey) -> Arc { - let encrypted_data = encrypt_page(buffer.as_slice(), page_idx, key).unwrap(); +fn encrypt_buffer( + page_idx: usize, + buffer: Arc, + ctx: &PerConnEncryptionContext, +) -> Arc { + let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap(); Arc::new(Buffer::new(encrypted_data.to_vec())) } diff --git a/core/storage/pager.rs b/core/storage/pager.rs index c1247449c..e18f95a94 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -28,7 +28,9 @@ use super::btree::{btree_init_page, BTreePage}; use super::page_cache::{CacheError, CacheResizeResult, DumbLruPageCache, PageCacheKey}; use super::sqlite3_ondisk::begin_write_btree_page; use super::wal::CheckpointMode; -use crate::storage::encryption::{EncryptionKey, ENCRYPTION_METADATA_SIZE}; +use crate::storage::encryption::{ + EncryptionKey, PerConnEncryptionContext, ENCRYPTION_METADATA_SIZE, +}; /// SQLite's default maximum page count const DEFAULT_MAX_PAGE_COUNT: u32 = 0xfffffffe; @@ -483,7 +485,7 @@ pub struct Pager { header_ref_state: RefCell, #[cfg(not(feature = "omit_autovacuum"))] btree_create_vacuum_full_state: Cell, - pub(crate) encryption_key: RefCell>, + pub(crate) encryption_ctx: RefCell>, } #[derive(Debug, Clone)] @@ -585,7 +587,7 @@ impl Pager { header_ref_state: RefCell::new(HeaderRefState::Start), #[cfg(not(feature = "omit_autovacuum"))] btree_create_vacuum_full_state: Cell::new(BtreeCreateVacuumFullState::Start), - encryption_key: RefCell::new(None), + encryption_ctx: RefCell::new(None), }) } @@ -1072,7 +1074,7 @@ impl Pager { page_idx, page.clone(), allow_empty_read, - self.encryption_key.borrow().as_ref(), + self.encryption_ctx.borrow().as_ref(), )?; return Ok((page, c)); }; @@ -1090,7 +1092,7 @@ impl Pager { page_idx, page.clone(), allow_empty_read, - self.encryption_key.borrow().as_ref(), + self.encryption_ctx.borrow().as_ref(), )?; Ok((page, c)) } @@ -1116,7 +1118,7 @@ impl Pager { page_idx: usize, page: PageRef, allow_empty_read: bool, - encryption_key: Option<&EncryptionKey>, + encryption_key: Option<&PerConnEncryptionContext>, ) -> Result { sqlite3_ondisk::begin_read_page( self.db_file.clone(), @@ -1694,7 +1696,7 @@ impl Pager { default_header.database_size = 1.into(); // if a key is set, then we will reserve space for encryption metadata - if self.encryption_key.borrow().is_some() { + if self.encryption_ctx.borrow().is_some() { default_header.reserved_space = ENCRYPTION_METADATA_SIZE as u8; } @@ -2076,10 +2078,11 @@ impl Pager { Ok(IOResult::Done(f(header))) } - pub fn set_encryption_key(&self, key: Option) { - self.encryption_key.replace(key.clone()); + pub fn set_encryption_context(&self, key: &EncryptionKey) { + let encryption_ctx = PerConnEncryptionContext::new(key).unwrap(); + self.encryption_ctx.replace(Some(encryption_ctx.clone())); let Some(wal) = self.wal.as_ref() else { return }; - wal.borrow_mut().set_encryption_key(key) + wal.borrow_mut().set_encryption_context(encryption_ctx) } } diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 7bc1202c8..d1a37615d 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -59,7 +59,7 @@ use crate::storage::btree::offset::{ use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min}; use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; -use crate::storage::encryption::EncryptionKey; +use crate::storage::encryption::PerConnEncryptionContext; use crate::storage::pager::Pager; use crate::storage::wal::READMARK_NOT_USED; use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype}; @@ -870,7 +870,7 @@ pub fn begin_read_page( page: PageRef, page_idx: usize, allow_empty_read: bool, - encryption_key: Option<&EncryptionKey>, + encryption_key: Option<&PerConnEncryptionContext>, ) -> Result { tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx); let buf = buffer_pool.get_page(); @@ -965,7 +965,7 @@ pub fn write_pages_vectored( pager: &Pager, batch: BTreeMap>, done_flag: Arc, - encryption_key: Option<&EncryptionKey>, + encryption_key: Option<&PerConnEncryptionContext>, ) -> Result> { if batch.is_empty() { done_flag.store(true, Ordering::Relaxed); diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 35882bf7d..0a903f4b3 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -17,7 +17,7 @@ use super::sqlite3_ondisk::{self, checksum_wal, WalHeader, WAL_MAGIC_BE, WAL_MAG use crate::fast_lock::SpinLock; use crate::io::{clock, File, IO}; use crate::result::LimboResult; -use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey}; +use crate::storage::encryption::PerConnEncryptionContext; use crate::storage::sqlite3_ondisk::{ begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame, write_pages_vectored, PageSize, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE, @@ -297,7 +297,7 @@ pub trait Wal: Debug { /// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no fn changed_pages_after(&self, frame_watermark: u64) -> Result>; - fn set_encryption_key(&mut self, key: Option); + fn set_encryption_context(&mut self, ctx: PerConnEncryptionContext); #[cfg(debug_assertions)] fn as_any(&self) -> &dyn std::any::Any; @@ -568,7 +568,7 @@ pub struct WalFile { /// Manages locks needed for checkpointing checkpoint_guard: Option, - encryption_key: RefCell>, + encryption_ctx: RefCell>, } impl fmt::Debug for WalFile { @@ -1034,7 +1034,7 @@ impl Wal for WalFile { page.set_locked(); let frame = page.clone(); let page_idx = page.get().id; - let key = self.encryption_key.borrow().clone(); + let encryption_ctx = self.encryption_ctx.borrow().clone(); let seq = self.header.checkpoint_seq; let complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { let Ok((buf, bytes_read)) = res else { @@ -1047,8 +1047,8 @@ impl Wal for WalFile { "read({bytes_read}) less than expected({buf_len}): frame_id={frame_id}" ); let cloned = frame.clone(); - if let Some(key) = key.clone() { - match decrypt_page(buf.as_slice(), page_idx, &key) { + if let Some(ctx) = encryption_ctx.clone() { + match ctx.decrypt_page(buf.as_slice(), page_idx) { Ok(decrypted_data) => { buf.as_mut_slice().copy_from_slice(&decrypted_data); } @@ -1213,15 +1213,15 @@ impl Wal for WalFile { let page_content = page.get_contents(); let page_buf = page_content.as_ptr(); - let key = self.encryption_key.borrow(); + let encryption_ctx = self.encryption_ctx.borrow(); let encrypted_data = { - if let Some(key) = key.as_ref() { - Some(encrypt_page(page_buf, page_id, key)?) + if let Some(key) = encryption_ctx.as_ref() { + Some(key.encrypt_page(page_buf, page_id)?) } else { None } }; - let data_to_write = if key.as_ref().is_some() { + let data_to_write = if encryption_ctx.as_ref().is_some() { encrypted_data.as_ref().unwrap().as_slice() } else { page_buf @@ -1374,8 +1374,8 @@ impl Wal for WalFile { self } - fn set_encryption_key(&mut self, key: Option) { - self.encryption_key.replace(key); + fn set_encryption_context(&mut self, ctx: PerConnEncryptionContext) { + self.encryption_ctx.replace(Some(ctx)); } } @@ -1413,7 +1413,7 @@ impl WalFile { prev_checkpoint: CheckpointResult::default(), checkpoint_guard: None, header: *header, - encryption_key: RefCell::new(None), + encryption_ctx: RefCell::new(None), } } @@ -1665,7 +1665,7 @@ impl WalFile { pager, batch_map, done_flag, - self.encryption_key.borrow().as_ref(), + self.encryption_ctx.borrow().as_ref(), )?); } } diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 1cb9a9d04..c7528027a 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -312,7 +312,7 @@ fn update_pragma( PragmaName::EncryptionKey => { let value = parse_string(&value)?; let key = EncryptionKey::from_string(&value); - connection.set_encryption_key(Some(key)); + connection.set_encryption_key(key); Ok((program, TransactionMode::None)) } }