diff --git a/COMPAT.md b/COMPAT.md index 5584899f4..5cd309bf2 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -120,7 +120,7 @@ Feature support of [sqlite expr syntax](https://www.sqlite.org/lang_expr.html). | like(X,Y,Z) | Yes | | | likelihood(X,Y) | No | | | likely(X) | No | | -| load_extension(X) | No | | +| load_extension(X) | Yes | sqlite3 extensions not yet supported | | load_extension(X,Y) | No | | | lower(X) | Yes | | | ltrim(X) | Yes | | diff --git a/cli/app.rs b/cli/app.rs index a0a7b5d99..26a99c470 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -325,7 +325,10 @@ impl Limbo { #[cfg(not(target_family = "wasm"))] fn handle_load_extension(&mut self, path: &str) -> Result<(), String> { - self.conn.load_extension(path).map_err(|e| e.to_string()) + let ext_path = limbo_core::resolve_ext_path(path).map_err(|e| e.to_string())?; + self.conn + .load_extension(ext_path) + .map_err(|e| e.to_string()) } fn display_in_memory(&mut self) -> std::io::Result<()> { diff --git a/core/function.rs b/core/function.rs index 68b5ef005..75a90e798 100644 --- a/core/function.rs +++ b/core/function.rs @@ -136,6 +136,8 @@ pub enum ScalarFunc { ZeroBlob, LastInsertRowid, Replace, + #[cfg(not(target_family = "wasm"))] + LoadExtension, } impl Display for ScalarFunc { @@ -185,6 +187,8 @@ impl Display for ScalarFunc { Self::LastInsertRowid => "last_insert_rowid".to_string(), Self::Replace => "replace".to_string(), Self::DateTime => "datetime".to_string(), + #[cfg(not(target_family = "wasm"))] + Self::LoadExtension => "load_extension".to_string(), }; write!(f, "{}", str) } @@ -426,6 +430,8 @@ impl Func { "tan" => Ok(Self::Math(MathFunc::Tan)), "tanh" => Ok(Self::Math(MathFunc::Tanh)), "trunc" => Ok(Self::Math(MathFunc::Trunc)), + #[cfg(not(target_family = "wasm"))] + "load_extension" => Ok(Self::Scalar(ScalarFunc::LoadExtension)), _ => Err(()), } } diff --git a/core/lib.rs b/core/lib.rs index 8f8914b0f..a0bd6d98c 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -182,7 +182,7 @@ impl Database { } #[cfg(not(target_family = "wasm"))] - pub fn load_extension(&self, path: &str) -> Result<()> { + pub fn load_extension>(&self, path: P) -> Result<()> { let api = Box::new(self.build_limbo_extension()); let lib = unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? }; @@ -401,7 +401,7 @@ impl Connection { } #[cfg(not(target_family = "wasm"))] - pub fn load_extension(&self, path: &str) -> Result<()> { + pub fn load_extension>(&self, path: P) -> Result<()> { Database::load_extension(self.db.as_ref(), path) } @@ -515,6 +515,33 @@ impl std::fmt::Debug for SymbolTable { } } +fn is_shared_library(path: &std::path::Path) -> bool { + path.extension() + .map_or(false, |ext| ext == "so" || ext == "dylib" || ext == "dll") +} + +pub fn resolve_ext_path(extpath: &str) -> Result { + let path = std::path::Path::new(extpath); + if !path.exists() { + if is_shared_library(path) { + return Err(LimboError::ExtensionError(format!( + "Extension file not found: {}", + extpath + ))); + }; + let maybe = path.with_extension(std::env::consts::DLL_EXTENSION); + maybe + .exists() + .then_some(maybe) + .ok_or(LimboError::ExtensionError(format!( + "Extension file not found: {}", + extpath + ))) + } else { + Ok(path.to_path_buf()) + } +} + impl SymbolTable { pub fn new() -> Self { Self { diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 5fbc5b260..4a0677d8d 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1092,6 +1092,19 @@ pub fn translate_expr( }); Ok(target_register) } + #[cfg(not(target_family = "wasm"))] + ScalarFunc::LoadExtension => { + let args = expect_arguments_exact!(args, 1, srf); + let reg = + translate_and_mark(program, referenced_tables, &args[0], resolver)?; + program.emit_insn(Insn::Function { + constant_mask: 0, + start_reg: reg, + dest: target_register, + func: func_ctx, + }); + Ok(target_register) + } ScalarFunc::Random => { if args.is_some() { crate::bail_parse_error!( diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 580834d39..9da225f44 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -41,7 +41,7 @@ use crate::{ json::json_arrow_extract, json::json_arrow_shift_extract, json::json_error_position, json::json_extract, json::json_type, }; -use crate::{Connection, Result, Rows, TransactionState, DATABASE_VERSION}; +use crate::{resolve_ext_path, Connection, Result, Rows, TransactionState, DATABASE_VERSION}; use datetime::{exec_date, exec_datetime_full, exec_julianday, exec_time, exec_unixepoch}; use insn::{ exec_add, exec_bit_and, exec_bit_not, exec_bit_or, exec_divide, exec_multiply, exec_remainder, @@ -1863,6 +1863,14 @@ impl Program { let replacement = &state.registers[*start_reg + 2]; state.registers[*dest] = exec_replace(source, pattern, replacement); } + #[cfg(not(target_family = "wasm"))] + ScalarFunc::LoadExtension => { + let extension = &state.registers[*start_reg]; + let ext = resolve_ext_path(&extension.to_string())?; + if let Some(conn) = self.connection.upgrade() { + conn.load_extension(ext)?; + } + } }, crate::function::Func::External(f) => { call_external_function! {f.func, *dest, state, arg_count, *start_reg };