diff --git a/core/storage/database.rs b/core/storage/database.rs index 21dcf1617..c5f6552ed 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -112,42 +112,77 @@ impl DatabaseStorage for DatabaseFile { return Err(LimboError::IntegerOverflow); }; - if let Some(ctx) = io_ctx.encryption_context() { - let encryption_ctx = ctx.clone(); - let read_buffer = r.buf_arc(); - let original_c = c.clone(); + match &io_ctx.encryption_or_checksum { + EncryptionOrChecksum::Encryption(ctx) => { + let encryption_ctx = ctx.clone(); + let read_buffer = r.buf_arc(); + let original_c = c.clone(); + let decrypt_complete = + Box::new(move |res: Result<(Arc, i32), CompletionError>| { + let Ok((buf, bytes_read)) = res else { + return; + }; + assert!( + bytes_read > 0, + "Expected to read some data on success for page_id={page_idx}" + ); + match encryption_ctx.decrypt_page(buf.as_slice(), page_idx) { + Ok(decrypted_data) => { + let original_buf = original_c.as_read().buf(); + original_buf.as_mut_slice().copy_from_slice(&decrypted_data); + original_c.complete(bytes_read); + } + Err(e) => { + tracing::error!( + "Failed to decrypt page data for page_id={page_idx}: {e}" + ); + assert!( + !original_c.has_error(), + "Original completion already has an error" + ); + original_c.error(CompletionError::DecryptionError { page_idx }); + } + } + }); + let wrapped_completion = Completion::new_read(read_buffer, decrypt_complete); + self.file.pread(pos, wrapped_completion) + } + EncryptionOrChecksum::Checksum(ctx) => { + let checksum_ctx = ctx.clone(); + let read_buffer = r.buf_arc(); + let original_c = c.clone(); - let decrypt_complete = - Box::new(move |res: Result<(Arc, i32), CompletionError>| { - let Ok((buf, bytes_read)) = res else { - return; - }; - assert!( - bytes_read > 0, - "Expected to read some data on success for page_id={page_idx}" - ); - match encryption_ctx.decrypt_page(buf.as_slice(), page_idx) { - Ok(decrypted_data) => { - let original_buf = original_c.as_read().buf(); - original_buf.as_mut_slice().copy_from_slice(&decrypted_data); + let verify_complete = + Box::new(move |res: Result<(Arc, i32), CompletionError>| { + let Ok((buf, bytes_read)) = res else { + return; + }; + if bytes_read <= 0 { + tracing::trace!("Read page {page_idx} with {} bytes", bytes_read); original_c.complete(bytes_read); + return; } - Err(e) => { - tracing::error!( - "Failed to decrypt page data for page_id={page_idx}: {e}" - ); - assert!( - !original_c.has_error(), - "Original completion already has an error" - ); - original_c.error(CompletionError::DecryptionError { page_idx }); + match checksum_ctx.verify_and_strip_checksum(buf.as_mut_slice(), page_idx) { + Ok(_) => { + original_c.complete(bytes_read); + } + Err(e) => { + tracing::error!( + "Failed to verify checksum for page_id={page_idx}: {e}" + ); + assert!( + !original_c.has_error(), + "Original completion already has an error" + ); + original_c.error(e); + } } - } - }); - let new_completion = Completion::new_read(read_buffer, decrypt_complete); - self.file.pread(pos, new_completion) - } else { - self.file.pread(pos, c) + }); + + let wrapped_completion = Completion::new_read(read_buffer, verify_complete); + self.file.pread(pos, wrapped_completion) + } + EncryptionOrChecksum::None => self.file.pread(pos, c), } } @@ -167,12 +202,10 @@ impl DatabaseStorage for DatabaseFile { let Some(pos) = (page_idx as u64 - 1).checked_mul(buffer_size as u64) else { return Err(LimboError::IntegerOverflow); }; - let buffer = { - if let Some(ctx) = io_ctx.encryption_context() { - encrypt_buffer(page_idx, buffer, ctx) - } else { - buffer - } + let buffer = match &io_ctx.encryption_or_checksum { + EncryptionOrChecksum::Encryption(ctx) => encrypt_buffer(page_idx, buffer, ctx), + EncryptionOrChecksum::Checksum(ctx) => checksum_buffer(page_idx, buffer, ctx), + EncryptionOrChecksum::None => buffer, }; self.file.pwrite(pos, buffer, c) } @@ -193,18 +226,19 @@ impl DatabaseStorage for DatabaseFile { let Some(pos) = (first_page_idx as u64 - 1).checked_mul(page_size as u64) else { return Err(LimboError::IntegerOverflow); }; - let buffers = { - if let Some(ctx) = io_ctx.encryption_context() { - buffers - .into_iter() - .enumerate() - .map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx)) - .collect::>() - } else { - buffers - } + let buffers = match &io_ctx.encryption_or_checksum() { + EncryptionOrChecksum::Encryption(ctx) => buffers + .into_iter() + .enumerate() + .map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx)) + .collect::>(), + EncryptionOrChecksum::Checksum(ctx) => buffers + .into_iter() + .enumerate() + .map(|(i, buffer)| checksum_buffer(first_page_idx + i, buffer, ctx)) + .collect::>(), + EncryptionOrChecksum::None => buffers, }; - let c = self.file.pwritev(pos, buffers, c)?; Ok(c) } @@ -237,3 +271,9 @@ fn encrypt_buffer(page_idx: usize, buffer: Arc, ctx: &EncryptionContext) let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap(); Arc::new(Buffer::new(encrypted_data.to_vec())) } + +fn checksum_buffer(page_idx: usize, buffer: Arc, ctx: &ChecksumContext) -> Arc { + ctx.add_checksum_to_page(buffer.as_mut_slice(), page_idx) + .unwrap(); + buffer +} diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index d2af9a6b0..27f5bc7e7 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -58,7 +58,7 @@ use crate::storage::btree::offset::{ }; use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min}; use crate::storage::buffer_pool::BufferPool; -use crate::storage::database::DatabaseStorage; +use crate::storage::database::{DatabaseStorage, EncryptionOrChecksum}; use crate::storage::pager::Pager; use crate::storage::wal::READMARK_NOT_USED; use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype}; @@ -1985,40 +1985,74 @@ pub fn begin_read_wal_frame( let buf = buffer_pool.get_page(); let buf = Arc::new(buf); - if let Some(ctx) = io_ctx.encryption_context() { - let encryption_ctx = ctx.clone(); - let original_complete = complete; + match io_ctx.encryption_or_checksum() { + EncryptionOrChecksum::Encryption(ctx) => { + let encryption_ctx = ctx.clone(); + let original_complete = complete; - let decrypt_complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { - let Ok((encrypted_buf, bytes_read)) = res else { - original_complete(res); - return; - }; - assert!( - bytes_read > 0, - "Expected to read some data on success for page_idx={page_idx}" - ); - match encryption_ctx.decrypt_page(encrypted_buf.as_slice(), page_idx) { - Ok(decrypted_data) => { - encrypted_buf - .as_mut_slice() - .copy_from_slice(&decrypted_data); - original_complete(Ok((encrypted_buf, bytes_read))); - } - Err(e) => { - tracing::error!( - "Failed to decrypt WAL frame data for page_idx={page_idx}: {e}" + let decrypt_complete = + Box::new(move |res: Result<(Arc, i32), CompletionError>| { + let Ok((encrypted_buf, bytes_read)) = res else { + original_complete(res); + return; + }; + assert!( + bytes_read > 0, + "Expected to read some data on success for page_idx={page_idx}" ); - original_complete(Err(CompletionError::DecryptionError { page_idx })); - } - } - }); + match encryption_ctx.decrypt_page(encrypted_buf.as_slice(), page_idx) { + Ok(decrypted_data) => { + encrypted_buf + .as_mut_slice() + .copy_from_slice(&decrypted_data); + original_complete(Ok((encrypted_buf, bytes_read))); + } + Err(e) => { + tracing::error!( + "Failed to decrypt WAL frame data for page_idx={page_idx}: {e}" + ); + original_complete(Err(CompletionError::DecryptionError { page_idx })); + } + } + }); - let new_completion = Completion::new_read(buf, decrypt_complete); - io.pread(offset, new_completion) - } else { - let c = Completion::new_read(buf, complete); - io.pread(offset, c) + let new_completion = Completion::new_read(buf, decrypt_complete); + io.pread(offset, new_completion) + } + EncryptionOrChecksum::Checksum(ctx) => { + let checksum_ctx = ctx.clone(); + let original_c = complete; + let verify_complete = + Box::new(move |res: Result<(Arc, i32), CompletionError>| { + let Ok((buf, bytes_read)) = res else { + original_c(res); + return; + }; + if bytes_read <= 0 { + tracing::trace!("Read page {page_idx} with {} bytes", bytes_read); + original_c(Ok((buf, bytes_read))); + return; + } + + match checksum_ctx.verify_and_strip_checksum(buf.as_mut_slice(), page_idx) { + Ok(_) => { + original_c(Ok((buf, bytes_read))); + } + Err(e) => { + tracing::error!( + "Failed to verify checksum for page_id={page_idx}: {e}" + ); + original_c(Err(e)) + } + } + }); + let c = Completion::new_read(buf, verify_complete); + io.pread(offset, c) + } + EncryptionOrChecksum::None => { + let c = Completion::new_read(buf, complete); + io.pread(offset, c) + } } }