From f5f77c0bd164e13e82c32f8bc2124c8bab429a8b Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Sun, 19 Jan 2025 15:10:47 +0200 Subject: [PATCH 1/7] Initial virtual table implementation --- Cargo.lock | 8 ++ Cargo.toml | 1 + core/ext/mod.rs | 80 +++++++++++++++++++- core/lib.rs | 74 ++++++++++++++++++- core/schema.rs | 17 +++++ core/translate/delete.rs | 3 +- core/translate/expr.rs | 55 ++++++++++---- core/translate/main_loop.rs | 138 +++++++++++++++++++++++------------ core/translate/optimizer.rs | 26 +++---- core/translate/plan.rs | 37 +++++++++- core/translate/planner.rs | 31 +++++++- core/types.rs | 9 +++ core/util.rs | 75 ++++++++++++++++++- core/vdbe/builder.rs | 6 +- core/vdbe/explain.rs | 57 +++++++++++++++ core/vdbe/insn.rs | 29 ++++++++ core/vdbe/mod.rs | 98 +++++++++++++++++++++++++ extensions/core/src/lib.rs | 75 ++++++++++++++++++- extensions/core/src/types.rs | 4 +- extensions/series/Cargo.toml | 15 ++++ extensions/series/src/lib.rs | 136 ++++++++++++++++++++++++++++++++++ macros/src/args.rs | 22 +++--- macros/src/lib.rs | 132 +++++++++++++++++++++++++++++++-- 23 files changed, 1015 insertions(+), 113 deletions(-) create mode 100644 extensions/series/Cargo.toml create mode 100644 extensions/series/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index f8a0df0e6..26cd80646 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1701,6 +1701,14 @@ dependencies = [ "regex", ] +[[package]] +name = "limbo_series" +version = "0.0.14" +dependencies = [ + "limbo_ext", + "log", +] + [[package]] name = "limbo_sim" version = "0.0.14" diff --git a/Cargo.toml b/Cargo.toml index 754583f42..400595b4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ members = [ "extensions/percentile", "extensions/time", "extensions/crypto", + "extensions/series", ] exclude = ["perf/latency/limbo"] diff --git a/core/ext/mod.rs b/core/ext/mod.rs index db7876431..6d034e313 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,6 +1,11 @@ -use crate::{function::ExternalFunc, Database}; -use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction}; +use crate::{function::ExternalFunc, util::columns_from_create_table_body, Database, VirtualTable}; +use fallible_iterator::FallibleIterator; +use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabModuleImpl}; pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; +use sqlite3_parser::{ + ast::{Cmd, Stmt}, + lexer::sql::Parser, +}; use std::{ ffi::{c_char, c_void, CStr}, rc::Rc, @@ -44,6 +49,48 @@ unsafe extern "C" fn register_aggregate_function( db.register_aggregate_function_impl(&name_str, args, (init_func, step_func, finalize_func)) } +unsafe extern "C" fn register_module( + ctx: *mut c_void, + name: *const c_char, + module: VTabModuleImpl, +) -> ResultCode { + let c_str = unsafe { CStr::from_ptr(name) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + if ctx.is_null() { + return ResultCode::Error; + } + let db = unsafe { &mut *(ctx as *mut Database) }; + + db.register_module_impl(&name_str, module) +} + +unsafe extern "C" fn declare_vtab( + ctx: *mut c_void, + name: *const c_char, + sql: *const c_char, +) -> ResultCode { + let c_str = unsafe { CStr::from_ptr(name) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + + let c_str = unsafe { CStr::from_ptr(sql) }; + let sql_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + + if ctx.is_null() { + return ResultCode::Error; + } + let db = unsafe { &mut *(ctx as *mut Database) }; + db.declare_vtab_impl(&name_str, &sql_str) +} + impl Database { fn register_scalar_function_impl(&self, name: &str, func: ScalarFunction) -> ResultCode { self.syms.borrow_mut().functions.insert( @@ -66,11 +113,40 @@ impl Database { ResultCode::OK } + fn register_module_impl(&mut self, name: &str, module: VTabModuleImpl) -> ResultCode { + self.vtab_modules.insert(name.to_string(), Rc::new(module)); + ResultCode::OK + } + + fn declare_vtab_impl(&mut self, name: &str, sql: &str) -> ResultCode { + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next().unwrap().unwrap(); + let Cmd::Stmt(stmt) = cmd else { + return ResultCode::Error; + }; + let Stmt::CreateTable { body, .. } = stmt else { + return ResultCode::Error; + }; + let Ok(columns) = columns_from_create_table_body(body) else { + return ResultCode::Error; + }; + let vtab_module = self.vtab_modules.get(name).unwrap().clone(); + let vtab = VirtualTable { + name: name.to_string(), + implementation: vtab_module, + columns, + }; + self.syms.borrow_mut().vtabs.insert(name.to_string(), vtab); + ResultCode::OK + } + pub fn build_limbo_ext(&self) -> ExtensionApi { ExtensionApi { ctx: self as *const _ as *mut c_void, register_scalar_function, register_aggregate_function, + register_module, + declare_vtab, } } diff --git a/core/lib.rs b/core/lib.rs index 884aad963..ccc2c2273 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -25,12 +25,13 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; use fallible_iterator::FallibleIterator; #[cfg(not(target_family = "wasm"))] use libloading::{Library, Symbol}; -use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; +#[cfg(not(target_family = "wasm"))] +use limbo_ext::{ExtensionApi, ExtensionEntryPoint, ResultCode}; +use limbo_ext::{VTabModuleImpl, Value as ExtValue}; use log::trace; use parking_lot::RwLock; -use schema::Schema; -use sqlite3_parser::ast; -use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; +use schema::{Column, Schema}; +use sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser}; use std::cell::Cell; use std::collections::HashMap; use std::num::NonZero; @@ -44,9 +45,11 @@ use storage::pager::allocate_page; use storage::sqlite3_ondisk::{DatabaseHeader, DATABASE_HEADER_SIZE}; pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; +use types::OwnedValue; pub use types::Value; use util::parse_schema_rows; use vdbe::builder::QueryMode; +use vdbe::VTabOpaqueCursor; pub use error::LimboError; use translate::select::prepare_select_plan; @@ -82,6 +85,7 @@ pub struct Database { schema: Rc>, header: Rc>, syms: Rc>, + vtab_modules: HashMap>, // Shared structures of a Database are the parts that are common to multiple threads that might // create DB connections. _shared_page_cache: Arc>, @@ -144,6 +148,7 @@ impl Database { _shared_page_cache: _shared_page_cache.clone(), _shared_wal: shared_wal.clone(), syms, + vtab_modules: HashMap::new(), }; if let Err(e) = db.register_builtins() { return Err(LimboError::ExtensionError(e)); @@ -506,10 +511,70 @@ pub type Row = types::Record; pub type StepResult = vdbe::StepResult; +#[derive(Clone, Debug)] +pub struct VirtualTable { + name: String, + pub implementation: Rc, + columns: Vec, +} + +impl VirtualTable { + pub fn open(&self) -> VTabOpaqueCursor { + let cursor = unsafe { (self.implementation.open)() }; + VTabOpaqueCursor::new(cursor) + } + + pub fn filter( + &self, + cursor: &VTabOpaqueCursor, + arg_count: usize, + args: Vec, + ) -> Result<()> { + let mut filter_args = Vec::with_capacity(arg_count); + for i in 0..arg_count { + let ownedvalue_arg = args.get(i).unwrap(); + let extvalue_arg: ExtValue = match ownedvalue_arg { + OwnedValue::Null => Ok(ExtValue::null()), + OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)), + OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)), + OwnedValue::Text(t) => Ok(ExtValue::from_text((*t.value).clone())), + OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())), + other => Err(LimboError::ExtensionError(format!( + "Unsupported value type: {:?}", + other + ))), + }?; + filter_args.push(extvalue_arg); + } + let rc = unsafe { + (self.implementation.filter)(cursor.as_ptr(), arg_count as i32, filter_args.as_ptr()) + }; + match rc { + ResultCode::OK => Ok(()), + _ => Err(LimboError::ExtensionError("Filter failed".to_string())), + } + } + + pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result { + let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) }; + OwnedValue::from_ffi(&val) + } + + pub fn next(&self, cursor: &VTabOpaqueCursor) -> Result { + let rc = unsafe { (self.implementation.next)(cursor.as_ptr()) }; + match rc { + ResultCode::OK => Ok(true), + ResultCode::EOF => Ok(false), + _ => Err(LimboError::ExtensionError("Next failed".to_string())), + } + } +} + pub(crate) struct SymbolTable { pub functions: HashMap>, #[cfg(not(target_family = "wasm"))] extensions: Vec<(Library, *const ExtensionApi)>, + pub vtabs: HashMap, } impl std::fmt::Debug for SymbolTable { @@ -551,6 +616,7 @@ impl SymbolTable { pub fn new() -> Self { Self { functions: HashMap::new(), + vtabs: HashMap::new(), #[cfg(not(target_family = "wasm"))] extensions: Vec::new(), } diff --git a/core/schema.rs b/core/schema.rs index a5f1e6121..e7688b58e 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -1,3 +1,4 @@ +use crate::VirtualTable; use crate::{util::normalize_ident, Result}; use core::fmt; use fallible_iterator::FallibleIterator; @@ -47,6 +48,7 @@ impl Schema { pub enum Table { BTree(Rc), Pseudo(Rc), + Virtual(Rc), } impl Table { @@ -54,6 +56,7 @@ impl Table { match self { Table::BTree(table) => table.root_page, Table::Pseudo(_) => unimplemented!(), + Table::Virtual(_) => unimplemented!(), } } @@ -61,6 +64,7 @@ impl Table { match self { Self::BTree(table) => &table.name, Self::Pseudo(_) => "", + Self::Virtual(table) => &table.name, } } @@ -74,6 +78,10 @@ impl Table { .columns .get(index) .expect("column index out of bounds"), + Self::Virtual(table) => table + .columns + .get(index) + .expect("column index out of bounds"), } } @@ -81,6 +89,7 @@ impl Table { match self { Self::BTree(table) => &table.columns, Self::Pseudo(table) => &table.columns, + Self::Virtual(table) => &table.columns, } } @@ -88,6 +97,13 @@ impl Table { match self { Self::BTree(table) => Some(table.clone()), Self::Pseudo(_) => None, + Self::Virtual(_) => None, + } + } + pub fn virtual_table(&self) -> Option> { + match self { + Self::Virtual(table) => Some(table.clone()), + _ => None, } } } @@ -97,6 +113,7 @@ impl PartialEq for Table { match (self, other) { (Self::BTree(a), Self::BTree(b)) => Rc::ptr_eq(a, b), (Self::Pseudo(a), Self::Pseudo(b)) => Rc::ptr_eq(a, b), + (Self::Virtual(a), Self::Virtual(b)) => Rc::ptr_eq(a, b), _ => false, } } diff --git a/core/translate/delete.rs b/core/translate/delete.rs index ffad33d73..675b58f34 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -7,7 +7,7 @@ use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts, QueryMode}; use crate::{schema::Schema, Result, SymbolTable}; use sqlite3_parser::ast::{Expr, Limit, QualifiedName}; -use super::plan::TableReference; +use super::plan::{TableReference, TableReferenceType}; pub fn translate_delete( query_mode: QueryMode, @@ -48,6 +48,7 @@ pub fn prepare_delete_plan( identifier: table.name.clone(), op: Operation::Scan { iter_dir: None }, join_info: None, + reference_type: TableReferenceType::BTreeTable, }]; let mut where_predicates = vec![]; diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 45fb1a648..8b2e70185 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -13,7 +13,7 @@ use crate::vdbe::{ use crate::Result; use super::emitter::Resolver; -use super::plan::{Operation, TableReference}; +use super::plan::{Operation, TableReference, TableReferenceType}; #[derive(Debug, Clone, Copy)] pub struct ConditionMetadata { @@ -1824,22 +1824,45 @@ pub fn translate_expr( // If we are reading a column from a table, we find the cursor that corresponds to // the table and read the column from the cursor. Operation::Scan { .. } | Operation::Search(_) => { - let cursor_id = program.resolve_cursor_id(&table_reference.identifier); - if *is_rowid_alias { - program.emit_insn(Insn::RowId { - cursor_id, - dest: target_register, - }); - } else { - program.emit_insn(Insn::Column { - cursor_id, - column: *column, - dest: target_register, - }); + match &table_reference.reference_type { + TableReferenceType::BTreeTable => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + if *is_rowid_alias { + program.emit_insn(Insn::RowId { + cursor_id, + dest: target_register, + }); + } else { + program.emit_insn(Insn::Column { + cursor_id, + column: *column, + dest: target_register, + }); + } + let column = table_reference.table.get_column_at(*column); + maybe_apply_affinity(column.ty, target_register, program); + Ok(target_register) + } + TableReferenceType::VirtualTable { .. } => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + program.emit_insn(Insn::VColumn { + cursor_id, + column: *column, + dest: target_register, + }); + Ok(target_register) + } + TableReferenceType::Subquery { + result_columns_start_reg, + } => { + program.emit_insn(Insn::Copy { + src_reg: result_columns_start_reg + *column, + dst_reg: target_register, + amount: 0, + }); + Ok(target_register) + } } - let column = table_reference.table.get_column_at(*column); - maybe_apply_affinity(column.ty, target_register, program); - Ok(target_register) } // If we are reading a column from a subquery, we instead copy the column from the // subquery's result registers. diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index b4fff9c7a..4558693e3 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -17,7 +17,7 @@ use super::{ order_by::{order_by_sorter_insert, sorter_insert}, plan::{ IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference, - WhereTerm, + TableReferenceType, WhereTerm, }, }; @@ -78,27 +78,40 @@ pub fn init_loop( } match &table.op { Operation::Scan { .. } => { + let ref_type = &table.reference_type; let cursor_id = program.alloc_cursor_id( Some(table.identifier.clone()), - CursorType::BTreeTable(table.btree().unwrap().clone()), + match ref_type { + TableReferenceType::BTreeTable => { + CursorType::BTreeTable(table.btree().unwrap().clone()) + } + TableReferenceType::VirtualTable { .. } => { + CursorType::VirtualTable(table.virtual_table().unwrap().clone()) + } + other => panic!("Invalid table reference type in Scan: {:?}", other), + }, ); - let root_page = table.table.get_root_page(); - - match mode { - OperationMode::SELECT => { + match (mode, ref_type) { + (OperationMode::SELECT, TableReferenceType::BTreeTable) => { + let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenReadAsync { cursor_id, root_page, }); program.emit_insn(Insn::OpenReadAwait {}); } - OperationMode::DELETE => { + (OperationMode::DELETE, TableReferenceType::BTreeTable) => { + let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenWriteAsync { cursor_id, root_page, }); program.emit_insn(Insn::OpenWriteAwait {}); } + (OperationMode::SELECT, TableReferenceType::VirtualTable { .. }) => { + program.emit_insn(Insn::VOpenAsync { cursor_id }); + program.emit_insn(Insn::VOpenAwait {}); + } _ => { unimplemented!() } @@ -245,31 +258,52 @@ pub fn open_loop( } } Operation::Scan { iter_dir } => { + let ref_type = &table.reference_type; let cursor_id = program.resolve_cursor_id(&table.identifier); - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::LastAsync { cursor_id }); - } else { - program.emit_insn(Insn::RewindAsync { cursor_id }); - } - program.emit_insn( + + if !matches!(ref_type, TableReferenceType::VirtualTable { .. }) { if iter_dir .as_ref() .is_some_and(|dir| *dir == IterationDirection::Backwards) { - Insn::LastAwait { - cursor_id, - pc_if_empty: loop_end, - } + program.emit_insn(Insn::LastAsync { cursor_id }); } else { - Insn::RewindAwait { - cursor_id, - pc_if_empty: loop_end, + program.emit_insn(Insn::RewindAsync { cursor_id }); + } + } + match ref_type { + TableReferenceType::BTreeTable => program.emit_insn( + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + Insn::LastAwait { + cursor_id, + pc_if_empty: loop_end, + } + } else { + Insn::RewindAwait { + cursor_id, + pc_if_empty: loop_end, + } + }, + ), + TableReferenceType::VirtualTable { args, .. } => { + let start_reg = program.alloc_registers(args.len()); + let mut cur_reg = start_reg; + for arg in args { + let reg = cur_reg; + cur_reg += 1; + translate_expr(program, Some(tables), arg, reg, &t_ctx.resolver)?; } - }, - ); + program.emit_insn(Insn::VFilter { + cursor_id, + arg_count: args.len(), + args_reg: start_reg, + }); + } + other => panic!("Unsupported table reference type: {:?}", other), + } program.resolve_label(loop_start, program.offset()); for cond in predicates @@ -688,29 +722,41 @@ pub fn close_loop( }); } Operation::Scan { iter_dir, .. } => { + let ref_type = &table.reference_type; program.resolve_label(loop_labels.next, program.offset()); let cursor_id = program.resolve_cursor_id(&table.identifier); - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::PrevAsync { cursor_id }); - } else { - program.emit_insn(Insn::NextAsync { cursor_id }); - } - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::PrevAwait { - cursor_id, - pc_if_next: loop_labels.loop_start, - }); - } else { - program.emit_insn(Insn::NextAwait { - cursor_id, - pc_if_next: loop_labels.loop_start, - }); + match ref_type { + TableReferenceType::BTreeTable { .. } => { + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + program.emit_insn(Insn::PrevAsync { cursor_id }); + } else { + program.emit_insn(Insn::NextAsync { cursor_id }); + } + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + program.emit_insn(Insn::PrevAwait { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } else { + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } + } + TableReferenceType::VirtualTable { .. } => { + program.emit_insn(Insn::VNext { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } + other => unreachable!("Unsupported table reference type: {:?}", other), } } Operation::Search(search) => { diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 53df77956..73124060f 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -204,16 +204,16 @@ fn eliminate_constant_conditions( } fn push_scan_direction(table: &mut TableReference, direction: &Direction) { - match &mut table.op { - Operation::Scan { iter_dir, .. } => { - if iter_dir.is_none() { - match direction { - Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards), - Direction::Descending => *iter_dir = Some(IterationDirection::Backwards), - } + if let Operation::Scan { + ref mut iter_dir, .. + } = table.op + { + if iter_dir.is_none() { + match direction { + Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards), + Direction::Descending => *iter_dir = Some(IterationDirection::Backwards), } } - _ => {} } } @@ -309,12 +309,10 @@ impl Optimizable for ast::Expr { }; let column = table_reference.table.get_column_at(*column); for index in available_indexes_for_table.iter() { - if column - .name - .as_ref() - .map_or(false, |name| *name == index.columns.first().unwrap().name) - { - return Ok(Some(index.clone())); + if let Some(name) = column.name.as_ref() { + if &index.columns.first().unwrap().name == name { + return Ok(Some(index.clone())); + } } } Ok(None) diff --git a/core/translate/plan.rs b/core/translate/plan.rs index e59ffd5e8..43cba8e1b 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -9,6 +9,7 @@ use crate::{ function::AggFunc, schema::{BTreeTable, Column, Index, Table}, vdbe::BranchOffset, + VirtualTable, }; use crate::{ schema::{PseudoTable, Type}, @@ -197,11 +198,9 @@ pub struct TableReference { pub identifier: String, /// The join info for this table reference, if it is the right side of a join (which all except the first table reference have) pub join_info: Option, + pub reference_type: TableReferenceType, } -/** - A SourceOperator is a reference in the query plan that reads data from a table. -*/ #[derive(Clone, Debug)] pub enum Operation { // Scan operation @@ -226,10 +225,37 @@ pub enum Operation { }, } +/// The type of the table reference, either BTreeTable or Subquery +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TableReferenceType { + /// A BTreeTable is a table that is stored on disk in a B-tree index. + BTreeTable, + /// A subquery. + Subquery { + /// The index of the first register in the query plan that contains the result columns of the subquery. + result_columns_start_reg: usize, + }, + /// A virtual table. + VirtualTable { + /// Arguments to pass e.g. generate_series(1, 10, 2) + args: Vec, + }, +} + impl TableReference { /// Returns the btree table for this table reference, if it is a BTreeTable. pub fn btree(&self) -> Option> { - self.table.btree() + match &self.reference_type { + TableReferenceType::BTreeTable => self.table.btree(), + TableReferenceType::Subquery { .. } => None, + TableReferenceType::VirtualTable { .. } => None, + } + } + pub fn virtual_table(&self) -> Option> { + match &self.reference_type { + TableReferenceType::VirtualTable { .. } => self.table.virtual_table(), + _ => None, + } } /// Creates a new TableReference for a subquery. @@ -254,6 +280,9 @@ impl TableReference { result_columns_start_reg: 0, // Will be set in the bytecode emission phase }, table, + reference_type: TableReferenceType::Subquery { + result_columns_start_reg: 0, // Will be set in the bytecode emission phase + }, identifier: identifier.clone(), join_info, } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 95cee1edf..272b788a2 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,7 +1,9 @@ +use std::rc::Rc; + use super::{ plan::{ Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, - WhereTerm, + TableReferenceType, WhereTerm, }, select::prepare_select_plan, SymbolTable, @@ -301,6 +303,7 @@ fn parse_from_clause_table( table: Table::BTree(table.clone()), identifier: alias.unwrap_or(normalized_qualified_name), join_info: None, + reference_type: TableReferenceType::BTreeTable, }) } ast::SelectTable::Select(subselect, maybe_alias) => { @@ -317,8 +320,30 @@ fn parse_from_clause_table( ast::As::Elided(id) => id.0.clone(), }) .unwrap_or(format!("subquery_{}", cur_table_index)); - let table_reference = TableReference::new_subquery(identifier, subplan, None); - Ok(table_reference) + Ok(TableReference::new_subquery(identifier, subplan, None)) + } + ast::SelectTable::TableCall(qualified_name, mut maybe_args, maybe_alias) => { + let normalized_name = normalize_ident(qualified_name.name.0.as_str()); + let Some(vtab) = syms.vtabs.get(&normalized_name) else { + crate::bail_parse_error!("Virtual table {} not found", normalized_name); + }; + let alias = maybe_alias + .as_ref() + .map(|a| match a { + ast::As::As(id) => id.0.clone(), + ast::As::Elided(id) => id.0.clone(), + }) + .unwrap_or(normalized_name); + + Ok(TableReference { + op: Operation::Scan { iter_dir: None }, + join_info: None, + table: Table::Virtual(vtab.clone().into()), + identifier: alias.clone(), + reference_type: TableReferenceType::VirtualTable { + args: maybe_args.take().unwrap_or_default(), + }, + }) } _ => todo!(), } diff --git a/core/types.rs b/core/types.rs index 06c459741..ae31314c1 100644 --- a/core/types.rs +++ b/core/types.rs @@ -6,6 +6,7 @@ use crate::pseudo::PseudoCursor; use crate::storage::btree::BTreeCursor; use crate::storage::sqlite3_ondisk::write_varint; use crate::vdbe::sorter::Sorter; +use crate::vdbe::VTabOpaqueCursor; use crate::Result; use std::fmt::Display; use std::rc::Rc; @@ -670,6 +671,7 @@ pub enum Cursor { Index(BTreeCursor), Pseudo(PseudoCursor), Sorter(Sorter), + Virtual(VTabOpaqueCursor), } impl Cursor { @@ -716,6 +718,13 @@ impl Cursor { _ => panic!("Cursor is not a sorter cursor"), } } + + pub fn as_virtual_mut(&mut self) -> &mut VTabOpaqueCursor { + match self { + Self::Virtual(cursor) => cursor, + _ => panic!("Cursor is not a virtual cursor"), + } + } } pub enum CursorResult { diff --git a/core/util.rs b/core/util.rs index c92c31412..5251b36cf 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1,9 +1,9 @@ use std::{rc::Rc, sync::Arc}; -use sqlite3_parser::ast::{Expr, FunctionTail, Literal}; +use sqlite3_parser::ast::{CreateTableBody, Expr, FunctionTail, Literal}; use crate::{ - schema::{self, Schema}, + schema::{self, Column, Schema, Type}, Result, Statement, StepResult, IO, }; @@ -308,6 +308,77 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { } } +pub fn columns_from_create_table_body(body: CreateTableBody) -> Result, ()> { + let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else { + return Err(()); + }; + + Ok(columns + .into_iter() + .filter_map(|(name, column_def)| { + // if column_def.col_type includes HIDDEN, omit it for now + if let Some(data_type) = column_def.col_type.as_ref() { + if data_type.name.as_str().contains("HIDDEN") { + return None; + } + } + let column = Column { + name: Some(name.0), + ty: match column_def.col_type { + Some(ref data_type) => { + // https://www.sqlite.org/datatype3.html + let type_name = data_type.name.as_str().to_uppercase(); + if type_name.contains("INT") { + Type::Integer + } else if type_name.contains("CHAR") + || type_name.contains("CLOB") + || type_name.contains("TEXT") + { + Type::Text + } else if type_name.contains("BLOB") || type_name.is_empty() { + Type::Blob + } else if type_name.contains("REAL") + || type_name.contains("FLOA") + || type_name.contains("DOUB") + { + Type::Real + } else { + Type::Numeric + } + } + None => Type::Null, + }, + default: column_def + .constraints + .iter() + .find_map(|c| match &c.constraint { + sqlite3_parser::ast::ColumnConstraint::Default(val) => Some(val.clone()), + _ => None, + }), + notnull: column_def.constraints.iter().any(|c| { + matches!( + c.constraint, + sqlite3_parser::ast::ColumnConstraint::NotNull { .. } + ) + }), + ty_str: column_def + .col_type + .clone() + .map(|t| t.name.to_string()) + .unwrap_or_default(), + primary_key: column_def.constraints.iter().any(|c| { + matches!( + c.constraint, + sqlite3_parser::ast::ColumnConstraint::PrimaryKey { .. } + ) + }), + is_rowid_alias: false, + }; + Some(column) + }) + .collect::>()) +} + #[cfg(test)] pub mod tests { use super::*; diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 21b4b8949..c3ead0d38 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -9,7 +9,7 @@ use crate::{ schema::{BTreeTable, Index, PseudoTable}, storage::sqlite3_ondisk::DatabaseHeader, translate::plan::{ResultSetColumn, TableReference}, - Connection, + Connection, VirtualTable, }; use super::{BranchOffset, CursorID, Insn, InsnReference, Program}; @@ -40,6 +40,7 @@ pub enum CursorType { BTreeIndex(Rc), Pseudo(Rc), Sorter, + VirtualTable(Rc), } impl CursorType { @@ -406,6 +407,9 @@ impl ProgramBuilder { Insn::IsNull { reg: _, target_pc } => { resolve(target_pc, "IsNull"); } + Insn::VNext { pc_if_next, .. } => { + resolve(pc_if_next, "VNext"); + } _ => continue, } } diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index e4c302bba..a0bb63023 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -363,6 +363,62 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::VOpenAsync { cursor_id } => ( + "VOpenAsync", + *cursor_id as i32, + 0, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VOpenAwait => ( + "VOpenAwait", + 0, + 0, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VFilter { + cursor_id, + arg_count, + args_reg, + } => ( + "VFilter", + *cursor_id as i32, + *arg_count as i32, + *args_reg as i32, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VColumn { + cursor_id, + column, + dest, + } => ( + "VColumn", + *cursor_id as i32, + *column as i32, + *dest as i32, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VNext { + cursor_id, + pc_if_next, + } => ( + "VNext", + *cursor_id as i32, + pc_if_next.to_debug_int(), + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), Insn::OpenPseudo { cursor_id, content_reg, @@ -423,6 +479,7 @@ pub fn insn_to_str( name } CursorType::Sorter => None, + CursorType::VirtualTable(v) => v.columns.get(*column).unwrap().name.as_ref(), }; ( "Column", diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 1cdb81e25..223f321aa 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -213,6 +213,35 @@ pub enum Insn { // Await for the completion of open cursor. OpenReadAwait, + /// Open a cursor for a virtual table. + VOpenAsync { + cursor_id: CursorID, + }, + + /// Await for the completion of open cursor for a virtual table. + VOpenAwait, + + /// Initialize the position of the virtual table cursor. + VFilter { + cursor_id: CursorID, + arg_count: usize, + args_reg: usize, + }, + + /// Read a column from the current row of the virtual table cursor. + VColumn { + cursor_id: CursorID, + column: usize, + dest: usize, + }, + + /// Advance the virtual table cursor to the next row. + /// TODO: async + VNext { + cursor_id: CursorID, + pc_if_next: BranchOffset, + }, + // Open a cursor for a pseudo-table that contains a single row. OpenPseudo { cursor_id: CursorID, diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index e77ccf551..c25dc572b 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -65,6 +65,7 @@ use sorter::Sorter; use std::borrow::BorrowMut; use std::cell::{Cell, RefCell, RefMut}; use std::collections::HashMap; +use std::ffi::c_void; use std::num::NonZero; use std::rc::{Rc, Weak}; @@ -267,6 +268,18 @@ fn get_cursor_as_sorter_mut<'long, 'short>( cursor } +fn get_cursor_as_virtual_mut<'long, 'short>( + cursors: &'short mut RefMut<'long, Vec>>, + cursor_id: CursorID, +) -> &'short mut VTabOpaqueCursor { + let cursor = cursors + .get_mut(cursor_id) + .expect("cursor id out of bounds") + .as_mut() + .expect("cursor not allocated") + .as_virtual_mut(); + cursor +} struct Bitfield([u64; N]); impl Bitfield { @@ -290,6 +303,18 @@ impl Bitfield { } } +pub struct VTabOpaqueCursor(*mut c_void); + +impl VTabOpaqueCursor { + pub fn new(cursor: *mut c_void) -> Self { + Self(cursor) + } + + pub fn as_ptr(&self) -> *mut c_void { + self.0 + } +} + /// The program state describes the environment in which the program executes. pub struct ProgramState { pub pc: InsnReference, @@ -370,6 +395,7 @@ macro_rules! must_be_btree_cursor { CursorType::BTreeIndex(_) => get_cursor_as_index_mut(&mut $cursors, $cursor_id), CursorType::Pseudo(_) => panic!("{} on pseudo cursor", $insn_name), CursorType::Sorter => panic!("{} on sorter cursor", $insn_name), + CursorType::VirtualTable(_) => panic!("{} on virtual table cursor", $insn_name), }; cursor }}; @@ -826,12 +852,79 @@ impl Program { CursorType::Sorter => { panic!("OpenReadAsync on sorter cursor"); } + CursorType::VirtualTable(_) => { + panic!("OpenReadAsync on virtual table cursor, use Insn::VOpenAsync instead"); + } } state.pc += 1; } Insn::OpenReadAwait => { state.pc += 1; } + Insn::VOpenAsync { cursor_id } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VOpenAsync on non-virtual table cursor"); + }; + let cursor = virtual_table.open(); + state + .cursors + .borrow_mut() + .insert(*cursor_id, Some(Cursor::Virtual(cursor))); + state.pc += 1; + } + Insn::VOpenAwait => { + state.pc += 1; + } + Insn::VFilter { + cursor_id, + arg_count, + args_reg, + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VFilter on non-virtual table cursor"); + }; + let mut cursors = state.cursors.borrow_mut(); + let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id); + let mut args = Vec::new(); + for i in 0..*arg_count { + args.push(state.registers[args_reg + i].clone()); + } + virtual_table.filter(cursor, *arg_count, args)?; + state.pc += 1; + } + Insn::VColumn { + cursor_id, + column, + dest, + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VColumn on non-virtual table cursor"); + }; + let mut cursors = state.cursors.borrow_mut(); + let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id); + state.registers[*dest] = virtual_table.column(cursor, *column)?; + state.pc += 1; + } + Insn::VNext { + cursor_id, + pc_if_next, + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VNextAsync on non-virtual table cursor"); + }; + let mut cursors = state.cursors.borrow_mut(); + let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id); + let has_more = virtual_table.next(cursor)?; + if has_more { + state.pc = pc_if_next.to_offset_int(); + } else { + state.pc += 1; + } + } Insn::OpenPseudo { cursor_id, content_reg: _, @@ -943,6 +1036,11 @@ impl Program { state.registers[*dest] = OwnedValue::Null; } } + CursorType::VirtualTable(_) => { + panic!( + "Insn::Column on virtual table cursor, use Insn::VColumn instead" + ); + } } state.pc += 1; diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 5f9bb09c5..fec363c44 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -1,5 +1,5 @@ mod types; -pub use limbo_macros::{register_extension, scalar, AggregateDerive}; +pub use limbo_macros::{register_extension, scalar, AggregateDerive, VTabModuleDerive}; use std::os::raw::{c_char, c_void}; pub use types::{ResultCode, Value, ValueType}; @@ -21,6 +21,30 @@ pub struct ExtensionApi { step_func: StepFunction, finalize_func: FinalizeFunction, ) -> ResultCode, + + pub register_module: unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + module: VTabModuleImpl, + ) -> ResultCode, + + pub declare_vtab: unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + sql: *const c_char, + ) -> ResultCode, +} + +impl ExtensionApi { + pub fn declare_virtual_table(&self, name: &str, sql: &str) -> ResultCode { + let Ok(name) = std::ffi::CString::new(name) else { + return ResultCode::Error; + }; + let Ok(sql) = std::ffi::CString::new(sql) else { + return ResultCode::Error; + }; + unsafe { (self.declare_vtab)(self.ctx, name.as_ptr(), sql.as_ptr()) } + } } pub type ExtensionEntryPoint = unsafe extern "C" fn(api: *const ExtensionApi) -> ResultCode; @@ -47,3 +71,52 @@ pub trait AggFunc { fn step(state: &mut Self::State, args: &[Value]); fn finalize(state: Self::State) -> Value; } + +#[repr(C)] +#[derive(Clone, Debug)] +pub struct VTabModuleImpl { + pub name: *const c_char, + pub connect: VtabFnConnect, + pub open: VtabFnOpen, + pub filter: VtabFnFilter, + pub column: VtabFnColumn, + pub next: VtabFnNext, + pub eof: VtabFnEof, +} + +pub type VtabFnConnect = unsafe extern "C" fn(api: *const c_void) -> ResultCode; + +pub type VtabFnOpen = unsafe extern "C" fn() -> *mut c_void; + +pub type VtabFnFilter = + unsafe extern "C" fn(cursor: *mut c_void, argc: i32, argv: *const Value) -> ResultCode; + +pub type VtabFnColumn = unsafe extern "C" fn(cursor: *mut c_void, idx: u32) -> Value; + +pub type VtabFnNext = unsafe extern "C" fn(cursor: *mut c_void) -> ResultCode; + +pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; + +pub trait VTabModule: 'static { + type VCursor: VTabCursor; + + fn name() -> &'static str; + fn connect(api: &ExtensionApi) -> ResultCode; + fn open() -> Self::VCursor; + fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; + fn column(cursor: &Self::VCursor, idx: u32) -> Value; + fn next(cursor: &mut Self::VCursor) -> ResultCode; + fn eof(cursor: &Self::VCursor) -> bool; +} + +pub trait VTabCursor: Sized { + fn rowid(&self) -> i64; + fn column(&self, idx: u32) -> Value; + fn eof(&self) -> bool; + fn next(&mut self) -> ResultCode; +} + +#[repr(C)] +pub struct VTabImpl { + pub module: VTabModuleImpl, +} diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 464e07bfd..4a1fa3978 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -2,8 +2,8 @@ use std::fmt::Display; /// Error type is of type ExtError which can be /// either a user defined error or an error code -#[derive(Clone, Copy)] #[repr(C)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum ResultCode { OK = 0, Error = 1, @@ -20,6 +20,7 @@ pub enum ResultCode { Internal = 12, Unavailable = 13, CustomError = 14, + EOF = 15, } impl ResultCode { @@ -50,6 +51,7 @@ impl Display for ResultCode { ResultCode::Internal => write!(f, "Internal Error"), ResultCode::Unavailable => write!(f, "Unavailable"), ResultCode::CustomError => write!(f, "Error "), + ResultCode::EOF => write!(f, "EOF"), } } } diff --git a/extensions/series/Cargo.toml b/extensions/series/Cargo.toml new file mode 100644 index 000000000..73a634ac7 --- /dev/null +++ b/extensions/series/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "limbo_series" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + + +[dependencies] +limbo_ext = { path = "../core"} +log = "0.4.20" diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs new file mode 100644 index 000000000..f438c6fce --- /dev/null +++ b/extensions/series/src/lib.rs @@ -0,0 +1,136 @@ +use limbo_ext::{ + register_extension, ExtensionApi, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value, + ValueType, +}; + +register_extension! { + vtabs: { GenerateSeriesVTab } +} + +/// A virtual table that generates a sequence of integers +#[derive(Debug, VTabModuleDerive)] +struct GenerateSeriesVTab; + +impl VTabModule for GenerateSeriesVTab { + type VCursor = GenerateSeriesCursor; + fn name() -> &'static str { + "generate_series" + } + + fn connect(api: &ExtensionApi) -> ResultCode { + // Create table schema + let sql = "CREATE TABLE generate_series( + value INTEGER, + start INTEGER HIDDEN, + stop INTEGER HIDDEN, + step INTEGER HIDDEN + )"; + let name = Self::name(); + api.declare_virtual_table(name, sql) + } + + fn open() -> Self::VCursor { + GenerateSeriesCursor { + start: 0, + stop: 0, + step: 0, + current: 0, + } + } + + fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { + // args are the start, stop, and step + if arg_count == 0 || arg_count > 3 { + return ResultCode::InvalidArgs; + } + let start = { + if args[0].value_type() == ValueType::Integer { + args[0].to_integer().unwrap() + } else { + return ResultCode::InvalidArgs; + } + }; + let stop = if args.len() == 1 { + i64::MAX + } else { + if args[1].value_type() == ValueType::Integer { + args[1].to_integer().unwrap() + } else { + return ResultCode::InvalidArgs; + } + }; + let step = if args.len() <= 2 { + 1 + } else { + if args[2].value_type() == ValueType::Integer { + args[2].to_integer().unwrap() + } else { + return ResultCode::InvalidArgs; + } + }; + cursor.start = start; + cursor.current = start; + cursor.stop = stop; + cursor.step = step; + ResultCode::OK + } + + fn column(cursor: &Self::VCursor, idx: u32) -> Value { + cursor.column(idx) + } + + fn next(cursor: &mut Self::VCursor) -> ResultCode { + GenerateSeriesCursor::next(cursor) + } + + fn eof(cursor: &Self::VCursor) -> bool { + cursor.eof() + } +} + +/// The cursor for iterating over the generated sequence +#[derive(Debug)] +struct GenerateSeriesCursor { + start: i64, + stop: i64, + step: i64, + current: i64, +} + +impl GenerateSeriesCursor { + fn next(&mut self) -> ResultCode { + let current = self.current; + + // Check if we've reached the end + if (self.step > 0 && current >= self.stop) || (self.step < 0 && current <= self.stop) { + return ResultCode::EOF; + } + + self.current = current.saturating_add(self.step); + ResultCode::OK + } +} + +impl VTabCursor for GenerateSeriesCursor { + fn next(&mut self) -> ResultCode { + GenerateSeriesCursor::next(self) + } + + fn eof(&self) -> bool { + (self.step > 0 && self.current > self.stop) || (self.step < 0 && self.current < self.stop) + } + + fn column(&self, idx: u32) -> Value { + match idx { + 0 => Value::from_integer(self.current), + 1 => Value::from_integer(self.start), + 2 => Value::from_integer(self.stop), + 3 => Value::from_integer(self.step), + _ => Value::null(), + } + } + + fn rowid(&self) -> i64 { + ((self.current - self.start) / self.step) + 1 + } +} diff --git a/macros/src/args.rs b/macros/src/args.rs index d9e59cbd3..12446b660 100644 --- a/macros/src/args.rs +++ b/macros/src/args.rs @@ -6,31 +6,32 @@ use syn::{Ident, LitStr, Token}; pub(crate) struct RegisterExtensionInput { pub aggregates: Vec, pub scalars: Vec, + pub vtabs: Vec, } impl syn::parse::Parse for RegisterExtensionInput { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut aggregates = Vec::new(); let mut scalars = Vec::new(); - + let mut vtabs = Vec::new(); while !input.is_empty() { if input.peek(syn::Ident) && input.peek2(Token![:]) { let section_name: Ident = input.parse()?; input.parse::()?; - - if section_name == "aggregates" || section_name == "scalars" { + let names = ["aggregates", "scalars", "vtabs"]; + if names.contains(§ion_name.to_string().as_str()) { let content; syn::braced!(content in input); - let parsed_items = Punctuated::::parse_terminated(&content)? .into_iter() .collect(); - if section_name == "aggregates" { - aggregates = parsed_items; - } else { - scalars = parsed_items; - } + match section_name.to_string().as_str() { + "aggregates" => aggregates = parsed_items, + "scalars" => scalars = parsed_items, + "vtabs" => vtabs = parsed_items, + _ => unreachable!(), + }; if input.peek(Token![,]) { input.parse::()?; @@ -39,13 +40,14 @@ impl syn::parse::Parse for RegisterExtensionInput { return Err(syn::Error::new(section_name.span(), "Unknown section")); } } else { - return Err(input.error("Expected aggregates: or scalars: section")); + return Err(input.error("Expected aggregates:, scalars:, or vtabs: section")); } } Ok(Self { aggregates, scalars, + vtabs, }) } } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 1e0ef421e..6b0df9679 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -324,6 +324,103 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } +#[proc_macro_derive(VTabModuleDerive)] +pub fn derive_vtab_module(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let struct_name = &ast.ident; + + let register_fn_name = format_ident!("register_{}", struct_name); + let connect_fn_name = format_ident!("connect_{}", struct_name); + let open_fn_name = format_ident!("open_{}", struct_name); + let filter_fn_name = format_ident!("filter_{}", struct_name); + let column_fn_name = format_ident!("column_{}", struct_name); + let next_fn_name = format_ident!("next_{}", struct_name); + let eof_fn_name = format_ident!("eof_{}", struct_name); + + let expanded = quote! { + impl #struct_name { + #[no_mangle] + unsafe extern "C" fn #connect_fn_name( + db: *const ::std::ffi::c_void, + ) -> ::limbo_ext::ResultCode { + let api = unsafe { &*(db as *const ExtensionApi) }; + <#struct_name as ::limbo_ext::VTabModule>::connect(api) + } + + #[no_mangle] + unsafe extern "C" fn #open_fn_name( + ) -> *mut ::std::ffi::c_void { + let cursor = <#struct_name as ::limbo_ext::VTabModule>::open(); + Box::into_raw(Box::new(cursor)) as *mut ::std::ffi::c_void + } + + #[no_mangle] + unsafe extern "C" fn #filter_fn_name( + cursor: *mut ::std::ffi::c_void, + argc: i32, + argv: *const ::limbo_ext::Value, + ) -> ::limbo_ext::ResultCode { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + let args = std::slice::from_raw_parts(argv, argc as usize); + <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, argc, args) + } + + #[no_mangle] + unsafe extern "C" fn #column_fn_name( + cursor: *mut ::std::ffi::c_void, + idx: u32, + ) -> ::limbo_ext::Value { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) + } + + #[no_mangle] + unsafe extern "C" fn #next_fn_name( + cursor: *mut ::std::ffi::c_void, + ) -> ::limbo_ext::ResultCode { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + <#struct_name as ::limbo_ext::VTabModule>::next(cursor) + } + + #[no_mangle] + unsafe extern "C" fn #eof_fn_name( + cursor: *mut ::std::ffi::c_void, + ) -> bool { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) + } + + #[no_mangle] + pub unsafe extern "C" fn #register_fn_name( + api: *const ::limbo_ext::ExtensionApi + ) -> ::limbo_ext::ResultCode { + if api.is_null() { + return ::limbo_ext::ResultCode::Error; + } + + let api = &*api; + let name = <#struct_name as ::limbo_ext::VTabModule>::name(); + // name needs to be a c str FFI compatible, NOT CString + let name_c = std::ffi::CString::new(name).unwrap(); + + let module = ::limbo_ext::VTabModuleImpl { + name: name_c.as_ptr(), + connect: Self::#connect_fn_name, + open: Self::#open_fn_name, + filter: Self::#filter_fn_name, + column: Self::#column_fn_name, + next: Self::#next_fn_name, + eof: Self::#eof_fn_name, + }; + + (api.register_module)(api.ctx, name_c.as_ptr(), module) + } + } + }; + + TokenStream::from(expanded) +} + /// Register your extension with 'core' by providing the relevant functions ///```ignore ///use limbo_ext::{register_extension, scalar, Value, AggregateDerive, AggFunc}; @@ -362,6 +459,7 @@ pub fn register_extension(input: TokenStream) -> TokenStream { let RegisterExtensionInput { aggregates, scalars, + vtabs, } = input_ast; let scalar_calls = scalars.iter().map(|scalar_ident| { @@ -388,8 +486,23 @@ pub fn register_extension(input: TokenStream) -> TokenStream { } } }); + let vtab_calls = vtabs.iter().map(|vtab_ident| { + let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span()); + quote! { + { + let result = unsafe{ #vtab_ident::#register_fn(api)}; + if result == ::limbo_ext::ResultCode::OK { + let result = <#vtab_ident as ::limbo_ext::VTabModule>::connect(api); + return result; + } else { + return result; + } + } + } + }); let static_aggregates = aggregate_calls.clone(); let static_scalars = scalar_calls.clone(); + let static_vtabs = vtab_calls.clone(); let expanded = quote! { #[cfg(not(target_family = "wasm"))] @@ -404,20 +517,23 @@ pub fn register_extension(input: TokenStream) -> TokenStream { #(#static_aggregates)* + #(#static_vtabs)* + ::limbo_ext::ResultCode::OK } #[cfg(not(feature = "static"))] - #[no_mangle] - pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { - let api = unsafe { &*api }; - #(#scalar_calls)* + #[no_mangle] + pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { + let api = unsafe { &*api }; + #(#scalar_calls)* - #(#aggregate_calls)* + #(#aggregate_calls)* - ::limbo_ext::ResultCode::OK - } + #(#vtab_calls)* + + ::limbo_ext::ResultCode::OK + } }; - TokenStream::from(expanded) } From 661c74e338287876b930d717bfd456e8400a1644 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sat, 1 Feb 2025 18:51:27 -0500 Subject: [PATCH 2/7] Apply new planner structure to virtual table impl --- Cargo.lock | 2 ++ core/Cargo.toml | 2 ++ core/ext/mod.rs | 5 ++++ core/util.rs | 5 ++-- core/vdbe/mod.rs | 1 + extensions/core/src/lib.rs | 2 +- extensions/series/Cargo.toml | 8 +++++- extensions/series/src/lib.rs | 53 ++++++++++++++---------------------- macros/src/lib.rs | 24 +++++++++++++--- 9 files changed, 61 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 26cd80646..2c0eb8db1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1616,6 +1616,7 @@ dependencies = [ "limbo_macros", "limbo_percentile", "limbo_regexp", + "limbo_series", "limbo_time", "limbo_uuid", "log", @@ -1707,6 +1708,7 @@ version = "0.0.14" dependencies = [ "limbo_ext", "log", + "mimalloc", ] [[package]] diff --git a/core/Cargo.toml b/core/Cargo.toml index 386bf01c7..687f4ff19 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -27,6 +27,7 @@ percentile = ["limbo_percentile/static"] regexp = ["limbo_regexp/static"] time = ["limbo_time/static"] crypto = ["limbo_crypto/static"] +series = ["limbo_series/static"] [target.'cfg(target_os = "linux")'.dependencies] io-uring = { version = "0.6.1", optional = true } @@ -67,6 +68,7 @@ limbo_regexp = { path = "../extensions/regexp", optional = true, features = ["st limbo_percentile = { path = "../extensions/percentile", optional = true, features = ["static"] } limbo_time = { path = "../extensions/time", optional = true, features = ["static"] } limbo_crypto = { path = "../extensions/crypto", optional = true, features = ["static"] } +limbo_series = { path = "../extensions/series", optional = true, features = ["static"] } miette = "7.4.0" strum = "0.26" parking_lot = "0.12.3" diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 6d034e313..c4b6006e3 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -131,6 +131,7 @@ impl Database { return ResultCode::Error; }; let vtab_module = self.vtab_modules.get(name).unwrap().clone(); + let vtab = VirtualTable { name: name.to_string(), implementation: vtab_module, @@ -172,6 +173,10 @@ impl Database { if unsafe { !limbo_crypto::register_extension_static(&ext_api).is_ok() } { return Err("Failed to register crypto extension".to_string()); } + #[cfg(feature = "series")] + if unsafe { !limbo_series::register_extension_static(&ext_api).is_ok() } { + return Err("Failed to register series extension".to_string()); + } Ok(()) } } diff --git a/core/util.rs b/core/util.rs index 5251b36cf..654951700 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1,7 +1,6 @@ +use sqlite3_parser::ast::{self, CreateTableBody, Expr, FunctionTail, Literal}; use std::{rc::Rc, sync::Arc}; -use sqlite3_parser::ast::{CreateTableBody, Expr, FunctionTail, Literal}; - use crate::{ schema::{self, Column, Schema, Type}, Result, Statement, StepResult, IO, @@ -308,7 +307,7 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { } } -pub fn columns_from_create_table_body(body: CreateTableBody) -> Result, ()> { +pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result, ()> { let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else { return Err(()); }; diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index c25dc572b..5003b72c6 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -280,6 +280,7 @@ fn get_cursor_as_virtual_mut<'long, 'short>( .as_virtual_mut(); cursor } + struct Bitfield([u64; N]); impl Bitfield { diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index fec363c44..0e550fca1 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -99,8 +99,8 @@ pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; pub trait VTabModule: 'static { type VCursor: VTabCursor; + const NAME: &'static str; - fn name() -> &'static str; fn connect(api: &ExtensionApi) -> ResultCode; fn open() -> Self::VCursor; fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; diff --git a/extensions/series/Cargo.toml b/extensions/series/Cargo.toml index 73a634ac7..cca322294 100644 --- a/extensions/series/Cargo.toml +++ b/extensions/series/Cargo.toml @@ -6,10 +6,16 @@ edition.workspace = true license.workspace = true repository.workspace = true +[features] +static = ["limbo_ext/static"] + [lib] crate-type = ["cdylib", "lib"] [dependencies] -limbo_ext = { path = "../core"} +limbo_ext = { path = "../core", features = ["static"] } log = "0.4.20" + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index f438c6fce..63b6c6227 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -1,21 +1,27 @@ use limbo_ext::{ register_extension, ExtensionApi, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value, - ValueType, }; register_extension! { vtabs: { GenerateSeriesVTab } } +macro_rules! try_option { + ($expr:expr, $err:expr) => { + match $expr { + Some(val) => val, + None => return $err, + } + }; +} + /// A virtual table that generates a sequence of integers #[derive(Debug, VTabModuleDerive)] struct GenerateSeriesVTab; impl VTabModule for GenerateSeriesVTab { type VCursor = GenerateSeriesCursor; - fn name() -> &'static str { - "generate_series" - } + const NAME: &'static str = "generate_series"; fn connect(api: &ExtensionApi) -> ResultCode { // Create table schema @@ -25,8 +31,7 @@ impl VTabModule for GenerateSeriesVTab { stop INTEGER HIDDEN, step INTEGER HIDDEN )"; - let name = Self::name(); - api.declare_virtual_table(name, sql) + api.declare_virtual_table(Self::NAME, sql) } fn open() -> Self::VCursor { @@ -43,35 +48,19 @@ impl VTabModule for GenerateSeriesVTab { if arg_count == 0 || arg_count > 3 { return ResultCode::InvalidArgs; } - let start = { - if args[0].value_type() == ValueType::Integer { - args[0].to_integer().unwrap() - } else { - return ResultCode::InvalidArgs; - } - }; - let stop = if args.len() == 1 { - i64::MAX - } else { - if args[1].value_type() == ValueType::Integer { - args[1].to_integer().unwrap() - } else { - return ResultCode::InvalidArgs; - } - }; - let step = if args.len() <= 2 { - 1 - } else { - if args[2].value_type() == ValueType::Integer { - args[2].to_integer().unwrap() - } else { - return ResultCode::InvalidArgs; - } - }; + let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); + let stop = try_option!( + args.get(1).map(|v| v.to_integer().unwrap_or(i64::MAX)), + ResultCode::InvalidArgs + ); + let step = try_option!( + args.get(2).map(|v| v.to_integer().unwrap_or(1)), + ResultCode::InvalidArgs + ); cursor.start = start; cursor.current = start; - cursor.stop = stop; cursor.step = step; + cursor.stop = stop; ResultCode::OK } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 6b0df9679..8dee8dc66 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -343,6 +343,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { unsafe extern "C" fn #connect_fn_name( db: *const ::std::ffi::c_void, ) -> ::limbo_ext::ResultCode { + if db.is_null() { + return ::limbo_ext::ResultCode::Error; + } let api = unsafe { &*(db as *const ExtensionApi) }; <#struct_name as ::limbo_ext::VTabModule>::connect(api) } @@ -360,6 +363,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { argc: i32, argv: *const ::limbo_ext::Value, ) -> ::limbo_ext::ResultCode { + if cursor.is_null() { + return ::limbo_ext::ResultCode::Error; + } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; let args = std::slice::from_raw_parts(argv, argc as usize); <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, argc, args) @@ -370,6 +376,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { cursor: *mut ::std::ffi::c_void, idx: u32, ) -> ::limbo_ext::Value { + if cursor.is_null() { + return ::limbo_ext::Value::error(ResultCode::Error); + } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) } @@ -378,6 +387,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { unsafe extern "C" fn #next_fn_name( cursor: *mut ::std::ffi::c_void, ) -> ::limbo_ext::ResultCode { + if cursor.is_null() { + return ::limbo_ext::ResultCode::Error; + } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; <#struct_name as ::limbo_ext::VTabModule>::next(cursor) } @@ -386,6 +398,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { unsafe extern "C" fn #eof_fn_name( cursor: *mut ::std::ffi::c_void, ) -> bool { + if cursor.is_null() { + return true; + } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) } @@ -399,7 +414,7 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { } let api = &*api; - let name = <#struct_name as ::limbo_ext::VTabModule>::name(); + let name = <#struct_name as ::limbo_ext::VTabModule>::NAME; // name needs to be a c str FFI compatible, NOT CString let name_c = std::ffi::CString::new(name).unwrap(); @@ -493,9 +508,9 @@ pub fn register_extension(input: TokenStream) -> TokenStream { let result = unsafe{ #vtab_ident::#register_fn(api)}; if result == ::limbo_ext::ResultCode::OK { let result = <#vtab_ident as ::limbo_ext::VTabModule>::connect(api); - return result; - } else { - return result; + if !result.is_ok() { + return result; + } } } } @@ -535,5 +550,6 @@ pub fn register_extension(input: TokenStream) -> TokenStream { ::limbo_ext::ResultCode::OK } }; + TokenStream::from(expanded) } From d4c06545e14f723e34f2bacdc7b261eb5ceea93f Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 2 Feb 2025 20:18:36 -0500 Subject: [PATCH 3/7] Refactor vtable impl and remove Rc Refcell from module --- core/lib.rs | 4 +- core/schema.rs | 20 +++------ core/translate/expr.rs | 4 +- core/translate/optimizer.rs | 4 +- core/translate/planner.rs | 2 - core/translate/select.rs | 6 +-- extensions/core/src/lib.rs | 5 +++ extensions/series/src/lib.rs | 13 ++++++ macros/src/lib.rs | 83 ++++++++++++++++++++++++++++++++++++ 9 files changed, 118 insertions(+), 23 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index ccc2c2273..f381a9fce 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -363,7 +363,7 @@ impl Connection { pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow(); + let syms: &SymbolTable = &db.syms.borrow_mut(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; if let Some(cmd) = cmd { @@ -417,7 +417,7 @@ impl Connection { #[cfg(not(target_family = "wasm"))] pub fn load_extension>(&self, path: P) -> Result<()> { - Database::load_extension(self.db.as_ref(), path) + Database::load_extension(&self.db, path) } /// Close a connection and checkpoint. diff --git a/core/schema.rs b/core/schema.rs index e7688b58e..f4a6aee2b 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -68,20 +68,11 @@ impl Table { } } - pub fn get_column_at(&self, index: usize) -> &Column { + pub fn get_column_at(&self, index: usize) -> Option<&Column> { match self { - Self::BTree(table) => table - .columns - .get(index) - .expect("column index out of bounds"), - Self::Pseudo(table) => table - .columns - .get(index) - .expect("column index out of bounds"), - Self::Virtual(table) => table - .columns - .get(index) - .expect("column index out of bounds"), + Self::BTree(table) => table.columns.get(index), + Self::Pseudo(table) => table.columns.get(index), + Self::Virtual(table) => table.columns.get(index), } } @@ -100,6 +91,7 @@ impl Table { Self::Virtual(_) => None, } } + pub fn virtual_table(&self) -> Option> { match self { Self::Virtual(table) => Some(table.clone()), @@ -172,7 +164,7 @@ impl BTreeTable { sql.push_str(",\n"); } sql.push_str(" "); - sql.push_str(&column.name.as_ref().expect("column name is None")); + sql.push_str(column.name.as_ref().expect("column name is None")); sql.push(' '); sql.push_str(&column.ty.to_string()); } diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 8b2e70185..bef18c9f2 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1839,7 +1839,9 @@ pub fn translate_expr( dest: target_register, }); } - let column = table_reference.table.get_column_at(*column); + let Some(column) = table_reference.table.get_column_at(*column) else { + crate::bail_parse_error!("column index out of bounds"); + }; maybe_apply_affinity(column.ty, target_register, program); Ok(target_register) } diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 73124060f..99de57398 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -307,7 +307,9 @@ impl Optimizable for ast::Expr { else { return Ok(None); }; - let column = table_reference.table.get_column_at(*column); + let Some(column) = table_reference.table.get_column_at(*column) else { + return Ok(None); + }; for index in available_indexes_for_table.iter() { if let Some(name) = column.name.as_ref() { if &index.columns.first().unwrap().name == name { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 272b788a2..dcde7dc62 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use super::{ plan::{ Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, diff --git a/core/translate/select.rs b/core/translate/select.rs index 2940cfca6..2a055afd2 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -28,9 +28,9 @@ pub fn translate_select( let mut program = ProgramBuilder::new(ProgramBuilderOpts { query_mode, - num_cursors: count_plan_required_cursors(&select), - approx_num_insns: estimate_num_instructions(&select), - approx_num_labels: estimate_num_labels(&select), + num_cursors: count_plan_required_cursors(select), + approx_num_insns: estimate_num_instructions(select), + approx_num_labels: estimate_num_labels(select), }); emit_program(&mut program, select_plan, syms)?; Ok(program) diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 0e550fca1..30ddece57 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -110,10 +110,15 @@ pub trait VTabModule: 'static { } pub trait VTabCursor: Sized { + type Error; fn rowid(&self) -> i64; fn column(&self, idx: u32) -> Value; fn eof(&self) -> bool; fn next(&mut self) -> ResultCode; + fn set_error(&mut self, error: Self::Error); + fn error(&self) -> Option { + None + } } #[repr(C)] diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 63b6c6227..ef278d451 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -40,12 +40,14 @@ impl VTabModule for GenerateSeriesVTab { stop: 0, step: 0, current: 0, + error: None, } } fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { // args are the start, stop, and step if arg_count == 0 || arg_count > 3 { + cursor.set_error("Expected between 1 and 3 arguments"); return ResultCode::InvalidArgs; } let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); @@ -84,6 +86,7 @@ struct GenerateSeriesCursor { stop: i64, step: i64, current: i64, + error: Option<&'static str>, } impl GenerateSeriesCursor { @@ -101,6 +104,8 @@ impl GenerateSeriesCursor { } impl VTabCursor for GenerateSeriesCursor { + type Error = &'static str; + fn next(&mut self) -> ResultCode { GenerateSeriesCursor::next(self) } @@ -119,6 +124,14 @@ impl VTabCursor for GenerateSeriesCursor { } } + fn error(&self) -> Option { + self.error + } + + fn set_error(&mut self, err: &'static str) { + self.error = Some(err); + } + fn rowid(&self) -> i64 { ((self.current - self.start) / self.step) + 1 } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8dee8dc66..56d019525 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -324,6 +324,89 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } +/// Macro to derive a VTabModule for your extension. This macro will generate +/// the necessary functions to register your module with core. You must implement +/// the VTabModule trait for your struct, and the VTabCursor trait for your cursor. +/// ```ignore +///#[derive(Debug, VTabModuleDerive)] +///struct CsvVTab; +///impl VTabModule for CsvVTab { +/// type VCursor = CsvCursor; +/// const NAME: &'static str = "csv_data"; +/// +/// /// Declare the schema for your virtual table +/// fn connect(api: &ExtensionApi) -> ResultCode { +/// let sql = "CREATE TABLE csv_data( +/// name TEXT, +/// age TEXT, +/// city TEXT +/// )"; +/// api.declare_virtual_table(Self::NAME, sql) +/// } +/// /// Open the virtual table and return a cursor +/// fn open() -> Self::VCursor { +/// let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); +/// let rows: Vec> = csv_content +/// .lines() +/// .skip(1) +/// .map(|line| { +/// line.split(',') +/// .map(|s| s.trim().to_string()) +/// .collect() +/// }) +/// .collect(); +/// CsvCursor { rows, index: 0 } +/// } +/// /// Filter the virtual table based on arguments (omitted here for simplicity) +/// fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { +/// ResultCode::OK +/// } +/// /// Return the value for a given column index +/// fn column(cursor: &Self::VCursor, idx: u32) -> Value { +/// cursor.column(idx) +/// } +/// /// Move the cursor to the next row +/// fn next(cursor: &mut Self::VCursor) -> ResultCode { +/// if cursor.index < cursor.rows.len() - 1 { +/// cursor.index += 1; +/// ResultCode::OK +/// } else { +/// ResultCode::EOF +/// } +/// } +/// fn eof(cursor: &Self::VCursor) -> bool { +/// cursor.index >= cursor.rows.len() +/// } +/// #[derive(Debug)] +/// struct CsvCursor { +/// rows: Vec>, +/// index: usize, +/// +/// impl CsvCursor { +/// /// Returns the value for a given column index. +/// fn column(&self, idx: u32) -> Value { +/// let row = &self.rows[self.index]; +/// if (idx as usize) < row.len() { +/// Value::from_text(&row[idx as usize]) +/// } else { +/// Value::null() +/// } +/// } +/// // Implement the VTabCursor trait for your virtual cursor +/// impl VTabCursor for CsvCursor { +/// fn next(&mut self) -> ResultCode { +/// Self::next(self) +/// } +/// fn eof(&self) -> bool { +/// self.index >= self.rows.len() +/// } +/// fn column(&self, idx: u32) -> Value { +/// self.column(idx) +/// } +/// fn rowid(&self) -> i64 { +/// self.index as i64 +/// } + #[proc_macro_derive(VTabModuleDerive)] pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); From ad30ccdc0e43a99465a32e6997ccd811480b0cbe Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 2 Feb 2025 20:20:17 -0500 Subject: [PATCH 4/7] Add docs in extension README for vtable modules --- core/lib.rs | 4 +- extensions/core/README.md | 119 +++++++++++++++++++++++++++++++++-- extensions/series/src/lib.rs | 7 +-- macros/src/lib.rs | 2 +- 4 files changed, 119 insertions(+), 13 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index f381a9fce..0c10003ac 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -26,8 +26,8 @@ use fallible_iterator::FallibleIterator; #[cfg(not(target_family = "wasm"))] use libloading::{Library, Symbol}; #[cfg(not(target_family = "wasm"))] -use limbo_ext::{ExtensionApi, ExtensionEntryPoint, ResultCode}; -use limbo_ext::{VTabModuleImpl, Value as ExtValue}; +use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; +use limbo_ext::{ResultCode, VTabModuleImpl, Value as ExtValue}; use log::trace; use parking_lot::RwLock; use schema::{Column, Schema}; diff --git a/extensions/core/README.md b/extensions/core/README.md index bcb7ff86f..6dd187122 100644 --- a/extensions/core/README.md +++ b/extensions/core/README.md @@ -9,7 +9,8 @@ like traditional `sqlite3` extensions, but are able to be written in much more e - [ x ] **Scalar Functions**: Create scalar functions using the `scalar` macro. - [ x ] **Aggregate Functions**: Define aggregate functions with `AggregateDerive` macro and `AggFunc` trait. - - [] **Virtual tables**: TODO + - [ x ] **Virtual tables**: Create a module for a virtual table with the `VTabModuleDerive` macro and `VTabCursor` trait. + - [] **VFS Modules** --- ## Installation @@ -17,24 +18,32 @@ like traditional `sqlite3` extensions, but are able to be written in much more e Add the crate to your `Cargo.toml`: ```toml + +[features] +static = ["limbo_ext/static"] + [dependencies] -limbo_ext = { path = "path/to/limbo/extensions/core" } # temporary until crate is published +limbo_ext = { path = "path/to/limbo/extensions/core", features = ["static"] } # temporary until crate is published + # mimalloc is required if you intend on linking dynamically. It is imported for you by the register_extension # macro, so no configuration is needed. But it must be added to your Cargo.toml [target.'cfg(not(target_family = "wasm"))'.dependencies] mimalloc = { version = "*", default-features = false } -``` -**NOTE** Crate must be of type `cdylib` if you wish to link dynamically -``` +# NOTE: Crate must be of type `cdylib` if you wish to link dynamically [lib] crate-type = ["cdylib", "lib"] ``` -`cargo build` will output a shared library that can be loaded with `.load target/debug/libyour_crate_name` +`cargo build` will output a shared library that can be loaded by the following options: +#### **CLI:** + `.load target/debug/libyour_crate_name` + +#### **SQL:** + `SELECT load_extension('target/debug/libyour_crate_name')` Extensions can be registered with the `register_extension!` macro: @@ -44,6 +53,7 @@ Extensions can be registered with the `register_extension!` macro: register_extension!{ scalars: { double }, // name of your function, if different from attribute name aggregates: { Percentile }, + vtabs: { CsvVTable }, } ``` @@ -140,4 +150,101 @@ impl AggFunc for Percentile { } ``` +### Virtual Table Example: +```rust + +/// Example: A virtual table that operates on a CSV file as a database table. +/// This example assumes that the CSV file is located at "data.csv" in the current directory. +#[derive(Debug, VTabModuleDerive)] +struct CsvVTable; + +impl VTabModule for CsvVTable { + type VCursor = CsvCursor; + /// Declare the name for your virtual table + const NAME: &'static str = "csv_data"; + + /// Declare the table schema and call `api.declare_virtual_table` with the schema sql. + fn connect(api: &ExtensionApi) -> ResultCode { + let sql = "CREATE TABLE csv_data( + name TEXT, + age TEXT, + city TEXT + )"; + api.declare_virtual_table(Self::NAME, sql) + } + + /// Open to return a new cursor: In this simple example, the CSV file is read completely into memory on connect. + fn open() -> Self::VCursor { + // Read CSV file contents from "data.csv" + let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); + // For simplicity, we'll ignore the header row. + let rows: Vec> = csv_content + .lines() + .skip(1) + .map(|line| { + line.split(',') + .map(|s| s.trim().to_string()) + .collect() + }) + .collect(); + CsvCursor { rows, index: 0 } + } + + /// Filter through result columns. (not used in this simple example) + fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { + ResultCode::OK + } + + /// Return the value for the column at the given index in the current row. + fn column(cursor: &Self::VCursor, idx: u32) -> Value { + cursor.column(idx) + } + + /// Next advances the cursor to the next row. + fn next(cursor: &mut Self::VCursor) -> ResultCode { + if cursor.index < cursor.rows.len() - 1 { + cursor.index += 1; + ResultCode::OK + } else { + ResultCode::EOF + } + } + + /// Return true if the cursor is at the end. + fn eof(cursor: &Self::VCursor) -> bool { + cursor.index >= cursor.rows.len() + } +} + +/// The cursor for iterating over CSV rows. +#[derive(Debug)] +struct CsvCursor { + rows: Vec>, + index: usize, +} + +/// Implement the VTabCursor trait for your cursor type +impl VTabCursor for CsvCursor { + fn next(&mut self) -> ResultCode { + CsvCursor::next(self) + } + + fn eof(&self) -> bool { + self.index >= self.rows.len() + } + + fn column(&self, idx: u32) -> Value { + let row = &self.rows[self.index]; + if (idx as usize) < row.len() { + Value::from_text(&row[idx as usize]) + } else { + Value::null() + } + } + + fn rowid(&self) -> i64 { + self.index as i64 + } +} +``` diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index ef278d451..fdd84c3b2 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -47,7 +47,6 @@ impl VTabModule for GenerateSeriesVTab { fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { // args are the start, stop, and step if arg_count == 0 || arg_count > 3 { - cursor.set_error("Expected between 1 and 3 arguments"); return ResultCode::InvalidArgs; } let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); @@ -86,7 +85,7 @@ struct GenerateSeriesCursor { stop: i64, step: i64, current: i64, - error: Option<&'static str>, + error: Option, } impl GenerateSeriesCursor { @@ -104,7 +103,7 @@ impl GenerateSeriesCursor { } impl VTabCursor for GenerateSeriesCursor { - type Error = &'static str; + type Error = ResultCode; fn next(&mut self) -> ResultCode { GenerateSeriesCursor::next(self) @@ -128,7 +127,7 @@ impl VTabCursor for GenerateSeriesCursor { self.error } - fn set_error(&mut self, err: &'static str) { + fn set_error(&mut self, err: ResultCode) { self.error = Some(err); } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 56d019525..632b95615 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -406,7 +406,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// fn rowid(&self) -> i64 { /// self.index as i64 /// } - +/// #[proc_macro_derive(VTabModuleDerive)] pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); From a8ae95716268162e313762599a5f35c90194a779 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Wed, 5 Feb 2025 13:35:24 -0500 Subject: [PATCH 5/7] Add tests for series extension, finish initial vtable impl --- core/lib.rs | 6 +- extensions/core/src/lib.rs | 2 +- extensions/series/src/lib.rs | 21 ++---- macros/src/lib.rs | 3 + testing/extensions.py | 134 ++++++++++++++++++++++------------- 5 files changed, 98 insertions(+), 68 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index 0c10003ac..41e8a6cd6 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -272,8 +272,8 @@ impl Connection { let sql = sql.as_ref(); trace!("Preparing: {}", sql); let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); + let syms = &db.syms.borrow(); let cmd = parser.next()?; if let Some(cmd) = cmd { match cmd { @@ -363,7 +363,7 @@ impl Connection { pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow_mut(); + let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; if let Some(cmd) = cmd { @@ -551,7 +551,7 @@ impl VirtualTable { }; match rc { ResultCode::OK => Ok(()), - _ => Err(LimboError::ExtensionError("Filter failed".to_string())), + _ => Err(LimboError::ExtensionError(rc.to_string())), } } diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 30ddece57..805051079 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -110,7 +110,7 @@ pub trait VTabModule: 'static { } pub trait VTabCursor: Sized { - type Error; + type Error: std::fmt::Display; fn rowid(&self) -> i64; fn column(&self, idx: u32) -> Value; fn eof(&self) -> bool; diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index fdd84c3b2..9732c909d 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -88,25 +88,16 @@ struct GenerateSeriesCursor { error: Option, } -impl GenerateSeriesCursor { - fn next(&mut self) -> ResultCode { - let current = self.current; - - // Check if we've reached the end - if (self.step > 0 && current >= self.stop) || (self.step < 0 && current <= self.stop) { - return ResultCode::EOF; - } - - self.current = current.saturating_add(self.step); - ResultCode::OK - } -} - impl VTabCursor for GenerateSeriesCursor { type Error = ResultCode; fn next(&mut self) -> ResultCode { - GenerateSeriesCursor::next(self) + let next_val = self.current.saturating_add(self.step); + if (self.step > 0 && next_val > self.stop) || (self.step < 0 && next_val < self.stop) { + return ResultCode::EOF; + } + self.current = next_val; + ResultCode::OK } fn eof(&self) -> bool { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 632b95615..2c88776f7 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -463,6 +463,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { return ::limbo_ext::Value::error(ResultCode::Error); } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + if let Some(err) = <#struct_name as ::limbo_ext::VTabModule>::VCursor::error(cursor) { + return ::limbo_ext::Value::error_with_message(err.to_string()); + } <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) } diff --git a/testing/extensions.py b/testing/extensions.py index d4a0a69c0..f2099d101 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -110,14 +110,19 @@ def validate_blob(result): # and assert they are valid hex digits return int(result, 16) is not None + def validate_string_uuid(result): return len(result) == 36 and result.count("-") == 4 -def returns_error(result): +def returns_error_no_func(result): return "error: no such function: " in result +def returns_vtable_parse_err(result): + return "Parse error: Virtual table" in result + + def returns_null(result): return result == "" or result == "\n" @@ -129,6 +134,7 @@ def assert_now_unixtime(result): def assert_specific_time(result): return result == "1736720789" + def test_uuid(pipe): specific_time = "01945ca0-3189-76c0-9a8f-caf310fc8b8e" # these are built into the binary, so we just test they work @@ -165,7 +171,7 @@ def test_regexp(pipe): extension_path = "./target/debug/liblimbo_regexp.so" # before extension loads, assert no function - run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error) + run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error_no_func) run_test(pipe, f".load {extension_path}", returns_null) print(f"Extension {extension_path} loaded successfully.") run_test(pipe, "SELECT regexp('a.c', 'abc');", validate_true) @@ -205,13 +211,14 @@ def validate_percentile2(res): def validate_percentile_disc(res): return res == "40.0" + def test_aggregates(pipe): extension_path = "./target/debug/liblimbo_percentile.so" # assert no function before extension loads run_test( pipe, "SELECT median(1);", - returns_error, + returns_error_no_func, "median agg function returns null when ext not loaded", ) run_test( @@ -252,63 +259,55 @@ def test_aggregates(pipe): pipe, "SELECT percentile_disc(value, 0.55) from test;", validate_percentile_disc ) -# Hashes -def validate_blake3(a): - return a == "6437b3ac38465133ffb63b75273a8db548c558465d79db03fd359c6cd5bd9d85" - -def validate_md5(a): - return a == "900150983cd24fb0d6963f7d28e17f72" - -def validate_sha1(a): - return a == "a9993e364706816aba3e25717850c26c9cd0d89d" - -def validate_sha256(a): - return a == "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" - -def validate_sha384(a): - return a == "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7" - -def validate_sha512(a): - return a == "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f" # Encoders and decoders def validate_url_encode(a): - return a == f"%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29" + return a == "%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29" + def validate_url_decode(a): return a == "/hello?text=(ಠ_ಠ)" + def validate_hex_encode(a): return a == "68656c6c6f" + def validate_hex_decode(a): return a == "hello" + def validate_base85_encode(a): return a == "BOu!rDZ" + def validate_base85_decode(a): return a == "hello" + def validate_base32_encode(a): return a == "NBSWY3DP" + def validate_base32_decode(a): return a == "hello" + def validate_base64_encode(a): return a == "aGVsbG8=" + def validate_base64_decode(a): return a == "hello" + def test_crypto(pipe): extension_path = "./target/debug/liblimbo_crypto.so" # assert no function before extension loads run_test( pipe, "SELECT crypto_blake('a');", - returns_error, + lambda res: "Error" in res, "crypto_blake3 returns null when ext not loaded", ) run_test( @@ -321,104 +320,139 @@ def test_crypto(pipe): run_test( pipe, "SELECT crypto_encode(crypto_blake3('abc'), 'hex');", - validate_blake3, - "blake3 should encrypt correctly" + lambda res: res + == "6437b3ac38465133ffb63b75273a8db548c558465d79db03fd359c6cd5bd9d85", + "blake3 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_md5('abc'), 'hex');", - validate_md5, - "md5 should encrypt correctly" + lambda res: res == "900150983cd24fb0d6963f7d28e17f72", + "md5 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_sha1('abc'), 'hex');", - validate_sha1, - "sha1 should encrypt correctly" + lambda res: res == "a9993e364706816aba3e25717850c26c9cd0d89d", + "sha1 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_sha256('abc'), 'hex');", - validate_sha256, - "sha256 should encrypt correctly" + lambda a: a + == "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad", + "sha256 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_sha384('abc'), 'hex');", - validate_sha384, - "sha384 should encrypt correctly" + lambda a: a + == "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7", + "sha384 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_sha512('abc'), 'hex');", - validate_sha512, - "sha512 should encrypt correctly" - ) + lambda a: a + == "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f", + "sha512 should encrypt correctly", + ) # Encoding and Decoding run_test( pipe, "SELECT crypto_encode('hello', 'base32');", validate_base32_encode, - "base32 should encode correctly" - ) + "base32 should encode correctly", + ) run_test( pipe, "SELECT crypto_decode('NBSWY3DP', 'base32');", validate_base32_decode, - "base32 should decode correctly" + "base32 should decode correctly", ) run_test( pipe, "SELECT crypto_encode('hello', 'base64');", validate_base64_encode, - "base64 should encode correctly" + "base64 should encode correctly", ) run_test( pipe, "SELECT crypto_decode('aGVsbG8=', 'base64');", validate_base64_decode, - "base64 should decode correctly" + "base64 should decode correctly", ) run_test( pipe, "SELECT crypto_encode('hello', 'base85');", validate_base85_encode, - "base85 should encode correctly" + "base85 should encode correctly", ) run_test( pipe, "SELECT crypto_decode('BOu!rDZ', 'base85');", validate_base85_decode, - "base85 should decode correctly" + "base85 should decode correctly", ) run_test( pipe, "SELECT crypto_encode('hello', 'hex');", validate_hex_encode, - "hex should encode correctly" + "hex should encode correctly", ) run_test( pipe, "SELECT crypto_decode('68656c6c6f', 'hex');", validate_hex_decode, - "hex should decode correctly" + "hex should decode correctly", ) - + run_test( pipe, "SELECT crypto_encode('/hello?text=(ಠ_ಠ)', 'url');", validate_url_encode, - "url should encode correctly" + "url should encode correctly", ) run_test( pipe, - f"SELECT crypto_decode('%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29', 'url');", + "SELECT crypto_decode('%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29', 'url');", validate_url_decode, - "url should decode correctly" + "url should decode correctly", ) + +def test_series(pipe): + ext_path = "./target/debug/liblimbo_series" + run_test( + pipe, + "SELECT * FROM generate_series(1, 10);", + lambda res: "Virtual table generate_series not found" in res, + ) + run_test(pipe, f".load {ext_path}", returns_null) + run_test( + pipe, + "SELECT * FROM generate_series(1, 10);", + lambda res: "Invalid Argument" in res, + ) + run_test( + pipe, + "SELECT * FROM generate_series(1, 10, 2);", + lambda res: res == "1\n3\n5\n7\n9", + ) + run_test( + pipe, + "SELECT * FROM generate_series(1, 10, 2, 3);", + lambda res: "Invalid Argument" in res, + ) + run_test( + pipe, + "SELECT * FROM generate_series(10, 1, -2);", + lambda res: res == "10\n8\n6\n4\n2", + ) + + def main(): pipe = init_limbo() try: @@ -426,6 +460,8 @@ def main(): test_uuid(pipe) test_aggregates(pipe) test_crypto(pipe) + test_series(pipe) + except Exception as e: print(f"Test FAILED: {e}") pipe.terminate() From cd83ac6146e56d603898f672c9c32efd697efe1e Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Wed, 5 Feb 2025 16:14:21 -0500 Subject: [PATCH 6/7] Remove error from vcursor trait in extensions --- extensions/core/src/lib.rs | 4 ---- extensions/series/src/lib.rs | 10 ---------- macros/src/lib.rs | 3 --- 3 files changed, 17 deletions(-) diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 805051079..22d90f572 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -115,10 +115,6 @@ pub trait VTabCursor: Sized { fn column(&self, idx: u32) -> Value; fn eof(&self) -> bool; fn next(&mut self) -> ResultCode; - fn set_error(&mut self, error: Self::Error); - fn error(&self) -> Option { - None - } } #[repr(C)] diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 9732c909d..83dd334ea 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -40,7 +40,6 @@ impl VTabModule for GenerateSeriesVTab { stop: 0, step: 0, current: 0, - error: None, } } @@ -85,7 +84,6 @@ struct GenerateSeriesCursor { stop: i64, step: i64, current: i64, - error: Option, } impl VTabCursor for GenerateSeriesCursor { @@ -114,14 +112,6 @@ impl VTabCursor for GenerateSeriesCursor { } } - fn error(&self) -> Option { - self.error - } - - fn set_error(&mut self, err: ResultCode) { - self.error = Some(err); - } - fn rowid(&self) -> i64 { ((self.current - self.start) / self.step) + 1 } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 2c88776f7..632b95615 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -463,9 +463,6 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { return ::limbo_ext::Value::error(ResultCode::Error); } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; - if let Some(err) = <#struct_name as ::limbo_ext::VTabModule>::VCursor::error(cursor) { - return ::limbo_ext::Value::error_with_message(err.to_string()); - } <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) } From ae88d51e6fa7a41d18ab2e17a31dfbf4c256372c Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Wed, 5 Feb 2025 21:30:55 -0500 Subject: [PATCH 7/7] Remove TableReferenceType enum to clean up planner --- core/ext/mod.rs | 1 + core/lib.rs | 3 +- core/translate/delete.rs | 3 +- core/translate/expr.rs | 67 ++++++++++++++++--------------------- core/translate/main_loop.rs | 43 ++++++++++++------------ core/translate/plan.rs | 32 +++--------------- core/translate/planner.rs | 27 +++++++++------ testing/extensions.py | 2 +- 8 files changed, 76 insertions(+), 102 deletions(-) diff --git a/core/ext/mod.rs b/core/ext/mod.rs index c4b6006e3..67fd78491 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -136,6 +136,7 @@ impl Database { name: name.to_string(), implementation: vtab_module, columns, + args: None, }; self.syms.borrow_mut().vtabs.insert(name.to_string(), vtab); ResultCode::OK diff --git a/core/lib.rs b/core/lib.rs index 41e8a6cd6..8fde24402 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -514,6 +514,7 @@ pub type StepResult = vdbe::StepResult; #[derive(Clone, Debug)] pub struct VirtualTable { name: String, + args: Option>, pub implementation: Rc, columns: Vec, } @@ -537,7 +538,7 @@ impl VirtualTable { OwnedValue::Null => Ok(ExtValue::null()), OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)), OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)), - OwnedValue::Text(t) => Ok(ExtValue::from_text((*t.value).clone())), + OwnedValue::Text(t) => Ok(ExtValue::from_text(t.as_str().to_string())), OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())), other => Err(LimboError::ExtensionError(format!( "Unsupported value type: {:?}", diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 675b58f34..ffad33d73 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -7,7 +7,7 @@ use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts, QueryMode}; use crate::{schema::Schema, Result, SymbolTable}; use sqlite3_parser::ast::{Expr, Limit, QualifiedName}; -use super::plan::{TableReference, TableReferenceType}; +use super::plan::TableReference; pub fn translate_delete( query_mode: QueryMode, @@ -48,7 +48,6 @@ pub fn prepare_delete_plan( identifier: table.name.clone(), op: Operation::Scan { iter_dir: None }, join_info: None, - reference_type: TableReferenceType::BTreeTable, }]; let mut where_predicates = vec![]; diff --git a/core/translate/expr.rs b/core/translate/expr.rs index bef18c9f2..c23cb053a 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -3,7 +3,7 @@ use sqlite3_parser::ast::{self, UnaryOperator}; #[cfg(feature = "json")] use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc, VectorFunc}; -use crate::schema::Type; +use crate::schema::{Table, Type}; use crate::util::normalize_ident; use crate::vdbe::{ builder::ProgramBuilder, @@ -13,7 +13,7 @@ use crate::vdbe::{ use crate::Result; use super::emitter::Resolver; -use super::plan::{Operation, TableReference, TableReferenceType}; +use super::plan::{Operation, TableReference}; #[derive(Debug, Clone, Copy)] pub struct ConditionMetadata { @@ -1823,49 +1823,38 @@ pub fn translate_expr( match table_reference.op { // If we are reading a column from a table, we find the cursor that corresponds to // the table and read the column from the cursor. - Operation::Scan { .. } | Operation::Search(_) => { - match &table_reference.reference_type { - TableReferenceType::BTreeTable => { - let cursor_id = program.resolve_cursor_id(&table_reference.identifier); - if *is_rowid_alias { - program.emit_insn(Insn::RowId { - cursor_id, - dest: target_register, - }); - } else { - program.emit_insn(Insn::Column { - cursor_id, - column: *column, - dest: target_register, - }); - } - let Some(column) = table_reference.table.get_column_at(*column) else { - crate::bail_parse_error!("column index out of bounds"); - }; - maybe_apply_affinity(column.ty, target_register, program); - Ok(target_register) - } - TableReferenceType::VirtualTable { .. } => { - let cursor_id = program.resolve_cursor_id(&table_reference.identifier); - program.emit_insn(Insn::VColumn { + Operation::Scan { .. } | Operation::Search(_) => match &table_reference.table { + Table::BTree(_) => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + if *is_rowid_alias { + program.emit_insn(Insn::RowId { + cursor_id, + dest: target_register, + }); + } else { + program.emit_insn(Insn::Column { cursor_id, column: *column, dest: target_register, }); - Ok(target_register) - } - TableReferenceType::Subquery { - result_columns_start_reg, - } => { - program.emit_insn(Insn::Copy { - src_reg: result_columns_start_reg + *column, - dst_reg: target_register, - amount: 0, - }); - Ok(target_register) } + let Some(column) = table_reference.table.get_column_at(*column) else { + crate::bail_parse_error!("column index out of bounds"); + }; + maybe_apply_affinity(column.ty, target_register, program); + Ok(target_register) } - } + Table::Virtual(_) => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + program.emit_insn(Insn::VColumn { + cursor_id, + column: *column, + dest: target_register, + }); + Ok(target_register) + } + _ => unreachable!(), + }, // If we are reading a column from a subquery, we instead copy the column from the // subquery's result registers. Operation::Subquery { diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 4558693e3..35cc505c9 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -1,6 +1,7 @@ use sqlite3_parser::ast; use crate::{ + schema::Table, translate::result_row::emit_select_result, vdbe::{ builder::{CursorType, ProgramBuilder}, @@ -17,7 +18,7 @@ use super::{ order_by::{order_by_sorter_insert, sorter_insert}, plan::{ IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference, - TableReferenceType, WhereTerm, + WhereTerm, }, }; @@ -78,21 +79,18 @@ pub fn init_loop( } match &table.op { Operation::Scan { .. } => { - let ref_type = &table.reference_type; let cursor_id = program.alloc_cursor_id( Some(table.identifier.clone()), - match ref_type { - TableReferenceType::BTreeTable => { - CursorType::BTreeTable(table.btree().unwrap().clone()) - } - TableReferenceType::VirtualTable { .. } => { + match &table.table { + Table::BTree(_) => CursorType::BTreeTable(table.btree().unwrap().clone()), + Table::Virtual(_) => { CursorType::VirtualTable(table.virtual_table().unwrap().clone()) } other => panic!("Invalid table reference type in Scan: {:?}", other), }, ); - match (mode, ref_type) { - (OperationMode::SELECT, TableReferenceType::BTreeTable) => { + match (mode, &table.table) { + (OperationMode::SELECT, Table::BTree(_)) => { let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenReadAsync { cursor_id, @@ -100,7 +98,7 @@ pub fn init_loop( }); program.emit_insn(Insn::OpenReadAwait {}); } - (OperationMode::DELETE, TableReferenceType::BTreeTable) => { + (OperationMode::DELETE, Table::BTree(_)) => { let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenWriteAsync { cursor_id, @@ -108,7 +106,7 @@ pub fn init_loop( }); program.emit_insn(Insn::OpenWriteAwait {}); } - (OperationMode::SELECT, TableReferenceType::VirtualTable { .. }) => { + (OperationMode::SELECT, Table::Virtual(_)) => { program.emit_insn(Insn::VOpenAsync { cursor_id }); program.emit_insn(Insn::VOpenAwait {}); } @@ -258,10 +256,9 @@ pub fn open_loop( } } Operation::Scan { iter_dir } => { - let ref_type = &table.reference_type; let cursor_id = program.resolve_cursor_id(&table.identifier); - if !matches!(ref_type, TableReferenceType::VirtualTable { .. }) { + if !matches!(&table.table, Table::Virtual(_)) { if iter_dir .as_ref() .is_some_and(|dir| *dir == IterationDirection::Backwards) @@ -271,8 +268,8 @@ pub fn open_loop( program.emit_insn(Insn::RewindAsync { cursor_id }); } } - match ref_type { - TableReferenceType::BTreeTable => program.emit_insn( + match &table.table { + Table::BTree(_) => program.emit_insn( if iter_dir .as_ref() .is_some_and(|dir| *dir == IterationDirection::Backwards) @@ -288,13 +285,18 @@ pub fn open_loop( } }, ), - TableReferenceType::VirtualTable { args, .. } => { + Table::Virtual(ref table) => { + let args = if let Some(args) = table.args.as_ref() { + args + } else { + &vec![] + }; let start_reg = program.alloc_registers(args.len()); let mut cur_reg = start_reg; for arg in args { let reg = cur_reg; cur_reg += 1; - translate_expr(program, Some(tables), arg, reg, &t_ctx.resolver)?; + translate_expr(program, Some(tables), &arg, reg, &t_ctx.resolver)?; } program.emit_insn(Insn::VFilter { cursor_id, @@ -722,11 +724,10 @@ pub fn close_loop( }); } Operation::Scan { iter_dir, .. } => { - let ref_type = &table.reference_type; program.resolve_label(loop_labels.next, program.offset()); let cursor_id = program.resolve_cursor_id(&table.identifier); - match ref_type { - TableReferenceType::BTreeTable { .. } => { + match &table.table { + Table::BTree(_) => { if iter_dir .as_ref() .is_some_and(|dir| *dir == IterationDirection::Backwards) @@ -750,7 +751,7 @@ pub fn close_loop( }); } } - TableReferenceType::VirtualTable { .. } => { + Table::Virtual(_) => { program.emit_insn(Insn::VNext { cursor_id, pc_if_next: loop_labels.loop_start, diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 43cba8e1b..8195aea13 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -198,7 +198,6 @@ pub struct TableReference { pub identifier: String, /// The join info for this table reference, if it is the right side of a join (which all except the first table reference have) pub join_info: Option, - pub reference_type: TableReferenceType, } #[derive(Clone, Debug)] @@ -225,35 +224,17 @@ pub enum Operation { }, } -/// The type of the table reference, either BTreeTable or Subquery -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum TableReferenceType { - /// A BTreeTable is a table that is stored on disk in a B-tree index. - BTreeTable, - /// A subquery. - Subquery { - /// The index of the first register in the query plan that contains the result columns of the subquery. - result_columns_start_reg: usize, - }, - /// A virtual table. - VirtualTable { - /// Arguments to pass e.g. generate_series(1, 10, 2) - args: Vec, - }, -} - impl TableReference { /// Returns the btree table for this table reference, if it is a BTreeTable. pub fn btree(&self) -> Option> { - match &self.reference_type { - TableReferenceType::BTreeTable => self.table.btree(), - TableReferenceType::Subquery { .. } => None, - TableReferenceType::VirtualTable { .. } => None, + match &self.table { + Table::BTree(_) => self.table.btree(), + _ => None, } } pub fn virtual_table(&self) -> Option> { - match &self.reference_type { - TableReferenceType::VirtualTable { .. } => self.table.virtual_table(), + match &self.table { + Table::Virtual(_) => self.table.virtual_table(), _ => None, } } @@ -280,9 +261,6 @@ impl TableReference { result_columns_start_reg: 0, // Will be set in the bytecode emission phase }, table, - reference_type: TableReferenceType::Subquery { - result_columns_start_reg: 0, // Will be set in the bytecode emission phase - }, identifier: identifier.clone(), join_info, } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index dcde7dc62..311458f9f 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,7 +1,7 @@ use super::{ plan::{ Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, - TableReferenceType, WhereTerm, + WhereTerm, }, select::prepare_select_plan, SymbolTable, @@ -11,7 +11,7 @@ use crate::{ schema::{Schema, Table}, util::{exprs_are_equivalent, normalize_ident}, vdbe::BranchOffset, - Result, + Result, VirtualTable, }; use sqlite3_parser::ast::{self, Expr, FromClause, JoinType, Limit, UnaryOperator}; @@ -301,7 +301,6 @@ fn parse_from_clause_table( table: Table::BTree(table.clone()), identifier: alias.unwrap_or(normalized_qualified_name), join_info: None, - reference_type: TableReferenceType::BTreeTable, }) } ast::SelectTable::Select(subselect, maybe_alias) => { @@ -320,9 +319,9 @@ fn parse_from_clause_table( .unwrap_or(format!("subquery_{}", cur_table_index)); Ok(TableReference::new_subquery(identifier, subplan, None)) } - ast::SelectTable::TableCall(qualified_name, mut maybe_args, maybe_alias) => { - let normalized_name = normalize_ident(qualified_name.name.0.as_str()); - let Some(vtab) = syms.vtabs.get(&normalized_name) else { + ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => { + let normalized_name = &normalize_ident(qualified_name.name.0.as_str()); + let Some(vtab) = syms.vtabs.get(normalized_name) else { crate::bail_parse_error!("Virtual table {} not found", normalized_name); }; let alias = maybe_alias @@ -331,16 +330,22 @@ fn parse_from_clause_table( ast::As::As(id) => id.0.clone(), ast::As::Elided(id) => id.0.clone(), }) - .unwrap_or(normalized_name); + .unwrap_or(normalized_name.to_string()); Ok(TableReference { op: Operation::Scan { iter_dir: None }, join_info: None, - table: Table::Virtual(vtab.clone().into()), + table: Table::Virtual( + VirtualTable { + name: normalized_name.clone(), + args: maybe_args, + implementation: vtab.implementation.clone(), + columns: vtab.columns.clone(), + } + .into(), + ) + .into(), identifier: alias.clone(), - reference_type: TableReferenceType::VirtualTable { - args: maybe_args.take().unwrap_or_default(), - }, }) } _ => todo!(), diff --git a/testing/extensions.py b/testing/extensions.py index f2099d101..cda953f86 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -307,7 +307,7 @@ def test_crypto(pipe): run_test( pipe, "SELECT crypto_blake('a');", - lambda res: "Error" in res, + lambda res: "Parse error" in res, "crypto_blake3 returns null when ext not loaded", ) run_test(