diff --git a/Cargo.lock b/Cargo.lock index f8a0df0e6..26cd80646 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1701,6 +1701,14 @@ dependencies = [ "regex", ] +[[package]] +name = "limbo_series" +version = "0.0.14" +dependencies = [ + "limbo_ext", + "log", +] + [[package]] name = "limbo_sim" version = "0.0.14" diff --git a/Cargo.toml b/Cargo.toml index 754583f42..400595b4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ members = [ "extensions/percentile", "extensions/time", "extensions/crypto", + "extensions/series", ] exclude = ["perf/latency/limbo"] diff --git a/core/ext/mod.rs b/core/ext/mod.rs index db7876431..6d034e313 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,6 +1,11 @@ -use crate::{function::ExternalFunc, Database}; -use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction}; +use crate::{function::ExternalFunc, util::columns_from_create_table_body, Database, VirtualTable}; +use fallible_iterator::FallibleIterator; +use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabModuleImpl}; pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; +use sqlite3_parser::{ + ast::{Cmd, Stmt}, + lexer::sql::Parser, +}; use std::{ ffi::{c_char, c_void, CStr}, rc::Rc, @@ -44,6 +49,48 @@ unsafe extern "C" fn register_aggregate_function( db.register_aggregate_function_impl(&name_str, args, (init_func, step_func, finalize_func)) } +unsafe extern "C" fn register_module( + ctx: *mut c_void, + name: *const c_char, + module: VTabModuleImpl, +) -> ResultCode { + let c_str = unsafe { CStr::from_ptr(name) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + if ctx.is_null() { + return ResultCode::Error; + } + let db = unsafe { &mut *(ctx as *mut Database) }; + + db.register_module_impl(&name_str, module) +} + +unsafe extern "C" fn declare_vtab( + ctx: *mut c_void, + name: *const c_char, + sql: *const c_char, +) -> ResultCode { + let c_str = unsafe { CStr::from_ptr(name) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + + let c_str = unsafe { CStr::from_ptr(sql) }; + let sql_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + + if ctx.is_null() { + return ResultCode::Error; + } + let db = unsafe { &mut *(ctx as *mut Database) }; + db.declare_vtab_impl(&name_str, &sql_str) +} + impl Database { fn register_scalar_function_impl(&self, name: &str, func: ScalarFunction) -> ResultCode { self.syms.borrow_mut().functions.insert( @@ -66,11 +113,40 @@ impl Database { ResultCode::OK } + fn register_module_impl(&mut self, name: &str, module: VTabModuleImpl) -> ResultCode { + self.vtab_modules.insert(name.to_string(), Rc::new(module)); + ResultCode::OK + } + + fn declare_vtab_impl(&mut self, name: &str, sql: &str) -> ResultCode { + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next().unwrap().unwrap(); + let Cmd::Stmt(stmt) = cmd else { + return ResultCode::Error; + }; + let Stmt::CreateTable { body, .. } = stmt else { + return ResultCode::Error; + }; + let Ok(columns) = columns_from_create_table_body(body) else { + return ResultCode::Error; + }; + let vtab_module = self.vtab_modules.get(name).unwrap().clone(); + let vtab = VirtualTable { + name: name.to_string(), + implementation: vtab_module, + columns, + }; + self.syms.borrow_mut().vtabs.insert(name.to_string(), vtab); + ResultCode::OK + } + pub fn build_limbo_ext(&self) -> ExtensionApi { ExtensionApi { ctx: self as *const _ as *mut c_void, register_scalar_function, register_aggregate_function, + register_module, + declare_vtab, } } diff --git a/core/lib.rs b/core/lib.rs index 884aad963..ccc2c2273 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -25,12 +25,13 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; use fallible_iterator::FallibleIterator; #[cfg(not(target_family = "wasm"))] use libloading::{Library, Symbol}; -use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; +#[cfg(not(target_family = "wasm"))] +use limbo_ext::{ExtensionApi, ExtensionEntryPoint, ResultCode}; +use limbo_ext::{VTabModuleImpl, Value as ExtValue}; use log::trace; use parking_lot::RwLock; -use schema::Schema; -use sqlite3_parser::ast; -use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; +use schema::{Column, Schema}; +use sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser}; use std::cell::Cell; use std::collections::HashMap; use std::num::NonZero; @@ -44,9 +45,11 @@ use storage::pager::allocate_page; use storage::sqlite3_ondisk::{DatabaseHeader, DATABASE_HEADER_SIZE}; pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; +use types::OwnedValue; pub use types::Value; use util::parse_schema_rows; use vdbe::builder::QueryMode; +use vdbe::VTabOpaqueCursor; pub use error::LimboError; use translate::select::prepare_select_plan; @@ -82,6 +85,7 @@ pub struct Database { schema: Rc>, header: Rc>, syms: Rc>, + vtab_modules: HashMap>, // Shared structures of a Database are the parts that are common to multiple threads that might // create DB connections. _shared_page_cache: Arc>, @@ -144,6 +148,7 @@ impl Database { _shared_page_cache: _shared_page_cache.clone(), _shared_wal: shared_wal.clone(), syms, + vtab_modules: HashMap::new(), }; if let Err(e) = db.register_builtins() { return Err(LimboError::ExtensionError(e)); @@ -506,10 +511,70 @@ pub type Row = types::Record; pub type StepResult = vdbe::StepResult; +#[derive(Clone, Debug)] +pub struct VirtualTable { + name: String, + pub implementation: Rc, + columns: Vec, +} + +impl VirtualTable { + pub fn open(&self) -> VTabOpaqueCursor { + let cursor = unsafe { (self.implementation.open)() }; + VTabOpaqueCursor::new(cursor) + } + + pub fn filter( + &self, + cursor: &VTabOpaqueCursor, + arg_count: usize, + args: Vec, + ) -> Result<()> { + let mut filter_args = Vec::with_capacity(arg_count); + for i in 0..arg_count { + let ownedvalue_arg = args.get(i).unwrap(); + let extvalue_arg: ExtValue = match ownedvalue_arg { + OwnedValue::Null => Ok(ExtValue::null()), + OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)), + OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)), + OwnedValue::Text(t) => Ok(ExtValue::from_text((*t.value).clone())), + OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())), + other => Err(LimboError::ExtensionError(format!( + "Unsupported value type: {:?}", + other + ))), + }?; + filter_args.push(extvalue_arg); + } + let rc = unsafe { + (self.implementation.filter)(cursor.as_ptr(), arg_count as i32, filter_args.as_ptr()) + }; + match rc { + ResultCode::OK => Ok(()), + _ => Err(LimboError::ExtensionError("Filter failed".to_string())), + } + } + + pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result { + let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) }; + OwnedValue::from_ffi(&val) + } + + pub fn next(&self, cursor: &VTabOpaqueCursor) -> Result { + let rc = unsafe { (self.implementation.next)(cursor.as_ptr()) }; + match rc { + ResultCode::OK => Ok(true), + ResultCode::EOF => Ok(false), + _ => Err(LimboError::ExtensionError("Next failed".to_string())), + } + } +} + pub(crate) struct SymbolTable { pub functions: HashMap>, #[cfg(not(target_family = "wasm"))] extensions: Vec<(Library, *const ExtensionApi)>, + pub vtabs: HashMap, } impl std::fmt::Debug for SymbolTable { @@ -551,6 +616,7 @@ impl SymbolTable { pub fn new() -> Self { Self { functions: HashMap::new(), + vtabs: HashMap::new(), #[cfg(not(target_family = "wasm"))] extensions: Vec::new(), } diff --git a/core/schema.rs b/core/schema.rs index a5f1e6121..e7688b58e 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -1,3 +1,4 @@ +use crate::VirtualTable; use crate::{util::normalize_ident, Result}; use core::fmt; use fallible_iterator::FallibleIterator; @@ -47,6 +48,7 @@ impl Schema { pub enum Table { BTree(Rc), Pseudo(Rc), + Virtual(Rc), } impl Table { @@ -54,6 +56,7 @@ impl Table { match self { Table::BTree(table) => table.root_page, Table::Pseudo(_) => unimplemented!(), + Table::Virtual(_) => unimplemented!(), } } @@ -61,6 +64,7 @@ impl Table { match self { Self::BTree(table) => &table.name, Self::Pseudo(_) => "", + Self::Virtual(table) => &table.name, } } @@ -74,6 +78,10 @@ impl Table { .columns .get(index) .expect("column index out of bounds"), + Self::Virtual(table) => table + .columns + .get(index) + .expect("column index out of bounds"), } } @@ -81,6 +89,7 @@ impl Table { match self { Self::BTree(table) => &table.columns, Self::Pseudo(table) => &table.columns, + Self::Virtual(table) => &table.columns, } } @@ -88,6 +97,13 @@ impl Table { match self { Self::BTree(table) => Some(table.clone()), Self::Pseudo(_) => None, + Self::Virtual(_) => None, + } + } + pub fn virtual_table(&self) -> Option> { + match self { + Self::Virtual(table) => Some(table.clone()), + _ => None, } } } @@ -97,6 +113,7 @@ impl PartialEq for Table { match (self, other) { (Self::BTree(a), Self::BTree(b)) => Rc::ptr_eq(a, b), (Self::Pseudo(a), Self::Pseudo(b)) => Rc::ptr_eq(a, b), + (Self::Virtual(a), Self::Virtual(b)) => Rc::ptr_eq(a, b), _ => false, } } diff --git a/core/translate/delete.rs b/core/translate/delete.rs index ffad33d73..675b58f34 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -7,7 +7,7 @@ use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts, QueryMode}; use crate::{schema::Schema, Result, SymbolTable}; use sqlite3_parser::ast::{Expr, Limit, QualifiedName}; -use super::plan::TableReference; +use super::plan::{TableReference, TableReferenceType}; pub fn translate_delete( query_mode: QueryMode, @@ -48,6 +48,7 @@ pub fn prepare_delete_plan( identifier: table.name.clone(), op: Operation::Scan { iter_dir: None }, join_info: None, + reference_type: TableReferenceType::BTreeTable, }]; let mut where_predicates = vec![]; diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 45fb1a648..8b2e70185 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -13,7 +13,7 @@ use crate::vdbe::{ use crate::Result; use super::emitter::Resolver; -use super::plan::{Operation, TableReference}; +use super::plan::{Operation, TableReference, TableReferenceType}; #[derive(Debug, Clone, Copy)] pub struct ConditionMetadata { @@ -1824,22 +1824,45 @@ pub fn translate_expr( // If we are reading a column from a table, we find the cursor that corresponds to // the table and read the column from the cursor. Operation::Scan { .. } | Operation::Search(_) => { - let cursor_id = program.resolve_cursor_id(&table_reference.identifier); - if *is_rowid_alias { - program.emit_insn(Insn::RowId { - cursor_id, - dest: target_register, - }); - } else { - program.emit_insn(Insn::Column { - cursor_id, - column: *column, - dest: target_register, - }); + match &table_reference.reference_type { + TableReferenceType::BTreeTable => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + if *is_rowid_alias { + program.emit_insn(Insn::RowId { + cursor_id, + dest: target_register, + }); + } else { + program.emit_insn(Insn::Column { + cursor_id, + column: *column, + dest: target_register, + }); + } + let column = table_reference.table.get_column_at(*column); + maybe_apply_affinity(column.ty, target_register, program); + Ok(target_register) + } + TableReferenceType::VirtualTable { .. } => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + program.emit_insn(Insn::VColumn { + cursor_id, + column: *column, + dest: target_register, + }); + Ok(target_register) + } + TableReferenceType::Subquery { + result_columns_start_reg, + } => { + program.emit_insn(Insn::Copy { + src_reg: result_columns_start_reg + *column, + dst_reg: target_register, + amount: 0, + }); + Ok(target_register) + } } - let column = table_reference.table.get_column_at(*column); - maybe_apply_affinity(column.ty, target_register, program); - Ok(target_register) } // If we are reading a column from a subquery, we instead copy the column from the // subquery's result registers. diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index b4fff9c7a..4558693e3 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -17,7 +17,7 @@ use super::{ order_by::{order_by_sorter_insert, sorter_insert}, plan::{ IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference, - WhereTerm, + TableReferenceType, WhereTerm, }, }; @@ -78,27 +78,40 @@ pub fn init_loop( } match &table.op { Operation::Scan { .. } => { + let ref_type = &table.reference_type; let cursor_id = program.alloc_cursor_id( Some(table.identifier.clone()), - CursorType::BTreeTable(table.btree().unwrap().clone()), + match ref_type { + TableReferenceType::BTreeTable => { + CursorType::BTreeTable(table.btree().unwrap().clone()) + } + TableReferenceType::VirtualTable { .. } => { + CursorType::VirtualTable(table.virtual_table().unwrap().clone()) + } + other => panic!("Invalid table reference type in Scan: {:?}", other), + }, ); - let root_page = table.table.get_root_page(); - - match mode { - OperationMode::SELECT => { + match (mode, ref_type) { + (OperationMode::SELECT, TableReferenceType::BTreeTable) => { + let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenReadAsync { cursor_id, root_page, }); program.emit_insn(Insn::OpenReadAwait {}); } - OperationMode::DELETE => { + (OperationMode::DELETE, TableReferenceType::BTreeTable) => { + let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenWriteAsync { cursor_id, root_page, }); program.emit_insn(Insn::OpenWriteAwait {}); } + (OperationMode::SELECT, TableReferenceType::VirtualTable { .. }) => { + program.emit_insn(Insn::VOpenAsync { cursor_id }); + program.emit_insn(Insn::VOpenAwait {}); + } _ => { unimplemented!() } @@ -245,31 +258,52 @@ pub fn open_loop( } } Operation::Scan { iter_dir } => { + let ref_type = &table.reference_type; let cursor_id = program.resolve_cursor_id(&table.identifier); - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::LastAsync { cursor_id }); - } else { - program.emit_insn(Insn::RewindAsync { cursor_id }); - } - program.emit_insn( + + if !matches!(ref_type, TableReferenceType::VirtualTable { .. }) { if iter_dir .as_ref() .is_some_and(|dir| *dir == IterationDirection::Backwards) { - Insn::LastAwait { - cursor_id, - pc_if_empty: loop_end, - } + program.emit_insn(Insn::LastAsync { cursor_id }); } else { - Insn::RewindAwait { - cursor_id, - pc_if_empty: loop_end, + program.emit_insn(Insn::RewindAsync { cursor_id }); + } + } + match ref_type { + TableReferenceType::BTreeTable => program.emit_insn( + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + Insn::LastAwait { + cursor_id, + pc_if_empty: loop_end, + } + } else { + Insn::RewindAwait { + cursor_id, + pc_if_empty: loop_end, + } + }, + ), + TableReferenceType::VirtualTable { args, .. } => { + let start_reg = program.alloc_registers(args.len()); + let mut cur_reg = start_reg; + for arg in args { + let reg = cur_reg; + cur_reg += 1; + translate_expr(program, Some(tables), arg, reg, &t_ctx.resolver)?; } - }, - ); + program.emit_insn(Insn::VFilter { + cursor_id, + arg_count: args.len(), + args_reg: start_reg, + }); + } + other => panic!("Unsupported table reference type: {:?}", other), + } program.resolve_label(loop_start, program.offset()); for cond in predicates @@ -688,29 +722,41 @@ pub fn close_loop( }); } Operation::Scan { iter_dir, .. } => { + let ref_type = &table.reference_type; program.resolve_label(loop_labels.next, program.offset()); let cursor_id = program.resolve_cursor_id(&table.identifier); - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::PrevAsync { cursor_id }); - } else { - program.emit_insn(Insn::NextAsync { cursor_id }); - } - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::PrevAwait { - cursor_id, - pc_if_next: loop_labels.loop_start, - }); - } else { - program.emit_insn(Insn::NextAwait { - cursor_id, - pc_if_next: loop_labels.loop_start, - }); + match ref_type { + TableReferenceType::BTreeTable { .. } => { + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + program.emit_insn(Insn::PrevAsync { cursor_id }); + } else { + program.emit_insn(Insn::NextAsync { cursor_id }); + } + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + program.emit_insn(Insn::PrevAwait { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } else { + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } + } + TableReferenceType::VirtualTable { .. } => { + program.emit_insn(Insn::VNext { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } + other => unreachable!("Unsupported table reference type: {:?}", other), } } Operation::Search(search) => { diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 53df77956..73124060f 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -204,16 +204,16 @@ fn eliminate_constant_conditions( } fn push_scan_direction(table: &mut TableReference, direction: &Direction) { - match &mut table.op { - Operation::Scan { iter_dir, .. } => { - if iter_dir.is_none() { - match direction { - Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards), - Direction::Descending => *iter_dir = Some(IterationDirection::Backwards), - } + if let Operation::Scan { + ref mut iter_dir, .. + } = table.op + { + if iter_dir.is_none() { + match direction { + Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards), + Direction::Descending => *iter_dir = Some(IterationDirection::Backwards), } } - _ => {} } } @@ -309,12 +309,10 @@ impl Optimizable for ast::Expr { }; let column = table_reference.table.get_column_at(*column); for index in available_indexes_for_table.iter() { - if column - .name - .as_ref() - .map_or(false, |name| *name == index.columns.first().unwrap().name) - { - return Ok(Some(index.clone())); + if let Some(name) = column.name.as_ref() { + if &index.columns.first().unwrap().name == name { + return Ok(Some(index.clone())); + } } } Ok(None) diff --git a/core/translate/plan.rs b/core/translate/plan.rs index e59ffd5e8..43cba8e1b 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -9,6 +9,7 @@ use crate::{ function::AggFunc, schema::{BTreeTable, Column, Index, Table}, vdbe::BranchOffset, + VirtualTable, }; use crate::{ schema::{PseudoTable, Type}, @@ -197,11 +198,9 @@ pub struct TableReference { pub identifier: String, /// The join info for this table reference, if it is the right side of a join (which all except the first table reference have) pub join_info: Option, + pub reference_type: TableReferenceType, } -/** - A SourceOperator is a reference in the query plan that reads data from a table. -*/ #[derive(Clone, Debug)] pub enum Operation { // Scan operation @@ -226,10 +225,37 @@ pub enum Operation { }, } +/// The type of the table reference, either BTreeTable or Subquery +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TableReferenceType { + /// A BTreeTable is a table that is stored on disk in a B-tree index. + BTreeTable, + /// A subquery. + Subquery { + /// The index of the first register in the query plan that contains the result columns of the subquery. + result_columns_start_reg: usize, + }, + /// A virtual table. + VirtualTable { + /// Arguments to pass e.g. generate_series(1, 10, 2) + args: Vec, + }, +} + impl TableReference { /// Returns the btree table for this table reference, if it is a BTreeTable. pub fn btree(&self) -> Option> { - self.table.btree() + match &self.reference_type { + TableReferenceType::BTreeTable => self.table.btree(), + TableReferenceType::Subquery { .. } => None, + TableReferenceType::VirtualTable { .. } => None, + } + } + pub fn virtual_table(&self) -> Option> { + match &self.reference_type { + TableReferenceType::VirtualTable { .. } => self.table.virtual_table(), + _ => None, + } } /// Creates a new TableReference for a subquery. @@ -254,6 +280,9 @@ impl TableReference { result_columns_start_reg: 0, // Will be set in the bytecode emission phase }, table, + reference_type: TableReferenceType::Subquery { + result_columns_start_reg: 0, // Will be set in the bytecode emission phase + }, identifier: identifier.clone(), join_info, } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 95cee1edf..272b788a2 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,7 +1,9 @@ +use std::rc::Rc; + use super::{ plan::{ Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, - WhereTerm, + TableReferenceType, WhereTerm, }, select::prepare_select_plan, SymbolTable, @@ -301,6 +303,7 @@ fn parse_from_clause_table( table: Table::BTree(table.clone()), identifier: alias.unwrap_or(normalized_qualified_name), join_info: None, + reference_type: TableReferenceType::BTreeTable, }) } ast::SelectTable::Select(subselect, maybe_alias) => { @@ -317,8 +320,30 @@ fn parse_from_clause_table( ast::As::Elided(id) => id.0.clone(), }) .unwrap_or(format!("subquery_{}", cur_table_index)); - let table_reference = TableReference::new_subquery(identifier, subplan, None); - Ok(table_reference) + Ok(TableReference::new_subquery(identifier, subplan, None)) + } + ast::SelectTable::TableCall(qualified_name, mut maybe_args, maybe_alias) => { + let normalized_name = normalize_ident(qualified_name.name.0.as_str()); + let Some(vtab) = syms.vtabs.get(&normalized_name) else { + crate::bail_parse_error!("Virtual table {} not found", normalized_name); + }; + let alias = maybe_alias + .as_ref() + .map(|a| match a { + ast::As::As(id) => id.0.clone(), + ast::As::Elided(id) => id.0.clone(), + }) + .unwrap_or(normalized_name); + + Ok(TableReference { + op: Operation::Scan { iter_dir: None }, + join_info: None, + table: Table::Virtual(vtab.clone().into()), + identifier: alias.clone(), + reference_type: TableReferenceType::VirtualTable { + args: maybe_args.take().unwrap_or_default(), + }, + }) } _ => todo!(), } diff --git a/core/types.rs b/core/types.rs index 06c459741..ae31314c1 100644 --- a/core/types.rs +++ b/core/types.rs @@ -6,6 +6,7 @@ use crate::pseudo::PseudoCursor; use crate::storage::btree::BTreeCursor; use crate::storage::sqlite3_ondisk::write_varint; use crate::vdbe::sorter::Sorter; +use crate::vdbe::VTabOpaqueCursor; use crate::Result; use std::fmt::Display; use std::rc::Rc; @@ -670,6 +671,7 @@ pub enum Cursor { Index(BTreeCursor), Pseudo(PseudoCursor), Sorter(Sorter), + Virtual(VTabOpaqueCursor), } impl Cursor { @@ -716,6 +718,13 @@ impl Cursor { _ => panic!("Cursor is not a sorter cursor"), } } + + pub fn as_virtual_mut(&mut self) -> &mut VTabOpaqueCursor { + match self { + Self::Virtual(cursor) => cursor, + _ => panic!("Cursor is not a virtual cursor"), + } + } } pub enum CursorResult { diff --git a/core/util.rs b/core/util.rs index c92c31412..5251b36cf 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1,9 +1,9 @@ use std::{rc::Rc, sync::Arc}; -use sqlite3_parser::ast::{Expr, FunctionTail, Literal}; +use sqlite3_parser::ast::{CreateTableBody, Expr, FunctionTail, Literal}; use crate::{ - schema::{self, Schema}, + schema::{self, Column, Schema, Type}, Result, Statement, StepResult, IO, }; @@ -308,6 +308,77 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { } } +pub fn columns_from_create_table_body(body: CreateTableBody) -> Result, ()> { + let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else { + return Err(()); + }; + + Ok(columns + .into_iter() + .filter_map(|(name, column_def)| { + // if column_def.col_type includes HIDDEN, omit it for now + if let Some(data_type) = column_def.col_type.as_ref() { + if data_type.name.as_str().contains("HIDDEN") { + return None; + } + } + let column = Column { + name: Some(name.0), + ty: match column_def.col_type { + Some(ref data_type) => { + // https://www.sqlite.org/datatype3.html + let type_name = data_type.name.as_str().to_uppercase(); + if type_name.contains("INT") { + Type::Integer + } else if type_name.contains("CHAR") + || type_name.contains("CLOB") + || type_name.contains("TEXT") + { + Type::Text + } else if type_name.contains("BLOB") || type_name.is_empty() { + Type::Blob + } else if type_name.contains("REAL") + || type_name.contains("FLOA") + || type_name.contains("DOUB") + { + Type::Real + } else { + Type::Numeric + } + } + None => Type::Null, + }, + default: column_def + .constraints + .iter() + .find_map(|c| match &c.constraint { + sqlite3_parser::ast::ColumnConstraint::Default(val) => Some(val.clone()), + _ => None, + }), + notnull: column_def.constraints.iter().any(|c| { + matches!( + c.constraint, + sqlite3_parser::ast::ColumnConstraint::NotNull { .. } + ) + }), + ty_str: column_def + .col_type + .clone() + .map(|t| t.name.to_string()) + .unwrap_or_default(), + primary_key: column_def.constraints.iter().any(|c| { + matches!( + c.constraint, + sqlite3_parser::ast::ColumnConstraint::PrimaryKey { .. } + ) + }), + is_rowid_alias: false, + }; + Some(column) + }) + .collect::>()) +} + #[cfg(test)] pub mod tests { use super::*; diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 21b4b8949..c3ead0d38 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -9,7 +9,7 @@ use crate::{ schema::{BTreeTable, Index, PseudoTable}, storage::sqlite3_ondisk::DatabaseHeader, translate::plan::{ResultSetColumn, TableReference}, - Connection, + Connection, VirtualTable, }; use super::{BranchOffset, CursorID, Insn, InsnReference, Program}; @@ -40,6 +40,7 @@ pub enum CursorType { BTreeIndex(Rc), Pseudo(Rc), Sorter, + VirtualTable(Rc), } impl CursorType { @@ -406,6 +407,9 @@ impl ProgramBuilder { Insn::IsNull { reg: _, target_pc } => { resolve(target_pc, "IsNull"); } + Insn::VNext { pc_if_next, .. } => { + resolve(pc_if_next, "VNext"); + } _ => continue, } } diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index e4c302bba..a0bb63023 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -363,6 +363,62 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::VOpenAsync { cursor_id } => ( + "VOpenAsync", + *cursor_id as i32, + 0, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VOpenAwait => ( + "VOpenAwait", + 0, + 0, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VFilter { + cursor_id, + arg_count, + args_reg, + } => ( + "VFilter", + *cursor_id as i32, + *arg_count as i32, + *args_reg as i32, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VColumn { + cursor_id, + column, + dest, + } => ( + "VColumn", + *cursor_id as i32, + *column as i32, + *dest as i32, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VNext { + cursor_id, + pc_if_next, + } => ( + "VNext", + *cursor_id as i32, + pc_if_next.to_debug_int(), + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), Insn::OpenPseudo { cursor_id, content_reg, @@ -423,6 +479,7 @@ pub fn insn_to_str( name } CursorType::Sorter => None, + CursorType::VirtualTable(v) => v.columns.get(*column).unwrap().name.as_ref(), }; ( "Column", diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 1cdb81e25..223f321aa 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -213,6 +213,35 @@ pub enum Insn { // Await for the completion of open cursor. OpenReadAwait, + /// Open a cursor for a virtual table. + VOpenAsync { + cursor_id: CursorID, + }, + + /// Await for the completion of open cursor for a virtual table. + VOpenAwait, + + /// Initialize the position of the virtual table cursor. + VFilter { + cursor_id: CursorID, + arg_count: usize, + args_reg: usize, + }, + + /// Read a column from the current row of the virtual table cursor. + VColumn { + cursor_id: CursorID, + column: usize, + dest: usize, + }, + + /// Advance the virtual table cursor to the next row. + /// TODO: async + VNext { + cursor_id: CursorID, + pc_if_next: BranchOffset, + }, + // Open a cursor for a pseudo-table that contains a single row. OpenPseudo { cursor_id: CursorID, diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index e77ccf551..c25dc572b 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -65,6 +65,7 @@ use sorter::Sorter; use std::borrow::BorrowMut; use std::cell::{Cell, RefCell, RefMut}; use std::collections::HashMap; +use std::ffi::c_void; use std::num::NonZero; use std::rc::{Rc, Weak}; @@ -267,6 +268,18 @@ fn get_cursor_as_sorter_mut<'long, 'short>( cursor } +fn get_cursor_as_virtual_mut<'long, 'short>( + cursors: &'short mut RefMut<'long, Vec>>, + cursor_id: CursorID, +) -> &'short mut VTabOpaqueCursor { + let cursor = cursors + .get_mut(cursor_id) + .expect("cursor id out of bounds") + .as_mut() + .expect("cursor not allocated") + .as_virtual_mut(); + cursor +} struct Bitfield([u64; N]); impl Bitfield { @@ -290,6 +303,18 @@ impl Bitfield { } } +pub struct VTabOpaqueCursor(*mut c_void); + +impl VTabOpaqueCursor { + pub fn new(cursor: *mut c_void) -> Self { + Self(cursor) + } + + pub fn as_ptr(&self) -> *mut c_void { + self.0 + } +} + /// The program state describes the environment in which the program executes. pub struct ProgramState { pub pc: InsnReference, @@ -370,6 +395,7 @@ macro_rules! must_be_btree_cursor { CursorType::BTreeIndex(_) => get_cursor_as_index_mut(&mut $cursors, $cursor_id), CursorType::Pseudo(_) => panic!("{} on pseudo cursor", $insn_name), CursorType::Sorter => panic!("{} on sorter cursor", $insn_name), + CursorType::VirtualTable(_) => panic!("{} on virtual table cursor", $insn_name), }; cursor }}; @@ -826,12 +852,79 @@ impl Program { CursorType::Sorter => { panic!("OpenReadAsync on sorter cursor"); } + CursorType::VirtualTable(_) => { + panic!("OpenReadAsync on virtual table cursor, use Insn::VOpenAsync instead"); + } } state.pc += 1; } Insn::OpenReadAwait => { state.pc += 1; } + Insn::VOpenAsync { cursor_id } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VOpenAsync on non-virtual table cursor"); + }; + let cursor = virtual_table.open(); + state + .cursors + .borrow_mut() + .insert(*cursor_id, Some(Cursor::Virtual(cursor))); + state.pc += 1; + } + Insn::VOpenAwait => { + state.pc += 1; + } + Insn::VFilter { + cursor_id, + arg_count, + args_reg, + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VFilter on non-virtual table cursor"); + }; + let mut cursors = state.cursors.borrow_mut(); + let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id); + let mut args = Vec::new(); + for i in 0..*arg_count { + args.push(state.registers[args_reg + i].clone()); + } + virtual_table.filter(cursor, *arg_count, args)?; + state.pc += 1; + } + Insn::VColumn { + cursor_id, + column, + dest, + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VColumn on non-virtual table cursor"); + }; + let mut cursors = state.cursors.borrow_mut(); + let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id); + state.registers[*dest] = virtual_table.column(cursor, *column)?; + state.pc += 1; + } + Insn::VNext { + cursor_id, + pc_if_next, + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VNextAsync on non-virtual table cursor"); + }; + let mut cursors = state.cursors.borrow_mut(); + let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id); + let has_more = virtual_table.next(cursor)?; + if has_more { + state.pc = pc_if_next.to_offset_int(); + } else { + state.pc += 1; + } + } Insn::OpenPseudo { cursor_id, content_reg: _, @@ -943,6 +1036,11 @@ impl Program { state.registers[*dest] = OwnedValue::Null; } } + CursorType::VirtualTable(_) => { + panic!( + "Insn::Column on virtual table cursor, use Insn::VColumn instead" + ); + } } state.pc += 1; diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 5f9bb09c5..fec363c44 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -1,5 +1,5 @@ mod types; -pub use limbo_macros::{register_extension, scalar, AggregateDerive}; +pub use limbo_macros::{register_extension, scalar, AggregateDerive, VTabModuleDerive}; use std::os::raw::{c_char, c_void}; pub use types::{ResultCode, Value, ValueType}; @@ -21,6 +21,30 @@ pub struct ExtensionApi { step_func: StepFunction, finalize_func: FinalizeFunction, ) -> ResultCode, + + pub register_module: unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + module: VTabModuleImpl, + ) -> ResultCode, + + pub declare_vtab: unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + sql: *const c_char, + ) -> ResultCode, +} + +impl ExtensionApi { + pub fn declare_virtual_table(&self, name: &str, sql: &str) -> ResultCode { + let Ok(name) = std::ffi::CString::new(name) else { + return ResultCode::Error; + }; + let Ok(sql) = std::ffi::CString::new(sql) else { + return ResultCode::Error; + }; + unsafe { (self.declare_vtab)(self.ctx, name.as_ptr(), sql.as_ptr()) } + } } pub type ExtensionEntryPoint = unsafe extern "C" fn(api: *const ExtensionApi) -> ResultCode; @@ -47,3 +71,52 @@ pub trait AggFunc { fn step(state: &mut Self::State, args: &[Value]); fn finalize(state: Self::State) -> Value; } + +#[repr(C)] +#[derive(Clone, Debug)] +pub struct VTabModuleImpl { + pub name: *const c_char, + pub connect: VtabFnConnect, + pub open: VtabFnOpen, + pub filter: VtabFnFilter, + pub column: VtabFnColumn, + pub next: VtabFnNext, + pub eof: VtabFnEof, +} + +pub type VtabFnConnect = unsafe extern "C" fn(api: *const c_void) -> ResultCode; + +pub type VtabFnOpen = unsafe extern "C" fn() -> *mut c_void; + +pub type VtabFnFilter = + unsafe extern "C" fn(cursor: *mut c_void, argc: i32, argv: *const Value) -> ResultCode; + +pub type VtabFnColumn = unsafe extern "C" fn(cursor: *mut c_void, idx: u32) -> Value; + +pub type VtabFnNext = unsafe extern "C" fn(cursor: *mut c_void) -> ResultCode; + +pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; + +pub trait VTabModule: 'static { + type VCursor: VTabCursor; + + fn name() -> &'static str; + fn connect(api: &ExtensionApi) -> ResultCode; + fn open() -> Self::VCursor; + fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; + fn column(cursor: &Self::VCursor, idx: u32) -> Value; + fn next(cursor: &mut Self::VCursor) -> ResultCode; + fn eof(cursor: &Self::VCursor) -> bool; +} + +pub trait VTabCursor: Sized { + fn rowid(&self) -> i64; + fn column(&self, idx: u32) -> Value; + fn eof(&self) -> bool; + fn next(&mut self) -> ResultCode; +} + +#[repr(C)] +pub struct VTabImpl { + pub module: VTabModuleImpl, +} diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 464e07bfd..4a1fa3978 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -2,8 +2,8 @@ use std::fmt::Display; /// Error type is of type ExtError which can be /// either a user defined error or an error code -#[derive(Clone, Copy)] #[repr(C)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum ResultCode { OK = 0, Error = 1, @@ -20,6 +20,7 @@ pub enum ResultCode { Internal = 12, Unavailable = 13, CustomError = 14, + EOF = 15, } impl ResultCode { @@ -50,6 +51,7 @@ impl Display for ResultCode { ResultCode::Internal => write!(f, "Internal Error"), ResultCode::Unavailable => write!(f, "Unavailable"), ResultCode::CustomError => write!(f, "Error "), + ResultCode::EOF => write!(f, "EOF"), } } } diff --git a/extensions/series/Cargo.toml b/extensions/series/Cargo.toml new file mode 100644 index 000000000..73a634ac7 --- /dev/null +++ b/extensions/series/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "limbo_series" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + + +[dependencies] +limbo_ext = { path = "../core"} +log = "0.4.20" diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs new file mode 100644 index 000000000..f438c6fce --- /dev/null +++ b/extensions/series/src/lib.rs @@ -0,0 +1,136 @@ +use limbo_ext::{ + register_extension, ExtensionApi, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value, + ValueType, +}; + +register_extension! { + vtabs: { GenerateSeriesVTab } +} + +/// A virtual table that generates a sequence of integers +#[derive(Debug, VTabModuleDerive)] +struct GenerateSeriesVTab; + +impl VTabModule for GenerateSeriesVTab { + type VCursor = GenerateSeriesCursor; + fn name() -> &'static str { + "generate_series" + } + + fn connect(api: &ExtensionApi) -> ResultCode { + // Create table schema + let sql = "CREATE TABLE generate_series( + value INTEGER, + start INTEGER HIDDEN, + stop INTEGER HIDDEN, + step INTEGER HIDDEN + )"; + let name = Self::name(); + api.declare_virtual_table(name, sql) + } + + fn open() -> Self::VCursor { + GenerateSeriesCursor { + start: 0, + stop: 0, + step: 0, + current: 0, + } + } + + fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { + // args are the start, stop, and step + if arg_count == 0 || arg_count > 3 { + return ResultCode::InvalidArgs; + } + let start = { + if args[0].value_type() == ValueType::Integer { + args[0].to_integer().unwrap() + } else { + return ResultCode::InvalidArgs; + } + }; + let stop = if args.len() == 1 { + i64::MAX + } else { + if args[1].value_type() == ValueType::Integer { + args[1].to_integer().unwrap() + } else { + return ResultCode::InvalidArgs; + } + }; + let step = if args.len() <= 2 { + 1 + } else { + if args[2].value_type() == ValueType::Integer { + args[2].to_integer().unwrap() + } else { + return ResultCode::InvalidArgs; + } + }; + cursor.start = start; + cursor.current = start; + cursor.stop = stop; + cursor.step = step; + ResultCode::OK + } + + fn column(cursor: &Self::VCursor, idx: u32) -> Value { + cursor.column(idx) + } + + fn next(cursor: &mut Self::VCursor) -> ResultCode { + GenerateSeriesCursor::next(cursor) + } + + fn eof(cursor: &Self::VCursor) -> bool { + cursor.eof() + } +} + +/// The cursor for iterating over the generated sequence +#[derive(Debug)] +struct GenerateSeriesCursor { + start: i64, + stop: i64, + step: i64, + current: i64, +} + +impl GenerateSeriesCursor { + fn next(&mut self) -> ResultCode { + let current = self.current; + + // Check if we've reached the end + if (self.step > 0 && current >= self.stop) || (self.step < 0 && current <= self.stop) { + return ResultCode::EOF; + } + + self.current = current.saturating_add(self.step); + ResultCode::OK + } +} + +impl VTabCursor for GenerateSeriesCursor { + fn next(&mut self) -> ResultCode { + GenerateSeriesCursor::next(self) + } + + fn eof(&self) -> bool { + (self.step > 0 && self.current > self.stop) || (self.step < 0 && self.current < self.stop) + } + + fn column(&self, idx: u32) -> Value { + match idx { + 0 => Value::from_integer(self.current), + 1 => Value::from_integer(self.start), + 2 => Value::from_integer(self.stop), + 3 => Value::from_integer(self.step), + _ => Value::null(), + } + } + + fn rowid(&self) -> i64 { + ((self.current - self.start) / self.step) + 1 + } +} diff --git a/macros/src/args.rs b/macros/src/args.rs index d9e59cbd3..12446b660 100644 --- a/macros/src/args.rs +++ b/macros/src/args.rs @@ -6,31 +6,32 @@ use syn::{Ident, LitStr, Token}; pub(crate) struct RegisterExtensionInput { pub aggregates: Vec, pub scalars: Vec, + pub vtabs: Vec, } impl syn::parse::Parse for RegisterExtensionInput { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut aggregates = Vec::new(); let mut scalars = Vec::new(); - + let mut vtabs = Vec::new(); while !input.is_empty() { if input.peek(syn::Ident) && input.peek2(Token![:]) { let section_name: Ident = input.parse()?; input.parse::()?; - - if section_name == "aggregates" || section_name == "scalars" { + let names = ["aggregates", "scalars", "vtabs"]; + if names.contains(§ion_name.to_string().as_str()) { let content; syn::braced!(content in input); - let parsed_items = Punctuated::::parse_terminated(&content)? .into_iter() .collect(); - if section_name == "aggregates" { - aggregates = parsed_items; - } else { - scalars = parsed_items; - } + match section_name.to_string().as_str() { + "aggregates" => aggregates = parsed_items, + "scalars" => scalars = parsed_items, + "vtabs" => vtabs = parsed_items, + _ => unreachable!(), + }; if input.peek(Token![,]) { input.parse::()?; @@ -39,13 +40,14 @@ impl syn::parse::Parse for RegisterExtensionInput { return Err(syn::Error::new(section_name.span(), "Unknown section")); } } else { - return Err(input.error("Expected aggregates: or scalars: section")); + return Err(input.error("Expected aggregates:, scalars:, or vtabs: section")); } } Ok(Self { aggregates, scalars, + vtabs, }) } } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 1e0ef421e..6b0df9679 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -324,6 +324,103 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } +#[proc_macro_derive(VTabModuleDerive)] +pub fn derive_vtab_module(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let struct_name = &ast.ident; + + let register_fn_name = format_ident!("register_{}", struct_name); + let connect_fn_name = format_ident!("connect_{}", struct_name); + let open_fn_name = format_ident!("open_{}", struct_name); + let filter_fn_name = format_ident!("filter_{}", struct_name); + let column_fn_name = format_ident!("column_{}", struct_name); + let next_fn_name = format_ident!("next_{}", struct_name); + let eof_fn_name = format_ident!("eof_{}", struct_name); + + let expanded = quote! { + impl #struct_name { + #[no_mangle] + unsafe extern "C" fn #connect_fn_name( + db: *const ::std::ffi::c_void, + ) -> ::limbo_ext::ResultCode { + let api = unsafe { &*(db as *const ExtensionApi) }; + <#struct_name as ::limbo_ext::VTabModule>::connect(api) + } + + #[no_mangle] + unsafe extern "C" fn #open_fn_name( + ) -> *mut ::std::ffi::c_void { + let cursor = <#struct_name as ::limbo_ext::VTabModule>::open(); + Box::into_raw(Box::new(cursor)) as *mut ::std::ffi::c_void + } + + #[no_mangle] + unsafe extern "C" fn #filter_fn_name( + cursor: *mut ::std::ffi::c_void, + argc: i32, + argv: *const ::limbo_ext::Value, + ) -> ::limbo_ext::ResultCode { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + let args = std::slice::from_raw_parts(argv, argc as usize); + <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, argc, args) + } + + #[no_mangle] + unsafe extern "C" fn #column_fn_name( + cursor: *mut ::std::ffi::c_void, + idx: u32, + ) -> ::limbo_ext::Value { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) + } + + #[no_mangle] + unsafe extern "C" fn #next_fn_name( + cursor: *mut ::std::ffi::c_void, + ) -> ::limbo_ext::ResultCode { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + <#struct_name as ::limbo_ext::VTabModule>::next(cursor) + } + + #[no_mangle] + unsafe extern "C" fn #eof_fn_name( + cursor: *mut ::std::ffi::c_void, + ) -> bool { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) + } + + #[no_mangle] + pub unsafe extern "C" fn #register_fn_name( + api: *const ::limbo_ext::ExtensionApi + ) -> ::limbo_ext::ResultCode { + if api.is_null() { + return ::limbo_ext::ResultCode::Error; + } + + let api = &*api; + let name = <#struct_name as ::limbo_ext::VTabModule>::name(); + // name needs to be a c str FFI compatible, NOT CString + let name_c = std::ffi::CString::new(name).unwrap(); + + let module = ::limbo_ext::VTabModuleImpl { + name: name_c.as_ptr(), + connect: Self::#connect_fn_name, + open: Self::#open_fn_name, + filter: Self::#filter_fn_name, + column: Self::#column_fn_name, + next: Self::#next_fn_name, + eof: Self::#eof_fn_name, + }; + + (api.register_module)(api.ctx, name_c.as_ptr(), module) + } + } + }; + + TokenStream::from(expanded) +} + /// Register your extension with 'core' by providing the relevant functions ///```ignore ///use limbo_ext::{register_extension, scalar, Value, AggregateDerive, AggFunc}; @@ -362,6 +459,7 @@ pub fn register_extension(input: TokenStream) -> TokenStream { let RegisterExtensionInput { aggregates, scalars, + vtabs, } = input_ast; let scalar_calls = scalars.iter().map(|scalar_ident| { @@ -388,8 +486,23 @@ pub fn register_extension(input: TokenStream) -> TokenStream { } } }); + let vtab_calls = vtabs.iter().map(|vtab_ident| { + let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span()); + quote! { + { + let result = unsafe{ #vtab_ident::#register_fn(api)}; + if result == ::limbo_ext::ResultCode::OK { + let result = <#vtab_ident as ::limbo_ext::VTabModule>::connect(api); + return result; + } else { + return result; + } + } + } + }); let static_aggregates = aggregate_calls.clone(); let static_scalars = scalar_calls.clone(); + let static_vtabs = vtab_calls.clone(); let expanded = quote! { #[cfg(not(target_family = "wasm"))] @@ -404,20 +517,23 @@ pub fn register_extension(input: TokenStream) -> TokenStream { #(#static_aggregates)* + #(#static_vtabs)* + ::limbo_ext::ResultCode::OK } #[cfg(not(feature = "static"))] - #[no_mangle] - pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { - let api = unsafe { &*api }; - #(#scalar_calls)* + #[no_mangle] + pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { + let api = unsafe { &*api }; + #(#scalar_calls)* - #(#aggregate_calls)* + #(#aggregate_calls)* - ::limbo_ext::ResultCode::OK - } + #(#vtab_calls)* + + ::limbo_ext::ResultCode::OK + } }; - TokenStream::from(expanded) }