diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 1f99101c2..8e9be5eaf 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -317,7 +317,7 @@ impl Drop for Connection { #[allow(clippy::arc_with_non_send_sync)] #[pyfunction(signature = (path))] pub fn connect(path: &str) -> Result { - match turso_core::Connection::from_uri(path, true, false, false, false) { + match turso_core::Connection::from_uri(path, true, false, false, false, false) { Ok((io, conn)) => Ok(Connection { conn, _io: io }), Err(e) => Err(PyErr::new::(format!( "Failed to create connection: {e:?}" diff --git a/cli/app.rs b/cli/app.rs index 5b4a6aff7..64729e035 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -78,6 +78,8 @@ pub struct Opts { pub mcp: bool, #[clap(long, help = "Enable experimental logical log feature")] pub experimental_logical_log: bool, + #[clap(long, help = "Enable experimental encryption feature")] + pub experimental_encryption: bool, } const PROMPT: &str = "turso> "; @@ -170,6 +172,7 @@ impl Limbo { opts.experimental_mvcc, opts.experimental_views, opts.experimental_strict, + opts.experimental_encryption, )? } else { let flags = if opts.readonly { @@ -186,6 +189,7 @@ impl Limbo { .with_indexes(indexes_enabled) .with_views(opts.experimental_views) .with_strict(opts.experimental_strict) + .with_encryption(opts.experimental_encryption) .turso_cli(), None, )?; diff --git a/cli/mcp_server.rs b/cli/mcp_server.rs index 5efb934d4..8c09cd075 100644 --- a/cli/mcp_server.rs +++ b/cli/mcp_server.rs @@ -408,7 +408,7 @@ impl TursoMcpServer { // Open the new database connection let conn = if path == ":memory:" || path.contains([':', '?', '&', '#']) { - match Connection::from_uri(&path, true, false, false, false) { + match Connection::from_uri(&path, true, false, false, false, false) { Ok((_io, c)) => c, Err(e) => return format!("Failed to open database '{path}': {e}"), } diff --git a/core/incremental/cursor.rs b/core/incremental/cursor.rs index 20bce4205..67814bd0d 100644 --- a/core/incremental/cursor.rs +++ b/core/incremental/cursor.rs @@ -317,6 +317,7 @@ mod tests { enable_views: true, enable_strict: false, enable_load_extension: false, + enable_encryption: false, }, None, )?; diff --git a/core/lib.rs b/core/lib.rs index 5404c40f8..167da3478 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -108,6 +108,7 @@ pub struct DatabaseOpts { pub enable_indexes: bool, pub enable_views: bool, pub enable_strict: bool, + pub enable_encryption: bool, enable_load_extension: bool, } @@ -118,6 +119,7 @@ impl Default for DatabaseOpts { enable_indexes: true, enable_views: false, enable_strict: false, + enable_encryption: false, enable_load_extension: false, } } @@ -153,6 +155,11 @@ impl DatabaseOpts { self.enable_strict = enable; self } + + pub fn with_encryption(mut self, enable: bool) -> Self { + self.enable_encryption = enable; + self + } } #[derive(Clone, Debug, Default)] @@ -493,6 +500,7 @@ impl Database { #[instrument(skip_all, level = Level::INFO)] pub fn connect(self: &Arc) -> Result> { let pager = self.init_pager(None)?; + pager.enable_encryption(self.opts.enable_encryption); let pager = Arc::new(pager); if self.mv_store.is_some() { @@ -1382,6 +1390,8 @@ impl Connection { mvcc: bool, views: bool, strict: bool, + // flag to opt-in encryption support + encryption: bool, ) -> Result<(Arc, Arc)> { use crate::util::MEMORY_PATH; let opts = OpenOptions::parse(uri)?; @@ -1396,7 +1406,8 @@ impl Connection { .with_mvcc(mvcc) .with_indexes(use_indexes) .with_views(views) - .with_strict(strict), + .with_strict(strict) + .with_encryption(encryption), None, )?; let conn = db.connect()?; @@ -1424,7 +1435,8 @@ impl Connection { .with_mvcc(mvcc) .with_indexes(use_indexes) .with_views(views) - .with_strict(strict), + .with_strict(strict) + .with_encryption(encryption), encryption_opts.clone(), )?; if let Some(modeof) = opts.modeof { @@ -1819,6 +1831,7 @@ impl Connection { } self.pager.write().clear_page_cache(); let pager = self.db.init_pager(Some(size.get() as usize))?; + pager.enable_encryption(self.db.opts.enable_encryption); *self.pager.write() = Arc::new(pager); self.pager.read().set_initial_page_size(size); diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 04539c2f5..a12bddb42 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -535,6 +535,8 @@ pub struct Pager { #[cfg(not(feature = "omit_autovacuum"))] vacuum_state: RwLock, pub(crate) io_ctx: RwLock, + /// encryption is an opt-in feature. we will enable it only if the flag is passed + enable_encryption: AtomicBool, } #[cfg(not(feature = "omit_autovacuum"))] @@ -645,6 +647,7 @@ impl Pager { btree_create_vacuum_full_state: BtreeCreateVacuumFullState::Start, }), io_ctx: RwLock::new(IOContext::default()), + enable_encryption: AtomicBool::new(false), }) } @@ -2407,6 +2410,14 @@ impl Pager { cipher_mode: CipherMode, key: &EncryptionKey, ) -> Result<()> { + // we will set the encryption context only if the encryption is opted-in. + if !self.enable_encryption.load(Ordering::SeqCst) { + return Err(LimboError::InvalidArgument( + "encryption is an opt in feature. enable it via passing `--experimental-encryption`" + .into(), + )); + } + let page_size = self.get_page_size_unchecked().get() as usize; let encryption_ctx = EncryptionContext::new(cipher_mode, key, page_size)?; { @@ -2432,6 +2443,12 @@ impl Pager { pub fn set_reserved_space_bytes(&self, value: u8) { self.set_reserved_space(value); } + + /// Encryption is an opt-in feature. If the flag is passed, then enable the encryption on + /// pager, which is then used to set it on the IOContext. + pub fn enable_encryption(&self, enable: bool) { + self.enable_encryption.store(enable, Ordering::SeqCst); + } } pub fn allocate_new_page(page_id: i64, buffer_pool: &Arc, offset: usize) -> PageRef { diff --git a/tests/integration/common.rs b/tests/integration/common.rs index 2a687ade2..7f6f0f6bf 100644 --- a/tests/integration/common.rs +++ b/tests/integration/common.rs @@ -29,7 +29,9 @@ impl TempDatabase { io.clone(), path.to_str().unwrap(), turso_core::OpenFlags::default(), - turso_core::DatabaseOpts::new().with_indexes(enable_indexes), + turso_core::DatabaseOpts::new() + .with_indexes(enable_indexes) + .with_encryption(true), None, ) .unwrap(); diff --git a/tests/integration/query_processing/encryption.rs b/tests/integration/query_processing/encryption.rs index 3fc7a8e63..dd390b4f3 100644 --- a/tests/integration/query_processing/encryption.rs +++ b/tests/integration/query_processing/encryption.rs @@ -3,6 +3,8 @@ use rand::{rng, RngCore}; use std::panic; use turso_core::Row; +const ENABLE_ENCRYPTION: bool = true; + #[test] fn test_per_page_encryption() -> anyhow::Result<()> { let _ = env_logger::try_init(); @@ -44,7 +46,8 @@ fn test_per_page_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327", db_path.to_str().unwrap() ); - let (_io, conn) = turso_core::Connection::from_uri(&uri, true, false, false, false)?; + let (_io, conn) = + turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; let mut row_count = 0; run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |row: &Row| { assert_eq!(row.get::(0).unwrap(), 1); @@ -59,7 +62,8 @@ fn test_per_page_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327", db_path.to_str().unwrap() ); - let (_io, conn) = turso_core::Connection::from_uri(&uri, true, false, false, false)?; + let (_io, conn) = + turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; run_query( &tmp_db, &conn, @@ -73,7 +77,8 @@ fn test_per_page_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327", db_path.to_str().unwrap() ); - let (_io, conn) = turso_core::Connection::from_uri(&uri, true, false, false, false)?; + let (_io, conn) = + turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; run_query( &tmp_db, &conn, @@ -95,7 +100,8 @@ fn test_per_page_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76377", db_path.to_str().unwrap() ); - let (_io, conn) = turso_core::Connection::from_uri(&uri, true, false, false, false)?; + let (_io, conn) = + turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_row: &Row| {}).unwrap(); })); @@ -108,7 +114,8 @@ fn test_per_page_encryption() -> anyhow::Result<()> { //test connecting to encrypted db using insufficient encryption parameters in URI.This should panic. let uri = format!("file:{}?cipher=aegis256", db_path.to_str().unwrap()); let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { - turso_core::Connection::from_uri(&uri, true, false, false, false).unwrap(); + turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION) + .unwrap(); })); assert!( should_panic.is_err(), @@ -121,7 +128,8 @@ fn test_per_page_encryption() -> anyhow::Result<()> { db_path.to_str().unwrap() ); let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { - turso_core::Connection::from_uri(&uri, true, false, false, false).unwrap(); + turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION) + .unwrap(); })); assert!( should_panic.is_err(), @@ -187,7 +195,8 @@ fn test_non_4k_page_size_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327", db_path.to_str().unwrap() ); - let (_io, conn) = turso_core::Connection::from_uri(&uri, true, false, false, false)?; + let (_io, conn) = + turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |row: &Row| { assert_eq!(row.get::(0).unwrap(), 1); assert_eq!(row.get::(1).unwrap(), "Hello, World!"); @@ -245,8 +254,15 @@ fn test_corruption_turso_magic_bytes() -> anyhow::Result<()> { ); let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { - let (_io, conn) = - turso_core::Connection::from_uri(&uri, true, false, false, false).unwrap(); + let (_io, conn) = turso_core::Connection::from_uri( + &uri, + true, + false, + false, + false, + ENABLE_ENCRYPTION, + ) + .unwrap(); run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_row: &Row| {}).unwrap(); })); @@ -330,8 +346,15 @@ fn test_corruption_associated_data_bytes() -> anyhow::Result<()> { ); let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { - let (_io, conn) = - turso_core::Connection::from_uri(&uri, true, false, false, false).unwrap(); + let (_io, conn) = turso_core::Connection::from_uri( + &uri, + true, + false, + false, + false, + ENABLE_ENCRYPTION, + ) + .unwrap(); run_query_on_row(&test_tmp_db, &conn, "SELECT * FROM test", |_row: &Row| {}) .unwrap(); }));