use encryption ctx instead of encryption key

This commit is contained in:
Avinash Sajjanshetty
2025-08-21 22:23:08 +05:30
parent cc8c763942
commit 3090545167
7 changed files with 58 additions and 51 deletions

View File

@@ -561,7 +561,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
fn read_page(
&self,
page_idx: usize,
_key: Option<&turso_core::EncryptionKey>,
_encryption_ctx: Option<&turso_core::PerConnEncryptionContext>,
c: turso_core::Completion,
) -> turso_core::Result<turso_core::Completion> {
let r = c.as_read();
@@ -578,7 +578,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
&self,
page_idx: usize,
buffer: Arc<turso_core::Buffer>,
_key: Option<&turso_core::EncryptionKey>,
_encryption_ctx: Option<&turso_core::PerConnEncryptionContext>,
c: turso_core::Completion,
) -> turso_core::Result<turso_core::Completion> {
let size = buffer.len();
@@ -591,7 +591,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
first_page_idx: usize,
page_size: usize,
buffers: Vec<Arc<turso_core::Buffer>>,
_key: Option<&turso_core::EncryptionKey>,
_encryption_ctx: Option<&turso_core::PerConnEncryptionContext>,
c: turso_core::Completion,
) -> turso_core::Result<turso_core::Completion> {
let pos = first_page_idx.saturating_sub(1) * page_size;

View File

@@ -76,7 +76,7 @@ use std::{
};
#[cfg(feature = "fs")]
use storage::database::DatabaseFile;
pub use storage::encryption::EncryptionKey;
pub use storage::encryption::{EncryptionKey, PerConnEncryptionContext};
use storage::page_cache::DumbLruPageCache;
use storage::pager::{AtomicDbState, DbState};
use storage::sqlite3_ondisk::PageSize;
@@ -1904,11 +1904,11 @@ impl Connection {
self.syms.borrow().vtab_modules.keys().cloned().collect()
}
pub fn set_encryption_key(&self, key: Option<EncryptionKey>) {
pub fn set_encryption_key(&self, key: EncryptionKey) {
tracing::trace!("setting encryption key for connection");
*self.encryption_key.borrow_mut() = key.clone();
*self.encryption_key.borrow_mut() = Some(key.clone());
let pager = self.pager.borrow();
pager.set_encryption_key(key);
pager.set_encryption_context(&key);
}
}

View File

@@ -1,5 +1,5 @@
use crate::error::LimboError;
use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey};
use crate::storage::encryption::PerConnEncryptionContext;
use crate::{io::Completion, Buffer, CompletionError, Result};
use std::sync::Arc;
use tracing::{instrument, Level};
@@ -15,14 +15,14 @@ pub trait DatabaseStorage: Send + Sync {
fn read_page(
&self,
page_idx: usize,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&PerConnEncryptionContext>,
c: Completion,
) -> Result<Completion>;
fn write_page(
&self,
page_idx: usize,
buffer: Arc<Buffer>,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&PerConnEncryptionContext>,
c: Completion,
) -> Result<Completion>;
fn write_pages(
@@ -30,7 +30,7 @@ pub trait DatabaseStorage: Send + Sync {
first_page_idx: usize,
page_size: usize,
buffers: Vec<Arc<Buffer>>,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&PerConnEncryptionContext>,
c: Completion,
) -> Result<Completion>;
fn sync(&self, c: Completion) -> Result<Completion>;
@@ -59,7 +59,7 @@ impl DatabaseStorage for DatabaseFile {
fn read_page(
&self,
page_idx: usize,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&PerConnEncryptionContext>,
c: Completion,
) -> Result<Completion> {
let r = c.as_read();
@@ -70,8 +70,8 @@ impl DatabaseStorage for DatabaseFile {
}
let pos = (page_idx - 1) * size;
if let Some(key) = encryption_key {
let key_clone = key.clone();
if let Some(ctx) = encryption_ctx {
let encryption_ctx = ctx.clone();
let read_buffer = r.buf_arc();
let original_c = c.clone();
@@ -81,7 +81,7 @@ impl DatabaseStorage for DatabaseFile {
return;
};
if bytes_read > 0 {
match decrypt_page(buf.as_slice(), page_idx, &key_clone) {
match encryption_ctx.decrypt_page(buf.as_slice(), page_idx) {
Ok(decrypted_data) => {
let original_buf = original_c.as_read().buf();
original_buf.as_mut_slice().copy_from_slice(&decrypted_data);
@@ -111,7 +111,7 @@ impl DatabaseStorage for DatabaseFile {
&self,
page_idx: usize,
buffer: Arc<Buffer>,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&PerConnEncryptionContext>,
c: Completion,
) -> Result<Completion> {
let buffer_size = buffer.len();
@@ -121,8 +121,8 @@ impl DatabaseStorage for DatabaseFile {
assert_eq!(buffer_size & (buffer_size - 1), 0);
let pos = (page_idx - 1) * buffer_size;
let buffer = {
if let Some(key) = encryption_key {
encrypt_buffer(page_idx, buffer, key)
if let Some(ctx) = encryption_ctx {
encrypt_buffer(page_idx, buffer, ctx)
} else {
buffer
}
@@ -135,7 +135,7 @@ impl DatabaseStorage for DatabaseFile {
first_page_idx: usize,
page_size: usize,
buffers: Vec<Arc<Buffer>>,
encryption_key: Option<&EncryptionKey>,
encryption_key: Option<&PerConnEncryptionContext>,
c: Completion,
) -> Result<Completion> {
assert!(first_page_idx > 0);
@@ -145,11 +145,11 @@ impl DatabaseStorage for DatabaseFile {
let pos = (first_page_idx - 1) * page_size;
let buffers = {
if let Some(key) = encryption_key {
if let Some(ctx) = encryption_key {
buffers
.into_iter()
.enumerate()
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, key))
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx))
.collect::<Vec<_>>()
} else {
buffers
@@ -184,7 +184,11 @@ impl DatabaseFile {
}
}
fn encrypt_buffer(page_idx: usize, buffer: Arc<Buffer>, key: &EncryptionKey) -> Arc<Buffer> {
let encrypted_data = encrypt_page(buffer.as_slice(), page_idx, key).unwrap();
fn encrypt_buffer(
page_idx: usize,
buffer: Arc<Buffer>,
ctx: &PerConnEncryptionContext,
) -> Arc<Buffer> {
let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap();
Arc::new(Buffer::new(encrypted_data.to_vec()))
}

View File

@@ -28,7 +28,9 @@ use super::btree::{btree_init_page, BTreePage};
use super::page_cache::{CacheError, CacheResizeResult, DumbLruPageCache, PageCacheKey};
use super::sqlite3_ondisk::begin_write_btree_page;
use super::wal::CheckpointMode;
use crate::storage::encryption::{EncryptionKey, ENCRYPTION_METADATA_SIZE};
use crate::storage::encryption::{
EncryptionKey, PerConnEncryptionContext, ENCRYPTION_METADATA_SIZE,
};
/// SQLite's default maximum page count
const DEFAULT_MAX_PAGE_COUNT: u32 = 0xfffffffe;
@@ -483,7 +485,7 @@ pub struct Pager {
header_ref_state: RefCell<HeaderRefState>,
#[cfg(not(feature = "omit_autovacuum"))]
btree_create_vacuum_full_state: Cell<BtreeCreateVacuumFullState>,
pub(crate) encryption_key: RefCell<Option<EncryptionKey>>,
pub(crate) encryption_ctx: RefCell<Option<PerConnEncryptionContext>>,
}
#[derive(Debug, Clone)]
@@ -585,7 +587,7 @@ impl Pager {
header_ref_state: RefCell::new(HeaderRefState::Start),
#[cfg(not(feature = "omit_autovacuum"))]
btree_create_vacuum_full_state: Cell::new(BtreeCreateVacuumFullState::Start),
encryption_key: RefCell::new(None),
encryption_ctx: RefCell::new(None),
})
}
@@ -1072,7 +1074,7 @@ impl Pager {
page_idx,
page.clone(),
allow_empty_read,
self.encryption_key.borrow().as_ref(),
self.encryption_ctx.borrow().as_ref(),
)?;
return Ok((page, c));
};
@@ -1090,7 +1092,7 @@ impl Pager {
page_idx,
page.clone(),
allow_empty_read,
self.encryption_key.borrow().as_ref(),
self.encryption_ctx.borrow().as_ref(),
)?;
Ok((page, c))
}
@@ -1116,7 +1118,7 @@ impl Pager {
page_idx: usize,
page: PageRef,
allow_empty_read: bool,
encryption_key: Option<&EncryptionKey>,
encryption_key: Option<&PerConnEncryptionContext>,
) -> Result<Completion> {
sqlite3_ondisk::begin_read_page(
self.db_file.clone(),
@@ -1694,7 +1696,7 @@ impl Pager {
default_header.database_size = 1.into();
// if a key is set, then we will reserve space for encryption metadata
if self.encryption_key.borrow().is_some() {
if self.encryption_ctx.borrow().is_some() {
default_header.reserved_space = ENCRYPTION_METADATA_SIZE as u8;
}
@@ -2076,10 +2078,11 @@ impl Pager {
Ok(IOResult::Done(f(header)))
}
pub fn set_encryption_key(&self, key: Option<EncryptionKey>) {
self.encryption_key.replace(key.clone());
pub fn set_encryption_context(&self, key: &EncryptionKey) {
let encryption_ctx = PerConnEncryptionContext::new(key).unwrap();
self.encryption_ctx.replace(Some(encryption_ctx.clone()));
let Some(wal) = self.wal.as_ref() else { return };
wal.borrow_mut().set_encryption_key(key)
wal.borrow_mut().set_encryption_context(encryption_ctx)
}
}

View File

@@ -59,7 +59,7 @@ use crate::storage::btree::offset::{
use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min};
use crate::storage::buffer_pool::BufferPool;
use crate::storage::database::DatabaseStorage;
use crate::storage::encryption::EncryptionKey;
use crate::storage::encryption::PerConnEncryptionContext;
use crate::storage::pager::Pager;
use crate::storage::wal::READMARK_NOT_USED;
use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype};
@@ -870,7 +870,7 @@ pub fn begin_read_page(
page: PageRef,
page_idx: usize,
allow_empty_read: bool,
encryption_key: Option<&EncryptionKey>,
encryption_key: Option<&PerConnEncryptionContext>,
) -> Result<Completion> {
tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx);
let buf = buffer_pool.get_page();
@@ -965,7 +965,7 @@ pub fn write_pages_vectored(
pager: &Pager,
batch: BTreeMap<usize, Arc<Buffer>>,
done_flag: Arc<AtomicBool>,
encryption_key: Option<&EncryptionKey>,
encryption_key: Option<&PerConnEncryptionContext>,
) -> Result<Vec<Completion>> {
if batch.is_empty() {
done_flag.store(true, Ordering::Relaxed);

View File

@@ -17,7 +17,7 @@ use super::sqlite3_ondisk::{self, checksum_wal, WalHeader, WAL_MAGIC_BE, WAL_MAG
use crate::fast_lock::SpinLock;
use crate::io::{clock, File, IO};
use crate::result::LimboResult;
use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey};
use crate::storage::encryption::PerConnEncryptionContext;
use crate::storage::sqlite3_ondisk::{
begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame,
write_pages_vectored, PageSize, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE,
@@ -297,7 +297,7 @@ pub trait Wal: Debug {
/// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no
fn changed_pages_after(&self, frame_watermark: u64) -> Result<Vec<u32>>;
fn set_encryption_key(&mut self, key: Option<EncryptionKey>);
fn set_encryption_context(&mut self, ctx: PerConnEncryptionContext);
#[cfg(debug_assertions)]
fn as_any(&self) -> &dyn std::any::Any;
@@ -568,7 +568,7 @@ pub struct WalFile {
/// Manages locks needed for checkpointing
checkpoint_guard: Option<CheckpointLocks>,
encryption_key: RefCell<Option<EncryptionKey>>,
encryption_ctx: RefCell<Option<PerConnEncryptionContext>>,
}
impl fmt::Debug for WalFile {
@@ -1034,7 +1034,7 @@ impl Wal for WalFile {
page.set_locked();
let frame = page.clone();
let page_idx = page.get().id;
let key = self.encryption_key.borrow().clone();
let encryption_ctx = self.encryption_ctx.borrow().clone();
let seq = self.header.checkpoint_seq;
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
@@ -1047,8 +1047,8 @@ impl Wal for WalFile {
"read({bytes_read}) less than expected({buf_len}): frame_id={frame_id}"
);
let cloned = frame.clone();
if let Some(key) = key.clone() {
match decrypt_page(buf.as_slice(), page_idx, &key) {
if let Some(ctx) = encryption_ctx.clone() {
match ctx.decrypt_page(buf.as_slice(), page_idx) {
Ok(decrypted_data) => {
buf.as_mut_slice().copy_from_slice(&decrypted_data);
}
@@ -1213,15 +1213,15 @@ impl Wal for WalFile {
let page_content = page.get_contents();
let page_buf = page_content.as_ptr();
let key = self.encryption_key.borrow();
let encryption_ctx = self.encryption_ctx.borrow();
let encrypted_data = {
if let Some(key) = key.as_ref() {
Some(encrypt_page(page_buf, page_id, key)?)
if let Some(key) = encryption_ctx.as_ref() {
Some(key.encrypt_page(page_buf, page_id)?)
} else {
None
}
};
let data_to_write = if key.as_ref().is_some() {
let data_to_write = if encryption_ctx.as_ref().is_some() {
encrypted_data.as_ref().unwrap().as_slice()
} else {
page_buf
@@ -1374,8 +1374,8 @@ impl Wal for WalFile {
self
}
fn set_encryption_key(&mut self, key: Option<EncryptionKey>) {
self.encryption_key.replace(key);
fn set_encryption_context(&mut self, ctx: PerConnEncryptionContext) {
self.encryption_ctx.replace(Some(ctx));
}
}
@@ -1413,7 +1413,7 @@ impl WalFile {
prev_checkpoint: CheckpointResult::default(),
checkpoint_guard: None,
header: *header,
encryption_key: RefCell::new(None),
encryption_ctx: RefCell::new(None),
}
}
@@ -1665,7 +1665,7 @@ impl WalFile {
pager,
batch_map,
done_flag,
self.encryption_key.borrow().as_ref(),
self.encryption_ctx.borrow().as_ref(),
)?);
}
}

View File

@@ -312,7 +312,7 @@ fn update_pragma(
PragmaName::EncryptionKey => {
let value = parse_string(&value)?;
let key = EncryptionKey::from_string(&value);
connection.set_encryption_key(Some(key));
connection.set_encryption_key(key);
Ok((program, TransactionMode::None))
}
}