From 0a10d893d9de12ca206ef810e216cce82f5dd820 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Wed, 8 Jan 2025 23:16:57 -0500 Subject: [PATCH 1/9] Sketch out runtime extension loading --- Cargo.lock | 30 +++++++++++++-- Cargo.toml | 2 +- cli/app.rs | 26 ++++++++++++- core/Cargo.toml | 2 + core/error.rs | 2 + core/ext/mod.rs | 39 ++++++++++++++++++-- core/ext/uuid.rs | 6 +-- core/function.rs | 15 ++++++-- core/lib.rs | 53 ++++++++++++++++++++------- core/types.rs | 45 +++++++++++++++++++++-- core/vdbe/mod.rs | 9 ++++- extension_api/Cargo.toml | 9 +++++ extension_api/src/lib.rs | 75 ++++++++++++++++++++++++++++++++++++++ extensions/uuid/Cargo.toml | 11 ++++++ extensions/uuid/src/lib.rs | 0 15 files changed, 291 insertions(+), 33 deletions(-) create mode 100644 extension_api/Cargo.toml create mode 100644 extension_api/src/lib.rs create mode 100644 extensions/uuid/Cargo.toml create mode 100644 extensions/uuid/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index bb69e5c0f..b3f3bb7e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -564,7 +564,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" dependencies = [ - "uuid", + "uuid 1.11.0", ] [[package]] @@ -694,6 +694,10 @@ dependencies = [ "str-buf", ] +[[package]] +name = "extension_api" +version = "0.0.11" + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -1137,6 +1141,16 @@ version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + [[package]] name = "libmimalloc-sys" version = "0.1.39" @@ -1204,6 +1218,7 @@ dependencies = [ "cfg_block", "chrono", "criterion", + "extension_api", "fallible-iterator 0.3.0", "getrandom", "hex", @@ -1212,6 +1227,7 @@ dependencies = [ "jsonb", "julian_day_converter", "libc", + "libloading", "limbo_macros", "log", "miette", @@ -1232,7 +1248,7 @@ dependencies = [ "sqlite3-parser", "tempfile", "thiserror 1.0.69", - "uuid", + "uuid 1.11.0", ] [[package]] @@ -2260,7 +2276,7 @@ dependencies = [ "debugid", "memmap2", "stable_deref_trait", - "uuid", + "uuid 1.11.0", ] [[package]] @@ -2502,6 +2518,14 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "0.0.11" +dependencies = [ + "extension_api", + "uuid 1.11.0", +] + [[package]] name = "uuid" version = "1.11.0" diff --git a/Cargo.toml b/Cargo.toml index 0d2fc81be..40f5b2bd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ members = [ "sqlite3", "core", "simulator", - "test", "macros", + "test", "macros", "extension_api", "extensions/uuid", ] exclude = ["perf/latency/limbo"] diff --git a/cli/app.rs b/cli/app.rs index f114b63a2..325f6fe19 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -129,6 +129,8 @@ pub enum Command { Tables, /// Import data from FILE into TABLE Import, + /// Loads an extension library + LoadExtension, } impl Command { @@ -141,7 +143,12 @@ impl Command { | Self::ShowInfo | Self::Tables | Self::SetOutput => 0, - Self::Open | Self::OutputMode | Self::Cwd | Self::Echo | Self::NullValue => 1, + Self::Open + | Self::OutputMode + | Self::Cwd + | Self::Echo + | Self::NullValue + | Self::LoadExtension => 1, Self::Import => 2, } + 1) // argv0 } @@ -160,6 +167,7 @@ impl Command { Self::NullValue => ".nullvalue ", Self::Echo => ".echo on|off", Self::Tables => ".tables", + Self::LoadExtension => ".load", Self::Import => &IMPORT_HELP, } } @@ -182,6 +190,7 @@ impl FromStr for Command { ".nullvalue" => Ok(Self::NullValue), ".echo" => Ok(Self::Echo), ".import" => Ok(Self::Import), + ".load" => Ok(Self::LoadExtension), _ => Err("Unknown command".to_string()), } } @@ -314,6 +323,16 @@ impl Limbo { }; } + fn handle_load_extension(&mut self) -> Result<(), String> { + let mut args = self.input_buff.split_whitespace(); + let _ = args.next(); + let lib = args + .next() + .ok_or("No library specified") + .map_err(|e| e.to_string())?; + self.conn.load_extension(lib).map_err(|e| e.to_string()) + } + fn display_in_memory(&mut self) -> std::io::Result<()> { if self.opts.db_file == ":memory:" { self.writeln("Connected to a transient in-memory database.")?; @@ -537,6 +556,11 @@ impl Limbo { let _ = self.writeln(e.to_string()); }; } + Command::LoadExtension => { + if let Err(e) = self.handle_load_extension() { + let _ = self.writeln(e.to_string()); + } + } } } else { let _ = self.write_fmt(format_args!( diff --git a/core/Cargo.toml b/core/Cargo.toml index 0daa58c0d..2bbc1347d 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -35,6 +35,7 @@ rustix = "0.38.34" mimalloc = { version = "*", default-features = false } [dependencies] +extension_api = { path = "../extension_api" } cfg_block = "0.1.1" fallible-iterator = "0.3.0" hex = "0.4.3" @@ -58,6 +59,7 @@ bumpalo = { version = "3.16.0", features = ["collections", "boxed"] } limbo_macros = { path = "../macros" } uuid = { version = "1.11.0", features = ["v4", "v7"], optional = true } miette = "7.4.0" +libloading = "0.8.6" [target.'cfg(not(target_family = "windows"))'.dev-dependencies] pprof = { version = "0.14.0", features = ["criterion", "flamegraph"] } diff --git a/core/error.rs b/core/error.rs index 646e85825..e3e176b79 100644 --- a/core/error.rs +++ b/core/error.rs @@ -39,6 +39,8 @@ pub enum LimboError { InvalidModifier(String), #[error("Runtime error: {0}")] Constraint(String), + #[error("Extension error: {0}")] + ExtensionError(String), } #[macro_export] diff --git a/core/ext/mod.rs b/core/ext/mod.rs index cea65a98d..c1718dcc8 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,8 +1,39 @@ #[cfg(feature = "uuid")] mod uuid; +use crate::{function::ExternalFunc, Database}; +use std::sync::Arc; + +use extension_api::{AggregateFunction, ExtensionApi, Result, ScalarFunction, VirtualTable}; #[cfg(feature = "uuid")] pub use uuid::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, UuidFunc}; +impl ExtensionApi for Database { + fn register_scalar_function( + &self, + name: &str, + func: Arc, + ) -> extension_api::Result<()> { + let ext_func = ExternalFunc::new(name, func.clone()); + self.syms + .borrow_mut() + .functions + .insert(name.to_string(), Arc::new(ext_func)); + Ok(()) + } + + fn register_aggregate_function( + &self, + _name: &str, + _func: Arc, + ) -> Result<()> { + todo!("implement aggregate function registration"); + } + + fn register_virtual_table(&self, _name: &str, _table: Arc) -> Result<()> { + todo!("implement virtual table registration"); + } +} + #[derive(Debug, Clone, PartialEq)] pub enum ExtFunc { #[cfg(feature = "uuid")] @@ -31,7 +62,7 @@ impl ExtFunc { } } -pub fn init(db: &mut crate::Database) { - #[cfg(feature = "uuid")] - uuid::init(db); -} +//pub fn init(db: &mut crate::Database) { +// #[cfg(feature = "uuid")] +// uuid::init(db); +//} diff --git a/core/ext/uuid.rs b/core/ext/uuid.rs index 92fdd831a..37e496f00 100644 --- a/core/ext/uuid.rs +++ b/core/ext/uuid.rs @@ -136,9 +136,9 @@ fn uuid_to_unix(uuid: &[u8; 16]) -> u64 { | (uuid[5] as u64) } -pub fn init(db: &mut Database) { - db.define_scalar_function("uuid4", |_args| exec_uuid4()); -} +//pub fn init(db: &mut Database) { +// db.define_scalar_function("uuid4", |_args| exec_uuid4()); +//} #[cfg(test)] #[cfg(feature = "uuid")] diff --git a/core/function.rs b/core/function.rs index 060a677c3..3987b2585 100644 --- a/core/function.rs +++ b/core/function.rs @@ -1,11 +1,20 @@ use crate::ext::ExtFunc; use std::fmt; use std::fmt::{Debug, Display}; -use std::rc::Rc; +use std::sync::Arc; pub struct ExternalFunc { pub name: String, - pub func: Box crate::Result>, + pub func: Arc, +} + +impl ExternalFunc { + pub fn new(name: &str, func: Arc) -> Self { + Self { + name: name.to_string(), + func, + } + } } impl Debug for ExternalFunc { @@ -300,7 +309,7 @@ pub enum Func { #[cfg(feature = "json")] Json(JsonFunc), Extension(ExtFunc), - External(Rc), + External(Arc), } impl Display for Func { diff --git a/core/lib.rs b/core/lib.rs index a80fab83a..21b7cf102 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -17,7 +17,9 @@ mod vdbe; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; +use extension_api::{Extension, ExtensionApi}; use fallible_iterator::FallibleIterator; +use libloading::{Library, Symbol}; use log::trace; use schema::Schema; use sqlite3_parser::ast; @@ -34,12 +36,11 @@ use storage::pager::allocate_page; use storage::sqlite3_ondisk::{DatabaseHeader, DATABASE_HEADER_SIZE}; pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; +pub use types::Value; use util::parse_schema_rows; -use translate::select::prepare_select_plan; -use types::OwnedValue; - pub use error::LimboError; +use translate::select::prepare_select_plan; pub type Result = std::result::Result; use crate::translate::optimizer::optimize_plan; @@ -56,8 +57,6 @@ pub use storage::pager::Page; pub use storage::pager::Pager; pub use storage::wal::CheckpointStatus; pub use storage::wal::Wal; -pub use types::Value; - pub static DATABASE_VERSION: OnceLock = OnceLock::new(); #[derive(Clone)] @@ -135,11 +134,11 @@ impl Database { _shared_wal: shared_wal.clone(), syms, }; - ext::init(&mut db); + // ext::init(&mut db); let db = Arc::new(db); let conn = Rc::new(Connection { db: db.clone(), - pager: pager, + pager, schema: schema.clone(), header, transaction_state: RefCell::new(TransactionState::None), @@ -169,16 +168,31 @@ impl Database { pub fn define_scalar_function>( &self, name: S, - func: impl Fn(&[Value]) -> Result + 'static, + func: Arc, ) { let func = function::ExternalFunc { name: name.as_ref().to_string(), - func: Box::new(func), + func: func.clone(), }; self.syms .borrow_mut() .functions - .insert(name.as_ref().to_string(), Rc::new(func)); + .insert(name.as_ref().to_string(), Arc::new(func)); + } + + pub fn load_extension(&self, path: &str) -> Result<()> { + let lib = + unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? }; + unsafe { + let register: Symbol Box> = + lib.get(b"register_extension") + .map_err(|e| LimboError::ExtensionError(e.to_string()))?; + let extension = register(self); + extension + .load() + .map_err(|e| LimboError::ExtensionError(e.to_string()))?; + } + Ok(()) } } @@ -372,6 +386,10 @@ impl Connection { Ok(()) } + pub fn load_extension(&self, path: &str) -> Result<()> { + Database::load_extension(self.db.as_ref(), path) + } + /// Close a connection and checkpoint. pub fn close(&self) -> Result<()> { loop { @@ -468,15 +486,24 @@ impl Rows { } } -#[derive(Debug)] pub(crate) struct SymbolTable { - pub functions: HashMap>, + pub functions: HashMap>, + extensions: Vec>, +} + +impl std::fmt::Debug for SymbolTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SymbolTable") + .field("functions", &self.functions) + .finish() + } } impl SymbolTable { pub fn new() -> Self { Self { functions: HashMap::new(), + extensions: Vec::new(), } } @@ -484,7 +511,7 @@ impl SymbolTable { &self, name: &str, _arg_count: usize, - ) -> Option> { + ) -> Option> { self.functions.get(name).cloned() } } diff --git a/core/types.rs b/core/types.rs index d9a496bfb..3d3756799 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,8 +1,8 @@ -use std::fmt::Display; -use std::rc::Rc; - use crate::error::LimboError; use crate::Result; +use extension_api::Value as ExtValue; +use std::fmt::Display; +use std::rc::Rc; use crate::storage::sqlite3_ondisk::write_varint; @@ -15,6 +15,45 @@ pub enum Value<'a> { Blob(&'a Vec), } +impl From<&OwnedValue> for extension_api::Value { + fn from(value: &OwnedValue) -> Self { + match value { + OwnedValue::Null => extension_api::Value::Null, + OwnedValue::Integer(i) => extension_api::Value::Integer(*i), + OwnedValue::Float(f) => extension_api::Value::Float(*f), + OwnedValue::Text(text) => extension_api::Value::Text(text.value.to_string()), + OwnedValue::Blob(blob) => extension_api::Value::Blob(blob.to_vec()), + OwnedValue::Agg(_) => { + panic!("Cannot convert Aggregate context to extension_api::Value") + } // Handle appropriately + OwnedValue::Record(_) => panic!("Cannot convert Record to extension_api::Value"), // Handle appropriately + } + } +} +impl From for OwnedValue { + fn from(value: ExtValue) -> Self { + match value { + ExtValue::Null => OwnedValue::Null, + ExtValue::Integer(i) => OwnedValue::Integer(i), + ExtValue::Float(f) => OwnedValue::Float(f), + ExtValue::Text(text) => OwnedValue::Text(LimboText::new(Rc::new(text.to_string()))), + ExtValue::Blob(blob) => OwnedValue::Blob(Rc::new(blob.to_vec())), + } + } +} + +impl<'a> From<&'a crate::Value<'a>> for ExtValue { + fn from(value: &'a crate::Value<'a>) -> Self { + match value { + crate::Value::Null => extension_api::Value::Null, + crate::Value::Integer(i) => extension_api::Value::Integer(*i), + crate::Value::Float(f) => extension_api::Value::Float(*f), + crate::Value::Text(t) => extension_api::Value::Text(t.to_string()), + crate::Value::Blob(b) => extension_api::Value::Blob(b.to_vec()), + } + } +} + impl Display for Value<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 28db889cf..f6903619b 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -1872,8 +1872,13 @@ impl Program { _ => unreachable!(), // when more extension types are added }, crate::function::Func::External(f) => { - let result = (f.func)(&[])?; - state.registers[*dest] = result; + let values = &state.registers[*start_reg..*start_reg + arg_count]; + let args: Vec<_> = values.into_iter().map(|v| v.into()).collect(); + let result = f + .func + .execute(args.as_slice()) + .map_err(|e| LimboError::ExtensionError(e.to_string()))?; + state.registers[*dest] = result.into(); } crate::function::Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => match math_func { diff --git a/extension_api/Cargo.toml b/extension_api/Cargo.toml new file mode 100644 index 000000000..73056af33 --- /dev/null +++ b/extension_api/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "extension_api" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] diff --git a/extension_api/src/lib.rs b/extension_api/src/lib.rs new file mode 100644 index 000000000..7ee26e329 --- /dev/null +++ b/extension_api/src/lib.rs @@ -0,0 +1,75 @@ +use std::any::Any; +use std::rc::Rc; +use std::sync::Arc; + +pub type Result = std::result::Result; + +pub trait Extension { + fn load(&self) -> Result<()>; +} + +#[derive(Debug)] +pub enum LimboApiError { + ConnectionError(String), + RegisterFunctionError(String), + ValueError(String), + VTableError(String), +} + +impl From for LimboApiError { + fn from(e: std::io::Error) -> Self { + Self::ConnectionError(e.to_string()) + } +} + +impl std::fmt::Display for LimboApiError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::ConnectionError(e) => write!(f, "Connection error: {e}"), + Self::RegisterFunctionError(e) => write!(f, "Register function error: {e}"), + Self::ValueError(e) => write!(f, "Value error: {e}"), + Self::VTableError(e) => write!(f, "VTable error: {e}"), + } + } +} + +pub trait ExtensionApi { + fn register_scalar_function(&self, name: &str, func: Arc) -> Result<()>; + fn register_aggregate_function( + &self, + name: &str, + func: Arc, + ) -> Result<()>; + fn register_virtual_table(&self, name: &str, table: Arc) -> Result<()>; +} + +pub trait ScalarFunction { + fn execute(&self, args: &[Value]) -> Result; +} + +pub trait AggregateFunction { + fn init(&self) -> Box; + fn step(&self, state: &mut dyn Any, args: &[Value]) -> Result<()>; + fn finalize(&self, state: Box) -> Result; +} + +pub trait VirtualTable { + fn schema(&self) -> &'static str; + fn create_cursor(&self) -> Box; +} + +pub trait Cursor { + fn next(&mut self) -> Result>; +} + +pub struct Row { + pub values: Vec, +} + +pub enum Value { + Text(String), + Blob(Vec), + Integer(i64), + Float(f64), + Null, +} diff --git a/extensions/uuid/Cargo.toml b/extensions/uuid/Cargo.toml new file mode 100644 index 000000000..c6ae90bdf --- /dev/null +++ b/extensions/uuid/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "uuid" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +extension_api = { path = "../../extension_api"} +uuid = { version = "1.11.0", features = ["v4", "v7"] } diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs new file mode 100644 index 000000000..e69de29bb From 3412a3d4c26d0d8a5fc8d2b5d2fd1a929f34bc7a Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sat, 11 Jan 2025 22:48:06 -0500 Subject: [PATCH 2/9] Rough design for extension api/draft extension --- Cargo.lock | 33 +-- Cargo.toml | 2 +- cli/app.rs | 14 +- core/Cargo.toml | 2 +- core/ext/mod.rs | 93 +++---- core/function.rs | 18 +- core/lib.rs | 45 ++-- core/translate/emitter.rs | 9 +- core/translate/expr.rs | 56 ----- core/translate/planner.rs | 19 +- core/translate/select.rs | 30 ++- core/types.rs | 79 +++--- core/vdbe/mod.rs | 57 ++--- extension_api/src/lib.rs | 75 ------ extensions/uuid/Cargo.toml | 8 +- extensions/uuid/src/lib.rs | 62 +++++ {extension_api => limbo_extension}/Cargo.toml | 2 +- limbo_extension/src/lib.rs | 233 ++++++++++++++++++ 18 files changed, 489 insertions(+), 348 deletions(-) delete mode 100644 extension_api/src/lib.rs rename {extension_api => limbo_extension}/Cargo.toml (86%) create mode 100644 limbo_extension/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index b3f3bb7e0..eb09669a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -564,7 +564,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" dependencies = [ - "uuid 1.11.0", + "uuid", ] [[package]] @@ -694,10 +694,6 @@ dependencies = [ "str-buf", ] -[[package]] -name = "extension_api" -version = "0.0.11" - [[package]] name = "fallible-iterator" version = "0.2.0" @@ -1218,7 +1214,6 @@ dependencies = [ "cfg_block", "chrono", "criterion", - "extension_api", "fallible-iterator 0.3.0", "getrandom", "hex", @@ -1228,6 +1223,7 @@ dependencies = [ "julian_day_converter", "libc", "libloading", + "limbo_extension", "limbo_macros", "log", "miette", @@ -1248,10 +1244,11 @@ dependencies = [ "sqlite3-parser", "tempfile", "thiserror 1.0.69", - "uuid 1.11.0", + "uuid", ] [[package]] +<<<<<<< HEAD name = "limbo_libsql" version = "0.0.12" dependencies = [ @@ -1260,6 +1257,10 @@ dependencies = [ "tokio", ] +[[package]] +name = "limbo_extension" +version = "0.0.11" + [[package]] name = "limbo_macros" version = "0.0.12" @@ -1288,6 +1289,14 @@ dependencies = [ "log", ] +[[package]] +name = "limbo_uuid" +version = "0.0.11" +dependencies = [ + "limbo_extension", + "uuid", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -2276,7 +2285,7 @@ dependencies = [ "debugid", "memmap2", "stable_deref_trait", - "uuid 1.11.0", + "uuid", ] [[package]] @@ -2518,14 +2527,6 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" -[[package]] -name = "uuid" -version = "0.0.11" -dependencies = [ - "extension_api", - "uuid 1.11.0", -] - [[package]] name = "uuid" version = "1.11.0" diff --git a/Cargo.toml b/Cargo.toml index 40f5b2bd4..44ec1ef15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ members = [ "sqlite3", "core", "simulator", - "test", "macros", "extension_api", "extensions/uuid", + "test", "macros", "limbo_extension", "extensions/uuid", ] exclude = ["perf/latency/limbo"] diff --git a/cli/app.rs b/cli/app.rs index 325f6fe19..62108ca6d 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -323,14 +323,8 @@ impl Limbo { }; } - fn handle_load_extension(&mut self) -> Result<(), String> { - let mut args = self.input_buff.split_whitespace(); - let _ = args.next(); - let lib = args - .next() - .ok_or("No library specified") - .map_err(|e| e.to_string())?; - self.conn.load_extension(lib).map_err(|e| e.to_string()) + fn handle_load_extension(&mut self, path: &str) -> Result<(), String> { + self.conn.load_extension(path).map_err(|e| e.to_string()) } fn display_in_memory(&mut self) -> std::io::Result<()> { @@ -557,8 +551,8 @@ impl Limbo { }; } Command::LoadExtension => { - if let Err(e) = self.handle_load_extension() { - let _ = self.writeln(e.to_string()); + if let Err(e) = self.handle_load_extension(args[1]) { + let _ = self.writeln(&e); } } } diff --git a/core/Cargo.toml b/core/Cargo.toml index 2bbc1347d..c6042a298 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -35,7 +35,7 @@ rustix = "0.38.34" mimalloc = { version = "*", default-features = false } [dependencies] -extension_api = { path = "../extension_api" } +limbo_extension = { path = "../limbo_extension" } cfg_block = "0.1.1" fallible-iterator = "0.3.0" hex = "0.4.3" diff --git a/core/ext/mod.rs b/core/ext/mod.rs index c1718dcc8..79936079e 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,68 +1,41 @@ -#[cfg(feature = "uuid")] -mod uuid; use crate::{function::ExternalFunc, Database}; -use std::sync::Arc; +use limbo_extension::{ExtensionApi, ResultCode, ScalarFunction, RESULT_ERROR, RESULT_OK}; +pub use limbo_extension::{Value as ExtValue, ValueType as ExtValueType}; +use std::{ + ffi::{c_char, c_void, CStr}, + rc::Rc, +}; -use extension_api::{AggregateFunction, ExtensionApi, Result, ScalarFunction, VirtualTable}; -#[cfg(feature = "uuid")] -pub use uuid::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, UuidFunc}; - -impl ExtensionApi for Database { - fn register_scalar_function( - &self, - name: &str, - func: Arc, - ) -> extension_api::Result<()> { - let ext_func = ExternalFunc::new(name, func.clone()); - self.syms - .borrow_mut() - .functions - .insert(name.to_string(), Arc::new(ext_func)); - Ok(()) - } - - fn register_aggregate_function( - &self, - _name: &str, - _func: Arc, - ) -> Result<()> { - todo!("implement aggregate function registration"); - } - - fn register_virtual_table(&self, _name: &str, _table: Arc) -> Result<()> { - todo!("implement virtual table registration"); - } +extern "C" fn register_scalar_function( + ctx: *mut c_void, + name: *const c_char, + func: ScalarFunction, +) -> ResultCode { + let c_str = unsafe { CStr::from_ptr(name) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return RESULT_ERROR, + }; + let db = unsafe { &*(ctx as *const Database) }; + db.register_scalar_function_impl(name_str, func) } -#[derive(Debug, Clone, PartialEq)] -pub enum ExtFunc { - #[cfg(feature = "uuid")] - Uuid(UuidFunc), -} +impl Database { + fn register_scalar_function_impl(&self, name: String, func: ScalarFunction) -> ResultCode { + self.syms.borrow_mut().functions.insert( + name.to_string(), + Rc::new(ExternalFunc { + name: name.to_string(), + func, + }), + ); + RESULT_OK + } -#[allow(unreachable_patterns)] // TODO: remove when more extension funcs added -impl std::fmt::Display for ExtFunc { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - #[cfg(feature = "uuid")] - Self::Uuid(uuidfn) => write!(f, "{}", uuidfn), - _ => write!(f, "unknown"), + pub fn build_limbo_extension(&self) -> ExtensionApi { + ExtensionApi { + ctx: self as *const _ as *mut c_void, + register_scalar_function, } } } - -#[allow(unreachable_patterns)] -impl ExtFunc { - pub fn resolve_function(name: &str, num_args: usize) -> Option { - match name { - #[cfg(feature = "uuid")] - name => UuidFunc::resolve_function(name, num_args), - _ => None, - } - } -} - -//pub fn init(db: &mut crate::Database) { -// #[cfg(feature = "uuid")] -// uuid::init(db); -//} diff --git a/core/function.rs b/core/function.rs index 3987b2585..68b5ef005 100644 --- a/core/function.rs +++ b/core/function.rs @@ -1,15 +1,16 @@ -use crate::ext::ExtFunc; use std::fmt; use std::fmt::{Debug, Display}; -use std::sync::Arc; +use std::rc::Rc; + +use limbo_extension::ScalarFunction; pub struct ExternalFunc { pub name: String, - pub func: Arc, + pub func: ScalarFunction, } impl ExternalFunc { - pub fn new(name: &str, func: Arc) -> Self { + pub fn new(name: &str, func: ScalarFunction) -> Self { Self { name: name.to_string(), func, @@ -308,8 +309,7 @@ pub enum Func { Math(MathFunc), #[cfg(feature = "json")] Json(JsonFunc), - Extension(ExtFunc), - External(Arc), + External(Rc), } impl Display for Func { @@ -320,7 +320,6 @@ impl Display for Func { Self::Math(math_func) => write!(f, "{}", math_func), #[cfg(feature = "json")] Self::Json(json_func) => write!(f, "{}", json_func), - Self::Extension(ext_func) => write!(f, "{}", ext_func), Self::External(generic_func) => write!(f, "{}", generic_func), } } @@ -427,10 +426,7 @@ impl Func { "tan" => Ok(Self::Math(MathFunc::Tan)), "tanh" => Ok(Self::Math(MathFunc::Tanh)), "trunc" => Ok(Self::Math(MathFunc::Trunc)), - _ => match ExtFunc::resolve_function(name, arg_count) { - Some(ext_func) => Ok(Self::Extension(ext_func)), - None => Err(()), - }, + _ => Err(()), } } } diff --git a/core/lib.rs b/core/lib.rs index 21b7cf102..4b3e8155b 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -17,9 +17,9 @@ mod vdbe; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; -use extension_api::{Extension, ExtensionApi}; use fallible_iterator::FallibleIterator; use libloading::{Library, Symbol}; +use limbo_extension::{ExtensionApi, ExtensionEntryPoint, RESULT_OK}; use log::trace; use schema::Schema; use sqlite3_parser::ast; @@ -134,7 +134,6 @@ impl Database { _shared_wal: shared_wal.clone(), syms, }; - // ext::init(&mut db); let db = Arc::new(db); let conn = Rc::new(Connection { db: db.clone(), @@ -168,31 +167,37 @@ impl Database { pub fn define_scalar_function>( &self, name: S, - func: Arc, + func: limbo_extension::ScalarFunction, ) { let func = function::ExternalFunc { name: name.as_ref().to_string(), - func: func.clone(), + func, }; self.syms .borrow_mut() .functions - .insert(name.as_ref().to_string(), Arc::new(func)); + .insert(name.as_ref().to_string(), func.into()); } pub fn load_extension(&self, path: &str) -> Result<()> { + let api = Box::new(self.build_limbo_extension()); let lib = unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? }; - unsafe { - let register: Symbol Box> = - lib.get(b"register_extension") - .map_err(|e| LimboError::ExtensionError(e.to_string()))?; - let extension = register(self); - extension - .load() - .map_err(|e| LimboError::ExtensionError(e.to_string()))?; + let entry: Symbol = unsafe { + lib.get(b"register_extension") + .map_err(|e| LimboError::ExtensionError(e.to_string()))? + }; + let api_ptr: *const ExtensionApi = Box::into_raw(api); + let result_code = entry(api_ptr); + if result_code == RESULT_OK { + self.syms.borrow_mut().extensions.push((lib, api_ptr)); + Ok(()) + } else { + let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; // own this again so we dont leak + Err(LimboError::ExtensionError( + "Extension registration failed".to_string(), + )) } - Ok(()) } } @@ -321,7 +326,11 @@ impl Connection { Cmd::ExplainQueryPlan(stmt) => { match stmt { ast::Stmt::Select(select) => { - let mut plan = prepare_select_plan(&self.schema.borrow(), *select)?; + let mut plan = prepare_select_plan( + &self.schema.borrow(), + *select, + &self.db.syms.borrow(), + )?; optimize_plan(&mut plan)?; println!("{}", plan); } @@ -487,8 +496,8 @@ impl Rows { } pub(crate) struct SymbolTable { - pub functions: HashMap>, - extensions: Vec>, + pub functions: HashMap>, + extensions: Vec<(libloading::Library, *const ExtensionApi)>, } impl std::fmt::Debug for SymbolTable { @@ -511,7 +520,7 @@ impl SymbolTable { &self, name: &str, _arg_count: usize, - ) -> Option> { + ) -> Option> { self.functions.get(name).cloned() } } diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index e376da160..aa1aaecf9 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -35,14 +35,13 @@ impl<'a> Resolver<'a> { } pub fn resolve_function(&self, func_name: &str, arg_count: usize) -> Option { - let func_type = match Func::resolve_function(&func_name, arg_count).ok() { + match Func::resolve_function(func_name, arg_count).ok() { Some(func) => Some(func), None => self .symbol_table - .resolve_function(&func_name, arg_count) - .map(|func| Func::External(func)), - }; - func_type + .resolve_function(func_name, arg_count) + .map(|arg| Func::External(arg.clone())), + } } pub fn resolve_cached_expr_reg(&self, expr: &ast::Expr) -> Option { diff --git a/core/translate/expr.rs b/core/translate/expr.rs index dd35a5d26..a464e5279 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1,7 +1,5 @@ use sqlite3_parser::ast::{self, UnaryOperator}; -#[cfg(feature = "uuid")] -use crate::ext::{ExtFunc, UuidFunc}; #[cfg(feature = "json")] use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc}; @@ -1428,60 +1426,6 @@ pub fn translate_expr( } } } - Func::Extension(ext_func) => match ext_func { - #[cfg(feature = "uuid")] - ExtFunc::Uuid(ref uuid_fn) => match uuid_fn { - UuidFunc::UuidStr | UuidFunc::UuidBlob | UuidFunc::Uuid7TS => { - let args = expect_arguments_exact!(args, 1, ext_func); - let regs = program.alloc_register(); - translate_expr(program, referenced_tables, &args[0], regs, resolver)?; - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: regs, - dest: target_register, - func: func_ctx, - }); - Ok(target_register) - } - UuidFunc::Uuid4Str => { - if args.is_some() { - crate::bail_parse_error!( - "{} function with arguments", - ext_func.to_string() - ); - } - let regs = program.alloc_register(); - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: regs, - dest: target_register, - func: func_ctx, - }); - Ok(target_register) - } - UuidFunc::Uuid7 => { - let args = expect_arguments_max!(args, 1, ext_func); - let mut start_reg = None; - if let Some(arg) = args.first() { - start_reg = Some(translate_and_mark( - program, - referenced_tables, - arg, - resolver, - )?); - } - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: start_reg.unwrap_or(target_register), - dest: target_register, - func: func_ctx, - }); - Ok(target_register) - } - }, - #[allow(unreachable_patterns)] - _ => unreachable!("{ext_func} not implemented yet"), - }, Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => { if args.is_some() { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 64a0ffb04..f5c835f8e 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,6 +1,7 @@ use super::{ plan::{Aggregate, Plan, SelectQueryType, SourceOperator, TableReference, TableReferenceType}, select::prepare_select_plan, + SymbolTable, }; use crate::{ function::Func, @@ -259,6 +260,7 @@ fn parse_from_clause_table( table: ast::SelectTable, operator_id_counter: &mut OperatorIdCounter, cur_table_index: usize, + syms: &SymbolTable, ) -> Result<(TableReference, SourceOperator)> { match table { ast::SelectTable::Table(qualified_name, maybe_alias, _) => { @@ -289,7 +291,7 @@ fn parse_from_clause_table( )) } ast::SelectTable::Select(subselect, maybe_alias) => { - let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect)? else { + let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect, syms)? else { unreachable!(); }; subplan.query_type = SelectQueryType::Subquery { @@ -322,6 +324,7 @@ pub fn parse_from( schema: &Schema, mut from: Option, operator_id_counter: &mut OperatorIdCounter, + syms: &SymbolTable, ) -> Result<(SourceOperator, Vec)> { if from.as_ref().and_then(|f| f.select.as_ref()).is_none() { return Ok(( @@ -339,7 +342,7 @@ pub fn parse_from( let select_owned = *std::mem::take(&mut from_owned.select).unwrap(); let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default(); let (table_reference, mut operator) = - parse_from_clause_table(schema, select_owned, operator_id_counter, table_index)?; + parse_from_clause_table(schema, select_owned, operator_id_counter, table_index, syms)?; tables.push(table_reference); table_index += 1; @@ -350,7 +353,14 @@ pub fn parse_from( is_outer_join: outer, using, predicates, - } = parse_join(schema, join, operator_id_counter, &mut tables, table_index)?; + } = parse_join( + schema, + join, + operator_id_counter, + &mut tables, + table_index, + syms, + )?; operator = SourceOperator::Join { left: Box::new(operator), right: Box::new(right), @@ -394,6 +404,7 @@ fn parse_join( operator_id_counter: &mut OperatorIdCounter, tables: &mut Vec, table_index: usize, + syms: &SymbolTable, ) -> Result { let ast::JoinedSelectTable { operator: join_operator, @@ -402,7 +413,7 @@ fn parse_join( } = join; let (table_reference, source_operator) = - parse_from_clause_table(schema, table, operator_id_counter, table_index)?; + parse_from_clause_table(schema, table, operator_id_counter, table_index, syms)?; tables.push(table_reference); diff --git a/core/translate/select.rs b/core/translate/select.rs index b1be01169..dfbd4c2fb 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -20,12 +20,15 @@ pub fn translate_select( select: ast::Select, syms: &SymbolTable, ) -> Result<()> { - let mut select_plan = prepare_select_plan(schema, select)?; optimize_plan(&mut select_plan)?; emit_program(program, select_plan, syms) } -pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result { +pub fn prepare_select_plan( + schema: &Schema, + select: ast::Select, + syms: &SymbolTable, +) -> Result { match *select.body.select { ast::OneSelect::Select { mut columns, @@ -42,7 +45,8 @@ pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result let mut operator_id_counter = OperatorIdCounter::new(); // Parse the FROM clause - let (source, referenced_tables) = parse_from(schema, from, &mut operator_id_counter)?; + let (source, referenced_tables) = + parse_from(schema, from, &mut operator_id_counter, syms)?; let mut plan = SelectPlan { source, @@ -142,7 +146,25 @@ pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result contains_aggregates, }); } - _ => {} + Err(_) => { + if syms.functions.contains_key(&name.0) { + // TODO: future extensions can be aggregate functions + log::debug!( + "Resolving {} function from symbol table", + name.0 + ); + plan.result_columns.push(ResultSetColumn { + name: get_name( + maybe_alias.as_ref(), + expr, + &plan.referenced_tables, + || format!("expr_{}", result_column_idx), + ), + expr: expr.clone(), + contains_aggregates: false, + }); + } + } } } ast::Expr::FunctionCallStar { diff --git a/core/types.rs b/core/types.rs index 3d3756799..b35d0dc98 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,11 +1,10 @@ use crate::error::LimboError; +use crate::ext::{ExtValue, ExtValueType}; +use crate::storage::sqlite3_ondisk::write_varint; use crate::Result; -use extension_api::Value as ExtValue; use std::fmt::Display; use std::rc::Rc; -use crate::storage::sqlite3_ondisk::write_varint; - #[derive(Debug, Clone, PartialEq)] pub enum Value<'a> { Null, @@ -15,45 +14,6 @@ pub enum Value<'a> { Blob(&'a Vec), } -impl From<&OwnedValue> for extension_api::Value { - fn from(value: &OwnedValue) -> Self { - match value { - OwnedValue::Null => extension_api::Value::Null, - OwnedValue::Integer(i) => extension_api::Value::Integer(*i), - OwnedValue::Float(f) => extension_api::Value::Float(*f), - OwnedValue::Text(text) => extension_api::Value::Text(text.value.to_string()), - OwnedValue::Blob(blob) => extension_api::Value::Blob(blob.to_vec()), - OwnedValue::Agg(_) => { - panic!("Cannot convert Aggregate context to extension_api::Value") - } // Handle appropriately - OwnedValue::Record(_) => panic!("Cannot convert Record to extension_api::Value"), // Handle appropriately - } - } -} -impl From for OwnedValue { - fn from(value: ExtValue) -> Self { - match value { - ExtValue::Null => OwnedValue::Null, - ExtValue::Integer(i) => OwnedValue::Integer(i), - ExtValue::Float(f) => OwnedValue::Float(f), - ExtValue::Text(text) => OwnedValue::Text(LimboText::new(Rc::new(text.to_string()))), - ExtValue::Blob(blob) => OwnedValue::Blob(Rc::new(blob.to_vec())), - } - } -} - -impl<'a> From<&'a crate::Value<'a>> for ExtValue { - fn from(value: &'a crate::Value<'a>) -> Self { - match value { - crate::Value::Null => extension_api::Value::Null, - crate::Value::Integer(i) => extension_api::Value::Integer(*i), - crate::Value::Float(f) => extension_api::Value::Float(*f), - crate::Value::Text(t) => extension_api::Value::Text(t.to_string()), - crate::Value::Blob(b) => extension_api::Value::Blob(b.to_vec()), - } - } -} - impl Display for Value<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -132,6 +92,41 @@ impl Display for OwnedValue { } } } +impl OwnedValue { + pub fn to_ffi(&self) -> ExtValue { + match self { + Self::Null => ExtValue::null(), + Self::Integer(i) => ExtValue::from_integer(*i), + Self::Float(fl) => ExtValue::from_float(*fl), + Self::Text(s) => ExtValue::from_text(s.value.to_string()), + Self::Blob(b) => ExtValue::from_blob(b), + Self::Agg(_) => todo!(), + Self::Record(_) => todo!(), + } + } + pub fn from_ffi(v: &ExtValue) -> Self { + match v.value_type { + ExtValueType::Null => OwnedValue::Null, + ExtValueType::Integer => OwnedValue::Integer(v.integer), + ExtValueType::Float => OwnedValue::Float(v.float), + ExtValueType::Text => { + if v.text.is_null() { + OwnedValue::Null + } else { + OwnedValue::build_text(std::rc::Rc::new(v.text.to_string())) + } + } + ExtValueType::Blob => { + if v.blob.data.is_null() { + OwnedValue::Null + } else { + let bytes = unsafe { std::slice::from_raw_parts(v.blob.data, v.blob.size) }; + OwnedValue::Blob(std::rc::Rc::new(bytes.to_vec())) + } + } + } + } +} #[derive(Debug, Clone, PartialEq)] pub enum AggContext { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index f6903619b..075dd5dab 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -25,8 +25,7 @@ pub mod likeop; pub mod sorter; use crate::error::{LimboError, SQLITE_CONSTRAINT_PRIMARYKEY}; -#[cfg(feature = "uuid")] -use crate::ext::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, ExtFunc, UuidFunc}; +use crate::ext::ExtValue; use crate::function::{AggFunc, FuncCtx, MathFunc, MathFuncArity, ScalarFunc}; use crate::pseudo::PseudoCursor; use crate::result::LimboResult; @@ -53,9 +52,10 @@ use rand::distributions::{Distribution, Uniform}; use rand::{thread_rng, Rng}; use regex::{Regex, RegexBuilder}; use sorter::Sorter; -use std::borrow::{Borrow, BorrowMut}; +use std::borrow::BorrowMut; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap}; +use std::os::raw::c_void; use std::rc::{Rc, Weak}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -1838,47 +1838,20 @@ impl Program { state.registers[*dest] = exec_replace(source, pattern, replacement); } }, - #[allow(unreachable_patterns)] - crate::function::Func::Extension(extfn) => match extfn { - #[cfg(feature = "uuid")] - ExtFunc::Uuid(uuidfn) => match uuidfn { - UuidFunc::Uuid4Str => { - state.registers[*dest] = exec_uuid(uuidfn, None)? - } - UuidFunc::Uuid7 => match arg_count { - 0 => { - state.registers[*dest] = - exec_uuid(uuidfn, None).unwrap_or(OwnedValue::Null); - } - 1 => { - let reg_value = state.registers[*start_reg].borrow(); - state.registers[*dest] = exec_uuid(uuidfn, Some(reg_value)) - .unwrap_or(OwnedValue::Null); - } - _ => unreachable!(), - }, - _ => { - // remaining accept 1 arg - let reg_value = state.registers[*start_reg].borrow(); - state.registers[*dest] = match uuidfn { - UuidFunc::Uuid7TS => Some(exec_ts_from_uuid7(reg_value)), - UuidFunc::UuidStr => exec_uuidstr(reg_value).ok(), - UuidFunc::UuidBlob => exec_uuidblob(reg_value).ok(), - _ => unreachable!(), - } - .unwrap_or(OwnedValue::Null); - } - }, - _ => unreachable!(), // when more extension types are added - }, crate::function::Func::External(f) => { let values = &state.registers[*start_reg..*start_reg + arg_count]; - let args: Vec<_> = values.into_iter().map(|v| v.into()).collect(); - let result = f - .func - .execute(args.as_slice()) - .map_err(|e| LimboError::ExtensionError(e.to_string()))?; - state.registers[*dest] = result.into(); + let c_values: Vec<*const c_void> = values + .iter() + .map(|ov| &ov.to_ffi() as *const _ as *const c_void) + .collect(); + let argv_ptr = if c_values.is_empty() { + std::ptr::null() + } else { + c_values.as_ptr() + }; + let result_c_value: ExtValue = (f.func)(arg_count as i32, argv_ptr); + let result_ov = OwnedValue::from_ffi(&result_c_value); + state.registers[*dest] = result_ov; } crate::function::Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => match math_func { diff --git a/extension_api/src/lib.rs b/extension_api/src/lib.rs deleted file mode 100644 index 7ee26e329..000000000 --- a/extension_api/src/lib.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::any::Any; -use std::rc::Rc; -use std::sync::Arc; - -pub type Result = std::result::Result; - -pub trait Extension { - fn load(&self) -> Result<()>; -} - -#[derive(Debug)] -pub enum LimboApiError { - ConnectionError(String), - RegisterFunctionError(String), - ValueError(String), - VTableError(String), -} - -impl From for LimboApiError { - fn from(e: std::io::Error) -> Self { - Self::ConnectionError(e.to_string()) - } -} - -impl std::fmt::Display for LimboApiError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::ConnectionError(e) => write!(f, "Connection error: {e}"), - Self::RegisterFunctionError(e) => write!(f, "Register function error: {e}"), - Self::ValueError(e) => write!(f, "Value error: {e}"), - Self::VTableError(e) => write!(f, "VTable error: {e}"), - } - } -} - -pub trait ExtensionApi { - fn register_scalar_function(&self, name: &str, func: Arc) -> Result<()>; - fn register_aggregate_function( - &self, - name: &str, - func: Arc, - ) -> Result<()>; - fn register_virtual_table(&self, name: &str, table: Arc) -> Result<()>; -} - -pub trait ScalarFunction { - fn execute(&self, args: &[Value]) -> Result; -} - -pub trait AggregateFunction { - fn init(&self) -> Box; - fn step(&self, state: &mut dyn Any, args: &[Value]) -> Result<()>; - fn finalize(&self, state: Box) -> Result; -} - -pub trait VirtualTable { - fn schema(&self) -> &'static str; - fn create_cursor(&self) -> Box; -} - -pub trait Cursor { - fn next(&mut self) -> Result>; -} - -pub struct Row { - pub values: Vec, -} - -pub enum Value { - Text(String), - Blob(Vec), - Integer(i64), - Float(f64), - Null, -} diff --git a/extensions/uuid/Cargo.toml b/extensions/uuid/Cargo.toml index c6ae90bdf..1b9eb10a2 100644 --- a/extensions/uuid/Cargo.toml +++ b/extensions/uuid/Cargo.toml @@ -1,11 +1,15 @@ [package] -name = "uuid" +name = "limbo_uuid" version.workspace = true authors.workspace = true edition.workspace = true license.workspace = true repository.workspace = true +[lib] +crate-type = ["cdylib"] + + [dependencies] -extension_api = { path = "../../extension_api"} +limbo_extension = { path = "../../limbo_extension"} uuid = { version = "1.11.0", features = ["v4", "v7"] } diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index e69de29bb..c9950be8f 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -0,0 +1,62 @@ +use limbo_extension::{ + declare_scalar_functions, register_extension, register_scalar_functions, Value, +}; + +register_extension! { + scalars: { + "uuid4_str" => uuid4_str, + "uuid4" => uuid4_blob, + "uuid_str" => uuid_str, + "uuid_blob" => uuid_blob, + }, +} + +declare_scalar_functions! { + #[args(min = 0, max = 0)] + fn uuid4_str(_args: &[Value]) -> Value { + let uuid = uuid::Uuid::new_v4().to_string(); + Value::from_text(uuid) + } + #[args(min = 0, max = 0)] + fn uuid4_blob(_args: &[Value]) -> Value { + let uuid = uuid::Uuid::new_v4(); + let bytes = uuid.as_bytes(); + Value::from_blob(bytes) + } + + #[args(min = 1, max = 1)] + fn uuid_str(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::null(); + } + if args[0].value_type != limbo_extension::ValueType::Blob { + return Value::null(); + } + let data_ptr = args[0].blob.data; + let size = args[0].blob.size; + if data_ptr.is_null() || size != 16 { + return Value::null(); + } + let slice = unsafe{ std::slice::from_raw_parts(data_ptr, size)}; + let parsed = uuid::Uuid::from_slice(slice).ok().map(|u| u.to_string()); + match parsed { + Some(s) => Value::from_text(s), + None => Value::null() + } + } + + #[args(min = 1, max = 1)] + fn uuid_blob(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::null(); + } + if args[0].value_type != limbo_extension::ValueType::Text { + return Value::null(); + } + let text = args[0].text.to_string(); + match uuid::Uuid::parse_str(&text) { + Ok(uuid) => Value::from_blob(uuid.as_bytes()), + Err(_) => Value::null() + } + } +} diff --git a/extension_api/Cargo.toml b/limbo_extension/Cargo.toml similarity index 86% rename from extension_api/Cargo.toml rename to limbo_extension/Cargo.toml index 73056af33..d3ac246d2 100644 --- a/extension_api/Cargo.toml +++ b/limbo_extension/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "extension_api" +name = "limbo_extension" version.workspace = true authors.workspace = true edition.workspace = true diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs new file mode 100644 index 000000000..6704f63ec --- /dev/null +++ b/limbo_extension/src/lib.rs @@ -0,0 +1,233 @@ +use std::ffi::CString; +use std::os::raw::{c_char, c_void}; + +pub type ResultCode = i32; + +pub const RESULT_OK: ResultCode = 0; +pub const RESULT_ERROR: ResultCode = 1; +// TODO: more error types + +pub type ExtensionEntryPoint = extern "C" fn(api: *const ExtensionApi) -> ResultCode; +pub type ScalarFunction = extern "C" fn(argc: i32, *const *const c_void) -> Value; + +#[repr(C)] +pub struct ExtensionApi { + pub ctx: *mut c_void, + pub register_scalar_function: + extern "C" fn(ctx: *mut c_void, name: *const c_char, func: ScalarFunction) -> ResultCode, +} + +#[macro_export] +macro_rules! register_extension { + ( + scalars: { $( $scalar_name:expr => $scalar_func:ident ),* $(,)? }, + //aggregates: { $( $agg_name:expr => ($step_func:ident, $finalize_func:ident) ),* $(,)? }, + //virtual_tables: { $( $vt_name:expr => $vt_impl:expr ),* $(,)? } + ) => { + #[no_mangle] + pub unsafe extern "C" fn register_extension(api: *const $crate::ExtensionApi) -> $crate::ResultCode { + if api.is_null() { + return $crate::RESULT_ERROR; + } + + register_scalar_functions! { api, $( $scalar_name => $scalar_func ),* } + // TODO: + //register_aggregate_functions! { $( $agg_name => ($step_func, $finalize_func) ),* } + //register_virtual_tables! { $( $vt_name => $vt_impl ),* } + $crate::RESULT_OK + } + } +} + +#[macro_export] +macro_rules! register_scalar_functions { + ( $api:expr, $( $fname:expr => $fptr:ident ),* ) => { + unsafe { + $( + let cname = std::ffi::CString::new($fname).unwrap(); + ((*$api).register_scalar_function)((*$api).ctx, cname.as_ptr(), $fptr); + )* + } + } +} + +/// Provide a cleaner interface to define scalar functions to extension authors +/// . e.g. +/// ``` +/// fn scalar_func(args: &[Value]) -> Value { +/// if args.len() != 1 { +/// return Value::null(); +/// } +/// Value::from_integer(args[0].integer * 2) +/// } +/// ``` +/// +#[macro_export] +macro_rules! declare_scalar_functions { + ( + $( + #[args(min = $min_args:literal, max = $max_args:literal)] + fn $func_name:ident ($args:ident : &[Value]) -> Value $body:block + )* + ) => { + $( + extern "C" fn $func_name( + argc: i32, + argv: *const *const std::os::raw::c_void + ) -> $crate::Value { + if !($min_args..=$max_args).contains(&argc) { + println!("{}: Invalid argument count", stringify!($func_name)); + return $crate::Value::null();// TODO: error code + } + if argc == 0 || argv.is_null() { + let $args: &[$crate::Value] = &[]; + $body + } else { + unsafe { + let ptr_slice = std::slice::from_raw_parts(argv, argc as usize); + let mut values = Vec::with_capacity(argc as usize); + for &ptr in ptr_slice { + let val_ptr = ptr as *const $crate::Value; + if val_ptr.is_null() { + values.push($crate::Value::null()); + } else { + values.push(std::ptr::read(val_ptr)); + } + } + let $args: &[$crate::Value] = &values[..]; + $body + } + } + } + )* + }; +} + +#[derive(PartialEq, Eq)] +#[repr(C)] +pub enum ValueType { + Null, + Integer, + Float, + Text, + Blob, +} + +// TODO: perf, these can be better expressed +#[repr(C)] +pub struct Value { + pub value_type: ValueType, + pub integer: i64, + pub float: f64, + pub text: TextValue, + pub blob: Blob, +} + +#[repr(C)] +pub struct TextValue { + text: *const c_char, + len: usize, +} + +impl std::fmt::Display for TextValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.text.is_null() { + return write!(f, ""); + } + let slice = unsafe { std::slice::from_raw_parts(self.text as *const u8, self.len) }; + match std::str::from_utf8(slice) { + Ok(s) => write!(f, "{}", s), + Err(e) => write!(f, "", e), + } + } +} + +impl TextValue { + pub fn is_null(&self) -> bool { + self.text.is_null() + } + + pub fn new(text: *const c_char, len: usize) -> Self { + Self { text, len } + } + + pub fn null() -> Self { + Self { + text: std::ptr::null(), + len: 0, + } + } +} + +#[repr(C)] +pub struct Blob { + pub data: *const u8, + pub size: usize, +} + +impl Blob { + pub fn new(data: *const u8, size: usize) -> Self { + Self { data, size } + } + pub fn null() -> Self { + Self { + data: std::ptr::null(), + size: 0, + } + } +} + +impl Value { + pub fn null() -> Self { + Self { + value_type: ValueType::Null, + integer: 0, + float: 0.0, + text: TextValue::null(), + blob: Blob::null(), + } + } + + pub fn from_integer(value: i64) -> Self { + Self { + value_type: ValueType::Integer, + integer: value, + float: 0.0, + text: TextValue::null(), + blob: Blob::null(), + } + } + pub fn from_float(value: f64) -> Self { + Self { + value_type: ValueType::Float, + integer: 0, + float: value, + text: TextValue::null(), + blob: Blob::null(), + } + } + + pub fn from_text(value: String) -> Self { + let cstr = CString::new(&*value).unwrap(); + let ptr = cstr.as_ptr(); + let len = value.len(); + std::mem::forget(cstr); + Self { + value_type: ValueType::Text, + integer: 0, + float: 0.0, + text: TextValue::new(ptr, len), + blob: Blob::null(), + } + } + + pub fn from_blob(value: &[u8]) -> Self { + Self { + value_type: ValueType::Blob, + integer: 0, + float: 0.0, + text: TextValue::null(), + blob: Blob::new(value.as_ptr(), value.len()), + } + } +} From c565fba1951c3a9f92cd3ceed837a96c4bd39385 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 12 Jan 2025 11:35:02 -0500 Subject: [PATCH 3/9] Adjust types in extension API --- Cargo.lock | 4 + core/ext/mod.rs | 4 +- core/types.rs | 47 ++++++++---- extensions/uuid/Cargo.toml | 3 +- extensions/uuid/src/lib.rs | 35 ++++----- limbo_extension/Cargo.toml | 1 + limbo_extension/src/lib.rs | 145 ++++++++++++++++++++----------------- 7 files changed, 139 insertions(+), 100 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eb09669a3..d3fffc7a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1260,6 +1260,9 @@ dependencies = [ [[package]] name = "limbo_extension" version = "0.0.11" +dependencies = [ + "log", +] [[package]] name = "limbo_macros" @@ -1294,6 +1297,7 @@ name = "limbo_uuid" version = "0.0.11" dependencies = [ "limbo_extension", + "log", "uuid", ] diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 79936079e..179b3e7c6 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,6 +1,8 @@ use crate::{function::ExternalFunc, Database}; +pub use limbo_extension::{ + Blob as ExtBlob, TextValue as ExtTextValue, Value as ExtValue, ValueType as ExtValueType, +}; use limbo_extension::{ExtensionApi, ResultCode, ScalarFunction, RESULT_ERROR, RESULT_OK}; -pub use limbo_extension::{Value as ExtValue, ValueType as ExtValueType}; use std::{ ffi::{c_char, c_void, CStr}, rc::Rc, diff --git a/core/types.rs b/core/types.rs index b35d0dc98..92d02fce5 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,5 +1,5 @@ use crate::error::LimboError; -use crate::ext::{ExtValue, ExtValueType}; +use crate::ext::{ExtBlob, ExtTextValue, ExtValue, ExtValueType}; use crate::storage::sqlite3_ondisk::write_varint; use crate::Result; use std::fmt::Display; @@ -92,37 +92,54 @@ impl Display for OwnedValue { } } } + impl OwnedValue { pub fn to_ffi(&self) -> ExtValue { match self { Self::Null => ExtValue::null(), Self::Integer(i) => ExtValue::from_integer(*i), Self::Float(fl) => ExtValue::from_float(*fl), - Self::Text(s) => ExtValue::from_text(s.value.to_string()), - Self::Blob(b) => ExtValue::from_blob(b), - Self::Agg(_) => todo!(), - Self::Record(_) => todo!(), + Self::Text(text) => ExtValue::from_text(text.value.to_string()), + Self::Blob(blob) => ExtValue::from_blob(blob.to_vec()), + Self::Agg(_) => todo!("Aggregate values not yet supported"), + Self::Record(_) => todo!("Record values not yet supported"), } } + pub fn from_ffi(v: &ExtValue) -> Self { + if v.value.is_null() { + return OwnedValue::Null; + } match v.value_type { ExtValueType::Null => OwnedValue::Null, - ExtValueType::Integer => OwnedValue::Integer(v.integer), - ExtValueType::Float => OwnedValue::Float(v.float), + ExtValueType::Integer => { + let int_ptr = v.value as *mut i64; + let integer = unsafe { *int_ptr }; + OwnedValue::Integer(integer) + } + ExtValueType::Float => { + let float_ptr = v.value as *mut f64; + let float = unsafe { *float_ptr }; + OwnedValue::Float(float) + } ExtValueType::Text => { - if v.text.is_null() { + if v.value.is_null() { OwnedValue::Null } else { - OwnedValue::build_text(std::rc::Rc::new(v.text.to_string())) + let Some(text) = ExtTextValue::from_value(v) else { + return OwnedValue::Null; + }; + OwnedValue::build_text(std::rc::Rc::new(unsafe { text.as_str().to_string() })) } } ExtValueType::Blob => { - if v.blob.data.is_null() { - OwnedValue::Null - } else { - let bytes = unsafe { std::slice::from_raw_parts(v.blob.data, v.blob.size) }; - OwnedValue::Blob(std::rc::Rc::new(bytes.to_vec())) - } + let blob_ptr = v.value as *mut ExtBlob; + let blob = unsafe { + let slice = + std::slice::from_raw_parts((*blob_ptr).data, (*blob_ptr).size as usize); + slice.to_vec() + }; + OwnedValue::Blob(std::rc::Rc::new(blob)) } } } diff --git a/extensions/uuid/Cargo.toml b/extensions/uuid/Cargo.toml index 1b9eb10a2..ed2c43e87 100644 --- a/extensions/uuid/Cargo.toml +++ b/extensions/uuid/Cargo.toml @@ -7,9 +7,10 @@ license.workspace = true repository.workspace = true [lib] -crate-type = ["cdylib"] +crate-type = ["cdylib", "lib"] [dependencies] limbo_extension = { path = "../../limbo_extension"} uuid = { version = "1.11.0", features = ["v4", "v7"] } +log = "0.4.20" diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index c9950be8f..940658155 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -1,5 +1,5 @@ use limbo_extension::{ - declare_scalar_functions, register_extension, register_scalar_functions, Value, + declare_scalar_functions, register_extension, register_scalar_functions, Blob, TextValue, Value, }; register_extension! { @@ -17,46 +17,47 @@ declare_scalar_functions! { let uuid = uuid::Uuid::new_v4().to_string(); Value::from_text(uuid) } + #[args(min = 0, max = 0)] fn uuid4_blob(_args: &[Value]) -> Value { let uuid = uuid::Uuid::new_v4(); let bytes = uuid.as_bytes(); - Value::from_blob(bytes) + Value::from_blob(bytes.to_vec()) } #[args(min = 1, max = 1)] fn uuid_str(args: &[Value]) -> Value { - if args.len() != 1 { - return Value::null(); - } if args[0].value_type != limbo_extension::ValueType::Blob { + log::debug!("uuid_str was passed a non-blob arg"); return Value::null(); } - let data_ptr = args[0].blob.data; - let size = args[0].blob.size; - if data_ptr.is_null() || size != 16 { - return Value::null(); - } - let slice = unsafe{ std::slice::from_raw_parts(data_ptr, size)}; + if let Some(blob) = Blob::from_value(&args[0]) { + let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; let parsed = uuid::Uuid::from_slice(slice).ok().map(|u| u.to_string()); match parsed { Some(s) => Value::from_text(s), None => Value::null() } + } else { + Value::null() + } } #[args(min = 1, max = 1)] fn uuid_blob(args: &[Value]) -> Value { - if args.len() != 1 { - return Value::null(); - } if args[0].value_type != limbo_extension::ValueType::Text { + log::debug!("uuid_blob was passed a non-text arg"); return Value::null(); } - let text = args[0].text.to_string(); - match uuid::Uuid::parse_str(&text) { - Ok(uuid) => Value::from_blob(uuid.as_bytes()), + if let Some(text) = TextValue::from_value(&args[0]) { + match uuid::Uuid::parse_str(unsafe {text.as_str()}) { + Ok(uuid) => { + Value::from_blob(uuid.as_bytes().to_vec()) + } Err(_) => Value::null() } + } else { + Value::null() + } } } diff --git a/limbo_extension/Cargo.toml b/limbo_extension/Cargo.toml index d3ac246d2..2928ed853 100644 --- a/limbo_extension/Cargo.toml +++ b/limbo_extension/Cargo.toml @@ -7,3 +7,4 @@ license.workspace = true repository.workspace = true [dependencies] +log = "0.4.20" diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index 6704f63ec..3daf99755 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -1,4 +1,3 @@ -use std::ffi::CString; use std::os::raw::{c_char, c_void}; pub type ResultCode = i32; @@ -76,8 +75,7 @@ macro_rules! declare_scalar_functions { argv: *const *const std::os::raw::c_void ) -> $crate::Value { if !($min_args..=$max_args).contains(&argc) { - println!("{}: Invalid argument count", stringify!($func_name)); - return $crate::Value::null();// TODO: error code + return $crate::Value::null(); } if argc == 0 || argv.is_null() { let $args: &[$crate::Value] = &[]; @@ -103,8 +101,8 @@ macro_rules! declare_scalar_functions { }; } -#[derive(PartialEq, Eq)] #[repr(C)] +#[derive(PartialEq, Eq)] pub enum ValueType { Null, Integer, @@ -113,45 +111,20 @@ pub enum ValueType { Blob, } -// TODO: perf, these can be better expressed #[repr(C)] pub struct Value { pub value_type: ValueType, - pub integer: i64, - pub float: f64, - pub text: TextValue, - pub blob: Blob, + pub value: *mut c_void, } #[repr(C)] pub struct TextValue { - text: *const c_char, - len: usize, + pub text: *const u8, + pub len: u32, } -impl std::fmt::Display for TextValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.text.is_null() { - return write!(f, ""); - } - let slice = unsafe { std::slice::from_raw_parts(self.text as *const u8, self.len) }; - match std::str::from_utf8(slice) { - Ok(s) => write!(f, "{}", s), - Err(e) => write!(f, "", e), - } - } -} - -impl TextValue { - pub fn is_null(&self) -> bool { - self.text.is_null() - } - - pub fn new(text: *const c_char, len: usize) -> Self { - Self { text, len } - } - - pub fn null() -> Self { +impl Default for TextValue { + fn default() -> Self { Self { text: std::ptr::null(), len: 0, @@ -159,21 +132,49 @@ impl TextValue { } } +impl TextValue { + pub fn new(text: *const u8, len: usize) -> Self { + Self { + text, + len: len as u32, + } + } + + pub fn from_value(value: &Value) -> Option<&Self> { + if value.value_type != ValueType::Text { + return None; + } + unsafe { Some(&*(value.value as *const TextValue)) } + } + + /// # Safety + /// The caller must ensure that the text is a valid UTF-8 string + pub unsafe fn as_str(&self) -> &str { + if self.text.is_null() { + return ""; + } + unsafe { + std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.text, self.len as usize)) + } + } +} + #[repr(C)] pub struct Blob { pub data: *const u8, - pub size: usize, + pub size: u64, } impl Blob { - pub fn new(data: *const u8, size: usize) -> Self { + pub fn new(data: *const u8, size: u64) -> Self { Self { data, size } } - pub fn null() -> Self { - Self { - data: std::ptr::null(), - size: 0, + + pub fn from_value(value: &Value) -> Option<&Self> { + if value.value_type != ValueType::Blob { + return None; } + unsafe { Some(&*(value.value as *const Blob)) } } } @@ -181,53 +182,65 @@ impl Value { pub fn null() -> Self { Self { value_type: ValueType::Null, - integer: 0, - float: 0.0, - text: TextValue::null(), - blob: Blob::null(), + value: std::ptr::null_mut(), } } pub fn from_integer(value: i64) -> Self { + let boxed = Box::new(value); Self { value_type: ValueType::Integer, - integer: value, - float: 0.0, - text: TextValue::null(), - blob: Blob::null(), + value: Box::into_raw(boxed) as *mut c_void, } } + pub fn from_float(value: f64) -> Self { + let boxed = Box::new(value); Self { value_type: ValueType::Float, - integer: 0, - float: value, - text: TextValue::null(), - blob: Blob::null(), + value: Box::into_raw(boxed) as *mut c_void, } } - pub fn from_text(value: String) -> Self { - let cstr = CString::new(&*value).unwrap(); - let ptr = cstr.as_ptr(); - let len = value.len(); - std::mem::forget(cstr); + pub fn from_text(s: String) -> Self { + let text_value = TextValue::new(s.as_ptr(), s.len()); + let boxed_text = Box::new(text_value); + std::mem::forget(s); Self { value_type: ValueType::Text, - integer: 0, - float: 0.0, - text: TextValue::new(ptr, len), - blob: Blob::null(), + value: Box::into_raw(boxed_text) as *mut c_void, } } - pub fn from_blob(value: &[u8]) -> Self { + pub fn from_blob(value: Vec) -> Self { + let boxed = Box::new(Blob::new(value.as_ptr(), value.len() as u64)); + std::mem::forget(value); Self { value_type: ValueType::Blob, - integer: 0, - float: 0.0, - text: TextValue::null(), - blob: Blob::new(value.as_ptr(), value.len()), + value: Box::into_raw(boxed) as *mut c_void, } } + + pub unsafe fn free(&mut self) { + if self.value.is_null() { + return; + } + match self.value_type { + ValueType::Integer => { + let _ = Box::from_raw(self.value as *mut i64); + } + ValueType::Float => { + let _ = Box::from_raw(self.value as *mut f64); + } + ValueType::Text => { + let _ = Box::from_raw(self.value as *mut TextValue); + } + ValueType::Blob => { + let _ = Box::from_raw(self.value as *mut Blob); + } + ValueType::Null => {} + } + + self.value = std::ptr::null_mut(); + } } From 852817c9ff4f3b5833f46a9baea17fbd5cb29943 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 12 Jan 2025 12:19:05 -0500 Subject: [PATCH 4/9] Have args macro in extension take a range --- extensions/uuid/src/lib.rs | 43 ++++++++++++++++++++++++++++++++++---- limbo_extension/src/lib.rs | 17 +++++++++------ 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index 940658155..7166e4bb1 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -6,26 +6,61 @@ register_extension! { scalars: { "uuid4_str" => uuid4_str, "uuid4" => uuid4_blob, + "uuid7_str" => uuid7_str, + "uuid7" => uuid7_blob, "uuid_str" => uuid_str, "uuid_blob" => uuid_blob, }, } declare_scalar_functions! { - #[args(min = 0, max = 0)] + #[args(0)] fn uuid4_str(_args: &[Value]) -> Value { let uuid = uuid::Uuid::new_v4().to_string(); Value::from_text(uuid) } - #[args(min = 0, max = 0)] + #[args(0)] fn uuid4_blob(_args: &[Value]) -> Value { let uuid = uuid::Uuid::new_v4(); let bytes = uuid.as_bytes(); Value::from_blob(bytes.to_vec()) } - #[args(min = 1, max = 1)] + #[args(0..=1)] + fn uuid7_str(args: &[Value]) -> Value { + let timestamp = if args.is_empty() { + let ctx = uuid::ContextV7::new(); + uuid::Timestamp::now(ctx) + } else if args[0].value_type == limbo_extension::ValueType::Integer { + let ctx = uuid::ContextV7::new(); + let int = args[0].value as i64; + uuid::Timestamp::from_unix(ctx, int as u64, 0) + } else { + return Value::null(); + }; + let uuid = uuid::Uuid::new_v7(timestamp); + Value::from_text(uuid.to_string()) + } + + #[args(0..=1)] + fn uuid7_blob(args: &[Value]) -> Value { + let timestamp = if args.is_empty() { + let ctx = uuid::ContextV7::new(); + uuid::Timestamp::now(ctx) + } else if args[0].value_type == limbo_extension::ValueType::Integer { + let ctx = uuid::ContextV7::new(); + let int = args[0].value as i64; + uuid::Timestamp::from_unix(ctx, int as u64, 0) + } else { + return Value::null(); + }; + let uuid = uuid::Uuid::new_v7(timestamp); + let bytes = uuid.as_bytes(); + Value::from_blob(bytes.to_vec()) + } + + #[args(1)] fn uuid_str(args: &[Value]) -> Value { if args[0].value_type != limbo_extension::ValueType::Blob { log::debug!("uuid_str was passed a non-blob arg"); @@ -43,7 +78,7 @@ declare_scalar_functions! { } } - #[args(min = 1, max = 1)] + #[args(1)] fn uuid_blob(args: &[Value]) -> Value { if args[0].value_type != limbo_extension::ValueType::Text { log::debug!("uuid_blob was passed a non-text arg"); diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index 3daf99755..a1e77d53d 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -53,6 +53,7 @@ macro_rules! register_scalar_functions { /// Provide a cleaner interface to define scalar functions to extension authors /// . e.g. /// ``` +/// #[args(1)] /// fn scalar_func(args: &[Value]) -> Value { /// if args.len() != 1 { /// return Value::null(); @@ -65,7 +66,7 @@ macro_rules! register_scalar_functions { macro_rules! declare_scalar_functions { ( $( - #[args(min = $min_args:literal, max = $max_args:literal)] + #[args($($args_count:tt)+)] fn $func_name:ident ($args:ident : &[Value]) -> Value $body:block )* ) => { @@ -74,28 +75,32 @@ macro_rules! declare_scalar_functions { argc: i32, argv: *const *const std::os::raw::c_void ) -> $crate::Value { - if !($min_args..=$max_args).contains(&argc) { + let valid_args = { + match argc { + $($args_count)+ => true, + _ => false, + } + }; + if !valid_args { return $crate::Value::null(); } if argc == 0 || argv.is_null() { let $args: &[$crate::Value] = &[]; $body } else { - unsafe { - let ptr_slice = std::slice::from_raw_parts(argv, argc as usize); + let ptr_slice = unsafe{ std::slice::from_raw_parts(argv, argc as usize)}; let mut values = Vec::with_capacity(argc as usize); for &ptr in ptr_slice { let val_ptr = ptr as *const $crate::Value; if val_ptr.is_null() { values.push($crate::Value::null()); } else { - values.push(std::ptr::read(val_ptr)); + unsafe{values.push(std::ptr::read(val_ptr))}; } } let $args: &[$crate::Value] = &values[..]; $body } - } } )* }; From 98eff6cf7a2a75bdce9e78acf949fff1367ad83d Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 12 Jan 2025 15:24:50 -0500 Subject: [PATCH 5/9] Enable passing arguments to external functions --- core/translate/expr.rs | 12 +++++- core/translate/select.rs | 11 +++--- core/vdbe/mod.rs | 42 +++++++++++++------- extensions/uuid/src/lib.rs | 34 ++++++++++++++++- limbo_extension/src/lib.rs | 78 +++++++++++++++++++++++++++----------- 5 files changed, 132 insertions(+), 45 deletions(-) diff --git a/core/translate/expr.rs b/core/translate/expr.rs index a464e5279..d13902433 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -762,13 +762,23 @@ pub fn translate_expr( crate::bail_parse_error!("aggregation function in non-aggregation context") } Func::External(_) => { - let regs = program.alloc_register(); + let regs = program.alloc_registers(args_count); + for (i, arg_expr) in args.iter().enumerate() { + translate_expr( + program, + referenced_tables, + &arg_expr[i], + regs + i, + resolver, + )?; + } program.emit_insn(Insn::Function { constant_mask: 0, start_reg: regs, dest: target_register, func: func_ctx, }); + Ok(target_register) } #[cfg(feature = "json")] diff --git a/core/translate/select.rs b/core/translate/select.rs index dfbd4c2fb..fa5361205 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -148,10 +148,9 @@ pub fn prepare_select_plan( } Err(_) => { if syms.functions.contains_key(&name.0) { - // TODO: future extensions can be aggregate functions - log::debug!( - "Resolving {} function from symbol table", - name.0 + let contains_aggregates = resolve_aggregates( + expr, + &mut aggregate_expressions, ); plan.result_columns.push(ResultSetColumn { name: get_name( @@ -161,7 +160,7 @@ pub fn prepare_select_plan( || format!("expr_{}", result_column_idx), ), expr: expr.clone(), - contains_aggregates: false, + contains_aggregates, }); } } @@ -202,7 +201,7 @@ pub fn prepare_select_plan( } expr => { let contains_aggregates = - resolve_aggregates(&expr, &mut aggregate_expressions); + resolve_aggregates(expr, &mut aggregate_expressions); plan.result_columns.push(ResultSetColumn { name: get_name( maybe_alias.as_ref(), diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 075dd5dab..580834d39 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -55,7 +55,6 @@ use sorter::Sorter; use std::borrow::BorrowMut; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap}; -use std::os::raw::c_void; use std::rc::{Rc, Weak}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -147,6 +146,33 @@ macro_rules! return_if_io { }; } +macro_rules! call_external_function { + ( + $func_ptr:expr, + $dest_register:expr, + $state:expr, + $arg_count:expr, + $start_reg:expr + ) => {{ + if $arg_count == 0 { + let result_c_value: ExtValue = ($func_ptr)(0, std::ptr::null()); + let result_ov = OwnedValue::from_ffi(&result_c_value); + $state.registers[$dest_register] = result_ov; + } else { + let register_slice = &$state.registers[$start_reg..$start_reg + $arg_count]; + let mut ext_values: Vec = Vec::with_capacity($arg_count); + for ov in register_slice.iter() { + let val = ov.to_ffi(); + ext_values.push(val); + } + let argv_ptr = ext_values.as_ptr(); + let result_c_value: ExtValue = ($func_ptr)($arg_count as i32, argv_ptr); + let result_ov = OwnedValue::from_ffi(&result_c_value); + $state.registers[$dest_register] = result_ov; + } + }}; +} + struct RegexCache { like: HashMap, glob: HashMap, @@ -1839,19 +1865,7 @@ impl Program { } }, crate::function::Func::External(f) => { - let values = &state.registers[*start_reg..*start_reg + arg_count]; - let c_values: Vec<*const c_void> = values - .iter() - .map(|ov| &ov.to_ffi() as *const _ as *const c_void) - .collect(); - let argv_ptr = if c_values.is_empty() { - std::ptr::null() - } else { - c_values.as_ptr() - }; - let result_c_value: ExtValue = (f.func)(arg_count as i32, argv_ptr); - let result_ov = OwnedValue::from_ffi(&result_c_value); - state.registers[*dest] = result_ov; + call_external_function! {f.func, *dest, state, arg_count, *start_reg }; } crate::function::Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => match math_func { diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index 7166e4bb1..7380f82d3 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -1,5 +1,6 @@ use limbo_extension::{ - declare_scalar_functions, register_extension, register_scalar_functions, Blob, TextValue, Value, + declare_scalar_functions, register_extension, register_scalar_functions, Blob, TextValue, + Value, ValueType, }; register_extension! { @@ -10,6 +11,7 @@ register_extension! { "uuid7" => uuid7_blob, "uuid_str" => uuid_str, "uuid_blob" => uuid_blob, + "exec_ts_from_uuid7" => exec_ts_from_uuid7, }, } @@ -60,6 +62,26 @@ declare_scalar_functions! { Value::from_blob(bytes.to_vec()) } + #[args(1)] + fn exec_ts_from_uuid7(args: &[Value]) -> Value { + match args[0].value_type { + ValueType::Blob => { + let blob = Blob::from_value(&args[0]).unwrap(); + let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; + let uuid = uuid::Uuid::from_slice(slice).unwrap(); + let unix = uuid_to_unix(uuid.as_bytes()); + Value::from_integer(unix as i64) + } + ValueType::Text => { + let text = TextValue::from_value(&args[0]).unwrap(); + let uuid = uuid::Uuid::parse_str(unsafe {text.as_str()}).unwrap(); + let unix = uuid_to_unix(uuid.as_bytes()); + Value::from_integer(unix as i64) + } + _ => Value::null(), + } + } + #[args(1)] fn uuid_str(args: &[Value]) -> Value { if args[0].value_type != limbo_extension::ValueType::Blob { @@ -96,3 +118,13 @@ declare_scalar_functions! { } } } + +#[inline(always)] +fn uuid_to_unix(uuid: &[u8; 16]) -> u64 { + ((uuid[0] as u64) << 40) + | ((uuid[1] as u64) << 32) + | ((uuid[2] as u64) << 24) + | ((uuid[3] as u64) << 16) + | ((uuid[4] as u64) << 8) + | (uuid[5] as u64) +} diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index a1e77d53d..7bfc881e5 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -7,7 +7,7 @@ pub const RESULT_ERROR: ResultCode = 1; // TODO: more error types pub type ExtensionEntryPoint = extern "C" fn(api: *const ExtensionApi) -> ResultCode; -pub type ScalarFunction = extern "C" fn(argc: i32, *const *const c_void) -> Value; +pub type ScalarFunction = extern "C" fn(argc: i32, *const Value) -> Value; #[repr(C)] pub struct ExtensionApi { @@ -54,12 +54,13 @@ macro_rules! register_scalar_functions { /// . e.g. /// ``` /// #[args(1)] -/// fn scalar_func(args: &[Value]) -> Value { -/// if args.len() != 1 { -/// return Value::null(); -/// } +/// fn scalar_double(args: &[Value]) -> Value { /// Value::from_integer(args[0].integer * 2) /// } +/// +/// #[args(0..=2)] +/// fn scalar_sum(args: &[Value]) -> Value { +/// Value::from_integer(args.iter().map(|v| v.integer).sum()) /// ``` /// #[macro_export] @@ -73,7 +74,7 @@ macro_rules! declare_scalar_functions { $( extern "C" fn $func_name( argc: i32, - argv: *const *const std::os::raw::c_void + argv: *const $crate::Value ) -> $crate::Value { let valid_args = { match argc { @@ -85,22 +86,14 @@ macro_rules! declare_scalar_functions { return $crate::Value::null(); } if argc == 0 || argv.is_null() { + log::debug!("{} was called with no arguments", stringify!($func_name)); let $args: &[$crate::Value] = &[]; $body } else { - let ptr_slice = unsafe{ std::slice::from_raw_parts(argv, argc as usize)}; - let mut values = Vec::with_capacity(argc as usize); - for &ptr in ptr_slice { - let val_ptr = ptr as *const $crate::Value; - if val_ptr.is_null() { - values.push($crate::Value::null()); - } else { - unsafe{values.push(std::ptr::read(val_ptr))}; - } - } - let $args: &[$crate::Value] = &values[..]; - $body - } + let ptr_slice = unsafe{ std::slice::from_raw_parts(argv, argc as usize)}; + let $args: &[$crate::Value] = ptr_slice; + $body + } } )* }; @@ -122,12 +115,42 @@ pub struct Value { pub value: *mut c_void, } +impl std::fmt::Debug for Value { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.value_type { + ValueType::Null => write!(f, "Value {{ Null }}"), + ValueType::Integer => write!(f, "Value {{ Integer: {} }}", unsafe { + *(self.value as *const i64) + }), + ValueType::Float => write!(f, "Value {{ Float: {} }}", unsafe { + *(self.value as *const f64) + }), + ValueType::Text => write!(f, "Value {{ Text: {:?} }}", unsafe { + &*(self.value as *const TextValue) + }), + ValueType::Blob => write!(f, "Value {{ Blob: {:?} }}", unsafe { + &*(self.value as *const Blob) + }), + } + } +} + #[repr(C)] pub struct TextValue { pub text: *const u8, pub len: u32, } +impl std::fmt::Debug for TextValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TextValue {{ text: {:?}, len: {} }}", + self.text, self.len + ) + } +} + impl Default for TextValue { fn default() -> Self { Self { @@ -170,6 +193,12 @@ pub struct Blob { pub size: u64, } +impl std::fmt::Debug for Blob { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Blob {{ data: {:?}, size: {} }}", self.data, self.size) + } +} + impl Blob { pub fn new(data: *const u8, size: u64) -> Self { Self { data, size } @@ -208,12 +237,15 @@ impl Value { } pub fn from_text(s: String) -> Self { - let text_value = TextValue::new(s.as_ptr(), s.len()); - let boxed_text = Box::new(text_value); - std::mem::forget(s); + let buffer = s.into_boxed_str(); + let ptr = buffer.as_ptr(); + let len = buffer.len(); + std::mem::forget(buffer); + let text_value = TextValue::new(ptr, len); + let text_box = Box::new(text_value); Self { value_type: ValueType::Text, - value: Box::into_raw(boxed_text) as *mut c_void, + value: Box::into_raw(text_box) as *mut c_void, } } From 6e05258d368aa3bbe2ded88f4b49cd0234b72618 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 12 Jan 2025 15:48:26 -0500 Subject: [PATCH 6/9] Add safety comments and clean up extension types --- core/types.rs | 12 ++++-------- extensions/uuid/src/lib.rs | 11 ++++++----- limbo_extension/src/lib.rs | 29 ++++++++++++++++++----------- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/core/types.rs b/core/types.rs index 92d02fce5..3d84da194 100644 --- a/core/types.rs +++ b/core/types.rs @@ -123,14 +123,10 @@ impl OwnedValue { OwnedValue::Float(float) } ExtValueType::Text => { - if v.value.is_null() { - OwnedValue::Null - } else { - let Some(text) = ExtTextValue::from_value(v) else { - return OwnedValue::Null; - }; - OwnedValue::build_text(std::rc::Rc::new(unsafe { text.as_str().to_string() })) - } + let Some(text) = (unsafe { ExtTextValue::from_value(v) }) else { + return OwnedValue::Null; + }; + OwnedValue::build_text(std::rc::Rc::new(unsafe { text.as_str().to_string() })) } ExtValueType::Blob => { let blob_ptr = v.value as *mut ExtBlob; diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index 7380f82d3..df8581955 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -73,7 +73,9 @@ declare_scalar_functions! { Value::from_integer(unix as i64) } ValueType::Text => { - let text = TextValue::from_value(&args[0]).unwrap(); + let Some(text) = (unsafe {TextValue::from_value(&args[0])}) else { + return Value::null(); + }; let uuid = uuid::Uuid::parse_str(unsafe {text.as_str()}).unwrap(); let unix = uuid_to_unix(uuid.as_bytes()); Value::from_integer(unix as i64) @@ -106,16 +108,15 @@ declare_scalar_functions! { log::debug!("uuid_blob was passed a non-text arg"); return Value::null(); } - if let Some(text) = TextValue::from_value(&args[0]) { + let Some(text) = (unsafe { TextValue::from_value(&args[0])}) else { + return Value::null(); + }; match uuid::Uuid::parse_str(unsafe {text.as_str()}) { Ok(uuid) => { Value::from_blob(uuid.as_bytes().to_vec()) } Err(_) => Value::null() } - } else { - Value::null() - } } } diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index 7bfc881e5..f8eb19dcf 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -137,8 +137,8 @@ impl std::fmt::Debug for Value { #[repr(C)] pub struct TextValue { - pub text: *const u8, - pub len: u32, + text: *const u8, + len: u32, } impl std::fmt::Debug for TextValue { @@ -168,22 +168,27 @@ impl TextValue { } } - pub fn from_value(value: &Value) -> Option<&Self> { + /// # Safety + /// Safe to call if the pointer is null, returns None + /// if the value is not a text type or if the value is null + pub unsafe fn from_value(value: &Value) -> Option<&Self> { if value.value_type != ValueType::Text { return None; } - unsafe { Some(&*(value.value as *const TextValue)) } + if value.value.is_null() { + return None; + } + Some(&*(value.value as *const TextValue)) } /// # Safety - /// The caller must ensure that the text is a valid UTF-8 string + /// If self.text is null we safely return an empty string but + /// the caller must ensure that the underlying value is valid utf8 pub unsafe fn as_str(&self) -> &str { if self.text.is_null() { return ""; } - unsafe { - std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.text, self.len as usize)) - } + std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.text, self.len as usize)) } } @@ -258,7 +263,11 @@ impl Value { } } - pub unsafe fn free(&mut self) { + /// # Safety + /// consumes the value while freeing the underlying memory with null check. + /// however this does assume that the type was properly constructed with + /// the appropriate value_type and value. + pub unsafe fn free(self) { if self.value.is_null() { return; } @@ -277,7 +286,5 @@ impl Value { } ValueType::Null => {} } - - self.value = std::ptr::null_mut(); } } From 3099e5c9ba237948dd0928f6ae0e9a904da8e24d Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 12 Jan 2025 16:35:56 -0500 Subject: [PATCH 7/9] Improve api, standardize conversions between types, finish extension --- core/ext/mod.rs | 7 +++-- core/lib.rs | 6 ++-- core/types.rs | 11 +++---- extensions/uuid/src/lib.rs | 61 ++++++++++++++++++++++++-------------- limbo_extension/src/lib.rs | 47 ++++++++++++++++++++++++----- 5 files changed, 92 insertions(+), 40 deletions(-) diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 179b3e7c6..f1758324b 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,8 +1,6 @@ use crate::{function::ExternalFunc, Database}; -pub use limbo_extension::{ - Blob as ExtBlob, TextValue as ExtTextValue, Value as ExtValue, ValueType as ExtValueType, -}; use limbo_extension::{ExtensionApi, ResultCode, ScalarFunction, RESULT_ERROR, RESULT_OK}; +pub use limbo_extension::{Value as ExtValue, ValueType as ExtValueType}; use std::{ ffi::{c_char, c_void, CStr}, rc::Rc, @@ -18,6 +16,9 @@ extern "C" fn register_scalar_function( Ok(s) => s.to_string(), Err(_) => return RESULT_ERROR, }; + if ctx.is_null() { + return RESULT_ERROR; + } let db = unsafe { &*(ctx as *const Database) }; db.register_scalar_function_impl(name_str, func) } diff --git a/core/lib.rs b/core/lib.rs index 4b3e8155b..f99bd3db3 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -126,7 +126,7 @@ impl Database { let header = db_header; let schema = Rc::new(RefCell::new(Schema::new())); let syms = Rc::new(RefCell::new(SymbolTable::new())); - let mut db = Database { + let db = Database { pager: pager.clone(), schema: schema.clone(), header: header.clone(), @@ -193,7 +193,9 @@ impl Database { self.syms.borrow_mut().extensions.push((lib, api_ptr)); Ok(()) } else { - let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; // own this again so we dont leak + if !api_ptr.is_null() { + let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; + } Err(LimboError::ExtensionError( "Extension registration failed".to_string(), )) diff --git a/core/types.rs b/core/types.rs index 3d84da194..82e8c787b 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,5 +1,5 @@ use crate::error::LimboError; -use crate::ext::{ExtBlob, ExtTextValue, ExtValue, ExtValueType}; +use crate::ext::{ExtValue, ExtValueType}; use crate::storage::sqlite3_ondisk::write_varint; use crate::Result; use std::fmt::Display; @@ -123,16 +123,17 @@ impl OwnedValue { OwnedValue::Float(float) } ExtValueType::Text => { - let Some(text) = (unsafe { ExtTextValue::from_value(v) }) else { + let Some(text) = v.to_text() else { return OwnedValue::Null; }; OwnedValue::build_text(std::rc::Rc::new(unsafe { text.as_str().to_string() })) } ExtValueType::Blob => { - let blob_ptr = v.value as *mut ExtBlob; + let Some(blob_ptr) = v.to_blob() else { + return OwnedValue::Null; + }; let blob = unsafe { - let slice = - std::slice::from_raw_parts((*blob_ptr).data, (*blob_ptr).size as usize); + let slice = std::slice::from_raw_parts(blob_ptr.data, blob_ptr.size as usize); slice.to_vec() }; OwnedValue::Blob(std::rc::Rc::new(blob)) diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index df8581955..d88f2f887 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -1,6 +1,5 @@ use limbo_extension::{ - declare_scalar_functions, register_extension, register_scalar_functions, Blob, TextValue, - Value, ValueType, + declare_scalar_functions, register_extension, register_scalar_functions, Value, ValueType, }; register_extension! { @@ -11,7 +10,7 @@ register_extension! { "uuid7" => uuid7_blob, "uuid_str" => uuid_str, "uuid_blob" => uuid_blob, - "exec_ts_from_uuid7" => exec_ts_from_uuid7, + "uuid7_timestamp_ms" => exec_ts_from_uuid7, }, } @@ -32,14 +31,35 @@ declare_scalar_functions! { #[args(0..=1)] fn uuid7_str(args: &[Value]) -> Value { let timestamp = if args.is_empty() { - let ctx = uuid::ContextV7::new(); + let ctx = uuid::ContextV7::new(); uuid::Timestamp::now(ctx) - } else if args[0].value_type == limbo_extension::ValueType::Integer { + } else { + let arg = &args[0]; + match arg.value_type { + ValueType::Integer => { let ctx = uuid::ContextV7::new(); - let int = args[0].value as i64; + let Some(int) = arg.to_integer() else { + return Value::null(); + }; uuid::Timestamp::from_unix(ctx, int as u64, 0) - } else { - return Value::null(); + } + ValueType::Text => { + let Some(text) = arg.to_text() else { + return Value::null(); + }; + let parsed = unsafe{text.as_str()}.parse::(); + match parsed { + Ok(unix) => { + if unix <= 0 { + return Value::null(); + } + uuid::Timestamp::from_unix(uuid::ContextV7::new(), unix as u64, 0) + } + Err(_) => return Value::null(), + } + } + _ => return Value::null(), + } }; let uuid = uuid::Uuid::new_v7(timestamp); Value::from_text(uuid.to_string()) @@ -52,7 +72,9 @@ declare_scalar_functions! { uuid::Timestamp::now(ctx) } else if args[0].value_type == limbo_extension::ValueType::Integer { let ctx = uuid::ContextV7::new(); - let int = args[0].value as i64; + let Some(int) = args[0].to_integer() else { + return Value::null(); + }; uuid::Timestamp::from_unix(ctx, int as u64, 0) } else { return Value::null(); @@ -66,14 +88,16 @@ declare_scalar_functions! { fn exec_ts_from_uuid7(args: &[Value]) -> Value { match args[0].value_type { ValueType::Blob => { - let blob = Blob::from_value(&args[0]).unwrap(); + let Some(blob) = &args[0].to_blob() else { + return Value::null(); + }; let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; let uuid = uuid::Uuid::from_slice(slice).unwrap(); let unix = uuid_to_unix(uuid.as_bytes()); Value::from_integer(unix as i64) } ValueType::Text => { - let Some(text) = (unsafe {TextValue::from_value(&args[0])}) else { + let Some(text) = args[0].to_text() else { return Value::null(); }; let uuid = uuid::Uuid::parse_str(unsafe {text.as_str()}).unwrap(); @@ -86,29 +110,20 @@ declare_scalar_functions! { #[args(1)] fn uuid_str(args: &[Value]) -> Value { - if args[0].value_type != limbo_extension::ValueType::Blob { - log::debug!("uuid_str was passed a non-blob arg"); + let Some(blob) = args[0].to_blob() else { return Value::null(); - } - if let Some(blob) = Blob::from_value(&args[0]) { + }; let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; let parsed = uuid::Uuid::from_slice(slice).ok().map(|u| u.to_string()); match parsed { Some(s) => Value::from_text(s), None => Value::null() } - } else { - Value::null() - } } #[args(1)] fn uuid_blob(args: &[Value]) -> Value { - if args[0].value_type != limbo_extension::ValueType::Text { - log::debug!("uuid_blob was passed a non-text arg"); - return Value::null(); - } - let Some(text) = (unsafe { TextValue::from_value(&args[0])}) else { + let Some(text) = args[0].to_text() else { return Value::null(); }; match uuid::Uuid::parse_str(unsafe {text.as_str()}) { diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index f8eb19dcf..d07cc8ea7 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -208,13 +208,6 @@ impl Blob { pub fn new(data: *const u8, size: u64) -> Self { Self { data, size } } - - pub fn from_value(value: &Value) -> Option<&Self> { - if value.value_type != ValueType::Blob { - return None; - } - unsafe { Some(&*(value.value as *const Blob)) } - } } impl Value { @@ -225,6 +218,46 @@ impl Value { } } + pub fn to_float(&self) -> Option { + if self.value_type != ValueType::Float { + return None; + } + if self.value.is_null() { + return None; + } + Some(unsafe { *(self.value as *const f64) }) + } + + pub fn to_text(&self) -> Option<&TextValue> { + if self.value_type != ValueType::Text { + return None; + } + if self.value.is_null() { + return None; + } + unsafe { Some(&*(self.value as *const TextValue)) } + } + + pub fn to_blob(&self) -> Option<&Blob> { + if self.value_type != ValueType::Blob { + return None; + } + if self.value.is_null() { + return None; + } + unsafe { Some(&*(self.value as *const Blob)) } + } + + pub fn to_integer(&self) -> Option { + if self.value_type != ValueType::Integer { + return None; + } + if self.value.is_null() { + return None; + } + Some(unsafe { *(self.value as *const i64) }) + } + pub fn from_integer(value: i64) -> Self { let boxed = Box::new(value); Self { From e4ce6402ebe357ea2de795c749556f86e9cd9039 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 12 Jan 2025 16:47:43 -0500 Subject: [PATCH 8/9] Remove previous uuid implementation --- core/ext/uuid.rs | 343 ------------------------------------- extensions/uuid/src/lib.rs | 4 +- 2 files changed, 3 insertions(+), 344 deletions(-) delete mode 100644 core/ext/uuid.rs diff --git a/core/ext/uuid.rs b/core/ext/uuid.rs deleted file mode 100644 index 37e496f00..000000000 --- a/core/ext/uuid.rs +++ /dev/null @@ -1,343 +0,0 @@ -use super::ExtFunc; -use crate::{ - types::{LimboText, OwnedValue}, - Database, LimboError, -}; -use std::rc::Rc; -use uuid::{ContextV7, Timestamp, Uuid}; - -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum UuidFunc { - Uuid4Str, - Uuid7, - Uuid7TS, - UuidStr, - UuidBlob, -} - -impl UuidFunc { - pub fn resolve_function(name: &str, num_args: usize) -> Option { - match name { - "uuid4_str" => Some(ExtFunc::Uuid(Self::Uuid4Str)), - "uuid7" if num_args < 2 => Some(ExtFunc::Uuid(Self::Uuid7)), - "uuid_str" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidStr)), - "uuid_blob" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidBlob)), - "uuid7_timestamp_ms" if num_args == 1 => Some(ExtFunc::Uuid(Self::Uuid7TS)), - // postgres_compatability - "gen_random_uuid" => Some(ExtFunc::Uuid(Self::Uuid4Str)), - _ => None, - } - } -} - -impl std::fmt::Display for UuidFunc { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Uuid4Str => write!(f, "uuid4_str"), - Self::Uuid7 => write!(f, "uuid7"), - Self::Uuid7TS => write!(f, "uuid7_timestamp_ms"), - Self::UuidStr => write!(f, "uuid_str"), - Self::UuidBlob => write!(f, "uuid_blob"), - } - } -} - -pub fn exec_uuid(var: &UuidFunc, sec: Option<&OwnedValue>) -> crate::Result { - match var { - UuidFunc::Uuid4Str => Ok(OwnedValue::Text(LimboText::new(Rc::new( - Uuid::new_v4().to_string(), - )))), - UuidFunc::Uuid7 => { - let uuid = match sec { - Some(OwnedValue::Integer(ref seconds)) => { - let ctx = ContextV7::new(); - if *seconds < 0 { - // not valid unix timestamp, error or null? - return Ok(OwnedValue::Null); - } - Uuid::new_v7(Timestamp::from_unix(ctx, *seconds as u64, 0)) - } - _ => Uuid::now_v7(), - }; - Ok(OwnedValue::Blob(Rc::new(uuid.into_bytes().to_vec()))) - } - _ => unreachable!(), - } -} - -pub fn exec_uuid4() -> crate::Result { - Ok(OwnedValue::Blob(Rc::new( - Uuid::new_v4().into_bytes().to_vec(), - ))) -} - -pub fn exec_uuidstr(reg: &OwnedValue) -> crate::Result { - match reg { - OwnedValue::Blob(blob) => { - let uuid = Uuid::from_slice(blob).map_err(|e| LimboError::ParseError(e.to_string()))?; - Ok(OwnedValue::Text(LimboText::new(Rc::new(uuid.to_string())))) - } - OwnedValue::Text(ref val) => { - let uuid = - Uuid::parse_str(&val.value).map_err(|e| LimboError::ParseError(e.to_string()))?; - Ok(OwnedValue::Text(LimboText::new(Rc::new(uuid.to_string())))) - } - OwnedValue::Null => Ok(OwnedValue::Null), - _ => Err(LimboError::ParseError( - "Invalid argument type for UUID function".to_string(), - )), - } -} - -pub fn exec_uuidblob(reg: &OwnedValue) -> crate::Result { - match reg { - OwnedValue::Text(val) => { - let uuid = - Uuid::parse_str(&val.value).map_err(|e| LimboError::ParseError(e.to_string()))?; - Ok(OwnedValue::Blob(Rc::new(uuid.as_bytes().to_vec()))) - } - OwnedValue::Blob(blob) => { - let uuid = Uuid::from_slice(blob).map_err(|e| LimboError::ParseError(e.to_string()))?; - Ok(OwnedValue::Blob(Rc::new(uuid.as_bytes().to_vec()))) - } - OwnedValue::Null => Ok(OwnedValue::Null), - _ => Err(LimboError::ParseError( - "Invalid argument type for UUID function".to_string(), - )), - } -} - -pub fn exec_ts_from_uuid7(reg: &OwnedValue) -> OwnedValue { - let uuid = match reg { - OwnedValue::Blob(blob) => { - Uuid::from_slice(blob).map_err(|e| LimboError::ParseError(e.to_string())) - } - OwnedValue::Text(val) => { - Uuid::parse_str(&val.value).map_err(|e| LimboError::ParseError(e.to_string())) - } - _ => Err(LimboError::ParseError( - "Invalid argument type for UUID function".to_string(), - )), - }; - match uuid { - Ok(uuid) => OwnedValue::Integer(uuid_to_unix(uuid.as_bytes()) as i64), - // display error? sqlean seems to set value to null - Err(_) => OwnedValue::Null, - } -} - -#[inline(always)] -fn uuid_to_unix(uuid: &[u8; 16]) -> u64 { - ((uuid[0] as u64) << 40) - | ((uuid[1] as u64) << 32) - | ((uuid[2] as u64) << 24) - | ((uuid[3] as u64) << 16) - | ((uuid[4] as u64) << 8) - | (uuid[5] as u64) -} - -//pub fn init(db: &mut Database) { -// db.define_scalar_function("uuid4", |_args| exec_uuid4()); -//} - -#[cfg(test)] -#[cfg(feature = "uuid")] -pub mod test { - use super::UuidFunc; - use crate::types::OwnedValue; - #[test] - fn test_exec_uuid_v4blob() { - use super::exec_uuid4; - use uuid::Uuid; - let owned_val = exec_uuid4(); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 4); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v4str() { - use super::{exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid4Str; - let owned_val = exec_uuid(&func, None); - match owned_val { - Ok(OwnedValue::Text(v4str)) => { - assert_eq!(v4str.value.len(), 36); - let uuid = Uuid::parse_str(&v4str.value); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 4); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_now() { - use super::{exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid7; - let owned_val = exec_uuid(&func, None); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_with_input() { - use super::{exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid7; - let owned_val = exec_uuid(&func, Some(&OwnedValue::Integer(946702800))); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_now_to_timestamp() { - use super::{exec_ts_from_uuid7, exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid7; - let owned_val = exec_uuid(&func, None); - match owned_val { - Ok(OwnedValue::Blob(ref blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - let result = exec_ts_from_uuid7(&owned_val.expect("uuid7")); - if let OwnedValue::Integer(ref ts) = result { - let unixnow = (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() - * 1000) as i64; - assert!(*ts >= unixnow - 1000); - } - } - - #[test] - fn test_exec_uuid_v7_to_timestamp() { - use super::{exec_ts_from_uuid7, exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid7; - let owned_val = exec_uuid(&func, Some(&OwnedValue::Integer(946702800))); - match owned_val { - Ok(OwnedValue::Blob(ref blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - let result = exec_ts_from_uuid7(&owned_val.expect("uuid7")); - assert_eq!(result, OwnedValue::Integer(946702800 * 1000)); - if let OwnedValue::Integer(ts) = result { - let time = chrono::DateTime::from_timestamp(ts / 1000, 0); - assert_eq!( - time.unwrap(), - "2000-01-01T05:00:00Z" - .parse::>() - .unwrap() - ); - } - } - - #[test] - fn test_exec_uuid_v4_str_to_blob() { - use super::{exec_uuid, exec_uuidblob, UuidFunc}; - use uuid::Uuid; - let owned_val = exec_uuidblob( - &exec_uuid(&UuidFunc::Uuid4Str, None).expect("uuid v4 string to generate"), - ); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 4); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_str_to_blob() { - use super::{exec_uuid, exec_uuidblob, exec_uuidstr, UuidFunc}; - use uuid::Uuid; - // convert a v7 blob to a string then back to a blob - let owned_val = exec_uuidblob( - &exec_uuidstr(&exec_uuid(&UuidFunc::Uuid7, None).expect("uuid v7 blob to generate")) - .expect("uuid v7 string to generate"), - ); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v4_blob_to_str() { - use super::{exec_uuid4, exec_uuidstr}; - use uuid::Uuid; - // convert a v4 blob to a string - let owned_val = exec_uuidstr(&exec_uuid4().expect("uuid v7 blob to generate")); - match owned_val { - Ok(OwnedValue::Text(v4str)) => { - assert_eq!(v4str.value.len(), 36); - let uuid = Uuid::parse_str(&v4str.value); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 4); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_blob_to_str() { - use super::{exec_uuid, exec_uuidstr}; - use uuid::Uuid; - // convert a v7 blob to a string - let owned_val = exec_uuidstr( - &exec_uuid(&UuidFunc::Uuid7, Some(&OwnedValue::Integer(123456789))) - .expect("uuid v7 blob to generate"), - ); - match owned_val { - Ok(OwnedValue::Text(v7str)) => { - assert_eq!(v7str.value.len(), 36); - let uuid = Uuid::parse_str(&v7str.value); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } -} diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index d88f2f887..92e9d5d4b 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -100,7 +100,9 @@ declare_scalar_functions! { let Some(text) = args[0].to_text() else { return Value::null(); }; - let uuid = uuid::Uuid::parse_str(unsafe {text.as_str()}).unwrap(); + let Ok(uuid) = uuid::Uuid::parse_str(unsafe {text.as_str()}) else { + return Value::null(); + }; let unix = uuid_to_unix(uuid.as_bytes()); Value::from_integer(unix as i64) } From 9c208dc866d0e4cb0e4053cdb0a83b9b71ab0038 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 12 Jan 2025 17:32:20 -0500 Subject: [PATCH 9/9] Add tests for first extension --- Cargo.lock | 17 +++-- Makefile | 9 ++- cli/app.rs | 5 +- core/lib.rs | 7 ++ core/translate/select.rs | 1 + core/types.rs | 25 +++---- extensions/uuid/src/lib.rs | 19 +++-- limbo_extension/src/lib.rs | 62 ++++++----------- testing/extensions.py | 139 +++++++++++++++++++++++++++++++++++++ 9 files changed, 205 insertions(+), 79 deletions(-) create mode 100755 testing/extensions.py diff --git a/Cargo.lock b/Cargo.lock index d3fffc7a0..7de3e3500 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1248,7 +1248,13 @@ dependencies = [ ] [[package]] -<<<<<<< HEAD +name = "limbo_extension" +version = "0.0.12" +dependencies = [ + "log", +] + +[[package]] name = "limbo_libsql" version = "0.0.12" dependencies = [ @@ -1257,13 +1263,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "limbo_extension" -version = "0.0.11" -dependencies = [ - "log", -] - [[package]] name = "limbo_macros" version = "0.0.12" @@ -1294,7 +1293,7 @@ dependencies = [ [[package]] name = "limbo_uuid" -version = "0.0.11" +version = "0.0.12" dependencies = [ "limbo_extension", "log", diff --git a/Makefile b/Makefile index 30d84bcc4..109a3f147 100644 --- a/Makefile +++ b/Makefile @@ -62,10 +62,15 @@ limbo-wasm: cargo build --package limbo-wasm --target wasm32-wasi .PHONY: limbo-wasm -test: limbo test-compat test-sqlite3 test-shell +test: limbo test-compat test-sqlite3 test-shell test-extensions .PHONY: test -test-shell: limbo +test-extensions: limbo + cargo build --package limbo_uuid + ./testing/extensions.py +.PHONY: test-extensions + +test-shell: limbo SQLITE_EXEC=$(SQLITE_EXEC) ./testing/shelltests.py .PHONY: test-shell diff --git a/cli/app.rs b/cli/app.rs index 62108ca6d..a0a7b5d99 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -323,6 +323,7 @@ impl Limbo { }; } + #[cfg(not(target_family = "wasm"))] fn handle_load_extension(&mut self, path: &str) -> Result<(), String> { self.conn.load_extension(path).map_err(|e| e.to_string()) } @@ -550,7 +551,9 @@ impl Limbo { let _ = self.writeln(e.to_string()); }; } - Command::LoadExtension => { + Command::LoadExtension => + { + #[cfg(not(target_family = "wasm"))] if let Err(e) = self.handle_load_extension(args[1]) { let _ = self.writeln(&e); } diff --git a/core/lib.rs b/core/lib.rs index f99bd3db3..8f8914b0f 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -18,7 +18,9 @@ mod vdbe; static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; use fallible_iterator::FallibleIterator; +#[cfg(not(target_family = "wasm"))] use libloading::{Library, Symbol}; +#[cfg(not(target_family = "wasm"))] use limbo_extension::{ExtensionApi, ExtensionEntryPoint, RESULT_OK}; use log::trace; use schema::Schema; @@ -179,6 +181,7 @@ impl Database { .insert(name.as_ref().to_string(), func.into()); } + #[cfg(not(target_family = "wasm"))] pub fn load_extension(&self, path: &str) -> Result<()> { let api = Box::new(self.build_limbo_extension()); let lib = @@ -397,6 +400,7 @@ impl Connection { Ok(()) } + #[cfg(not(target_family = "wasm"))] pub fn load_extension(&self, path: &str) -> Result<()> { Database::load_extension(self.db.as_ref(), path) } @@ -499,6 +503,7 @@ impl Rows { pub(crate) struct SymbolTable { pub functions: HashMap>, + #[cfg(not(target_family = "wasm"))] extensions: Vec<(libloading::Library, *const ExtensionApi)>, } @@ -514,6 +519,8 @@ impl SymbolTable { pub fn new() -> Self { Self { functions: HashMap::new(), + // TODO: wasm libs will be very different + #[cfg(not(target_family = "wasm"))] extensions: Vec::new(), } } diff --git a/core/translate/select.rs b/core/translate/select.rs index fa5361205..768474a8e 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -20,6 +20,7 @@ pub fn translate_select( select: ast::Select, syms: &SymbolTable, ) -> Result<()> { + let mut select_plan = prepare_select_plan(schema, select, syms)?; optimize_plan(&mut select_plan)?; emit_program(program, select_plan, syms) } diff --git a/core/types.rs b/core/types.rs index 82e8c787b..81d9d5328 100644 --- a/core/types.rs +++ b/core/types.rs @@ -107,35 +107,30 @@ impl OwnedValue { } pub fn from_ffi(v: &ExtValue) -> Self { - if v.value.is_null() { - return OwnedValue::Null; - } - match v.value_type { + match v.value_type() { ExtValueType::Null => OwnedValue::Null, ExtValueType::Integer => { - let int_ptr = v.value as *mut i64; - let integer = unsafe { *int_ptr }; - OwnedValue::Integer(integer) + let Some(int) = v.to_integer() else { + return OwnedValue::Null; + }; + OwnedValue::Integer(int) } ExtValueType::Float => { - let float_ptr = v.value as *mut f64; - let float = unsafe { *float_ptr }; + let Some(float) = v.to_float() else { + return OwnedValue::Null; + }; OwnedValue::Float(float) } ExtValueType::Text => { let Some(text) = v.to_text() else { return OwnedValue::Null; }; - OwnedValue::build_text(std::rc::Rc::new(unsafe { text.as_str().to_string() })) + OwnedValue::build_text(std::rc::Rc::new(text)) } ExtValueType::Blob => { - let Some(blob_ptr) = v.to_blob() else { + let Some(blob) = v.to_blob() else { return OwnedValue::Null; }; - let blob = unsafe { - let slice = std::slice::from_raw_parts(blob_ptr.data, blob_ptr.size as usize); - slice.to_vec() - }; OwnedValue::Blob(std::rc::Rc::new(blob)) } } diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index 92e9d5d4b..e6a3a4f9b 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -35,7 +35,7 @@ declare_scalar_functions! { uuid::Timestamp::now(ctx) } else { let arg = &args[0]; - match arg.value_type { + match arg.value_type() { ValueType::Integer => { let ctx = uuid::ContextV7::new(); let Some(int) = arg.to_integer() else { @@ -47,8 +47,7 @@ declare_scalar_functions! { let Some(text) = arg.to_text() else { return Value::null(); }; - let parsed = unsafe{text.as_str()}.parse::(); - match parsed { + match text.parse::() { Ok(unix) => { if unix <= 0 { return Value::null(); @@ -70,7 +69,7 @@ declare_scalar_functions! { let timestamp = if args.is_empty() { let ctx = uuid::ContextV7::new(); uuid::Timestamp::now(ctx) - } else if args[0].value_type == limbo_extension::ValueType::Integer { + } else if args[0].value_type() == limbo_extension::ValueType::Integer { let ctx = uuid::ContextV7::new(); let Some(int) = args[0].to_integer() else { return Value::null(); @@ -86,13 +85,12 @@ declare_scalar_functions! { #[args(1)] fn exec_ts_from_uuid7(args: &[Value]) -> Value { - match args[0].value_type { + match args[0].value_type() { ValueType::Blob => { let Some(blob) = &args[0].to_blob() else { return Value::null(); }; - let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; - let uuid = uuid::Uuid::from_slice(slice).unwrap(); + let uuid = uuid::Uuid::from_slice(blob.as_slice()).unwrap(); let unix = uuid_to_unix(uuid.as_bytes()); Value::from_integer(unix as i64) } @@ -100,7 +98,7 @@ declare_scalar_functions! { let Some(text) = args[0].to_text() else { return Value::null(); }; - let Ok(uuid) = uuid::Uuid::parse_str(unsafe {text.as_str()}) else { + let Ok(uuid) = uuid::Uuid::parse_str(&text) else { return Value::null(); }; let unix = uuid_to_unix(uuid.as_bytes()); @@ -115,8 +113,7 @@ declare_scalar_functions! { let Some(blob) = args[0].to_blob() else { return Value::null(); }; - let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; - let parsed = uuid::Uuid::from_slice(slice).ok().map(|u| u.to_string()); + let parsed = uuid::Uuid::from_slice(blob.as_slice()).ok().map(|u| u.to_string()); match parsed { Some(s) => Value::from_text(s), None => Value::null() @@ -128,7 +125,7 @@ declare_scalar_functions! { let Some(text) = args[0].to_text() else { return Value::null(); }; - match uuid::Uuid::parse_str(unsafe {text.as_str()}) { + match uuid::Uuid::parse_str(&text) { Ok(uuid) => { Value::from_blob(uuid.as_bytes().to_vec()) } diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index d07cc8ea7..0666c588b 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -50,19 +50,6 @@ macro_rules! register_scalar_functions { } } -/// Provide a cleaner interface to define scalar functions to extension authors -/// . e.g. -/// ``` -/// #[args(1)] -/// fn scalar_double(args: &[Value]) -> Value { -/// Value::from_integer(args[0].integer * 2) -/// } -/// -/// #[args(0..=2)] -/// fn scalar_sum(args: &[Value]) -> Value { -/// Value::from_integer(args.iter().map(|v| v.integer).sum()) -/// ``` -/// #[macro_export] macro_rules! declare_scalar_functions { ( @@ -100,7 +87,7 @@ macro_rules! declare_scalar_functions { } #[repr(C)] -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone, Copy)] pub enum ValueType { Null, Integer, @@ -111,8 +98,8 @@ pub enum ValueType { #[repr(C)] pub struct Value { - pub value_type: ValueType, - pub value: *mut c_void, + value_type: ValueType, + value: *mut c_void, } impl std::fmt::Debug for Value { @@ -161,41 +148,27 @@ impl Default for TextValue { } impl TextValue { - pub fn new(text: *const u8, len: usize) -> Self { + pub(crate) fn new(text: *const u8, len: usize) -> Self { Self { text, len: len as u32, } } - /// # Safety - /// Safe to call if the pointer is null, returns None - /// if the value is not a text type or if the value is null - pub unsafe fn from_value(value: &Value) -> Option<&Self> { - if value.value_type != ValueType::Text { - return None; - } - if value.value.is_null() { - return None; - } - Some(&*(value.value as *const TextValue)) - } - - /// # Safety - /// If self.text is null we safely return an empty string but - /// the caller must ensure that the underlying value is valid utf8 - pub unsafe fn as_str(&self) -> &str { + fn as_str(&self) -> &str { if self.text.is_null() { return ""; } - std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.text, self.len as usize)) + unsafe { + std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.text, self.len as usize)) + } } } #[repr(C)] pub struct Blob { - pub data: *const u8, - pub size: u64, + data: *const u8, + size: u64, } impl std::fmt::Debug for Blob { @@ -218,6 +191,10 @@ impl Value { } } + pub fn value_type(&self) -> ValueType { + self.value_type + } + pub fn to_float(&self) -> Option { if self.value_type != ValueType::Float { return None; @@ -228,24 +205,27 @@ impl Value { Some(unsafe { *(self.value as *const f64) }) } - pub fn to_text(&self) -> Option<&TextValue> { + pub fn to_text(&self) -> Option { if self.value_type != ValueType::Text { return None; } if self.value.is_null() { return None; } - unsafe { Some(&*(self.value as *const TextValue)) } + let txt = unsafe { &*(self.value as *const TextValue) }; + Some(String::from(txt.as_str())) } - pub fn to_blob(&self) -> Option<&Blob> { + pub fn to_blob(&self) -> Option> { if self.value_type != ValueType::Blob { return None; } if self.value.is_null() { return None; } - unsafe { Some(&*(self.value as *const Blob)) } + let blob = unsafe { &*(self.value as *const Blob) }; + let slice = unsafe { std::slice::from_raw_parts(blob.data, blob.size as usize) }; + Some(slice.to_vec()) } pub fn to_integer(&self) -> Option { diff --git a/testing/extensions.py b/testing/extensions.py new file mode 100755 index 000000000..74383be94 --- /dev/null +++ b/testing/extensions.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +import os +import subprocess +import select +import time +import uuid + +sqlite_exec = "./target/debug/limbo" +sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ") + + +def init_limbo(): + pipe = subprocess.Popen( + [sqlite_exec, *sqlite_flags], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=0, + ) + return pipe + + +def execute_sql(pipe, sql): + end_suffix = "END_OF_RESULT" + write_to_pipe(pipe, sql) + write_to_pipe(pipe, f"SELECT '{end_suffix}';\n") + stdout = pipe.stdout + stderr = pipe.stderr + output = "" + while True: + ready_to_read, _, error_in_pipe = select.select( + [stdout, stderr], [], [stdout, stderr] + ) + ready_to_read_or_err = set(ready_to_read + error_in_pipe) + if stderr in ready_to_read_or_err: + exit_on_error(stderr) + + if stdout in ready_to_read_or_err: + fragment = stdout.read(select.PIPE_BUF) + output += fragment.decode() + if output.rstrip().endswith(end_suffix): + output = output.rstrip().removesuffix(end_suffix) + break + output = strip_each_line(output) + return output + + +def strip_each_line(lines: str) -> str: + lines = lines.split("\n") + lines = [line.strip() for line in lines if line != ""] + return "\n".join(lines) + + +def write_to_pipe(pipe, command): + if pipe.stdin is None: + raise RuntimeError("Failed to write to shell") + pipe.stdin.write((command + "\n").encode()) + pipe.stdin.flush() + + +def exit_on_error(stderr): + while True: + ready_to_read, _, _ = select.select([stderr], [], []) + if not ready_to_read: + break + print(stderr.read().decode(), end="") + exit(1) + + +def run_test(pipe, sql, validator=None): + print(f"Running test: {sql}") + result = execute_sql(pipe, sql) + if validator is not None: + if not validator(result): + print(f"Test FAILED: {sql}") + print(f"Returned: {result}") + raise Exception("Validation failed") + print("Test PASSED") + + +def validate_blob(result): + # HACK: blobs are difficult to test because the shell + # tries to return them as utf8 strings, so we call hex + # and assert they are valid hex digits + return int(result, 16) is not None + + +def validate_string_uuid(result): + return len(result) == 36 and result.count("-") == 4 + + +def returns_null(result): + return result == "" or result == b"\n" or result == b"" + + +def assert_now_unixtime(result): + return result == str(int(time.time())) + + +def assert_specific_time(result): + return result == "1736720789" + + +def main(): + specific_time = "01945ca0-3189-76c0-9a8f-caf310fc8b8e" + extension_path = "./target/debug/liblimbo_uuid.so" + pipe = init_limbo() + try: + # before extension loads, assert no function + run_test(pipe, "SELECT uuid4();", returns_null) + run_test(pipe, "SELECT uuid4_str();", returns_null) + run_test(pipe, f".load {extension_path}", returns_null) + print("Extension loaded successfully.") + run_test(pipe, "SELECT hex(uuid4());", validate_blob) + run_test(pipe, "SELECT uuid4_str();", validate_string_uuid) + run_test(pipe, "SELECT hex(uuid7());", validate_blob) + run_test( + pipe, + "SELECT uuid7_timestamp_ms(uuid7()) / 1000;", + ) + run_test(pipe, "SELECT uuid7_str();", validate_string_uuid) + run_test(pipe, "SELECT uuid_str(uuid7());", validate_string_uuid) + run_test(pipe, "SELECT hex(uuid_blob(uuid7_str()));", validate_blob) + run_test(pipe, "SELECT uuid_str(uuid_blob(uuid7_str()));", validate_string_uuid) + run_test( + pipe, + f"SELECT uuid7_timestamp_ms('{specific_time}') / 1000;", + assert_specific_time, + ) + except Exception as e: + print(f"Test FAILED: {e}") + pipe.terminate() + exit(1) + pipe.terminate() + print("All tests passed successfully.") + + +if __name__ == "__main__": + main()