diff --git a/cli/app.rs b/cli/app.rs index 121cd0c4e..4bb389384 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -137,11 +137,14 @@ impl Limbo { let conn = db.connect()?; (io, conn) }; - let mut ext_api = conn.build_turso_ext(); - if unsafe { !limbo_completion::register_extension_static(&mut ext_api).is_ok() } { - return Err(anyhow!( - "Failed to register completion extension".to_string() - )); + unsafe { + let mut ext_api = conn._build_turso_ext(); + if !limbo_completion::register_extension_static(&mut ext_api).is_ok() { + return Err(anyhow!( + "Failed to register completion extension".to_string() + )); + } + conn._free_extension_ctx(ext_api); } let interrupt_count = Arc::new(AtomicUsize::new(0)); { diff --git a/core/ext/dynamic.rs b/core/ext/dynamic.rs index 60f76547f..851e38b67 100644 --- a/core/ext/dynamic.rs +++ b/core/ext/dynamic.rs @@ -35,7 +35,7 @@ impl Connection { ) -> crate::Result<()> { use turso_ext::ExtensionApiRef; - let api = Box::new(self.build_turso_ext()); + let api = Box::new(unsafe { self._build_turso_ext() }); let lib = unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? }; let entry: Symbol = unsafe { diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 89c3e1a61..52761741b 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,23 +1,76 @@ #[cfg(feature = "fs")] mod dynamic; mod vtab_xconnect; -use crate::vtab::VirtualTable; +use crate::schema::{Schema, Table}; #[cfg(all(target_os = "linux", feature = "io_uring"))] use crate::UringIO; use crate::{function::ExternalFunc, Connection, Database, LimboError, IO}; +use crate::{vtab::VirtualTable, SymbolTable}; #[cfg(feature = "fs")] pub use dynamic::{add_builtin_vfs_extensions, add_vfs_module, list_vfs_modules, VfsMod}; use std::{ ffi::{c_char, c_void, CStr, CString}, rc::Rc, - sync::Arc, + sync::{Arc, Mutex}, }; use turso_ext::{ ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabKind, VTabModuleImpl, }; pub use turso_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; pub use vtab_xconnect::{close, execute, prepare_stmt}; -type ExternAggFunc = (InitAggFunction, StepFunction, FinalizeFunction); + +/// The context passed to extensions to register with Core +/// along with the function pointers +#[repr(C)] +pub struct ExtensionCtx { + syms: *mut SymbolTable, + schema: *mut c_void, +} + +pub(crate) unsafe extern "C" fn register_vtab_module( + ctx: *mut c_void, + name: *const c_char, + module: VTabModuleImpl, + kind: VTabKind, +) -> ResultCode { + if name.is_null() || ctx.is_null() { + return ResultCode::Error; + } + + let c_str = unsafe { CString::from_raw(name as *mut c_char) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + + let ext_ctx = unsafe { &mut *(ctx as *mut ExtensionCtx) }; + let module = Rc::new(module); + let vmodule = VTabImpl { + module_kind: kind, + implementation: module, + }; + + unsafe { + let syms = &mut *ext_ctx.syms; + syms.vtab_modules.insert(name_str.clone(), vmodule.into()); + + if kind == VTabKind::TableValuedFunction { + if let Ok(vtab) = VirtualTable::function(&name_str, syms) { + // Use the schema handler to insert the table + let table = Arc::new(Table::Virtual(vtab)); + let mutex = &*(ext_ctx.schema as *mut Mutex>); + let Ok(guard) = mutex.lock() else { + return ResultCode::Error; + }; + let schema_ptr = Arc::as_ptr(&*guard) as *mut Schema; + (*schema_ptr).tables.insert(name_str, table); + } else { + return ResultCode::Error; + } + } + } + ResultCode::OK +} #[derive(Clone)] pub struct VTabImpl { @@ -38,8 +91,14 @@ pub(crate) unsafe extern "C" fn register_scalar_function( if ctx.is_null() { return ResultCode::Error; } - let conn = unsafe { &*(ctx as *const Connection) }; - conn.register_scalar_function_impl(&name_str, func) + let ext_ctx = unsafe { &mut *(ctx as *mut ExtensionCtx) }; + unsafe { + (*ext_ctx.syms).functions.insert( + name_str.clone(), + Rc::new(ExternalFunc::new_scalar(name_str, func)), + ); + } + ResultCode::OK } pub(crate) unsafe extern "C" fn register_aggregate_function( @@ -58,30 +117,18 @@ pub(crate) unsafe extern "C" fn register_aggregate_function( if ctx.is_null() { return ResultCode::Error; } - let conn = unsafe { &*(ctx as *const Connection) }; - conn.register_aggregate_function_impl(&name_str, args, (init_func, step_func, finalize_func)) -} - -pub(crate) unsafe extern "C" fn register_vtab_module( - ctx: *mut c_void, - name: *const c_char, - module: VTabModuleImpl, - kind: VTabKind, -) -> ResultCode { - if name.is_null() || ctx.is_null() { - return ResultCode::Error; + let ext_ctx = unsafe { &mut *(ctx as *mut ExtensionCtx) }; + unsafe { + (*ext_ctx.syms).functions.insert( + name_str.clone(), + Rc::new(ExternalFunc::new_aggregate( + name_str, + args, + (init_func, step_func, finalize_func), + )), + ); } - let c_str = unsafe { CString::from_raw(name as *mut _) }; - let name_str = match c_str.to_str() { - Ok(s) => s.to_string(), - Err(_) => return ResultCode::Error, - }; - if ctx.is_null() { - return ResultCode::Error; - } - let conn = unsafe { &mut *(ctx as *mut Connection) }; - - conn.register_vtab_module_impl(&name_str, module, kind) + ResultCode::OK } impl Database { @@ -110,58 +157,74 @@ impl Database { let db = Self::open_file(io.clone(), path, false, false)?; Ok((io, db)) } + + /// Register any built-in extensions that can be stored on the Database so we do not have + /// to register these once-per-connection, and the connection can just extend its symbol table + pub fn register_global_builtin_extensions(&self) -> Result<(), String> { + let syms = self.builtin_syms.as_ptr(); + // Pass the mutex pointer and the appropriate handler + let schema_mutex_ptr = &self.schema as *const Mutex> as *mut Mutex>; + let ctx = Box::into_raw(Box::new(ExtensionCtx { + syms, + schema: schema_mutex_ptr as *mut c_void, + })); + let mut ext_api = ExtensionApi { + ctx: ctx as *mut c_void, + register_scalar_function, + register_aggregate_function, + register_vtab_module, + #[cfg(feature = "fs")] + vfs_interface: turso_ext::VfsInterface { + register_vfs: dynamic::register_vfs, + builtin_vfs: std::ptr::null_mut(), + builtin_vfs_count: 0, + }, + }; + + #[cfg(feature = "uuid")] + crate::uuid::register_extension(&mut ext_api); + #[cfg(feature = "series")] + crate::series::register_extension(&mut ext_api); + #[cfg(feature = "fs")] + { + let vfslist = add_builtin_vfs_extensions(Some(ext_api)).map_err(|e| e.to_string())?; + for (name, vfs) in vfslist { + add_vfs_module(name, vfs); + } + } + let _ = unsafe { Box::from_raw(ctx) }; + Ok(()) + } } impl Connection { - fn register_scalar_function_impl(&self, name: &str, func: ScalarFunction) -> ResultCode { - self.syms.borrow_mut().functions.insert( - name.to_string(), - Rc::new(ExternalFunc::new_scalar(name.to_string(), func)), - ); - ResultCode::OK - } - - fn register_aggregate_function_impl( - &self, - name: &str, - args: i32, - func: ExternAggFunc, - ) -> ResultCode { - self.syms.borrow_mut().functions.insert( - name.to_string(), - Rc::new(ExternalFunc::new_aggregate(name.to_string(), args, func)), - ); - ResultCode::OK - } - - fn register_vtab_module_impl( - &mut self, - name: &str, - module: VTabModuleImpl, - kind: VTabKind, - ) -> ResultCode { - let module = Rc::new(module); - let vmodule = VTabImpl { - module_kind: kind, - implementation: module, + /// Build the connection's extension api context for manually registering an extension. + /// you probably want to use `Connection::load_extension(path)`. + /// + /// # Safety + /// Only to be used when registering a staticly linked extension manually. + /// You should only ever call this method on your applications startup, + /// The caller is responsible for calling `_free_extension_ctx` after registering the + /// extension. + /// + /// usage: + /// ```ignore + /// let ext_api = conn._build_turso_ext(); + /// unsafe { + /// my_extension::register_extension(&mut ext_api); + /// conn._free_extension_ctx(ext_api); + /// } + ///``` + pub unsafe fn _build_turso_ext(&self) -> ExtensionApi { + let schema_mutex_ptr = + &self._db.schema as *const Mutex> as *mut Mutex>; + let ctx = ExtensionCtx { + syms: self.syms.as_ptr(), + schema: schema_mutex_ptr as *mut c_void, }; - self.syms - .borrow_mut() - .vtab_modules - .insert(name.to_string(), vmodule.into()); - if kind == VTabKind::TableValuedFunction { - if let Ok(vtab) = VirtualTable::function(name, &self.syms.borrow()) { - self.with_schema_mut(|schema| schema.add_virtual_table(vtab)); - } else { - return ResultCode::Error; - } - } - ResultCode::OK - } - - pub fn build_turso_ext(&self) -> ExtensionApi { + let ctx = Box::into_raw(Box::new(ctx)) as *mut c_void; ExtensionApi { - ctx: self as *const _ as *mut c_void, + ctx, register_scalar_function, register_aggregate_function, register_vtab_module, @@ -174,20 +237,13 @@ impl Connection { } } - pub fn register_builtins(&self) -> Result<(), String> { - #[allow(unused_variables)] - let mut ext_api = self.build_turso_ext(); - #[cfg(feature = "uuid")] - crate::uuid::register_extension(&mut ext_api); - #[cfg(feature = "series")] - crate::series::register_extension(&mut ext_api); - #[cfg(feature = "fs")] - { - let vfslist = add_builtin_vfs_extensions(Some(ext_api)).map_err(|e| e.to_string())?; - for (name, vfs) in vfslist { - add_vfs_module(name, vfs); - } + /// Free the connection's extension libary context after registering an extension manually. + /// # Safety + /// Only to be used if you have previously called Connection::build_turso_ext + pub unsafe fn _free_extension_ctx(&self, api: ExtensionApi) { + if api.ctx.is_null() { + return; } - Ok(()) + let _ = unsafe { Box::from_raw(api.ctx as *mut ExtensionCtx) }; } } diff --git a/core/lib.rs b/core/lib.rs index 331e0384a..aa139b319 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -62,7 +62,6 @@ pub use io::{ use parking_lot::RwLock; use schema::Schema; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Mutex; use std::{ borrow::Cow, cell::{Cell, RefCell, UnsafeCell}, @@ -72,7 +71,7 @@ use std::{ num::NonZero, ops::Deref, rc::Rc, - sync::Arc, + sync::{Arc, Mutex}, }; #[cfg(feature = "fs")] use storage::database::DatabaseFile; @@ -121,6 +120,7 @@ pub struct Database { db_state: Arc, init_lock: Arc>, open_flags: OpenFlags, + builtin_syms: RefCell, } unsafe impl Send for Database {} @@ -245,7 +245,7 @@ impl Database { }; let shared_page_cache = Arc::new(RwLock::new(DumbLruPageCache::default())); - + let syms = SymbolTable::new(); let db = Arc::new(Database { mv_store, path: path.to_string(), @@ -253,11 +253,14 @@ impl Database { _shared_page_cache: shared_page_cache.clone(), maybe_shared_wal: RwLock::new(maybe_shared_wal), db_file, + builtin_syms: syms.into(), io: io.clone(), open_flags: flags, db_state: Arc::new(AtomicUsize::new(db_state)), init_lock: Arc::new(Mutex::new(())), }); + db.register_global_builtin_extensions() + .expect("unable to register global extensions"); // Check: https://github.com/tursodatabase/turso/pull/1761#discussion_r2154013123 if db_state == DB_STATE_INITIALIZED { @@ -315,10 +318,9 @@ impl Database { capture_data_changes: RefCell::new(CaptureDataChangesMode::Off), closed: Cell::new(false), }); - - if let Err(e) = conn.register_builtins() { - return Err(LimboError::ExtensionError(e)); - } + let builtin_syms = self.builtin_syms.borrow(); + // add built-in extensions symbols to the connection to prevent having to load each time + conn.syms.borrow_mut().extend(&builtin_syms); Ok(conn) } @@ -1222,6 +1224,18 @@ impl SymbolTable { ) -> Option> { self.functions.get(name).cloned() } + + pub fn extend(&mut self, other: &SymbolTable) { + for (name, func) in &other.functions { + self.functions.insert(name.clone(), func.clone()); + } + for (name, vtab) in &other.vtabs { + self.vtabs.insert(name.clone(), vtab.clone()); + } + for (name, module) in &other.vtab_modules { + self.vtab_modules.insert(name.clone(), module.clone()); + } + } } pub struct QueryRunner<'a> { diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 5c6671f76..af396384b 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -31,8 +31,11 @@ use crate::{ }; use std::ops::DerefMut; use std::sync::atomic::AtomicUsize; -use std::sync::Mutex; -use std::{borrow::BorrowMut, rc::Rc, sync::Arc}; +use std::{ + borrow::BorrowMut, + rc::Rc, + sync::{Arc, Mutex}, +}; use crate::{pseudo::PseudoCursor, result::LimboResult};