diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c3becd317..12073def9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -38,6 +38,8 @@ jobs: python-version: "3.10" - name: Build run: cargo build --verbose + - name: Test Encryption + run: cargo test --features encryption --color=always --test integration_tests query_processing::encryption - name: Test env: RUST_LOG: ${{ runner.debug && 'turso_core::storage=trace' || '' }} diff --git a/Cargo.lock b/Cargo.lock index 1bbb07efb..9d521cd30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,41 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.8.11" @@ -449,6 +484,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.32" @@ -607,6 +652,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -732,6 +786,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] @@ -772,6 +827,15 @@ version = "0.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f211af61d8efdd104f96e57adf5e426ba1bc3ed7a4ead616e15e5881fd79c4d" +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "ctrlc" version = "3.4.5" @@ -1335,6 +1399,16 @@ dependencies = [ "wasi 0.14.2+wasi-0.2.4", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.31.1" @@ -1681,6 +1755,15 @@ dependencies = [ "libc", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "io-uring" version = "0.7.6" @@ -2435,6 +2518,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "option-ext" version = "0.2.0" @@ -2582,6 +2671,18 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.11.0" @@ -3362,6 +3463,12 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "supports-color" version = "3.0.2" @@ -3850,6 +3957,8 @@ dependencies = [ name = "turso_core" version = "0.1.4" dependencies = [ + "aes", + "aes-gcm", "antithesis_sdk", "bitflags 2.9.0", "built", @@ -4101,6 +4210,16 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.9.0" diff --git a/bindings/javascript/Cargo.toml b/bindings/javascript/Cargo.toml index 18a9b319d..077f07fb3 100644 --- a/bindings/javascript/Cargo.toml +++ b/bindings/javascript/Cargo.toml @@ -16,5 +16,8 @@ napi = { version = "3.1.3", default-features = false, features = ["napi6"] } napi-derive = { version = "3.1.1", default-features = true } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } +[features] +encryption = ["turso_core/encryption"] + [build-dependencies] napi-build = "2.2.3" diff --git a/bindings/javascript/src/lib.rs b/bindings/javascript/src/lib.rs index b02503a87..1c0218eeb 100644 --- a/bindings/javascript/src/lib.rs +++ b/bindings/javascript/src/lib.rs @@ -561,6 +561,7 @@ impl turso_core::DatabaseStorage for DatabaseFile { fn read_page( &self, page_idx: usize, + _key: Option<&turso_core::EncryptionKey>, c: turso_core::Completion, ) -> turso_core::Result { let r = c.as_read(); @@ -577,6 +578,7 @@ impl turso_core::DatabaseStorage for DatabaseFile { &self, page_idx: usize, buffer: Arc, + _key: Option<&turso_core::EncryptionKey>, c: turso_core::Completion, ) -> turso_core::Result { let size = buffer.len(); @@ -586,12 +588,13 @@ impl turso_core::DatabaseStorage for DatabaseFile { fn write_pages( &self, - page_idx: usize, + first_page_idx: usize, page_size: usize, buffers: Vec>, + _key: Option<&turso_core::EncryptionKey>, c: turso_core::Completion, ) -> turso_core::Result { - let pos = page_idx.saturating_sub(1) * page_size; + let pos = first_page_idx.saturating_sub(1) * page_size; let c = self.file.pwritev(pos, buffers, c)?; Ok(c) } diff --git a/core/Cargo.toml b/core/Cargo.toml index f51f7c810..ce88e7566 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -27,6 +27,7 @@ omit_autovacuum = [] simulator = ["fuzz", "serde"] serde = ["dep:serde"] series = [] +encryption = [] [target.'cfg(target_os = "linux")'.dependencies] io-uring = { version = "0.7.5", optional = true } @@ -73,6 +74,8 @@ uuid = { version = "1.11.0", features = ["v4", "v7"], optional = true } tempfile = "3.8.0" pack1 = { version = "1.0.0", features = ["bytemuck"] } bytemuck = "1.23.1" +aes-gcm = { version = "0.10.3"} +aes = { version = "0.8.4"} [build-dependencies] chrono = { version = "0.4.38", default-features = false } diff --git a/core/io/mod.rs b/core/io/mod.rs index a54443ce5..beb674d67 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -280,6 +280,10 @@ impl ReadCompletion { pub fn callback(&self, bytes_read: Result) { (self.complete)(bytes_read.map(|b| (self.buf.clone(), b))); } + + pub fn buf_arc(&self) -> Arc { + self.buf.clone() + } } pub struct WriteCompletion { diff --git a/core/lib.rs b/core/lib.rs index c7c4dfa04..da49b2b80 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -76,6 +76,7 @@ use std::{ }; #[cfg(feature = "fs")] use storage::database::DatabaseFile; +pub use storage::encryption::EncryptionKey; use storage::page_cache::DumbLruPageCache; use storage::pager::{AtomicDbState, DbState}; use storage::sqlite3_ondisk::PageSize; @@ -425,6 +426,7 @@ impl Database { view_transaction_states: RefCell::new(HashMap::new()), metrics: RefCell::new(ConnectionMetrics::new()), is_nested_stmt: Cell::new(false), + encryption_key: RefCell::new(None), }); let builtin_syms = self.builtin_syms.borrow(); // add built-in extensions symbols to the connection to prevent having to load each time @@ -852,6 +854,7 @@ pub struct Connection { /// Whether the connection is executing a statement initiated by another statement. /// Generally this is only true for ParseSchema. is_nested_stmt: Cell, + encryption_key: RefCell>, } impl Connection { @@ -1925,6 +1928,13 @@ impl Connection { pub fn get_syms_vtab_mods(&self) -> std::collections::HashSet { self.syms.borrow().vtab_modules.keys().cloned().collect() } + + pub fn set_encryption_key(&self, key: Option) { + tracing::trace!("setting encryption key for connection"); + *self.encryption_key.borrow_mut() = key.clone(); + let pager = self.pager.borrow(); + pager.set_encryption_key(key); + } } pub struct Statement { diff --git a/core/pragma.rs b/core/pragma.rs index f4b4b3f44..62be7f313 100644 --- a/core/pragma.rs +++ b/core/pragma.rs @@ -107,6 +107,10 @@ pub fn pragma_for(pragma: &PragmaName) -> Pragma { &["query_only"], ), FreelistCount => Pragma::new(PragmaFlags::Result0, &["freelist_count"]), + EncryptionKey => Pragma::new( + PragmaFlags::Result0 | PragmaFlags::SchemaReq | PragmaFlags::NoColumns1, + &["key"], + ), } } diff --git a/core/storage/btree.rs b/core/storage/btree.rs index de831757d..b0474e344 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -8541,7 +8541,7 @@ mod tests { }); let _c = pager .db_file - .write_page(current_page as usize, buf.clone(), c)?; + .write_page(current_page as usize, buf.clone(), None, c)?; pager.io.run_once()?; let (page, _c) = cursor.read_page(current_page as usize)?; diff --git a/core/storage/database.rs b/core/storage/database.rs index d7f4cf552..980dce8e4 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -1,5 +1,6 @@ use crate::error::LimboError; -use crate::{io::Completion, Buffer, Result}; +use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey}; +use crate::{io::Completion, Buffer, CompletionError, Result}; use std::sync::Arc; use tracing::{instrument, Level}; @@ -10,14 +11,26 @@ use tracing::{instrument, Level}; /// or something like a remote page server service. pub trait DatabaseStorage: Send + Sync { fn read_header(&self, c: Completion) -> Result; - fn read_page(&self, page_idx: usize, c: Completion) -> Result; - fn write_page(&self, page_idx: usize, buffer: Arc, c: Completion) - -> Result; + + fn read_page( + &self, + page_idx: usize, + encryption_key: Option<&EncryptionKey>, + c: Completion, + ) -> Result; + fn write_page( + &self, + page_idx: usize, + buffer: Arc, + encryption_key: Option<&EncryptionKey>, + c: Completion, + ) -> Result; fn write_pages( &self, first_page_idx: usize, page_size: usize, buffers: Vec>, + encryption_key: Option<&EncryptionKey>, c: Completion, ) -> Result; fn sync(&self, c: Completion) -> Result; @@ -41,8 +54,14 @@ impl DatabaseStorage for DatabaseFile { fn read_header(&self, c: Completion) -> Result { self.file.pread(0, c) } + #[instrument(skip_all, level = Level::DEBUG)] - fn read_page(&self, page_idx: usize, c: Completion) -> Result { + fn read_page( + &self, + page_idx: usize, + encryption_key: Option<&EncryptionKey>, + c: Completion, + ) -> Result { let r = c.as_read(); let size = r.buf().len(); assert!(page_idx > 0); @@ -50,7 +69,41 @@ impl DatabaseStorage for DatabaseFile { return Err(LimboError::NotADB); } let pos = (page_idx - 1) * size; - self.file.pread(pos, c) + + if let Some(key) = encryption_key { + let key_clone = key.clone(); + let read_buffer = r.buf_arc(); + let original_c = c.clone(); + + let decrypt_complete = + Box::new(move |res: Result<(Arc, i32), CompletionError>| { + let Ok((buf, bytes_read)) = res else { + return; + }; + if bytes_read > 0 { + match decrypt_page(buf.as_slice(), page_idx, &key_clone) { + Ok(decrypted_data) => { + let original_buf = original_c.as_read().buf(); + original_buf.as_mut_slice().copy_from_slice(&decrypted_data); + original_c.complete(bytes_read); + } + Err(_) => { + tracing::error!( + "Failed to decrypt page data for page_id={page_idx}" + ); + original_c.complete(-1); + } + } + } else { + original_c.complete(bytes_read); + } + }); + + let new_completion = Completion::new_read(read_buffer, decrypt_complete); + self.file.pread(pos, new_completion) + } else { + self.file.pread(pos, c) + } } #[instrument(skip_all, level = Level::DEBUG)] @@ -58,6 +111,7 @@ impl DatabaseStorage for DatabaseFile { &self, page_idx: usize, buffer: Arc, + encryption_key: Option<&EncryptionKey>, c: Completion, ) -> Result { let buffer_size = buffer.len(); @@ -66,21 +120,42 @@ impl DatabaseStorage for DatabaseFile { assert!(buffer_size <= 65536); assert_eq!(buffer_size & (buffer_size - 1), 0); let pos = (page_idx - 1) * buffer_size; + let buffer = { + if let Some(key) = encryption_key { + encrypt_buffer(page_idx, buffer, key) + } else { + buffer + } + }; self.file.pwrite(pos, buffer, c) } fn write_pages( &self, - page_idx: usize, + first_page_idx: usize, page_size: usize, buffers: Vec>, + encryption_key: Option<&EncryptionKey>, c: Completion, ) -> Result { - assert!(page_idx > 0); + assert!(first_page_idx > 0); assert!(page_size >= 512); assert!(page_size <= 65536); assert_eq!(page_size & (page_size - 1), 0); - let pos = (page_idx - 1) * page_size; + + let pos = (first_page_idx - 1) * page_size; + let buffers = { + if let Some(key) = encryption_key { + buffers + .into_iter() + .enumerate() + .map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, key)) + .collect::>() + } else { + buffers + } + }; + let c = self.file.pwritev(pos, buffers, c)?; Ok(c) } @@ -108,3 +183,8 @@ impl DatabaseFile { Self { file } } } + +fn encrypt_buffer(page_idx: usize, buffer: Arc, key: &EncryptionKey) -> Arc { + let encrypted_data = encrypt_page(buffer.as_slice(), page_idx, key).unwrap(); + Arc::new(Buffer::new(encrypted_data.to_vec())) +} diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs new file mode 100644 index 000000000..97c2b3574 --- /dev/null +++ b/core/storage/encryption.rs @@ -0,0 +1,207 @@ +#![allow(unused_variables, dead_code)] +#[cfg(not(feature = "encryption"))] +use crate::LimboError; +use crate::Result; +use aes_gcm::{ + aead::{Aead, AeadCore, KeyInit, OsRng}, + Aes256Gcm, Key, Nonce, +}; +use std::ops::Deref; + +pub const ENCRYPTION_METADATA_SIZE: usize = 28; +pub const ENCRYPTED_PAGE_SIZE: usize = 4096; +pub const ENCRYPTION_NONCE_SIZE: usize = 12; + +#[repr(transparent)] +#[derive(Clone)] +pub struct EncryptionKey([u8; 32]); + +impl EncryptionKey { + pub fn new(key: [u8; 32]) -> Self { + Self(key) + } + + pub fn from_string(s: &str) -> Self { + let mut key = [0u8; 32]; + let bytes = s.as_bytes(); + let len = bytes.len().min(32); + key[..len].copy_from_slice(&bytes[..len]); + Self(key) + } + + pub fn as_bytes(&self) -> &[u8; 32] { + &self.0 + } + + pub fn as_slice(&self) -> &[u8] { + &self.0 + } +} + +impl Deref for EncryptionKey { + type Target = [u8; 32]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef<[u8; 32]> for EncryptionKey { + fn as_ref(&self) -> &[u8; 32] { + &self.0 + } +} + +impl std::fmt::Debug for EncryptionKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EncryptionKey") + .field("key", &"") + .finish() + } +} + +impl Drop for EncryptionKey { + fn drop(&mut self) { + // securely zero out the key bytes before dropping + for byte in self.0.iter_mut() { + unsafe { + std::ptr::write_volatile(byte, 0); + } + } + } +} + +#[cfg(not(feature = "encryption"))] +pub fn encrypt_page(page: &[u8], page_id: usize, key: &EncryptionKey) -> Result> { + Err(LimboError::InvalidArgument( + "encryption is not enabled, cannot encrypt page. enable via passing `--features encryption`".into(), + )) +} + +#[cfg(feature = "encryption")] +pub fn encrypt_page(page: &[u8], page_id: usize, key: &EncryptionKey) -> Result> { + if page_id == 1 { + tracing::debug!("skipping encryption for page 1 (database header)"); + return Ok(page.to_vec()); + } + tracing::debug!("encrypting page {}", page_id); + assert_eq!( + page.len(), + ENCRYPTED_PAGE_SIZE, + "Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" + ); + let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE..]; + let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0); + assert!( + reserved_bytes_zeroed, + "last reserved bytes must be empty/zero, but found non-zero bytes" + ); + let payload = &page[..ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE]; + let (encrypted, nonce) = encrypt(payload, key)?; + assert_eq!( + encrypted.len(), + ENCRYPTED_PAGE_SIZE - nonce.len(), + "Encrypted page must be exactly {} bytes", + ENCRYPTED_PAGE_SIZE - nonce.len() + ); + let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); + result.extend_from_slice(&encrypted); + result.extend_from_slice(&nonce); + assert_eq!( + result.len(), + ENCRYPTED_PAGE_SIZE, + "Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes" + ); + Ok(result) +} + +#[cfg(not(feature = "encryption"))] +pub fn decrypt_page(encrypted_page: &[u8], page_id: usize, key: &EncryptionKey) -> Result> { + Err(LimboError::InvalidArgument( + "encryption is not enabled, cannot decrypt page. enable via passing `--features encryption`".into(), + )) +} + +#[cfg(feature = "encryption")] +pub fn decrypt_page(encrypted_page: &[u8], page_id: usize, key: &EncryptionKey) -> Result> { + if page_id == 1 { + tracing::debug!("skipping decryption for page 1 (database header)"); + return Ok(encrypted_page.to_vec()); + } + tracing::debug!("decrypting page {}", page_id); + assert_eq!( + encrypted_page.len(), + ENCRYPTED_PAGE_SIZE, + "Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" + ); + + let nonce_start = encrypted_page.len() - ENCRYPTION_NONCE_SIZE; + let payload = &encrypted_page[..nonce_start]; + let nonce = &encrypted_page[nonce_start..]; + + let decrypted_data = decrypt(payload, nonce, key)?; + assert_eq!( + decrypted_data.len(), + ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE, + "Decrypted page data must be exactly {} bytes", + ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE + ); + let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); + result.extend_from_slice(&decrypted_data); + result.resize(ENCRYPTED_PAGE_SIZE, 0); + assert_eq!( + result.len(), + ENCRYPTED_PAGE_SIZE, + "Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" + ); + Ok(result) +} + +fn encrypt(plaintext: &[u8], key: &EncryptionKey) -> Result<(Vec, Vec)> { + let key: &Key = key.as_ref().into(); + let cipher = Aes256Gcm::new(key); + let nonce = Aes256Gcm::generate_nonce(&mut OsRng); + let ciphertext = cipher.encrypt(&nonce, plaintext).unwrap(); + Ok((ciphertext, nonce.to_vec())) +} + +fn decrypt(ciphertext: &[u8], nonce: &[u8], key: &EncryptionKey) -> Result> { + let key: &Key = key.as_ref().into(); + let cipher = Aes256Gcm::new(key); + let nonce = Nonce::from_slice(nonce); + let plaintext = cipher.decrypt(nonce, ciphertext).unwrap(); + Ok(plaintext) +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::Rng; + + #[test] + #[cfg(feature = "encryption")] + fn test_encrypt_decrypt_round_trip() { + let mut rng = rand::thread_rng(); + let data_size = ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE; + + let page_data = { + let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE]; + page.iter_mut() + .take(data_size) + .for_each(|byte| *byte = rng.gen()); + page + }; + + let key = EncryptionKey::from_string("alice and bob use encryption on database"); + + let page_id = 42; + let encrypted = encrypt_page(&page_data, page_id, &key).unwrap(); + assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE); + assert_ne!(&encrypted[..data_size], &page_data[..data_size]); + assert_ne!(&encrypted[..], &page_data[..]); + + let decrypted = decrypt_page(&encrypted, page_id, &key).unwrap(); + assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE); + assert_eq!(decrypted, page_data); + } +} diff --git a/core/storage/mod.rs b/core/storage/mod.rs index 944f192fc..52c2e3f47 100644 --- a/core/storage/mod.rs +++ b/core/storage/mod.rs @@ -13,6 +13,7 @@ pub(crate) mod btree; pub(crate) mod buffer_pool; pub(crate) mod database; +pub(crate) mod encryption; pub(crate) mod page_cache; #[allow(clippy::arc_with_non_send_sync)] pub(crate) mod pager; diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 18bbc1d23..7ea80cde4 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -28,6 +28,7 @@ use super::btree::{btree_init_page, BTreePage}; use super::page_cache::{CacheError, CacheResizeResult, DumbLruPageCache, PageCacheKey}; use super::sqlite3_ondisk::begin_write_btree_page; use super::wal::CheckpointMode; +use crate::storage::encryption::EncryptionKey; /// SQLite's default maximum page count const DEFAULT_MAX_PAGE_COUNT: u32 = 0xfffffffe; @@ -425,6 +426,7 @@ pub struct Pager { header_ref_state: RefCell, #[cfg(not(feature = "omit_autovacuum"))] btree_create_vacuum_full_state: Cell, + pub(crate) encryption_key: RefCell>, } #[derive(Debug, Clone)] @@ -526,6 +528,7 @@ impl Pager { header_ref_state: RefCell::new(HeaderRefState::Start), #[cfg(not(feature = "omit_autovacuum"))] btree_create_vacuum_full_state: Cell::new(BtreeCreateVacuumFullState::Start), + encryption_key: RefCell::new(None), }) } @@ -627,8 +630,8 @@ impl Pager { // Check if the calculated offset for the entry is within the bounds of the actual page data length. if offset_in_ptrmap_page + PTRMAP_ENTRY_SIZE > actual_data_length { return Err(LimboError::InternalError(format!( - "Ptrmap offset {offset_in_ptrmap_page} + entry size {PTRMAP_ENTRY_SIZE} out of bounds for page {ptrmap_pg_no} (actual data len {actual_data_length})" - ))); + "Ptrmap offset {offset_in_ptrmap_page} + entry size {PTRMAP_ENTRY_SIZE} out of bounds for page {ptrmap_pg_no} (actual data len {actual_data_length})" + ))); } let entry_slice = &ptrmap_page_data_slice @@ -670,8 +673,8 @@ impl Pager { || is_ptrmap_page(db_page_no_to_update, page_size) { return Err(LimboError::InternalError(format!( - "Cannot set ptrmap entry for page {db_page_no_to_update}: it's a header/ptrmap page or invalid." - ))); + "Cannot set ptrmap entry for page {db_page_no_to_update}: it's a header/ptrmap page or invalid." + ))); } let ptrmap_pg_no = get_ptrmap_page_no_for_db_page(db_page_no_to_update, page_size); @@ -1006,8 +1009,14 @@ impl Pager { matches!(frame_watermark, Some(0) | None), "frame_watermark must be either None or Some(0) because DB has no WAL and read with other watermark is invalid" ); + page.set_locked(); - let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read)?; + let c = self.begin_read_disk_page( + page_idx, + page.clone(), + allow_empty_read, + self.encryption_key.borrow().as_ref(), + )?; return Ok((page, c)); }; @@ -1020,7 +1029,12 @@ impl Pager { return Ok((page, c)); } - let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read)?; + let c = self.begin_read_disk_page( + page_idx, + page.clone(), + allow_empty_read, + self.encryption_key.borrow().as_ref(), + )?; Ok((page, c)) } @@ -1045,6 +1059,7 @@ impl Pager { page_idx: usize, page: PageRef, allow_empty_read: bool, + encryption_key: Option<&EncryptionKey>, ) -> Result { sqlite3_ondisk::begin_read_page( self.db_file.clone(), @@ -1052,6 +1067,7 @@ impl Pager { page, page_idx, allow_empty_read, + encryption_key, ) } @@ -1957,6 +1973,12 @@ impl Pager { let header = header_ref.borrow_mut(); Ok(IOResult::Done(f(header))) } + + pub fn set_encryption_key(&self, key: Option) { + self.encryption_key.replace(key.clone()); + let Some(wal) = self.wal.as_ref() else { return }; + wal.borrow_mut().set_encryption_key(key) + } } pub fn allocate_new_page(page_id: usize, buffer_pool: &Arc, offset: usize) -> PageRef { @@ -2071,10 +2093,10 @@ mod ptrmap { pub fn serialize(&self, buffer: &mut [u8]) -> Result<()> { if buffer.len() < PTRMAP_ENTRY_SIZE { return Err(LimboError::InternalError(format!( - "Buffer too small to serialize ptrmap entry. Expected at least {} bytes, got {}", - PTRMAP_ENTRY_SIZE, - buffer.len() - ))); + "Buffer too small to serialize ptrmap entry. Expected at least {} bytes, got {}", + PTRMAP_ENTRY_SIZE, + buffer.len() + ))); } buffer[0] = self.entry_type as u8; buffer[1..5].copy_from_slice(&self.parent_page_no.to_be_bytes()); diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index fff5abd5c..d6bc3aace 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -59,6 +59,7 @@ use crate::storage::btree::offset::{ use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min}; use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; +use crate::storage::encryption::EncryptionKey; use crate::storage::pager::Pager; use crate::storage::wal::{PendingFlush, READMARK_NOT_USED}; use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype}; @@ -307,6 +308,9 @@ impl Default for DatabaseHeader { page_size: Default::default(), write_version: Version::Wal, read_version: Version::Wal, + #[cfg(feature = "encryption")] + reserved_space: 28, + #[cfg(not(feature = "encryption"))] reserved_space: 0, max_embed_frac: 64, min_embed_frac: 32, @@ -869,6 +873,7 @@ pub fn begin_read_page( page: PageRef, page_idx: usize, allow_empty_read: bool, + encryption_key: Option<&EncryptionKey>, ) -> Result { tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx); let buf = buffer_pool.get_page(); @@ -891,7 +896,7 @@ pub fn begin_read_page( finish_read_page(page_idx, buf, page.clone()); }); let c = Completion::new_read(buf, complete); - db_file.read_page(page_idx, c) + db_file.read_page(page_idx, encryption_key, c) } #[instrument(skip_all, level = Level::INFO)] @@ -942,7 +947,7 @@ pub fn begin_write_btree_page(pager: &Pager, page: &PageRef) -> Result>, flush: &PendingFlush, + encryption_key: Option<&EncryptionKey>, ) -> Result> { if batch.is_empty() { return Ok(Vec::new()); @@ -1039,6 +1045,7 @@ pub fn write_pages_vectored( start_id, page_sz, std::mem::replace(&mut run_bufs, Vec::with_capacity(EST_BUFF_CAPACITY)), + encryption_key, c, ) { Ok(c) => { diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 1e48132fc..e7a4bd936 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -2,7 +2,7 @@ #![allow(clippy::not_unsafe_ptr_arg_deref)] use std::array; -use std::cell::UnsafeCell; +use std::cell::{RefCell, UnsafeCell}; use std::collections::{BTreeMap, HashMap, HashSet}; use strum::EnumString; use tracing::{instrument, Level}; @@ -11,9 +11,14 @@ use std::fmt::{Debug, Formatter}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::{cell::Cell, fmt, rc::Rc, sync::Arc}; +use self::sqlite3_ondisk::{checksum_wal, PageContent, WAL_MAGIC_BE, WAL_MAGIC_LE}; +use super::buffer_pool::BufferPool; +use super::pager::{PageRef, Pager}; +use super::sqlite3_ondisk::{self, WalHeader}; use crate::fast_lock::SpinLock; use crate::io::{File, IO}; use crate::result::LimboResult; +use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey}; use crate::storage::sqlite3_ondisk::{ begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame, write_pages_vectored, PageSize, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE, @@ -25,12 +30,6 @@ use crate::{ }; use crate::{Completion, Page}; -use self::sqlite3_ondisk::{checksum_wal, PageContent, WAL_MAGIC_BE, WAL_MAGIC_LE}; - -use super::buffer_pool::BufferPool; -use super::pager::{PageRef, Pager}; -use super::sqlite3_ondisk::{self, WalHeader}; - #[derive(Debug, Clone, Default)] pub struct CheckpointResult { /// number of frames in WAL that could have been backfilled @@ -287,6 +286,8 @@ pub trait Wal: Debug { /// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no fn changed_pages_after(&self, frame_watermark: u64) -> Result>; + fn set_encryption_key(&mut self, key: Option); + #[cfg(debug_assertions)] fn as_any(&self) -> &dyn std::any::Any; } @@ -446,6 +447,8 @@ pub struct WalFile { /// Manages locks needed for checkpointing checkpoint_guard: Option, + + encryption_key: RefCell>, } impl fmt::Debug for WalFile { @@ -913,6 +916,8 @@ impl Wal for WalFile { let offset = self.frame_offset(frame_id); page.set_locked(); let frame = page.clone(); + let page_idx = page.get().id; + let key = self.encryption_key.borrow().clone(); let complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { let Ok((buf, bytes_read)) = res else { page.clear_locked(); @@ -924,6 +929,17 @@ impl Wal for WalFile { "read({bytes_read}) less than expected({buf_len}): frame_id={frame_id}" ); let frame = frame.clone(); + if let Some(key) = key.clone() { + match decrypt_page(buf.as_slice(), page_idx, &key) { + Ok(decrypted_data) => { + buf.as_mut_slice().copy_from_slice(&decrypted_data); + } + Err(_) => { + tracing::error!("Failed to decrypt page data for frame_id={frame_id}"); + return; + } + } + } finish_read_page(page.get().id, buf, frame); }); begin_read_wal_frame( @@ -1076,6 +1092,21 @@ impl Wal for WalFile { let checksums = self.last_checksum; let page_content = page.get_contents(); let page_buf = page_content.as_ptr(); + + let key = self.encryption_key.borrow(); + let encrypted_data = { + if let Some(key) = key.as_ref() { + Some(encrypt_page(page_buf, page_id, key)?) + } else { + None + } + }; + let data_to_write = if key.as_ref().is_some() { + encrypted_data.as_ref().unwrap().as_slice() + } else { + page_buf + }; + let (frame_checksums, frame_bytes) = prepare_wal_frame( &self.buffer_pool, &header, @@ -1083,7 +1114,7 @@ impl Wal for WalFile { header.page_size, page_id as u32, db_size, - page_buf, + data_to_write, ); let c = Completion::new_write({ @@ -1227,6 +1258,10 @@ impl Wal for WalFile { fn as_any(&self) -> &dyn std::any::Any { self } + + fn set_encryption_key(&mut self, key: Option) { + self.encryption_key.replace(key); + } } impl WalFile { @@ -1266,6 +1301,7 @@ impl WalFile { prev_checkpoint: CheckpointResult::default(), checkpoint_guard: None, header: *header, + encryption_key: RefCell::new(None), } } @@ -1480,6 +1516,7 @@ impl WalFile { pager, std::mem::take(&mut self.ongoing_checkpoint.batch), &self.ongoing_checkpoint.pending_flush, + self.encryption_key.borrow().as_ref(), )?; // batch is queued self.ongoing_checkpoint.state = CheckpointState::AfterFlush; @@ -1893,7 +1930,7 @@ pub mod test { sync::{atomic::Ordering, Arc}, }; #[allow(clippy::arc_with_non_send_sync)] - fn get_database() -> (Arc, std::path::PathBuf) { + pub(crate) fn get_database() -> (Arc, std::path::PathBuf) { let mut path = tempfile::tempdir().unwrap().keep(); let dbpath = path.clone(); path.push("test.db"); diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index bac0309cd..f8afb272e 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -7,11 +7,15 @@ use std::sync::Arc; use turso_sqlite3_parser::ast::{self, ColumnDefinition, Expr, Literal, Name}; use turso_sqlite3_parser::ast::{PragmaName, QualifiedName}; +use super::integrity_check::translate_integrity_check; use crate::pragma::pragma_for; use crate::schema::Schema; +use crate::storage::encryption::EncryptionKey; use crate::storage::pager::AutoVacuumMode; +use crate::storage::pager::Pager; use crate::storage::sqlite3_ondisk::CacheSize; use crate::storage::wal::CheckpointMode; +use crate::translate::emitter::TransactionMode; use crate::translate::schema::translate_create_table; use crate::util::{normalize_ident, parse_signed_number, parse_string, IOExt as _}; use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts}; @@ -20,10 +24,6 @@ use crate::{bail_parse_error, CaptureDataChangesMode, LimboError, SymbolTable, V use std::str::FromStr; use strum::IntoEnumIterator; -use super::integrity_check::translate_integrity_check; -use crate::storage::pager::Pager; -use crate::translate::emitter::TransactionMode; - fn list_pragmas(program: &mut ProgramBuilder) { for x in PragmaName::iter() { let register = program.emit_string8_new_reg(x.to_string()); @@ -309,6 +309,12 @@ fn update_pragma( connection, program, ), + PragmaName::EncryptionKey => { + let value = parse_string(&value)?; + let key = EncryptionKey::from_string(&value); + connection.set_encryption_key(Some(key)); + Ok((program, TransactionMode::None)) + } } } @@ -566,6 +572,20 @@ fn query_pragma( program.add_pragma_result_column(pragma.to_string()); Ok((program, TransactionMode::None)) } + PragmaName::EncryptionKey => { + let msg = { + if connection.encryption_key.borrow().is_some() { + "encryption key is set for this session" + } else { + "encryption key is not set for this session" + } + }; + let register = program.alloc_register(); + program.emit_string8(msg.to_string(), register); + program.emit_result_row(register, 1); + program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::None)) + } } } diff --git a/tests/Cargo.toml b/tests/Cargo.toml index f2936a014..b26ca9b5f 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -32,3 +32,6 @@ zerocopy = "0.8.26" test-log = { version = "0.2.17", features = ["trace"] } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } tracing = "0.1.41" + +[features] +encryption = ["turso_core/encryption"] \ No newline at end of file diff --git a/tests/integration/query_processing/encryption.rs b/tests/integration/query_processing/encryption.rs new file mode 100644 index 000000000..c286bde7c --- /dev/null +++ b/tests/integration/query_processing/encryption.rs @@ -0,0 +1,70 @@ +use crate::common::{do_flush, TempDatabase}; +use crate::query_processing::test_write_path::{run_query, run_query_on_row}; +use rand::{rng, RngCore}; +use std::panic; +use turso_core::Row; + +#[test] +fn test_per_page_encryption() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let db_name = format!("test-{}.db", rng().next_u32()); + let tmp_db = TempDatabase::new(&db_name, false); + let db_path = tmp_db.path.clone(); + + { + let conn = tmp_db.connect_limbo(); + run_query( + &tmp_db, + &conn, + "PRAGMA key = 'super secret key for encryption';", + )?; + run_query( + &tmp_db, + &conn, + "CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT);", + )?; + run_query( + &tmp_db, + &conn, + "INSERT INTO test (value) VALUES ('Hello, World!')", + )?; + let mut row_count = 0; + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |row: &Row| { + assert_eq!(row.get::(0).unwrap(), 1); + assert_eq!(row.get::(1).unwrap(), "Hello, World!"); + row_count += 1; + })?; + assert_eq!(row_count, 1); + do_flush(&conn, &tmp_db)?; + } + + { + // this should panik because we should not be able to access the encrypted database + // without the key + let conn = tmp_db.connect_limbo(); + let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_: &Row| {}).unwrap(); + })); + assert!( + should_panic.is_err(), + "should panic when accessing encrypted DB without key" + ); + } + + { + // let's test the existing db with the key + let existing_db = TempDatabase::new_with_existent(&db_path, false); + let conn = existing_db.connect_limbo(); + run_query( + &existing_db, + &conn, + "PRAGMA key = 'super secret key for encryption';", + )?; + run_query_on_row(&existing_db, &conn, "SELECT * FROM test", |row: &Row| { + assert_eq!(row.get::(0).unwrap(), 1); + assert_eq!(row.get::(1).unwrap(), "Hello, World!"); + })?; + } + + Ok(()) +} diff --git a/tests/integration/query_processing/mod.rs b/tests/integration/query_processing/mod.rs index c77608718..742cdf52c 100644 --- a/tests/integration/query_processing/mod.rs +++ b/tests/integration/query_processing/mod.rs @@ -4,3 +4,6 @@ mod test_write_path; mod test_multi_thread; mod test_transactions; + +#[cfg(feature = "encryption")] +mod encryption; diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index e8d755799..de096827f 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -1829,6 +1829,11 @@ pub enum PragmaName { IntegrityCheck, /// `journal_mode` pragma JournalMode, + /// encryption key for encrypted databases. This is just called `key` because most + /// extensions use this name instead of `encryption_key`. + #[strum(serialize = "key")] + #[cfg_attr(feature = "serde", serde(rename = "key"))] + EncryptionKey, /// Noop as per SQLite docs LegacyFileFormat, /// Set or get the maximum number of pages in the database file.