diff --git a/core/ext/dynamic.rs b/core/ext/dynamic.rs index df342caca..17138f268 100644 --- a/core/ext/dynamic.rs +++ b/core/ext/dynamic.rs @@ -6,6 +6,7 @@ use libloading::{Library, Symbol}; use limbo_ext::{ExtensionApi, ExtensionApiRef, ExtensionEntryPoint, ResultCode, VfsImpl}; use std::{ ffi::{c_char, CString}, + rc::Rc, sync::{Arc, Mutex, OnceLock}, }; @@ -29,7 +30,10 @@ unsafe impl Send for VfsMod {} unsafe impl Sync for VfsMod {} impl Connection { - pub fn load_extension>(&self, path: P) -> crate::Result<()> { + pub fn load_extension>( + self: &Rc, + path: P, + ) -> crate::Result<()> { use limbo_ext::ExtensionApiRef; let api = Box::new(self.build_limbo_ext()); @@ -44,7 +48,15 @@ impl Connection { let result_code = unsafe { entry(api_ptr) }; if result_code.is_ok() { let extensions = get_extension_libraries(); - extensions.lock().unwrap().push((Arc::new(lib), api_ref)); + extensions + .lock() + .map_err(|_| { + LimboError::ExtensionError("Error locking extension libraries".to_string()) + })? + .push((Arc::new(lib), api_ref)); + { + self.parse_schema_rows()?; + } Ok(()) } else { if !api_ptr.is_null() { diff --git a/core/lib.rs b/core/lib.rs index e364b226d..cd728fa64 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -70,7 +70,7 @@ use vdbe::{builder::QueryMode, VTabOpaqueCursor}; pub type Result = std::result::Result; pub static DATABASE_VERSION: OnceLock = OnceLock::new(); -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq)] enum TransactionState { Write, Read, @@ -158,7 +158,13 @@ impl Database { .try_write() .expect("lock on schema should succeed first try"); let syms = conn.syms.borrow(); - parse_schema_rows(rows, &mut schema, io, syms.deref(), None)?; + if let Err(LimboError::ExtensionError(e)) = + parse_schema_rows(rows, &mut schema, io, &syms, None) + { + // this means that a vtab exists and we no longer have the module loaded. we print + // a warning to the user to load the module + eprintln!("Warning: {}", e); + } } Ok(db) } @@ -186,9 +192,9 @@ impl Database { schema: self.schema.clone(), header: self.header.clone(), last_insert_rowid: Cell::new(0), - auto_commit: RefCell::new(true), + auto_commit: Cell::new(true), mv_transactions: RefCell::new(Vec::new()), - transaction_state: RefCell::new(TransactionState::None), + transaction_state: Cell::new(TransactionState::None), last_change: Cell::new(0), syms: RefCell::new(SymbolTable::new()), total_changes: Cell::new(0), @@ -278,9 +284,9 @@ pub struct Connection { pager: Rc, schema: Arc>, header: Arc>, - auto_commit: RefCell, + auto_commit: Cell, mv_transactions: RefCell>, - transaction_state: RefCell, + transaction_state: Cell, last_insert_rowid: Cell, last_change: Cell, total_changes: Cell, @@ -517,7 +523,26 @@ impl Connection { } pub fn get_auto_commit(&self) -> bool { - *self.auto_commit.borrow() + self.auto_commit.get() + } + + pub fn parse_schema_rows(self: &Rc) -> Result<()> { + let rows = self.query("SELECT * FROM sqlite_schema")?; + let mut schema = self + .schema + .try_write() + .expect("lock on schema should succeed first try"); + { + let syms = self.syms.borrow(); + if let Err(LimboError::ExtensionError(e)) = + parse_schema_rows(rows, &mut schema, self.pager.io.clone(), &syms, None) + { + // this means that a vtab exists and we no longer have the module loaded. we print + // a warning to the user to load the module + eprintln!("Warning: {}", e); + } + } + Ok(()) } } @@ -630,7 +655,7 @@ impl VirtualTable { module_name )))?; if let VTabKind::VirtualTable = kind { - if module.module_kind != VTabKind::VirtualTable { + if module.module_kind == VTabKind::TableValuedFunction { return Err(LimboError::ExtensionError(format!( "{} is not a virtual table module", module_name diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 23b937019..514bc21ab 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -135,6 +135,7 @@ fn prologue<'a>( Ok((t_ctx, init_label, start_offset)) } +#[derive(Clone, Copy, Debug)] pub enum TransactionMode { None, Read, diff --git a/core/util.rs b/core/util.rs index f17699233..3d12a2c6e 100644 --- a/core/util.rs +++ b/core/util.rs @@ -60,7 +60,35 @@ pub fn parse_schema_rows( let sql: &str = row.get::<&str>(4)?; if root_page == 0 && sql.to_lowercase().contains("create virtual") { let name: &str = row.get::<&str>(1)?; - let vtab = syms.vtabs.get(name).unwrap().clone(); + // a virtual table is found in the sqlite_schema, but it's no + // longer in the in-memory schema. We need to recreate it if + // the module is loaded in the symbol table. + let vtab = if let Some(vtab) = syms.vtabs.get(name) { + vtab.clone() + } else { + let mod_name = module_name_from_sql(sql)?; + if let Some(vmod) = syms.vtab_modules.get(mod_name) { + if let limbo_ext::VTabKind::VirtualTable = vmod.module_kind + { + crate::VirtualTable::from_args( + Some(name), + mod_name, + module_args_from_sql(sql)?, + syms, + vmod.module_kind, + None, + )? + } else { + return Err(LimboError::Corrupt("Table valued function: {name} registered as virtual table in schema".to_string())); + } + } else { + // the extension isn't loaded, so we emit a warning. + return Err(LimboError::ExtensionError(format!( + "Virtual table module '{}' not found\nPlease load extension", + &mod_name + ))); + } + }; schema.add_virtual_table(vtab); } else { let table = schema::BTreeTable::from_sql(sql, root_page as usize)?; @@ -132,6 +160,99 @@ pub fn check_ident_equivalency(ident1: &str, ident2: &str) -> bool { strip_quotes(ident1).eq_ignore_ascii_case(strip_quotes(ident2)) } +fn module_name_from_sql(sql: &str) -> Result<&str> { + if let Some(start) = sql.find("USING") { + let start = start + 6; + // stop at the first space, semicolon, or parenthesis + let end = sql[start..] + .find(|c: char| c.is_whitespace() || c == ';' || c == '(') + .unwrap_or(sql.len() - start) + + start; + Ok(sql[start..end].trim()) + } else { + Err(LimboError::InvalidArgument( + "Expected 'USING' in module name".to_string(), + )) + } +} + +// CREATE VIRTUAL TABLE table_name USING module_name(arg1, arg2, ...); +// CREATE VIRTUAL TABLE table_name USING module_name; +fn module_args_from_sql(sql: &str) -> Result> { + if !sql.contains('(') { + return Ok(vec![]); + } + let start = sql.find('(').ok_or_else(|| { + LimboError::InvalidArgument("Expected '(' in module argument list".to_string()) + })? + 1; + let end = sql.rfind(')').ok_or_else(|| { + LimboError::InvalidArgument("Expected ')' in module argument list".to_string()) + })?; + + let mut args = Vec::new(); + let mut current_arg = String::new(); + let mut chars = sql[start..end].chars().peekable(); + let mut in_quotes = false; + + while let Some(c) = chars.next() { + match c { + '\'' => { + if in_quotes { + if chars.peek() == Some(&'\'') { + // Escaped quote + current_arg.push('\''); + chars.next(); + } else { + in_quotes = false; + args.push(limbo_ext::Value::from_text(current_arg.trim().to_string())); + current_arg.clear(); + // Skip until comma or end + while let Some(&nc) = chars.peek() { + if nc == ',' { + chars.next(); // Consume comma + break; + } else if nc.is_whitespace() { + chars.next(); + } else { + return Err(LimboError::InvalidArgument( + "Unexpected characters after quoted argument".to_string(), + )); + } + } + } + } else { + in_quotes = true; + } + } + ',' => { + if !in_quotes { + if !current_arg.trim().is_empty() { + args.push(limbo_ext::Value::from_text(current_arg.trim().to_string())); + current_arg.clear(); + } + } else { + current_arg.push(c); + } + } + _ => { + current_arg.push(c); + } + } + } + + if !current_arg.trim().is_empty() && !in_quotes { + args.push(limbo_ext::Value::from_text(current_arg.trim().to_string())); + } + + if in_quotes { + return Err(LimboError::InvalidArgument( + "Unterminated string literal in module arguments".to_string(), + )); + } + + Ok(args) +} + pub fn check_literal_equivalency(lhs: &Literal, rhs: &Literal) -> bool { match (lhs, rhs) { (Literal::Numeric(n1), Literal::Numeric(n2)) => cmp_numeric_strings(n1, n2), @@ -1632,4 +1753,88 @@ pub mod tests { Ok((OwnedValueType::Float, "1.23e4")) ); } + + #[test] + fn test_module_name_basic() { + let sql = "CREATE VIRTUAL TABLE x USING y;"; + assert_eq!(module_name_from_sql(sql).unwrap(), "y"); + } + + #[test] + fn test_module_name_with_args() { + let sql = "CREATE VIRTUAL TABLE x USING modname('a', 'b');"; + assert_eq!(module_name_from_sql(sql).unwrap(), "modname"); + } + + #[test] + fn test_module_name_missing_using() { + let sql = "CREATE VIRTUAL TABLE x (a, b);"; + assert!(module_name_from_sql(sql).is_err()); + } + + #[test] + fn test_module_name_no_semicolon() { + let sql = "CREATE VIRTUAL TABLE x USING limbo(a, b)"; + assert_eq!(module_name_from_sql(sql).unwrap(), "limbo"); + } + + #[test] + fn test_module_name_no_semicolon_or_args() { + let sql = "CREATE VIRTUAL TABLE x USING limbo"; + assert_eq!(module_name_from_sql(sql).unwrap(), "limbo"); + } + + #[test] + fn test_module_args_none() { + let sql = "CREATE VIRTUAL TABLE x USING modname;"; + let args = module_args_from_sql(sql).unwrap(); + assert_eq!(args.len(), 0); + } + + #[test] + fn test_module_args_basic() { + let sql = "CREATE VIRTUAL TABLE x USING modname('arg1', 'arg2');"; + let args = module_args_from_sql(sql).unwrap(); + assert_eq!(args.len(), 2); + assert_eq!("arg1", args[0].to_text().unwrap()); + assert_eq!("arg2", args[1].to_text().unwrap()); + for arg in args { + unsafe { arg.__free_internal_type() } + } + } + + #[test] + fn test_module_args_with_escaped_quote() { + let sql = "CREATE VIRTUAL TABLE x USING modname('a''b', 'c');"; + let args = module_args_from_sql(sql).unwrap(); + assert_eq!(args.len(), 2); + assert_eq!(args[0].to_text().unwrap(), "a'b"); + assert_eq!(args[1].to_text().unwrap(), "c"); + for arg in args { + unsafe { arg.__free_internal_type() } + } + } + + #[test] + fn test_module_args_unterminated_string() { + let sql = "CREATE VIRTUAL TABLE x USING modname('arg1, 'arg2');"; + assert!(module_args_from_sql(sql).is_err()); + } + + #[test] + fn test_module_args_extra_garbage_after_quote() { + let sql = "CREATE VIRTUAL TABLE x USING modname('arg1'x);"; + assert!(module_args_from_sql(sql).is_err()); + } + + #[test] + fn test_module_args_trailing_comma() { + let sql = "CREATE VIRTUAL TABLE x USING modname('arg1',);"; + let args = module_args_from_sql(sql).unwrap(); + assert_eq!(args.len(), 1); + assert_eq!("arg1", args[0].to_text().unwrap()); + for arg in args { + unsafe { arg.__free_internal_type() } + } + } } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 09c283ecd..654e9a2c5 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -917,12 +917,21 @@ pub fn op_vcreate( "Failed to upgrade Connection".to_string(), )); }; + let mod_type = conn + .syms + .borrow() + .vtab_modules + .get(&module_name) + .ok_or_else(|| { + crate::LimboError::ExtensionError(format!("Module {} not found", module_name)) + })? + .module_kind; let table = crate::VirtualTable::from_args( Some(&table_name), &module_name, args, &conn.syms.borrow(), - limbo_ext::VTabKind::VirtualTable, + mod_type, None, )?; { @@ -1542,8 +1551,8 @@ pub fn op_transaction( } } else { let connection = program.connection.upgrade().unwrap(); - let current_state = connection.transaction_state.borrow().clone(); - let (new_transaction_state, updated) = match (¤t_state, write) { + let current_state = connection.transaction_state.get(); + let (new_transaction_state, updated) = match (current_state, write) { (TransactionState::Write, true) => (TransactionState::Write, false), (TransactionState::Write, false) => (TransactionState::Write, false), (TransactionState::Read, true) => (TransactionState::Write, true), @@ -1597,7 +1606,7 @@ pub fn op_auto_commit( }; } - if *auto_commit != *conn.auto_commit.borrow() { + if *auto_commit != conn.auto_commit.get() { if *rollback { todo!("Rollback is not implemented"); } else { @@ -4227,13 +4236,15 @@ pub fn op_parse_schema( ))?; let mut schema = conn.schema.write(); // TODO: This function below is synchronous, make it async - parse_schema_rows( - Some(stmt), - &mut schema, - conn.pager.io.clone(), - &conn.syms.borrow(), - state.mv_tx_id, - )?; + { + parse_schema_rows( + Some(stmt), + &mut schema, + conn.pager.io.clone(), + &conn.syms.borrow(), + state.mv_tx_id, + )?; + } state.pc += 1; Ok(InsnFunctionStepResult::Step) } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 8794b208a..550f21164 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -386,7 +386,7 @@ impl Program { ) -> Result { if let Some(mv_store) = mv_store { let conn = self.connection.upgrade().unwrap(); - let auto_commit = *conn.auto_commit.borrow(); + let auto_commit = conn.auto_commit.get(); if auto_commit { let mut mv_transactions = conn.mv_transactions.borrow_mut(); for tx_id in mv_transactions.iter() { @@ -400,7 +400,7 @@ impl Program { .connection .upgrade() .expect("only weak ref to connection?"); - let auto_commit = *connection.auto_commit.borrow(); + let auto_commit = connection.auto_commit.get(); tracing::trace!("Halt auto_commit {}", auto_commit); assert!( program_state.halt_state.is_none() @@ -409,7 +409,7 @@ impl Program { if program_state.halt_state.is_some() { self.step_end_write_txn(&pager, &mut program_state.halt_state, connection.deref()) } else if auto_commit { - let current_state = connection.transaction_state.borrow().clone(); + let current_state = connection.transaction_state.get(); match current_state { TransactionState::Write => self.step_end_write_txn( &pager, diff --git a/testing/cli_tests/extensions.py b/testing/cli_tests/extensions.py index bab8cb74f..ac870ee4d 100755 --- a/testing/cli_tests/extensions.py +++ b/testing/cli_tests/extensions.py @@ -345,10 +345,10 @@ def test_kv(): limbo = TestLimboShell() limbo.run_test_fn( "create virtual table t using kv_store;", - lambda res: "Virtual table module not found: kv_store" in res, + lambda res: "Module kv_store not found" in res, ) limbo.execute_dot(f".load {ext_path}") - limbo.debug_print( + limbo.execute_dot( "create virtual table t using kv_store;", ) limbo.run_test_fn(".schema", lambda res: "CREATE VIRTUAL TABLE t" in res)