diff --git a/Cargo.lock b/Cargo.lock index bb69e5c0f..7de3e3500 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1137,6 +1137,16 @@ version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + [[package]] name = "libmimalloc-sys" version = "0.1.39" @@ -1212,6 +1222,8 @@ dependencies = [ "jsonb", "julian_day_converter", "libc", + "libloading", + "limbo_extension", "limbo_macros", "log", "miette", @@ -1235,6 +1247,13 @@ dependencies = [ "uuid", ] +[[package]] +name = "limbo_extension" +version = "0.0.12" +dependencies = [ + "log", +] + [[package]] name = "limbo_libsql" version = "0.0.12" @@ -1272,6 +1291,15 @@ dependencies = [ "log", ] +[[package]] +name = "limbo_uuid" +version = "0.0.12" +dependencies = [ + "limbo_extension", + "log", + "uuid", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" diff --git a/Cargo.toml b/Cargo.toml index 0d2fc81be..44ec1ef15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ members = [ "sqlite3", "core", "simulator", - "test", "macros", + "test", "macros", "limbo_extension", "extensions/uuid", ] exclude = ["perf/latency/limbo"] diff --git a/Makefile b/Makefile index 30d84bcc4..109a3f147 100644 --- a/Makefile +++ b/Makefile @@ -62,10 +62,15 @@ limbo-wasm: cargo build --package limbo-wasm --target wasm32-wasi .PHONY: limbo-wasm -test: limbo test-compat test-sqlite3 test-shell +test: limbo test-compat test-sqlite3 test-shell test-extensions .PHONY: test -test-shell: limbo +test-extensions: limbo + cargo build --package limbo_uuid + ./testing/extensions.py +.PHONY: test-extensions + +test-shell: limbo SQLITE_EXEC=$(SQLITE_EXEC) ./testing/shelltests.py .PHONY: test-shell diff --git a/cli/app.rs b/cli/app.rs index f114b63a2..a0a7b5d99 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -129,6 +129,8 @@ pub enum Command { Tables, /// Import data from FILE into TABLE Import, + /// Loads an extension library + LoadExtension, } impl Command { @@ -141,7 +143,12 @@ impl Command { | Self::ShowInfo | Self::Tables | Self::SetOutput => 0, - Self::Open | Self::OutputMode | Self::Cwd | Self::Echo | Self::NullValue => 1, + Self::Open + | Self::OutputMode + | Self::Cwd + | Self::Echo + | Self::NullValue + | Self::LoadExtension => 1, Self::Import => 2, } + 1) // argv0 } @@ -160,6 +167,7 @@ impl Command { Self::NullValue => ".nullvalue ", Self::Echo => ".echo on|off", Self::Tables => ".tables", + Self::LoadExtension => ".load", Self::Import => &IMPORT_HELP, } } @@ -182,6 +190,7 @@ impl FromStr for Command { ".nullvalue" => Ok(Self::NullValue), ".echo" => Ok(Self::Echo), ".import" => Ok(Self::Import), + ".load" => Ok(Self::LoadExtension), _ => Err("Unknown command".to_string()), } } @@ -314,6 +323,11 @@ impl Limbo { }; } + #[cfg(not(target_family = "wasm"))] + fn handle_load_extension(&mut self, path: &str) -> Result<(), String> { + self.conn.load_extension(path).map_err(|e| e.to_string()) + } + fn display_in_memory(&mut self) -> std::io::Result<()> { if self.opts.db_file == ":memory:" { self.writeln("Connected to a transient in-memory database.")?; @@ -537,6 +551,13 @@ impl Limbo { let _ = self.writeln(e.to_string()); }; } + Command::LoadExtension => + { + #[cfg(not(target_family = "wasm"))] + if let Err(e) = self.handle_load_extension(args[1]) { + let _ = self.writeln(&e); + } + } } } else { let _ = self.write_fmt(format_args!( diff --git a/core/Cargo.toml b/core/Cargo.toml index 0daa58c0d..c6042a298 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -35,6 +35,7 @@ rustix = "0.38.34" mimalloc = { version = "*", default-features = false } [dependencies] +limbo_extension = { path = "../limbo_extension" } cfg_block = "0.1.1" fallible-iterator = "0.3.0" hex = "0.4.3" @@ -58,6 +59,7 @@ bumpalo = { version = "3.16.0", features = ["collections", "boxed"] } limbo_macros = { path = "../macros" } uuid = { version = "1.11.0", features = ["v4", "v7"], optional = true } miette = "7.4.0" +libloading = "0.8.6" [target.'cfg(not(target_family = "windows"))'.dev-dependencies] pprof = { version = "0.14.0", features = ["criterion", "flamegraph"] } diff --git a/core/error.rs b/core/error.rs index 646e85825..e3e176b79 100644 --- a/core/error.rs +++ b/core/error.rs @@ -39,6 +39,8 @@ pub enum LimboError { InvalidModifier(String), #[error("Runtime error: {0}")] Constraint(String), + #[error("Extension error: {0}")] + ExtensionError(String), } #[macro_export] diff --git a/core/ext/mod.rs b/core/ext/mod.rs index cea65a98d..f1758324b 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,37 +1,44 @@ -#[cfg(feature = "uuid")] -mod uuid; -#[cfg(feature = "uuid")] -pub use uuid::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, UuidFunc}; +use crate::{function::ExternalFunc, Database}; +use limbo_extension::{ExtensionApi, ResultCode, ScalarFunction, RESULT_ERROR, RESULT_OK}; +pub use limbo_extension::{Value as ExtValue, ValueType as ExtValueType}; +use std::{ + ffi::{c_char, c_void, CStr}, + rc::Rc, +}; -#[derive(Debug, Clone, PartialEq)] -pub enum ExtFunc { - #[cfg(feature = "uuid")] - Uuid(UuidFunc), +extern "C" fn register_scalar_function( + ctx: *mut c_void, + name: *const c_char, + func: ScalarFunction, +) -> ResultCode { + let c_str = unsafe { CStr::from_ptr(name) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return RESULT_ERROR, + }; + if ctx.is_null() { + return RESULT_ERROR; + } + let db = unsafe { &*(ctx as *const Database) }; + db.register_scalar_function_impl(name_str, func) } -#[allow(unreachable_patterns)] // TODO: remove when more extension funcs added -impl std::fmt::Display for ExtFunc { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - #[cfg(feature = "uuid")] - Self::Uuid(uuidfn) => write!(f, "{}", uuidfn), - _ => write!(f, "unknown"), +impl Database { + fn register_scalar_function_impl(&self, name: String, func: ScalarFunction) -> ResultCode { + self.syms.borrow_mut().functions.insert( + name.to_string(), + Rc::new(ExternalFunc { + name: name.to_string(), + func, + }), + ); + RESULT_OK + } + + pub fn build_limbo_extension(&self) -> ExtensionApi { + ExtensionApi { + ctx: self as *const _ as *mut c_void, + register_scalar_function, } } } - -#[allow(unreachable_patterns)] -impl ExtFunc { - pub fn resolve_function(name: &str, num_args: usize) -> Option { - match name { - #[cfg(feature = "uuid")] - name => UuidFunc::resolve_function(name, num_args), - _ => None, - } - } -} - -pub fn init(db: &mut crate::Database) { - #[cfg(feature = "uuid")] - uuid::init(db); -} diff --git a/core/ext/uuid.rs b/core/ext/uuid.rs deleted file mode 100644 index 92fdd831a..000000000 --- a/core/ext/uuid.rs +++ /dev/null @@ -1,343 +0,0 @@ -use super::ExtFunc; -use crate::{ - types::{LimboText, OwnedValue}, - Database, LimboError, -}; -use std::rc::Rc; -use uuid::{ContextV7, Timestamp, Uuid}; - -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum UuidFunc { - Uuid4Str, - Uuid7, - Uuid7TS, - UuidStr, - UuidBlob, -} - -impl UuidFunc { - pub fn resolve_function(name: &str, num_args: usize) -> Option { - match name { - "uuid4_str" => Some(ExtFunc::Uuid(Self::Uuid4Str)), - "uuid7" if num_args < 2 => Some(ExtFunc::Uuid(Self::Uuid7)), - "uuid_str" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidStr)), - "uuid_blob" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidBlob)), - "uuid7_timestamp_ms" if num_args == 1 => Some(ExtFunc::Uuid(Self::Uuid7TS)), - // postgres_compatability - "gen_random_uuid" => Some(ExtFunc::Uuid(Self::Uuid4Str)), - _ => None, - } - } -} - -impl std::fmt::Display for UuidFunc { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Uuid4Str => write!(f, "uuid4_str"), - Self::Uuid7 => write!(f, "uuid7"), - Self::Uuid7TS => write!(f, "uuid7_timestamp_ms"), - Self::UuidStr => write!(f, "uuid_str"), - Self::UuidBlob => write!(f, "uuid_blob"), - } - } -} - -pub fn exec_uuid(var: &UuidFunc, sec: Option<&OwnedValue>) -> crate::Result { - match var { - UuidFunc::Uuid4Str => Ok(OwnedValue::Text(LimboText::new(Rc::new( - Uuid::new_v4().to_string(), - )))), - UuidFunc::Uuid7 => { - let uuid = match sec { - Some(OwnedValue::Integer(ref seconds)) => { - let ctx = ContextV7::new(); - if *seconds < 0 { - // not valid unix timestamp, error or null? - return Ok(OwnedValue::Null); - } - Uuid::new_v7(Timestamp::from_unix(ctx, *seconds as u64, 0)) - } - _ => Uuid::now_v7(), - }; - Ok(OwnedValue::Blob(Rc::new(uuid.into_bytes().to_vec()))) - } - _ => unreachable!(), - } -} - -pub fn exec_uuid4() -> crate::Result { - Ok(OwnedValue::Blob(Rc::new( - Uuid::new_v4().into_bytes().to_vec(), - ))) -} - -pub fn exec_uuidstr(reg: &OwnedValue) -> crate::Result { - match reg { - OwnedValue::Blob(blob) => { - let uuid = Uuid::from_slice(blob).map_err(|e| LimboError::ParseError(e.to_string()))?; - Ok(OwnedValue::Text(LimboText::new(Rc::new(uuid.to_string())))) - } - OwnedValue::Text(ref val) => { - let uuid = - Uuid::parse_str(&val.value).map_err(|e| LimboError::ParseError(e.to_string()))?; - Ok(OwnedValue::Text(LimboText::new(Rc::new(uuid.to_string())))) - } - OwnedValue::Null => Ok(OwnedValue::Null), - _ => Err(LimboError::ParseError( - "Invalid argument type for UUID function".to_string(), - )), - } -} - -pub fn exec_uuidblob(reg: &OwnedValue) -> crate::Result { - match reg { - OwnedValue::Text(val) => { - let uuid = - Uuid::parse_str(&val.value).map_err(|e| LimboError::ParseError(e.to_string()))?; - Ok(OwnedValue::Blob(Rc::new(uuid.as_bytes().to_vec()))) - } - OwnedValue::Blob(blob) => { - let uuid = Uuid::from_slice(blob).map_err(|e| LimboError::ParseError(e.to_string()))?; - Ok(OwnedValue::Blob(Rc::new(uuid.as_bytes().to_vec()))) - } - OwnedValue::Null => Ok(OwnedValue::Null), - _ => Err(LimboError::ParseError( - "Invalid argument type for UUID function".to_string(), - )), - } -} - -pub fn exec_ts_from_uuid7(reg: &OwnedValue) -> OwnedValue { - let uuid = match reg { - OwnedValue::Blob(blob) => { - Uuid::from_slice(blob).map_err(|e| LimboError::ParseError(e.to_string())) - } - OwnedValue::Text(val) => { - Uuid::parse_str(&val.value).map_err(|e| LimboError::ParseError(e.to_string())) - } - _ => Err(LimboError::ParseError( - "Invalid argument type for UUID function".to_string(), - )), - }; - match uuid { - Ok(uuid) => OwnedValue::Integer(uuid_to_unix(uuid.as_bytes()) as i64), - // display error? sqlean seems to set value to null - Err(_) => OwnedValue::Null, - } -} - -#[inline(always)] -fn uuid_to_unix(uuid: &[u8; 16]) -> u64 { - ((uuid[0] as u64) << 40) - | ((uuid[1] as u64) << 32) - | ((uuid[2] as u64) << 24) - | ((uuid[3] as u64) << 16) - | ((uuid[4] as u64) << 8) - | (uuid[5] as u64) -} - -pub fn init(db: &mut Database) { - db.define_scalar_function("uuid4", |_args| exec_uuid4()); -} - -#[cfg(test)] -#[cfg(feature = "uuid")] -pub mod test { - use super::UuidFunc; - use crate::types::OwnedValue; - #[test] - fn test_exec_uuid_v4blob() { - use super::exec_uuid4; - use uuid::Uuid; - let owned_val = exec_uuid4(); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 4); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v4str() { - use super::{exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid4Str; - let owned_val = exec_uuid(&func, None); - match owned_val { - Ok(OwnedValue::Text(v4str)) => { - assert_eq!(v4str.value.len(), 36); - let uuid = Uuid::parse_str(&v4str.value); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 4); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_now() { - use super::{exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid7; - let owned_val = exec_uuid(&func, None); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_with_input() { - use super::{exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid7; - let owned_val = exec_uuid(&func, Some(&OwnedValue::Integer(946702800))); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_now_to_timestamp() { - use super::{exec_ts_from_uuid7, exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid7; - let owned_val = exec_uuid(&func, None); - match owned_val { - Ok(OwnedValue::Blob(ref blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - let result = exec_ts_from_uuid7(&owned_val.expect("uuid7")); - if let OwnedValue::Integer(ref ts) = result { - let unixnow = (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() - * 1000) as i64; - assert!(*ts >= unixnow - 1000); - } - } - - #[test] - fn test_exec_uuid_v7_to_timestamp() { - use super::{exec_ts_from_uuid7, exec_uuid, UuidFunc}; - use uuid::Uuid; - let func = UuidFunc::Uuid7; - let owned_val = exec_uuid(&func, Some(&OwnedValue::Integer(946702800))); - match owned_val { - Ok(OwnedValue::Blob(ref blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - let result = exec_ts_from_uuid7(&owned_val.expect("uuid7")); - assert_eq!(result, OwnedValue::Integer(946702800 * 1000)); - if let OwnedValue::Integer(ts) = result { - let time = chrono::DateTime::from_timestamp(ts / 1000, 0); - assert_eq!( - time.unwrap(), - "2000-01-01T05:00:00Z" - .parse::>() - .unwrap() - ); - } - } - - #[test] - fn test_exec_uuid_v4_str_to_blob() { - use super::{exec_uuid, exec_uuidblob, UuidFunc}; - use uuid::Uuid; - let owned_val = exec_uuidblob( - &exec_uuid(&UuidFunc::Uuid4Str, None).expect("uuid v4 string to generate"), - ); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 4); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_str_to_blob() { - use super::{exec_uuid, exec_uuidblob, exec_uuidstr, UuidFunc}; - use uuid::Uuid; - // convert a v7 blob to a string then back to a blob - let owned_val = exec_uuidblob( - &exec_uuidstr(&exec_uuid(&UuidFunc::Uuid7, None).expect("uuid v7 blob to generate")) - .expect("uuid v7 string to generate"), - ); - match owned_val { - Ok(OwnedValue::Blob(blob)) => { - assert_eq!(blob.len(), 16); - let uuid = Uuid::from_slice(&blob); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v4_blob_to_str() { - use super::{exec_uuid4, exec_uuidstr}; - use uuid::Uuid; - // convert a v4 blob to a string - let owned_val = exec_uuidstr(&exec_uuid4().expect("uuid v7 blob to generate")); - match owned_val { - Ok(OwnedValue::Text(v4str)) => { - assert_eq!(v4str.value.len(), 36); - let uuid = Uuid::parse_str(&v4str.value); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 4); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } - - #[test] - fn test_exec_uuid_v7_blob_to_str() { - use super::{exec_uuid, exec_uuidstr}; - use uuid::Uuid; - // convert a v7 blob to a string - let owned_val = exec_uuidstr( - &exec_uuid(&UuidFunc::Uuid7, Some(&OwnedValue::Integer(123456789))) - .expect("uuid v7 blob to generate"), - ); - match owned_val { - Ok(OwnedValue::Text(v7str)) => { - assert_eq!(v7str.value.len(), 36); - let uuid = Uuid::parse_str(&v7str.value); - assert!(uuid.is_ok()); - assert_eq!(uuid.unwrap().get_version_num(), 7); - } - _ => panic!("exec_uuid did not return a Blob variant"), - } - } -} diff --git a/core/function.rs b/core/function.rs index 060a677c3..68b5ef005 100644 --- a/core/function.rs +++ b/core/function.rs @@ -1,11 +1,21 @@ -use crate::ext::ExtFunc; use std::fmt; use std::fmt::{Debug, Display}; use std::rc::Rc; +use limbo_extension::ScalarFunction; + pub struct ExternalFunc { pub name: String, - pub func: Box crate::Result>, + pub func: ScalarFunction, +} + +impl ExternalFunc { + pub fn new(name: &str, func: ScalarFunction) -> Self { + Self { + name: name.to_string(), + func, + } + } } impl Debug for ExternalFunc { @@ -299,7 +309,6 @@ pub enum Func { Math(MathFunc), #[cfg(feature = "json")] Json(JsonFunc), - Extension(ExtFunc), External(Rc), } @@ -311,7 +320,6 @@ impl Display for Func { Self::Math(math_func) => write!(f, "{}", math_func), #[cfg(feature = "json")] Self::Json(json_func) => write!(f, "{}", json_func), - Self::Extension(ext_func) => write!(f, "{}", ext_func), Self::External(generic_func) => write!(f, "{}", generic_func), } } @@ -418,10 +426,7 @@ impl Func { "tan" => Ok(Self::Math(MathFunc::Tan)), "tanh" => Ok(Self::Math(MathFunc::Tanh)), "trunc" => Ok(Self::Math(MathFunc::Trunc)), - _ => match ExtFunc::resolve_function(name, arg_count) { - Some(ext_func) => Ok(Self::Extension(ext_func)), - None => Err(()), - }, + _ => Err(()), } } } diff --git a/core/lib.rs b/core/lib.rs index a80fab83a..8f8914b0f 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -18,6 +18,10 @@ mod vdbe; static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; use fallible_iterator::FallibleIterator; +#[cfg(not(target_family = "wasm"))] +use libloading::{Library, Symbol}; +#[cfg(not(target_family = "wasm"))] +use limbo_extension::{ExtensionApi, ExtensionEntryPoint, RESULT_OK}; use log::trace; use schema::Schema; use sqlite3_parser::ast; @@ -34,12 +38,11 @@ use storage::pager::allocate_page; use storage::sqlite3_ondisk::{DatabaseHeader, DATABASE_HEADER_SIZE}; pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; +pub use types::Value; use util::parse_schema_rows; -use translate::select::prepare_select_plan; -use types::OwnedValue; - pub use error::LimboError; +use translate::select::prepare_select_plan; pub type Result = std::result::Result; use crate::translate::optimizer::optimize_plan; @@ -56,8 +59,6 @@ pub use storage::pager::Page; pub use storage::pager::Pager; pub use storage::wal::CheckpointStatus; pub use storage::wal::Wal; -pub use types::Value; - pub static DATABASE_VERSION: OnceLock = OnceLock::new(); #[derive(Clone)] @@ -127,7 +128,7 @@ impl Database { let header = db_header; let schema = Rc::new(RefCell::new(Schema::new())); let syms = Rc::new(RefCell::new(SymbolTable::new())); - let mut db = Database { + let db = Database { pager: pager.clone(), schema: schema.clone(), header: header.clone(), @@ -135,11 +136,10 @@ impl Database { _shared_wal: shared_wal.clone(), syms, }; - ext::init(&mut db); let db = Arc::new(db); let conn = Rc::new(Connection { db: db.clone(), - pager: pager, + pager, schema: schema.clone(), header, transaction_state: RefCell::new(TransactionState::None), @@ -169,16 +169,40 @@ impl Database { pub fn define_scalar_function>( &self, name: S, - func: impl Fn(&[Value]) -> Result + 'static, + func: limbo_extension::ScalarFunction, ) { let func = function::ExternalFunc { name: name.as_ref().to_string(), - func: Box::new(func), + func, }; self.syms .borrow_mut() .functions - .insert(name.as_ref().to_string(), Rc::new(func)); + .insert(name.as_ref().to_string(), func.into()); + } + + #[cfg(not(target_family = "wasm"))] + pub fn load_extension(&self, path: &str) -> Result<()> { + let api = Box::new(self.build_limbo_extension()); + let lib = + unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? }; + let entry: Symbol = unsafe { + lib.get(b"register_extension") + .map_err(|e| LimboError::ExtensionError(e.to_string()))? + }; + let api_ptr: *const ExtensionApi = Box::into_raw(api); + let result_code = entry(api_ptr); + if result_code == RESULT_OK { + self.syms.borrow_mut().extensions.push((lib, api_ptr)); + Ok(()) + } else { + if !api_ptr.is_null() { + let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; + } + Err(LimboError::ExtensionError( + "Extension registration failed".to_string(), + )) + } } } @@ -307,7 +331,11 @@ impl Connection { Cmd::ExplainQueryPlan(stmt) => { match stmt { ast::Stmt::Select(select) => { - let mut plan = prepare_select_plan(&self.schema.borrow(), *select)?; + let mut plan = prepare_select_plan( + &self.schema.borrow(), + *select, + &self.db.syms.borrow(), + )?; optimize_plan(&mut plan)?; println!("{}", plan); } @@ -372,6 +400,11 @@ impl Connection { Ok(()) } + #[cfg(not(target_family = "wasm"))] + pub fn load_extension(&self, path: &str) -> Result<()> { + Database::load_extension(self.db.as_ref(), path) + } + /// Close a connection and checkpoint. pub fn close(&self) -> Result<()> { loop { @@ -468,15 +501,27 @@ impl Rows { } } -#[derive(Debug)] pub(crate) struct SymbolTable { pub functions: HashMap>, + #[cfg(not(target_family = "wasm"))] + extensions: Vec<(libloading::Library, *const ExtensionApi)>, +} + +impl std::fmt::Debug for SymbolTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SymbolTable") + .field("functions", &self.functions) + .finish() + } } impl SymbolTable { pub fn new() -> Self { Self { functions: HashMap::new(), + // TODO: wasm libs will be very different + #[cfg(not(target_family = "wasm"))] + extensions: Vec::new(), } } diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index e376da160..aa1aaecf9 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -35,14 +35,13 @@ impl<'a> Resolver<'a> { } pub fn resolve_function(&self, func_name: &str, arg_count: usize) -> Option { - let func_type = match Func::resolve_function(&func_name, arg_count).ok() { + match Func::resolve_function(func_name, arg_count).ok() { Some(func) => Some(func), None => self .symbol_table - .resolve_function(&func_name, arg_count) - .map(|func| Func::External(func)), - }; - func_type + .resolve_function(func_name, arg_count) + .map(|arg| Func::External(arg.clone())), + } } pub fn resolve_cached_expr_reg(&self, expr: &ast::Expr) -> Option { diff --git a/core/translate/expr.rs b/core/translate/expr.rs index dd35a5d26..d13902433 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1,7 +1,5 @@ use sqlite3_parser::ast::{self, UnaryOperator}; -#[cfg(feature = "uuid")] -use crate::ext::{ExtFunc, UuidFunc}; #[cfg(feature = "json")] use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc}; @@ -764,13 +762,23 @@ pub fn translate_expr( crate::bail_parse_error!("aggregation function in non-aggregation context") } Func::External(_) => { - let regs = program.alloc_register(); + let regs = program.alloc_registers(args_count); + for (i, arg_expr) in args.iter().enumerate() { + translate_expr( + program, + referenced_tables, + &arg_expr[i], + regs + i, + resolver, + )?; + } program.emit_insn(Insn::Function { constant_mask: 0, start_reg: regs, dest: target_register, func: func_ctx, }); + Ok(target_register) } #[cfg(feature = "json")] @@ -1428,60 +1436,6 @@ pub fn translate_expr( } } } - Func::Extension(ext_func) => match ext_func { - #[cfg(feature = "uuid")] - ExtFunc::Uuid(ref uuid_fn) => match uuid_fn { - UuidFunc::UuidStr | UuidFunc::UuidBlob | UuidFunc::Uuid7TS => { - let args = expect_arguments_exact!(args, 1, ext_func); - let regs = program.alloc_register(); - translate_expr(program, referenced_tables, &args[0], regs, resolver)?; - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: regs, - dest: target_register, - func: func_ctx, - }); - Ok(target_register) - } - UuidFunc::Uuid4Str => { - if args.is_some() { - crate::bail_parse_error!( - "{} function with arguments", - ext_func.to_string() - ); - } - let regs = program.alloc_register(); - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: regs, - dest: target_register, - func: func_ctx, - }); - Ok(target_register) - } - UuidFunc::Uuid7 => { - let args = expect_arguments_max!(args, 1, ext_func); - let mut start_reg = None; - if let Some(arg) = args.first() { - start_reg = Some(translate_and_mark( - program, - referenced_tables, - arg, - resolver, - )?); - } - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: start_reg.unwrap_or(target_register), - dest: target_register, - func: func_ctx, - }); - Ok(target_register) - } - }, - #[allow(unreachable_patterns)] - _ => unreachable!("{ext_func} not implemented yet"), - }, Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => { if args.is_some() { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 64a0ffb04..f5c835f8e 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,6 +1,7 @@ use super::{ plan::{Aggregate, Plan, SelectQueryType, SourceOperator, TableReference, TableReferenceType}, select::prepare_select_plan, + SymbolTable, }; use crate::{ function::Func, @@ -259,6 +260,7 @@ fn parse_from_clause_table( table: ast::SelectTable, operator_id_counter: &mut OperatorIdCounter, cur_table_index: usize, + syms: &SymbolTable, ) -> Result<(TableReference, SourceOperator)> { match table { ast::SelectTable::Table(qualified_name, maybe_alias, _) => { @@ -289,7 +291,7 @@ fn parse_from_clause_table( )) } ast::SelectTable::Select(subselect, maybe_alias) => { - let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect)? else { + let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect, syms)? else { unreachable!(); }; subplan.query_type = SelectQueryType::Subquery { @@ -322,6 +324,7 @@ pub fn parse_from( schema: &Schema, mut from: Option, operator_id_counter: &mut OperatorIdCounter, + syms: &SymbolTable, ) -> Result<(SourceOperator, Vec)> { if from.as_ref().and_then(|f| f.select.as_ref()).is_none() { return Ok(( @@ -339,7 +342,7 @@ pub fn parse_from( let select_owned = *std::mem::take(&mut from_owned.select).unwrap(); let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default(); let (table_reference, mut operator) = - parse_from_clause_table(schema, select_owned, operator_id_counter, table_index)?; + parse_from_clause_table(schema, select_owned, operator_id_counter, table_index, syms)?; tables.push(table_reference); table_index += 1; @@ -350,7 +353,14 @@ pub fn parse_from( is_outer_join: outer, using, predicates, - } = parse_join(schema, join, operator_id_counter, &mut tables, table_index)?; + } = parse_join( + schema, + join, + operator_id_counter, + &mut tables, + table_index, + syms, + )?; operator = SourceOperator::Join { left: Box::new(operator), right: Box::new(right), @@ -394,6 +404,7 @@ fn parse_join( operator_id_counter: &mut OperatorIdCounter, tables: &mut Vec, table_index: usize, + syms: &SymbolTable, ) -> Result { let ast::JoinedSelectTable { operator: join_operator, @@ -402,7 +413,7 @@ fn parse_join( } = join; let (table_reference, source_operator) = - parse_from_clause_table(schema, table, operator_id_counter, table_index)?; + parse_from_clause_table(schema, table, operator_id_counter, table_index, syms)?; tables.push(table_reference); diff --git a/core/translate/select.rs b/core/translate/select.rs index b1be01169..768474a8e 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -20,12 +20,16 @@ pub fn translate_select( select: ast::Select, syms: &SymbolTable, ) -> Result<()> { - let mut select_plan = prepare_select_plan(schema, select)?; + let mut select_plan = prepare_select_plan(schema, select, syms)?; optimize_plan(&mut select_plan)?; emit_program(program, select_plan, syms) } -pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result { +pub fn prepare_select_plan( + schema: &Schema, + select: ast::Select, + syms: &SymbolTable, +) -> Result { match *select.body.select { ast::OneSelect::Select { mut columns, @@ -42,7 +46,8 @@ pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result let mut operator_id_counter = OperatorIdCounter::new(); // Parse the FROM clause - let (source, referenced_tables) = parse_from(schema, from, &mut operator_id_counter)?; + let (source, referenced_tables) = + parse_from(schema, from, &mut operator_id_counter, syms)?; let mut plan = SelectPlan { source, @@ -142,7 +147,24 @@ pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result contains_aggregates, }); } - _ => {} + Err(_) => { + if syms.functions.contains_key(&name.0) { + let contains_aggregates = resolve_aggregates( + expr, + &mut aggregate_expressions, + ); + plan.result_columns.push(ResultSetColumn { + name: get_name( + maybe_alias.as_ref(), + expr, + &plan.referenced_tables, + || format!("expr_{}", result_column_idx), + ), + expr: expr.clone(), + contains_aggregates, + }); + } + } } } ast::Expr::FunctionCallStar { @@ -180,7 +202,7 @@ pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result } expr => { let contains_aggregates = - resolve_aggregates(&expr, &mut aggregate_expressions); + resolve_aggregates(expr, &mut aggregate_expressions); plan.result_columns.push(ResultSetColumn { name: get_name( maybe_alias.as_ref(), diff --git a/core/types.rs b/core/types.rs index d9a496bfb..81d9d5328 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,11 +1,10 @@ +use crate::error::LimboError; +use crate::ext::{ExtValue, ExtValueType}; +use crate::storage::sqlite3_ondisk::write_varint; +use crate::Result; use std::fmt::Display; use std::rc::Rc; -use crate::error::LimboError; -use crate::Result; - -use crate::storage::sqlite3_ondisk::write_varint; - #[derive(Debug, Clone, PartialEq)] pub enum Value<'a> { Null, @@ -94,6 +93,50 @@ impl Display for OwnedValue { } } +impl OwnedValue { + pub fn to_ffi(&self) -> ExtValue { + match self { + Self::Null => ExtValue::null(), + Self::Integer(i) => ExtValue::from_integer(*i), + Self::Float(fl) => ExtValue::from_float(*fl), + Self::Text(text) => ExtValue::from_text(text.value.to_string()), + Self::Blob(blob) => ExtValue::from_blob(blob.to_vec()), + Self::Agg(_) => todo!("Aggregate values not yet supported"), + Self::Record(_) => todo!("Record values not yet supported"), + } + } + + pub fn from_ffi(v: &ExtValue) -> Self { + match v.value_type() { + ExtValueType::Null => OwnedValue::Null, + ExtValueType::Integer => { + let Some(int) = v.to_integer() else { + return OwnedValue::Null; + }; + OwnedValue::Integer(int) + } + ExtValueType::Float => { + let Some(float) = v.to_float() else { + return OwnedValue::Null; + }; + OwnedValue::Float(float) + } + ExtValueType::Text => { + let Some(text) = v.to_text() else { + return OwnedValue::Null; + }; + OwnedValue::build_text(std::rc::Rc::new(text)) + } + ExtValueType::Blob => { + let Some(blob) = v.to_blob() else { + return OwnedValue::Null; + }; + OwnedValue::Blob(std::rc::Rc::new(blob)) + } + } + } +} + #[derive(Debug, Clone, PartialEq)] pub enum AggContext { Avg(OwnedValue, OwnedValue), // acc and count diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 28db889cf..580834d39 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -25,8 +25,7 @@ pub mod likeop; pub mod sorter; use crate::error::{LimboError, SQLITE_CONSTRAINT_PRIMARYKEY}; -#[cfg(feature = "uuid")] -use crate::ext::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, ExtFunc, UuidFunc}; +use crate::ext::ExtValue; use crate::function::{AggFunc, FuncCtx, MathFunc, MathFuncArity, ScalarFunc}; use crate::pseudo::PseudoCursor; use crate::result::LimboResult; @@ -53,7 +52,7 @@ use rand::distributions::{Distribution, Uniform}; use rand::{thread_rng, Rng}; use regex::{Regex, RegexBuilder}; use sorter::Sorter; -use std::borrow::{Borrow, BorrowMut}; +use std::borrow::BorrowMut; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap}; use std::rc::{Rc, Weak}; @@ -147,6 +146,33 @@ macro_rules! return_if_io { }; } +macro_rules! call_external_function { + ( + $func_ptr:expr, + $dest_register:expr, + $state:expr, + $arg_count:expr, + $start_reg:expr + ) => {{ + if $arg_count == 0 { + let result_c_value: ExtValue = ($func_ptr)(0, std::ptr::null()); + let result_ov = OwnedValue::from_ffi(&result_c_value); + $state.registers[$dest_register] = result_ov; + } else { + let register_slice = &$state.registers[$start_reg..$start_reg + $arg_count]; + let mut ext_values: Vec = Vec::with_capacity($arg_count); + for ov in register_slice.iter() { + let val = ov.to_ffi(); + ext_values.push(val); + } + let argv_ptr = ext_values.as_ptr(); + let result_c_value: ExtValue = ($func_ptr)($arg_count as i32, argv_ptr); + let result_ov = OwnedValue::from_ffi(&result_c_value); + $state.registers[$dest_register] = result_ov; + } + }}; +} + struct RegexCache { like: HashMap, glob: HashMap, @@ -1838,42 +1864,8 @@ impl Program { state.registers[*dest] = exec_replace(source, pattern, replacement); } }, - #[allow(unreachable_patterns)] - crate::function::Func::Extension(extfn) => match extfn { - #[cfg(feature = "uuid")] - ExtFunc::Uuid(uuidfn) => match uuidfn { - UuidFunc::Uuid4Str => { - state.registers[*dest] = exec_uuid(uuidfn, None)? - } - UuidFunc::Uuid7 => match arg_count { - 0 => { - state.registers[*dest] = - exec_uuid(uuidfn, None).unwrap_or(OwnedValue::Null); - } - 1 => { - let reg_value = state.registers[*start_reg].borrow(); - state.registers[*dest] = exec_uuid(uuidfn, Some(reg_value)) - .unwrap_or(OwnedValue::Null); - } - _ => unreachable!(), - }, - _ => { - // remaining accept 1 arg - let reg_value = state.registers[*start_reg].borrow(); - state.registers[*dest] = match uuidfn { - UuidFunc::Uuid7TS => Some(exec_ts_from_uuid7(reg_value)), - UuidFunc::UuidStr => exec_uuidstr(reg_value).ok(), - UuidFunc::UuidBlob => exec_uuidblob(reg_value).ok(), - _ => unreachable!(), - } - .unwrap_or(OwnedValue::Null); - } - }, - _ => unreachable!(), // when more extension types are added - }, crate::function::Func::External(f) => { - let result = (f.func)(&[])?; - state.registers[*dest] = result; + call_external_function! {f.func, *dest, state, arg_count, *start_reg }; } crate::function::Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => match math_func { diff --git a/extensions/uuid/Cargo.toml b/extensions/uuid/Cargo.toml new file mode 100644 index 000000000..ed2c43e87 --- /dev/null +++ b/extensions/uuid/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "limbo_uuid" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + + +[dependencies] +limbo_extension = { path = "../../limbo_extension"} +uuid = { version = "1.11.0", features = ["v4", "v7"] } +log = "0.4.20" diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs new file mode 100644 index 000000000..e6a3a4f9b --- /dev/null +++ b/extensions/uuid/src/lib.rs @@ -0,0 +1,145 @@ +use limbo_extension::{ + declare_scalar_functions, register_extension, register_scalar_functions, Value, ValueType, +}; + +register_extension! { + scalars: { + "uuid4_str" => uuid4_str, + "uuid4" => uuid4_blob, + "uuid7_str" => uuid7_str, + "uuid7" => uuid7_blob, + "uuid_str" => uuid_str, + "uuid_blob" => uuid_blob, + "uuid7_timestamp_ms" => exec_ts_from_uuid7, + }, +} + +declare_scalar_functions! { + #[args(0)] + fn uuid4_str(_args: &[Value]) -> Value { + let uuid = uuid::Uuid::new_v4().to_string(); + Value::from_text(uuid) + } + + #[args(0)] + fn uuid4_blob(_args: &[Value]) -> Value { + let uuid = uuid::Uuid::new_v4(); + let bytes = uuid.as_bytes(); + Value::from_blob(bytes.to_vec()) + } + + #[args(0..=1)] + fn uuid7_str(args: &[Value]) -> Value { + let timestamp = if args.is_empty() { + let ctx = uuid::ContextV7::new(); + uuid::Timestamp::now(ctx) + } else { + let arg = &args[0]; + match arg.value_type() { + ValueType::Integer => { + let ctx = uuid::ContextV7::new(); + let Some(int) = arg.to_integer() else { + return Value::null(); + }; + uuid::Timestamp::from_unix(ctx, int as u64, 0) + } + ValueType::Text => { + let Some(text) = arg.to_text() else { + return Value::null(); + }; + match text.parse::() { + Ok(unix) => { + if unix <= 0 { + return Value::null(); + } + uuid::Timestamp::from_unix(uuid::ContextV7::new(), unix as u64, 0) + } + Err(_) => return Value::null(), + } + } + _ => return Value::null(), + } + }; + let uuid = uuid::Uuid::new_v7(timestamp); + Value::from_text(uuid.to_string()) + } + + #[args(0..=1)] + fn uuid7_blob(args: &[Value]) -> Value { + let timestamp = if args.is_empty() { + let ctx = uuid::ContextV7::new(); + uuid::Timestamp::now(ctx) + } else if args[0].value_type() == limbo_extension::ValueType::Integer { + let ctx = uuid::ContextV7::new(); + let Some(int) = args[0].to_integer() else { + return Value::null(); + }; + uuid::Timestamp::from_unix(ctx, int as u64, 0) + } else { + return Value::null(); + }; + let uuid = uuid::Uuid::new_v7(timestamp); + let bytes = uuid.as_bytes(); + Value::from_blob(bytes.to_vec()) + } + + #[args(1)] + fn exec_ts_from_uuid7(args: &[Value]) -> Value { + match args[0].value_type() { + ValueType::Blob => { + let Some(blob) = &args[0].to_blob() else { + return Value::null(); + }; + let uuid = uuid::Uuid::from_slice(blob.as_slice()).unwrap(); + let unix = uuid_to_unix(uuid.as_bytes()); + Value::from_integer(unix as i64) + } + ValueType::Text => { + let Some(text) = args[0].to_text() else { + return Value::null(); + }; + let Ok(uuid) = uuid::Uuid::parse_str(&text) else { + return Value::null(); + }; + let unix = uuid_to_unix(uuid.as_bytes()); + Value::from_integer(unix as i64) + } + _ => Value::null(), + } + } + + #[args(1)] + fn uuid_str(args: &[Value]) -> Value { + let Some(blob) = args[0].to_blob() else { + return Value::null(); + }; + let parsed = uuid::Uuid::from_slice(blob.as_slice()).ok().map(|u| u.to_string()); + match parsed { + Some(s) => Value::from_text(s), + None => Value::null() + } + } + + #[args(1)] + fn uuid_blob(args: &[Value]) -> Value { + let Some(text) = args[0].to_text() else { + return Value::null(); + }; + match uuid::Uuid::parse_str(&text) { + Ok(uuid) => { + Value::from_blob(uuid.as_bytes().to_vec()) + } + Err(_) => Value::null() + } + } +} + +#[inline(always)] +fn uuid_to_unix(uuid: &[u8; 16]) -> u64 { + ((uuid[0] as u64) << 40) + | ((uuid[1] as u64) << 32) + | ((uuid[2] as u64) << 24) + | ((uuid[3] as u64) << 16) + | ((uuid[4] as u64) << 8) + | (uuid[5] as u64) +} diff --git a/limbo_extension/Cargo.toml b/limbo_extension/Cargo.toml new file mode 100644 index 000000000..2928ed853 --- /dev/null +++ b/limbo_extension/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "limbo_extension" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +log = "0.4.20" diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs new file mode 100644 index 000000000..0666c588b --- /dev/null +++ b/limbo_extension/src/lib.rs @@ -0,0 +1,303 @@ +use std::os::raw::{c_char, c_void}; + +pub type ResultCode = i32; + +pub const RESULT_OK: ResultCode = 0; +pub const RESULT_ERROR: ResultCode = 1; +// TODO: more error types + +pub type ExtensionEntryPoint = extern "C" fn(api: *const ExtensionApi) -> ResultCode; +pub type ScalarFunction = extern "C" fn(argc: i32, *const Value) -> Value; + +#[repr(C)] +pub struct ExtensionApi { + pub ctx: *mut c_void, + pub register_scalar_function: + extern "C" fn(ctx: *mut c_void, name: *const c_char, func: ScalarFunction) -> ResultCode, +} + +#[macro_export] +macro_rules! register_extension { + ( + scalars: { $( $scalar_name:expr => $scalar_func:ident ),* $(,)? }, + //aggregates: { $( $agg_name:expr => ($step_func:ident, $finalize_func:ident) ),* $(,)? }, + //virtual_tables: { $( $vt_name:expr => $vt_impl:expr ),* $(,)? } + ) => { + #[no_mangle] + pub unsafe extern "C" fn register_extension(api: *const $crate::ExtensionApi) -> $crate::ResultCode { + if api.is_null() { + return $crate::RESULT_ERROR; + } + + register_scalar_functions! { api, $( $scalar_name => $scalar_func ),* } + // TODO: + //register_aggregate_functions! { $( $agg_name => ($step_func, $finalize_func) ),* } + //register_virtual_tables! { $( $vt_name => $vt_impl ),* } + $crate::RESULT_OK + } + } +} + +#[macro_export] +macro_rules! register_scalar_functions { + ( $api:expr, $( $fname:expr => $fptr:ident ),* ) => { + unsafe { + $( + let cname = std::ffi::CString::new($fname).unwrap(); + ((*$api).register_scalar_function)((*$api).ctx, cname.as_ptr(), $fptr); + )* + } + } +} + +#[macro_export] +macro_rules! declare_scalar_functions { + ( + $( + #[args($($args_count:tt)+)] + fn $func_name:ident ($args:ident : &[Value]) -> Value $body:block + )* + ) => { + $( + extern "C" fn $func_name( + argc: i32, + argv: *const $crate::Value + ) -> $crate::Value { + let valid_args = { + match argc { + $($args_count)+ => true, + _ => false, + } + }; + if !valid_args { + return $crate::Value::null(); + } + if argc == 0 || argv.is_null() { + log::debug!("{} was called with no arguments", stringify!($func_name)); + let $args: &[$crate::Value] = &[]; + $body + } else { + let ptr_slice = unsafe{ std::slice::from_raw_parts(argv, argc as usize)}; + let $args: &[$crate::Value] = ptr_slice; + $body + } + } + )* + }; +} + +#[repr(C)] +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum ValueType { + Null, + Integer, + Float, + Text, + Blob, +} + +#[repr(C)] +pub struct Value { + value_type: ValueType, + value: *mut c_void, +} + +impl std::fmt::Debug for Value { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.value_type { + ValueType::Null => write!(f, "Value {{ Null }}"), + ValueType::Integer => write!(f, "Value {{ Integer: {} }}", unsafe { + *(self.value as *const i64) + }), + ValueType::Float => write!(f, "Value {{ Float: {} }}", unsafe { + *(self.value as *const f64) + }), + ValueType::Text => write!(f, "Value {{ Text: {:?} }}", unsafe { + &*(self.value as *const TextValue) + }), + ValueType::Blob => write!(f, "Value {{ Blob: {:?} }}", unsafe { + &*(self.value as *const Blob) + }), + } + } +} + +#[repr(C)] +pub struct TextValue { + text: *const u8, + len: u32, +} + +impl std::fmt::Debug for TextValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TextValue {{ text: {:?}, len: {} }}", + self.text, self.len + ) + } +} + +impl Default for TextValue { + fn default() -> Self { + Self { + text: std::ptr::null(), + len: 0, + } + } +} + +impl TextValue { + pub(crate) fn new(text: *const u8, len: usize) -> Self { + Self { + text, + len: len as u32, + } + } + + fn as_str(&self) -> &str { + if self.text.is_null() { + return ""; + } + unsafe { + std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.text, self.len as usize)) + } + } +} + +#[repr(C)] +pub struct Blob { + data: *const u8, + size: u64, +} + +impl std::fmt::Debug for Blob { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Blob {{ data: {:?}, size: {} }}", self.data, self.size) + } +} + +impl Blob { + pub fn new(data: *const u8, size: u64) -> Self { + Self { data, size } + } +} + +impl Value { + pub fn null() -> Self { + Self { + value_type: ValueType::Null, + value: std::ptr::null_mut(), + } + } + + pub fn value_type(&self) -> ValueType { + self.value_type + } + + pub fn to_float(&self) -> Option { + if self.value_type != ValueType::Float { + return None; + } + if self.value.is_null() { + return None; + } + Some(unsafe { *(self.value as *const f64) }) + } + + pub fn to_text(&self) -> Option { + if self.value_type != ValueType::Text { + return None; + } + if self.value.is_null() { + return None; + } + let txt = unsafe { &*(self.value as *const TextValue) }; + Some(String::from(txt.as_str())) + } + + pub fn to_blob(&self) -> Option> { + if self.value_type != ValueType::Blob { + return None; + } + if self.value.is_null() { + return None; + } + let blob = unsafe { &*(self.value as *const Blob) }; + let slice = unsafe { std::slice::from_raw_parts(blob.data, blob.size as usize) }; + Some(slice.to_vec()) + } + + pub fn to_integer(&self) -> Option { + if self.value_type != ValueType::Integer { + return None; + } + if self.value.is_null() { + return None; + } + Some(unsafe { *(self.value as *const i64) }) + } + + pub fn from_integer(value: i64) -> Self { + let boxed = Box::new(value); + Self { + value_type: ValueType::Integer, + value: Box::into_raw(boxed) as *mut c_void, + } + } + + pub fn from_float(value: f64) -> Self { + let boxed = Box::new(value); + Self { + value_type: ValueType::Float, + value: Box::into_raw(boxed) as *mut c_void, + } + } + + pub fn from_text(s: String) -> Self { + let buffer = s.into_boxed_str(); + let ptr = buffer.as_ptr(); + let len = buffer.len(); + std::mem::forget(buffer); + let text_value = TextValue::new(ptr, len); + let text_box = Box::new(text_value); + Self { + value_type: ValueType::Text, + value: Box::into_raw(text_box) as *mut c_void, + } + } + + pub fn from_blob(value: Vec) -> Self { + let boxed = Box::new(Blob::new(value.as_ptr(), value.len() as u64)); + std::mem::forget(value); + Self { + value_type: ValueType::Blob, + value: Box::into_raw(boxed) as *mut c_void, + } + } + + /// # Safety + /// consumes the value while freeing the underlying memory with null check. + /// however this does assume that the type was properly constructed with + /// the appropriate value_type and value. + pub unsafe fn free(self) { + if self.value.is_null() { + return; + } + match self.value_type { + ValueType::Integer => { + let _ = Box::from_raw(self.value as *mut i64); + } + ValueType::Float => { + let _ = Box::from_raw(self.value as *mut f64); + } + ValueType::Text => { + let _ = Box::from_raw(self.value as *mut TextValue); + } + ValueType::Blob => { + let _ = Box::from_raw(self.value as *mut Blob); + } + ValueType::Null => {} + } + } +} diff --git a/testing/extensions.py b/testing/extensions.py new file mode 100755 index 000000000..74383be94 --- /dev/null +++ b/testing/extensions.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +import os +import subprocess +import select +import time +import uuid + +sqlite_exec = "./target/debug/limbo" +sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ") + + +def init_limbo(): + pipe = subprocess.Popen( + [sqlite_exec, *sqlite_flags], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=0, + ) + return pipe + + +def execute_sql(pipe, sql): + end_suffix = "END_OF_RESULT" + write_to_pipe(pipe, sql) + write_to_pipe(pipe, f"SELECT '{end_suffix}';\n") + stdout = pipe.stdout + stderr = pipe.stderr + output = "" + while True: + ready_to_read, _, error_in_pipe = select.select( + [stdout, stderr], [], [stdout, stderr] + ) + ready_to_read_or_err = set(ready_to_read + error_in_pipe) + if stderr in ready_to_read_or_err: + exit_on_error(stderr) + + if stdout in ready_to_read_or_err: + fragment = stdout.read(select.PIPE_BUF) + output += fragment.decode() + if output.rstrip().endswith(end_suffix): + output = output.rstrip().removesuffix(end_suffix) + break + output = strip_each_line(output) + return output + + +def strip_each_line(lines: str) -> str: + lines = lines.split("\n") + lines = [line.strip() for line in lines if line != ""] + return "\n".join(lines) + + +def write_to_pipe(pipe, command): + if pipe.stdin is None: + raise RuntimeError("Failed to write to shell") + pipe.stdin.write((command + "\n").encode()) + pipe.stdin.flush() + + +def exit_on_error(stderr): + while True: + ready_to_read, _, _ = select.select([stderr], [], []) + if not ready_to_read: + break + print(stderr.read().decode(), end="") + exit(1) + + +def run_test(pipe, sql, validator=None): + print(f"Running test: {sql}") + result = execute_sql(pipe, sql) + if validator is not None: + if not validator(result): + print(f"Test FAILED: {sql}") + print(f"Returned: {result}") + raise Exception("Validation failed") + print("Test PASSED") + + +def validate_blob(result): + # HACK: blobs are difficult to test because the shell + # tries to return them as utf8 strings, so we call hex + # and assert they are valid hex digits + return int(result, 16) is not None + + +def validate_string_uuid(result): + return len(result) == 36 and result.count("-") == 4 + + +def returns_null(result): + return result == "" or result == b"\n" or result == b"" + + +def assert_now_unixtime(result): + return result == str(int(time.time())) + + +def assert_specific_time(result): + return result == "1736720789" + + +def main(): + specific_time = "01945ca0-3189-76c0-9a8f-caf310fc8b8e" + extension_path = "./target/debug/liblimbo_uuid.so" + pipe = init_limbo() + try: + # before extension loads, assert no function + run_test(pipe, "SELECT uuid4();", returns_null) + run_test(pipe, "SELECT uuid4_str();", returns_null) + run_test(pipe, f".load {extension_path}", returns_null) + print("Extension loaded successfully.") + run_test(pipe, "SELECT hex(uuid4());", validate_blob) + run_test(pipe, "SELECT uuid4_str();", validate_string_uuid) + run_test(pipe, "SELECT hex(uuid7());", validate_blob) + run_test( + pipe, + "SELECT uuid7_timestamp_ms(uuid7()) / 1000;", + ) + run_test(pipe, "SELECT uuid7_str();", validate_string_uuid) + run_test(pipe, "SELECT uuid_str(uuid7());", validate_string_uuid) + run_test(pipe, "SELECT hex(uuid_blob(uuid7_str()));", validate_blob) + run_test(pipe, "SELECT uuid_str(uuid_blob(uuid7_str()));", validate_string_uuid) + run_test( + pipe, + f"SELECT uuid7_timestamp_ms('{specific_time}') / 1000;", + assert_specific_time, + ) + except Exception as e: + print(f"Test FAILED: {e}") + pipe.terminate() + exit(1) + pipe.terminate() + print("All tests passed successfully.") + + +if __name__ == "__main__": + main()