diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 9db5bc267..8eb1a5199 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -921,18 +921,34 @@ impl Wal for WalFile { let offset = self.frame_offset(frame_id); page.set_locked(); let frame = page.clone(); - let complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { - let Ok((buf, bytes_read)) = res else { - return; - }; + #[cfg(feature = "encryption")] + let page_idx = page.get().id; + #[cfg(feature = "encryption")] + let key = self.encryption_key.borrow().clone(); + let complete = Box::new(move |buf: Arc, bytes_read: i32| { let buf_len = buf.len(); turso_assert!( bytes_read == buf_len as i32, "read({bytes_read}) less than expected({buf_len}): frame_id={frame_id}" ); let frame = frame.clone(); - finish_read_page(page.get().id, buf, frame); + #[cfg(feature = "encryption")] + { + if let Some(key) = key.clone() { + match decrypt_page(buf.as_slice(), page_idx, &key) { + Ok(decrypted_data) => { + buf.as_mut_slice().copy_from_slice(&decrypted_data); + } + Err(_) => { + tracing::error!("Failed to decrypt page data for frame_id={frame_id}"); + return; + } + } + } + } + finish_read_page(page.get().id, buf, frame).unwrap(); }); + begin_read_wal_frame( &self.get_shared().file, offset + WAL_FRAME_HEADER_SIZE, @@ -1083,6 +1099,31 @@ impl Wal for WalFile { let checksums = self.last_checksum; let page_content = page.get_contents(); let page_buf = page_content.as_ptr(); + + #[cfg(feature = "encryption")] + let encrypted_data = { + let key = self.encryption_key.borrow(); + if let Some(key) = key.as_ref() { + Some(encrypt_page(page_buf, page_id, key)?) + } else { + None + } + }; + let data_to_write = { + #[cfg(feature = "encryption")] + { + if let Some(ref data) = encrypted_data { + data.as_slice() + } else { + page_buf + } + } + #[cfg(not(feature = "encryption"))] + { + page_buf + } + }; + let (frame_checksums, frame_bytes) = prepare_wal_frame( &self.buffer_pool, &header, @@ -1090,7 +1131,7 @@ impl Wal for WalFile { header.page_size, page_id as u32, db_size, - page_buf, + data_to_write, ); let c = Completion::new_write({ @@ -1494,6 +1535,8 @@ impl WalFile { pager, std::mem::take(&mut self.ongoing_checkpoint.batch), &self.ongoing_checkpoint.pending_flush, + #[cfg(feature = "encryption")] + self.encryption_key.borrow().as_ref(), )?; // batch is queued self.ongoing_checkpoint.state = CheckpointState::AfterFlush; @@ -1907,7 +1950,7 @@ pub mod test { sync::{atomic::Ordering, Arc}, }; #[allow(clippy::arc_with_non_send_sync)] - fn get_database() -> (Arc, std::path::PathBuf) { + pub(crate) fn get_database() -> (Arc, std::path::PathBuf) { let mut path = tempfile::tempdir().unwrap().keep(); let dbpath = path.clone(); path.push("test.db");