Sketch out runtime extension loading

This commit is contained in:
PThorpe92
2025-01-08 23:16:57 -05:00
parent bfbaa80bdc
commit 0a10d893d9
15 changed files with 291 additions and 33 deletions

30
Cargo.lock generated
View File

@@ -564,7 +564,7 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d"
dependencies = [
"uuid",
"uuid 1.11.0",
]
[[package]]
@@ -694,6 +694,10 @@ dependencies = [
"str-buf",
]
[[package]]
name = "extension_api"
version = "0.0.11"
[[package]]
name = "fallible-iterator"
version = "0.2.0"
@@ -1137,6 +1141,16 @@ version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libloading"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.52.6",
]
[[package]]
name = "libmimalloc-sys"
version = "0.1.39"
@@ -1204,6 +1218,7 @@ dependencies = [
"cfg_block",
"chrono",
"criterion",
"extension_api",
"fallible-iterator 0.3.0",
"getrandom",
"hex",
@@ -1212,6 +1227,7 @@ dependencies = [
"jsonb",
"julian_day_converter",
"libc",
"libloading",
"limbo_macros",
"log",
"miette",
@@ -1232,7 +1248,7 @@ dependencies = [
"sqlite3-parser",
"tempfile",
"thiserror 1.0.69",
"uuid",
"uuid 1.11.0",
]
[[package]]
@@ -2260,7 +2276,7 @@ dependencies = [
"debugid",
"memmap2",
"stable_deref_trait",
"uuid",
"uuid 1.11.0",
]
[[package]]
@@ -2502,6 +2518,14 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "0.0.11"
dependencies = [
"extension_api",
"uuid 1.11.0",
]
[[package]]
name = "uuid"
version = "1.11.0"

View File

@@ -11,7 +11,7 @@ members = [
"sqlite3",
"core",
"simulator",
"test", "macros",
"test", "macros", "extension_api", "extensions/uuid",
]
exclude = ["perf/latency/limbo"]

View File

@@ -129,6 +129,8 @@ pub enum Command {
Tables,
/// Import data from FILE into TABLE
Import,
/// Loads an extension library
LoadExtension,
}
impl Command {
@@ -141,7 +143,12 @@ impl Command {
| Self::ShowInfo
| Self::Tables
| Self::SetOutput => 0,
Self::Open | Self::OutputMode | Self::Cwd | Self::Echo | Self::NullValue => 1,
Self::Open
| Self::OutputMode
| Self::Cwd
| Self::Echo
| Self::NullValue
| Self::LoadExtension => 1,
Self::Import => 2,
} + 1) // argv0
}
@@ -160,6 +167,7 @@ impl Command {
Self::NullValue => ".nullvalue <string>",
Self::Echo => ".echo on|off",
Self::Tables => ".tables",
Self::LoadExtension => ".load",
Self::Import => &IMPORT_HELP,
}
}
@@ -182,6 +190,7 @@ impl FromStr for Command {
".nullvalue" => Ok(Self::NullValue),
".echo" => Ok(Self::Echo),
".import" => Ok(Self::Import),
".load" => Ok(Self::LoadExtension),
_ => Err("Unknown command".to_string()),
}
}
@@ -314,6 +323,16 @@ impl Limbo {
};
}
fn handle_load_extension(&mut self) -> Result<(), String> {
let mut args = self.input_buff.split_whitespace();
let _ = args.next();
let lib = args
.next()
.ok_or("No library specified")
.map_err(|e| e.to_string())?;
self.conn.load_extension(lib).map_err(|e| e.to_string())
}
fn display_in_memory(&mut self) -> std::io::Result<()> {
if self.opts.db_file == ":memory:" {
self.writeln("Connected to a transient in-memory database.")?;
@@ -537,6 +556,11 @@ impl Limbo {
let _ = self.writeln(e.to_string());
};
}
Command::LoadExtension => {
if let Err(e) = self.handle_load_extension() {
let _ = self.writeln(e.to_string());
}
}
}
} else {
let _ = self.write_fmt(format_args!(

View File

@@ -35,6 +35,7 @@ rustix = "0.38.34"
mimalloc = { version = "*", default-features = false }
[dependencies]
extension_api = { path = "../extension_api" }
cfg_block = "0.1.1"
fallible-iterator = "0.3.0"
hex = "0.4.3"
@@ -58,6 +59,7 @@ bumpalo = { version = "3.16.0", features = ["collections", "boxed"] }
limbo_macros = { path = "../macros" }
uuid = { version = "1.11.0", features = ["v4", "v7"], optional = true }
miette = "7.4.0"
libloading = "0.8.6"
[target.'cfg(not(target_family = "windows"))'.dev-dependencies]
pprof = { version = "0.14.0", features = ["criterion", "flamegraph"] }

View File

@@ -39,6 +39,8 @@ pub enum LimboError {
InvalidModifier(String),
#[error("Runtime error: {0}")]
Constraint(String),
#[error("Extension error: {0}")]
ExtensionError(String),
}
#[macro_export]

View File

@@ -1,8 +1,39 @@
#[cfg(feature = "uuid")]
mod uuid;
use crate::{function::ExternalFunc, Database};
use std::sync::Arc;
use extension_api::{AggregateFunction, ExtensionApi, Result, ScalarFunction, VirtualTable};
#[cfg(feature = "uuid")]
pub use uuid::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, UuidFunc};
impl ExtensionApi for Database {
fn register_scalar_function(
&self,
name: &str,
func: Arc<dyn ScalarFunction>,
) -> extension_api::Result<()> {
let ext_func = ExternalFunc::new(name, func.clone());
self.syms
.borrow_mut()
.functions
.insert(name.to_string(), Arc::new(ext_func));
Ok(())
}
fn register_aggregate_function(
&self,
_name: &str,
_func: Arc<dyn AggregateFunction>,
) -> Result<()> {
todo!("implement aggregate function registration");
}
fn register_virtual_table(&self, _name: &str, _table: Arc<dyn VirtualTable>) -> Result<()> {
todo!("implement virtual table registration");
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ExtFunc {
#[cfg(feature = "uuid")]
@@ -31,7 +62,7 @@ impl ExtFunc {
}
}
pub fn init(db: &mut crate::Database) {
#[cfg(feature = "uuid")]
uuid::init(db);
}
//pub fn init(db: &mut crate::Database) {
// #[cfg(feature = "uuid")]
// uuid::init(db);
//}

View File

@@ -136,9 +136,9 @@ fn uuid_to_unix(uuid: &[u8; 16]) -> u64 {
| (uuid[5] as u64)
}
pub fn init(db: &mut Database) {
db.define_scalar_function("uuid4", |_args| exec_uuid4());
}
//pub fn init(db: &mut Database) {
// db.define_scalar_function("uuid4", |_args| exec_uuid4());
//}
#[cfg(test)]
#[cfg(feature = "uuid")]

View File

@@ -1,11 +1,20 @@
use crate::ext::ExtFunc;
use std::fmt;
use std::fmt::{Debug, Display};
use std::rc::Rc;
use std::sync::Arc;
pub struct ExternalFunc {
pub name: String,
pub func: Box<dyn Fn(&[crate::types::Value]) -> crate::Result<crate::types::OwnedValue>>,
pub func: Arc<dyn extension_api::ScalarFunction>,
}
impl ExternalFunc {
pub fn new(name: &str, func: Arc<dyn extension_api::ScalarFunction>) -> Self {
Self {
name: name.to_string(),
func,
}
}
}
impl Debug for ExternalFunc {
@@ -300,7 +309,7 @@ pub enum Func {
#[cfg(feature = "json")]
Json(JsonFunc),
Extension(ExtFunc),
External(Rc<ExternalFunc>),
External(Arc<ExternalFunc>),
}
impl Display for Func {

View File

@@ -17,7 +17,9 @@ mod vdbe;
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
use extension_api::{Extension, ExtensionApi};
use fallible_iterator::FallibleIterator;
use libloading::{Library, Symbol};
use log::trace;
use schema::Schema;
use sqlite3_parser::ast;
@@ -34,12 +36,11 @@ use storage::pager::allocate_page;
use storage::sqlite3_ondisk::{DatabaseHeader, DATABASE_HEADER_SIZE};
pub use storage::wal::WalFile;
pub use storage::wal::WalFileShared;
pub use types::Value;
use util::parse_schema_rows;
use translate::select::prepare_select_plan;
use types::OwnedValue;
pub use error::LimboError;
use translate::select::prepare_select_plan;
pub type Result<T> = std::result::Result<T, error::LimboError>;
use crate::translate::optimizer::optimize_plan;
@@ -56,8 +57,6 @@ pub use storage::pager::Page;
pub use storage::pager::Pager;
pub use storage::wal::CheckpointStatus;
pub use storage::wal::Wal;
pub use types::Value;
pub static DATABASE_VERSION: OnceLock<String> = OnceLock::new();
#[derive(Clone)]
@@ -135,11 +134,11 @@ impl Database {
_shared_wal: shared_wal.clone(),
syms,
};
ext::init(&mut db);
// ext::init(&mut db);
let db = Arc::new(db);
let conn = Rc::new(Connection {
db: db.clone(),
pager: pager,
pager,
schema: schema.clone(),
header,
transaction_state: RefCell::new(TransactionState::None),
@@ -169,16 +168,31 @@ impl Database {
pub fn define_scalar_function<S: AsRef<str>>(
&self,
name: S,
func: impl Fn(&[Value]) -> Result<OwnedValue> + 'static,
func: Arc<dyn extension_api::ScalarFunction>,
) {
let func = function::ExternalFunc {
name: name.as_ref().to_string(),
func: Box::new(func),
func: func.clone(),
};
self.syms
.borrow_mut()
.functions
.insert(name.as_ref().to_string(), Rc::new(func));
.insert(name.as_ref().to_string(), Arc::new(func));
}
pub fn load_extension(&self, path: &str) -> Result<()> {
let lib =
unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? };
unsafe {
let register: Symbol<unsafe extern "C" fn(&dyn ExtensionApi) -> Box<dyn Extension>> =
lib.get(b"register_extension")
.map_err(|e| LimboError::ExtensionError(e.to_string()))?;
let extension = register(self);
extension
.load()
.map_err(|e| LimboError::ExtensionError(e.to_string()))?;
}
Ok(())
}
}
@@ -372,6 +386,10 @@ impl Connection {
Ok(())
}
pub fn load_extension(&self, path: &str) -> Result<()> {
Database::load_extension(self.db.as_ref(), path)
}
/// Close a connection and checkpoint.
pub fn close(&self) -> Result<()> {
loop {
@@ -468,15 +486,24 @@ impl Rows {
}
}
#[derive(Debug)]
pub(crate) struct SymbolTable {
pub functions: HashMap<String, Rc<crate::function::ExternalFunc>>,
pub functions: HashMap<String, Arc<crate::function::ExternalFunc>>,
extensions: Vec<Rc<dyn Extension>>,
}
impl std::fmt::Debug for SymbolTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SymbolTable")
.field("functions", &self.functions)
.finish()
}
}
impl SymbolTable {
pub fn new() -> Self {
Self {
functions: HashMap::new(),
extensions: Vec::new(),
}
}
@@ -484,7 +511,7 @@ impl SymbolTable {
&self,
name: &str,
_arg_count: usize,
) -> Option<Rc<crate::function::ExternalFunc>> {
) -> Option<Arc<crate::function::ExternalFunc>> {
self.functions.get(name).cloned()
}
}

View File

@@ -1,8 +1,8 @@
use std::fmt::Display;
use std::rc::Rc;
use crate::error::LimboError;
use crate::Result;
use extension_api::Value as ExtValue;
use std::fmt::Display;
use std::rc::Rc;
use crate::storage::sqlite3_ondisk::write_varint;
@@ -15,6 +15,45 @@ pub enum Value<'a> {
Blob(&'a Vec<u8>),
}
impl From<&OwnedValue> for extension_api::Value {
fn from(value: &OwnedValue) -> Self {
match value {
OwnedValue::Null => extension_api::Value::Null,
OwnedValue::Integer(i) => extension_api::Value::Integer(*i),
OwnedValue::Float(f) => extension_api::Value::Float(*f),
OwnedValue::Text(text) => extension_api::Value::Text(text.value.to_string()),
OwnedValue::Blob(blob) => extension_api::Value::Blob(blob.to_vec()),
OwnedValue::Agg(_) => {
panic!("Cannot convert Aggregate context to extension_api::Value")
} // Handle appropriately
OwnedValue::Record(_) => panic!("Cannot convert Record to extension_api::Value"), // Handle appropriately
}
}
}
impl From<ExtValue> for OwnedValue {
fn from(value: ExtValue) -> Self {
match value {
ExtValue::Null => OwnedValue::Null,
ExtValue::Integer(i) => OwnedValue::Integer(i),
ExtValue::Float(f) => OwnedValue::Float(f),
ExtValue::Text(text) => OwnedValue::Text(LimboText::new(Rc::new(text.to_string()))),
ExtValue::Blob(blob) => OwnedValue::Blob(Rc::new(blob.to_vec())),
}
}
}
impl<'a> From<&'a crate::Value<'a>> for ExtValue {
fn from(value: &'a crate::Value<'a>) -> Self {
match value {
crate::Value::Null => extension_api::Value::Null,
crate::Value::Integer(i) => extension_api::Value::Integer(*i),
crate::Value::Float(f) => extension_api::Value::Float(*f),
crate::Value::Text(t) => extension_api::Value::Text(t.to_string()),
crate::Value::Blob(b) => extension_api::Value::Blob(b.to_vec()),
}
}
}
impl Display for Value<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {

View File

@@ -1872,8 +1872,13 @@ impl Program {
_ => unreachable!(), // when more extension types are added
},
crate::function::Func::External(f) => {
let result = (f.func)(&[])?;
state.registers[*dest] = result;
let values = &state.registers[*start_reg..*start_reg + arg_count];
let args: Vec<_> = values.into_iter().map(|v| v.into()).collect();
let result = f
.func
.execute(args.as_slice())
.map_err(|e| LimboError::ExtensionError(e.to_string()))?;
state.registers[*dest] = result.into();
}
crate::function::Func::Math(math_func) => match math_func.arity() {
MathFuncArity::Nullary => match math_func {

9
extension_api/Cargo.toml Normal file
View File

@@ -0,0 +1,9 @@
[package]
name = "extension_api"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
[dependencies]

75
extension_api/src/lib.rs Normal file
View File

@@ -0,0 +1,75 @@
use std::any::Any;
use std::rc::Rc;
use std::sync::Arc;
pub type Result<T> = std::result::Result<T, LimboApiError>;
pub trait Extension {
fn load(&self) -> Result<()>;
}
#[derive(Debug)]
pub enum LimboApiError {
ConnectionError(String),
RegisterFunctionError(String),
ValueError(String),
VTableError(String),
}
impl From<std::io::Error> for LimboApiError {
fn from(e: std::io::Error) -> Self {
Self::ConnectionError(e.to_string())
}
}
impl std::fmt::Display for LimboApiError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::ConnectionError(e) => write!(f, "Connection error: {e}"),
Self::RegisterFunctionError(e) => write!(f, "Register function error: {e}"),
Self::ValueError(e) => write!(f, "Value error: {e}"),
Self::VTableError(e) => write!(f, "VTable error: {e}"),
}
}
}
pub trait ExtensionApi {
fn register_scalar_function(&self, name: &str, func: Arc<dyn ScalarFunction>) -> Result<()>;
fn register_aggregate_function(
&self,
name: &str,
func: Arc<dyn AggregateFunction>,
) -> Result<()>;
fn register_virtual_table(&self, name: &str, table: Arc<dyn VirtualTable>) -> Result<()>;
}
pub trait ScalarFunction {
fn execute(&self, args: &[Value]) -> Result<Value>;
}
pub trait AggregateFunction {
fn init(&self) -> Box<dyn Any>;
fn step(&self, state: &mut dyn Any, args: &[Value]) -> Result<()>;
fn finalize(&self, state: Box<dyn Any>) -> Result<Value>;
}
pub trait VirtualTable {
fn schema(&self) -> &'static str;
fn create_cursor(&self) -> Box<dyn Cursor>;
}
pub trait Cursor {
fn next(&mut self) -> Result<Option<Row>>;
}
pub struct Row {
pub values: Vec<Value>,
}
pub enum Value {
Text(String),
Blob(Vec<u8>),
Integer(i64),
Float(f64),
Null,
}

View File

@@ -0,0 +1,11 @@
[package]
name = "uuid"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
[dependencies]
extension_api = { path = "../../extension_api"}
uuid = { version = "1.11.0", features = ["v4", "v7"] }

View File