diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 8e9be5eaf..1c13e8d66 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -317,7 +317,7 @@ impl Drop for Connection { #[allow(clippy::arc_with_non_send_sync)] #[pyfunction(signature = (path))] pub fn connect(path: &str) -> Result { - match turso_core::Connection::from_uri(path, true, false, false, false, false) { + match turso_core::Connection::from_uri(path, true, false, false, false, false, false) { Ok((io, conn)) => Ok(Connection { conn, _io: io }), Err(e) => Err(PyErr::new::(format!( "Failed to create connection: {e:?}" diff --git a/cli/app.rs b/cli/app.rs index 923d280a3..52423a4d3 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -78,6 +78,8 @@ pub struct Opts { pub mcp: bool, #[clap(long, help = "Enable experimental encryption feature")] pub experimental_encryption: bool, + #[clap(long, help = "Enable experimental custom modules feature")] + pub experimental_custom_modules: bool, } const PROMPT: &str = "turso> "; @@ -192,6 +194,7 @@ impl Limbo { opts.experimental_views, opts.experimental_strict, opts.experimental_encryption, + opts.experimental_custom_modules, )? } else { let flags = if opts.readonly { @@ -209,6 +212,7 @@ impl Limbo { .with_views(opts.experimental_views) .with_strict(opts.experimental_strict) .with_encryption(opts.experimental_encryption) + .with_custom_modules(opts.experimental_custom_modules) .turso_cli(), None, )?; diff --git a/cli/mcp_server.rs b/cli/mcp_server.rs index 8c09cd075..0798e130b 100644 --- a/cli/mcp_server.rs +++ b/cli/mcp_server.rs @@ -408,7 +408,7 @@ impl TursoMcpServer { // Open the new database connection let conn = if path == ":memory:" || path.contains([':', '?', '&', '#']) { - match Connection::from_uri(&path, true, false, false, false, false) { + match Connection::from_uri(&path, true, false, false, false, false, false) { Ok((_io, c)) => c, Err(e) => return format!("Failed to open database '{path}': {e}"), } diff --git a/core/incremental/cursor.rs b/core/incremental/cursor.rs index 12e3c49c6..cd7af669b 100644 --- a/core/incremental/cursor.rs +++ b/core/incremental/cursor.rs @@ -316,6 +316,7 @@ mod tests { enable_strict: false, enable_load_extension: false, enable_encryption: false, + enable_custom_modules: false, }, None, )?; diff --git a/core/lib.rs b/core/lib.rs index dae370af4..99d79113c 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -113,6 +113,7 @@ pub struct DatabaseOpts { pub enable_views: bool, pub enable_strict: bool, pub enable_encryption: bool, + pub enable_custom_modules: bool, enable_load_extension: bool, } @@ -124,6 +125,7 @@ impl Default for DatabaseOpts { enable_views: false, enable_strict: false, enable_encryption: false, + enable_custom_modules: false, enable_load_extension: false, } } @@ -164,6 +166,11 @@ impl DatabaseOpts { self.enable_encryption = enable; self } + + pub fn with_custom_modules(mut self, enable: bool) -> Self { + self.enable_custom_modules = enable; + self + } } #[derive(Clone, Debug, Default)] @@ -870,6 +877,10 @@ impl Database { self.opts.enable_views } + pub fn experimental_custom_modules_enabled(&self) -> bool { + self.opts.enable_custom_modules + } + pub fn experimental_strict_enabled(&self) -> bool { self.opts.enable_strict } @@ -1474,6 +1485,8 @@ impl Connection { strict: bool, // flag to opt-in encryption support encryption: bool, + // flag to opt-in custom modules support + custom_modules: bool, ) -> Result<(Arc, Arc)> { use crate::util::MEMORY_PATH; let opts = OpenOptions::parse(uri)?; @@ -1489,7 +1502,8 @@ impl Connection { .with_indexes(use_indexes) .with_views(views) .with_strict(strict) - .with_encryption(encryption), + .with_encryption(encryption) + .with_custom_modules(custom_modules), None, )?; let conn = db.connect()?; @@ -1518,7 +1532,8 @@ impl Connection { .with_indexes(use_indexes) .with_views(views) .with_strict(strict) - .with_encryption(encryption), + .with_encryption(encryption) + .with_custom_modules(custom_modules), encryption_opts.clone(), )?; if let Some(modeof) = opts.modeof { @@ -1997,6 +2012,10 @@ impl Connection { self.db.experimental_views_enabled() } + pub fn experimental_custom_modules_enabled(&self) -> bool { + self.db.experimental_custom_modules_enabled() + } + pub fn experimental_strict_enabled(&self) -> bool { self.db.experimental_strict_enabled() } diff --git a/core/translate/index.rs b/core/translate/index.rs index b680f560c..c79789b0d 100644 --- a/core/translate/index.rs +++ b/core/translate/index.rs @@ -29,15 +29,32 @@ use super::schema::{emit_schema_entry, SchemaEntryType, SQLITE_TABLEID}; #[allow(clippy::too_many_arguments)] pub fn translate_create_index( - unique_if_not_exists: (bool, bool), - resolver: &Resolver, - idx_name: &Name, - tbl_name: &Name, - columns: &[SortedColumn], mut program: ProgramBuilder, connection: &Arc, - where_clause: Option>, + resolver: &Resolver, + stmt: ast::Stmt, ) -> crate::Result { + let sql = stmt.to_string(); + let ast::Stmt::CreateIndex { + unique, + if_not_exists, + idx_name, + tbl_name, + columns, + where_clause, + with_clause, + using, + } = stmt + else { + panic!("translate_create_index must be called with CreateIndex AST node"); + }; + + if !connection.experimental_custom_modules_enabled() + && (using.is_some() || !with_clause.is_empty()) + { + bail_parse_error!("custom modules is an experimental feature. Enable with --experimental-custom-modules flag") + } + let original_idx_name = idx_name; let original_tbl_name = tbl_name; let idx_name = normalize_ident(idx_name.as_str()); diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 758031544..71fa52fb7 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -152,23 +152,9 @@ pub fn translate_inner( } ast::Stmt::Begin { typ, name } => translate_tx_begin(typ, name, resolver.schema, program)?, ast::Stmt::Commit { name } => translate_tx_commit(name, program)?, - ast::Stmt::CreateIndex { - unique, - if_not_exists, - idx_name, - tbl_name, - columns, - where_clause, - } => translate_create_index( - (unique, if_not_exists), - resolver, - &idx_name.name, - &tbl_name, - &columns, - program, - connection, - where_clause, - )?, + ast::Stmt::CreateIndex { .. } => { + translate_create_index(program, connection, resolver, stmt)? + } ast::Stmt::CreateTable { temporary, if_not_exists, diff --git a/tests/integration/query_processing/encryption.rs b/tests/integration/query_processing/encryption.rs index dd390b4f3..980443ea0 100644 --- a/tests/integration/query_processing/encryption.rs +++ b/tests/integration/query_processing/encryption.rs @@ -46,8 +46,15 @@ fn test_per_page_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327", db_path.to_str().unwrap() ); - let (_io, conn) = - turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; + let (_io, conn) = turso_core::Connection::from_uri( + &uri, + true, + false, + false, + false, + ENABLE_ENCRYPTION, + false, + )?; let mut row_count = 0; run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |row: &Row| { assert_eq!(row.get::(0).unwrap(), 1); @@ -62,8 +69,15 @@ fn test_per_page_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327", db_path.to_str().unwrap() ); - let (_io, conn) = - turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; + let (_io, conn) = turso_core::Connection::from_uri( + &uri, + true, + false, + false, + false, + ENABLE_ENCRYPTION, + false, + )?; run_query( &tmp_db, &conn, @@ -77,8 +91,15 @@ fn test_per_page_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327", db_path.to_str().unwrap() ); - let (_io, conn) = - turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; + let (_io, conn) = turso_core::Connection::from_uri( + &uri, + true, + false, + false, + false, + ENABLE_ENCRYPTION, + false, + )?; run_query( &tmp_db, &conn, @@ -100,8 +121,15 @@ fn test_per_page_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76377", db_path.to_str().unwrap() ); - let (_io, conn) = - turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; + let (_io, conn) = turso_core::Connection::from_uri( + &uri, + true, + false, + false, + false, + ENABLE_ENCRYPTION, + false, + )?; let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_row: &Row| {}).unwrap(); })); @@ -114,8 +142,16 @@ fn test_per_page_encryption() -> anyhow::Result<()> { //test connecting to encrypted db using insufficient encryption parameters in URI.This should panic. let uri = format!("file:{}?cipher=aegis256", db_path.to_str().unwrap()); let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { - turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION) - .unwrap(); + turso_core::Connection::from_uri( + &uri, + true, + false, + false, + false, + ENABLE_ENCRYPTION, + false, + ) + .unwrap(); })); assert!( should_panic.is_err(), @@ -128,8 +164,16 @@ fn test_per_page_encryption() -> anyhow::Result<()> { db_path.to_str().unwrap() ); let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { - turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION) - .unwrap(); + turso_core::Connection::from_uri( + &uri, + true, + false, + false, + false, + ENABLE_ENCRYPTION, + false, + ) + .unwrap(); })); assert!( should_panic.is_err(), @@ -195,8 +239,15 @@ fn test_non_4k_page_size_encryption() -> anyhow::Result<()> { "file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327", db_path.to_str().unwrap() ); - let (_io, conn) = - turso_core::Connection::from_uri(&uri, true, false, false, false, ENABLE_ENCRYPTION)?; + let (_io, conn) = turso_core::Connection::from_uri( + &uri, + true, + false, + false, + false, + ENABLE_ENCRYPTION, + false, + )?; run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |row: &Row| { assert_eq!(row.get::(0).unwrap(), 1); assert_eq!(row.get::(1).unwrap(), "Hello, World!"); @@ -261,6 +312,7 @@ fn test_corruption_turso_magic_bytes() -> anyhow::Result<()> { false, false, ENABLE_ENCRYPTION, + false, ) .unwrap(); run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_row: &Row| {}).unwrap(); @@ -353,6 +405,7 @@ fn test_corruption_associated_data_bytes() -> anyhow::Result<()> { false, false, ENABLE_ENCRYPTION, + false, ) .unwrap(); run_query_on_row(&test_tmp_db, &conn, "SELECT * FROM test", |_row: &Row| {})