encrypt/decrypt when writing/reading from DB

This commit is contained in:
Avinash Sajjanshetty
2025-08-12 21:57:35 +05:30
parent 657daeded3
commit bd9b4bbfd2
5 changed files with 125 additions and 17 deletions

View File

@@ -280,6 +280,10 @@ impl ReadCompletion {
pub fn callback(&self, bytes_read: Result<i32, CompletionError>) {
(self.complete)(bytes_read.map(|b| (self.buf.clone(), b)));
}
pub fn buf_arc(&self) -> Arc<Buffer> {
self.buf.clone()
}
}
pub struct WriteCompletion {

View File

@@ -8539,6 +8539,11 @@ mod tests {
let c = Completion::new_write(move |_| {
let _ = _buf.clone();
});
#[cfg(feature = "encryption")]
let _c = pager
.db_file
.write_page(current_page as usize, buf.clone(), None, c)?;
#[cfg(not(feature = "encryption"))]
let _c = pager
.db_file
.write_page(current_page as usize, buf.clone(), c)?;

View File

@@ -58,7 +58,12 @@ impl DatabaseStorage for DatabaseFile {
self.file.pread(0, c)
}
#[instrument(skip_all, level = Level::DEBUG)]
fn read_page(&self, page_idx: usize, c: Completion) -> Result<Completion> {
fn read_page(
&self,
page_idx: usize,
#[cfg(feature = "encryption")] encryption_key: Option<&EncryptionKey>,
c: Completion,
) -> Result<Completion> {
let r = c.as_read();
let size = r.buf().len();
assert!(page_idx > 0);
@@ -66,7 +71,42 @@ impl DatabaseStorage for DatabaseFile {
return Err(LimboError::NotADB);
}
let pos = (page_idx - 1) * size;
self.file.pread(pos, c)
#[cfg(feature = "encryption")]
{
if let Some(key) = encryption_key {
let key_clone = key.clone();
let read_buffer = r.buf_arc();
let original_c = c.clone();
let decrypt_complete = Box::new(move |buf: Arc<Buffer>, bytes_read: i32| {
if bytes_read > 0 {
match decrypt_page(buf.as_slice(), page_idx, &key_clone) {
Ok(decrypted_data) => {
let original_buf = original_c.as_read().buf();
original_buf.as_mut_slice().copy_from_slice(&decrypted_data);
original_c.complete(bytes_read);
}
Err(_) => {
original_c.complete(-1);
}
}
} else {
original_c.complete(bytes_read);
}
});
let new_completion = Completion::new_read(read_buffer, decrypt_complete);
self.file.pread(pos, new_completion)
} else {
self.file.pread(pos, c)
}
}
#[cfg(not(feature = "encryption"))]
{
self.file.pread(pos, c)
}
}
#[instrument(skip_all, level = Level::DEBUG)]
@@ -74,6 +114,7 @@ impl DatabaseStorage for DatabaseFile {
&self,
page_idx: usize,
buffer: Arc<Buffer>,
#[cfg(feature = "encryption")] encryption_key: Option<&EncryptionKey>,
c: Completion,
) -> Result<Completion> {
let buffer_size = buffer.len();
@@ -82,21 +123,57 @@ impl DatabaseStorage for DatabaseFile {
assert!(buffer_size <= 65536);
assert_eq!(buffer_size & (buffer_size - 1), 0);
let pos = (page_idx - 1) * buffer_size;
let buffer = {
#[cfg(feature = "encryption")]
{
if let Some(key) = encryption_key {
encrypt_buffer(page_idx, buffer, key)
} else {
buffer
}
}
#[cfg(not(feature = "encryption"))]
{
buffer
}
};
self.file.pwrite(pos, buffer, c)
}
fn write_pages(
&self,
page_idx: usize,
first_page_idx: usize,
page_size: usize,
buffers: Vec<Arc<Buffer>>,
#[cfg(feature = "encryption")] encryption_key: Option<&EncryptionKey>,
c: Completion,
) -> Result<Completion> {
assert!(page_idx > 0);
assert!(first_page_idx > 0);
assert!(page_size >= 512);
assert!(page_size <= 65536);
assert_eq!(page_size & (page_size - 1), 0);
let pos = (page_idx - 1) * page_size;
let pos = (first_page_idx - 1) * page_size;
let buffers = {
#[cfg(feature = "encryption")]
{
if let Some(key) = encryption_key {
buffers
.into_iter()
.enumerate()
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, key))
.collect::<Vec<_>>()
} else {
buffers
}
}
#[cfg(not(feature = "encryption"))]
{
buffers
}
};
let c = self.file.pwritev(pos, buffers, c)?;
Ok(c)
}
@@ -124,3 +201,9 @@ impl DatabaseFile {
Self { file }
}
}
#[cfg(feature = "encryption")]
fn encrypt_buffer(page_idx: usize, buffer: Arc<Buffer>, key: &EncryptionKey) -> Arc<Buffer> {
let encrypted_data = encrypt_page(buffer.as_slice(), page_idx, key).unwrap();
Arc::new(Buffer::new(encrypted_data.to_vec()))
}

View File

@@ -633,8 +633,8 @@ impl Pager {
// Check if the calculated offset for the entry is within the bounds of the actual page data length.
if offset_in_ptrmap_page + PTRMAP_ENTRY_SIZE > actual_data_length {
return Err(LimboError::InternalError(format!(
"Ptrmap offset {offset_in_ptrmap_page} + entry size {PTRMAP_ENTRY_SIZE} out of bounds for page {ptrmap_pg_no} (actual data len {actual_data_length})"
)));
"Ptrmap offset {offset_in_ptrmap_page} + entry size {PTRMAP_ENTRY_SIZE} out of bounds for page {ptrmap_pg_no} (actual data len {actual_data_length})"
)));
}
let entry_slice = &ptrmap_page_data_slice
@@ -676,8 +676,8 @@ impl Pager {
|| is_ptrmap_page(db_page_no_to_update, page_size)
{
return Err(LimboError::InternalError(format!(
"Cannot set ptrmap entry for page {db_page_no_to_update}: it's a header/ptrmap page or invalid."
)));
"Cannot set ptrmap entry for page {db_page_no_to_update}: it's a header/ptrmap page or invalid."
)));
}
let ptrmap_pg_no = get_ptrmap_page_no_for_db_page(db_page_no_to_update, page_size);
@@ -1012,8 +1012,15 @@ impl Pager {
matches!(frame_watermark, Some(0) | None),
"frame_watermark must be either None or Some(0) because DB has no WAL and read with other watermark is invalid"
);
page.set_locked();
let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read)?;
let c = self.begin_read_disk_page(
page_idx,
page.clone(),
allow_empty_read,
#[cfg(feature = "encryption")]
self.encryption_key.borrow().as_ref(),
)?;
return Ok((page, c));
};
@@ -1026,7 +1033,13 @@ impl Pager {
return Ok((page, c));
}
let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read)?;
let c = self.begin_read_disk_page(
page_idx,
page.clone(),
allow_empty_read,
#[cfg(feature = "encryption")]
self.encryption_key.borrow().as_ref(),
)?;
Ok((page, c))
}
@@ -2087,10 +2100,10 @@ mod ptrmap {
pub fn serialize(&self, buffer: &mut [u8]) -> Result<()> {
if buffer.len() < PTRMAP_ENTRY_SIZE {
return Err(LimboError::InternalError(format!(
"Buffer too small to serialize ptrmap entry. Expected at least {} bytes, got {}",
PTRMAP_ENTRY_SIZE,
buffer.len()
)));
"Buffer too small to serialize ptrmap entry. Expected at least {} bytes, got {}",
PTRMAP_ENTRY_SIZE,
buffer.len()
)));
}
buffer[0] = self.entry_type as u8;
buffer[1..5].copy_from_slice(&self.parent_page_no.to_be_bytes());

