diff --git a/core/lib.rs b/core/lib.rs index ccc2c2273..f381a9fce 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -363,7 +363,7 @@ impl Connection { pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow(); + let syms: &SymbolTable = &db.syms.borrow_mut(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; if let Some(cmd) = cmd { @@ -417,7 +417,7 @@ impl Connection { #[cfg(not(target_family = "wasm"))] pub fn load_extension>(&self, path: P) -> Result<()> { - Database::load_extension(self.db.as_ref(), path) + Database::load_extension(&self.db, path) } /// Close a connection and checkpoint. diff --git a/core/schema.rs b/core/schema.rs index e7688b58e..f4a6aee2b 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -68,20 +68,11 @@ impl Table { } } - pub fn get_column_at(&self, index: usize) -> &Column { + pub fn get_column_at(&self, index: usize) -> Option<&Column> { match self { - Self::BTree(table) => table - .columns - .get(index) - .expect("column index out of bounds"), - Self::Pseudo(table) => table - .columns - .get(index) - .expect("column index out of bounds"), - Self::Virtual(table) => table - .columns - .get(index) - .expect("column index out of bounds"), + Self::BTree(table) => table.columns.get(index), + Self::Pseudo(table) => table.columns.get(index), + Self::Virtual(table) => table.columns.get(index), } } @@ -100,6 +91,7 @@ impl Table { Self::Virtual(_) => None, } } + pub fn virtual_table(&self) -> Option> { match self { Self::Virtual(table) => Some(table.clone()), @@ -172,7 +164,7 @@ impl BTreeTable { sql.push_str(",\n"); } sql.push_str(" "); - sql.push_str(&column.name.as_ref().expect("column name is None")); + sql.push_str(column.name.as_ref().expect("column name is None")); sql.push(' '); sql.push_str(&column.ty.to_string()); } diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 8b2e70185..bef18c9f2 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1839,7 +1839,9 @@ pub fn translate_expr( dest: target_register, }); } - let column = table_reference.table.get_column_at(*column); + let Some(column) = table_reference.table.get_column_at(*column) else { + crate::bail_parse_error!("column index out of bounds"); + }; maybe_apply_affinity(column.ty, target_register, program); Ok(target_register) } diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 73124060f..99de57398 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -307,7 +307,9 @@ impl Optimizable for ast::Expr { else { return Ok(None); }; - let column = table_reference.table.get_column_at(*column); + let Some(column) = table_reference.table.get_column_at(*column) else { + return Ok(None); + }; for index in available_indexes_for_table.iter() { if let Some(name) = column.name.as_ref() { if &index.columns.first().unwrap().name == name { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 272b788a2..dcde7dc62 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use super::{ plan::{ Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, diff --git a/core/translate/select.rs b/core/translate/select.rs index 2940cfca6..2a055afd2 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -28,9 +28,9 @@ pub fn translate_select( let mut program = ProgramBuilder::new(ProgramBuilderOpts { query_mode, - num_cursors: count_plan_required_cursors(&select), - approx_num_insns: estimate_num_instructions(&select), - approx_num_labels: estimate_num_labels(&select), + num_cursors: count_plan_required_cursors(select), + approx_num_insns: estimate_num_instructions(select), + approx_num_labels: estimate_num_labels(select), }); emit_program(&mut program, select_plan, syms)?; Ok(program) diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 0e550fca1..30ddece57 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -110,10 +110,15 @@ pub trait VTabModule: 'static { } pub trait VTabCursor: Sized { + type Error; fn rowid(&self) -> i64; fn column(&self, idx: u32) -> Value; fn eof(&self) -> bool; fn next(&mut self) -> ResultCode; + fn set_error(&mut self, error: Self::Error); + fn error(&self) -> Option { + None + } } #[repr(C)] diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 63b6c6227..ef278d451 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -40,12 +40,14 @@ impl VTabModule for GenerateSeriesVTab { stop: 0, step: 0, current: 0, + error: None, } } fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { // args are the start, stop, and step if arg_count == 0 || arg_count > 3 { + cursor.set_error("Expected between 1 and 3 arguments"); return ResultCode::InvalidArgs; } let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); @@ -84,6 +86,7 @@ struct GenerateSeriesCursor { stop: i64, step: i64, current: i64, + error: Option<&'static str>, } impl GenerateSeriesCursor { @@ -101,6 +104,8 @@ impl GenerateSeriesCursor { } impl VTabCursor for GenerateSeriesCursor { + type Error = &'static str; + fn next(&mut self) -> ResultCode { GenerateSeriesCursor::next(self) } @@ -119,6 +124,14 @@ impl VTabCursor for GenerateSeriesCursor { } } + fn error(&self) -> Option { + self.error + } + + fn set_error(&mut self, err: &'static str) { + self.error = Some(err); + } + fn rowid(&self) -> i64 { ((self.current - self.start) / self.step) + 1 } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8dee8dc66..56d019525 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -324,6 +324,89 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } +/// Macro to derive a VTabModule for your extension. This macro will generate +/// the necessary functions to register your module with core. You must implement +/// the VTabModule trait for your struct, and the VTabCursor trait for your cursor. +/// ```ignore +///#[derive(Debug, VTabModuleDerive)] +///struct CsvVTab; +///impl VTabModule for CsvVTab { +/// type VCursor = CsvCursor; +/// const NAME: &'static str = "csv_data"; +/// +/// /// Declare the schema for your virtual table +/// fn connect(api: &ExtensionApi) -> ResultCode { +/// let sql = "CREATE TABLE csv_data( +/// name TEXT, +/// age TEXT, +/// city TEXT +/// )"; +/// api.declare_virtual_table(Self::NAME, sql) +/// } +/// /// Open the virtual table and return a cursor +/// fn open() -> Self::VCursor { +/// let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); +/// let rows: Vec> = csv_content +/// .lines() +/// .skip(1) +/// .map(|line| { +/// line.split(',') +/// .map(|s| s.trim().to_string()) +/// .collect() +/// }) +/// .collect(); +/// CsvCursor { rows, index: 0 } +/// } +/// /// Filter the virtual table based on arguments (omitted here for simplicity) +/// fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { +/// ResultCode::OK +/// } +/// /// Return the value for a given column index +/// fn column(cursor: &Self::VCursor, idx: u32) -> Value { +/// cursor.column(idx) +/// } +/// /// Move the cursor to the next row +/// fn next(cursor: &mut Self::VCursor) -> ResultCode { +/// if cursor.index < cursor.rows.len() - 1 { +/// cursor.index += 1; +/// ResultCode::OK +/// } else { +/// ResultCode::EOF +/// } +/// } +/// fn eof(cursor: &Self::VCursor) -> bool { +/// cursor.index >= cursor.rows.len() +/// } +/// #[derive(Debug)] +/// struct CsvCursor { +/// rows: Vec>, +/// index: usize, +/// +/// impl CsvCursor { +/// /// Returns the value for a given column index. +/// fn column(&self, idx: u32) -> Value { +/// let row = &self.rows[self.index]; +/// if (idx as usize) < row.len() { +/// Value::from_text(&row[idx as usize]) +/// } else { +/// Value::null() +/// } +/// } +/// // Implement the VTabCursor trait for your virtual cursor +/// impl VTabCursor for CsvCursor { +/// fn next(&mut self) -> ResultCode { +/// Self::next(self) +/// } +/// fn eof(&self) -> bool { +/// self.index >= self.rows.len() +/// } +/// fn column(&self, idx: u32) -> Value { +/// self.column(idx) +/// } +/// fn rowid(&self) -> i64 { +/// self.index as i64 +/// } + #[proc_macro_derive(VTabModuleDerive)] pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput);