Set and propagate IOContext as required

This commit is contained in:
Avinash Sajjanshetty
2025-08-27 21:44:25 +05:30
parent 9e663c7f46
commit 2c0842ff52
5 changed files with 61 additions and 48 deletions

View File

@@ -7404,7 +7404,7 @@ mod tests {
},
types::Text,
vdbe::Register,
BufferPool, Completion, Connection, StepResult, WalFile, WalFileShared,
BufferPool, Completion, Connection, IOContext, StepResult, WalFile, WalFileShared,
};
use std::{
cell::RefCell,
@@ -8713,9 +8713,12 @@ mod tests {
let c = Completion::new_write(move |_| {
let _ = _buf.clone();
});
let _c = pager
.db_file
.write_page(current_page as usize, buf.clone(), None, c)?;
let _c = pager.db_file.write_page(
current_page as usize,
buf.clone(),
&IOContext::default(),
c,
)?;
pager.io.run_once()?;
let (page, _c) = cursor.read_page(current_page as usize)?;

View File

@@ -4,12 +4,14 @@ use crate::{io::Completion, Buffer, CompletionError, Result};
use std::sync::Arc;
use tracing::{instrument, Level};
#[derive(Clone)]
pub enum EncryptionOrChecksum {
Encryption(EncryptionContext),
Checksum,
None,
}
#[derive(Clone)]
pub struct IOContext {
encryption_or_checksum: EncryptionOrChecksum,
}
@@ -21,6 +23,18 @@ impl IOContext {
_ => None,
}
}
pub fn set_encryption(&mut self, encryption_ctx: EncryptionContext) {
self.encryption_or_checksum = EncryptionOrChecksum::Encryption(encryption_ctx);
}
}
impl Default for IOContext {
fn default() -> Self {
Self {
encryption_or_checksum: EncryptionOrChecksum::None,
}
}
}
/// DatabaseStorage is an interface a database file that consists of pages.

View File

@@ -11,7 +11,7 @@ use crate::storage::{
};
use crate::types::{IOCompletions, WalState};
use crate::util::IOExt as _;
use crate::{io_yield_many, io_yield_one};
use crate::{io_yield_many, io_yield_one, IOContext};
use crate::{
return_if_io, turso_assert, types::WalFrameInfo, Completion, Connection, IOResult, LimboError,
Result, TransactionState,
@@ -503,7 +503,7 @@ pub struct Pager {
header_ref_state: RefCell<HeaderRefState>,
#[cfg(not(feature = "omit_autovacuum"))]
btree_create_vacuum_full_state: Cell<BtreeCreateVacuumFullState>,
pub(crate) encryption_ctx: RefCell<Option<EncryptionContext>>,
pub(crate) io_ctx: RefCell<IOContext>,
}
#[derive(Debug, Clone)]
@@ -607,7 +607,7 @@ impl Pager {
header_ref_state: RefCell::new(HeaderRefState::Start),
#[cfg(not(feature = "omit_autovacuum"))]
btree_create_vacuum_full_state: Cell::new(BtreeCreateVacuumFullState::Start),
encryption_ctx: RefCell::new(None),
io_ctx: RefCell::new(IOContext::default()),
})
}
@@ -1094,7 +1094,7 @@ impl Pager {
) -> Result<(PageRef, Completion)> {
tracing::trace!("read_page_no_cache(page_idx = {})", page_idx);
let page = Arc::new(Page::new(page_idx));
let io_ctx = &self.io_ctx.borrow();
let Some(wal) = self.wal.as_ref() else {
turso_assert!(
matches!(frame_watermark, Some(0) | None),
@@ -1102,12 +1102,7 @@ impl Pager {
);
page.set_locked();
let c = self.begin_read_disk_page(
page_idx,
page.clone(),
allow_empty_read,
self.encryption_ctx.borrow().as_ref(),
)?;
let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read, io_ctx)?;
return Ok((page, c));
};
@@ -1120,12 +1115,7 @@ impl Pager {
return Ok((page, c));
}
let c = self.begin_read_disk_page(
page_idx,
page.clone(),
allow_empty_read,
self.encryption_ctx.borrow().as_ref(),
)?;
let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read, io_ctx)?;
Ok((page, c))
}
@@ -1149,7 +1139,7 @@ impl Pager {
page_idx: usize,
page: PageRef,
allow_empty_read: bool,
encryption_key: Option<&EncryptionContext>,
io_ctx: &IOContext,
) -> Result<Completion> {
sqlite3_ondisk::begin_read_page(
self.db_file.clone(),
@@ -1157,7 +1147,7 @@ impl Pager {
page,
page_idx,
allow_empty_read,
encryption_key,
io_ctx,
)
}
@@ -1802,7 +1792,8 @@ impl Pager {
default_header.database_size = 1.into();
// if a key is set, then we will reserve space for encryption metadata
if let Some(ref ctx) = *self.encryption_ctx.borrow() {
let io_ctx = self.io_ctx.borrow();
if let Some(ctx) = io_ctx.encryption_context() {
default_header.reserved_space = ctx.required_reserved_bytes()
}
@@ -2190,9 +2181,13 @@ impl Pager {
pub fn set_encryption_context(&self, cipher_mode: CipherMode, key: &EncryptionKey) {
let encryption_ctx = EncryptionContext::new(cipher_mode, key).unwrap();
self.encryption_ctx.replace(Some(encryption_ctx.clone()));
{
let mut io_ctx = self.io_ctx.borrow_mut();
io_ctx.set_encryption(encryption_ctx);
}
let Some(wal) = self.wal.as_ref() else { return };
wal.borrow_mut().set_encryption_context(encryption_ctx)
wal.borrow_mut()
.set_io_context(self.io_ctx.borrow().clone())
}
}

View File

@@ -59,11 +59,12 @@ use crate::storage::btree::offset::{
use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min};
use crate::storage::buffer_pool::BufferPool;
use crate::storage::database::DatabaseStorage;
use crate::storage::encryption::EncryptionContext;
use crate::storage::pager::Pager;
use crate::storage::wal::READMARK_NOT_USED;
use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype};
use crate::{bail_corrupt_error, turso_assert, CompletionError, File, Result, WalFileShared};
use crate::{
bail_corrupt_error, turso_assert, CompletionError, File, IOContext, Result, WalFileShared,
};
use std::cell::{Cell, UnsafeCell};
use std::collections::{BTreeMap, HashMap};
use std::mem::MaybeUninit;
@@ -905,7 +906,7 @@ pub fn begin_read_page(
page: PageRef,
page_idx: usize,
allow_empty_read: bool,
encryption_key: Option<&EncryptionContext>,
io_ctx: &IOContext,
) -> Result<Completion> {
tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx);
let buf = buffer_pool.get_page();
@@ -928,7 +929,7 @@ pub fn begin_read_page(
finish_read_page(page_idx, buf, page.clone());
});
let c = Completion::new_read(buf, complete);
db_file.read_page(page_idx, encryption_key, c)
db_file.read_page(page_idx, io_ctx, c)
}
#[instrument(skip_all, level = Level::INFO)]
@@ -982,7 +983,8 @@ pub fn begin_write_btree_page(pager: &Pager, page: &PageRef) -> Result<Completio
})
};
let c = Completion::new_write(write_complete);
page_source.write_page(page_id, buffer.clone(), None, c)
let io_ctx = &pager.io_ctx.borrow();
page_source.write_page(page_id, buffer.clone(), io_ctx, c)
}
#[instrument(skip_all, level = Level::DEBUG)]
@@ -1000,7 +1002,6 @@ pub fn write_pages_vectored(
pager: &Pager,
batch: BTreeMap<usize, Arc<Buffer>>,
done_flag: Arc<AtomicBool>,
encryption_key: Option<&EncryptionContext>,
) -> Result<Vec<Completion>> {
if batch.is_empty() {
done_flag.store(true, Ordering::Relaxed);
@@ -1076,11 +1077,12 @@ pub fn write_pages_vectored(
});
// Submit write operation for this run, decrementing the counter if we error
let io_ctx = &pager.io_ctx.borrow();
match pager.db_file.write_pages(
start_id,
page_sz,
std::mem::replace(&mut run_bufs, Vec::with_capacity(EST_BUFF_CAPACITY)),
encryption_key,
io_ctx,
c,
) {
Ok(c) => {

View File

@@ -17,7 +17,6 @@ use super::sqlite3_ondisk::{self, checksum_wal, WalHeader, WAL_MAGIC_BE, WAL_MAG
use crate::fast_lock::SpinLock;
use crate::io::{clock, File, IO};
use crate::result::LimboResult;
use crate::storage::encryption::EncryptionContext;
use crate::storage::sqlite3_ondisk::{
begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame,
write_pages_vectored, PageSize, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE,
@@ -25,7 +24,7 @@ use crate::storage::sqlite3_ondisk::{
use crate::types::{IOCompletions, IOResult};
use crate::{
bail_corrupt_error, io_yield_many, turso_assert, Buffer, Completion, CompletionError,
LimboError, Result,
IOContext, LimboError, Result,
};
#[derive(Debug, Clone, Default)]
@@ -304,7 +303,7 @@ pub trait Wal: Debug {
/// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no
fn changed_pages_after(&self, frame_watermark: u64) -> Result<Vec<u32>>;
fn set_encryption_context(&mut self, ctx: EncryptionContext);
fn set_io_context(&mut self, ctx: IOContext);
#[cfg(debug_assertions)]
fn as_any(&self) -> &dyn std::any::Any;
@@ -576,7 +575,7 @@ pub struct WalFile {
/// Manages locks needed for checkpointing
checkpoint_guard: Option<CheckpointLocks>,
encryption_ctx: RefCell<Option<EncryptionContext>>,
io_ctx: RefCell<IOContext>,
}
impl fmt::Debug for WalFile {
@@ -1065,7 +1064,10 @@ impl Wal for WalFile {
page.set_locked();
let frame = page.clone();
let page_idx = page.get().id;
let encryption_ctx = self.encryption_ctx.borrow().clone();
let encryption_ctx = {
let io_ctx = self.io_ctx.borrow();
io_ctx.encryption_context().cloned()
};
let seq = self.header.checkpoint_seq;
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
@@ -1244,7 +1246,8 @@ impl Wal for WalFile {
let page_content = page.get_contents();
let page_buf = page_content.as_ptr();
let encryption_ctx = self.encryption_ctx.borrow();
let io_ctx = self.io_ctx.borrow();
let encryption_ctx = io_ctx.encryption_context();
let encrypted_data = {
if let Some(key) = encryption_ctx.as_ref() {
Some(key.encrypt_page(page_buf, page_id)?)
@@ -1442,7 +1445,8 @@ impl Wal for WalFile {
let plain = page.get_contents().as_ptr();
let data_to_write: std::borrow::Cow<[u8]> = {
let ectx = self.encryption_ctx.borrow();
let io_ctx = self.io_ctx.borrow();
let ectx = io_ctx.encryption_context();
if let Some(ctx) = ectx.as_ref() {
Cow::Owned(ctx.encrypt_page(plain, page_id as usize)?)
} else {
@@ -1511,8 +1515,8 @@ impl Wal for WalFile {
self
}
fn set_encryption_context(&mut self, ctx: EncryptionContext) {
self.encryption_ctx.replace(Some(ctx));
fn set_io_context(&mut self, ctx: IOContext) {
self.io_ctx.replace(ctx);
}
}
@@ -1550,7 +1554,7 @@ impl WalFile {
prev_checkpoint: CheckpointResult::default(),
checkpoint_guard: None,
header: *header,
encryption_ctx: RefCell::new(None),
io_ctx: RefCell::new(IOContext::default()),
}
}
@@ -1798,12 +1802,7 @@ impl WalFile {
let batch_map = self.ongoing_checkpoint.pending_writes.take();
if !batch_map.is_empty() {
let done_flag = self.ongoing_checkpoint.add_write();
completions.extend(write_pages_vectored(
pager,
batch_map,
done_flag,
self.encryption_ctx.borrow().as_ref(),
)?);
completions.extend(write_pages_vectored(pager, batch_map, done_flag)?);
}
}