diff --git a/COMPAT.md b/COMPAT.md index 05ab1dfb8..bc20ce0ed 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -572,14 +572,14 @@ Modifiers: | Trace | No | | | Transaction | Yes | | | VBegin | No | | -| VColumn | No | | -| VCreate | No | | +| VColumn | Yes | | +| VCreate | Yes | | | VDestroy | No | | -| VFilter | No | | -| VNext | No | | -| VOpen | No | | +| VFilter | Yes | | +| VNext | Yes | | +| VOpen | Yes |VOpenAsync| | VRename | No | | -| VUpdate | No | | +| VUpdate | Yes | | | Vacuum | No | | | Variable | No | | | VerifyCookie | No | | diff --git a/Cargo.lock b/Cargo.lock index 812d3b644..c5af95f10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1670,6 +1670,15 @@ dependencies = [ "limbo_macros", ] +[[package]] +name = "limbo_kv" +version = "0.0.15" +dependencies = [ + "lazy_static", + "limbo_ext", + "mimalloc", +] + [[package]] name = "limbo_macros" version = "0.0.15" diff --git a/Cargo.toml b/Cargo.toml index d7fe68fed..425dc07cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,8 @@ members = [ "cli", "core", "extensions/core", - "extensions/crypto", + "extensions/crypto", + "extensions/kvstore", "extensions/percentile", "extensions/regexp", "extensions/series", diff --git a/core/error.rs b/core/error.rs index 53308114c..7832747eb 100644 --- a/core/error.rs +++ b/core/error.rs @@ -76,5 +76,11 @@ macro_rules! bail_constraint_error { }; } +impl From for LimboError { + fn from(err: limbo_ext::ResultCode) -> Self { + LimboError::ExtensionError(err.to_string()) + } +} + pub const SQLITE_CONSTRAINT: usize = 19; pub const SQLITE_CONSTRAINT_PRIMARYKEY: usize = SQLITE_CONSTRAINT | (6 << 8); diff --git a/core/ext/mod.rs b/core/ext/mod.rs index a4f5d6cc3..1402c4098 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,17 +1,20 @@ -use crate::{function::ExternalFunc, util::columns_from_create_table_body, Database, VirtualTable}; -use fallible_iterator::FallibleIterator; -use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabModuleImpl}; -pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; -use limbo_sqlite3_parser::{ - ast::{Cmd, Stmt}, - lexer::sql::Parser, +use crate::{function::ExternalFunc, Database}; +use limbo_ext::{ + ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabKind, VTabModuleImpl, }; +pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; use std::{ - ffi::{c_char, c_void, CStr}, + ffi::{c_char, c_void, CStr, CString}, rc::Rc, }; type ExternAggFunc = (InitAggFunction, StepFunction, FinalizeFunction); +#[derive(Clone)] +pub struct VTabImpl { + pub module_kind: VTabKind, + pub implementation: Rc, +} + unsafe extern "C" fn register_scalar_function( ctx: *mut c_void, name: *const c_char, @@ -53,8 +56,12 @@ unsafe extern "C" fn register_module( ctx: *mut c_void, name: *const c_char, module: VTabModuleImpl, + kind: VTabKind, ) -> ResultCode { - let c_str = unsafe { CStr::from_ptr(name) }; + if name.is_null() || ctx.is_null() { + return ResultCode::Error; + } + let c_str = unsafe { CString::from_raw(name as *mut i8) }; let name_str = match c_str.to_str() { Ok(s) => s.to_string(), Err(_) => return ResultCode::Error, @@ -64,31 +71,7 @@ unsafe extern "C" fn register_module( } let db = unsafe { &mut *(ctx as *mut Database) }; - db.register_module_impl(&name_str, module) -} - -unsafe extern "C" fn declare_vtab( - ctx: *mut c_void, - name: *const c_char, - sql: *const c_char, -) -> ResultCode { - let c_str = unsafe { CStr::from_ptr(name) }; - let name_str = match c_str.to_str() { - Ok(s) => s.to_string(), - Err(_) => return ResultCode::Error, - }; - - let c_str = unsafe { CStr::from_ptr(sql) }; - let sql_str = match c_str.to_str() { - Ok(s) => s.to_string(), - Err(_) => return ResultCode::Error, - }; - - if ctx.is_null() { - return ResultCode::Error; - } - let db = unsafe { &mut *(ctx as *mut Database) }; - db.declare_vtab_impl(&name_str, &sql_str) + db.register_module_impl(&name_str, module, kind) } impl Database { @@ -113,32 +96,21 @@ impl Database { ResultCode::OK } - fn register_module_impl(&mut self, name: &str, module: VTabModuleImpl) -> ResultCode { - self.vtab_modules.insert(name.to_string(), Rc::new(module)); - ResultCode::OK - } - - fn declare_vtab_impl(&mut self, name: &str, sql: &str) -> ResultCode { - let mut parser = Parser::new(sql.as_bytes()); - let cmd = parser.next().unwrap().unwrap(); - let Cmd::Stmt(stmt) = cmd else { - return ResultCode::Error; + fn register_module_impl( + &mut self, + name: &str, + module: VTabModuleImpl, + kind: VTabKind, + ) -> ResultCode { + let module = Rc::new(module); + let vmodule = VTabImpl { + module_kind: kind, + implementation: module, }; - let Stmt::CreateTable { body, .. } = stmt else { - return ResultCode::Error; - }; - let Ok(columns) = columns_from_create_table_body(*body) else { - return ResultCode::Error; - }; - let vtab_module = self.vtab_modules.get(name).unwrap().clone(); - - let vtab = VirtualTable { - name: name.to_string(), - implementation: vtab_module, - columns, - args: None, - }; - self.syms.borrow_mut().vtabs.insert(name.to_string(), vtab); + self.syms + .borrow_mut() + .vtab_modules + .insert(name.to_string(), vmodule.into()); ResultCode::OK } @@ -148,7 +120,6 @@ impl Database { register_scalar_function, register_aggregate_function, register_module, - declare_vtab, } } diff --git a/core/lib.rs b/core/lib.rs index 5459a0b1f..55ea42e18 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -28,7 +28,7 @@ use fallible_iterator::FallibleIterator; use libloading::{Library, Symbol}; #[cfg(not(target_family = "wasm"))] use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; -use limbo_ext::{ResultCode, VTabModuleImpl, Value as ExtValue}; +use limbo_ext::{ResultCode, VTabKind, VTabModuleImpl, Value as ExtValue}; use limbo_sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser}; use parking_lot::RwLock; use schema::{Column, Schema}; @@ -50,7 +50,7 @@ pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; use types::OwnedValue; pub use types::Value; -use util::parse_schema_rows; +use util::{columns_from_create_table_body, parse_schema_rows}; use vdbe::builder::QueryMode; use vdbe::VTabOpaqueCursor; @@ -88,7 +88,6 @@ pub struct Database { schema: Rc>, header: Rc>, syms: Rc>, - vtab_modules: HashMap>, // Shared structures of a Database are the parts that are common to multiple threads that might // create DB connections. _shared_page_cache: Arc>, @@ -150,8 +149,7 @@ impl Database { header: header.clone(), _shared_page_cache: _shared_page_cache.clone(), _shared_wal: shared_wal.clone(), - syms, - vtab_modules: HashMap::new(), + syms: syms.clone(), }; if let Err(e) = db.register_builtins() { return Err(LimboError::ExtensionError(e)); @@ -170,7 +168,7 @@ impl Database { }); let rows = conn.query("SELECT * FROM sqlite_schema")?; let mut schema = schema.borrow_mut(); - parse_schema_rows(rows, &mut schema, io)?; + parse_schema_rows(rows, &mut schema, io, &syms.borrow())?; Ok(db) } @@ -277,10 +275,9 @@ impl Connection { pub fn prepare(self: &Rc, sql: impl AsRef) -> Result { let sql = sql.as_ref(); tracing::trace!("Preparing: {}", sql); - let db = &self.db; let mut parser = Parser::new(sql.as_bytes()); - let syms = &db.syms.borrow(); let cmd = parser.next()?; + let syms = self.db.syms.borrow(); if let Some(cmd) = cmd { match cmd { Cmd::Stmt(stmt) => { @@ -290,7 +287,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Normal, )?); Ok(Statement::new(program, self.pager.clone())) @@ -316,7 +313,7 @@ impl Connection { pub(crate) fn run_cmd(self: &Rc, cmd: Cmd) -> Result> { let db = self.db.clone(); - let syms: &SymbolTable = &db.syms.borrow(); + let syms = db.syms.borrow(); match cmd { Cmd::Stmt(stmt) => { let program = Rc::new(translate::translate( @@ -325,7 +322,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Normal, )?); let stmt = Statement::new(program, self.pager.clone()); @@ -338,7 +335,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Explain, )?; program.explain(); @@ -347,12 +344,8 @@ impl Connection { Cmd::ExplainQueryPlan(stmt) => { match stmt { ast::Stmt::Select(select) => { - let mut plan = prepare_select_plan( - &self.schema.borrow(), - *select, - &self.db.syms.borrow(), - None, - )?; + let mut plan = + prepare_select_plan(&self.schema.borrow(), *select, &syms, None)?; optimize_plan(&mut plan, &self.schema.borrow())?; println!("{}", plan); } @@ -369,10 +362,9 @@ impl Connection { pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); - let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; + let syms = self.db.syms.borrow(); if let Some(cmd) = cmd { match cmd { Cmd::Explain(stmt) => { @@ -382,7 +374,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Explain, )?; program.explain(); @@ -395,7 +387,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Normal, )?; @@ -531,8 +523,54 @@ pub struct VirtualTable { } impl VirtualTable { - pub fn open(&self) -> VTabOpaqueCursor { - let cursor = unsafe { (self.implementation.open)() }; + pub(crate) fn rowid(&self, cursor: &VTabOpaqueCursor) -> i64 { + unsafe { (self.implementation.rowid)(cursor.as_ptr()) } + } + /// takes ownership of the provided Args + pub(crate) fn from_args( + tbl_name: Option<&str>, + module_name: &str, + args: Vec, + syms: &SymbolTable, + kind: VTabKind, + exprs: Option>, + ) -> Result> { + let module = syms + .vtab_modules + .get(module_name) + .ok_or(LimboError::ExtensionError(format!( + "Virtual table module not found: {}", + module_name + )))?; + if let VTabKind::VirtualTable = kind { + if module.module_kind != VTabKind::VirtualTable { + return Err(LimboError::ExtensionError(format!( + "{} is not a virtual table module", + module_name + ))); + } + }; + let schema = module.implementation.as_ref().init_schema(args)?; + let mut parser = Parser::new(schema.as_bytes()); + if let ast::Cmd::Stmt(ast::Stmt::CreateTable { body, .. }) = parser.next()?.ok_or( + LimboError::ParseError("Failed to parse schema from virtual table module".to_string()), + )? { + let columns = columns_from_create_table_body(&body)?; + let vtab = Rc::new(VirtualTable { + name: tbl_name.unwrap_or(module_name).to_owned(), + implementation: module.implementation.clone(), + columns, + args: exprs, + }); + return Ok(vtab); + } + Err(crate::LimboError::ParseError( + "Failed to parse schema from virtual table module".to_string(), + )) + } + + pub fn open(&self) -> crate::Result { + let cursor = unsafe { (self.implementation.open)(self.implementation.ctx) }; VTabOpaqueCursor::new(cursor) } @@ -570,7 +608,7 @@ impl VirtualTable { pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result { let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) }; - OwnedValue::from_ffi(&val) + OwnedValue::from_ffi(val) } pub fn next(&self, cursor: &VTabOpaqueCursor) -> Result { @@ -581,13 +619,39 @@ impl VirtualTable { _ => Err(LimboError::ExtensionError("Next failed".to_string())), } } + + pub fn update(&self, args: &[OwnedValue]) -> Result> { + let arg_count = args.len(); + let ext_args = args.iter().map(|arg| arg.to_ffi()).collect::>(); + let newrowid = 0i64; + let implementation = self.implementation.as_ref(); + let rc = unsafe { + (self.implementation.update)( + implementation as *const VTabModuleImpl as *const std::ffi::c_void, + arg_count as i32, + ext_args.as_ptr(), + &newrowid as *const _ as *mut i64, + ) + }; + for arg in ext_args { + unsafe { + arg.free(); + } + } + match rc { + ResultCode::OK => Ok(None), + ResultCode::RowID => Ok(Some(newrowid)), + _ => Err(LimboError::ExtensionError(rc.to_string())), + } + } } pub(crate) struct SymbolTable { pub functions: HashMap>, #[cfg(not(target_family = "wasm"))] extensions: Vec<(Library, *const ExtensionApi)>, - pub vtabs: HashMap, + pub vtabs: HashMap>, + pub vtab_modules: HashMap>, } impl std::fmt::Debug for SymbolTable { @@ -632,6 +696,7 @@ impl SymbolTable { vtabs: HashMap::new(), #[cfg(not(target_family = "wasm"))] extensions: Vec::new(), + vtab_modules: HashMap::new(), } } diff --git a/core/schema.rs b/core/schema.rs index 38bcf86a5..22130a825 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -12,29 +12,46 @@ use std::rc::Rc; use tracing::trace; pub struct Schema { - pub tables: HashMap>, + pub tables: HashMap>, // table_name to list of indexes for the table pub indexes: HashMap>>, } impl Schema { pub fn new() -> Self { - let mut tables: HashMap> = HashMap::new(); + let mut tables: HashMap> = HashMap::new(); let indexes: HashMap>> = HashMap::new(); - tables.insert("sqlite_schema".to_string(), Rc::new(sqlite_schema_table())); + tables.insert( + "sqlite_schema".to_string(), + Rc::new(Table::BTree(sqlite_schema_table().into())), + ); Self { tables, indexes } } - pub fn add_table(&mut self, table: Rc) { + pub fn add_btree_table(&mut self, table: Rc) { let name = normalize_ident(&table.name); - self.tables.insert(name, table); + self.tables.insert(name, Table::BTree(table).into()); } - pub fn get_table(&self, name: &str) -> Option> { + pub fn add_virtual_table(&mut self, table: Rc) { + let name = normalize_ident(&table.name); + self.tables.insert(name, Table::Virtual(table).into()); + } + + pub fn get_table(&self, name: &str) -> Option> { let name = normalize_ident(name); self.tables.get(&name).cloned() } + pub fn get_btree_table(&self, name: &str) -> Option> { + let name = normalize_ident(name); + if let Some(table) = self.tables.get(&name) { + table.btree() + } else { + None + } + } + pub fn add_index(&mut self, index: Rc) { let table_name = normalize_ident(&index.table_name); self.indexes diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 6a55bfc03..1e0d64a98 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -25,7 +25,7 @@ pub fn translate_delete( let mut program = ProgramBuilder::new(ProgramBuilderOpts { query_mode, num_cursors: 1, - approx_num_insns: estimate_num_instructions(&delete), + approx_num_insns: estimate_num_instructions(delete), approx_num_labels: 0, }); emit_program(&mut program, delete_plan, syms)?; @@ -42,10 +42,17 @@ pub fn prepare_delete_plan( Some(table) => table, None => crate::bail_corrupt_error!("Parse error: no such table: {}", tbl_name), }; - + let table = if let Some(table) = table.virtual_table() { + Table::Virtual(table.clone()) + } else if let Some(table) = table.btree() { + Table::BTree(table.clone()) + } else { + crate::bail_corrupt_error!("Table is neither a virtual table nor a btree table"); + }; + let name = tbl_name.name.0.as_str().to_string(); let table_references = vec![TableReference { - table: Table::BTree(table.clone()), - identifier: table.name.clone(), + table, + identifier: name, op: Operation::Scan { iter_dir: None }, join_info: None, }]; diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 8de249de3..1fe085c94 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -292,7 +292,7 @@ pub fn emit_query<'a>( fn emit_program_for_delete( program: &mut ProgramBuilder, - mut plan: DeletePlan, + plan: DeletePlan, syms: &SymbolTable, ) -> Result<()> { let (mut t_ctx, init_label, start_offset) = prologue( @@ -304,6 +304,7 @@ fn emit_program_for_delete( // No rows will be read from source table loops if there is a constant false condition eg. WHERE 0 let after_main_loop_label = program.allocate_label(); + t_ctx.label_main_loop_end = Some(after_main_loop_label); if plan.contains_constant_false_condition { program.emit_insn(Insn::Goto { target_pc: after_main_loop_label, @@ -322,10 +323,9 @@ fn emit_program_for_delete( open_loop( program, &mut t_ctx, - &mut plan.table_references, + &plan.table_references, &plan.where_clause, )?; - emit_delete_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?; // Clean up and close the main execution loop @@ -364,8 +364,27 @@ fn emit_delete_insns( cursor_id, dest: key_reg, }); - program.emit_insn(Insn::DeleteAsync { cursor_id }); - program.emit_insn(Insn::DeleteAwait { cursor_id }); + + if let Some(vtab) = table_reference.virtual_table() { + let conflict_action = 0u16; + let start_reg = key_reg; + + let new_rowid_reg = program.alloc_register(); + program.emit_insn(Insn::Null { + dest: new_rowid_reg, + dest_end: None, + }); + program.emit_insn(Insn::VUpdate { + cursor_id, + arg_count: 2, + start_reg, + vtab_ptr: vtab.implementation.as_ref().ctx as usize, + conflict_action, + }); + } else { + program.emit_insn(Insn::DeleteAsync { cursor_id }); + program.emit_insn(Insn::DeleteAwait { cursor_id }); + } if let Some(limit) = limit { let limit_reg = program.alloc_register(); program.emit_insn(Insn::Integer { diff --git a/core/translate/insert.rs b/core/translate/insert.rs index eb36b7e75..53368d30b 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -1,15 +1,15 @@ use std::ops::Deref; +use std::rc::Rc; use limbo_sqlite3_parser::ast::{ - DistinctNames, Expr, InsertBody, QualifiedName, ResolveType, ResultColumn, With, + DistinctNames, Expr, InsertBody, OneSelect, QualifiedName, ResolveType, ResultColumn, With, }; use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; -use crate::schema::BTreeTable; +use crate::schema::Table; use crate::util::normalize_ident; use crate::vdbe::builder::{ProgramBuilderOpts, QueryMode}; use crate::vdbe::BranchOffset; -use crate::Result; use crate::{ schema::{Column, Schema}, translate::expr::translate_expr, @@ -19,6 +19,7 @@ use crate::{ }, SymbolTable, }; +use crate::{Result, VirtualTable}; use super::emitter::Resolver; @@ -46,32 +47,45 @@ pub fn translate_insert( if on_conflict.is_some() { crate::bail_parse_error!("ON CONFLICT clause is not supported"); } + + let table_name = &tbl_name.name; + let table = match schema.get_table(table_name.0.as_str()) { + Some(table) => table, + None => crate::bail_corrupt_error!("Parse error: no such table: {}", table_name), + }; let resolver = Resolver::new(syms); + if let Some(virtual_table) = &table.virtual_table() { + translate_virtual_table_insert( + &mut program, + virtual_table.clone(), + columns, + body, + on_conflict, + &resolver, + )?; + return Ok(program); + } let init_label = program.allocate_label(); program.emit_insn(Insn::Init { target_pc: init_label, }); let start_offset = program.offset(); - // open table - let table_name = &tbl_name.name; - - let table = match schema.get_table(table_name.0.as_str()) { - Some(table) => table, - None => crate::bail_corrupt_error!("Parse error: no such table: {}", table_name), + let Some(btree_table) = table.btree() else { + crate::bail_corrupt_error!("Parse error: no such table: {}", table_name); }; - if !table.has_rowid { + if !btree_table.has_rowid { crate::bail_parse_error!("INSERT into WITHOUT ROWID table is not supported"); } let cursor_id = program.alloc_cursor_id( Some(table_name.0.clone()), - CursorType::BTreeTable(table.clone()), + CursorType::BTreeTable(btree_table.clone()), ); - let root_page = table.root_page; + let root_page = btree_table.root_page; let values = match body { InsertBody::Select(select, None) => match &select.body.select.deref() { - limbo_sqlite3_parser::ast::OneSelect::Values(values) => values, + OneSelect::Values(values) => values, _ => todo!(), }, _ => todo!(), @@ -79,9 +93,9 @@ pub fn translate_insert( let column_mappings = resolve_columns_for_insert(&table, columns, values)?; // Check if rowid was provided (through INTEGER PRIMARY KEY as a rowid alias) - let rowid_alias_index = table.columns.iter().position(|c| c.is_rowid_alias); + let rowid_alias_index = btree_table.columns.iter().position(|c| c.is_rowid_alias); let has_user_provided_rowid = { - assert_eq!(column_mappings.len(), table.columns.len()); + assert_eq!(column_mappings.len(), btree_table.columns.len()); if let Some(index) = rowid_alias_index { column_mappings[index].value_index.is_some() } else { @@ -91,7 +105,7 @@ pub fn translate_insert( // allocate a register for each column in the table. if not provided by user, they will simply be set as null. // allocate an extra register for rowid regardless of whether user provided a rowid alias column. - let num_cols = table.columns.len(); + let num_cols = btree_table.columns.len(); let rowid_reg = program.alloc_registers(num_cols + 1); let column_registers_start = rowid_reg + 1; let rowid_alias_reg = { @@ -217,7 +231,7 @@ pub fn translate_insert( target_pc: make_record_label, }); let rowid_column_name = if let Some(index) = rowid_alias_index { - &table + btree_table .columns .get(index) .unwrap() @@ -302,7 +316,7 @@ struct ColumnMapping<'a> { /// - Named columns map to their corresponding value index /// - Unspecified columns map to None fn resolve_columns_for_insert<'a>( - table: &'a BTreeTable, + table: &'a Table, columns: &Option, values: &[Vec], ) -> Result>> { @@ -310,7 +324,7 @@ fn resolve_columns_for_insert<'a>( crate::bail_parse_error!("no values to insert"); } - let table_columns = &table.columns; + let table_columns = &table.columns(); // Case 1: No columns specified - map values to columns in order if columns.is_none() { @@ -318,7 +332,7 @@ fn resolve_columns_for_insert<'a>( if num_values > table_columns.len() { crate::bail_parse_error!( "table {} has {} columns but {} values were supplied", - &table.name, + &table.get_name(), table_columns.len(), num_values ); @@ -361,7 +375,11 @@ fn resolve_columns_for_insert<'a>( }); if table_index.is_none() { - crate::bail_parse_error!("table {} has no column named {}", &table.name, column_name); + crate::bail_parse_error!( + "table {} has no column named {}", + &table.get_name(), + column_name + ); } mappings[table_index.unwrap()].value_index = Some(value_index); @@ -425,3 +443,100 @@ fn populate_column_registers( } Ok(()) } + +fn translate_virtual_table_insert( + program: &mut ProgramBuilder, + virtual_table: Rc, + columns: &Option, + body: &InsertBody, + on_conflict: &Option, + resolver: &Resolver, +) -> Result<()> { + let init_label = program.allocate_label(); + program.emit_insn(Insn::Init { + target_pc: init_label, + }); + let start_offset = program.offset(); + + let values = match body { + InsertBody::Select(select, None) => match &select.body.select.deref() { + OneSelect::Values(values) => values, + _ => crate::bail_parse_error!("Virtual tables only support VALUES clause in INSERT"), + }, + InsertBody::DefaultValues => &vec![], + _ => crate::bail_parse_error!("Unsupported INSERT body for virtual tables"), + }; + + let table = Table::Virtual(virtual_table.clone()); + let column_mappings = resolve_columns_for_insert(&table, columns, values)?; + + let value_registers_start = program.alloc_registers(values[0].len()); + for (i, expr) in values[0].iter().enumerate() { + translate_expr(program, None, expr, value_registers_start + i, resolver)?; + } + /* * + * Inserts for virtual tables are done in a single step. + * argv[0] = (NULL for insert) + * argv[1] = (NULL for insert) + * argv[2..] = column values + * */ + + let rowid_reg = program.alloc_registers(column_mappings.len() + 3); + let insert_rowid_reg = rowid_reg + 1; // argv[1] = insert_rowid + let data_start_reg = rowid_reg + 2; // argv[2..] = column values + + program.emit_insn(Insn::Null { + dest: rowid_reg, + dest_end: None, + }); + program.emit_insn(Insn::Null { + dest: insert_rowid_reg, + dest_end: None, + }); + + for (i, mapping) in column_mappings.iter().enumerate() { + let target_reg = data_start_reg + i; + if let Some(value_index) = mapping.value_index { + program.emit_insn(Insn::Copy { + src_reg: value_registers_start + value_index, + dst_reg: target_reg, + amount: 1, + }); + } else { + program.emit_insn(Insn::Null { + dest: target_reg, + dest_end: None, + }); + } + } + + let conflict_action = on_conflict.as_ref().map(|c| c.bit_value()).unwrap_or(0) as u16; + + let cursor_id = program.alloc_cursor_id( + Some(virtual_table.name.clone()), + CursorType::VirtualTable(virtual_table.clone()), + ); + + program.emit_insn(Insn::VUpdate { + cursor_id, + arg_count: column_mappings.len() + 2, + start_reg: rowid_reg, + vtab_ptr: virtual_table.implementation.as_ref().ctx as usize, + conflict_action, + }); + + let halt_label = program.allocate_label(); + program.emit_insn(Insn::Halt { + err_code: 0, + description: String::new(), + }); + + program.resolve_label(halt_label, program.offset()); + program.resolve_label(init_label, program.offset()); + + program.emit_insn(Insn::Goto { + target_pc: start_offset, + }); + + Ok(()) +} diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index c2e18597c..b70541c77 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -110,6 +110,10 @@ pub fn init_loop( program.emit_insn(Insn::VOpenAsync { cursor_id }); program.emit_insn(Insn::VOpenAwait {}); } + (OperationMode::DELETE, Table::Virtual(_)) => { + program.emit_insn(Insn::VOpenAsync { cursor_id }); + program.emit_insn(Insn::VOpenAwait {}); + } _ => { unimplemented!() } @@ -286,22 +290,23 @@ pub fn open_loop( }, ), Table::Virtual(ref table) => { - let args = if let Some(args) = table.args.as_ref() { - args - } else { - &vec![] - }; - let start_reg = program.alloc_registers(args.len()); + let start_reg = program + .alloc_registers(table.args.as_ref().map(|a| a.len()).unwrap_or(0)); let mut cur_reg = start_reg; + let args = match table.args.as_ref() { + Some(args) => args, + None => &vec![], + }; for arg in args { let reg = cur_reg; cur_reg += 1; - translate_expr(program, Some(tables), &arg, reg, &t_ctx.resolver)?; + let _ = + translate_expr(program, Some(tables), arg, reg, &t_ctx.resolver)?; } program.emit_insn(Insn::VFilter { cursor_id, pc_if_empty: loop_end, - arg_count: args.len(), + arg_count: table.args.as_ref().map_or(0, |args| args.len()), args_reg: start_reg, }); } @@ -675,9 +680,9 @@ fn emit_loop_source( ); let offset_jump_to = t_ctx .labels_main_loop - .get(0) + .first() .map(|l| l.next) - .or_else(|| t_ctx.label_main_loop_end); + .or(t_ctx.label_main_loop_end); emit_select_result( program, t_ctx, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 3ee8f4ce0..fe49d05ce 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -33,8 +33,7 @@ use crate::vdbe::builder::{CursorType, ProgramBuilderOpts, QueryMode}; use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; -use limbo_sqlite3_parser::ast::{self, fmt::ToTokens}; -use limbo_sqlite3_parser::ast::{Delete, Insert}; +use limbo_sqlite3_parser::ast::{self, fmt::ToTokens, CreateVirtualTable, Delete, Insert}; use select::translate_select; use std::cell::RefCell; use std::fmt::Display; @@ -74,8 +73,8 @@ pub fn translate( } ast::Stmt::CreateTrigger { .. } => bail_parse_error!("CREATE TRIGGER not supported yet"), ast::Stmt::CreateView { .. } => bail_parse_error!("CREATE VIEW not supported yet"), - ast::Stmt::CreateVirtualTable { .. } => { - bail_parse_error!("CREATE VIRTUAL TABLE not supported yet") + ast::Stmt::CreateVirtualTable(vtab) => { + translate_create_virtual_table(*vtab, schema, query_mode)? } ast::Stmt::Delete(delete) => { let Delete { @@ -94,7 +93,7 @@ pub fn translate( ast::Stmt::DropView { .. } => bail_parse_error!("DROP VIEW not supported yet"), ast::Stmt::Pragma(name, body) => pragma::translate_pragma( query_mode, - &schema, + schema, &name, body.map(|b| *b), database_header.clone(), @@ -187,6 +186,7 @@ impl SchemaEntryType { } } } +const SQLITE_TABLEID: &str = "sqlite_schema"; fn emit_schema_entry( program: &mut ProgramBuilder, @@ -209,11 +209,18 @@ fn emit_schema_entry( program.emit_string8_new_reg(tbl_name.to_string()); let rootpage_reg = program.alloc_register(); - program.emit_insn(Insn::Copy { - src_reg: root_page_reg, - dst_reg: rootpage_reg, - amount: 1, - }); + if root_page_reg == 0 { + program.emit_insn(Insn::Integer { + dest: rootpage_reg, + value: 0, // virtual tables in sqlite always have rootpage=0 + }); + } else { + program.emit_insn(Insn::Copy { + src_reg: root_page_reg, + dst_reg: rootpage_reg, + amount: 1, + }); + } let sql_reg = program.alloc_register(); if let Some(sql) = sql { @@ -455,10 +462,9 @@ fn translate_create_table( }); } - let table_id = "sqlite_schema".to_string(); - let table = schema.get_table(&table_id).unwrap(); + let table = schema.get_btree_table(SQLITE_TABLEID).unwrap(); let sqlite_schema_cursor_id = program.alloc_cursor_id( - Some(table_id.to_owned()), + Some(SQLITE_TABLEID.to_owned()), CursorType::BTreeTable(table.clone()), ); program.emit_insn(Insn::OpenWriteAsync { @@ -546,3 +552,136 @@ fn create_table_body_to_str(tbl_name: &ast::QualifiedName, body: &ast::CreateTab } sql } + +fn create_vtable_body_to_str(vtab: &CreateVirtualTable) -> String { + let args = if let Some(args) = &vtab.args { + args.iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(", ") + } else { + "".to_string() + }; + let if_not_exists = if vtab.if_not_exists { + "IF NOT EXISTS " + } else { + "" + }; + format!( + "CREATE VIRTUAL TABLE {} {} USING {}{}", + vtab.tbl_name.name.0, + if_not_exists, + vtab.module_name.0, + if args.is_empty() { + String::new() + } else { + format!("({})", args) + } + ) +} + +fn translate_create_virtual_table( + vtab: CreateVirtualTable, + schema: &Schema, + query_mode: QueryMode, +) -> Result { + let ast::CreateVirtualTable { + if_not_exists, + tbl_name, + module_name, + args, + } = &vtab; + + let table_name = tbl_name.name.0.clone(); + let module_name_str = module_name.0.clone(); + let args_vec = args.clone().unwrap_or_default(); + + if schema.get_table(&table_name).is_some() && *if_not_exists { + let mut program = ProgramBuilder::new(ProgramBuilderOpts { + query_mode, + num_cursors: 1, + approx_num_insns: 5, + approx_num_labels: 1, + }); + let init_label = program.emit_init(); + program.emit_halt(); + program.resolve_label(init_label, program.offset()); + program.emit_transaction(true); + program.emit_constant_insns(); + return Ok(program); + } + + let mut program = ProgramBuilder::new(ProgramBuilderOpts { + query_mode, + num_cursors: 2, + approx_num_insns: 40, + approx_num_labels: 2, + }); + + let module_name_reg = program.emit_string8_new_reg(module_name_str.clone()); + let table_name_reg = program.emit_string8_new_reg(table_name.clone()); + + let args_reg = if !args_vec.is_empty() { + let args_start = program.alloc_register(); + + // Emit string8 instructions for each arg + for (i, arg) in args_vec.iter().enumerate() { + program.emit_string8(arg.clone(), args_start + i); + } + let args_record_reg = program.alloc_register(); + + // VCreate expects an array of args as a record + program.emit_insn(Insn::MakeRecord { + start_reg: args_start, + count: args_vec.len(), + dest_reg: args_record_reg, + }); + Some(args_record_reg) + } else { + None + }; + + program.emit_insn(Insn::VCreate { + module_name: module_name_reg, + table_name: table_name_reg, + args_reg, + }); + + let table = schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let sqlite_schema_cursor_id = program.alloc_cursor_id( + Some(SQLITE_TABLEID.to_owned()), + CursorType::BTreeTable(table.clone()), + ); + program.emit_insn(Insn::OpenWriteAsync { + cursor_id: sqlite_schema_cursor_id, + root_page: 1, + }); + program.emit_insn(Insn::OpenWriteAwait {}); + + let sql = create_vtable_body_to_str(&vtab); + emit_schema_entry( + &mut program, + sqlite_schema_cursor_id, + SchemaEntryType::Table, + &tbl_name.name.0, + &tbl_name.name.0, + 0, // virtual tables dont have a root page + Some(sql), + ); + + let parse_schema_where_clause = format!("tbl_name = '{}' AND type != 'trigger'", table_name); + program.emit_insn(Insn::ParseSchema { + db: sqlite_schema_cursor_id, + where_clause: parse_schema_where_clause, + }); + + let init_label = program.emit_init(); + let start_offset = program.offset(); + program.emit_halt(); + program.resolve_label(init_label, program.offset()); + program.emit_transaction(true); + program.emit_constant_insns(); + program.emit_goto(start_offset); + + Ok(program) +} diff --git a/core/translate/planner.rs b/core/translate/planner.rs index ffc7ee992..953a15e59 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -9,9 +9,9 @@ use super::{ use crate::{ function::Func, schema::{Schema, Table}, - util::{exprs_are_equivalent, normalize_ident}, + util::{exprs_are_equivalent, normalize_ident, vtable_args}, vdbe::BranchOffset, - Result, VirtualTable, + Result, }; use limbo_sqlite3_parser::ast::{ self, Expr, FromClause, JoinType, Limit, Materialized, UnaryOperator, With, @@ -310,9 +310,18 @@ fn parse_from_clause_table<'a>( ast::As::Elided(id) => id, }) .map(|a| a.0); + let tbl_ref = if let Table::Virtual(tbl) = table.as_ref() { + Table::Virtual(tbl.clone()) + } else if let Table::BTree(table) = table.as_ref() { + Table::BTree(table.clone()) + } else { + return Err(crate::LimboError::InvalidArgument( + "Table type not supported".to_string(), + )); + }; scope.tables.push(TableReference { op: Operation::Scan { iter_dir: None }, - table: Table::BTree(table.clone()), + table: tbl_ref, identifier: alias.unwrap_or(normalized_qualified_name), join_info: None, }); @@ -369,9 +378,18 @@ fn parse_from_clause_table<'a>( } ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => { let normalized_name = &normalize_ident(qualified_name.name.0.as_str()); - let Some(vtab) = syms.vtabs.get(normalized_name) else { - crate::bail_parse_error!("Virtual table {} not found", normalized_name); + let args = match maybe_args { + Some(ref args) => vtable_args(args), + None => vec![], }; + let vtab = crate::VirtualTable::from_args( + None, + normalized_name, + args, + syms, + limbo_ext::VTabKind::TableValuedFunction, + maybe_args, + )?; let alias = maybe_alias .as_ref() .map(|a| match a { @@ -383,18 +401,10 @@ fn parse_from_clause_table<'a>( scope.tables.push(TableReference { op: Operation::Scan { iter_dir: None }, join_info: None, - table: Table::Virtual( - VirtualTable { - name: normalized_name.clone(), - args: maybe_args, - implementation: vtab.implementation.clone(), - columns: vtab.columns.clone(), - } - .into(), - ) - .into(), - identifier: alias.clone(), + table: Table::Virtual(vtab), + identifier: alias, }); + Ok(()) } _ => todo!(), diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 1fd738acd..b33f38011 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -218,7 +218,7 @@ fn query_pragma( program.alloc_register(); program.alloc_register(); if let Some(table) = table { - for (i, column) in table.columns.iter().enumerate() { + for (i, column) in table.columns().iter().enumerate() { // cid program.emit_int(i as i64, base_reg); // name diff --git a/core/types.rs b/core/types.rs index 7eee76e94..f1dafd31e 100644 --- a/core/types.rs +++ b/core/types.rs @@ -223,8 +223,8 @@ impl OwnedValue { } } - pub fn from_ffi(v: &ExtValue) -> Result { - match v.value_type() { + pub fn from_ffi(v: ExtValue) -> Result { + let res = match v.value_type() { ExtValueType::Null => Ok(OwnedValue::Null), ExtValueType::Integer => { let Some(int) = v.to_integer() else { @@ -259,7 +259,11 @@ impl OwnedValue { (code, None) => Err(LimboError::ExtensionError(code.to_string())), } } + }; + unsafe { + v.free(); } + res } } @@ -281,8 +285,7 @@ impl AggContext { if let Self::External(ext_state) = self { if ext_state.finalized_value.is_none() { let final_value = unsafe { (ext_state.finalize_fn)(ext_state.state) }; - ext_state.cache_final_value(OwnedValue::from_ffi(&final_value)?); - unsafe { final_value.free() }; + ext_state.cache_final_value(OwnedValue::from_ffi(final_value)?); } } Ok(()) diff --git a/core/util.rs b/core/util.rs index 7bc73a42f..3ed12c6d6 100644 --- a/core/util.rs +++ b/core/util.rs @@ -5,7 +5,7 @@ use std::{rc::Rc, sync::Arc}; use crate::{ schema::{self, Column, Schema, Type}, types::OwnedValue, - LimboError, OpenFlags, Result, Statement, StepResult, IO, + LimboError, OpenFlags, Result, Statement, StepResult, SymbolTable, IO, }; // https://sqlite.org/lang_keywords.html @@ -30,6 +30,7 @@ pub fn parse_schema_rows( rows: Option, schema: &mut Schema, io: Arc, + syms: &SymbolTable, ) -> Result<()> { if let Some(mut rows) = rows { let mut automatic_indexes = Vec::new(); @@ -38,15 +39,21 @@ pub fn parse_schema_rows( StepResult::Row => { let row = rows.row().unwrap(); let ty = row.get::<&str>(0)?; - if ty != "table" && ty != "index" { + if !["table", "index"].contains(&ty) { continue; } match ty { "table" => { let root_page: i64 = row.get::(3)?; let sql: &str = row.get::<&str>(4)?; - let table = schema::BTreeTable::from_sql(sql, root_page as usize)?; - schema.add_table(Rc::new(table)); + if root_page == 0 && sql.to_lowercase().contains("virtual") { + let name: &str = row.get::<&str>(1)?; + let vtab = syms.vtabs.get(name).unwrap().clone(); + schema.add_virtual_table(vtab); + } else { + let table = schema::BTreeTable::from_sql(sql, root_page as usize)?; + schema.add_btree_table(Rc::new(table)); + } } "index" => { let root_page: i64 = row.get::(3)?; @@ -85,7 +92,7 @@ pub fn parse_schema_rows( } for (index_name, table_name, root_page) in automatic_indexes { // We need to process these after all tables are loaded into memory due to the schema.get_table() call - let table = schema.get_table(&table_name).unwrap(); + let table = schema.get_btree_table(&table_name).unwrap(); let index = schema::Index::automatic_from_primary_key(&table, &index_name, root_page as usize)?; schema.add_index(Rc::new(index)); @@ -309,9 +316,11 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { } } -pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result, ()> { +pub fn columns_from_create_table_body(body: &ast::CreateTableBody) -> crate::Result> { let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else { - return Err(()); + return Err(crate::LimboError::ParseError( + "CREATE TABLE body must contain columns and constraints".to_string(), + )); }; Ok(columns @@ -324,7 +333,7 @@ pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result { // https://www.sqlite.org/datatype3.html @@ -782,6 +791,35 @@ pub fn text_to_real(text: &str) -> (OwnedValue, CastTextToRealResultCode) { return (OwnedValue::Float(0.0), CastTextToRealResultCode::NotValid); } +// for TVF's we need these at planning time so we cannot emit translate_expr +pub fn vtable_args(args: &[ast::Expr]) -> Vec { + let mut vtable_args = Vec::new(); + for arg in args { + match arg { + Expr::Literal(lit) => match lit { + Literal::Numeric(i) => { + if i.contains('.') { + vtable_args.push(limbo_ext::Value::from_float(i.parse().unwrap())); + } else { + vtable_args.push(limbo_ext::Value::from_integer(i.parse().unwrap())); + } + } + Literal::String(s) => { + vtable_args.push(limbo_ext::Value::from_text(s.clone())); + } + Literal::Blob(b) => { + vtable_args.push(limbo_ext::Value::from_blob(b.as_bytes().into())); + } + _ => { + vtable_args.push(limbo_ext::Value::null()); + } + }, + _ => vtable_args.push(limbo_ext::Value::null()), + } + } + vtable_args +} + #[cfg(test)] pub mod tests { use super::*; diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 716610b71..470997158 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -50,7 +50,7 @@ impl CursorType { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Copy)] pub enum QueryMode { Normal, Explain, diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 2b61310d5..a609fc667 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -381,6 +381,19 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::VCreate { + table_name, + module_name, + args_reg, + } => ( + "VCreate", + *table_name as i32, + *module_name as i32, + args_reg.unwrap_or(0) as i32, + OwnedValue::build_text(""), + 0, + format!("table={}, module={}", table_name, module_name), + ), Insn::VFilter { cursor_id, pc_if_empty, @@ -408,6 +421,21 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::VUpdate { + cursor_id, + arg_count, // P2: Number of arguments in argv[] + start_reg, // P3: Start register for argv[] + vtab_ptr, // P4: vtab pointer + conflict_action, // P5: Conflict resolution flags + } => ( + "VUpdate", + *cursor_id as i32, + *arg_count as i32, + *start_reg as i32, + OwnedValue::build_text(&format!("vtab:{}", vtab_ptr)), + *conflict_action, + format!("args=r[{}..{}]", start_reg, start_reg + arg_count - 1), + ), Insn::VNext { cursor_id, pc_if_next, diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 01b6b5469..44bc5e52f 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -220,6 +220,13 @@ pub enum Insn { /// Await for the completion of open cursor for a virtual table. VOpenAwait, + /// Create a new virtual table. + VCreate { + module_name: usize, // P1: Name of the module that contains the virtual table implementation + table_name: usize, // P2: Name of the virtual table + args_reg: Option, + }, + /// Initialize the position of the virtual table cursor. VFilter { cursor_id: CursorID, @@ -235,6 +242,15 @@ pub enum Insn { dest: usize, }, + /// `VUpdate`: Virtual Table Insert/Update/Delete Instruction + VUpdate { + cursor_id: usize, // P1: Virtual table cursor number + arg_count: usize, // P2: Number of arguments in argv[] + start_reg: usize, // P3: Start register for argv[] + vtab_ptr: usize, // P4: vtab pointer + conflict_action: u16, // P5: Conflict resolution flags + }, + /// Advance the virtual table cursor to the next row. /// TODO: async VNext { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index d11acbf32..8939025e9 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -173,13 +173,11 @@ macro_rules! call_external_function { ) => {{ if $arg_count == 0 { let result_c_value: ExtValue = unsafe { ($func_ptr)(0, std::ptr::null()) }; - match OwnedValue::from_ffi(&result_c_value) { + match OwnedValue::from_ffi(result_c_value) { Ok(result_ov) => { $state.registers[$dest_register] = result_ov; - unsafe { result_c_value.free() }; } Err(e) => { - unsafe { result_c_value.free() }; return Err(e); } } @@ -192,13 +190,14 @@ macro_rules! call_external_function { } let argv_ptr = ext_values.as_ptr(); let result_c_value: ExtValue = unsafe { ($func_ptr)($arg_count as i32, argv_ptr) }; - match OwnedValue::from_ffi(&result_c_value) { + for arg in ext_values { + unsafe { arg.free() }; + } + match OwnedValue::from_ffi(result_c_value) { Ok(result_ov) => { $state.registers[$dest_register] = result_ov; - unsafe { result_c_value.free() }; } Err(e) => { - unsafe { result_c_value.free() }; return Err(e); } } @@ -308,14 +307,19 @@ impl Bitfield { } } -pub struct VTabOpaqueCursor(*mut c_void); +pub struct VTabOpaqueCursor(*const c_void); impl VTabOpaqueCursor { - pub fn new(cursor: *mut c_void) -> Self { - Self(cursor) + pub fn new(cursor: *const c_void) -> Result { + if cursor.is_null() { + return Err(LimboError::InternalError( + "VTabOpaqueCursor: cursor is null".into(), + )); + } + Ok(Self(cursor)) } - pub fn as_ptr(&self) -> *mut c_void { + pub fn as_ptr(&self) -> *const c_void { self.0 } } @@ -870,13 +874,54 @@ impl Program { let CursorType::VirtualTable(virtual_table) = cursor_type else { panic!("VOpenAsync on non-virtual table cursor"); }; - let cursor = virtual_table.open(); + let cursor = virtual_table.open()?; state .cursors .borrow_mut() .insert(*cursor_id, Some(Cursor::Virtual(cursor))); state.pc += 1; } + Insn::VCreate { + module_name, + table_name, + args_reg, + } => { + let module_name = state.registers[*module_name].to_string(); + let table_name = state.registers[*table_name].to_string(); + let args = if let Some(args_reg) = args_reg { + if let OwnedValue::Record(rec) = &state.registers[*args_reg] { + rec.get_values().iter().map(|v| v.to_ffi()).collect() + } else { + return Err(LimboError::InternalError( + "VCreate: args_reg is not a record".to_string(), + )); + } + } else { + vec![] + }; + let Some(conn) = self.connection.upgrade() else { + return Err(crate::LimboError::ExtensionError( + "Failed to upgrade Connection".to_string(), + )); + }; + let table = crate::VirtualTable::from_args( + Some(&table_name), + &module_name, + args, + &conn.db.syms.borrow(), + limbo_ext::VTabKind::VirtualTable, + None, + )?; + { + conn.db + .syms + .as_ref() + .borrow_mut() + .vtabs + .insert(table_name, table.clone()); + } + state.pc += 1; + } Insn::VOpenAwait => { state.pc += 1; } @@ -917,6 +962,59 @@ impl Program { state.registers[*dest] = virtual_table.column(cursor, *column)?; state.pc += 1; } + Insn::VUpdate { + cursor_id, + arg_count, + start_reg, + conflict_action, + .. + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VUpdate on non-virtual table cursor"); + }; + + if *arg_count < 2 { + return Err(LimboError::InternalError( + "VUpdate: arg_count must be at least 2 (rowid and insert_rowid)" + .to_string(), + )); + } + let mut argv = Vec::with_capacity(*arg_count); + for i in 0..*arg_count { + if let Some(value) = state.registers.get(*start_reg + i) { + argv.push(value.clone()); + } else { + return Err(LimboError::InternalError(format!( + "VUpdate: register out of bounds at {}", + *start_reg + i + ))); + } + } + let result = virtual_table.update(&argv); + match result { + Ok(Some(new_rowid)) => { + if *conflict_action == 5 { + // ResolveType::Replace + if let Some(conn) = self.connection.upgrade() { + conn.update_last_rowid(new_rowid as u64); + } + } + state.pc += 1; + } + Ok(None) => { + // no-op or successful update without rowid return + state.pc += 1; + } + Err(e) => { + // virtual table update failed + return Err(LimboError::ExtensionError(format!( + "Virtual table update failed: {}", + e + ))); + } + } + } Insn::VNext { cursor_id, pc_if_next, @@ -1257,11 +1355,30 @@ impl Program { } } - let cursor = get_cursor_as_table_mut(&mut cursors, *cursor_id); - if let Some(ref rowid) = cursor.rowid()? { - state.registers[*dest] = OwnedValue::Integer(*rowid as i64); + if let Some(Cursor::Table(btree_cursor)) = cursors.get_mut(*cursor_id).unwrap() + { + if let Some(ref rowid) = btree_cursor.rowid()? { + state.registers[*dest] = OwnedValue::Integer(*rowid as i64); + } else { + state.registers[*dest] = OwnedValue::Null; + } + } else if let Some(Cursor::Virtual(virtual_cursor)) = + cursors.get_mut(*cursor_id).unwrap() + { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VUpdate on non-virtual table cursor"); + }; + let rowid = virtual_table.rowid(virtual_cursor); + if rowid != 0 { + state.registers[*dest] = OwnedValue::Integer(rowid); + } else { + state.registers[*dest] = OwnedValue::Null; + } } else { - state.registers[*dest] = OwnedValue::Null; + return Err(LimboError::InternalError( + "RowId: cursor is not a table or virtual cursor".to_string(), + )); } state.pc += 1; } @@ -1744,6 +1861,9 @@ impl Program { } let argv_ptr = ext_values.as_ptr(); unsafe { step_fn(state_ptr, argc as i32, argv_ptr) }; + for ext_value in ext_values { + unsafe { ext_value.free() }; + } } } }; @@ -2744,8 +2864,13 @@ impl Program { where_clause ))?; let mut schema = RefCell::borrow_mut(&conn.schema); - // TODO: This function below is synchronous, make it not async - parse_schema_rows(Some(stmt), &mut schema, conn.pager.io.clone())?; + // TODO: This function below is synchronous, make it async + parse_schema_rows( + Some(stmt), + &mut schema, + conn.pager.io.clone(), + &conn.db.syms.borrow(), + )?; state.pc += 1; } Insn::ReadCookie { db, dest, cookie } => { diff --git a/extensions/core/README.md b/extensions/core/README.md index f2aedc960..226ed09e0 100644 --- a/extensions/core/README.md +++ b/extensions/core/README.md @@ -95,6 +95,9 @@ impl AggFunc for Percentile { /// The state to track during the steps type State = (Vec, Option, Option); // Tracks the values, Percentile, and errors + /// Define your error type, must impl Display + type Error = String; + /// Define the name you wish to call your function by. /// e.g. SELECT percentile(value, 40); const NAME: &str = "percentile"; @@ -129,15 +132,15 @@ impl AggFunc for Percentile { } /// A function to finalize the state into a value to be returned as a result /// or an error (if you chose to track an error state as well) - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { let (mut values, p_value, error) = state; if let Some(error) = error { - return Value::custom_error(error); + return Err(error); } if values.is_empty() { - return Value::null(); + return Ok(Value::null()); } values.sort_by(|a, b| a.partial_cmp(b).unwrap()); @@ -145,7 +148,7 @@ impl AggFunc for Percentile { let p = p_value.unwrap(); let index = (p * (n - 1.0) / 100.0).floor() as usize; - Value::from_float(values[index]) + Ok(Value::from_float(values[index])) } } ``` @@ -161,21 +164,25 @@ struct CsvVTable; impl VTabModule for CsvVTable { type VCursor = CsvCursor; + /// Define your error type. Must impl Display and match VCursor::Error + type Error = &'static str; /// Declare the name for your virtual table const NAME: &'static str = "csv_data"; - /// Declare the table schema and call `api.declare_virtual_table` with the schema sql. - fn connect(api: &ExtensionApi) -> ResultCode { - let sql = "CREATE TABLE csv_data( + /// Declare the type of vtable (TableValuedFunction or VirtualTable) + const VTAB_KIND: VTabKind = VTabKind::VirtualTable; + + /// Function to initialize the schema of your vtable + fn create_schema(_args: &[Value]) -> &'static str { + "CREATE TABLE csv_data( name TEXT, age TEXT, city TEXT - )"; - api.declare_virtual_table(Self::NAME, sql) + )" } /// Open to return a new cursor: In this simple example, the CSV file is read completely into memory on connect. - fn open() -> Self::VCursor { + fn open(&self) -> Result { // Read CSV file contents from "data.csv" let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); // For simplicity, we'll ignore the header row. @@ -188,16 +195,16 @@ impl VTabModule for CsvVTable { .collect() }) .collect(); - CsvCursor { rows, index: 0 } + Ok(CsvCursor { rows, index: 0 }) } /// Filter through result columns. (not used in this simple example) - fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { + fn filter(_cursor: &mut Self::VCursor, _args: &[Value]) -> ResultCode { ResultCode::OK } /// Return the value for the column at the given index in the current row. - fn column(cursor: &Self::VCursor, idx: u32) -> Value { + fn column(cursor: &Self::VCursor, idx: u32) -> Result { cursor.column(idx) } @@ -215,6 +222,22 @@ impl VTabModule for CsvVTable { fn eof(cursor: &Self::VCursor) -> bool { cursor.index >= cursor.rows.len() } + + /// *Optional* methods for non-readonly tables + + /// Update the value at rowid + fn update(&mut self, _rowid: i64, _args: &[Value]) -> Result<(), Self::Error> { + Ok(()) + } + + /// Insert the value(s) + fn insert(&mut self, _args: &[Value]) -> Result { + Ok(0) + } + /// Delete the value at rowid + fn delete(&mut self, _rowid: i64) -> Result<(), Self::Error> { + Ok(()) + } } /// The cursor for iterating over CSV rows. @@ -226,6 +249,8 @@ struct CsvCursor { /// Implement the VTabCursor trait for your cursor type impl VTabCursor for CsvCursor { + type Error = &'static str; + fn next(&mut self) -> ResultCode { CsvCursor::next(self) } @@ -234,12 +259,12 @@ impl VTabCursor for CsvCursor { self.index >= self.rows.len() } - fn column(&self, idx: u32) -> Value { + fn column(&self, idx: u32) -> Result { let row = &self.rows[self.index]; if (idx as usize) < row.len() { - Value::from_text(&row[idx as usize]) + Ok(Value::from_text(&row[idx as usize])) } else { - Value::null() + Ok(Value::null()) } } diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 22d90f572..2951d591e 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -1,63 +1,48 @@ mod types; pub use limbo_macros::{register_extension, scalar, AggregateDerive, VTabModuleDerive}; -use std::os::raw::{c_char, c_void}; +use std::{ + fmt::Display, + os::raw::{c_char, c_void}, +}; pub use types::{ResultCode, Value, ValueType}; +pub type ExtResult = std::result::Result; + #[repr(C)] pub struct ExtensionApi { pub ctx: *mut c_void, - - pub register_scalar_function: unsafe extern "C" fn( - ctx: *mut c_void, - name: *const c_char, - func: ScalarFunction, - ) -> ResultCode, - - pub register_aggregate_function: unsafe extern "C" fn( - ctx: *mut c_void, - name: *const c_char, - args: i32, - init_func: InitAggFunction, - step_func: StepFunction, - finalize_func: FinalizeFunction, - ) -> ResultCode, - - pub register_module: unsafe extern "C" fn( - ctx: *mut c_void, - name: *const c_char, - module: VTabModuleImpl, - ) -> ResultCode, - - pub declare_vtab: unsafe extern "C" fn( - ctx: *mut c_void, - name: *const c_char, - sql: *const c_char, - ) -> ResultCode, -} - -impl ExtensionApi { - pub fn declare_virtual_table(&self, name: &str, sql: &str) -> ResultCode { - let Ok(name) = std::ffi::CString::new(name) else { - return ResultCode::Error; - }; - let Ok(sql) = std::ffi::CString::new(sql) else { - return ResultCode::Error; - }; - unsafe { (self.declare_vtab)(self.ctx, name.as_ptr(), sql.as_ptr()) } - } + pub register_scalar_function: RegisterScalarFn, + pub register_aggregate_function: RegisterAggFn, + pub register_module: RegisterModuleFn, } pub type ExtensionEntryPoint = unsafe extern "C" fn(api: *const ExtensionApi) -> ResultCode; + pub type ScalarFunction = unsafe extern "C" fn(argc: i32, *const Value) -> Value; +pub type RegisterScalarFn = + unsafe extern "C" fn(ctx: *mut c_void, name: *const c_char, func: ScalarFunction) -> ResultCode; + +pub type RegisterAggFn = unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + args: i32, + init: InitAggFunction, + step: StepFunction, + finalize: FinalizeFunction, +) -> ResultCode; + +pub type RegisterModuleFn = unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + module: VTabModuleImpl, + kind: VTabKind, +) -> ResultCode; + pub type InitAggFunction = unsafe extern "C" fn() -> *mut AggCtx; pub type StepFunction = unsafe extern "C" fn(ctx: *mut AggCtx, argc: i32, argv: *const Value); pub type FinalizeFunction = unsafe extern "C" fn(ctx: *mut AggCtx) -> Value; -pub trait Scalar { - fn call(&self, args: &[Value]) -> Value; -} - #[repr(C)] pub struct AggCtx { pub state: *mut c_void, @@ -65,59 +50,99 @@ pub struct AggCtx { pub trait AggFunc { type State: Default; + type Error: Display; const NAME: &'static str; const ARGS: i32; fn step(state: &mut Self::State, args: &[Value]); - fn finalize(state: Self::State) -> Value; + fn finalize(state: Self::State) -> Result; } #[repr(C)] #[derive(Clone, Debug)] pub struct VTabModuleImpl { + pub ctx: *const c_void, pub name: *const c_char, - pub connect: VtabFnConnect, + pub create_schema: VtabFnCreateSchema, pub open: VtabFnOpen, pub filter: VtabFnFilter, pub column: VtabFnColumn, pub next: VtabFnNext, pub eof: VtabFnEof, + pub update: VtabFnUpdate, + pub rowid: VtabRowIDFn, } -pub type VtabFnConnect = unsafe extern "C" fn(api: *const c_void) -> ResultCode; +impl VTabModuleImpl { + pub fn init_schema(&self, args: Vec) -> ExtResult { + let schema = unsafe { (self.create_schema)(args.as_ptr(), args.len() as i32) }; + if schema.is_null() { + return Err(ResultCode::InvalidArgs); + } + for arg in args { + unsafe { arg.free() }; + } + let schema = unsafe { std::ffi::CString::from_raw(schema) }; + Ok(schema.to_string_lossy().to_string()) + } +} -pub type VtabFnOpen = unsafe extern "C" fn() -> *mut c_void; +pub type VtabFnCreateSchema = unsafe extern "C" fn(args: *const Value, argc: i32) -> *mut c_char; + +pub type VtabFnOpen = unsafe extern "C" fn(*const c_void) -> *const c_void; pub type VtabFnFilter = - unsafe extern "C" fn(cursor: *mut c_void, argc: i32, argv: *const Value) -> ResultCode; + unsafe extern "C" fn(cursor: *const c_void, argc: i32, argv: *const Value) -> ResultCode; -pub type VtabFnColumn = unsafe extern "C" fn(cursor: *mut c_void, idx: u32) -> Value; +pub type VtabFnColumn = unsafe extern "C" fn(cursor: *const c_void, idx: u32) -> Value; -pub type VtabFnNext = unsafe extern "C" fn(cursor: *mut c_void) -> ResultCode; +pub type VtabFnNext = unsafe extern "C" fn(cursor: *const c_void) -> ResultCode; -pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; +pub type VtabFnEof = unsafe extern "C" fn(cursor: *const c_void) -> bool; + +pub type VtabRowIDFn = unsafe extern "C" fn(cursor: *const c_void) -> i64; + +pub type VtabFnUpdate = unsafe extern "C" fn( + vtab: *const c_void, + argc: i32, + argv: *const Value, + p_out_rowid: *mut i64, +) -> ResultCode; + +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum VTabKind { + VirtualTable, + TableValuedFunction, +} pub trait VTabModule: 'static { - type VCursor: VTabCursor; + type VCursor: VTabCursor; + const VTAB_KIND: VTabKind; const NAME: &'static str; + type Error: std::fmt::Display; - fn connect(api: &ExtensionApi) -> ResultCode; - fn open() -> Self::VCursor; - fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; - fn column(cursor: &Self::VCursor, idx: u32) -> Value; + fn create_schema(args: &[Value]) -> String; + fn open(&self) -> Result; + fn filter(cursor: &mut Self::VCursor, args: &[Value]) -> ResultCode; + fn column(cursor: &Self::VCursor, idx: u32) -> Result; fn next(cursor: &mut Self::VCursor) -> ResultCode; fn eof(cursor: &Self::VCursor) -> bool; + fn update(&mut self, _rowid: i64, _args: &[Value]) -> Result<(), Self::Error> { + Ok(()) + } + fn insert(&mut self, _args: &[Value]) -> Result { + Ok(0) + } + fn delete(&mut self, _rowid: i64) -> Result<(), Self::Error> { + Ok(()) + } } pub trait VTabCursor: Sized { type Error: std::fmt::Display; fn rowid(&self) -> i64; - fn column(&self, idx: u32) -> Value; + fn column(&self, idx: u32) -> Result; fn eof(&self) -> bool; fn next(&mut self) -> ResultCode; } - -#[repr(C)] -pub struct VTabImpl { - pub module: VTabModuleImpl, -} diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index c29768d7f..f08fe099e 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -21,6 +21,8 @@ pub enum ResultCode { Unavailable = 13, CustomError = 14, EOF = 15, + ReadOnly = 16, + RowID = 17, } impl ResultCode { @@ -52,6 +54,8 @@ impl Display for ResultCode { ResultCode::Unavailable => write!(f, "Unavailable"), ResultCode::CustomError => write!(f, "Error "), ResultCode::EOF => write!(f, "EOF"), + ResultCode::ReadOnly => write!(f, "Read Only"), + ResultCode::RowID => write!(f, "RowID"), } } } @@ -403,6 +407,7 @@ impl Value { } } + /// Extension authors should __not__ use this function. /// # Safety /// consumes the value while freeing the underlying memory with null check. /// however this does assume that the type was properly constructed with diff --git a/extensions/kvstore/Cargo.toml b/extensions/kvstore/Cargo.toml new file mode 100644 index 000000000..cac010bb6 --- /dev/null +++ b/extensions/kvstore/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "limbo_kv" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +static= [ "limbo_ext/static" ] + +[dependencies] +lazy_static = "1.5.0" +limbo_ext = { workspace = true, features = ["static"] } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } diff --git a/extensions/kvstore/src/lib.rs b/extensions/kvstore/src/lib.rs new file mode 100644 index 000000000..a9de7c71d --- /dev/null +++ b/extensions/kvstore/src/lib.rs @@ -0,0 +1,147 @@ +use lazy_static::lazy_static; +use limbo_ext::{ + register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, +}; +use std::collections::BTreeMap; +use std::sync::Mutex; + +lazy_static! { + static ref GLOBAL_STORE: Mutex> = Mutex::new(BTreeMap::new()); +} + +register_extension! { + vtabs: { KVStoreVTab }, +} + +#[derive(VTabModuleDerive, Default)] +pub struct KVStoreVTab; + +/// the cursor holds a snapshot of (rowid, key, value) in memory. +pub struct KVStoreCursor { + rows: Vec<(i64, String, String)>, + index: Option, +} + +impl VTabModule for KVStoreVTab { + type VCursor = KVStoreCursor; + const VTAB_KIND: VTabKind = VTabKind::VirtualTable; + const NAME: &'static str = "kv_store"; + type Error = String; + + fn create_schema(_args: &[Value]) -> String { + "CREATE TABLE x (key TEXT PRIMARY KEY, value TEXT);".to_string() + } + + fn open(&self) -> Result { + Ok(KVStoreCursor { + rows: Vec::new(), + index: None, + }) + } + + fn filter(cursor: &mut Self::VCursor, _args: &[Value]) -> ResultCode { + let store = GLOBAL_STORE.lock().unwrap(); + cursor.rows = store + .iter() + .map(|(&rowid, (k, v))| (rowid, k.clone(), v.clone())) + .collect(); + cursor.rows.sort_by_key(|(rowid, _, _)| *rowid); + + if cursor.rows.is_empty() { + cursor.index = None; + return ResultCode::EOF; + } else { + cursor.index = Some(0); + } + ResultCode::OK + } + + fn insert(&mut self, values: &[Value]) -> Result { + let key = values + .first() + .and_then(|v| v.to_text()) + .ok_or("Missing key")? + .to_string(); + let val = values + .get(1) + .and_then(|v| v.to_text()) + .ok_or("Missing value")? + .to_string(); + let rowid = hash_key(&key); + { + let mut store = GLOBAL_STORE.lock().unwrap(); + store.insert(rowid, (key, val)); + } + Ok(rowid) + } + + fn delete(&mut self, rowid: i64) -> Result<(), Self::Error> { + let mut store = GLOBAL_STORE.lock().unwrap(); + store.remove(&rowid); + Ok(()) + } + + fn update(&mut self, rowid: i64, values: &[Value]) -> Result<(), Self::Error> { + { + let mut store = GLOBAL_STORE.lock().unwrap(); + store.remove(&rowid); + } + let _ = self.insert(values)?; + Ok(()) + } + fn eof(cursor: &Self::VCursor) -> bool { + cursor.index.is_some_and(|s| s >= cursor.rows.len()) || cursor.index.is_none() + } + + fn next(cursor: &mut Self::VCursor) -> ResultCode { + cursor.index = Some(cursor.index.unwrap_or(0) + 1); + if cursor.index.is_some_and(|c| c >= cursor.rows.len()) { + return ResultCode::EOF; + } + ResultCode::OK + } + + fn column(cursor: &Self::VCursor, idx: u32) -> Result { + if cursor.index.is_some_and(|c| c >= cursor.rows.len()) { + return Err("cursor out of range".into()); + } + let (_, ref key, ref val) = cursor.rows[cursor.index.unwrap_or(0)]; + match idx { + 0 => Ok(Value::from_text(key.clone())), // key + 1 => Ok(Value::from_text(val.clone())), // value + _ => Err("Invalid column".into()), + } + } +} + +fn hash_key(key: &str) -> i64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + key.hash(&mut hasher); + hasher.finish() as i64 +} + +impl VTabCursor for KVStoreCursor { + type Error = String; + + fn rowid(&self) -> i64 { + if self.index.is_some_and(|c| c < self.rows.len()) { + self.rows[self.index.unwrap_or(0)].0 + } else { + println!("rowid: -1"); + -1 + } + } + + fn column(&self, idx: u32) -> Result { + ::column(self, idx) + } + + fn eof(&self) -> bool { + ::eof(self) + } + + fn next(&mut self) -> ResultCode { + ::next(self) + } +} diff --git a/extensions/percentile/src/lib.rs b/extensions/percentile/src/lib.rs index 9f81a6674..4b0b7bd83 100644 --- a/extensions/percentile/src/lib.rs +++ b/extensions/percentile/src/lib.rs @@ -9,6 +9,7 @@ struct Median; impl AggFunc for Median { type State = Vec; + type Error = &'static str; const NAME: &'static str = "median"; const ARGS: i32 = 1; @@ -18,9 +19,9 @@ impl AggFunc for Median { } } - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { if state.is_empty() { - return Value::null(); + return Ok(Value::null()); } let mut sorted = state; @@ -28,11 +29,11 @@ impl AggFunc for Median { let len = sorted.len(); if len % 2 == 1 { - Value::from_float(sorted[len / 2]) + Ok(Value::from_float(sorted[len / 2])) } else { let mid1 = sorted[len / 2 - 1]; let mid2 = sorted[len / 2]; - Value::from_float((mid1 + mid2) / 2.0) + Ok(Value::from_float((mid1 + mid2) / 2.0)) } } } @@ -41,8 +42,8 @@ impl AggFunc for Median { struct Percentile; impl AggFunc for Percentile { - type State = (Vec, Option, Option<&'static str>); - + type State = (Vec, Option, Option); + type Error = &'static str; const NAME: &'static str = "percentile"; const ARGS: i32 = 2; @@ -69,16 +70,16 @@ impl AggFunc for Percentile { } } - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { let (mut values, p_value, err_value) = state; if values.is_empty() { - return Value::null(); + return Ok(Value::null()); } if let Some(err) = err_value { - return Value::error_with_message(err.into()); + return Err(err); } if values.len() == 1 { - return Value::from_float(values[0]); + return Ok(Value::from_float(values[0])); } let p = p_value.unwrap(); @@ -89,10 +90,12 @@ impl AggFunc for Percentile { let upper = index.ceil() as usize; if lower == upper { - Value::from_float(values[lower]) + Ok(Value::from_float(values[lower])) } else { let weight = index - lower as f64; - Value::from_float(values[lower] * (1.0 - weight) + values[upper] * weight) + Ok(Value::from_float( + values[lower] * (1.0 - weight) + values[upper] * weight, + )) } } } @@ -101,8 +104,8 @@ impl AggFunc for Percentile { struct PercentileCont; impl AggFunc for PercentileCont { - type State = (Vec, Option, Option<&'static str>); - + type State = (Vec, Option, Option); + type Error = &'static str; const NAME: &'static str = "percentile_cont"; const ARGS: i32 = 2; @@ -129,16 +132,16 @@ impl AggFunc for PercentileCont { } } - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { let (mut values, p_value, err_state) = state; if values.is_empty() { - return Value::null(); + return Ok(Value::null()); } if let Some(err) = err_state { - return Value::error_with_message(err.into()); + return Err(err); } if values.len() == 1 { - return Value::from_float(values[0]); + return Ok(Value::from_float(values[0])); } let p = p_value.unwrap(); @@ -149,10 +152,12 @@ impl AggFunc for PercentileCont { let upper = index.ceil() as usize; if lower == upper { - Value::from_float(values[lower]) + Ok(Value::from_float(values[lower])) } else { let weight = index - lower as f64; - Value::from_float(values[lower] * (1.0 - weight) + values[upper] * weight) + Ok(Value::from_float( + values[lower] * (1.0 - weight) + values[upper] * weight, + )) } } } @@ -161,8 +166,8 @@ impl AggFunc for PercentileCont { struct PercentileDisc; impl AggFunc for PercentileDisc { - type State = (Vec, Option, Option<&'static str>); - + type State = (Vec, Option, Option); + type Error = &'static str; const NAME: &'static str = "percentile_disc"; const ARGS: i32 = 2; @@ -170,19 +175,19 @@ impl AggFunc for PercentileDisc { Percentile::step(state, args); } - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { let (mut values, p_value, err_value) = state; if values.is_empty() { - return Value::null(); + return Ok(Value::null()); } if let Some(err) = err_value { - return Value::error_with_message(err.into()); + return Err(err); } let p = p_value.unwrap(); values.sort_by(|a, b| a.partial_cmp(b).unwrap()); let n = values.len() as f64; let index = (p * (n - 1.0)).floor() as usize; - Value::from_float(values[index]) + Ok(Value::from_float(values[index])) } } diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index df2d67da1..43028eed5 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -1,5 +1,5 @@ use limbo_ext::{ - register_extension, ExtensionApi, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value, + register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, }; register_extension! { @@ -16,36 +16,38 @@ macro_rules! try_option { } /// A virtual table that generates a sequence of integers -#[derive(Debug, VTabModuleDerive)] +#[derive(Debug, VTabModuleDerive, Default)] struct GenerateSeriesVTab; impl VTabModule for GenerateSeriesVTab { type VCursor = GenerateSeriesCursor; + type Error = ResultCode; const NAME: &'static str = "generate_series"; + const VTAB_KIND: VTabKind = VTabKind::TableValuedFunction; - fn connect(api: &ExtensionApi) -> ResultCode { + fn create_schema(_args: &[Value]) -> String { // Create table schema - let sql = "CREATE TABLE generate_series( + "CREATE TABLE generate_series( value INTEGER, start INTEGER HIDDEN, stop INTEGER HIDDEN, step INTEGER HIDDEN - )"; - api.declare_virtual_table(Self::NAME, sql) + )" + .into() } - fn open() -> Self::VCursor { - GenerateSeriesCursor { + fn open(&self) -> Result { + Ok(GenerateSeriesCursor { start: 0, stop: 0, step: 0, current: 0, - } + }) } - fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { + fn filter(cursor: &mut Self::VCursor, args: &[Value]) -> ResultCode { // args are the start, stop, and step - if arg_count == 0 || arg_count > 3 { + if args.is_empty() || args.len() > 3 { return ResultCode::InvalidArgs; } let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); @@ -78,7 +80,7 @@ impl VTabModule for GenerateSeriesVTab { ResultCode::OK } - fn column(cursor: &Self::VCursor, idx: u32) -> Value { + fn column(cursor: &Self::VCursor, idx: u32) -> Result { cursor.column(idx) } @@ -163,14 +165,14 @@ impl VTabCursor for GenerateSeriesCursor { false } - fn column(&self, idx: u32) -> Value { - match idx { + fn column(&self, idx: u32) -> Result { + Ok(match idx { 0 => Value::from_integer(self.current), 1 => Value::from_integer(self.start), 2 => Value::from_integer(self.stop), 3 => Value::from_integer(self.step), _ => Value::null(), - } + }) } fn rowid(&self) -> i64 { @@ -227,7 +229,8 @@ mod tests { } // Helper function to collect all values from a cursor, returns Result with error code fn collect_series(series: Series) -> Result, ResultCode> { - let mut cursor = GenerateSeriesVTab::open(); + let tbl = GenerateSeriesVTab; + let mut cursor = tbl.open()?; // Create args array for filter let args = vec![ @@ -237,7 +240,7 @@ mod tests { ]; // Initialize cursor through filter - match GenerateSeriesVTab::filter(&mut cursor, 3, &args) { + match GenerateSeriesVTab::filter(&mut cursor, &args) { ResultCode::OK => (), ResultCode::EOF => return Ok(vec![]), err => return Err(err), @@ -245,7 +248,7 @@ mod tests { let mut values = Vec::new(); loop { - values.push(cursor.column(0).to_integer().unwrap()); + values.push(cursor.column(0)?.to_integer().unwrap()); if values.len() > 1000 { panic!( "Generated more than 1000 values, expected this many: {:?}", @@ -543,8 +546,8 @@ mod tests { let start = series.start; let stop = series.stop; let step = series.step; - - let mut cursor = GenerateSeriesVTab::open(); + let tbl = GenerateSeriesVTab::default(); + let mut cursor = tbl.open().unwrap(); let args = vec![ Value::from_integer(start), @@ -553,7 +556,7 @@ mod tests { ]; // Initialize cursor through filter - GenerateSeriesVTab::filter(&mut cursor, 3, &args); + GenerateSeriesVTab::filter(&mut cursor, &args); let mut rowids = vec![]; while !GenerateSeriesVTab::eof(&cursor) { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 632b95615..e85b144e5 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -171,7 +171,7 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { let fn_body = &ast.block; let alias_check = if let Some(alias) = &scalar_info.alias { quote! { - let Ok(alias_c_name) = std::ffi::CString::new(#alias) else { + let Ok(alias_c_name) = ::std::ffi::CString::new(#alias) else { return ::limbo_ext::ResultCode::Error; }; (api.register_scalar_function)( @@ -193,7 +193,7 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { return ::limbo_ext::ResultCode::Error; } let api = unsafe { &*api }; - let Ok(c_name) = std::ffi::CString::new(#name) else { + let Ok(c_name) = ::std::ffi::CString::new(#name) else { return ::limbo_ext::ResultCode::Error; }; (api.register_scalar_function)( @@ -232,6 +232,7 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { /// ///impl AggFunc for SumPlusOne { /// type State = i64; +/// type Error = &'static str; /// const NAME: &'static str = "sum_plus_one"; /// const ARGS: i32 = 1; /// fn step(state: &mut Self::State, args: &[Value]) { @@ -240,8 +241,8 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { /// }; /// *state += val; /// } -/// fn finalize(state: Self::State) -> Value { -/// Value::from_integer(state + 1) +/// fn finalize(state: Self::State) -> Result { +/// Ok(Value::from_integer(state + 1)) /// } ///} /// ``` @@ -259,11 +260,11 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { impl #struct_name { #[no_mangle] pub extern "C" fn #init_fn_name() -> *mut ::limbo_ext::AggCtx { - let state = Box::new(<#struct_name as ::limbo_ext::AggFunc>::State::default()); - let ctx = Box::new(::limbo_ext::AggCtx { - state: Box::into_raw(state) as *mut ::std::os::raw::c_void, + let state = ::std::boxed::Box::new(<#struct_name as ::limbo_ext::AggFunc>::State::default()); + let ctx = ::std::boxed::Box::new(::limbo_ext::AggCtx { + state: ::std::boxed::Box::into_raw(state) as *mut ::std::os::raw::c_void, }); - Box::into_raw(ctx) + ::std::boxed::Box::into_raw(ctx) } #[no_mangle] @@ -275,7 +276,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { unsafe { let ctx = &mut *ctx; let state = &mut *(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); - let args = std::slice::from_raw_parts(argv, argc as usize); + let args = ::std::slice::from_raw_parts(argv, argc as usize); <#struct_name as ::limbo_ext::AggFunc>::step(state, args); } } @@ -286,8 +287,13 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { ) -> ::limbo_ext::Value { unsafe { let ctx = &mut *ctx; - let state = Box::from_raw(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); - <#struct_name as ::limbo_ext::AggFunc>::finalize(*state) + let state = ::std::boxed::Box::from_raw(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); + match <#struct_name as ::limbo_ext::AggFunc>::finalize(*state) { + Ok(val) => val, + Err(e) => { + ::limbo_ext::Value::error_with_message(e.to_string()) + } + } } } @@ -301,7 +307,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { let api = &*api; let name_str = #struct_name::NAME; - let c_name = match std::ffi::CString::new(name_str) { + let c_name = match ::std::ffi::CString::new(name_str) { Ok(cname) => cname, Err(_) => return ::limbo_ext::ResultCode::Error, }; @@ -335,13 +341,12 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// const NAME: &'static str = "csv_data"; /// /// /// Declare the schema for your virtual table -/// fn connect(api: &ExtensionApi) -> ResultCode { +/// fn create_schema(args: &[&str]) -> &'static str { /// let sql = "CREATE TABLE csv_data( /// name TEXT, /// age TEXT, /// city TEXT -/// )"; -/// api.declare_virtual_table(Self::NAME, sql) +/// )" /// } /// /// Open the virtual table and return a cursor /// fn open() -> Self::VCursor { @@ -377,6 +382,22 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// fn eof(cursor: &Self::VCursor) -> bool { /// cursor.index >= cursor.rows.len() /// } +/// +/// /// **Optional** methods for non-readonly tables: +/// +/// /// Update the row with the provided values, return the new rowid +/// fn update(&mut self, rowid: i64, args: &[Value]) -> Result, Self::Error> { +/// Ok(None)// return Ok(None) for read-only +/// } +/// /// Insert a new row with the provided values, return the new rowid +/// fn insert(&mut self, args: &[Value]) -> Result<(), Self::Error> { +/// Ok(()) // +/// } +/// /// Delete the row with the provided rowid +/// fn delete(&mut self, rowid: i64) -> Result<(), Self::Error> { +/// Ok(()) +/// } +/// /// #[derive(Debug)] /// struct CsvCursor { /// rows: Vec>, @@ -384,7 +405,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// /// impl CsvCursor { /// /// Returns the value for a given column index. -/// fn column(&self, idx: u32) -> Value { +/// fn column(&self, idx: u32) -> Result { /// let row = &self.rows[self.index]; /// if (idx as usize) < row.len() { /// Value::from_text(&row[idx as usize]) @@ -413,36 +434,47 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let struct_name = &ast.ident; let register_fn_name = format_ident!("register_{}", struct_name); - let connect_fn_name = format_ident!("connect_{}", struct_name); + let create_schema_fn_name = format_ident!("create_schema_{}", struct_name); let open_fn_name = format_ident!("open_{}", struct_name); let filter_fn_name = format_ident!("filter_{}", struct_name); let column_fn_name = format_ident!("column_{}", struct_name); let next_fn_name = format_ident!("next_{}", struct_name); let eof_fn_name = format_ident!("eof_{}", struct_name); + let update_fn_name = format_ident!("update_{}", struct_name); + let rowid_fn_name = format_ident!("rowid_{}", struct_name); let expanded = quote! { impl #struct_name { #[no_mangle] - unsafe extern "C" fn #connect_fn_name( - db: *const ::std::ffi::c_void, - ) -> ::limbo_ext::ResultCode { - if db.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let api = unsafe { &*(db as *const ExtensionApi) }; - <#struct_name as ::limbo_ext::VTabModule>::connect(api) + unsafe extern "C" fn #create_schema_fn_name( + argv: *const ::limbo_ext::Value, argc: i32 + ) -> *mut ::std::ffi::c_char { + let args = if argv.is_null() { + &Vec::new() + } else { + ::std::slice::from_raw_parts(argv, argc as usize) + }; + let sql = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); + ::std::ffi::CString::new(sql).unwrap().into_raw() } #[no_mangle] - unsafe extern "C" fn #open_fn_name( - ) -> *mut ::std::ffi::c_void { - let cursor = <#struct_name as ::limbo_ext::VTabModule>::open(); - Box::into_raw(Box::new(cursor)) as *mut ::std::ffi::c_void + unsafe extern "C" fn #open_fn_name(ctx: *const ::std::ffi::c_void) -> *const ::std::ffi::c_void { + if ctx.is_null() { + return ::std::ptr::null(); + } + let ctx = ctx as *const #struct_name; + let ctx: &#struct_name = &*ctx; + if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open(ctx) { + return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *const ::std::ffi::c_void; + } else { + return ::std::ptr::null(); + } } #[no_mangle] unsafe extern "C" fn #filter_fn_name( - cursor: *mut ::std::ffi::c_void, + cursor: *const ::std::ffi::c_void, argc: i32, argv: *const ::limbo_ext::Value, ) -> ::limbo_ext::ResultCode { @@ -450,44 +482,108 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { return ::limbo_ext::ResultCode::Error; } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; - let args = std::slice::from_raw_parts(argv, argc as usize); - <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, argc, args) + let args = ::std::slice::from_raw_parts(argv, argc as usize); + <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, args) } #[no_mangle] unsafe extern "C" fn #column_fn_name( - cursor: *mut ::std::ffi::c_void, + cursor: *const ::std::ffi::c_void, idx: u32, ) -> ::limbo_ext::Value { if cursor.is_null() { - return ::limbo_ext::Value::error(ResultCode::Error); + return ::limbo_ext::Value::error(::limbo_ext::ResultCode::Error); } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; - <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) + match <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) { + Ok(val) => val, + Err(e) => ::limbo_ext::Value::error_with_message(e.to_string()) + } } #[no_mangle] unsafe extern "C" fn #next_fn_name( - cursor: *mut ::std::ffi::c_void, + cursor: *const ::std::ffi::c_void, ) -> ::limbo_ext::ResultCode { if cursor.is_null() { return ::limbo_ext::ResultCode::Error; } - let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); <#struct_name as ::limbo_ext::VTabModule>::next(cursor) } #[no_mangle] unsafe extern "C" fn #eof_fn_name( - cursor: *mut ::std::ffi::c_void, + cursor: *const ::std::ffi::c_void, ) -> bool { if cursor.is_null() { return true; } - let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) } + #[no_mangle] + unsafe extern "C" fn #update_fn_name( + vtab: *const ::std::ffi::c_void, + argc: i32, + argv: *const ::limbo_ext::Value, + p_out_rowid: *mut i64, + ) -> ::limbo_ext::ResultCode { + if vtab.is_null() { + return ::limbo_ext::ResultCode::Error; + } + + let vtab = &mut *(vtab as *mut #struct_name); + let args = ::std::slice::from_raw_parts(argv, argc as usize); + + let old_rowid = match args.get(0).map(|v| v.value_type()) { + Some(::limbo_ext::ValueType::Integer) => args.get(0).unwrap().to_integer(), + _ => None, + }; + let new_rowid = match args.get(1).map(|v| v.value_type()) { + Some(::limbo_ext::ValueType::Integer) => args.get(1).unwrap().to_integer(), + _ => None, + }; + let columns = &args[2..]; + match (old_rowid, new_rowid) { + // DELETE: old_rowid provided, no new_rowid + (Some(old), None) => { + if <#struct_name as VTabModule>::delete(vtab, old).is_err() { + return ::limbo_ext::ResultCode::Error; + } + return ::limbo_ext::ResultCode::OK; + } + // UPDATE: old_rowid provided and new_rowid may exist + (Some(old), Some(new)) => { + if <#struct_name as VTabModule>::update(vtab, old, &columns).is_err() { + return ::limbo_ext::ResultCode::Error; + } + return ::limbo_ext::ResultCode::OK; + } + // INSERT: no old_rowid (old_rowid = None) + (None, _) => { + if let Ok(rowid) = <#struct_name as VTabModule>::insert(vtab, &columns) { + if !p_out_rowid.is_null() { + *p_out_rowid = rowid; + return ::limbo_ext::ResultCode::RowID; + } + return ::limbo_ext::ResultCode::OK; + } + } + } + return ::limbo_ext::ResultCode::Error; + } + + #[no_mangle] + pub unsafe extern "C" fn #rowid_fn_name(ctx: *const ::std::ffi::c_void) -> i64 { + if ctx.is_null() { + return -1; + } + let cursor = &*(ctx as *const <#struct_name as ::limbo_ext::VTabModule>::VCursor); + <<#struct_name as ::limbo_ext::VTabModule>::VCursor as ::limbo_ext::VTabCursor>::rowid(cursor) + } + #[no_mangle] pub unsafe extern "C" fn #register_fn_name( api: *const ::limbo_ext::ExtensionApi @@ -495,23 +591,23 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { if api.is_null() { return ::limbo_ext::ResultCode::Error; } - let api = &*api; let name = <#struct_name as ::limbo_ext::VTabModule>::NAME; - // name needs to be a c str FFI compatible, NOT CString - let name_c = std::ffi::CString::new(name).unwrap(); - + let name_c = ::std::ffi::CString::new(name).unwrap().into_raw() as *const ::std::ffi::c_char; + let table_instance = ::std::boxed::Box::into_raw(::std::boxed::Box::new(#struct_name::default())); let module = ::limbo_ext::VTabModuleImpl { - name: name_c.as_ptr(), - connect: Self::#connect_fn_name, + ctx: table_instance as *const ::std::ffi::c_void, + name: name_c, + create_schema: Self::#create_schema_fn_name, open: Self::#open_fn_name, filter: Self::#filter_fn_name, column: Self::#column_fn_name, next: Self::#next_fn_name, eof: Self::#eof_fn_name, + update: Self::#update_fn_name, + rowid: Self::#rowid_fn_name, }; - - (api.register_module)(api.ctx, name_c.as_ptr(), module) + (api.register_module)(api.ctx, name_c, module, <#struct_name as ::limbo_ext::VTabModule>::VTAB_KIND) } } }; @@ -589,11 +685,8 @@ pub fn register_extension(input: TokenStream) -> TokenStream { quote! { { let result = unsafe{ #vtab_ident::#register_fn(api)}; - if result == ::limbo_ext::ResultCode::OK { - let result = <#vtab_ident as ::limbo_ext::VTabModule>::connect(api); - if !result.is_ok() { - return result; - } + if !result.is_ok() { + return result; } } } diff --git a/testing/extensions.py b/testing/extensions.py index 76af242dd..8bff11bc2 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -443,7 +443,7 @@ def test_series(pipe): run_test( pipe, "SELECT * FROM generate_series(1, 10);", - lambda res: "Virtual table generate_series not found" in res, + lambda res: "Virtual table module not found: generate_series" in res, ) run_test(pipe, f".load {ext_path}", returns_null) run_test( @@ -468,6 +468,80 @@ def test_series(pipe): ) +def test_kv(pipe): + ext_path = "./target/debug/liblimbo_kv" + run_test( + pipe, + "create virtual table t using kv_store;", + lambda res: "Virtual table module not found: kv_store" in res, + ) + run_test(pipe, f".load {ext_path}", returns_null) + run_test( + pipe, + "create virtual table t using kv_store;", + returns_null, + "can create kv_store vtable", + ) + run_test( + pipe, + "insert into t values ('hello', 'world');", + returns_null, + "can insert into kv_store vtable", + ) + run_test( + pipe, + "select value from t where key = 'hello';", + lambda res: "world" == res, + "can select from kv_store", + ) + run_test( + pipe, + "delete from t where key = 'hello';", + returns_null, + "can delete from kv_store", + ) + run_test(pipe, "insert into t values ('other', 'value');", returns_null) + run_test( + pipe, + "select value from t where key = 'hello';", + lambda res: "" == res, + "proper data is deleted", + ) + run_test( + pipe, + "select * from t;", + lambda res: "other|value" == res, + "can select after deletion", + ) + run_test( + pipe, + "delete from t where key = 'other';", + returns_null, + "can delete from kv_store", + ) + run_test( + pipe, + "select * from t;", + lambda res: "" == res, + "can select empty table without error", + ) + run_test( + pipe, + "delete from t;", + returns_null, + "can delete from empty table without error", + ) + for i in range(100): + write_to_pipe(pipe, f"insert into t values ('key{i}', 'val{i}');") + run_test( + pipe, "select count(*) from t;", lambda res: "100" == res, "can insert 100 rows" + ) + run_test(pipe, "delete from t limit 96;", returns_null, "can delete 96 rows") + run_test( + pipe, "select count(*) from t;", lambda res: "4" == res, "four rows remain" + ) + + def main(): pipe = init_limbo() try: @@ -476,6 +550,7 @@ def main(): test_aggregates(pipe) test_crypto(pipe) test_series(pipe) + test_kv(pipe) except Exception as e: print(f"Test FAILED: {e}") diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index ad359cc0b..1aac9c2c4 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -1724,6 +1724,18 @@ pub enum ResolveType { /// `REPLACE` Replace, } +impl ResolveType { + /// Get the OE_XXX bit value + pub fn bit_value(&self) -> usize { + match self { + ResolveType::Rollback => 1, + ResolveType::Abort => 2, + ResolveType::Fail => 3, + ResolveType::Ignore => 4, + ResolveType::Replace => 5, + } + } +} /// `WITH` clause // https://sqlite.org/lang_with.html