Merge 'Add per page checksums' from Avinash Sajjanshetty

This patch adds checksums to Turso DB. You may check the design here in
the [RFC](https://github.com/tursodatabase/turso/issues/2178).
1. We use reserved bytes (8 bytes) to store the checksums. On every IO
read, we verify that the checksum matches.
2. We use twox hash for checksums.
3. Checksum works only on 4K pages now. It's a small change to enable
for all other sizes, I will send another PR.
4. Right now, it's not possible to switch to different algorithm or turn
off altogether. That will be added in the future PRs.
5. Checksums can be enabled only for new dbs. For existing DBs, we will
disable it.
6. To add checksums for existing DBs, we need vacuum since it would
require rewrite of whole db.

Closes #2840
This commit is contained in:
Pekka Enberg
2025-09-13 18:46:53 +03:00
committed by GitHub
19 changed files with 686 additions and 181 deletions

View File

@@ -42,10 +42,14 @@ jobs:
run: |
cargo test --features encryption --color=always --test integration_tests query_processing::encryption
cargo test --features encryption --color=always --lib storage::encryption
- name: Test Checksums
run: |
cargo test --features checksum --color=always --lib storage::checksum
cargo test --features checksum --color=always --test integration_tests storage::checksum
- name: Test
env:
RUST_LOG: ${{ runner.debug && 'turso_core::storage=trace' || '' }}
run: cargo test --verbose
run: cargo test --verbose --features checksum
timeout-minutes: 20
clippy:

11
Cargo.lock generated
View File

@@ -678,6 +678,7 @@ dependencies = [
"tracing-subscriber",
"turso",
"turso_core",
"twox-hash",
"zerocopy 0.8.26",
]
@@ -4265,6 +4266,7 @@ dependencies = [
"turso_macros",
"turso_parser",
"turso_sqlite3_parser",
"twox-hash",
"uncased",
"uuid",
]
@@ -4438,6 +4440,15 @@ dependencies = [
"turso_parser",
]
[[package]]
name = "twox-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b907da542cbced5261bd3256de1b3a1bf340a3d37f93425a07362a1d687de56"
dependencies = [
"rand 0.9.2",
]
[[package]]
name = "typenum"
version = "1.18.0"

View File

@@ -29,6 +29,7 @@ simulator = ["fuzz", "serde"]
serde = ["dep:serde"]
series = []
encryption = []
checksum = []
[target.'cfg(target_os = "linux")'.dependencies]
io-uring = { version = "0.7.5", optional = true }
@@ -79,6 +80,7 @@ aes-gcm = { version = "0.10.3"}
aes = { version = "0.8.4"}
turso_parser = { workspace = true }
aegis = "0.9.0"
twox-hash = "2.1.1"
[build-dependencies]
chrono = { version = "0.4.38", default-features = false }

View File

@@ -126,6 +126,12 @@ pub enum CompletionError {
DecryptionError { page_idx: usize },
#[error("I/O error: partial write")]
ShortWrite,
#[error("Checksum mismatch on page {page_id}: expected {expected}, got {actual}")]
ChecksumMismatch {
page_id: usize,
expected: u64,
actual: u64,
},
}
#[macro_export]

View File

@@ -1,4 +1,5 @@
#![allow(clippy::arc_with_non_send_sync)]
extern crate core;
mod assert;
mod error;
@@ -40,6 +41,7 @@ pub mod numeric;
mod numeric;
use crate::incremental::view::AllViewsTxState;
use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES;
use crate::storage::encryption::CipherMode;
use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME;
#[cfg(all(feature = "fs", feature = "conn_raw_api"))]
@@ -517,6 +519,23 @@ impl Database {
Ok(page_size)
}
fn read_reserved_space_bytes_from_db_header(&self) -> Result<u8> {
turso_assert!(
self.db_state.is_initialized(),
"read_reserved_space_bytes_from_db_header called on uninitialized database"
);
turso_assert!(
PageSize::MIN % 512 == 0,
"header read must be a multiple of 512 for O_DIRECT"
);
let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize));
let c = Completion::new_read(buf.clone(), move |_res| {});
let c = self.db_file.read_header(c)?;
self.io.wait_for_completion(c)?;
let reserved_bytes = u8::from_be_bytes(buf.as_slice()[20..21].try_into().unwrap());
Ok(reserved_bytes)
}
/// Read the page size in order of preference:
/// 1. From the WAL header if it exists and is initialized
/// 2. From the database header if the database is initialized
@@ -551,7 +570,24 @@ impl Database {
}
}
/// if the database is initialized i.e. it exists on disk, return the reserved space bytes from
/// the header or None
fn maybe_get_reserved_space_bytes(&self) -> Result<Option<u8>> {
if self.db_state.is_initialized() {
Ok(Some(self.read_reserved_space_bytes_from_db_header()?))
} else {
Ok(None)
}
}
fn init_pager(&self, requested_page_size: Option<usize>) -> Result<Pager> {
let reserved_bytes = self.maybe_get_reserved_space_bytes()?;
let disable_checksums = if let Some(reserved_bytes) = reserved_bytes {
// if the required reserved bytes for checksums is not present, disable checksums
reserved_bytes != CHECKSUM_REQUIRED_RESERVED_BYTES
} else {
false
};
// Check if WAL is enabled
let shared_wal = self.shared_wal.read();
if shared_wal.enabled.load(Ordering::Relaxed) {
@@ -579,6 +615,12 @@ impl Database {
self.init_lock.clone(),
)?;
pager.page_size.set(Some(page_size));
if let Some(reserved_bytes) = reserved_bytes {
pager.set_reserved_space_bytes(reserved_bytes);
}
if disable_checksums {
pager.reset_checksum_context();
}
return Ok(pager);
}
let page_size = self.determine_actual_page_size(&shared_wal, requested_page_size)?;
@@ -603,6 +645,12 @@ impl Database {
)?;
pager.page_size.set(Some(page_size));
if let Some(reserved_bytes) = reserved_bytes {
pager.set_reserved_space_bytes(reserved_bytes);
}
if disable_checksums {
pager.reset_checksum_context();
}
let file = self
.io
.open_file(&self.wal_path, OpenFlags::Create, false)?;

View File

@@ -6506,7 +6506,12 @@ fn find_free_slot(
pub fn btree_init_page(page: &PageRef, page_type: PageType, offset: usize, usable_space: usize) {
// setup btree page
let contents = page.get_contents();
tracing::debug!("btree_init_page(id={}, offset={})", page.get().id, offset);
tracing::debug!(
"btree_init_page(id={}, offset={}, usable_space={})",
page.get().id,
offset,
usable_space
);
contents.offset = offset;
let id = page_type as u8;
contents.write_page_type(id);
@@ -7988,7 +7993,7 @@ mod tests {
// FIXME: handle page cache is full
let _ = run_until_done(|| pager.allocate_page1(), &pager);
let page2 = run_until_done(|| pager.allocate_page(), &pager).unwrap();
btree_init_page(&page2, PageType::TableLeaf, 0, 4096);
btree_init_page(&page2, PageType::TableLeaf, 0, pager.usable_space());
(pager, page2.get().id, db, conn)
}

187
core/storage/checksum.rs Normal file
View File

@@ -0,0 +1,187 @@
#![allow(unused_variables, dead_code)]
use crate::{CompletionError, Result};
const CHECKSUM_PAGE_SIZE: usize = 4096;
const CHECKSUM_SIZE: usize = 8;
pub(crate) const CHECKSUM_REQUIRED_RESERVED_BYTES: u8 = CHECKSUM_SIZE as u8;
#[derive(Clone)]
pub struct ChecksumContext {}
impl ChecksumContext {
pub fn new() -> Self {
ChecksumContext {}
}
#[cfg(not(feature = "checksum"))]
pub fn add_checksum_to_page(&self, _page: &mut [u8], _page_id: usize) -> Result<()> {
Ok(())
}
#[cfg(not(feature = "checksum"))]
pub fn verify_checksum(
&self,
_page: &mut [u8],
_page_id: usize,
) -> std::result::Result<(), CompletionError> {
Ok(())
}
#[cfg(feature = "checksum")]
pub fn add_checksum_to_page(&self, page: &mut [u8], _page_id: usize) -> Result<()> {
if page.len() != CHECKSUM_PAGE_SIZE {
return Ok(());
}
// compute checksum on the actual page data (excluding the reserved checksum area)
let actual_page = &page[..CHECKSUM_PAGE_SIZE - CHECKSUM_SIZE];
let checksum = self.compute_checksum(actual_page);
let checksum_bytes = checksum.to_le_bytes();
assert_eq!(checksum_bytes.len(), CHECKSUM_SIZE);
page[CHECKSUM_PAGE_SIZE - CHECKSUM_SIZE..].copy_from_slice(&checksum_bytes);
Ok(())
}
#[cfg(feature = "checksum")]
pub fn verify_checksum(
&self,
page: &mut [u8],
page_id: usize,
) -> std::result::Result<(), CompletionError> {
if page.len() != CHECKSUM_PAGE_SIZE {
return Ok(());
}
let actual_page = &page[..CHECKSUM_PAGE_SIZE - CHECKSUM_SIZE];
let stored_checksum_bytes = &page[CHECKSUM_PAGE_SIZE - CHECKSUM_SIZE..];
let stored_checksum = u64::from_le_bytes(stored_checksum_bytes.try_into().unwrap());
let computed_checksum = self.compute_checksum(actual_page);
if stored_checksum != computed_checksum {
tracing::error!(
"Checksum mismatch on page {}: expected {:x}, got {:x}",
page_id,
stored_checksum,
computed_checksum
);
return Err(CompletionError::ChecksumMismatch {
page_id,
expected: stored_checksum,
actual: computed_checksum,
});
}
Ok(())
}
fn compute_checksum(&self, data: &[u8]) -> u64 {
twox_hash::XxHash3_64::oneshot(data)
}
pub fn required_reserved_bytes(&self) -> u8 {
CHECKSUM_REQUIRED_RESERVED_BYTES
}
}
impl Default for ChecksumContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::CompletionError;
fn get_random_page() -> [u8; CHECKSUM_PAGE_SIZE] {
let mut page = [0u8; CHECKSUM_PAGE_SIZE];
for (i, byte) in page
.iter_mut()
.enumerate()
.take(CHECKSUM_PAGE_SIZE - CHECKSUM_SIZE)
{
*byte = (i % 256) as u8;
}
page
}
#[test]
fn test_add_checksum_to_page() {
let ctx = ChecksumContext::new();
let mut page = get_random_page();
let result = ctx.add_checksum_to_page(&mut page, 2);
assert!(result.is_ok());
let checksum_bytes = &page[CHECKSUM_PAGE_SIZE - CHECKSUM_SIZE..];
let stored_checksum = u64::from_le_bytes(checksum_bytes.try_into().unwrap());
let actual_page = &page[..CHECKSUM_PAGE_SIZE - CHECKSUM_SIZE];
let expected_checksum = ctx.compute_checksum(actual_page);
assert_eq!(stored_checksum, expected_checksum);
}
#[test]
fn test_verify_and_strip_checksum_valid() {
let ctx = ChecksumContext::new();
let mut page = get_random_page();
ctx.add_checksum_to_page(&mut page, 2).unwrap();
let result = ctx.verify_checksum(&mut page, 2);
assert!(result.is_ok());
}
#[test]
fn test_verify_and_strip_checksum_mismatch() {
let ctx = ChecksumContext::new();
let mut page = get_random_page();
ctx.add_checksum_to_page(&mut page, 2).unwrap();
// corrupt the data to cause checksum mismatch
page[0] = 255;
let result = ctx.verify_checksum(&mut page, 2);
assert!(result.is_err());
match result.unwrap_err() {
CompletionError::ChecksumMismatch {
page_id,
expected,
actual,
} => {
assert_eq!(page_id, 2);
assert_ne!(expected, actual);
}
_ => panic!("Expected ChecksumMismatch error"),
}
}
#[test]
fn test_verify_and_strip_checksum_corrupted_checksum() {
let ctx = ChecksumContext::new();
let mut page = get_random_page();
ctx.add_checksum_to_page(&mut page, 2).unwrap();
// corrupt the checksum itself
page[CHECKSUM_PAGE_SIZE - 1] = 255;
let result = ctx.verify_checksum(&mut page, 2);
assert!(result.is_err());
match result.unwrap_err() {
CompletionError::ChecksumMismatch {
page_id,
expected,
actual,
} => {
assert_eq!(page_id, 2);
assert_ne!(expected, actual);
}
_ => panic!("Expected ChecksumMismatch error"),
}
}
}

View File

@@ -1,4 +1,5 @@
use crate::error::LimboError;
use crate::storage::checksum::ChecksumContext;
use crate::storage::encryption::EncryptionContext;
use crate::{io::Completion, Buffer, CompletionError, Result};
use std::sync::Arc;
@@ -7,7 +8,7 @@ use tracing::{instrument, Level};
#[derive(Clone)]
pub enum EncryptionOrChecksum {
Encryption(EncryptionContext),
Checksum,
Checksum(ChecksumContext),
None,
}
@@ -24,15 +25,31 @@ impl IOContext {
}
}
pub fn get_reserved_space_bytes(&self) -> u8 {
match &self.encryption_or_checksum {
EncryptionOrChecksum::Encryption(ctx) => ctx.required_reserved_bytes(),
EncryptionOrChecksum::Checksum(ctx) => ctx.required_reserved_bytes(),
EncryptionOrChecksum::None => Default::default(),
}
}
pub fn set_encryption(&mut self, encryption_ctx: EncryptionContext) {
self.encryption_or_checksum = EncryptionOrChecksum::Encryption(encryption_ctx);
}
pub fn encryption_or_checksum(&self) -> &EncryptionOrChecksum {
&self.encryption_or_checksum
}
pub fn reset_checksum(&mut self) {
self.encryption_or_checksum = EncryptionOrChecksum::None;
}
}
impl Default for IOContext {
fn default() -> Self {
Self {
encryption_or_checksum: EncryptionOrChecksum::None,
encryption_or_checksum: EncryptionOrChecksum::Checksum(ChecksumContext::default()),
}
}
}
@@ -95,42 +112,77 @@ impl DatabaseStorage for DatabaseFile {
return Err(LimboError::IntegerOverflow);
};
if let Some(ctx) = io_ctx.encryption_context() {
let encryption_ctx = ctx.clone();
let read_buffer = r.buf_arc();
let original_c = c.clone();
match &io_ctx.encryption_or_checksum {
EncryptionOrChecksum::Encryption(ctx) => {
let encryption_ctx = ctx.clone();
let read_buffer = r.buf_arc();
let original_c = c.clone();
let decrypt_complete =
Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
return;
};
assert!(
bytes_read > 0,
"Expected to read some data on success for page_id={page_idx}"
);
match encryption_ctx.decrypt_page(buf.as_slice(), page_idx) {
Ok(decrypted_data) => {
let original_buf = original_c.as_read().buf();
original_buf.as_mut_slice().copy_from_slice(&decrypted_data);
original_c.complete(bytes_read);
}
Err(e) => {
tracing::error!(
"Failed to decrypt page data for page_id={page_idx}: {e}"
);
assert!(
!original_c.has_error(),
"Original completion already has an error"
);
original_c.error(CompletionError::DecryptionError { page_idx });
}
}
});
let wrapped_completion = Completion::new_read(read_buffer, decrypt_complete);
self.file.pread(pos, wrapped_completion)
}
EncryptionOrChecksum::Checksum(ctx) => {
let checksum_ctx = ctx.clone();
let read_buffer = r.buf_arc();
let original_c = c.clone();
let decrypt_complete =
Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
return;
};
assert!(
bytes_read > 0,
"Expected to read some data on success for page_id={page_idx}"
);
match encryption_ctx.decrypt_page(buf.as_slice(), page_idx) {
Ok(decrypted_data) => {
let original_buf = original_c.as_read().buf();
original_buf.as_mut_slice().copy_from_slice(&decrypted_data);
let verify_complete =
Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
return;
};
if bytes_read <= 0 {
tracing::trace!("Read page {page_idx} with {} bytes", bytes_read);
original_c.complete(bytes_read);
return;
}
Err(e) => {
tracing::error!(
"Failed to decrypt page data for page_id={page_idx}: {e}"
);
assert!(
!original_c.has_error(),
"Original completion already has an error"
);
original_c.error(CompletionError::DecryptionError { page_idx });
match checksum_ctx.verify_checksum(buf.as_mut_slice(), page_idx) {
Ok(_) => {
original_c.complete(bytes_read);
}
Err(e) => {
tracing::error!(
"Failed to verify checksum for page_id={page_idx}: {e}"
);
assert!(
!original_c.has_error(),
"Original completion already has an error"
);
original_c.error(e);
}
}
}
});
let new_completion = Completion::new_read(read_buffer, decrypt_complete);
self.file.pread(pos, new_completion)
} else {
self.file.pread(pos, c)
});
let wrapped_completion = Completion::new_read(read_buffer, verify_complete);
self.file.pread(pos, wrapped_completion)
}
EncryptionOrChecksum::None => self.file.pread(pos, c),
}
}
@@ -150,12 +202,10 @@ impl DatabaseStorage for DatabaseFile {
let Some(pos) = (page_idx as u64 - 1).checked_mul(buffer_size as u64) else {
return Err(LimboError::IntegerOverflow);
};
let buffer = {
if let Some(ctx) = io_ctx.encryption_context() {
encrypt_buffer(page_idx, buffer, ctx)
} else {
buffer
}
let buffer = match &io_ctx.encryption_or_checksum {
EncryptionOrChecksum::Encryption(ctx) => encrypt_buffer(page_idx, buffer, ctx),
EncryptionOrChecksum::Checksum(ctx) => checksum_buffer(page_idx, buffer, ctx),
EncryptionOrChecksum::None => buffer,
};
self.file.pwrite(pos, buffer, c)
}
@@ -176,18 +226,19 @@ impl DatabaseStorage for DatabaseFile {
let Some(pos) = (first_page_idx as u64 - 1).checked_mul(page_size as u64) else {
return Err(LimboError::IntegerOverflow);
};
let buffers = {
if let Some(ctx) = io_ctx.encryption_context() {
buffers
.into_iter()
.enumerate()
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx))
.collect::<Vec<_>>()
} else {
buffers
}
let buffers = match &io_ctx.encryption_or_checksum() {
EncryptionOrChecksum::Encryption(ctx) => buffers
.into_iter()
.enumerate()
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx))
.collect::<Vec<_>>(),
EncryptionOrChecksum::Checksum(ctx) => buffers
.into_iter()
.enumerate()
.map(|(i, buffer)| checksum_buffer(first_page_idx + i, buffer, ctx))
.collect::<Vec<_>>(),
EncryptionOrChecksum::None => buffers,
};
let c = self.file.pwritev(pos, buffers, c)?;
Ok(c)
}
@@ -220,3 +271,9 @@ fn encrypt_buffer(page_idx: usize, buffer: Arc<Buffer>, ctx: &EncryptionContext)
let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap();
Arc::new(Buffer::new(encrypted_data.to_vec()))
}
fn checksum_buffer(page_idx: usize, buffer: Arc<Buffer>, ctx: &ChecksumContext) -> Arc<Buffer> {
ctx.add_checksum_to_page(buffer.as_mut_slice(), page_idx)
.unwrap();
buffer
}

