From 0a10d893d9de12ca206ef810e216cce82f5dd820 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Wed, 8 Jan 2025 23:16:57 -0500 Subject: [PATCH] Sketch out runtime extension loading --- Cargo.lock | 30 +++++++++++++-- Cargo.toml | 2 +- cli/app.rs | 26 ++++++++++++- core/Cargo.toml | 2 + core/error.rs | 2 + core/ext/mod.rs | 39 ++++++++++++++++++-- core/ext/uuid.rs | 6 +-- core/function.rs | 15 ++++++-- core/lib.rs | 53 ++++++++++++++++++++------- core/types.rs | 45 +++++++++++++++++++++-- core/vdbe/mod.rs | 9 ++++- extension_api/Cargo.toml | 9 +++++ extension_api/src/lib.rs | 75 ++++++++++++++++++++++++++++++++++++++ extensions/uuid/Cargo.toml | 11 ++++++ extensions/uuid/src/lib.rs | 0 15 files changed, 291 insertions(+), 33 deletions(-) create mode 100644 extension_api/Cargo.toml create mode 100644 extension_api/src/lib.rs create mode 100644 extensions/uuid/Cargo.toml create mode 100644 extensions/uuid/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index bb69e5c0f..b3f3bb7e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -564,7 +564,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" dependencies = [ - "uuid", + "uuid 1.11.0", ] [[package]] @@ -694,6 +694,10 @@ dependencies = [ "str-buf", ] +[[package]] +name = "extension_api" +version = "0.0.11" + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -1137,6 +1141,16 @@ version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + [[package]] name = "libmimalloc-sys" version = "0.1.39" @@ -1204,6 +1218,7 @@ dependencies = [ "cfg_block", "chrono", "criterion", + "extension_api", "fallible-iterator 0.3.0", "getrandom", "hex", @@ -1212,6 +1227,7 @@ dependencies = [ "jsonb", "julian_day_converter", "libc", + "libloading", "limbo_macros", "log", "miette", @@ -1232,7 +1248,7 @@ dependencies = [ "sqlite3-parser", "tempfile", "thiserror 1.0.69", - "uuid", + "uuid 1.11.0", ] [[package]] @@ -2260,7 +2276,7 @@ dependencies = [ "debugid", "memmap2", "stable_deref_trait", - "uuid", + "uuid 1.11.0", ] [[package]] @@ -2502,6 +2518,14 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "0.0.11" +dependencies = [ + "extension_api", + "uuid 1.11.0", +] + [[package]] name = "uuid" version = "1.11.0" diff --git a/Cargo.toml b/Cargo.toml index 0d2fc81be..40f5b2bd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ members = [ "sqlite3", "core", "simulator", - "test", "macros", + "test", "macros", "extension_api", "extensions/uuid", ] exclude = ["perf/latency/limbo"] diff --git a/cli/app.rs b/cli/app.rs index f114b63a2..325f6fe19 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -129,6 +129,8 @@ pub enum Command { Tables, /// Import data from FILE into TABLE Import, + /// Loads an extension library + LoadExtension, } impl Command { @@ -141,7 +143,12 @@ impl Command { | Self::ShowInfo | Self::Tables | Self::SetOutput => 0, - Self::Open | Self::OutputMode | Self::Cwd | Self::Echo | Self::NullValue => 1, + Self::Open + | Self::OutputMode + | Self::Cwd + | Self::Echo + | Self::NullValue + | Self::LoadExtension => 1, Self::Import => 2, } + 1) // argv0 } @@ -160,6 +167,7 @@ impl Command { Self::NullValue => ".nullvalue ", Self::Echo => ".echo on|off", Self::Tables => ".tables", + Self::LoadExtension => ".load", Self::Import => &IMPORT_HELP, } } @@ -182,6 +190,7 @@ impl FromStr for Command { ".nullvalue" => Ok(Self::NullValue), ".echo" => Ok(Self::Echo), ".import" => Ok(Self::Import), + ".load" => Ok(Self::LoadExtension), _ => Err("Unknown command".to_string()), } } @@ -314,6 +323,16 @@ impl Limbo { }; } + fn handle_load_extension(&mut self) -> Result<(), String> { + let mut args = self.input_buff.split_whitespace(); + let _ = args.next(); + let lib = args + .next() + .ok_or("No library specified") + .map_err(|e| e.to_string())?; + self.conn.load_extension(lib).map_err(|e| e.to_string()) + } + fn display_in_memory(&mut self) -> std::io::Result<()> { if self.opts.db_file == ":memory:" { self.writeln("Connected to a transient in-memory database.")?; @@ -537,6 +556,11 @@ impl Limbo { let _ = self.writeln(e.to_string()); }; } + Command::LoadExtension => { + if let Err(e) = self.handle_load_extension() { + let _ = self.writeln(e.to_string()); + } + } } } else { let _ = self.write_fmt(format_args!( diff --git a/core/Cargo.toml b/core/Cargo.toml index 0daa58c0d..2bbc1347d 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -35,6 +35,7 @@ rustix = "0.38.34" mimalloc = { version = "*", default-features = false } [dependencies] +extension_api = { path = "../extension_api" } cfg_block = "0.1.1" fallible-iterator = "0.3.0" hex = "0.4.3" @@ -58,6 +59,7 @@ bumpalo = { version = "3.16.0", features = ["collections", "boxed"] } limbo_macros = { path = "../macros" } uuid = { version = "1.11.0", features = ["v4", "v7"], optional = true } miette = "7.4.0" +libloading = "0.8.6" [target.'cfg(not(target_family = "windows"))'.dev-dependencies] pprof = { version = "0.14.0", features = ["criterion", "flamegraph"] } diff --git a/core/error.rs b/core/error.rs index 646e85825..e3e176b79 100644 --- a/core/error.rs +++ b/core/error.rs @@ -39,6 +39,8 @@ pub enum LimboError { InvalidModifier(String), #[error("Runtime error: {0}")] Constraint(String), + #[error("Extension error: {0}")] + ExtensionError(String), } #[macro_export] diff --git a/core/ext/mod.rs b/core/ext/mod.rs index cea65a98d..c1718dcc8 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,8 +1,39 @@ #[cfg(feature = "uuid")] mod uuid; +use crate::{function::ExternalFunc, Database}; +use std::sync::Arc; + +use extension_api::{AggregateFunction, ExtensionApi, Result, ScalarFunction, VirtualTable}; #[cfg(feature = "uuid")] pub use uuid::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, UuidFunc}; +impl ExtensionApi for Database { + fn register_scalar_function( + &self, + name: &str, + func: Arc, + ) -> extension_api::Result<()> { + let ext_func = ExternalFunc::new(name, func.clone()); + self.syms + .borrow_mut() + .functions + .insert(name.to_string(), Arc::new(ext_func)); + Ok(()) + } + + fn register_aggregate_function( + &self, + _name: &str, + _func: Arc, + ) -> Result<()> { + todo!("implement aggregate function registration"); + } + + fn register_virtual_table(&self, _name: &str, _table: Arc) -> Result<()> { + todo!("implement virtual table registration"); + } +} + #[derive(Debug, Clone, PartialEq)] pub enum ExtFunc { #[cfg(feature = "uuid")] @@ -31,7 +62,7 @@ impl ExtFunc { } } -pub fn init(db: &mut crate::Database) { - #[cfg(feature = "uuid")] - uuid::init(db); -} +//pub fn init(db: &mut crate::Database) { +// #[cfg(feature = "uuid")] +// uuid::init(db); +//} diff --git a/core/ext/uuid.rs b/core/ext/uuid.rs index 92fdd831a..37e496f00 100644 --- a/core/ext/uuid.rs +++ b/core/ext/uuid.rs @@ -136,9 +136,9 @@ fn uuid_to_unix(uuid: &[u8; 16]) -> u64 { | (uuid[5] as u64) } -pub fn init(db: &mut Database) { - db.define_scalar_function("uuid4", |_args| exec_uuid4()); -} +//pub fn init(db: &mut Database) { +// db.define_scalar_function("uuid4", |_args| exec_uuid4()); +//} #[cfg(test)] #[cfg(feature = "uuid")] diff --git a/core/function.rs b/core/function.rs index 060a677c3..3987b2585 100644 --- a/core/function.rs +++ b/core/function.rs @@ -1,11 +1,20 @@ use crate::ext::ExtFunc; use std::fmt; use std::fmt::{Debug, Display}; -use std::rc::Rc; +use std::sync::Arc; pub struct ExternalFunc { pub name: String, - pub func: Box crate::Result>, + pub func: Arc, +} + +impl ExternalFunc { + pub fn new(name: &str, func: Arc) -> Self { + Self { + name: name.to_string(), + func, + } + } } impl Debug for ExternalFunc { @@ -300,7 +309,7 @@ pub enum Func { #[cfg(feature = "json")] Json(JsonFunc), Extension(ExtFunc), - External(Rc), + External(Arc), } impl Display for Func { diff --git a/core/lib.rs b/core/lib.rs index a80fab83a..21b7cf102 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -17,7 +17,9 @@ mod vdbe; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; +use extension_api::{Extension, ExtensionApi}; use fallible_iterator::FallibleIterator; +use libloading::{Library, Symbol}; use log::trace; use schema::Schema; use sqlite3_parser::ast; @@ -34,12 +36,11 @@ use storage::pager::allocate_page; use storage::sqlite3_ondisk::{DatabaseHeader, DATABASE_HEADER_SIZE}; pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; +pub use types::Value; use util::parse_schema_rows; -use translate::select::prepare_select_plan; -use types::OwnedValue; - pub use error::LimboError; +use translate::select::prepare_select_plan; pub type Result = std::result::Result; use crate::translate::optimizer::optimize_plan; @@ -56,8 +57,6 @@ pub use storage::pager::Page; pub use storage::pager::Pager; pub use storage::wal::CheckpointStatus; pub use storage::wal::Wal; -pub use types::Value; - pub static DATABASE_VERSION: OnceLock = OnceLock::new(); #[derive(Clone)] @@ -135,11 +134,11 @@ impl Database { _shared_wal: shared_wal.clone(), syms, }; - ext::init(&mut db); + // ext::init(&mut db); let db = Arc::new(db); let conn = Rc::new(Connection { db: db.clone(), - pager: pager, + pager, schema: schema.clone(), header, transaction_state: RefCell::new(TransactionState::None), @@ -169,16 +168,31 @@ impl Database { pub fn define_scalar_function>( &self, name: S, - func: impl Fn(&[Value]) -> Result + 'static, + func: Arc, ) { let func = function::ExternalFunc { name: name.as_ref().to_string(), - func: Box::new(func), + func: func.clone(), }; self.syms .borrow_mut() .functions - .insert(name.as_ref().to_string(), Rc::new(func)); + .insert(name.as_ref().to_string(), Arc::new(func)); + } + + pub fn load_extension(&self, path: &str) -> Result<()> { + let lib = + unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? }; + unsafe { + let register: Symbol Box> = + lib.get(b"register_extension") + .map_err(|e| LimboError::ExtensionError(e.to_string()))?; + let extension = register(self); + extension + .load() + .map_err(|e| LimboError::ExtensionError(e.to_string()))?; + } + Ok(()) } } @@ -372,6 +386,10 @@ impl Connection { Ok(()) } + pub fn load_extension(&self, path: &str) -> Result<()> { + Database::load_extension(self.db.as_ref(), path) + } + /// Close a connection and checkpoint. pub fn close(&self) -> Result<()> { loop { @@ -468,15 +486,24 @@ impl Rows { } } -#[derive(Debug)] pub(crate) struct SymbolTable { - pub functions: HashMap>, + pub functions: HashMap>, + extensions: Vec>, +} + +impl std::fmt::Debug for SymbolTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SymbolTable") + .field("functions", &self.functions) + .finish() + } } impl SymbolTable { pub fn new() -> Self { Self { functions: HashMap::new(), + extensions: Vec::new(), } } @@ -484,7 +511,7 @@ impl SymbolTable { &self, name: &str, _arg_count: usize, - ) -> Option> { + ) -> Option> { self.functions.get(name).cloned() } } diff --git a/core/types.rs b/core/types.rs index d9a496bfb..3d3756799 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,8 +1,8 @@ -use std::fmt::Display; -use std::rc::Rc; - use crate::error::LimboError; use crate::Result; +use extension_api::Value as ExtValue; +use std::fmt::Display; +use std::rc::Rc; use crate::storage::sqlite3_ondisk::write_varint; @@ -15,6 +15,45 @@ pub enum Value<'a> { Blob(&'a Vec), } +impl From<&OwnedValue> for extension_api::Value { + fn from(value: &OwnedValue) -> Self { + match value { + OwnedValue::Null => extension_api::Value::Null, + OwnedValue::Integer(i) => extension_api::Value::Integer(*i), + OwnedValue::Float(f) => extension_api::Value::Float(*f), + OwnedValue::Text(text) => extension_api::Value::Text(text.value.to_string()), + OwnedValue::Blob(blob) => extension_api::Value::Blob(blob.to_vec()), + OwnedValue::Agg(_) => { + panic!("Cannot convert Aggregate context to extension_api::Value") + } // Handle appropriately + OwnedValue::Record(_) => panic!("Cannot convert Record to extension_api::Value"), // Handle appropriately + } + } +} +impl From for OwnedValue { + fn from(value: ExtValue) -> Self { + match value { + ExtValue::Null => OwnedValue::Null, + ExtValue::Integer(i) => OwnedValue::Integer(i), + ExtValue::Float(f) => OwnedValue::Float(f), + ExtValue::Text(text) => OwnedValue::Text(LimboText::new(Rc::new(text.to_string()))), + ExtValue::Blob(blob) => OwnedValue::Blob(Rc::new(blob.to_vec())), + } + } +} + +impl<'a> From<&'a crate::Value<'a>> for ExtValue { + fn from(value: &'a crate::Value<'a>) -> Self { + match value { + crate::Value::Null => extension_api::Value::Null, + crate::Value::Integer(i) => extension_api::Value::Integer(*i), + crate::Value::Float(f) => extension_api::Value::Float(*f), + crate::Value::Text(t) => extension_api::Value::Text(t.to_string()), + crate::Value::Blob(b) => extension_api::Value::Blob(b.to_vec()), + } + } +} + impl Display for Value<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 28db889cf..f6903619b 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -1872,8 +1872,13 @@ impl Program { _ => unreachable!(), // when more extension types are added }, crate::function::Func::External(f) => { - let result = (f.func)(&[])?; - state.registers[*dest] = result; + let values = &state.registers[*start_reg..*start_reg + arg_count]; + let args: Vec<_> = values.into_iter().map(|v| v.into()).collect(); + let result = f + .func + .execute(args.as_slice()) + .map_err(|e| LimboError::ExtensionError(e.to_string()))?; + state.registers[*dest] = result.into(); } crate::function::Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => match math_func { diff --git a/extension_api/Cargo.toml b/extension_api/Cargo.toml new file mode 100644 index 000000000..73056af33 --- /dev/null +++ b/extension_api/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "extension_api" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] diff --git a/extension_api/src/lib.rs b/extension_api/src/lib.rs new file mode 100644 index 000000000..7ee26e329 --- /dev/null +++ b/extension_api/src/lib.rs @@ -0,0 +1,75 @@ +use std::any::Any; +use std::rc::Rc; +use std::sync::Arc; + +pub type Result = std::result::Result; + +pub trait Extension { + fn load(&self) -> Result<()>; +} + +#[derive(Debug)] +pub enum LimboApiError { + ConnectionError(String), + RegisterFunctionError(String), + ValueError(String), + VTableError(String), +} + +impl From for LimboApiError { + fn from(e: std::io::Error) -> Self { + Self::ConnectionError(e.to_string()) + } +} + +impl std::fmt::Display for LimboApiError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::ConnectionError(e) => write!(f, "Connection error: {e}"), + Self::RegisterFunctionError(e) => write!(f, "Register function error: {e}"), + Self::ValueError(e) => write!(f, "Value error: {e}"), + Self::VTableError(e) => write!(f, "VTable error: {e}"), + } + } +} + +pub trait ExtensionApi { + fn register_scalar_function(&self, name: &str, func: Arc) -> Result<()>; + fn register_aggregate_function( + &self, + name: &str, + func: Arc, + ) -> Result<()>; + fn register_virtual_table(&self, name: &str, table: Arc) -> Result<()>; +} + +pub trait ScalarFunction { + fn execute(&self, args: &[Value]) -> Result; +} + +pub trait AggregateFunction { + fn init(&self) -> Box; + fn step(&self, state: &mut dyn Any, args: &[Value]) -> Result<()>; + fn finalize(&self, state: Box) -> Result; +} + +pub trait VirtualTable { + fn schema(&self) -> &'static str; + fn create_cursor(&self) -> Box; +} + +pub trait Cursor { + fn next(&mut self) -> Result>; +} + +pub struct Row { + pub values: Vec, +} + +pub enum Value { + Text(String), + Blob(Vec), + Integer(i64), + Float(f64), + Null, +} diff --git a/extensions/uuid/Cargo.toml b/extensions/uuid/Cargo.toml new file mode 100644 index 000000000..c6ae90bdf --- /dev/null +++ b/extensions/uuid/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "uuid" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +extension_api = { path = "../../extension_api"} +uuid = { version = "1.11.0", features = ["v4", "v7"] } diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs new file mode 100644 index 000000000..e69de29bb