diff --git a/core/error.rs b/core/error.rs index 53308114c..7832747eb 100644 --- a/core/error.rs +++ b/core/error.rs @@ -76,5 +76,11 @@ macro_rules! bail_constraint_error { }; } +impl From for LimboError { + fn from(err: limbo_ext::ResultCode) -> Self { + LimboError::ExtensionError(err.to_string()) + } +} + pub const SQLITE_CONSTRAINT: usize = 19; pub const SQLITE_CONSTRAINT_PRIMARYKEY: usize = SQLITE_CONSTRAINT | (6 << 8); diff --git a/core/ext/mod.rs b/core/ext/mod.rs index a4f5d6cc3..3ea7d9692 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,17 +1,20 @@ -use crate::{function::ExternalFunc, util::columns_from_create_table_body, Database, VirtualTable}; -use fallible_iterator::FallibleIterator; -use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabModuleImpl}; -pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; -use limbo_sqlite3_parser::{ - ast::{Cmd, Stmt}, - lexer::sql::Parser, +use crate::{function::ExternalFunc, Database}; +use limbo_ext::{ + ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabKind, VTabModuleImpl, }; +pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; use std::{ - ffi::{c_char, c_void, CStr}, + ffi::{c_char, c_void, CStr, CString}, rc::Rc, }; type ExternAggFunc = (InitAggFunction, StepFunction, FinalizeFunction); +#[derive(Clone)] +pub struct VTabImpl { + pub module_type: VTabKind, + pub implementation: Rc, +} + unsafe extern "C" fn register_scalar_function( ctx: *mut c_void, name: *const c_char, @@ -53,8 +56,12 @@ unsafe extern "C" fn register_module( ctx: *mut c_void, name: *const c_char, module: VTabModuleImpl, + kind: VTabKind, ) -> ResultCode { - let c_str = unsafe { CStr::from_ptr(name) }; + if name.is_null() || ctx.is_null() { + return ResultCode::Error; + } + let c_str = unsafe { CString::from_raw(name as *mut i8) }; let name_str = match c_str.to_str() { Ok(s) => s.to_string(), Err(_) => return ResultCode::Error, @@ -64,31 +71,7 @@ unsafe extern "C" fn register_module( } let db = unsafe { &mut *(ctx as *mut Database) }; - db.register_module_impl(&name_str, module) -} - -unsafe extern "C" fn declare_vtab( - ctx: *mut c_void, - name: *const c_char, - sql: *const c_char, -) -> ResultCode { - let c_str = unsafe { CStr::from_ptr(name) }; - let name_str = match c_str.to_str() { - Ok(s) => s.to_string(), - Err(_) => return ResultCode::Error, - }; - - let c_str = unsafe { CStr::from_ptr(sql) }; - let sql_str = match c_str.to_str() { - Ok(s) => s.to_string(), - Err(_) => return ResultCode::Error, - }; - - if ctx.is_null() { - return ResultCode::Error; - } - let db = unsafe { &mut *(ctx as *mut Database) }; - db.declare_vtab_impl(&name_str, &sql_str) + db.register_module_impl(&name_str, module, kind) } impl Database { @@ -113,32 +96,22 @@ impl Database { ResultCode::OK } - fn register_module_impl(&mut self, name: &str, module: VTabModuleImpl) -> ResultCode { - self.vtab_modules.insert(name.to_string(), Rc::new(module)); - ResultCode::OK - } - - fn declare_vtab_impl(&mut self, name: &str, sql: &str) -> ResultCode { - let mut parser = Parser::new(sql.as_bytes()); - let cmd = parser.next().unwrap().unwrap(); - let Cmd::Stmt(stmt) = cmd else { - return ResultCode::Error; + fn register_module_impl( + &mut self, + name: &str, + module: VTabModuleImpl, + kind: VTabKind, + ) -> ResultCode { + let module = Rc::new(module); + let vmodule = VTabImpl { + module_type: kind, + implementation: module, }; - let Stmt::CreateTable { body, .. } = stmt else { - return ResultCode::Error; - }; - let Ok(columns) = columns_from_create_table_body(*body) else { - return ResultCode::Error; - }; - let vtab_module = self.vtab_modules.get(name).unwrap().clone(); - - let vtab = VirtualTable { - name: name.to_string(), - implementation: vtab_module, - columns, - args: None, - }; - self.syms.borrow_mut().vtabs.insert(name.to_string(), vtab); + self.syms + .borrow_mut() + .vtab_modules + .insert(name.to_string(), vmodule.into()); + println!("Registered module: {}", name); ResultCode::OK } @@ -148,7 +121,6 @@ impl Database { register_scalar_function, register_aggregate_function, register_module, - declare_vtab, } } diff --git a/core/lib.rs b/core/lib.rs index 00889db98..d9727fd98 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -27,7 +27,7 @@ use fallible_iterator::FallibleIterator; use libloading::{Library, Symbol}; #[cfg(not(target_family = "wasm"))] use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; -use limbo_ext::{ResultCode, VTabModuleImpl, Value as ExtValue}; +use limbo_ext::{ResultCode, VTabKind, VTabModuleImpl, Value as ExtValue}; use limbo_sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser}; use parking_lot::RwLock; use schema::{Column, Schema}; @@ -49,7 +49,7 @@ pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; use types::OwnedValue; pub use types::Value; -use util::parse_schema_rows; +use util::{columns_from_create_table_body, parse_schema_rows}; use vdbe::builder::QueryMode; use vdbe::VTabOpaqueCursor; @@ -87,7 +87,6 @@ pub struct Database { schema: Rc>, header: Rc>, syms: Rc>, - vtab_modules: HashMap>, // Shared structures of a Database are the parts that are common to multiple threads that might // create DB connections. _shared_page_cache: Arc>, @@ -149,8 +148,7 @@ impl Database { header: header.clone(), _shared_page_cache: _shared_page_cache.clone(), _shared_wal: shared_wal.clone(), - syms, - vtab_modules: HashMap::new(), + syms: syms.clone(), }; if let Err(e) = db.register_builtins() { return Err(LimboError::ExtensionError(e)); @@ -169,7 +167,7 @@ impl Database { }); let rows = conn.query("SELECT * FROM sqlite_schema")?; let mut schema = schema.borrow_mut(); - parse_schema_rows(rows, &mut schema, io)?; + parse_schema_rows(rows, &mut schema, io, &syms.borrow())?; Ok(db) } @@ -276,10 +274,9 @@ impl Connection { pub fn prepare(self: &Rc, sql: impl AsRef) -> Result { let sql = sql.as_ref(); tracing::trace!("Preparing: {}", sql); - let db = &self.db; let mut parser = Parser::new(sql.as_bytes()); - let syms = &db.syms.borrow(); let cmd = parser.next()?; + let syms = self.db.syms.borrow(); if let Some(cmd) = cmd { match cmd { Cmd::Stmt(stmt) => { @@ -289,7 +286,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Normal, )?); Ok(Statement::new(program, self.pager.clone())) @@ -315,7 +312,7 @@ impl Connection { pub(crate) fn run_cmd(self: &Rc, cmd: Cmd) -> Result> { let db = self.db.clone(); - let syms: &SymbolTable = &db.syms.borrow(); + let syms = db.syms.borrow(); match cmd { Cmd::Stmt(stmt) => { let program = Rc::new(translate::translate( @@ -324,7 +321,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Normal, )?); let stmt = Statement::new(program, self.pager.clone()); @@ -337,7 +334,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Explain, )?; program.explain(); @@ -346,12 +343,8 @@ impl Connection { Cmd::ExplainQueryPlan(stmt) => { match stmt { ast::Stmt::Select(select) => { - let mut plan = prepare_select_plan( - &self.schema.borrow(), - *select, - &self.db.syms.borrow(), - None, - )?; + let mut plan = + prepare_select_plan(&self.schema.borrow(), *select, &syms, None)?; optimize_plan(&mut plan, &self.schema.borrow())?; println!("{}", plan); } @@ -368,10 +361,9 @@ impl Connection { pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); - let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; + let syms = self.db.syms.borrow(); if let Some(cmd) = cmd { match cmd { Cmd::Explain(stmt) => { @@ -381,7 +373,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Explain, )?; program.explain(); @@ -394,7 +386,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Normal, )?; @@ -524,14 +516,73 @@ pub type StepResult = vdbe::StepResult; #[derive(Clone, Debug)] pub struct VirtualTable { name: String, - args: Option>, + args: Option>, pub implementation: Rc, columns: Vec, } impl VirtualTable { + pub(crate) fn from_args( + tbl_name: Option<&str>, + module_name: &str, + args: &[String], + syms: &SymbolTable, + kind: VTabKind, + ) -> Result> { + let module = syms + .vtab_modules + .get(module_name) + .ok_or(LimboError::ExtensionError(format!( + "Virtual table module not found: {}", + module_name + )))?; + if let VTabKind::VirtualTable = kind { + if module.module_type != VTabKind::VirtualTable { + return Err(LimboError::ExtensionError(format!( + "Virtual table module {} is not a virtual table", + module_name + ))); + } + }; + let schema = module.implementation.as_ref().init_schema(args)?; + let mut parser = Parser::new(schema.as_bytes()); + parser.reset(schema.as_bytes()); + println!("Schema: {}", schema); + if let ast::Cmd::Stmt(ast::Stmt::CreateTable { body, .. }) = parser.next()?.ok_or( + LimboError::ParseError("Failed to parse schema from virtual table module".to_string()), + )? { + let columns = columns_from_create_table_body(&body)?; + let vtab = Rc::new(VirtualTable { + name: tbl_name.unwrap_or(module_name).to_owned(), + args: Some(args.to_vec()), + implementation: module.implementation.clone(), + columns, + }); + return Ok(vtab); + } + Err(crate::LimboError::ParseError( + "Failed to parse schema from virtual table module".to_string(), + )) + } + pub fn open(&self) -> VTabOpaqueCursor { - let cursor = unsafe { (self.implementation.open)() }; + let args = if let Some(args) = &self.args { + args.iter() + .map(|e| std::ffi::CString::new(e.to_string()).unwrap().into_raw()) + .collect() + } else { + Vec::new() + }; + let cursor = + unsafe { (self.implementation.open)(args.as_slice().as_ptr(), args.len() as i32) }; + // free the CString pointers + for arg in args { + unsafe { + if !arg.is_null() { + let _ = std::ffi::CString::from_raw(arg); + } + } + } VTabOpaqueCursor::new(cursor) } @@ -580,13 +631,51 @@ impl VirtualTable { _ => Err(LimboError::ExtensionError("Next failed".to_string())), } } + + pub fn update(&self, args: &[OwnedValue], rowid: Option) -> Result> { + let arg_count = args.len(); + let mut ext_args = Vec::with_capacity(arg_count); + for i in 0..arg_count { + let ownedvalue_arg = args.get(i).unwrap(); + let extvalue_arg: ExtValue = match ownedvalue_arg { + OwnedValue::Null => Ok(ExtValue::null()), + OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)), + OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)), + OwnedValue::Text(t) => Ok(ExtValue::from_text(t.as_str().to_string())), + OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())), + other => Err(LimboError::ExtensionError(format!( + "Unsupported value type: {:?}", + other + ))), + }?; + ext_args.push(extvalue_arg); + } + let rowid = rowid.unwrap_or(-1); + let newrowid = 0i64; + let implementation = self.implementation.as_ref(); + let rc = unsafe { + (self.implementation.update)( + implementation as *const VTabModuleImpl as *mut std::ffi::c_void, + arg_count as i32, + ext_args.as_ptr(), + rowid, + &newrowid as *const _ as *mut i64, + ) + }; + match rc { + ResultCode::OK => Ok(None), + ResultCode::RowID => Ok(Some(newrowid)), + _ => Err(LimboError::ExtensionError(rc.to_string())), + } + } } pub(crate) struct SymbolTable { pub functions: HashMap>, #[cfg(not(target_family = "wasm"))] extensions: Vec<(Library, *const ExtensionApi)>, - pub vtabs: HashMap, + pub vtabs: HashMap>, + pub vtab_modules: HashMap>, } impl std::fmt::Debug for SymbolTable { @@ -631,6 +720,7 @@ impl SymbolTable { vtabs: HashMap::new(), #[cfg(not(target_family = "wasm"))] extensions: Vec::new(), + vtab_modules: HashMap::new(), } } diff --git a/core/schema.rs b/core/schema.rs index 0395b2c28..884867066 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -12,29 +12,46 @@ use std::rc::Rc; use tracing::trace; pub struct Schema { - pub tables: HashMap>, + pub tables: HashMap>, // table_name to list of indexes for the table pub indexes: HashMap>>, } impl Schema { pub fn new() -> Self { - let mut tables: HashMap> = HashMap::new(); + let mut tables: HashMap> = HashMap::new(); let indexes: HashMap>> = HashMap::new(); - tables.insert("sqlite_schema".to_string(), Rc::new(sqlite_schema_table())); + tables.insert( + "sqlite_schema".to_string(), + Rc::new(Table::BTree(sqlite_schema_table().into())), + ); Self { tables, indexes } } - pub fn add_table(&mut self, table: Rc) { + pub fn add_btree_table(&mut self, table: Rc) { let name = normalize_ident(&table.name); - self.tables.insert(name, table); + self.tables.insert(name, Table::BTree(table).into()); } - pub fn get_table(&self, name: &str) -> Option> { + pub fn add_virtual_table(&mut self, table: Rc) { + let name = normalize_ident(&table.name); + self.tables.insert(name, Table::Virtual(table).into()); + } + + pub fn get_table(&self, name: &str) -> Option> { let name = normalize_ident(name); self.tables.get(&name).cloned() } + pub fn get_btree_table(&self, name: &str) -> Option> { + let name = normalize_ident(name); + if let Some(table) = self.tables.get(&name) { + table.btree() + } else { + None + } + } + pub fn add_index(&mut self, index: Rc) { let table_name = normalize_ident(&index.table_name); self.indexes diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 6a55bfc03..81f8ba6ef 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -42,7 +42,10 @@ pub fn prepare_delete_plan( Some(table) => table, None => crate::bail_corrupt_error!("Parse error: no such table: {}", tbl_name), }; - + //if let Some(table) = table.virtual_table() { + // // TODO: emit VUpdate + //} + let table = table.btree().unwrap(); let table_references = vec![TableReference { table: Table::BTree(table.clone()), identifier: table.name.clone(), diff --git a/core/translate/insert.rs b/core/translate/insert.rs index eb36b7e75..5f933e93a 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -1,15 +1,15 @@ use std::ops::Deref; +use std::rc::Rc; use limbo_sqlite3_parser::ast::{ - DistinctNames, Expr, InsertBody, QualifiedName, ResolveType, ResultColumn, With, + DistinctNames, Expr, InsertBody, OneSelect, QualifiedName, ResolveType, ResultColumn, With, }; use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; -use crate::schema::BTreeTable; +use crate::schema::Table; use crate::util::normalize_ident; use crate::vdbe::builder::{ProgramBuilderOpts, QueryMode}; use crate::vdbe::BranchOffset; -use crate::Result; use crate::{ schema::{Column, Schema}, translate::expr::translate_expr, @@ -19,6 +19,7 @@ use crate::{ }, SymbolTable, }; +use crate::{Result, VirtualTable}; use super::emitter::Resolver; @@ -46,32 +47,45 @@ pub fn translate_insert( if on_conflict.is_some() { crate::bail_parse_error!("ON CONFLICT clause is not supported"); } + + let table_name = &tbl_name.name; + let table = match schema.get_table(table_name.0.as_str()) { + Some(table) => table, + None => crate::bail_corrupt_error!("Parse error: no such table: {}", table_name), + }; let resolver = Resolver::new(syms); + if let Some(virtual_table) = &table.virtual_table() { + translate_virtual_table_insert( + &mut program, + virtual_table.clone(), + columns, + body, + on_conflict, + &resolver, + ); + return Ok(program); + } let init_label = program.allocate_label(); program.emit_insn(Insn::Init { target_pc: init_label, }); let start_offset = program.offset(); - // open table - let table_name = &tbl_name.name; - - let table = match schema.get_table(table_name.0.as_str()) { - Some(table) => table, - None => crate::bail_corrupt_error!("Parse error: no such table: {}", table_name), + let Some(btree_table) = table.btree() else { + crate::bail_corrupt_error!("Parse error: no such table: {}", table_name); }; - if !table.has_rowid { + if !btree_table.has_rowid { crate::bail_parse_error!("INSERT into WITHOUT ROWID table is not supported"); } let cursor_id = program.alloc_cursor_id( Some(table_name.0.clone()), - CursorType::BTreeTable(table.clone()), + CursorType::BTreeTable(btree_table.clone()), ); - let root_page = table.root_page; + let root_page = btree_table.root_page; let values = match body { InsertBody::Select(select, None) => match &select.body.select.deref() { - limbo_sqlite3_parser::ast::OneSelect::Values(values) => values, + OneSelect::Values(values) => values, _ => todo!(), }, _ => todo!(), @@ -79,9 +93,9 @@ pub fn translate_insert( let column_mappings = resolve_columns_for_insert(&table, columns, values)?; // Check if rowid was provided (through INTEGER PRIMARY KEY as a rowid alias) - let rowid_alias_index = table.columns.iter().position(|c| c.is_rowid_alias); + let rowid_alias_index = btree_table.columns.iter().position(|c| c.is_rowid_alias); let has_user_provided_rowid = { - assert_eq!(column_mappings.len(), table.columns.len()); + assert_eq!(column_mappings.len(), btree_table.columns.len()); if let Some(index) = rowid_alias_index { column_mappings[index].value_index.is_some() } else { @@ -91,7 +105,7 @@ pub fn translate_insert( // allocate a register for each column in the table. if not provided by user, they will simply be set as null. // allocate an extra register for rowid regardless of whether user provided a rowid alias column. - let num_cols = table.columns.len(); + let num_cols = btree_table.columns.len(); let rowid_reg = program.alloc_registers(num_cols + 1); let column_registers_start = rowid_reg + 1; let rowid_alias_reg = { @@ -108,7 +122,7 @@ pub fn translate_insert( let inserting_multiple_rows = values.len() > 1; - // Multiple rows - use coroutine for value population + // multiple rows - use coroutine for value population if inserting_multiple_rows { let yield_reg = program.alloc_register(); let jump_on_definition_label = program.allocate_label(); @@ -217,7 +231,7 @@ pub fn translate_insert( target_pc: make_record_label, }); let rowid_column_name = if let Some(index) = rowid_alias_index { - &table + btree_table .columns .get(index) .unwrap() @@ -302,7 +316,7 @@ struct ColumnMapping<'a> { /// - Named columns map to their corresponding value index /// - Unspecified columns map to None fn resolve_columns_for_insert<'a>( - table: &'a BTreeTable, + table: &'a Table, columns: &Option, values: &[Vec], ) -> Result>> { @@ -310,7 +324,7 @@ fn resolve_columns_for_insert<'a>( crate::bail_parse_error!("no values to insert"); } - let table_columns = &table.columns; + let table_columns = &table.columns(); // Case 1: No columns specified - map values to columns in order if columns.is_none() { @@ -318,7 +332,7 @@ fn resolve_columns_for_insert<'a>( if num_values > table_columns.len() { crate::bail_parse_error!( "table {} has {} columns but {} values were supplied", - &table.name, + &table.get_name(), table_columns.len(), num_values ); @@ -361,7 +375,11 @@ fn resolve_columns_for_insert<'a>( }); if table_index.is_none() { - crate::bail_parse_error!("table {} has no column named {}", &table.name, column_name); + crate::bail_parse_error!( + "table {} has no column named {}", + &table.get_name(), + column_name + ); } mappings[table_index.unwrap()].value_index = Some(value_index); @@ -425,3 +443,95 @@ fn populate_column_registers( } Ok(()) } + +fn translate_virtual_table_insert( + program: &mut ProgramBuilder, + virtual_table: Rc, + columns: &Option, + body: &InsertBody, + on_conflict: &Option, + resolver: &Resolver, +) -> Result<()> { + let init_label = program.allocate_label(); + program.emit_insn(Insn::Init { + target_pc: init_label, + }); + let start_offset = program.offset(); + + let values = match body { + InsertBody::Select(select, None) => match &select.body.select.deref() { + OneSelect::Values(values) => values, + _ => crate::bail_parse_error!("Virtual tables only support VALUES clause in INSERT"), + }, + InsertBody::DefaultValues => &vec![], + _ => crate::bail_parse_error!("Unsupported INSERT body for virtual tables"), + }; + + let table = Table::Virtual(virtual_table.clone()); + let column_mappings = resolve_columns_for_insert(&table, columns, values)?; + + let value_registers_start = program.alloc_registers(values[0].len()); + for (i, expr) in values[0].iter().enumerate() { + translate_expr(program, None, expr, value_registers_start + i, resolver)?; + } + + let start_reg = program.alloc_registers(column_mappings.len() + 3); + let rowid_reg = start_reg; // argv[0] = rowid + let insert_rowid_reg = start_reg + 1; // argv[1] = insert_rowid + let data_start_reg = start_reg + 2; // argv[2..] = column values + + program.emit_insn(Insn::Null { + dest: rowid_reg, + dest_end: None, + }); + program.emit_insn(Insn::Null { + dest: insert_rowid_reg, + dest_end: None, + }); + + for (i, mapping) in column_mappings.iter().enumerate() { + let target_reg = data_start_reg + i; + if let Some(value_index) = mapping.value_index { + program.emit_insn(Insn::Copy { + src_reg: value_registers_start + value_index, + dst_reg: target_reg, + amount: 1, + }); + } else { + program.emit_insn(Insn::Null { + dest: target_reg, + dest_end: None, + }); + } + } + + let conflict_action = on_conflict.as_ref().map(|c| c.bit_value()).unwrap_or(0) as u16; + + let cursor_id = program.alloc_cursor_id( + Some(virtual_table.name.clone()), + CursorType::VirtualTable(virtual_table.clone()), + ); + + program.emit_insn(Insn::VUpdate { + cursor_id, + arg_count: column_mappings.len() + 2, + start_reg, + vtab_ptr: virtual_table.implementation.as_ref().ctx as usize, + conflict_action, + }); + + let halt_label = program.allocate_label(); + program.emit_insn(Insn::Halt { + err_code: 0, + description: String::new(), + }); + + program.resolve_label(halt_label, program.offset()); + program.resolve_label(init_label, program.offset()); + + program.emit_insn(Insn::Goto { + target_pc: start_offset, + }); + + Ok(()) +} diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index a9fa9a158..a0e4a13c4 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -293,10 +293,32 @@ pub fn open_loop( }; let start_reg = program.alloc_registers(args.len()); let mut cur_reg = start_reg; - for arg in args { + + for arg_str in args { let reg = cur_reg; cur_reg += 1; - translate_expr(program, Some(tables), &arg, reg, &t_ctx.resolver)?; + + if let Ok(i) = arg_str.parse::() { + program.emit_insn(Insn::Integer { + value: i, + dest: reg, + }); + } else if let Ok(f) = arg_str.parse::() { + program.emit_insn(Insn::Real { + value: f, + dest: reg, + }); + } else if arg_str.starts_with('"') && arg_str.ends_with('"') { + program.emit_insn(Insn::String8 { + value: arg_str.trim_matches('"').to_string(), + dest: reg, + }); + } else { + program.emit_insn(Insn::String8 { + value: arg_str.clone(), + dest: reg, + }); + } } program.emit_insn(Insn::VFilter { cursor_id, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 3ee8f4ce0..7df6258ec 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -33,8 +33,7 @@ use crate::vdbe::builder::{CursorType, ProgramBuilderOpts, QueryMode}; use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; -use limbo_sqlite3_parser::ast::{self, fmt::ToTokens}; -use limbo_sqlite3_parser::ast::{Delete, Insert}; +use limbo_sqlite3_parser::ast::{self, fmt::ToTokens, CreateVirtualTable, Delete, Insert}; use select::translate_select; use std::cell::RefCell; use std::fmt::Display; @@ -74,8 +73,8 @@ pub fn translate( } ast::Stmt::CreateTrigger { .. } => bail_parse_error!("CREATE TRIGGER not supported yet"), ast::Stmt::CreateView { .. } => bail_parse_error!("CREATE VIEW not supported yet"), - ast::Stmt::CreateVirtualTable { .. } => { - bail_parse_error!("CREATE VIRTUAL TABLE not supported yet") + ast::Stmt::CreateVirtualTable(vtab) => { + translate_create_virtual_table(*vtab, schema, query_mode)? } ast::Stmt::Delete(delete) => { let Delete { @@ -94,7 +93,7 @@ pub fn translate( ast::Stmt::DropView { .. } => bail_parse_error!("DROP VIEW not supported yet"), ast::Stmt::Pragma(name, body) => pragma::translate_pragma( query_mode, - &schema, + schema, &name, body.map(|b| *b), database_header.clone(), @@ -177,6 +176,7 @@ addr opcode p1 p2 p3 p4 p5 comment enum SchemaEntryType { Table, Index, + Virtual, } impl SchemaEntryType { @@ -184,9 +184,11 @@ impl SchemaEntryType { match self { SchemaEntryType::Table => "table", SchemaEntryType::Index => "index", + SchemaEntryType::Virtual => "virtual", } } } +const SQLITE_TABLEID: &str = "sqlite_schema"; fn emit_schema_entry( program: &mut ProgramBuilder, @@ -209,11 +211,18 @@ fn emit_schema_entry( program.emit_string8_new_reg(tbl_name.to_string()); let rootpage_reg = program.alloc_register(); - program.emit_insn(Insn::Copy { - src_reg: root_page_reg, - dst_reg: rootpage_reg, - amount: 1, - }); + if matches!(entry_type, SchemaEntryType::Virtual) { + program.emit_insn(Insn::Integer { + dest: rootpage_reg, + value: 0, // virtual tables in sqlite always have rootpage=0 + }); + } else { + program.emit_insn(Insn::Copy { + src_reg: root_page_reg, + dst_reg: rootpage_reg, + amount: 1, + }); + } let sql_reg = program.alloc_register(); if let Some(sql) = sql { @@ -455,10 +464,9 @@ fn translate_create_table( }); } - let table_id = "sqlite_schema".to_string(); - let table = schema.get_table(&table_id).unwrap(); + let table = schema.get_btree_table(SQLITE_TABLEID).unwrap(); let sqlite_schema_cursor_id = program.alloc_cursor_id( - Some(table_id.to_owned()), + Some(SQLITE_TABLEID.to_owned()), CursorType::BTreeTable(table.clone()), ); program.emit_insn(Insn::OpenWriteAsync { @@ -546,3 +554,132 @@ fn create_table_body_to_str(tbl_name: &ast::QualifiedName, body: &ast::CreateTab } sql } + +fn create_vtable_body_to_str(vtab: &CreateVirtualTable) -> String { + let args = if let Some(args) = &vtab.args { + args.iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(", ") + } else { + "".to_string() + }; + let if_not_exists = if vtab.if_not_exists { + "IF NOT EXISTS " + } else { + "" + }; + format!( + "CREATE VIRTUAL TABLE {} {} USING {}{}", + vtab.tbl_name.name.0, + vtab.module_name.0, + if_not_exists, + if args.is_empty() { + String::new() + } else { + format!("({})", args) + } + ) +} + +fn translate_create_virtual_table( + vtab: CreateVirtualTable, + schema: &Schema, + query_mode: QueryMode, +) -> Result { + let ast::CreateVirtualTable { + if_not_exists, + tbl_name, + module_name, + args, + } = &vtab; + + let table_name = tbl_name.name.0.clone(); + let module_name_str = module_name.0.clone(); + let args_vec = args.clone().unwrap_or_default(); + + if schema.get_table(&table_name).is_some() && *if_not_exists { + let mut program = ProgramBuilder::new(ProgramBuilderOpts { + query_mode, + num_cursors: 1, + approx_num_insns: 5, + approx_num_labels: 1, + }); + let init_label = program.emit_init(); + program.emit_halt(); + program.resolve_label(init_label, program.offset()); + program.emit_transaction(true); + program.emit_constant_insns(); + return Ok(program); + } + + let mut program = ProgramBuilder::new(ProgramBuilderOpts { + query_mode, + num_cursors: 2, + approx_num_insns: 40, + approx_num_labels: 2, + }); + + let module_name_reg = program.emit_string8_new_reg(module_name_str.clone()); + let table_name_reg = program.emit_string8_new_reg(table_name.clone()); + + let args_reg = if !args_vec.is_empty() { + let args_start = program.alloc_register(); + for (i, arg) in args_vec.iter().enumerate() { + program.emit_string8(arg.clone(), args_start + i); + } + let args_record_reg = program.alloc_register(); + program.emit_insn(Insn::MakeRecord { + start_reg: args_start, + count: args_vec.len(), + dest_reg: args_record_reg, + }); + Some(args_record_reg) + } else { + None + }; + + program.emit_insn(Insn::VCreate { + module_name: module_name_reg, + table_name: table_name_reg, + args_reg, + }); + + let table = schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let sqlite_schema_cursor_id = program.alloc_cursor_id( + Some(SQLITE_TABLEID.to_owned()), + CursorType::BTreeTable(table.clone()), + ); + program.emit_insn(Insn::OpenWriteAsync { + cursor_id: sqlite_schema_cursor_id, + root_page: 1, + }); + program.emit_insn(Insn::OpenWriteAwait {}); + + let sql = create_vtable_body_to_str(&vtab); + emit_schema_entry( + &mut program, + sqlite_schema_cursor_id, + SchemaEntryType::Virtual, + &tbl_name.name.0, + &tbl_name.name.0, + 0, // virtual tables dont have a root page + Some(sql), + ); + + let parse_schema_where_clause = format!("tbl_name = '{}' AND type != 'trigger'", table_name); + program.emit_insn(Insn::ParseSchema { + db: sqlite_schema_cursor_id, + where_clause: parse_schema_where_clause, + }); + + let init_label = program.emit_init(); + let start_offset = program.offset(); + program.emit_halt(); + program.resolve_label(init_label, program.offset()); + program.emit_transaction(true); + program.emit_constant_insns(); + program.emit_goto(start_offset); + + Ok(program) +} diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 138d1cbc0..c9c266a13 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -11,7 +11,7 @@ use crate::{ schema::{Schema, Table}, util::{exprs_are_equivalent, normalize_ident}, vdbe::BranchOffset, - Result, VirtualTable, + Result, }; use limbo_sqlite3_parser::ast::{ self, Expr, FromClause, JoinType, Limit, Materialized, UnaryOperator, With, @@ -303,7 +303,7 @@ fn parse_from_clause_table<'a>( return Ok(()); }; // Check if our top level schema has this table. - if let Some(table) = schema.get_table(&normalized_qualified_name) { + if let Some(table) = schema.get_btree_table(&normalized_qualified_name) { let alias = maybe_alias .map(|a| match a { ast::As::As(id) => id, @@ -369,9 +369,16 @@ fn parse_from_clause_table<'a>( } ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => { let normalized_name = &normalize_ident(qualified_name.name.0.as_str()); - let Some(vtab) = syms.vtabs.get(normalized_name) else { - crate::bail_parse_error!("Virtual table {} not found", normalized_name); - }; + let vtab = crate::VirtualTable::from_args( + None, + normalized_name, + &maybe_args + .as_ref() + .map(|a| a.iter().map(|s| s.to_string()).collect::>()) + .unwrap_or_default(), + syms, + limbo_ext::VTabKind::TableValuedFunction, + )?; let alias = maybe_alias .as_ref() .map(|a| match a { @@ -383,18 +390,10 @@ fn parse_from_clause_table<'a>( scope.tables.push(TableReference { op: Operation::Scan { iter_dir: None }, join_info: None, - table: Table::Virtual( - VirtualTable { - name: normalized_name.clone(), - args: maybe_args, - implementation: vtab.implementation.clone(), - columns: vtab.columns.clone(), - } - .into(), - ) - .into(), - identifier: alias.clone(), + table: Table::Virtual(vtab), + identifier: alias, }); + Ok(()) } _ => todo!(), @@ -611,7 +610,7 @@ fn parse_join<'a>( constraint, } = join; - parse_from_clause_table(schema, table, scope, syms)?; + parse_from_clause_table(schema, table, scope, &syms)?; let (outer, natural) = match join_operator { ast::JoinOperator::TypedJoin(Some(join_type)) => { diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 1fd738acd..b33f38011 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -218,7 +218,7 @@ fn query_pragma( program.alloc_register(); program.alloc_register(); if let Some(table) = table { - for (i, column) in table.columns.iter().enumerate() { + for (i, column) in table.columns().iter().enumerate() { // cid program.emit_int(i as i64, base_reg); // name diff --git a/core/util.rs b/core/util.rs index 2ee09189d..ddaf931a9 100644 --- a/core/util.rs +++ b/core/util.rs @@ -3,7 +3,7 @@ use std::{rc::Rc, sync::Arc}; use crate::{ schema::{self, Column, Schema, Type}, - Result, Statement, StepResult, IO, + Result, Statement, StepResult, SymbolTable, IO, }; // https://sqlite.org/lang_keywords.html @@ -28,6 +28,7 @@ pub fn parse_schema_rows( rows: Option, schema: &mut Schema, io: Arc, + syms: &SymbolTable, ) -> Result<()> { if let Some(mut rows) = rows { let mut automatic_indexes = Vec::new(); @@ -36,7 +37,7 @@ pub fn parse_schema_rows( StepResult::Row => { let row = rows.row().unwrap(); let ty = row.get::<&str>(0)?; - if ty != "table" && ty != "index" { + if !["table", "index", "virtual"].contains(&ty) { continue; } match ty { @@ -44,7 +45,12 @@ pub fn parse_schema_rows( let root_page: i64 = row.get::(3)?; let sql: &str = row.get::<&str>(4)?; let table = schema::BTreeTable::from_sql(sql, root_page as usize)?; - schema.add_table(Rc::new(table)); + schema.add_btree_table(Rc::new(table)); + } + "virtual" => { + let name: &str = row.get::<&str>(1)?; + let vtab = syms.vtabs.get(name).unwrap().clone(); + schema.add_virtual_table(vtab); } "index" => { let root_page: i64 = row.get::(3)?; @@ -83,7 +89,7 @@ pub fn parse_schema_rows( } for (index_name, table_name, root_page) in automatic_indexes { // We need to process these after all tables are loaded into memory due to the schema.get_table() call - let table = schema.get_table(&table_name).unwrap(); + let table = schema.get_btree_table(&table_name).unwrap(); let index = schema::Index::automatic_from_primary_key(&table, &index_name, root_page as usize)?; schema.add_index(Rc::new(index)); @@ -307,9 +313,11 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { } } -pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result, ()> { +pub fn columns_from_create_table_body(body: &ast::CreateTableBody) -> crate::Result> { let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else { - return Err(()); + return Err(crate::LimboError::ParseError( + "CREATE TABLE body must contain columns and constraints".to_string(), + )); }; Ok(columns @@ -322,7 +330,7 @@ pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result { // https://www.sqlite.org/datatype3.html diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 716610b71..470997158 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -50,7 +50,7 @@ impl CursorType { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Copy)] pub enum QueryMode { Normal, Explain, diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 2b61310d5..a609fc667 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -381,6 +381,19 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::VCreate { + table_name, + module_name, + args_reg, + } => ( + "VCreate", + *table_name as i32, + *module_name as i32, + args_reg.unwrap_or(0) as i32, + OwnedValue::build_text(""), + 0, + format!("table={}, module={}", table_name, module_name), + ), Insn::VFilter { cursor_id, pc_if_empty, @@ -408,6 +421,21 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::VUpdate { + cursor_id, + arg_count, // P2: Number of arguments in argv[] + start_reg, // P3: Start register for argv[] + vtab_ptr, // P4: vtab pointer + conflict_action, // P5: Conflict resolution flags + } => ( + "VUpdate", + *cursor_id as i32, + *arg_count as i32, + *start_reg as i32, + OwnedValue::build_text(&format!("vtab:{}", vtab_ptr)), + *conflict_action, + format!("args=r[{}..{}]", start_reg, start_reg + arg_count - 1), + ), Insn::VNext { cursor_id, pc_if_next, diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 1fad2c479..d6b25046c 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -220,6 +220,13 @@ pub enum Insn { /// Await for the completion of open cursor for a virtual table. VOpenAwait, + /// Create a new virtual table. + VCreate { + module_name: usize, // P1: Name of the module that contains the virtual table implementation + table_name: usize, // P2: Name of the virtual table + args_reg: Option, + }, + /// Initialize the position of the virtual table cursor. VFilter { cursor_id: CursorID, @@ -235,6 +242,15 @@ pub enum Insn { dest: usize, }, + /// `VUpdate`: Virtual Table Insert/Update/Delete Instruction + VUpdate { + cursor_id: usize, // P1: Virtual table cursor number + arg_count: usize, // P2: Number of arguments in argv[] + start_reg: usize, // P3: Start register for argv[] + vtab_ptr: usize, // P4: vtab pointer + conflict_action: u16, // P5: Conflict resolution flags + }, + /// Advance the virtual table cursor to the next row. /// TODO: async VNext { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index a81a1c687..a881c0d56 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -873,6 +873,46 @@ impl Program { .insert(*cursor_id, Some(Cursor::Virtual(cursor))); state.pc += 1; } + Insn::VCreate { + module_name, + table_name, + args_reg, + } => { + let module_name = state.registers[*module_name].to_string(); + let table_name = state.registers[*table_name].to_string(); + let args = if let Some(args_reg) = args_reg { + if let OwnedValue::Record(rec) = &state.registers[*args_reg] { + rec.get_values().iter().map(|v| v.to_string()).collect() + } else { + return Err(LimboError::InternalError( + "VCreate: args_reg is not a record".to_string(), + )); + } + } else { + vec![] + }; + let Some(conn) = self.connection.upgrade() else { + return Err(crate::LimboError::ExtensionError( + "Failed to upgrade Connection".to_string(), + )); + }; + let table = crate::VirtualTable::from_args( + Some(&table_name), + &module_name, + &args, + &conn.db.syms.borrow(), + limbo_ext::VTabKind::VirtualTable, + )?; + { + conn.db + .syms + .as_ref() + .borrow_mut() + .vtabs + .insert(table_name, table.clone()); + } + state.pc += 1; + } Insn::VOpenAwait => { state.pc += 1; } @@ -913,6 +953,68 @@ impl Program { state.registers[*dest] = virtual_table.column(cursor, *column)?; state.pc += 1; } + Insn::VUpdate { + cursor_id, + arg_count, + start_reg, + conflict_action, + .. + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VUpdate on non-virtual table cursor"); + }; + + if *arg_count < 2 { + return Err(LimboError::InternalError( + "VUpdate: arg_count must be at least 2 (rowid and insert_rowid)" + .to_string(), + )); + } + + let mut argv = Vec::with_capacity(*arg_count); + for i in 0..*arg_count { + if let Some(value) = state.registers.get(*start_reg + i) { + argv.push(value.clone()); + } else { + return Err(LimboError::InternalError(format!( + "VUpdate: register out of bounds at {}", + *start_reg + i + ))); + } + } + + let current_rowid = match argv.first() { + Some(OwnedValue::Integer(rowid)) => Some(*rowid), + _ => None, + }; + let insert_rowid = match argv.get(1) { + Some(OwnedValue::Integer(rowid)) => Some(*rowid), + _ => None, + }; + + let result = virtual_table.update(&argv, insert_rowid); + + match result { + Ok(Some(new_rowid)) => { + if *conflict_action == 5 { + if let Some(conn) = self.connection.upgrade() { + conn.update_last_rowid(new_rowid as u64); + } + } + state.pc += 1; + } + Ok(None) => { + state.pc += 1; + } + Err(e) => { + return Err(LimboError::ExtensionError(format!( + "Virtual table update failed: {}", + e + ))); + } + } + } Insn::VNext { cursor_id, pc_if_next, @@ -2724,8 +2826,13 @@ impl Program { where_clause ))?; let mut schema = RefCell::borrow_mut(&conn.schema); - // TODO: This function below is synchronous, make it not async - parse_schema_rows(Some(stmt), &mut schema, conn.pager.io.clone())?; + // TODO: This function below is synchronous, make it async + parse_schema_rows( + Some(stmt), + &mut schema, + conn.pager.io.clone(), + &conn.db.syms.borrow(), + )?; state.pc += 1; } Insn::ReadCookie { db, dest, cookie } => { diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index d06340aa2..2b123a3a2 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -6,34 +6,20 @@ use std::{ }; pub use types::{ResultCode, Value, ValueType}; +pub type ExtResult = std::result::Result; + #[repr(C)] pub struct ExtensionApi { pub ctx: *mut c_void, pub register_scalar_function: RegisterScalarFn, pub register_aggregate_function: RegisterAggFn, pub register_module: RegisterModuleFn, - pub declare_vtab: DeclareVTabFn, -} - -impl ExtensionApi { - pub fn declare_virtual_table(&self, name: &str, sql: &str) -> ResultCode { - let Ok(name) = std::ffi::CString::new(name) else { - return ResultCode::Error; - }; - let Ok(sql) = std::ffi::CString::new(sql) else { - return ResultCode::Error; - }; - unsafe { (self.declare_vtab)(self.ctx, name.as_ptr(), sql.as_ptr()) } - } } pub type ExtensionEntryPoint = unsafe extern "C" fn(api: *const ExtensionApi) -> ResultCode; pub type ScalarFunction = unsafe extern "C" fn(argc: i32, *const Value) -> Value; -pub type DeclareVTabFn = - unsafe extern "C" fn(ctx: *mut c_void, name: *const c_char, sql: *const c_char) -> ResultCode; - pub type RegisterScalarFn = unsafe extern "C" fn(ctx: *mut c_void, name: *const c_char, func: ScalarFunction) -> ResultCode; @@ -50,6 +36,7 @@ pub type RegisterModuleFn = unsafe extern "C" fn( ctx: *mut c_void, name: *const c_char, module: VTabModuleImpl, + kind: VTabKind, ) -> ResultCode; pub type InitAggFunction = unsafe extern "C" fn() -> *mut AggCtx; @@ -74,18 +61,39 @@ pub trait AggFunc { #[repr(C)] #[derive(Clone, Debug)] pub struct VTabModuleImpl { + pub ctx: *mut c_void, pub name: *const c_char, - pub connect: VtabFnConnect, + pub create_schema: VtabFnCreateSchema, pub open: VtabFnOpen, pub filter: VtabFnFilter, pub column: VtabFnColumn, pub next: VtabFnNext, pub eof: VtabFnEof, + pub update: VtabFnUpdate, } -pub type VtabFnConnect = unsafe extern "C" fn(api: *const c_void) -> ResultCode; +impl VTabModuleImpl { + pub fn init_schema(&self, args: &[String]) -> ExtResult { + let c_args = args + .iter() + .map(|s| std::ffi::CString::new(s.as_bytes()).unwrap().into_raw()) + .collect::>(); + let schema = unsafe { (self.create_schema)(c_args.as_ptr(), c_args.len() as i32) }; + c_args.into_iter().for_each(|s| unsafe { + let _ = std::ffi::CString::from_raw(s); + }); + if schema.is_null() { + return Err(ResultCode::InvalidArgs); + } + let schema = unsafe { std::ffi::CString::from_raw(schema) }; + Ok(schema.to_string_lossy().to_string()) + } +} -pub type VtabFnOpen = unsafe extern "C" fn() -> *mut c_void; +pub type VtabFnCreateSchema = + unsafe extern "C" fn(args: *const *mut c_char, argc: i32) -> *mut c_char; + +pub type VtabFnOpen = unsafe extern "C" fn(args: *const *mut c_char, argc: i32) -> *mut c_void; pub type VtabFnFilter = unsafe extern "C" fn(cursor: *mut c_void, argc: i32, argv: *const Value) -> ResultCode; @@ -96,17 +104,34 @@ pub type VtabFnNext = unsafe extern "C" fn(cursor: *mut c_void) -> ResultCode; pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; +pub type VtabFnUpdate = unsafe extern "C" fn( + vtab: *mut c_void, + argc: i32, + argv: *const Value, + rowid: i64, + p_out_rowid: *mut i64, +) -> ResultCode; + +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum VTabKind { + VirtualTable, + TableValuedFunction, +} + pub trait VTabModule: 'static { type VCursor: VTabCursor; + const VTAB_KIND: VTabKind; const NAME: &'static str; type Error: std::fmt::Display; - fn init_sql() -> &'static str; - fn open() -> Result; + fn create_schema(args: &[String]) -> String; + fn open(args: &[String]) -> Result; fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; fn column(cursor: &Self::VCursor, idx: u32) -> Result; fn next(cursor: &mut Self::VCursor) -> ResultCode; fn eof(cursor: &Self::VCursor) -> bool; + fn update(&mut self, args: &[Value], rowid: Option) -> Result, Self::Error>; } pub trait VTabCursor: Sized { @@ -116,8 +141,3 @@ pub trait VTabCursor: Sized { fn eof(&self) -> bool; fn next(&mut self) -> ResultCode; } - -#[repr(C)] -pub struct VTabImpl { - pub module: VTabModuleImpl, -} diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index c29768d7f..55d39ac42 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -21,6 +21,8 @@ pub enum ResultCode { Unavailable = 13, CustomError = 14, EOF = 15, + ReadOnly = 16, + RowID = 17, } impl ResultCode { @@ -52,6 +54,8 @@ impl Display for ResultCode { ResultCode::Unavailable => write!(f, "Unavailable"), ResultCode::CustomError => write!(f, "Error "), ResultCode::EOF => write!(f, "EOF"), + ResultCode::ReadOnly => write!(f, "Read Only"), + ResultCode::RowID => write!(f, "RowID"), } } } diff --git a/extensions/kvstore/Cargo.toml b/extensions/kvstore/Cargo.toml new file mode 100644 index 000000000..81c1f804d --- /dev/null +++ b/extensions/kvstore/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "limbo_kv" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +static= [ "limbo_ext/static" ] + +[dependencies] +limbo_ext = { workspace = true, features = ["static"] } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } diff --git a/extensions/kvstore/src/lib.rs b/extensions/kvstore/src/lib.rs new file mode 100644 index 000000000..184467aca --- /dev/null +++ b/extensions/kvstore/src/lib.rs @@ -0,0 +1,103 @@ +use limbo_ext::{ + register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, +}; +use std::collections::HashMap; + +register_extension! { + vtabs: { KVStoreVTab }, +} + +#[derive(VTabModuleDerive, Default)] +pub struct KVStoreVTab { + store: HashMap, +} + +pub struct KVStoreCursor { + keys: Vec, + values: Vec, + index: usize, +} + +impl VTabModule for KVStoreVTab { + type VCursor = KVStoreCursor; + const VTAB_KIND: VTabKind = VTabKind::VirtualTable; + const NAME: &'static str = "kv_store"; + type Error = String; + + fn create_schema(_args: &[String]) -> String { + "CREATE TABLE x (key TEXT PRIMARY KEY, value TEXT);".to_string() + } + + fn open(_args: &[String]) -> Result { + Ok(KVStoreCursor { + keys: Vec::new(), + values: Vec::new(), + index: 0, + }) + } + + fn filter(cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { + cursor.index = 0; + ResultCode::OK + } + + fn column(cursor: &Self::VCursor, idx: u32) -> Result { + match idx { + 0 => Ok(Value::from_text(cursor.keys[cursor.index].clone())), + 1 => Ok(Value::from_text(cursor.values[cursor.index].clone())), + _ => Err("Invalid column".into()), + } + } + + fn next(cursor: &mut Self::VCursor) -> ResultCode { + cursor.index += 1; + ResultCode::OK + } + + fn eof(cursor: &Self::VCursor) -> bool { + cursor.index >= cursor.keys.len() + } + + fn update(&mut self, args: &[Value], rowid: Option) -> Result, Self::Error> { + match args.len() { + 1 => { + let key = args[0].to_text().ok_or("Invalid key")?; + // Handle DELETE + self.store.remove(key); + Ok(None) + } + 2 => { + let key = args[0].to_text().ok_or("Invalid key")?; + let value = args[1].to_text().ok_or("Invalid value")?; + // Handle INSERT / UPDATE + self.store.insert(key.to_string(), value.to_string()); + Ok(Some(rowid.unwrap_or(0))) + } + _ => { + println!("args: {:?}", args); + Err("Invalid arguments for update".into()) + } + } + } +} + +impl VTabCursor for KVStoreCursor { + type Error = String; + fn rowid(&self) -> i64 { + self.index as i64 + } + fn column(&self, idx: u32) -> Result { + match idx { + 0 => Ok(Value::from_text(self.keys[self.index].clone())), + 1 => Ok(Value::from_text(self.values[self.index].clone())), + _ => Err("Invalid column".into()), + } + } + fn eof(&self) -> bool { + self.index >= self.keys.len() + } + fn next(&mut self) -> ResultCode { + self.index += 1; + ResultCode::OK + } +} diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 6e86f0c93..161bfe886 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -1,4 +1,6 @@ -use limbo_ext::{register_extension, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value}; +use limbo_ext::{ + register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, +}; register_extension! { vtabs: { GenerateSeriesVTab } @@ -14,16 +16,16 @@ macro_rules! try_option { } /// A virtual table that generates a sequence of integers -#[derive(Debug, VTabModuleDerive)] +#[derive(Debug, VTabModuleDerive, Default)] struct GenerateSeriesVTab; impl VTabModule for GenerateSeriesVTab { type VCursor = GenerateSeriesCursor; type Error = ResultCode; - const NAME: &'static str = "generate_series"; + const VTAB_KIND: VTabKind = VTabKind::TableValuedFunction; - fn init_sql() -> &'static str { + fn create_schema(_args: &[String]) -> String { // Create table schema "CREATE TABLE generate_series( value INTEGER, @@ -31,9 +33,10 @@ impl VTabModule for GenerateSeriesVTab { stop INTEGER HIDDEN, step INTEGER HIDDEN )" + .into() } - fn open() -> Result { + fn open(_args: &[String]) -> Result { Ok(GenerateSeriesCursor { start: 0, stop: 0, @@ -88,6 +91,10 @@ impl VTabModule for GenerateSeriesVTab { fn eof(cursor: &Self::VCursor) -> bool { cursor.eof() } + + fn update(&mut self, _args: &[Value], _rowid: Option) -> Result, Self::Error> { + Ok(None) + } } /// The cursor for iterating over the generated sequence diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 089579081..1f221b03f 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -341,7 +341,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// const NAME: &'static str = "csv_data"; /// /// /// Declare the schema for your virtual table -/// fn init_sql() -> &'static str { +/// fn create_schema(args: &[&str]) -> &'static str { /// let sql = "CREATE TABLE csv_data( /// name TEXT, /// age TEXT, @@ -382,6 +382,12 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// fn eof(cursor: &Self::VCursor) -> bool { /// cursor.index >= cursor.rows.len() /// } +/// +/// /// Update the row with the provided values, return the new rowid if provided +/// fn update(&mut self, args: &[Value], rowid: Option) -> Result, Self::Error> { +/// Ok(None)// return Ok(None) for read-only +/// } +/// /// #[derive(Debug)] /// struct CsvCursor { /// rows: Vec>, @@ -389,7 +395,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// /// impl CsvCursor { /// /// Returns the value for a given column index. -/// fn column(&self, idx: u32) -> Value { +/// fn column(&self, idx: u32) -> Result { /// let row = &self.rows[self.index]; /// if (idx as usize) < row.len() { /// Value::from_text(&row[idx as usize]) @@ -418,31 +424,45 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let struct_name = &ast.ident; let register_fn_name = format_ident!("register_{}", struct_name); - let connect_fn_name = format_ident!("connect_{}", struct_name); + let create_schema_fn_name = format_ident!("create_schema_{}", struct_name); let open_fn_name = format_ident!("open_{}", struct_name); let filter_fn_name = format_ident!("filter_{}", struct_name); let column_fn_name = format_ident!("column_{}", struct_name); let next_fn_name = format_ident!("next_{}", struct_name); let eof_fn_name = format_ident!("eof_{}", struct_name); + let update_fn_name = format_ident!("update_{}", struct_name); let expanded = quote! { impl #struct_name { #[no_mangle] - unsafe extern "C" fn #connect_fn_name( - db: *const ::std::ffi::c_void - ) -> ::limbo_ext::ResultCode { - let api = &*(db as *const ::limbo_ext::ExtensionApi); - let sql = <#struct_name as ::limbo_ext::VTabModule>::init_sql(); - api.declare_virtual_table(<#struct_name as ::limbo_ext::VTabModule>::NAME, sql) + unsafe extern "C" fn #create_schema_fn_name( + argv: *const *mut ::std::ffi::c_char, argc: i32 + ) -> *mut ::std::ffi::c_char { + let args = if argv.is_null() { + Vec::new() + } else { + ::std::slice::from_raw_parts(argv, argc as usize).iter().map(|s| { + ::std::ffi::CStr::from_ptr(*s).to_string_lossy().to_string() + }).collect::>() + }; + let sql = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); + ::std::ffi::CString::new(sql).unwrap().into_raw() } #[no_mangle] - unsafe extern "C" fn #open_fn_name( - ) -> *mut ::std::ffi::c_void { - if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open() { - ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *mut ::std::ffi::c_void + unsafe extern "C" fn #open_fn_name(argv: *const *mut ::std::ffi::c_char, argc: i32) -> *mut ::std::ffi::c_void { + let args = if argv.is_null() { + Vec::new() } else { - ::std::ptr::null_mut() + ::std::slice::from_raw_parts(argv, argc as usize).iter().map(|s| { + ::std::ffi::CStr::from_ptr(*s).to_string_lossy().to_string() + }).collect::>() + }; + let schema = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); + if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open(&args) { + return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *mut ::std::ffi::c_void; + } else { + return ::std::ptr::null_mut(); } } @@ -497,6 +517,37 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) } + #[no_mangle] + unsafe extern "C" fn #update_fn_name( + vtab: *mut ::std::ffi::c_void, + argc: i32, + argv: *const ::limbo_ext::Value, + rowid: i64, + p_out_rowid: *mut i64, + ) -> ::limbo_ext::ResultCode { + if vtab.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let vtab = &mut *(vtab as *mut #struct_name); + let args = ::std::slice::from_raw_parts(argv, argc as usize); + let rowid = if rowid == -1 { + None + } else { + Some(rowid as i64) + }; + let result = <#struct_name as ::limbo_ext::VTabModule>::update(vtab, args, rowid); + match result { + Ok(Some(rowid)) => { + // set the output rowid if it was provided + *p_out_rowid = rowid; + ::limbo_ext::ResultCode::RowID + } + Ok(None) => ::limbo_ext::ResultCode::OK, + Err(_) => ::limbo_ext::ResultCode::Error, + } + } + + #[no_mangle] pub unsafe extern "C" fn #register_fn_name( api: *const ::limbo_ext::ExtensionApi @@ -506,20 +557,20 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { } let api = &*api; let name = <#struct_name as ::limbo_ext::VTabModule>::NAME; - // name needs to be a c str FFI compatible, NOT CString let name_c = ::std::ffi::CString::new(name).unwrap().into_raw() as *const ::std::ffi::c_char; - + let table_instance = ::std::boxed::Box::into_raw(::std::boxed::Box::new(#struct_name::default())); let module = ::limbo_ext::VTabModuleImpl { + ctx: table_instance as *mut ::std::ffi::c_void, name: name_c, - connect: Self::#connect_fn_name, + create_schema: Self::#create_schema_fn_name, open: Self::#open_fn_name, filter: Self::#filter_fn_name, column: Self::#column_fn_name, next: Self::#next_fn_name, eof: Self::#eof_fn_name, + update: Self::#update_fn_name, }; - - (api.register_module)(api.ctx, name_c, module) + (api.register_module)(api.ctx, name_c, module, <#struct_name as ::limbo_ext::VTabModule>::VTAB_KIND) } } }; @@ -594,16 +645,11 @@ pub fn register_extension(input: TokenStream) -> TokenStream { }); let vtab_calls = vtabs.iter().map(|vtab_ident| { let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span()); - let connect_fn = syn::Ident::new(&format!("connect_{}", vtab_ident), vtab_ident.span()); quote! { { let result = unsafe{ #vtab_ident::#register_fn(api)}; - if result == ::limbo_ext::ResultCode::OK { - let api = api as *const _ as *const ::std::ffi::c_void; - let result = #vtab_ident::#connect_fn(api); - if !result.is_ok() { - return result; - } + if !result.is_ok() { + return result; } } } diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index ad359cc0b..1aac9c2c4 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -1724,6 +1724,18 @@ pub enum ResolveType { /// `REPLACE` Replace, } +impl ResolveType { + /// Get the OE_XXX bit value + pub fn bit_value(&self) -> usize { + match self { + ResolveType::Rollback => 1, + ResolveType::Abort => 2, + ResolveType::Fail => 3, + ResolveType::Ignore => 4, + ResolveType::Replace => 5, + } + } +} /// `WITH` clause // https://sqlite.org/lang_with.html