diff --git a/Cargo.lock b/Cargo.lock index 7a027407a..356dcfb8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1813,6 +1813,7 @@ dependencies = [ "libm", "limbo_completion", "limbo_crypto", + "limbo_csv", "limbo_ext", "limbo_ext_tests", "limbo_ipaddr", @@ -1862,6 +1863,16 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "limbo_csv" +version = "0.0.20" +dependencies = [ + "csv", + "limbo_ext", + "mimalloc", + "tempfile", +] + [[package]] name = "limbo_ext" version = "0.0.20" diff --git a/Cargo.toml b/Cargo.toml index 31794d093..317330655 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,8 @@ members = [ "core", "extensions/completion", "extensions/core", - "extensions/crypto", + "extensions/crypto", + "extensions/csv", "extensions/percentile", "extensions/regexp", "extensions/series", @@ -40,6 +41,7 @@ repository = "https://github.com/tursodatabase/limbo" limbo_completion = { path = "extensions/completion", version = "0.0.20" } limbo_core = { path = "core", version = "0.0.20" } limbo_crypto = { path = "extensions/crypto", version = "0.0.20" } +limbo_csv = { path = "extensions/csv", version = "0.0.20" } limbo_ext = { path = "extensions/core", version = "0.0.20" } limbo_ext_tests = { path = "extensions/tests", version = "0.0.20" } limbo_ipaddr = { path = "extensions/ipaddr", version = "0.0.20" } diff --git a/core/Cargo.toml b/core/Cargo.toml index 27807f3dd..d32fc4745 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -29,6 +29,7 @@ completion = ["limbo_completion/static"] testvfs = ["limbo_ext_tests/static"] static = ["limbo_ext/static"] fuzz = [] +csv = ["limbo_csv/static"] [target.'cfg(target_os = "linux")'.dependencies] io-uring = { version = "0.7.5", optional = true } @@ -68,6 +69,7 @@ limbo_series = { workspace = true, optional = true, features = ["static"] } limbo_ipaddr = { workspace = true, optional = true, features = ["static"] } limbo_completion = { workspace = true, optional = true, features = ["static"] } limbo_ext_tests = { workspace = true, optional = true, features = ["static"] } +limbo_csv = { workspace = true, optional = true, features = ["static"] } miette = "7.6.0" strum = { workspace = true } parking_lot = "0.12.3" diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 939fe3e05..b6fe6fc44 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -202,6 +202,10 @@ impl Connection { if unsafe { !limbo_completion::register_extension_static(&mut ext_api).is_ok() } { return Err("Failed to register completion extension".to_string()); } + #[cfg(feature = "csv")] + if unsafe { !limbo_csv::register_extension_static(&mut ext_api).is_ok() } { + return Err("Failed to register csv extension".to_string()); + } #[cfg(feature = "fs")] { let vfslist = add_builtin_vfs_extensions(Some(ext_api)).map_err(|e| e.to_string())?; diff --git a/core/lib.rs b/core/lib.rs index 4e4164a35..1db7001a6 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -45,6 +45,7 @@ use limbo_ext::{ConstraintInfo, IndexInfo, OrderByInfo, ResultCode, VTabKind, VT use limbo_sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser}; use parking_lot::RwLock; use schema::{Column, Schema}; +use std::ffi::c_void; use std::{ borrow::Cow, cell::{Cell, RefCell, UnsafeCell}, @@ -770,6 +771,7 @@ pub struct VirtualTable { pub implementation: Rc, columns: Vec, kind: VTabKind, + table_ptr: *const c_void, } impl VirtualTable { @@ -815,7 +817,7 @@ impl VirtualTable { ))); } }; - let schema = module.implementation.as_ref().init_schema(args)?; + let (schema, table_ptr) = module.implementation.as_ref().create(args)?; let mut parser = Parser::new(schema.as_bytes()); if let ast::Cmd::Stmt(ast::Stmt::CreateTable { body, .. }) = parser.next()?.ok_or( LimboError::ParseError("Failed to parse schema from virtual table module".to_string()), @@ -827,6 +829,7 @@ impl VirtualTable { columns, args: exprs, kind, + table_ptr, }); return Ok(vtab); } @@ -836,7 +839,7 @@ impl VirtualTable { } pub fn open(&self) -> crate::Result { - let cursor = unsafe { (self.implementation.open)(self.implementation.ctx) }; + let cursor = unsafe { (self.implementation.open)(self.table_ptr) }; VTabOpaqueCursor::new(cursor, self.implementation.close) } @@ -893,10 +896,9 @@ impl VirtualTable { let arg_count = args.len(); let ext_args = args.iter().map(|arg| arg.to_ffi()).collect::>(); let newrowid = 0i64; - let implementation = self.implementation.as_ref(); let rc = unsafe { (self.implementation.update)( - implementation as *const VTabModuleImpl as *const std::ffi::c_void, + self.table_ptr, arg_count as i32, ext_args.as_ptr(), &newrowid as *const _ as *mut i64, @@ -915,12 +917,7 @@ impl VirtualTable { } pub fn destroy(&self) -> Result<()> { - let implementation = self.implementation.as_ref(); - let rc = unsafe { - (self.implementation.destroy)( - implementation as *const VTabModuleImpl as *const std::ffi::c_void, - ) - }; + let rc = unsafe { (self.implementation.destroy)(self.table_ptr) }; match rc { ResultCode::OK => Ok(()), _ => Err(LimboError::ExtensionError(rc.to_string())), diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 28d992a54..15d0e8035 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -461,7 +461,7 @@ fn emit_delete_insns( dest: key_reg, }); - if let Some(vtab) = table_reference.virtual_table() { + if let Some(_) = table_reference.virtual_table() { let conflict_action = 0u16; let start_reg = key_reg; @@ -474,7 +474,6 @@ fn emit_delete_insns( cursor_id, arg_count: 2, start_reg, - vtab_ptr: vtab.implementation.as_ref().ctx as usize, conflict_action, }); } else { @@ -1039,13 +1038,12 @@ fn emit_update_insns( flag: 0, table_name: table_ref.identifier.clone(), }); - } else if let Some(vtab) = table_ref.virtual_table() { + } else if let Some(_) = table_ref.virtual_table() { let arg_count = table_ref.columns().len() + 2; program.emit_insn(Insn::VUpdate { cursor_id, arg_count, start_reg: beg, - vtab_ptr: vtab.implementation.as_ref().ctx as usize, conflict_action: 0u16, }); } diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 1c3169b4c..195dfa582 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -714,7 +714,6 @@ fn translate_virtual_table_insert( cursor_id, arg_count: column_mappings.len() + 2, start_reg: registers_start, - vtab_ptr: virtual_table.implementation.as_ref().ctx as usize, conflict_action, }); diff --git a/core/translate/schema.rs b/core/translate/schema.rs index 2c88a21e7..7e2c865a8 100644 --- a/core/translate/schema.rs +++ b/core/translate/schema.rs @@ -505,7 +505,7 @@ fn create_vtable_body_to_str(vtab: &CreateVirtualTable, module: Rc) -> .collect::>(); let schema = module .implementation - .init_schema(ext_args) + .create_schema(ext_args) .unwrap_or_default(); let vtab_args = if let Some(first_paren) = schema.find('(') { let closing_paren = schema.rfind(')').unwrap_or_default(); diff --git a/core/util.rs b/core/util.rs index 40a2c0902..3415aefdc 100644 --- a/core/util.rs +++ b/core/util.rs @@ -502,7 +502,7 @@ pub fn columns_from_create_table_body(body: &ast::CreateTableBody) -> crate::Res } let column = Column { - name: Some(name.0.clone()), + name: Some(normalize_ident(&name.0)), ty: match column_def.col_type { Some(ref data_type) => { // https://www.sqlite.org/datatype3.html diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 4de07b2e5..888cf32bf 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -416,14 +416,13 @@ pub fn insn_to_str( cursor_id, arg_count, // P2: Number of arguments in argv[] start_reg, // P3: Start register for argv[] - vtab_ptr, // P4: vtab pointer - conflict_action, // P5: Conflict resolution flags + conflict_action, // P4: Conflict resolution flags } => ( "VUpdate", *cursor_id as i32, *arg_count as i32, *start_reg as i32, - Value::build_text(&format!("vtab:{}", vtab_ptr)), + Value::build_text(""), *conflict_action, format!("args=r[{}..{}]", start_reg, start_reg + arg_count - 1), ), diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 7063ac050..f5ff8352d 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -323,8 +323,7 @@ pub enum Insn { cursor_id: usize, // P1: Virtual table cursor number arg_count: usize, // P2: Number of arguments in argv[] start_reg: usize, // P3: Start register for argv[] - vtab_ptr: usize, // P4: vtab pointer - conflict_action: u16, // P5: Conflict resolution flags + conflict_action: u16, // P4: Conflict resolution flags }, /// Advance the virtual table cursor to the next row. diff --git a/extensions/completion/src/lib.rs b/extensions/completion/src/lib.rs index 53358c23c..bfd03756c 100644 --- a/extensions/completion/src/lib.rs +++ b/extensions/completion/src/lib.rs @@ -4,10 +4,12 @@ mod keywords; use keywords::KEYWORDS; -use limbo_ext::{register_extension, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value}; +use limbo_ext::{ + register_extension, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, VTable, Value, +}; register_extension! { - vtabs: { CompletionVTab } + vtabs: { CompletionVTabModule } } macro_rules! try_option { @@ -57,73 +59,34 @@ impl Into for CompletionPhase { /// A virtual table that generates candidate completions #[derive(Debug, Default, VTabModuleDerive)] -struct CompletionVTab {} +struct CompletionVTabModule {} -impl VTabModule for CompletionVTab { - type VCursor = CompletionCursor; +impl VTabModule for CompletionVTabModule { + type Table = CompletionTable; const NAME: &'static str = "completion"; const VTAB_KIND: limbo_ext::VTabKind = limbo_ext::VTabKind::TableValuedFunction; - type Error = ResultCode; - fn create_schema(_args: &[Value]) -> String { - "CREATE TABLE completion( + fn create(_args: &[Value]) -> Result<(String, Self::Table), ResultCode> { + let schema = "CREATE TABLE completion( candidate TEXT, prefix TEXT HIDDEN, wholeline TEXT HIDDEN, phase INT HIDDEN )" - .to_string() + .to_string(); + Ok((schema, CompletionTable {})) } +} - fn open(&self) -> Result { +struct CompletionTable {} + +impl VTable for CompletionTable { + type Cursor = CompletionCursor; + type Error = ResultCode; + + fn open(&self) -> Result { Ok(CompletionCursor::default()) } - - fn column(cursor: &Self::VCursor, idx: u32) -> Result { - cursor.column(idx) - } - - fn next(cursor: &mut Self::VCursor) -> ResultCode { - cursor.next() - } - - fn eof(cursor: &Self::VCursor) -> bool { - cursor.eof() - } - - fn filter(cursor: &mut Self::VCursor, args: &[Value], _: Option<(&str, i32)>) -> ResultCode { - if args.is_empty() || args.len() > 2 { - return ResultCode::InvalidArgs; - } - cursor.reset(); - let prefix = try_option!(args[0].to_text(), ResultCode::InvalidArgs); - - let wholeline = args.get(1).map(|v| v.to_text().unwrap_or("")).unwrap_or(""); - - cursor.line = wholeline.to_string(); - cursor.prefix = prefix.to_string(); - - // Currently best index is not implemented so the correct arg parsing is not done here - if !cursor.line.is_empty() && cursor.prefix.is_empty() { - let mut i = cursor.line.len(); - while let Some(ch) = cursor.line.chars().next() { - if i > 0 && (ch.is_alphanumeric() || ch == '_') { - i -= 1; - } else { - break; - } - } - if cursor.line.len() - i > 0 { - // TODO see if need to inclusive range - cursor.prefix = cursor.line[..i].to_string(); - } - } - - cursor.rowid = 0; - cursor.phase = CompletionPhase::Keywords; - - Self::next(cursor) - } } /// The cursor for iterating over the completions @@ -150,6 +113,40 @@ impl CompletionCursor { impl VTabCursor for CompletionCursor { type Error = ResultCode; + fn filter(&mut self, args: &[Value], _: Option<(&str, i32)>) -> ResultCode { + if args.is_empty() || args.len() > 2 { + return ResultCode::InvalidArgs; + } + self.reset(); + let prefix = try_option!(args[0].to_text(), ResultCode::InvalidArgs); + + let wholeline = args.get(1).map(|v| v.to_text().unwrap_or("")).unwrap_or(""); + + self.line = wholeline.to_string(); + self.prefix = prefix.to_string(); + + // Currently best index is not implemented so the correct arg parsing is not done here + if !self.line.is_empty() && self.prefix.is_empty() { + let mut i = self.line.len(); + while let Some(ch) = self.line.chars().next() { + if i > 0 && (ch.is_alphanumeric() || ch == '_') { + i -= 1; + } else { + break; + } + } + if self.line.len() - i > 0 { + // TODO see if need to inclusive range + self.prefix = self.line[..i].to_string(); + } + } + + self.rowid = 0; + self.phase = CompletionPhase::Keywords; + + self.next() + } + fn next(&mut self) -> ResultCode { self.rowid += 1; diff --git a/extensions/core/README.md b/extensions/core/README.md index 0e4246f37..0493eead0 100644 --- a/extensions/core/README.md +++ b/extensions/core/README.md @@ -172,29 +172,35 @@ impl AggFunc for Percentile { /// Example: A virtual table that operates on a CSV file as a database table. /// This example assumes that the CSV file is located at "data.csv" in the current directory. #[derive(Debug, VTabModuleDerive)] -struct CsvVTable; +struct CsvVTableModule; -impl VTabModule for CsvVTable { - type VCursor = CsvCursor; - /// Define your error type. Must impl Display and match VCursor::Error - type Error = &'static str; +impl VTabModule for CsvVTableModule { + type Table = CsvTable; /// Declare the name for your virtual table const NAME: &'static str = "csv_data"; - /// Declare the type of vtable (TableValuedFunction or VirtualTable) const VTAB_KIND: VTabKind = VTabKind::VirtualTable; - /// Function to initialize the schema of your vtable - fn create_schema(_args: &[Value]) -> &'static str { - "CREATE TABLE csv_data( + /// Declare your virtual table and its schema + fn create(args: &[Value]) -> Result<(String, Self::Table), ResultCode> { + let schema = "CREATE TABLE csv_data( name TEXT, age TEXT, city TEXT - )" + )".into(); + Ok((schema, CsvTable {})) } +} + +struct CsvTable {} + +impl VTable for CsvTable { + type Cursor = CsvCursor; + /// Define your error type. Must impl Display and match Cursor::Error + type Error = &'static str; /// Open to return a new cursor: In this simple example, the CSV file is read completely into memory on connect. - fn open(&self) -> Result { + fn open(&self) -> Result { // Read CSV file contents from "data.csv" let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); // For simplicity, we'll ignore the header row. @@ -210,31 +216,6 @@ impl VTabModule for CsvVTable { Ok(CsvCursor { rows, index: 0 }) } - /// Filter through result columns. (not used in this simple example) - fn filter(_cursor: &mut Self::VCursor, _args: &[Value]) -> ResultCode { - ResultCode::OK - } - - /// Return the value for the column at the given index in the current row. - fn column(cursor: &Self::VCursor, idx: u32) -> Result { - cursor.column(idx) - } - - /// Next advances the cursor to the next row. - fn next(cursor: &mut Self::VCursor) -> ResultCode { - if cursor.index < cursor.rows.len() - 1 { - cursor.index += 1; - ResultCode::OK - } else { - ResultCode::EOF - } - } - - /// Return true if the cursor is at the end. - fn eof(cursor: &Self::VCursor) -> bool { - cursor.index >= cursor.rows.len() - } - /// *Optional* methods for non-readonly tables /// Update the value at rowid @@ -263,14 +244,27 @@ struct CsvCursor { impl VTabCursor for CsvCursor { type Error = &'static str; - fn next(&mut self) -> ResultCode { - CsvCursor::next(self) + /// Filter through result columns. (not used in this simple example) + fn filter(&mut self, _args: &[Value], _idx_info: Option<(&str, i32)>) -> ResultCode { + ResultCode::OK } + /// Next advances the cursor to the next row. + fn next(&mut self) -> ResultCode { + if self.index < self.rows.len() - 1 { + self.index += 1; + ResultCode::OK + } else { + ResultCode::EOF + } + } + + /// Return true if the cursor is at the end. fn eof(&self) -> bool { self.index >= self.rows.len() } + /// Return the value for the column at the given index in the current row. fn column(&self, idx: u32) -> Result { let row = &self.rows[self.index]; if (idx as usize) < row.len() { diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 99729de6c..75ea1c77d 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -17,7 +17,7 @@ pub use vfs_modules::{RegisterVfsFn, VfsExtension, VfsFile, VfsFileImpl, VfsImpl use vtabs::RegisterModuleFn; pub use vtabs::{ ConstraintInfo, ConstraintOp, ConstraintUsage, ExtIndexInfo, IndexInfo, OrderByInfo, - VTabCursor, VTabKind, VTabModule, VTabModuleImpl, + VTabCreateResult, VTabCursor, VTabKind, VTabModule, VTabModuleImpl, VTable, }; pub type ExtResult = std::result::Result; diff --git a/extensions/core/src/vtabs.rs b/extensions/core/src/vtabs.rs index de5794c6e..316e0a5c0 100644 --- a/extensions/core/src/vtabs.rs +++ b/extensions/core/src/vtabs.rs @@ -11,9 +11,8 @@ pub type RegisterModuleFn = unsafe extern "C" fn( #[repr(C)] #[derive(Clone, Debug)] pub struct VTabModuleImpl { - pub ctx: *const c_void, pub name: *const c_char, - pub create_schema: VtabFnCreateSchema, + pub create: VtabFnCreate, pub open: VtabFnOpen, pub close: VtabFnClose, pub filter: VtabFnFilter, @@ -26,24 +25,50 @@ pub struct VTabModuleImpl { pub best_idx: BestIdxFn, } +#[repr(C)] +pub struct VTabCreateResult { + pub code: ResultCode, + pub schema: *const c_char, + pub table: *const c_void, +} + #[cfg(feature = "core_only")] impl VTabModuleImpl { - pub fn init_schema(&self, args: Vec) -> crate::ExtResult { - let schema = unsafe { (self.create_schema)(args.as_ptr(), args.len() as i32) }; - if schema.is_null() { - return Err(ResultCode::InvalidArgs); - } + pub fn create(&self, args: Vec) -> crate::ExtResult<(String, *const c_void)> { + let result = unsafe { (self.create)(args.as_ptr(), args.len() as i32) }; for arg in args { unsafe { arg.__free_internal_type() }; } - let schema = unsafe { std::ffi::CString::from_raw(schema) }; - Ok(schema.to_string_lossy().to_string()) + if !result.code.is_ok() { + return Err(result.code); + } + let schema = unsafe { std::ffi::CString::from_raw(result.schema as *mut _) }; + Ok((schema.to_string_lossy().to_string(), result.table)) + } + + // TODO: This function is temporary and should eventually be removed. + // The only difference from `create` is that it takes ownership of the table instance. + // Currently, it is used to generate virtual table column names that are stored in + // `sqlite_schema` alongside the table's schema. + // However, storing column names is not necessary to match SQLite's behavior. + // SQLite computes the list of columns dynamically each time the `.schema` command + // is executed, using the `shell_add_schema` UDF function. + pub fn create_schema(&self, args: Vec) -> crate::ExtResult { + self.create(args).and_then(|(schema, table)| { + // Drop the allocated table instance to avoid a memory leak. + let result = unsafe { (self.destroy)(table) }; + if result.is_ok() { + Ok(schema) + } else { + Err(result) + } + }) } } -pub type VtabFnCreateSchema = unsafe extern "C" fn(args: *const Value, argc: i32) -> *mut c_char; +pub type VtabFnCreate = unsafe extern "C" fn(args: *const Value, argc: i32) -> VTabCreateResult; -pub type VtabFnOpen = unsafe extern "C" fn(*const c_void) -> *const c_void; +pub type VtabFnOpen = unsafe extern "C" fn(table: *const c_void) -> *const c_void; pub type VtabFnClose = unsafe extern "C" fn(cursor: *const c_void) -> ResultCode; @@ -64,13 +89,14 @@ pub type VtabFnEof = unsafe extern "C" fn(cursor: *const c_void) -> bool; pub type VtabRowIDFn = unsafe extern "C" fn(cursor: *const c_void) -> i64; pub type VtabFnUpdate = unsafe extern "C" fn( - vtab: *const c_void, + table: *const c_void, argc: i32, argv: *const Value, p_out_rowid: *mut i64, ) -> ResultCode; -pub type VtabFnDestroy = unsafe extern "C" fn(vtab: *const c_void) -> ResultCode; +pub type VtabFnDestroy = unsafe extern "C" fn(table: *const c_void) -> ResultCode; + pub type BestIdxFn = unsafe extern "C" fn( constraints: *const ConstraintInfo, constraint_len: i32, @@ -86,21 +112,20 @@ pub enum VTabKind { } pub trait VTabModule: 'static { - type VCursor: VTabCursor; + type Table: VTable; const VTAB_KIND: VTabKind; const NAME: &'static str; + + /// Creates a new instance of a virtual table. + /// Returns a tuple where the first element is the table's schema. + fn create(args: &[Value]) -> Result<(String, Self::Table), ResultCode>; +} + +pub trait VTable { + type Cursor: VTabCursor; type Error: std::fmt::Display; - fn create_schema(args: &[Value]) -> String; - fn open(&self) -> Result; - fn filter( - cursor: &mut Self::VCursor, - args: &[Value], - idx_info: Option<(&str, i32)>, - ) -> ResultCode; - fn column(cursor: &Self::VCursor, idx: u32) -> Result; - fn next(cursor: &mut Self::VCursor) -> ResultCode; - fn eof(cursor: &Self::VCursor) -> bool; + fn open(&self) -> Result; fn update(&mut self, _rowid: i64, _args: &[Value]) -> Result<(), Self::Error> { Ok(()) } @@ -133,6 +158,7 @@ pub trait VTabModule: 'static { pub trait VTabCursor: Sized { type Error: std::fmt::Display; + fn filter(&mut self, args: &[Value], idx_info: Option<(&str, i32)>) -> ResultCode; fn rowid(&self) -> i64; fn column(&self, idx: u32) -> Result; fn eof(&self) -> bool; diff --git a/extensions/csv/Cargo.toml b/extensions/csv/Cargo.toml new file mode 100644 index 000000000..d39da1eec --- /dev/null +++ b/extensions/csv/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "limbo_csv" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Limbo CSV extension" + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +static = ["limbo_ext/static"] + +[dependencies] +limbo_ext = { workspace = true, features = ["static"] } +csv = "1.3.1" + +[dev-dependencies] +tempfile = "3.19.1" + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "0.1", default-features = false } diff --git a/extensions/csv/src/lib.rs b/extensions/csv/src/lib.rs new file mode 100644 index 000000000..2f8b2a7aa --- /dev/null +++ b/extensions/csv/src/lib.rs @@ -0,0 +1,895 @@ +//! Port of SQLite's CSV virtual table extension: +//! +//! This extension allows querying CSV files as if they were database tables, +//! using the virtual table mechanism. +//! +//! It supports specifying the CSV input via a filename or raw data string, optional headers, +//! and customizable schema generation. +//! +//! ## Example usage: +//! +//! ```sql +//! CREATE VIRTUAL TABLE temp.my_csv USING csv(filename='data.csv', header=yes); +//! SELECT * FROM my_csv; +//! ``` +//! +//! ## Parameters: +//! - `filename` — path to the CSV file (mutually exclusive with `data=`) +//! - `data` — inline CSV content as a string +//! - `header` — whether the first row contains column names; +//! accepts `yes`/`no`, `on`/`off`, `true`/`false`, or `1`/`0` +//! - `columns` — number of columns +//! - `schema` — optional custom SQL `CREATE TABLE` schema +use limbo_ext::{ + register_extension, ConstraintInfo, IndexInfo, OrderByInfo, ResultCode, VTabCursor, VTabKind, + VTabModule, VTabModuleDerive, VTable, Value, +}; +use std::fs::File; +use std::io::{Read, Seek, SeekFrom}; + +register_extension! { + vtabs: { CsvVTabModule } +} + +#[derive(Debug, VTabModuleDerive, Default)] +struct CsvVTabModule; + +impl CsvVTabModule { + fn parse_arg(arg: &Value) -> Result<(&str, &str), ResultCode> { + if let Some(text) = arg.to_text() { + let mut split = text.splitn(2, '='); + if let Some(name) = split.next() { + if let Some(value) = split.next() { + let name = name.trim(); + let value = value.trim(); + return Ok((name, value)); + } + } + } + Err(ResultCode::InvalidArgs) + } + + fn parse_string(s: &str) -> Result { + let chars: Vec = s.chars().collect(); + let len = chars.len(); + + if len >= 2 && (chars[0] == '"' || chars[0] == '\'') { + let quote = chars[0]; + + if quote != chars[len - 1] { + return Err(ResultCode::InvalidArgs); + } + + let mut result = String::new(); + let mut i = 1; + + while i < len - 1 { + if chars[i] == quote && i + 1 < len - 1 && chars[i + 1] == quote { + // Escaped quote ("" or '') + result.push(quote); + i += 2; + } else { + result.push(chars[i]); + i += 1; + } + } + + Ok(result) + } else { + Ok(s.to_owned()) + } + } + + fn parse_boolean(s: &str) -> Option { + if s.eq_ignore_ascii_case("yes") + || s.eq_ignore_ascii_case("on") + || s.eq_ignore_ascii_case("true") + || s.eq("1") + { + Some(true) + } else if s.eq_ignore_ascii_case("no") + || s.eq_ignore_ascii_case("off") + || s.eq_ignore_ascii_case("false") + || s.eq("0") + { + Some(false) + } else { + None + } + } + + fn escape_double_quote(identifier: &str) -> String { + identifier.replace('"', "\"\"") + } +} + +impl VTabModule for CsvVTabModule { + type Table = CsvTable; + const VTAB_KIND: VTabKind = VTabKind::VirtualTable; + const NAME: &'static str = "csv"; + + fn create(args: &[Value]) -> Result<(String, Self::Table), ResultCode> { + if args.is_empty() { + return Err(ResultCode::InvalidArgs); + } + + let mut filename = None; + let mut data = None; + let mut schema = None; + let mut column_count = None; + let mut header = None; + + for arg in args { + let (name, value) = Self::parse_arg(arg)?; + match name { + "filename" => { + if filename.is_some() { + return Err(ResultCode::InvalidArgs); + } + filename = Some(Self::parse_string(value)?); + } + "data" => { + if data.is_some() { + return Err(ResultCode::InvalidArgs); + } + data = Some(Self::parse_string(value)?); + } + "schema" => { + if schema.is_some() { + return Err(ResultCode::InvalidArgs); + } + schema = Some(Self::parse_string(value)?); + } + "columns" => { + if column_count.is_some() { + return Err(ResultCode::InvalidArgs); + } + let n: u32 = value.parse().map_err(|_| ResultCode::InvalidArgs)?; + if n == 0 { + return Err(ResultCode::InvalidArgs); + } + column_count = Some(n); + } + "header" => { + if header.is_some() { + return Err(ResultCode::InvalidArgs); + } + header = Some(Self::parse_boolean(value).ok_or(ResultCode::InvalidArgs)?); + } + _ => { + return Err(ResultCode::InvalidArgs); + } + } + } + + if filename.is_some() == data.is_some() { + return Err(ResultCode::InvalidArgs); + } + + let mut columns: Vec = Vec::new(); + + let mut table = CsvTable { + column_count, + filename, + data, + header: header.unwrap_or(false), + first_row_position: csv::Position::new(), + }; + + if table.header || (column_count.is_none() && schema.is_none()) { + let mut reader = table.new_reader()?; + if table.header { + let headers = reader.headers().map_err(|_| ResultCode::Error)?; + if column_count.is_none() && schema.is_none() { + columns = headers + .into_iter() + .map(|header| Self::escape_double_quote(header)) + .collect(); + } + if columns.is_empty() { + columns.push("(NULL)".to_owned()); + } + table.first_row_position = reader.position().clone(); + } else { + let mut record = csv::ByteRecord::new(); + if reader + .read_byte_record(&mut record) + .map_err(|_| ResultCode::Error)? + { + for (i, _) in record.iter().enumerate() { + columns.push(format!("c{i}")); + } + } + if columns.is_empty() { + columns.push("c0".to_owned()); + } + } + } else if let Some(count) = column_count { + for i in 0..count { + columns.push(format!("c{i}")); + } + } + + if schema.is_none() { + let mut sql = String::from("CREATE TABLE x("); + for (i, col) in columns.iter().enumerate() { + sql.push('"'); + sql.push_str(col); + sql.push_str("\" TEXT"); + if i < columns.len() - 1 { + sql.push_str(", "); + } + } + sql.push_str(")"); + schema = Some(sql); + } + + Ok((schema.unwrap(), table)) + } +} + +struct CsvTable { + filename: Option, + data: Option, + header: bool, + column_count: Option, + first_row_position: csv::Position, +} + +impl CsvTable { + fn new_reader(&self) -> Result, ResultCode> { + let mut builder = csv::ReaderBuilder::new(); + builder.has_headers(self.header).delimiter(b',').quote(b'"'); + + match (&self.filename, &self.data) { + (Some(path), None) => { + let file = File::open(path).map_err(|_| ResultCode::Error)?; + Ok(builder.from_reader(ReadSource::File(file))) + } + (None, Some(data)) => { + let cursor = std::io::Cursor::new(data.clone().into_bytes()); + Ok(builder.from_reader(ReadSource::Memory(cursor))) + } + _ => Err(ResultCode::Internal), + } + } +} + +impl VTable for CsvTable { + type Cursor = CsvCursor; + type Error = ResultCode; + + fn open(&self) -> Result { + match self.new_reader() { + Ok(reader) => Ok(CsvCursor::new(reader, self)), + Err(_) => Err(ResultCode::Error), + } + } + + fn update(&mut self, _rowid: i64, _args: &[Value]) -> Result<(), Self::Error> { + Err(ResultCode::ReadOnly) + } + + fn insert(&mut self, _args: &[Value]) -> Result { + Err(ResultCode::ReadOnly) + } + + fn delete(&mut self, _rowid: i64) -> Result<(), Self::Error> { + Err(ResultCode::ReadOnly) + } + + fn best_index(_constraints: &[ConstraintInfo], _order_by: &[OrderByInfo]) -> IndexInfo { + // Only a forward full table scan is supported. + IndexInfo { + idx_num: -1, + idx_str: None, + order_by_consumed: false, + estimated_cost: 1_000_000., + ..Default::default() + } + } +} + +enum ReadSource { + File(File), + Memory(std::io::Cursor>), +} + +impl Read for ReadSource { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self { + ReadSource::File(f) => f.read(buf), + ReadSource::Memory(c) => c.read(buf), + } + } +} + +impl Seek for ReadSource { + fn seek(&mut self, pos: SeekFrom) -> std::io::Result { + match self { + ReadSource::File(f) => f.seek(pos), + ReadSource::Memory(c) => c.seek(pos), + } + } +} + +struct CsvCursor { + column_count: Option, + reader: csv::Reader, + row_number: usize, + current_row: csv::StringRecord, + eof: bool, + first_row_position: csv::Position, +} + +impl CsvCursor { + fn new(reader: csv::Reader, table: &CsvTable) -> Self { + CsvCursor { + column_count: table.column_count, + reader, + row_number: 0, + current_row: csv::StringRecord::new(), + eof: false, + first_row_position: table.first_row_position.clone(), + } + } +} + +impl VTabCursor for CsvCursor { + type Error = ResultCode; + + fn filter(&mut self, _args: &[Value], _idx_info: Option<(&str, i32)>) -> ResultCode { + let offset_first_row = self.first_row_position.clone(); + if self.reader.seek(offset_first_row).is_err() { + return ResultCode::Error; + }; + self.row_number = 0; + self.next() + } + + fn rowid(&self) -> i64 { + self.row_number as i64 + } + + fn column(&self, idx: u32) -> Result { + if let Some(count) = self.column_count { + if idx >= count { + return Ok(Value::null()); + } + } + let value = self + .current_row + .get(idx as usize) + .map_or(Value::null(), |s| Value::from_text(s.to_owned())); + Ok(value) + } + + fn eof(&self) -> bool { + self.eof + } + + fn next(&mut self) -> ResultCode { + { + self.eof = self.reader.is_done(); + if self.eof { + return ResultCode::EOF; + } + + match self.reader.read_record(&mut self.current_row) { + Ok(more) => { + self.eof = !more; + if self.eof { + return ResultCode::EOF; + } + } + Err(_) => return ResultCode::Error, + } + } + + self.row_number += 1; + ResultCode::OK + } +} + +#[cfg(test)] +mod tests { + use super::*; + use limbo_ext::{Value, ValueType}; + use std::io::Write; + use tempfile::NamedTempFile; + + fn write_csv(content: &str) -> NamedTempFile { + let mut tmp = NamedTempFile::new().expect("Failed to create temp file"); + write!(tmp, "{}", content).unwrap(); + tmp + } + + fn new_table(args: Vec<&str>) -> CsvTable { + try_new_table(args).unwrap().1 + } + + fn try_new_table(args: Vec<&str>) -> Result<(String, CsvTable), ResultCode> { + let args = &args + .iter() + .map(|s| Value::from_text(s.to_string())) + .collect::>(); + CsvVTabModule::create(args) + } + + fn read_rows(mut cursor: CsvCursor, column_count: u32) -> Vec>> { + let mut results = vec![]; + cursor.filter(&[], None); + + while !cursor.eof() { + let mut row = vec![]; + + for i in 0..column_count { + let cell = match cursor.column(i) { + Ok(v) => match v.value_type() { + ValueType::Null => None, + ValueType::Text => v.to_text().map(|s| s.to_owned()), + _ => panic!("Unexpected column type"), + }, + Err(_) => panic!("Error reading column"), + }; + row.push(cell); + } + + results.push(row); + cursor.next(); + } + + results + } + + macro_rules! cell { + ($x:expr) => { + Some($x.to_owned()) + }; + } + + #[test] + fn test_file_with_header() { + let file = write_csv("id,name\n1,Alice\n2,Bob\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=true", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!( + rows, + vec![ + vec![cell!("1"), cell!("Alice")], + vec![cell!("2"), cell!("Bob")] + ] + ); + } + + #[test] + fn test_data_with_header() { + let table = new_table(vec!["data=id,name\n1,Alice\n2,Bob\n", "header=true"]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!( + rows, + vec![ + vec![cell!("1"), cell!("Alice")], + vec![cell!("2"), cell!("Bob")] + ] + ); + } + + #[test] + fn test_file_without_header() { + let file = write_csv("1,Alice\n2,Bob\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=false", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!( + rows, + vec![ + vec![cell!("1"), cell!("Alice")], + vec![cell!("2"), cell!("Bob")] + ] + ); + } + + #[test] + fn test_data_without_header() { + let table = new_table(vec!["data=1,Alice\n2,Bob\n", "header=false"]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!( + rows, + vec![ + vec![cell!("1"), cell!("Alice")], + vec![cell!("2"), cell!("Bob")] + ] + ); + } + + #[test] + fn test_empty_file_with_header() { + let file = write_csv("id,name\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=true", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert!(rows.is_empty()); + } + + #[test] + fn test_empty_data_with_header() { + let table = new_table(vec!["data=id,name\n", "header=true"]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert!(rows.is_empty()); + } + + #[test] + fn test_empty_file_no_header() { + let file = write_csv(""); + let (schema, table) = try_new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=false", + ]) + .unwrap(); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert!(rows.is_empty()); + assert_eq!(schema, "CREATE TABLE x(\"c0\" TEXT)"); + } + + #[test] + fn test_empty_data_no_header() { + let (schema, table) = try_new_table(vec!["data=", "header=false"]).unwrap(); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert!(rows.is_empty()); + assert_eq!(schema, "CREATE TABLE x(\"c0\" TEXT)"); + } + + #[test] + fn test_empty_file_with_header_enabled() { + let file = write_csv(""); + let (schema, table) = try_new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=true", + ]) + .unwrap(); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert!(rows.is_empty()); + assert_eq!(schema, "CREATE TABLE x(\"(NULL)\" TEXT)"); + } + + #[test] + fn test_empty_data_with_header_enabled() { + let (schema, table) = try_new_table(vec!["data=", "header=true"]).unwrap(); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert!(rows.is_empty()); + assert_eq!(schema, "CREATE TABLE x(\"(NULL)\" TEXT)"); + } + + #[test] + fn test_quoted_field() { + let file = write_csv("id,name\n1,\"A,l,i,c,e\"\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=true", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!(rows, vec![vec![cell!("1"), cell!("A,l,i,c,e")],]); + } + + #[test] + fn test_quote_inside_field() { + let file = write_csv("\"aaa\",\"b\"\"bb\",\"ccc\"\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=false", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 3); + assert_eq!( + rows, + vec![vec![cell!("aaa"), cell!("b\"bb"), cell!("ccc")],] + ); + } + + #[test] + fn test_custom_schema() { + let file = write_csv("1,Alice\n2,Bob\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=false", + "schema=CREATE TABLE x(id INT, name TEXT)", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!( + rows, + vec![ + vec![cell!("1"), cell!("Alice")], + vec![cell!("2"), cell!("Bob")] + ] + ); + } + + #[test] + fn test_more_than_one_filename_argument() { + let result = try_new_table(vec!["filename=aaa.csv", "filename=bbb.csv"]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + + #[test] + fn test_more_than_one_data_argument() { + let result = try_new_table(vec!["data=1,Alice\n2,Bob\n", "data=3,Alice\n4,Bob\n"]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + + #[test] + fn test_more_than_one_schema_argument() { + let result = try_new_table(vec![ + "schema=CREATE TABLE x(id INT, name TEXT)", + "schema=CREATE TABLE x(key INT, value TEXT)", + ]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + + #[test] + fn test_more_than_one_columns_argument() { + let result = try_new_table(vec!["columns=2", "columns=6"]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + + #[test] + fn test_more_than_one_header_argument() { + let result = try_new_table(vec!["header=true", "header=false"]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + + #[test] + fn test_unrecognized_argument() { + let result = try_new_table(vec!["non_existent=abc"]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + + #[test] + fn test_missing_filename_and_data() { + let result = try_new_table(vec!["header=false"]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + + #[test] + fn test_conflicting_filename_and_data() { + let result = try_new_table(vec!["filename=a.csv", "data=id,name\n1,Alice\n2,Bob\n"]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + + #[test] + fn test_header_argument_parsing() { + let true_values = ["true", "TRUE", "yes", "on", "1"]; + let false_values = ["false", "FALSE", "no", "off", "0"]; + + for &val in &true_values { + let result = try_new_table(vec![ + "data=id,name\n1,Alice\n2,Bob\n", + &format!("header={}", val), + ]); + assert!(result.is_ok(), "Expected Ok for header='{}'", val); + assert_eq!( + result.unwrap().1.header, + true, + "Expected true for '{}'", + val + ); + } + + for &val in &false_values { + let result = try_new_table(vec![ + "data=id,name\n1,Alice\n2,Bob\n", + &format!("header={}", val), + ]); + assert!(result.is_ok(), "Expected Ok for header='{}'", val); + assert_eq!( + result.unwrap().1.header, + false, + "Expected false for '{}'", + val + ); + } + } + + #[test] + fn test_invalid_header_argument() { + let invalid_values = ["tru", "2", "maybe", "onoff", "", "\"true\""]; + + for &val in &invalid_values { + let result = try_new_table(vec![ + "data=id,name\n1,Alice\n2,Bob\n", + &format!("header={}", val), + ]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + } + + #[test] + fn test_arguments_with_whitespace() { + let table = new_table(vec![ + " data = id,name\n1,Alice\n2,Bob\n ", + " header = true ", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!( + rows, + vec![ + vec![cell!("1"), cell!("Alice")], + vec![cell!("2"), cell!("Bob")] + ] + ); + } + + #[test] + fn test_unparsable_argument() { + let unparsable_arguments = [ + "header", + "schema='CREATE TABLE x(id INT, name TEXT)", + "schema=\"CREATE TABLE x(id INT, name TEXT)", + "schema=\"CREATE TABLE x(id INT, name TEXT)'", + ]; + + for &val in &unparsable_arguments { + let result = try_new_table(vec!["data=id,name\n1,Alice\n2,Bob\n", val]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + } + + #[test] + fn test_escaped_quote() { + let quotes = ["'", "\""]; + + for "e in "es { + let table = new_table(vec![&format!( + "data={}aa{}{}bb{}", + quote, quote, quote, quote + )]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 1); + assert_eq!(rows, vec![vec![cell!(format!("aa{}bb", quote))]]); + } + } + + #[test] + fn test_unescaped_quote() { + let cases = [("", "'"), ("", "\""), ("'", "\""), ("\"", "'")]; + + for &case in &cases { + let (outer, inner) = case; + let table = new_table(vec![&format!( + "data={}aa{}{}bb{}", + outer, inner, inner, outer + )]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 1); + assert_eq!(rows, vec![vec![cell!(format!("aa{}{}bb", inner, inner))]]); + } + } + + #[test] + fn test_non_existent_file() { + let result = try_new_table(vec!["filename=non_existent.csv"]); + assert!(matches!(result, Err(ResultCode::Error))); + } + + #[test] + fn test_invalid_columns_argument() { + let invalid_values = ["0", "-2", "\"2\"", "'2'"]; + + for &val in &invalid_values { + let result = try_new_table(vec![ + "data=id,name\n1,Alice\n2,Bob\n", + &format!("columns={}", val), + ]); + assert!(matches!(result, Err(ResultCode::InvalidArgs))); + } + } + + #[test] + fn test_more_columns_than_in_file() { + let file = write_csv("1,Alice\n2,Bob\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=false", + "columns=4", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 4); + assert_eq!( + rows, + vec![ + vec![cell!("1"), cell!("Alice"), None, None], + vec![cell!("2"), cell!("Bob"), None, None] + ] + ); + } + + #[test] + fn test_fewer_columns_than_in_file() { + let file = write_csv("1,Alice\n2,Bob\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=false", + "columns=1", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 1); + assert_eq!(rows, vec![vec![cell!("1")], vec![cell!("2")]]); + } + + #[test] + fn test_fewer_columns_than_in_schema() { + let file = write_csv("1,Alice,2002\n2,Bob,2000\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=false", + "columns=1", + "schema='CREATE TABLE x(id INT, name TEXT)'", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!(rows, vec![vec![cell!("1"), None], vec![cell!("2"), None]]); + } + + #[test] + fn test_more_columns_than_in_schema() { + let file = write_csv("1,Alice,2002\n2,Bob,2000\n"); + let table = new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=false", + "columns=5", + "schema='CREATE TABLE x(id INT, name TEXT)'", + ]); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!( + rows, + vec![ + vec![cell!("1"), cell!("Alice")], + vec![cell!("2"), cell!("Bob")] + ] + ); + } + + #[test] + fn test_double_quote_in_header() { + let file = write_csv("id,first\"name\n1,Alice\n2,Bob\n"); + let (schema, table) = try_new_table(vec![ + &format!("filename={}", file.path().to_string_lossy()), + "header=true", + ]) + .unwrap(); + let cursor = table.open().unwrap(); + let rows = read_rows(cursor, 2); + assert_eq!( + rows, + vec![ + vec![cell!("1"), cell!("Alice")], + vec![cell!("2"), cell!("Bob")] + ] + ); + assert_eq!( + schema, + "CREATE TABLE x(\"id\" TEXT, \"first\"\"name\" TEXT)" + ); + } +} diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 21d3a89fa..f609ded91 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -1,9 +1,10 @@ use limbo_ext::{ - register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, + register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, VTable, + Value, }; register_extension! { - vtabs: { GenerateSeriesVTab } + vtabs: { GenerateSeriesVTabModule } } macro_rules! try_option { @@ -17,26 +18,32 @@ macro_rules! try_option { /// A virtual table that generates a sequence of integers #[derive(Debug, VTabModuleDerive, Default)] -struct GenerateSeriesVTab; +struct GenerateSeriesVTabModule; -impl VTabModule for GenerateSeriesVTab { - type VCursor = GenerateSeriesCursor; - type Error = ResultCode; +impl VTabModule for GenerateSeriesVTabModule { + type Table = GenerateSeriesTable; const NAME: &'static str = "generate_series"; const VTAB_KIND: VTabKind = VTabKind::TableValuedFunction; - fn create_schema(_args: &[Value]) -> String { - // Create table schema - "CREATE TABLE generate_series( + fn create(_args: &[Value]) -> Result<(String, Self::Table), ResultCode> { + let schema = "CREATE TABLE generate_series( value INTEGER, start INTEGER HIDDEN, stop INTEGER HIDDEN, step INTEGER HIDDEN )" - .into() + .into(); + Ok((schema, GenerateSeriesTable {})) } +} - fn open(&self) -> Result { +struct GenerateSeriesTable {} + +impl VTable for GenerateSeriesTable { + type Cursor = GenerateSeriesCursor; + type Error = ResultCode; + + fn open(&self) -> Result { Ok(GenerateSeriesCursor { start: 0, stop: 0, @@ -44,53 +51,6 @@ impl VTabModule for GenerateSeriesVTab { current: 0, }) } - - fn filter(cursor: &mut Self::VCursor, args: &[Value], _: Option<(&str, i32)>) -> ResultCode { - // args are the start, stop, and step - if args.is_empty() || args.len() > 3 { - return ResultCode::InvalidArgs; - } - let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); - let stop = try_option!( - args.get(1).map(|v| v.to_integer().unwrap_or(i64::MAX)), - ResultCode::EOF // Sqlite returns an empty series for wacky args - ); - let mut step = args - .get(2) - .map(|v| v.to_integer().unwrap_or(1)) - .unwrap_or(1); - - // Convert zero step to 1, matching SQLite behavior - if step == 0 { - step = 1; - } - - cursor.start = start; - cursor.step = step; - cursor.stop = stop; - - // Set initial value based on range validity - // For invalid input SQLite returns an empty series - cursor.current = if cursor.is_invalid_range() { - return ResultCode::EOF; - } else { - start - }; - - ResultCode::OK - } - - fn column(cursor: &Self::VCursor, idx: u32) -> Result { - cursor.column(idx) - } - - fn next(cursor: &mut Self::VCursor) -> ResultCode { - cursor.next() - } - - fn eof(cursor: &Self::VCursor) -> bool { - cursor.eof() - } } /// The cursor for iterating over the generated sequence @@ -128,6 +88,41 @@ impl GenerateSeriesCursor { impl VTabCursor for GenerateSeriesCursor { type Error = ResultCode; + fn filter(&mut self, args: &[Value], _: Option<(&str, i32)>) -> ResultCode { + // args are the start, stop, and step + if args.is_empty() || args.len() > 3 { + return ResultCode::InvalidArgs; + } + let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); + let stop = try_option!( + args.get(1).map(|v| v.to_integer().unwrap_or(i64::MAX)), + ResultCode::EOF // Sqlite returns an empty series for wacky args + ); + let mut step = args + .get(2) + .map(|v| v.to_integer().unwrap_or(1)) + .unwrap_or(1); + + // Convert zero step to 1, matching SQLite behavior + if step == 0 { + step = 1; + } + + self.start = start; + self.step = step; + self.stop = stop; + + // Set initial value based on range validity + // For invalid input SQLite returns an empty series + self.current = if self.is_invalid_range() { + return ResultCode::EOF; + } else { + start + }; + + ResultCode::OK + } + fn next(&mut self) -> ResultCode { if self.eof() { return ResultCode::EOF; @@ -229,7 +224,7 @@ mod tests { } // Helper function to collect all values from a cursor, returns Result with error code fn collect_series(series: Series) -> Result, ResultCode> { - let tbl = GenerateSeriesVTab; + let tbl = GenerateSeriesTable {}; let mut cursor = tbl.open()?; // Create args array for filter @@ -240,7 +235,7 @@ mod tests { ]; // Initialize cursor through filter - match GenerateSeriesVTab::filter(&mut cursor, &args, None) { + match cursor.filter(&args, None) { ResultCode::OK => (), ResultCode::EOF => return Ok(vec![]), err => return Err(err), @@ -255,7 +250,7 @@ mod tests { (series.stop - series.start) / series.step + 1 ); } - match GenerateSeriesVTab::next(&mut cursor) { + match cursor.next() { ResultCode::OK => (), ResultCode::EOF => break, err => return Err(err), @@ -546,7 +541,7 @@ mod tests { let start = series.start; let stop = series.stop; let step = series.step; - let tbl = GenerateSeriesVTab {}; + let tbl = GenerateSeriesTable {}; let mut cursor = tbl.open().unwrap(); let args = vec![ @@ -556,12 +551,12 @@ mod tests { ]; // Initialize cursor through filter - GenerateSeriesVTab::filter(&mut cursor, &args, None); + cursor.filter(&args, None); let mut rowids = vec![]; - while !GenerateSeriesVTab::eof(&cursor) { + while !cursor.eof() { let cur_rowid = cursor.rowid(); - match GenerateSeriesVTab::next(&mut cursor) { + match cursor.next() { ResultCode::OK => rowids.push(cur_rowid), ResultCode::EOF => break, err => panic!( diff --git a/extensions/tests/src/lib.rs b/extensions/tests/src/lib.rs index 5c6495595..b70b5497b 100644 --- a/extensions/tests/src/lib.rs +++ b/extensions/tests/src/lib.rs @@ -1,7 +1,8 @@ use lazy_static::lazy_static; use limbo_ext::{ register_extension, scalar, ConstraintInfo, ConstraintOp, ConstraintUsage, ExtResult, - IndexInfo, OrderByInfo, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, + IndexInfo, OrderByInfo, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, VTable, + Value, }; #[cfg(not(target_family = "wasm"))] use limbo_ext::{VfsDerive, VfsExtension, VfsFile}; @@ -11,7 +12,7 @@ use std::io::{Read, Seek, SeekFrom, Write}; use std::sync::Mutex; register_extension! { - vtabs: { KVStoreVTab }, + vtabs: { KVStoreVTabModule }, scalars: { test_scalar }, vfs: { TestFS }, } @@ -21,7 +22,7 @@ lazy_static! { } #[derive(VTabModuleDerive, Default)] -pub struct KVStoreVTab; +pub struct KVStoreVTabModule; /// the cursor holds a snapshot of (rowid, key, value) in memory. pub struct KVStoreCursor { @@ -29,17 +30,114 @@ pub struct KVStoreCursor { index: Option, } -impl VTabModule for KVStoreVTab { - type VCursor = KVStoreCursor; +impl VTabModule for KVStoreVTabModule { + type Table = KVStoreTable; const VTAB_KIND: VTabKind = VTabKind::VirtualTable; const NAME: &'static str = "kv_store"; + + fn create(_args: &[Value]) -> Result<(String, Self::Table), ResultCode> { + let schema = "CREATE TABLE x (key TEXT PRIMARY KEY, value TEXT);".to_string(); + Ok((schema, KVStoreTable {})) + } +} + +fn hash_key(key: &str) -> i64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + key.hash(&mut hasher); + hasher.finish() as i64 +} + +impl VTabCursor for KVStoreCursor { type Error = String; - fn create_schema(_args: &[Value]) -> String { - "CREATE TABLE x (key TEXT PRIMARY KEY, value TEXT);".to_string() + fn filter(&mut self, args: &[Value], idx_str: Option<(&str, i32)>) -> ResultCode { + match idx_str { + Some(("key_eq", 1)) => { + let key = args + .first() + .and_then(|v| v.to_text()) + .map(|s| s.to_string()); + log::debug!("idx_str found: key_eq\n value: {:?}", key); + if let Some(key) = key { + let rowid = hash_key(&key); + let store = GLOBAL_STORE.lock().unwrap(); + if let Some((k, v)) = store.get(&rowid) { + self.rows.push((rowid, k.clone(), v.clone())); + self.index = Some(0); + } else { + self.rows.clear(); + self.index = None; + return ResultCode::EOF; + } + return ResultCode::OK; + } + self.rows.clear(); + self.index = None; + ResultCode::OK + } + _ => { + let store = GLOBAL_STORE.lock().unwrap(); + self.rows = store + .iter() + .map(|(&rowid, (k, v))| (rowid, k.clone(), v.clone())) + .collect(); + self.rows.sort_by_key(|(rowid, _, _)| *rowid); + if self.rows.is_empty() { + self.index = None; + ResultCode::EOF + } else { + self.index = Some(0); + ResultCode::OK + } + } + } } - fn open(&self) -> Result { + fn rowid(&self) -> i64 { + if self.index.is_some_and(|c| c < self.rows.len()) { + self.rows[self.index.unwrap_or(0)].0 + } else { + log::error!("rowid: -1"); + -1 + } + } + + fn column(&self, idx: u32) -> Result { + if self.index.is_some_and(|c| c >= self.rows.len()) { + return Err("cursor out of range".into()); + } + if let Some((_, ref key, ref val)) = self.rows.get(self.index.unwrap_or(0)) { + match idx { + 0 => Ok(Value::from_text(key.clone())), // key + 1 => Ok(Value::from_text(val.clone())), // value + _ => Err("Invalid column".into()), + } + } else { + Err("Invalid Column".into()) + } + } + + fn eof(&self) -> bool { + self.index.is_some_and(|s| s >= self.rows.len()) || self.index.is_none() + } + + fn next(&mut self) -> ResultCode { + self.index = Some(self.index.unwrap_or(0) + 1); + if self.index.is_some_and(|c| c >= self.rows.len()) { + return ResultCode::EOF; + } + ResultCode::OK + } +} + +pub struct KVStoreTable {} + +impl VTable for KVStoreTable { + type Cursor = KVStoreCursor; + type Error = String; + + fn open(&self) -> Result { let _ = env_logger::try_init(); Ok(KVStoreCursor { rows: Vec::new(), @@ -88,53 +186,6 @@ impl VTabModule for KVStoreVTab { } } - fn filter( - cursor: &mut Self::VCursor, - args: &[Value], - idx_str: Option<(&str, i32)>, - ) -> ResultCode { - match idx_str { - Some(("key_eq", 1)) => { - let key = args - .first() - .and_then(|v| v.to_text()) - .map(|s| s.to_string()); - log::debug!("idx_str found: key_eq\n value: {:?}", key); - if let Some(key) = key { - let rowid = hash_key(&key); - let store = GLOBAL_STORE.lock().unwrap(); - if let Some((k, v)) = store.get(&rowid) { - cursor.rows.push((rowid, k.clone(), v.clone())); - cursor.index = Some(0); - } else { - cursor.rows.clear(); - cursor.index = None; - return ResultCode::EOF; - } - return ResultCode::OK; - } - cursor.rows.clear(); - cursor.index = None; - ResultCode::OK - } - _ => { - let store = GLOBAL_STORE.lock().unwrap(); - cursor.rows = store - .iter() - .map(|(&rowid, (k, v))| (rowid, k.clone(), v.clone())) - .collect(); - cursor.rows.sort_by_key(|(rowid, _, _)| *rowid); - if cursor.rows.is_empty() { - cursor.index = None; - ResultCode::EOF - } else { - cursor.index = Some(0); - ResultCode::OK - } - } - } - } - fn insert(&mut self, values: &[Value]) -> Result { let key = values .first() @@ -169,71 +220,12 @@ impl VTabModule for KVStoreVTab { Ok(()) } - fn eof(cursor: &Self::VCursor) -> bool { - cursor.index.is_some_and(|s| s >= cursor.rows.len()) || cursor.index.is_none() - } - - fn next(cursor: &mut Self::VCursor) -> ResultCode { - cursor.index = Some(cursor.index.unwrap_or(0) + 1); - if cursor.index.is_some_and(|c| c >= cursor.rows.len()) { - return ResultCode::EOF; - } - ResultCode::OK - } - - fn column(cursor: &Self::VCursor, idx: u32) -> Result { - if cursor.index.is_some_and(|c| c >= cursor.rows.len()) { - return Err("cursor out of range".into()); - } - if let Some((_, ref key, ref val)) = cursor.rows.get(cursor.index.unwrap_or(0)) { - match idx { - 0 => Ok(Value::from_text(key.clone())), // key - 1 => Ok(Value::from_text(val.clone())), // value - _ => Err("Invalid column".into()), - } - } else { - Err("Invalid Column".into()) - } - } - fn destroy(&mut self) -> Result<(), Self::Error> { println!("VDestroy called"); Ok(()) } } -fn hash_key(key: &str) -> i64 { - use std::hash::{Hash, Hasher}; - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - key.hash(&mut hasher); - hasher.finish() as i64 -} - -impl VTabCursor for KVStoreCursor { - type Error = String; - - fn rowid(&self) -> i64 { - if self.index.is_some_and(|c| c < self.rows.len()) { - self.rows[self.index.unwrap_or(0)].0 - } else { - log::error!("rowid: -1"); - -1 - } - } - - fn column(&self, idx: u32) -> Result { - ::column(self, idx) - } - - fn eof(&self) -> bool { - ::eof(self) - } - - fn next(&mut self) -> ResultCode { - ::next(self) - } -} - pub struct TestFile { file: File, } diff --git a/macros/src/ext/vtab_derive.rs b/macros/src/ext/vtab_derive.rs index 0dd67ca0d..0bcb3c53a 100644 --- a/macros/src/ext/vtab_derive.rs +++ b/macros/src/ext/vtab_derive.rs @@ -7,7 +7,7 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let struct_name = &ast.ident; let register_fn_name = format_ident!("register_{}", struct_name); - let create_schema_fn_name = format_ident!("create_schema_{}", struct_name); + let create_fn_name = format_ident!("create_{}", struct_name); let open_fn_name = format_ident!("open_{}", struct_name); let close_fn_name = format_ident!("close_{}", struct_name); let filter_fn_name = format_ident!("filter_{}", struct_name); @@ -22,26 +22,40 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let expanded = quote! { impl #struct_name { #[no_mangle] - unsafe extern "C" fn #create_schema_fn_name( + unsafe extern "C" fn #create_fn_name( argv: *const ::limbo_ext::Value, argc: i32 - ) -> *mut ::std::ffi::c_char { + ) -> ::limbo_ext::VTabCreateResult { let args = if argv.is_null() { &Vec::new() } else { ::std::slice::from_raw_parts(argv, argc as usize) }; - let sql = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); - ::std::ffi::CString::new(sql).unwrap().into_raw() + match <#struct_name as ::limbo_ext::VTabModule>::create(&args) { + Ok((schema, table)) => { + ::limbo_ext::VTabCreateResult { + code: ::limbo_ext::ResultCode::OK, + schema: ::std::ffi::CString::new(schema).unwrap().into_raw(), + table: ::std::boxed::Box::into_raw(::std::boxed::Box::new(table)) as *const ::std::ffi::c_void, + } + }, + Err(e) => { + ::limbo_ext::VTabCreateResult { + code: e, + schema: ::std::ptr::null(), + table: ::std::ptr::null(), + } + } + } } #[no_mangle] - unsafe extern "C" fn #open_fn_name(ctx: *const ::std::ffi::c_void) -> *const ::std::ffi::c_void { - if ctx.is_null() { + unsafe extern "C" fn #open_fn_name(table: *const ::std::ffi::c_void) -> *const ::std::ffi::c_void { + if table.is_null() { return ::std::ptr::null(); } - let ctx = ctx as *const #struct_name; - let ctx: &#struct_name = &*ctx; - if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open(ctx) { + let table = table as *const <#struct_name as ::limbo_ext::VTabModule>::Table; + let table: &<#struct_name as ::limbo_ext::VTabModule>::Table = &*table; + if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::Table::open(table) { return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *const ::std::ffi::c_void; } else { return ::std::ptr::null(); @@ -55,8 +69,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { if cursor.is_null() { return ::limbo_ext::ResultCode::Error; } - let boxed_cursor = ::std::boxed::Box::from_raw(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); - boxed_cursor.close() + let cursor = cursor as *mut <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor; + let cursor = ::std::boxed::Box::from_raw(cursor); + <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor::close(&*cursor) } #[no_mangle] @@ -70,14 +85,14 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { if cursor.is_null() { return ::limbo_ext::ResultCode::Error; } - let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + let cursor = &mut *(cursor as *mut <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor); let args = ::std::slice::from_raw_parts(argv, argc as usize); let idx_str = if idx_str.is_null() { None } else { Some((unsafe { ::std::ffi::CStr::from_ptr(idx_str).to_str().unwrap() }, idx_num)) }; - <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, args, idx_str) + <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor::filter(cursor, args, idx_str) } #[no_mangle] @@ -88,8 +103,8 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { if cursor.is_null() { return ::limbo_ext::Value::error(::limbo_ext::ResultCode::Error); } - let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; - match <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) { + let cursor = &*(cursor as *const <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor); + match <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor::column(cursor, idx) { Ok(val) => val, Err(e) => ::limbo_ext::Value::error_with_message(e.to_string()) } @@ -102,8 +117,8 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { if cursor.is_null() { return ::limbo_ext::ResultCode::Error; } - let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); - <#struct_name as ::limbo_ext::VTabModule>::next(cursor) + let cursor = &mut *(cursor as *mut <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor); + <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor::next(cursor) } #[no_mangle] @@ -113,22 +128,22 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { if cursor.is_null() { return true; } - let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); - <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) + let cursor = &*(cursor as *const <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor); + <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor::eof(cursor) } #[no_mangle] unsafe extern "C" fn #update_fn_name( - vtab: *const ::std::ffi::c_void, + table: *const ::std::ffi::c_void, argc: i32, argv: *const ::limbo_ext::Value, p_out_rowid: *mut i64, ) -> ::limbo_ext::ResultCode { - if vtab.is_null() { + if table.is_null() { return ::limbo_ext::ResultCode::Error; } - let vtab = &mut *(vtab as *mut #struct_name); + let table = &mut *(table as *mut <#struct_name as ::limbo_ext::VTabModule>::Table); let args = ::std::slice::from_raw_parts(argv, argc as usize); let old_rowid = match args.get(0).map(|v| v.value_type()) { @@ -143,21 +158,21 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { match (old_rowid, new_rowid) { // DELETE: old_rowid provided, no new_rowid (Some(old), None) => { - if <#struct_name as VTabModule>::delete(vtab, old).is_err() { + if <#struct_name as VTabModule>::Table::delete(table, old).is_err() { return ::limbo_ext::ResultCode::Error; } return ::limbo_ext::ResultCode::OK; } // UPDATE: old_rowid provided and new_rowid may exist (Some(old), Some(new)) => { - if <#struct_name as VTabModule>::update(vtab, old, &columns).is_err() { + if <#struct_name as VTabModule>::Table::update(table, old, &columns).is_err() { return ::limbo_ext::ResultCode::Error; } return ::limbo_ext::ResultCode::OK; } // INSERT: no old_rowid (old_rowid = None) (None, _) => { - if let Ok(rowid) = <#struct_name as VTabModule>::insert(vtab, &columns) { + if let Ok(rowid) = <#struct_name as VTabModule>::Table::insert(table, &columns) { if !p_out_rowid.is_null() { *p_out_rowid = rowid; return ::limbo_ext::ResultCode::RowID; @@ -170,24 +185,26 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { } #[no_mangle] - pub unsafe extern "C" fn #rowid_fn_name(ctx: *const ::std::ffi::c_void) -> i64 { - if ctx.is_null() { + pub unsafe extern "C" fn #rowid_fn_name(cursor: *const ::std::ffi::c_void) -> i64 { + if cursor.is_null() { return -1; } - let cursor = &*(ctx as *const <#struct_name as ::limbo_ext::VTabModule>::VCursor); - <<#struct_name as ::limbo_ext::VTabModule>::VCursor as ::limbo_ext::VTabCursor>::rowid(cursor) + let cursor = &*(cursor as *const <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor); + <<#struct_name as ::limbo_ext::VTabModule>::Table as ::limbo_ext::VTable>::Cursor::rowid(cursor) } #[no_mangle] unsafe extern "C" fn #destroy_fn_name( - vtab: *const ::std::ffi::c_void, + table: *const ::std::ffi::c_void, ) -> ::limbo_ext::ResultCode { - if vtab.is_null() { + if table.is_null() { return ::limbo_ext::ResultCode::Error; } - let vtab = &mut *(vtab as *mut #struct_name); - if <#struct_name as VTabModule>::destroy(vtab).is_err() { + // Take ownership of the table so it can be properly dropped. + let mut table: ::std::boxed::Box<<#struct_name as ::limbo_ext::VTabModule>::Table> = + ::std::boxed::Box::from_raw(table as *mut <#struct_name as ::limbo_ext::VTabModule>::Table); + if <#struct_name as VTabModule>::Table::destroy(&mut *table).is_err() { return ::limbo_ext::ResultCode::Error; } @@ -203,7 +220,7 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { ) -> ::limbo_ext::ExtIndexInfo { let constraints = if n_constraints > 0 { std::slice::from_raw_parts(constraints, n_constraints as usize) } else { &[] }; let order_by = if n_order_by > 0 { std::slice::from_raw_parts(order_by, n_order_by as usize) } else { &[] }; - <#struct_name as ::limbo_ext::VTabModule>::best_index(constraints, order_by).to_ffi() + <#struct_name as ::limbo_ext::VTabModule>::Table::best_index(constraints, order_by).to_ffi() } #[no_mangle] @@ -216,11 +233,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let api = &*api; let name = <#struct_name as ::limbo_ext::VTabModule>::NAME; let name_c = ::std::ffi::CString::new(name).unwrap().into_raw() as *const ::std::ffi::c_char; - let table_instance = ::std::boxed::Box::into_raw(::std::boxed::Box::new(#struct_name::default())); let module = ::limbo_ext::VTabModuleImpl { - ctx: table_instance as *const ::std::ffi::c_void, name: name_c, - create_schema: Self::#create_schema_fn_name, + create: Self::#create_fn_name, open: Self::#open_fn_name, close: Self::#close_fn_name, filter: Self::#filter_fn_name, diff --git a/macros/src/lib.rs b/macros/src/lib.rs index e173d47ba..b65af2234 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -230,55 +230,47 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// Macro to derive a VTabModule for your extension. This macro will generate /// the necessary functions to register your module with core. You must implement -/// the VTabModule trait for your struct, and the VTabCursor trait for your cursor. +/// the VTabModule, VTable, and VTabCursor traits. /// ```ignore -///#[derive(Debug, VTabModuleDerive)] -///struct CsvVTab; -///impl VTabModule for CsvVTab { -/// type VCursor = CsvCursor; -/// const NAME: &'static str = "csv_data"; +/// #[derive(Debug, VTabModuleDerive)] +/// struct CsvVTabModule; /// -/// /// Declare the schema for your virtual table -/// fn create_schema(args: &[&str]) -> &'static str { -/// let sql = "CREATE TABLE csv_data( -/// name TEXT, -/// age TEXT, -/// city TEXT -/// )" -/// } -/// /// Open the virtual table and return a cursor -/// fn open() -> Self::VCursor { -/// let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); -/// let rows: Vec> = csv_content -/// .lines() -/// .skip(1) -/// .map(|line| { -/// line.split(',') -/// .map(|s| s.trim().to_string()) -/// .collect() -/// }) -/// .collect(); -/// CsvCursor { rows, index: 0 } -/// } -/// /// Filter the virtual table based on arguments (omitted here for simplicity) -/// fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { -/// ResultCode::OK -/// } -/// /// Return the value for a given column index -/// fn column(cursor: &Self::VCursor, idx: u32) -> Value { -/// cursor.column(idx) +/// impl VTabModule for CsvVTabModule { +/// type Table = CsvTable; +/// const NAME: &'static str = "csv_data"; +/// const VTAB_KIND: VTabKind = VTabKind::VirtualTable; +/// +/// /// Declare your virtual table and its schema +/// fn create(args: &[Value]) -> Result<(String, Self::Table), ResultCode> { +/// let schema = "CREATE TABLE csv_data( +/// name TEXT, +/// age TEXT, +/// city TEXT +/// )".into(); +/// Ok((schema, CsvTable {})) /// } -/// /// Move the cursor to the next row -/// fn next(cursor: &mut Self::VCursor) -> ResultCode { -/// if cursor.index < cursor.rows.len() - 1 { -/// cursor.index += 1; -/// ResultCode::OK -/// } else { -/// ResultCode::EOF -/// } -/// } -/// fn eof(cursor: &Self::VCursor) -> bool { -/// cursor.index >= cursor.rows.len() +/// } +/// +/// struct CsvTable {} +/// +/// // Implement the VTable trait for your virtual table +/// impl VTable for CsvTable { +/// type Cursor = CsvCursor; +/// type Error = &'static str; +/// +/// /// Open the virtual table and return a cursor +/// fn open(&self) -> Result { +/// let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); +/// let rows: Vec> = csv_content +/// .lines() +/// .skip(1) +/// .map(|line| { +/// line.split(',') +/// .map(|s| s.trim().to_string()) +/// .collect() +/// }) +/// .collect(); +/// Ok(CsvCursor { rows, index: 0 }) /// } /// /// /// **Optional** methods for non-readonly tables: @@ -287,23 +279,28 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// fn update(&mut self, rowid: i64, args: &[Value]) -> Result, Self::Error> { /// Ok(None)// return Ok(None) for read-only /// } +/// /// /// Insert a new row with the provided values, return the new rowid /// fn insert(&mut self, args: &[Value]) -> Result<(), Self::Error> { /// Ok(()) // /// } +/// /// /// Delete the row with the provided rowid /// fn delete(&mut self, rowid: i64) -> Result<(), Self::Error> { /// Ok(()) /// } +/// /// /// Destroy the virtual table. Any cleanup logic for when the table is deleted comes heres /// fn destroy(&mut self) -> Result<(), Self::Error> { /// Ok(()) /// } +/// } /// /// #[derive(Debug)] /// struct CsvCursor { /// rows: Vec>, /// index: usize, +/// } /// /// impl CsvCursor { /// /// Returns the value for a given column index. @@ -315,20 +312,40 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// Value::null() /// } /// } +/// } +/// /// // Implement the VTabCursor trait for your virtual cursor /// impl VTabCursor for CsvCursor { -/// fn next(&mut self) -> ResultCode { -/// Self::next(self) -/// } +/// type Error = &'static str; +/// +/// /// Filter the virtual table based on arguments (omitted here for simplicity) +/// fn filter(&mut self, _args: &[Value], _idx_info: Option<(&str, i32)>) -> ResultCode { +/// ResultCode::OK +/// } +/// +/// /// Move the cursor to the next row +/// fn next(&mut self) -> ResultCode { +/// if self.index < self.rows.len() - 1 { +/// self.index += 1; +/// ResultCode::OK +/// } else { +/// ResultCode::EOF +/// } +/// } +/// /// fn eof(&self) -> bool { /// self.index >= self.rows.len() /// } -/// fn column(&self, idx: u32) -> Value { +/// +/// /// Return the value for a given column index +/// fn column(&self, idx: u32) -> Result { /// self.column(idx) /// } +/// /// fn rowid(&self) -> i64 { /// self.index as i64 /// } +/// } /// #[proc_macro_derive(VTabModuleDerive)] pub fn derive_vtab_module(input: TokenStream) -> TokenStream { diff --git a/scripts/publish-crates.sh b/scripts/publish-crates.sh index 5b1615d53..efd3cb3d9 100755 --- a/scripts/publish-crates.sh +++ b/scripts/publish-crates.sh @@ -5,6 +5,7 @@ cargo publish -p limbo_ext cargo publish -p limbo_ext_tests cargo publish -p limbo_completion cargo publish -p limbo_crypto +cargo publish -p limbo_csv cargo publish -p limbo_percentile cargo publish -p limbo_regexp cargo publish -p limbo_series diff --git a/testing/cli_tests/extensions.py b/testing/cli_tests/extensions.py index 3d1b4b62e..59d9e060e 100755 --- a/testing/cli_tests/extensions.py +++ b/testing/cli_tests/extensions.py @@ -621,6 +621,89 @@ def test_create_virtual_table(): limbo.quit() +def test_csv(): + limbo = TestLimboShell() + ext_path = "./target/debug/liblimbo_csv" + limbo.execute_dot(f".load {ext_path}") + + limbo.run_test_fn( + "CREATE VIRTUAL TABLE temp.csv USING csv(filename=./testing/test_files/test.csv);", + null, + "Create virtual table from CSV file" + ) + limbo.run_test_fn( + "SELECT * FROM temp.csv;", + lambda res: res == "1|2.0|String'1\n3|4.0|String2", + "Read all rows from CSV table" + ) + limbo.run_test_fn( + "SELECT * FROM temp.csv WHERE c2 = 'String2';", + lambda res: res == "3|4.0|String2", + "Filter rows with WHERE clause" + ) + limbo.run_test_fn( + "INSERT INTO temp.csv VALUES (5, 6.0, 'String3');", + lambda res: "Virtual table update failed" in res, + "INSERT into CSV table should fail" + ) + limbo.run_test_fn( + "UPDATE temp.csv SET c0 = 10 WHERE c1 = '2.0';", + lambda res: "Virtual table update failed" in res, + "UPDATE on CSV table should fail" + ) + limbo.run_test_fn( + "DELETE FROM temp.csv WHERE c1 = '2.0';", + lambda res: "Virtual table update failed" in res, + "DELETE on CSV table should fail" + ) + limbo.run_test_fn( + "DROP TABLE temp.csv;", + null, + "Drop CSV table" + ) + limbo.run_test_fn( + "SELECT * FROM temp.csv;", + lambda res: "Parse error: Table csv not found" in res, + "Query dropped CSV table should fail" + ) + limbo.run_test_fn( + "create virtual table t1 using csv(data='1'\\'2');", + lambda res: "unrecognized token at" in res, + "Create CSV table with malformed escape sequence" + ) + limbo.run_test_fn( + "create virtual table t1 using csv(data=\"12');", + lambda res: "non-terminated literal at" in res, + "Create CSV table with unterminated quoted string" + ) + + limbo.run_debug("create virtual table t1 using csv(data='');") + limbo.run_test_fn( + "SELECT c0 FROM t1;", + lambda res: res == "", + "Empty CSV table without a header should have one column: 'c0'" + ) + limbo.run_test_fn( + "SELECT c1 FROM t1;", + lambda res: "Parse error: Column c1 not found" in res, + "Empty CSV table without header should not have columns other than 'c0'" + ) + + limbo.run_debug("create virtual table t2 using csv(data='', header=true);") + limbo.run_test_fn( + "SELECT \"(NULL)\" FROM t2;", + lambda res: res == "", + "Empty CSV table with header should have one column named '(NULL)'" + ) + limbo.run_test_fn( + "SELECT c0 FROM t2;", + lambda res: "Parse error: Column c0 not found" in res, + "Empty CSV table with header should not have columns other than '(NULL)'" + ) + + limbo.quit() + + def cleanup(): if os.path.exists("testing/vfs.db"): os.remove("testing/vfs.db") @@ -641,6 +724,7 @@ def main(): test_kv() test_drop_virtual_table() test_create_virtual_table() + test_csv() except Exception as e: console.error(f"Test FAILED: {e}") cleanup()