View File

@@ -874,6 +874,7 @@ pub fn begin_read_page(
page: PageRef,
page_idx: usize,
allow_empty_read: bool,
encryption_key: Option<&EncryptionKey>,
) -> Result<Completion> {
tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx);
let buf = buffer_pool.get_page();
@@ -895,7 +896,7 @@ pub fn begin_read_page(
finish_read_page(page_idx, buf, page.clone());
});
let c = Completion::new_read(buf, complete);
db_file.read_page(page_idx, c)
db_file.read_page(page_idx, encryption_key, c)
}
#[instrument(skip_all, level = Level::INFO)]
@@ -946,7 +947,7 @@ pub fn begin_write_btree_page(pager: &Pager, page: &PageRef) -> Result<Completio
})
};
let c = Completion::new_write(write_complete);
page_source.write_page(page_id, buffer.clone(), c)
page_source.write_page(page_id, buffer.clone(), None, c)
}
#[instrument(skip_all, level = Level::DEBUG)]
@@ -964,6 +965,7 @@ pub fn write_pages_vectored(
pager: &Pager,
batch: BTreeMap<usize, Arc<Buffer>>,
flush: &PendingFlush,
encryption_key: Option<&EncryptionKey>,
) -> Result<Vec<Completion>> {
if batch.is_empty() {
return Ok(Vec::new());
@@ -1043,6 +1045,7 @@ pub fn write_pages_vectored(
start_id,
page_sz,
std::mem::replace(&mut run_bufs, Vec::with_capacity(EST_BUFF_CAPACITY)),
encryption_key,
c,
) {
Ok(c) => {