View File

@@ -12,6 +12,7 @@
//! for the database, also either local or remote.
pub(crate) mod btree;
pub(crate) mod buffer_pool;
pub(crate) mod checksum;
pub mod database;
pub(crate) mod encryption;
pub(crate) mod page_cache;

View File

@@ -16,7 +16,7 @@ use crate::{
Result, TransactionState,
};
use parking_lot::RwLock;
use std::cell::{Cell, OnceCell, RefCell, UnsafeCell};
use std::cell::{Cell, RefCell, UnsafeCell};
use std::collections::HashSet;
use std::hash;
use std::rc::Rc;
@@ -43,7 +43,7 @@ impl HeaderRef {
pub fn from_pager(pager: &Pager) -> Result<IOResult<Self>> {
loop {
let state = pager.header_ref_state.borrow().clone();
tracing::trace!(?state);
tracing::trace!("HeaderRef::from_pager - {:?}", state);
match state {
HeaderRefState::Start => {
if !pager.db_state.is_initialized() {
@@ -491,7 +491,7 @@ pub struct Pager {
/// `usable_space` calls. TODO: Invalidate reserved_space when we add the functionality
/// to change it.
pub(crate) page_size: Cell<Option<PageSize>>,
reserved_space: OnceCell<u8>,
reserved_space: Cell<Option<u8>>,
free_page_state: RefCell<FreePageState>,
/// Maximum number of pages allowed in the database. Default is 1073741823 (SQLite default).
max_page_count: Cell<u32>,
@@ -597,7 +597,7 @@ impl Pager {
init_lock,
allocate_page1_state,
page_size: Cell::new(None),
reserved_space: OnceCell::new(),
reserved_space: Cell::new(None),
free_page_state: RefCell::new(FreePageState::Start),
allocate_page_state: RefCell::new(AllocatePageState::Start),
max_page_count: Cell::new(DEFAULT_MAX_PAGE_COUNT),
@@ -969,7 +969,7 @@ impl Pager {
.unwrap_or_default()
});
let reserved_space = *self.reserved_space.get_or_init(|| {
let reserved_space = *self.reserved_space.get().get_or_insert_with(|| {
self.io
.block(|| self.with_header(|header| header.reserved_space))
.unwrap_or_default()
@@ -1826,15 +1826,25 @@ impl Pager {
assert_eq!(default_header.database_size.get(), 0);
default_header.database_size = 1.into();
// if a key is set, then we will reserve space for encryption metadata
let io_ctx = self.io_ctx.borrow();
if let Some(ctx) = io_ctx.encryption_context() {
default_header.reserved_space = ctx.required_reserved_bytes()
}
// based on the IOContext set, we will set the reserved space bytes as required by
// either the encryption or checksum, or None if they are not set.
let reserved_space_bytes = {
let io_ctx = self.io_ctx.borrow();
io_ctx.get_reserved_space_bytes()
};
default_header.reserved_space = reserved_space_bytes;
self.reserved_space.set(Some(reserved_space_bytes));
if let Some(size) = self.page_size.get() {
default_header.page_size = size;
}
tracing::info!(
"allocate_page1(Start) page_size = {:?}, reserved_space = {}",
default_header.page_size,
default_header.reserved_space
);
self.buffer_pool
.finalize_with_page_size(default_header.page_size.get() as usize)?;
let page = allocate_new_page(1, &self.buffer_pool, 0);
@@ -2212,6 +2222,20 @@ impl Pager {
.set_io_context(self.io_ctx.borrow().clone());
Ok(())
}
pub fn reset_checksum_context(&self) {
{
let mut io_ctx = self.io_ctx.borrow_mut();
io_ctx.reset_checksum();
}
let Some(wal) = self.wal.as_ref() else { return };
wal.borrow_mut()
.set_io_context(self.io_ctx.borrow().clone())
}
pub fn set_reserved_space_bytes(&self, value: u8) {
self.reserved_space.set(Some(value))
}
}
pub fn allocate_new_page(page_id: usize, buffer_pool: &Arc<BufferPool>, offset: usize) -> PageRef {

View File

@@ -58,7 +58,7 @@ use crate::storage::btree::offset::{
};
use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min};
use crate::storage::buffer_pool::BufferPool;
use crate::storage::database::DatabaseStorage;
use crate::storage::database::{DatabaseStorage, EncryptionOrChecksum};
use crate::storage::pager::Pager;
use crate::storage::wal::READMARK_NOT_USED;
use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype};
@@ -1985,40 +1985,74 @@ pub fn begin_read_wal_frame(
let buf = buffer_pool.get_page();
let buf = Arc::new(buf);
if let Some(ctx) = io_ctx.encryption_context() {
let encryption_ctx = ctx.clone();
let original_complete = complete;
match io_ctx.encryption_or_checksum() {
EncryptionOrChecksum::Encryption(ctx) => {
let encryption_ctx = ctx.clone();
let original_complete = complete;
let decrypt_complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((encrypted_buf, bytes_read)) = res else {
original_complete(res);
return;
};
assert!(
bytes_read > 0,
"Expected to read some data on success for page_idx={page_idx}"
);
match encryption_ctx.decrypt_page(encrypted_buf.as_slice(), page_idx) {
Ok(decrypted_data) => {
encrypted_buf
.as_mut_slice()
.copy_from_slice(&decrypted_data);
original_complete(Ok((encrypted_buf, bytes_read)));
}
Err(e) => {
tracing::error!(
"Failed to decrypt WAL frame data for page_idx={page_idx}: {e}"
let decrypt_complete =
Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((encrypted_buf, bytes_read)) = res else {
original_complete(res);
return;
};
assert!(
bytes_read > 0,
"Expected to read some data on success for page_idx={page_idx}"
);
original_complete(Err(CompletionError::DecryptionError { page_idx }));
}
}
});
match encryption_ctx.decrypt_page(encrypted_buf.as_slice(), page_idx) {
Ok(decrypted_data) => {
encrypted_buf
.as_mut_slice()
.copy_from_slice(&decrypted_data);
original_complete(Ok((encrypted_buf, bytes_read)));
}
Err(e) => {
tracing::error!(
"Failed to decrypt WAL frame data for page_idx={page_idx}: {e}"
);
original_complete(Err(CompletionError::DecryptionError { page_idx }));
}
}
});
let new_completion = Completion::new_read(buf, decrypt_complete);
io.pread(offset, new_completion)
} else {
let c = Completion::new_read(buf, complete);
io.pread(offset, c)
let new_completion = Completion::new_read(buf, decrypt_complete);
io.pread(offset, new_completion)
}
EncryptionOrChecksum::Checksum(ctx) => {
let checksum_ctx = ctx.clone();
let original_c = complete;
let verify_complete =
Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
original_c(res);
return;
};
if bytes_read <= 0 {
tracing::trace!("Read page {page_idx} with {} bytes", bytes_read);
original_c(Ok((buf, bytes_read)));
return;
}
match checksum_ctx.verify_checksum(buf.as_mut_slice(), page_idx) {
Ok(_) => {
original_c(Ok((buf, bytes_read)));
}
Err(e) => {
tracing::error!(
"Failed to verify checksum for page_id={page_idx}: {e}"
);
original_c(Err(e))
}
}
});
let c = Completion::new_read(buf, verify_complete);
io.pread(offset, c)
}
EncryptionOrChecksum::None => {
let c = Completion::new_read(buf, complete);
io.pread(offset, c)
}
}
}

View File

@@ -18,6 +18,7 @@ use super::sqlite3_ondisk::{self, checksum_wal, WalHeader, WAL_MAGIC_BE, WAL_MAG
use crate::fast_lock::SpinLock;
use crate::io::{clock, File, IO};
use crate::result::LimboResult;
use crate::storage::database::EncryptionOrChecksum;
use crate::storage::sqlite3_ondisk::{
begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame,
write_pages_vectored, PageSize, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE,
@@ -1312,18 +1313,17 @@ impl Wal for WalFile {
let page_buf = page_content.as_ptr();
let io_ctx = self.io_ctx.borrow();
let encryption_ctx = io_ctx.encryption_context();
let encrypted_data = {
if let Some(key) = encryption_ctx.as_ref() {
Some(key.encrypt_page(page_buf, page_id)?)
} else {
None
let encrypted_data;
let data_to_write = match &io_ctx.encryption_or_checksum() {
EncryptionOrChecksum::Encryption(ctx) => {
encrypted_data = ctx.encrypt_page(page_buf, page_id)?;
encrypted_data.as_slice()
}
};
let data_to_write = if encryption_ctx.as_ref().is_some() {
encrypted_data.as_ref().unwrap().as_slice()
} else {
page_buf
EncryptionOrChecksum::Checksum(ctx) => {
ctx.add_checksum_to_page(page_buf, page_id)?;
page_buf
}
EncryptionOrChecksum::None => page_buf,
};
let (frame_checksums, frame_bytes) = prepare_wal_frame(
@@ -1521,11 +1521,15 @@ impl Wal for WalFile {
let data_to_write: std::borrow::Cow<[u8]> = {
let io_ctx = self.io_ctx.borrow();
let ectx = io_ctx.encryption_context();
if let Some(ctx) = ectx.as_ref() {
Cow::Owned(ctx.encrypt_page(plain, page_id)?)
} else {
Cow::Borrowed(plain)
match &io_ctx.encryption_or_checksum() {
EncryptionOrChecksum::Encryption(ctx) => {
Cow::Owned(ctx.encrypt_page(plain, page_id)?)
}
EncryptionOrChecksum::Checksum(ctx) => {
ctx.add_checksum_to_page(plain, page_id)?;
Cow::Borrowed(plain)
}
EncryptionOrChecksum::None => Cow::Borrowed(plain),
}
};

View File

@@ -28,6 +28,7 @@ rand_chacha = "0.9.0"
rand = "0.9.0"
zerocopy = "0.8.26"
ctor = "0.5.0"
twox-hash = "2.1.1"
[dev-dependencies]
test-log = { version = "0.2.17", features = ["trace"] }
@@ -36,3 +37,4 @@ tracing = "0.1.41"
[features]
encryption = ["turso_core/encryption"]
checksum = ["turso_core/checksum"]

View File

@@ -5,7 +5,7 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use tempfile::TempDir;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use turso_core::{Connection, Database, IO};
use turso_core::{Connection, Database, Row, StepResult, IO};
#[allow(dead_code)]
pub struct TempDatabase {
@@ -293,6 +293,45 @@ pub(crate) fn rng_from_time() -> (ChaCha8Rng, u64) {
(rng, seed)
}
pub fn run_query(tmp_db: &TempDatabase, conn: &Arc<Connection>, query: &str) -> anyhow::Result<()> {
run_query_core(tmp_db, conn, query, None::<fn(&Row)>)
}
pub fn run_query_on_row(
tmp_db: &TempDatabase,
conn: &Arc<Connection>,
query: &str,
on_row: impl FnMut(&Row),
) -> anyhow::Result<()> {
run_query_core(tmp_db, conn, query, Some(on_row))
}
pub fn run_query_core(
_tmp_db: &TempDatabase,
conn: &Arc<Connection>,
query: &str,
mut on_row: Option<impl FnMut(&Row)>,
) -> anyhow::Result<()> {
if let Some(ref mut rows) = conn.query(query)? {
loop {
match rows.step()? {
StepResult::IO => {
rows.run_once()?;
}
StepResult::Done => break,
StepResult::Row => {
if let Some(on_row) = on_row.as_mut() {
let row = rows.row().unwrap();
on_row(row)
}
}
r => panic!("unexpected step result: {r:?}"),
}
}
};
Ok(())
}
#[cfg(test)]
mod tests {
use std::{sync::Arc, vec};

View File

@@ -4,6 +4,7 @@ mod fuzz;
mod fuzz_transaction;
mod pragma;
mod query_processing;
mod storage;
mod wal;
#[cfg(test)]

View File

@@ -1,5 +1,4 @@
use crate::common::{do_flush, TempDatabase};
use crate::query_processing::test_write_path::{run_query, run_query_on_row};
use crate::common::{do_flush, run_query, run_query_on_row, TempDatabase};
use rand::{rng, RngCore};
use std::panic;
use turso_core::Row;

View File

@@ -175,10 +175,10 @@ fn test_sequential_write() -> anyhow::Result<()> {
println!("progress {progress:.1}%");
}
let insert_query = format!("INSERT INTO test VALUES ({i})");
run_query(&tmp_db, &conn, &insert_query)?;
common::run_query(&tmp_db, &conn, &insert_query)?;
let mut current_read_index = 0;
run_query_on_row(&tmp_db, &conn, list_query, |row: &Row| {
common::run_query_on_row(&tmp_db, &conn, list_query, |row: &Row| {
let first_value = row.get::<&Value>(0).expect("missing id");
let id = match first_value {
turso_core::Value::Integer(i) => *i as i32,
@@ -204,14 +204,14 @@ fn test_regression_multi_row_insert() -> anyhow::Result<()> {
let insert_query = "INSERT INTO test VALUES (-2), (-3), (-1)";
let list_query = "SELECT * FROM test";
run_query(&tmp_db, &conn, insert_query)?;
common::run_query(&tmp_db, &conn, insert_query)?;
common::do_flush(&conn, &tmp_db)?;
let mut current_read_index = 1;
let expected_ids = vec![-3, -2, -1];
let mut actual_ids = Vec::new();
run_query_on_row(&tmp_db, &conn, list_query, |row: &Row| {
common::run_query_on_row(&tmp_db, &conn, list_query, |row: &Row| {
let first_value = row.get::<&Value>(0).expect("missing id");
let id = match first_value {
Value::Float(f) => *f as i32,
@@ -290,13 +290,13 @@ fn test_wal_checkpoint() -> anyhow::Result<()> {
conn.checkpoint(CheckpointMode::Passive {
upper_bound_inclusive: None,
})?;
run_query(&tmp_db, &conn, &insert_query)?;
common::run_query(&tmp_db, &conn, &insert_query)?;
}
do_flush(&conn, &tmp_db)?;
let list_query = "SELECT * FROM test LIMIT 1";
let mut current_index = 0;
run_query_on_row(&tmp_db, &conn, list_query, |row: &Row| {
common::run_query_on_row(&tmp_db, &conn, list_query, |row: &Row| {
let id = row.get::<i64>(0).unwrap();
assert_eq!(current_index, id as usize);
current_index += 1;
@@ -315,7 +315,7 @@ fn test_wal_restart() -> anyhow::Result<()> {
fn insert(i: usize, conn: &Arc<Connection>, tmp_db: &TempDatabase) -> anyhow::Result<()> {
debug!("inserting {i}");
let insert_query = format!("INSERT INTO test VALUES ({i})");
run_query(tmp_db, conn, &insert_query)?;
common::run_query(tmp_db, conn, &insert_query)?;
debug!("inserted {i}");
tmp_db.io.step()?;
Ok(())
@@ -325,7 +325,7 @@ fn test_wal_restart() -> anyhow::Result<()> {
debug!("counting");
let list_query = "SELECT count(x) FROM test";
let mut count = None;
run_query_on_row(tmp_db, conn, list_query, |row: &Row| {
common::run_query_on_row(tmp_db, conn, list_query, |row: &Row| {
assert!(count.is_none());
count = Some(row.get::<i64>(0).unwrap() as usize);
debug!("counted {count:?}");
@@ -378,15 +378,15 @@ fn test_write_delete_with_index() -> anyhow::Result<()> {
for i in 0..max_iterations {
println!("inserting {i} ");
let insert_query = format!("INSERT INTO test VALUES ({i})");
run_query(&tmp_db, &conn, &insert_query)?;
common::run_query(&tmp_db, &conn, &insert_query)?;
}
for i in 0..max_iterations {
println!("deleting {i} ");
let delete_query = format!("delete from test where x={i}");
run_query(&tmp_db, &conn, &delete_query)?;
common::run_query(&tmp_db, &conn, &delete_query)?;
println!("listing after deleting {i} ");
let mut current_read_index = i + 1;
run_query_on_row(&tmp_db, &conn, list_query, |row: &Row| {
common::run_query_on_row(&tmp_db, &conn, list_query, |row: &Row| {
let first_value = row.get::<&Value>(0).expect("missing id");
let id = match first_value {
turso_core::Value::Integer(i) => *i as i32,
@@ -398,7 +398,7 @@ fn test_write_delete_with_index() -> anyhow::Result<()> {
})?;
for i in i + 1..max_iterations {
// now test with seek
run_query_on_row(
common::run_query_on_row(
&tmp_db,
&conn,
&format!("select * from test where x = {i}"),
@@ -428,20 +428,20 @@ fn test_update_with_index() -> anyhow::Result<()> {
TempDatabase::new_with_rusqlite("CREATE TABLE test (x REAL PRIMARY KEY, y TEXT);", true);
let conn = tmp_db.connect_limbo();
run_query(&tmp_db, &conn, "INSERT INTO test VALUES (1.0, 'foo')")?;
run_query(&tmp_db, &conn, "INSERT INTO test VALUES (2.0, 'bar')")?;
common::run_query(&tmp_db, &conn, "INSERT INTO test VALUES (1.0, 'foo')")?;
common::run_query(&tmp_db, &conn, "INSERT INTO test VALUES (2.0, 'bar')")?;
run_query_on_row(&tmp_db, &conn, "SELECT * from test WHERE x=10.0", |row| {
common::run_query_on_row(&tmp_db, &conn, "SELECT * from test WHERE x=10.0", |row| {
assert_eq!(row.get::<f64>(0).unwrap(), 1.0);
})?;
run_query(&tmp_db, &conn, "UPDATE test SET x=10.0 WHERE x=1.0")?;
run_query_on_row(&tmp_db, &conn, "SELECT * from test WHERE x=10.0", |row| {
common::run_query(&tmp_db, &conn, "UPDATE test SET x=10.0 WHERE x=1.0")?;
common::run_query_on_row(&tmp_db, &conn, "SELECT * from test WHERE x=10.0", |row| {
assert_eq!(row.get::<f64>(0).unwrap(), 10.0);
})?;
let mut count_1 = 0;
let mut count_10 = 0;
run_query_on_row(&tmp_db, &conn, "SELECT * from test", |row| {
common::run_query_on_row(&tmp_db, &conn, "SELECT * from test", |row| {
let v = row.get::<f64>(0).unwrap();
if v == 1.0 {
count_1 += 1;
@@ -464,10 +464,10 @@ fn test_delete_with_index() -> anyhow::Result<()> {
let tmp_db = TempDatabase::new_with_rusqlite("CREATE TABLE t (x UNIQUE)", true);
let conn = tmp_db.connect_limbo();
run_query(&tmp_db, &conn, "INSERT INTO t VALUES (1), (2)")?;
run_query(&tmp_db, &conn, "DELETE FROM t WHERE x >= 1")?;
common::run_query(&tmp_db, &conn, "INSERT INTO t VALUES (1), (2)")?;
common::run_query(&tmp_db, &conn, "DELETE FROM t WHERE x >= 1")?;
run_query_on_row(&tmp_db, &conn, "SELECT * FROM t", |_| {
common::run_query_on_row(&tmp_db, &conn, "SELECT * FROM t", |_| {
panic!("Delete should've deleted every row!");
})?;
@@ -516,7 +516,7 @@ fn test_multiple_statements() -> anyhow::Result<()> {
conn.execute("INSERT INTO t values(1); insert into t values(2);")?;
run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t;", |row| {
common::run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t;", |row| {
let count = row.get::<i64>(0).unwrap();
assert_eq!(count, 2);
})
@@ -526,7 +526,7 @@ fn test_multiple_statements() -> anyhow::Result<()> {
}
fn check_integrity_is_ok(tmp_db: TempDatabase, conn: Arc<Connection>) -> Result<(), anyhow::Error> {
run_query_on_row(&tmp_db, &conn, "pragma integrity_check", |row: &Row| {
common::run_query_on_row(&tmp_db, &conn, "pragma integrity_check", |row: &Row| {
let res = row.get::<String>(0).unwrap();
assert!(res.contains("ok"));
})?;
@@ -639,7 +639,7 @@ fn test_write_concurrent_connections() -> anyhow::Result<()> {
}
let conn = tmp_db.connect_limbo();
run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t", |row: &Row| {
common::run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t", |row: &Row| {
let count = row.get::<i64>(0).unwrap();
assert_eq!(
count,
@@ -665,12 +665,12 @@ fn test_wal_bad_frame() -> anyhow::Result<()> {
conn.execute("INSERT INTO t2(x) VALUES (1)")?;
conn.execute("INSERT INTO t3(x) VALUES (1)")?;
conn.execute("COMMIT")?;
run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t2", |row| {
common::run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t2", |row| {
let x = row.get::<i64>(0).unwrap();
assert_eq!(x, 1);
})
.unwrap();
run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t3", |row| {
common::run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t3", |row| {
let x = row.get::<i64>(0).unwrap();
assert_eq!(x, 1);
})
@@ -715,7 +715,7 @@ fn test_wal_bad_frame() -> anyhow::Result<()> {
db,
};
let conn = tmp_db.connect_limbo();
run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t2", |row| {
common::run_query_on_row(&tmp_db, &conn, "SELECT count(1) from t2", |row| {
let x = row.get::<i64>(0).unwrap();
assert_eq!(x, 0);
})
@@ -789,42 +789,3 @@ fn test_insert_with_column_names() -> anyhow::Result<()> {
Ok(())
}
pub fn run_query(tmp_db: &TempDatabase, conn: &Arc<Connection>, query: &str) -> anyhow::Result<()> {
run_query_core(tmp_db, conn, query, None::<fn(&Row)>)
}
pub fn run_query_on_row(
tmp_db: &TempDatabase,
conn: &Arc<Connection>,
query: &str,
on_row: impl FnMut(&Row),
) -> anyhow::Result<()> {
run_query_core(tmp_db, conn, query, Some(on_row))
}
pub fn run_query_core(
_tmp_db: &TempDatabase,
conn: &Arc<Connection>,
query: &str,
mut on_row: Option<impl FnMut(&Row)>,
) -> anyhow::Result<()> {
if let Some(ref mut rows) = conn.query(query)? {
loop {
match rows.step()? {
StepResult::IO => {
rows.run_once()?;
}
StepResult::Done => break,
StepResult::Row => {
if let Some(on_row) = on_row.as_mut() {
let row = rows.row().unwrap();
on_row(row)
}
}
r => panic!("unexpected step result: {r:?}"),
}
}
};
Ok(())
}

View File

@@ -0,0 +1,118 @@
use crate::common::{do_flush, run_query, run_query_on_row, TempDatabase};
use rand::{rng, RngCore};
use std::panic;
use turso_core::Row;
#[test]
fn test_per_page_checksum() -> anyhow::Result<()> {
let _ = env_logger::try_init();
let db_name = format!("test-{}.db", rng().next_u32());
let tmp_db = TempDatabase::new(&db_name, false);
let db_path = tmp_db.path.clone();
{
let conn = tmp_db.connect_limbo();
run_query(
&tmp_db,
&conn,
"CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT);",
)?;
run_query(
&tmp_db,
&conn,
"INSERT INTO test (value) VALUES ('Hello, World!')",
)?;
let mut row_count = 0;
run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |row: &Row| {
assert_eq!(row.get::<i64>(0).unwrap(), 1);
assert_eq!(row.get::<String>(1).unwrap(), "Hello, World!");
row_count += 1;
})?;
assert_eq!(row_count, 1);
do_flush(&conn, &tmp_db)?;
}
{
let metadata = std::fs::metadata(&db_path)?;
assert_eq!(metadata.len(), 4096, "db file should be exactly 4096 bytes");
}
// let's test that page actually contains checksum bytes
{
let file_contents = std::fs::read(&db_path)?;
assert_eq!(
file_contents.len(),
4096,
"file contents should be 4096 bytes"
);
// split the page: first 4088 bytes are actual page, last 8 bytes are checksum
let actual_page = &file_contents[..4096 - 8];
let checksum_bytes = &file_contents[4096 - 8..];
let stored_checksum = u64::from_le_bytes(checksum_bytes.try_into().unwrap());
let expected_checksum = twox_hash::XxHash3_64::oneshot(actual_page);
assert_eq!(
stored_checksum, expected_checksum,
"Stored checksum should match manually calculated checksum"
);
}
Ok(())
}
#[test]
fn test_checksum_detects_corruption() {
let _ = env_logger::try_init();
let db_name = format!("test-corruption-{}.db", rng().next_u32());
let tmp_db = TempDatabase::new(&db_name, false);
let db_path = tmp_db.path.clone();
// Create and populate the database
{
let conn = tmp_db.connect_limbo();
run_query(
&tmp_db,
&conn,
"CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT);",
)
.unwrap();
run_query(
&tmp_db,
&conn,
"INSERT INTO test (value) VALUES ('Hello, World!')",
)
.unwrap();
do_flush(&conn, &tmp_db).unwrap();
run_query(&tmp_db, &conn, "PRAGMA wal_checkpoint(TRUNCATE);").unwrap();
}
{
let mut file_contents = std::fs::read(&db_path).unwrap();
assert_eq!(file_contents.len(), 8192, "File should be 4096 bytes");
// lets corrupt the db at byte 2025, the year of Turso DB
file_contents[2025] = !file_contents[2025];
std::fs::write(&db_path, file_contents).unwrap();
}
{
let existing_db = TempDatabase::new_with_existent(&db_path, false);
// this query should fail and result in panic because db is now corrupted
let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let conn = existing_db.connect_limbo();
run_query_on_row(
&existing_db,
&conn,
"SELECT * FROM test",
|_: &Row| unreachable!(),
)
.unwrap();
}));
assert!(
should_panic.is_err(),
"should panic when accessing corrupted DB"
);
}
}

View File

@@ -0,0 +1,2 @@
#[cfg(feature = "checksum")]
mod checksum;