From bd9b4bbfd25b39614272063b3f28ca9501bf2e0d Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Tue, 12 Aug 2025 21:57:35 +0530 Subject: [PATCH] encrypt/decrypt when writing/reading from DB --- core/io/mod.rs | 4 ++ core/storage/btree.rs | 5 ++ core/storage/database.rs | 93 ++++++++++++++++++++++++++++++++-- core/storage/pager.rs | 33 ++++++++---- core/storage/sqlite3_ondisk.rs | 7 ++- 5 files changed, 125 insertions(+), 17 deletions(-) diff --git a/core/io/mod.rs b/core/io/mod.rs index a54443ce5..beb674d67 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -280,6 +280,10 @@ impl ReadCompletion { pub fn callback(&self, bytes_read: Result) { (self.complete)(bytes_read.map(|b| (self.buf.clone(), b))); } + + pub fn buf_arc(&self) -> Arc { + self.buf.clone() + } } pub struct WriteCompletion { diff --git a/core/storage/btree.rs b/core/storage/btree.rs index de831757d..e72673778 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -8539,6 +8539,11 @@ mod tests { let c = Completion::new_write(move |_| { let _ = _buf.clone(); }); + #[cfg(feature = "encryption")] + let _c = pager + .db_file + .write_page(current_page as usize, buf.clone(), None, c)?; + #[cfg(not(feature = "encryption"))] let _c = pager .db_file .write_page(current_page as usize, buf.clone(), c)?; diff --git a/core/storage/database.rs b/core/storage/database.rs index 59a354b8c..df7dd577f 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -58,7 +58,12 @@ impl DatabaseStorage for DatabaseFile { self.file.pread(0, c) } #[instrument(skip_all, level = Level::DEBUG)] - fn read_page(&self, page_idx: usize, c: Completion) -> Result { + fn read_page( + &self, + page_idx: usize, + #[cfg(feature = "encryption")] encryption_key: Option<&EncryptionKey>, + c: Completion, + ) -> Result { let r = c.as_read(); let size = r.buf().len(); assert!(page_idx > 0); @@ -66,7 +71,42 @@ impl DatabaseStorage for DatabaseFile { return Err(LimboError::NotADB); } let pos = (page_idx - 1) * size; - self.file.pread(pos, c) + + #[cfg(feature = "encryption")] + { + if let Some(key) = encryption_key { + let key_clone = key.clone(); + let read_buffer = r.buf_arc(); + let original_c = c.clone(); + + let decrypt_complete = Box::new(move |buf: Arc, bytes_read: i32| { + if bytes_read > 0 { + match decrypt_page(buf.as_slice(), page_idx, &key_clone) { + Ok(decrypted_data) => { + let original_buf = original_c.as_read().buf(); + original_buf.as_mut_slice().copy_from_slice(&decrypted_data); + original_c.complete(bytes_read); + } + Err(_) => { + original_c.complete(-1); + } + } + } else { + original_c.complete(bytes_read); + } + }); + + let new_completion = Completion::new_read(read_buffer, decrypt_complete); + self.file.pread(pos, new_completion) + } else { + self.file.pread(pos, c) + } + } + + #[cfg(not(feature = "encryption"))] + { + self.file.pread(pos, c) + } } #[instrument(skip_all, level = Level::DEBUG)] @@ -74,6 +114,7 @@ impl DatabaseStorage for DatabaseFile { &self, page_idx: usize, buffer: Arc, + #[cfg(feature = "encryption")] encryption_key: Option<&EncryptionKey>, c: Completion, ) -> Result { let buffer_size = buffer.len(); @@ -82,21 +123,57 @@ impl DatabaseStorage for DatabaseFile { assert!(buffer_size <= 65536); assert_eq!(buffer_size & (buffer_size - 1), 0); let pos = (page_idx - 1) * buffer_size; + let buffer = { + #[cfg(feature = "encryption")] + { + if let Some(key) = encryption_key { + encrypt_buffer(page_idx, buffer, key) + } else { + buffer + } + } + #[cfg(not(feature = "encryption"))] + { + buffer + } + }; self.file.pwrite(pos, buffer, c) } fn write_pages( &self, - page_idx: usize, + first_page_idx: usize, page_size: usize, buffers: Vec>, + #[cfg(feature = "encryption")] encryption_key: Option<&EncryptionKey>, c: Completion, ) -> Result { - assert!(page_idx > 0); + assert!(first_page_idx > 0); assert!(page_size >= 512); assert!(page_size <= 65536); assert_eq!(page_size & (page_size - 1), 0); - let pos = (page_idx - 1) * page_size; + + let pos = (first_page_idx - 1) * page_size; + + let buffers = { + #[cfg(feature = "encryption")] + { + if let Some(key) = encryption_key { + buffers + .into_iter() + .enumerate() + .map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, key)) + .collect::>() + } else { + buffers + } + } + #[cfg(not(feature = "encryption"))] + { + buffers + } + }; + let c = self.file.pwritev(pos, buffers, c)?; Ok(c) } @@ -124,3 +201,9 @@ impl DatabaseFile { Self { file } } } + +#[cfg(feature = "encryption")] +fn encrypt_buffer(page_idx: usize, buffer: Arc, key: &EncryptionKey) -> Arc { + let encrypted_data = encrypt_page(buffer.as_slice(), page_idx, key).unwrap(); + Arc::new(Buffer::new(encrypted_data.to_vec())) +} diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 7f000c80b..6ba49cb84 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -633,8 +633,8 @@ impl Pager { // Check if the calculated offset for the entry is within the bounds of the actual page data length. if offset_in_ptrmap_page + PTRMAP_ENTRY_SIZE > actual_data_length { return Err(LimboError::InternalError(format!( - "Ptrmap offset {offset_in_ptrmap_page} + entry size {PTRMAP_ENTRY_SIZE} out of bounds for page {ptrmap_pg_no} (actual data len {actual_data_length})" - ))); + "Ptrmap offset {offset_in_ptrmap_page} + entry size {PTRMAP_ENTRY_SIZE} out of bounds for page {ptrmap_pg_no} (actual data len {actual_data_length})" + ))); } let entry_slice = &ptrmap_page_data_slice @@ -676,8 +676,8 @@ impl Pager { || is_ptrmap_page(db_page_no_to_update, page_size) { return Err(LimboError::InternalError(format!( - "Cannot set ptrmap entry for page {db_page_no_to_update}: it's a header/ptrmap page or invalid." - ))); + "Cannot set ptrmap entry for page {db_page_no_to_update}: it's a header/ptrmap page or invalid." + ))); } let ptrmap_pg_no = get_ptrmap_page_no_for_db_page(db_page_no_to_update, page_size); @@ -1012,8 +1012,15 @@ impl Pager { matches!(frame_watermark, Some(0) | None), "frame_watermark must be either None or Some(0) because DB has no WAL and read with other watermark is invalid" ); + page.set_locked(); - let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read)?; + let c = self.begin_read_disk_page( + page_idx, + page.clone(), + allow_empty_read, + #[cfg(feature = "encryption")] + self.encryption_key.borrow().as_ref(), + )?; return Ok((page, c)); }; @@ -1026,7 +1033,13 @@ impl Pager { return Ok((page, c)); } - let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read)?; + let c = self.begin_read_disk_page( + page_idx, + page.clone(), + allow_empty_read, + #[cfg(feature = "encryption")] + self.encryption_key.borrow().as_ref(), + )?; Ok((page, c)) } @@ -2087,10 +2100,10 @@ mod ptrmap { pub fn serialize(&self, buffer: &mut [u8]) -> Result<()> { if buffer.len() < PTRMAP_ENTRY_SIZE { return Err(LimboError::InternalError(format!( - "Buffer too small to serialize ptrmap entry. Expected at least {} bytes, got {}", - PTRMAP_ENTRY_SIZE, - buffer.len() - ))); + "Buffer too small to serialize ptrmap entry. Expected at least {} bytes, got {}", + PTRMAP_ENTRY_SIZE, + buffer.len() + ))); } buffer[0] = self.entry_type as u8; buffer[1..5].copy_from_slice(&self.parent_page_no.to_be_bytes()); diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index c6dce74a1..5c5a1a419 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -874,6 +874,7 @@ pub fn begin_read_page( page: PageRef, page_idx: usize, allow_empty_read: bool, + encryption_key: Option<&EncryptionKey>, ) -> Result { tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx); let buf = buffer_pool.get_page(); @@ -895,7 +896,7 @@ pub fn begin_read_page( finish_read_page(page_idx, buf, page.clone()); }); let c = Completion::new_read(buf, complete); - db_file.read_page(page_idx, c) + db_file.read_page(page_idx, encryption_key, c) } #[instrument(skip_all, level = Level::INFO)] @@ -946,7 +947,7 @@ pub fn begin_write_btree_page(pager: &Pager, page: &PageRef) -> Result>, flush: &PendingFlush, + encryption_key: Option<&EncryptionKey>, ) -> Result> { if batch.is_empty() { return Ok(Vec::new()); @@ -1043,6 +1045,7 @@ pub fn write_pages_vectored( start_id, page_sz, std::mem::replace(&mut run_bufs, Vec::with_capacity(EST_BUFF_CAPACITY)), + encryption_key, c, ) { Ok(c) => {