From d775b3ea5a756317e81be9ba7c7535dc4325bb81 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Wed, 12 Feb 2025 22:43:26 -0500 Subject: [PATCH 01/13] Improve extension API with results, fix paths in proc macros --- extensions/core/README.md | 35 ++++++++------- extensions/core/src/lib.rs | 76 ++++++++++++++++---------------- extensions/percentile/src/lib.rs | 57 +++++++++++++----------- extensions/series/src/lib.rs | 33 +++++++------- macros/src/lib.rs | 72 +++++++++++++++++------------- 5 files changed, 146 insertions(+), 127 deletions(-) diff --git a/extensions/core/README.md b/extensions/core/README.md index 6dd187122..ddcfe413d 100644 --- a/extensions/core/README.md +++ b/extensions/core/README.md @@ -95,6 +95,9 @@ impl AggFunc for Percentile { /// The state to track during the steps type State = (Vec, Option, Option); // Tracks the values, Percentile, and errors + /// Define your error type, must impl Display + type Error = String; + /// Define the name you wish to call your function by. /// e.g. SELECT percentile(value, 40); const NAME: &str = "percentile"; @@ -129,15 +132,15 @@ impl AggFunc for Percentile { } /// A function to finalize the state into a value to be returned as a result /// or an error (if you chose to track an error state as well) - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { let (mut values, p_value, error) = state; if let Some(error) = error { - return Value::custom_error(error); + return Err(error); } if values.is_empty() { - return Value::null(); + return Ok(Value::null()); } values.sort_by(|a, b| a.partial_cmp(b).unwrap()); @@ -145,7 +148,7 @@ impl AggFunc for Percentile { let p = p_value.unwrap(); let index = (p * (n - 1.0) / 100.0).floor() as usize; - Value::from_float(values[index]) + Ok(Value::from_float(values[index])) } } ``` @@ -161,21 +164,21 @@ struct CsvVTable; impl VTabModule for CsvVTable { type VCursor = CsvCursor; + /// Define your error type. Must impl Display and match VCursor::Error + type Error = &'static str; /// Declare the name for your virtual table const NAME: &'static str = "csv_data"; - /// Declare the table schema and call `api.declare_virtual_table` with the schema sql. - fn connect(api: &ExtensionApi) -> ResultCode { - let sql = "CREATE TABLE csv_data( + fn init_sql() -> &'static str { + "CREATE TABLE csv_data( name TEXT, age TEXT, city TEXT - )"; - api.declare_virtual_table(Self::NAME, sql) + )" } /// Open to return a new cursor: In this simple example, the CSV file is read completely into memory on connect. - fn open() -> Self::VCursor { + fn open() -> Result { // Read CSV file contents from "data.csv" let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); // For simplicity, we'll ignore the header row. @@ -188,7 +191,7 @@ impl VTabModule for CsvVTable { .collect() }) .collect(); - CsvCursor { rows, index: 0 } + Ok(CsvCursor { rows, index: 0 }) } /// Filter through result columns. (not used in this simple example) @@ -197,7 +200,7 @@ impl VTabModule for CsvVTable { } /// Return the value for the column at the given index in the current row. - fn column(cursor: &Self::VCursor, idx: u32) -> Value { + fn column(cursor: &Self::VCursor, idx: u32) -> Result { cursor.column(idx) } @@ -226,6 +229,8 @@ struct CsvCursor { /// Implement the VTabCursor trait for your cursor type impl VTabCursor for CsvCursor { + type Error = &'static str; + fn next(&mut self) -> ResultCode { CsvCursor::next(self) } @@ -234,12 +239,12 @@ impl VTabCursor for CsvCursor { self.index >= self.rows.len() } - fn column(&self, idx: u32) -> Value { + fn column(&self, idx: u32) -> Result { let row = &self.rows[self.index]; if (idx as usize) < row.len() { - Value::from_text(&row[idx as usize]) + Ok(Value::from_text(&row[idx as usize])) } else { - Value::null() + Ok(Value::null()) } } diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 22d90f572..d06340aa2 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -1,38 +1,18 @@ mod types; pub use limbo_macros::{register_extension, scalar, AggregateDerive, VTabModuleDerive}; -use std::os::raw::{c_char, c_void}; +use std::{ + fmt::Display, + os::raw::{c_char, c_void}, +}; pub use types::{ResultCode, Value, ValueType}; #[repr(C)] pub struct ExtensionApi { pub ctx: *mut c_void, - - pub register_scalar_function: unsafe extern "C" fn( - ctx: *mut c_void, - name: *const c_char, - func: ScalarFunction, - ) -> ResultCode, - - pub register_aggregate_function: unsafe extern "C" fn( - ctx: *mut c_void, - name: *const c_char, - args: i32, - init_func: InitAggFunction, - step_func: StepFunction, - finalize_func: FinalizeFunction, - ) -> ResultCode, - - pub register_module: unsafe extern "C" fn( - ctx: *mut c_void, - name: *const c_char, - module: VTabModuleImpl, - ) -> ResultCode, - - pub declare_vtab: unsafe extern "C" fn( - ctx: *mut c_void, - name: *const c_char, - sql: *const c_char, - ) -> ResultCode, + pub register_scalar_function: RegisterScalarFn, + pub register_aggregate_function: RegisterAggFn, + pub register_module: RegisterModuleFn, + pub declare_vtab: DeclareVTabFn, } impl ExtensionApi { @@ -48,16 +28,34 @@ impl ExtensionApi { } pub type ExtensionEntryPoint = unsafe extern "C" fn(api: *const ExtensionApi) -> ResultCode; + pub type ScalarFunction = unsafe extern "C" fn(argc: i32, *const Value) -> Value; +pub type DeclareVTabFn = + unsafe extern "C" fn(ctx: *mut c_void, name: *const c_char, sql: *const c_char) -> ResultCode; + +pub type RegisterScalarFn = + unsafe extern "C" fn(ctx: *mut c_void, name: *const c_char, func: ScalarFunction) -> ResultCode; + +pub type RegisterAggFn = unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + args: i32, + init: InitAggFunction, + step: StepFunction, + finalize: FinalizeFunction, +) -> ResultCode; + +pub type RegisterModuleFn = unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + module: VTabModuleImpl, +) -> ResultCode; + pub type InitAggFunction = unsafe extern "C" fn() -> *mut AggCtx; pub type StepFunction = unsafe extern "C" fn(ctx: *mut AggCtx, argc: i32, argv: *const Value); pub type FinalizeFunction = unsafe extern "C" fn(ctx: *mut AggCtx) -> Value; -pub trait Scalar { - fn call(&self, args: &[Value]) -> Value; -} - #[repr(C)] pub struct AggCtx { pub state: *mut c_void, @@ -65,11 +63,12 @@ pub struct AggCtx { pub trait AggFunc { type State: Default; + type Error: Display; const NAME: &'static str; const ARGS: i32; fn step(state: &mut Self::State, args: &[Value]); - fn finalize(state: Self::State) -> Value; + fn finalize(state: Self::State) -> Result; } #[repr(C)] @@ -98,13 +97,14 @@ pub type VtabFnNext = unsafe extern "C" fn(cursor: *mut c_void) -> ResultCode; pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; pub trait VTabModule: 'static { - type VCursor: VTabCursor; + type VCursor: VTabCursor; const NAME: &'static str; + type Error: std::fmt::Display; - fn connect(api: &ExtensionApi) -> ResultCode; - fn open() -> Self::VCursor; + fn init_sql() -> &'static str; + fn open() -> Result; fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; - fn column(cursor: &Self::VCursor, idx: u32) -> Value; + fn column(cursor: &Self::VCursor, idx: u32) -> Result; fn next(cursor: &mut Self::VCursor) -> ResultCode; fn eof(cursor: &Self::VCursor) -> bool; } @@ -112,7 +112,7 @@ pub trait VTabModule: 'static { pub trait VTabCursor: Sized { type Error: std::fmt::Display; fn rowid(&self) -> i64; - fn column(&self, idx: u32) -> Value; + fn column(&self, idx: u32) -> Result; fn eof(&self) -> bool; fn next(&mut self) -> ResultCode; } diff --git a/extensions/percentile/src/lib.rs b/extensions/percentile/src/lib.rs index 9f81a6674..4b0b7bd83 100644 --- a/extensions/percentile/src/lib.rs +++ b/extensions/percentile/src/lib.rs @@ -9,6 +9,7 @@ struct Median; impl AggFunc for Median { type State = Vec; + type Error = &'static str; const NAME: &'static str = "median"; const ARGS: i32 = 1; @@ -18,9 +19,9 @@ impl AggFunc for Median { } } - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { if state.is_empty() { - return Value::null(); + return Ok(Value::null()); } let mut sorted = state; @@ -28,11 +29,11 @@ impl AggFunc for Median { let len = sorted.len(); if len % 2 == 1 { - Value::from_float(sorted[len / 2]) + Ok(Value::from_float(sorted[len / 2])) } else { let mid1 = sorted[len / 2 - 1]; let mid2 = sorted[len / 2]; - Value::from_float((mid1 + mid2) / 2.0) + Ok(Value::from_float((mid1 + mid2) / 2.0)) } } } @@ -41,8 +42,8 @@ impl AggFunc for Median { struct Percentile; impl AggFunc for Percentile { - type State = (Vec, Option, Option<&'static str>); - + type State = (Vec, Option, Option); + type Error = &'static str; const NAME: &'static str = "percentile"; const ARGS: i32 = 2; @@ -69,16 +70,16 @@ impl AggFunc for Percentile { } } - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { let (mut values, p_value, err_value) = state; if values.is_empty() { - return Value::null(); + return Ok(Value::null()); } if let Some(err) = err_value { - return Value::error_with_message(err.into()); + return Err(err); } if values.len() == 1 { - return Value::from_float(values[0]); + return Ok(Value::from_float(values[0])); } let p = p_value.unwrap(); @@ -89,10 +90,12 @@ impl AggFunc for Percentile { let upper = index.ceil() as usize; if lower == upper { - Value::from_float(values[lower]) + Ok(Value::from_float(values[lower])) } else { let weight = index - lower as f64; - Value::from_float(values[lower] * (1.0 - weight) + values[upper] * weight) + Ok(Value::from_float( + values[lower] * (1.0 - weight) + values[upper] * weight, + )) } } } @@ -101,8 +104,8 @@ impl AggFunc for Percentile { struct PercentileCont; impl AggFunc for PercentileCont { - type State = (Vec, Option, Option<&'static str>); - + type State = (Vec, Option, Option); + type Error = &'static str; const NAME: &'static str = "percentile_cont"; const ARGS: i32 = 2; @@ -129,16 +132,16 @@ impl AggFunc for PercentileCont { } } - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { let (mut values, p_value, err_state) = state; if values.is_empty() { - return Value::null(); + return Ok(Value::null()); } if let Some(err) = err_state { - return Value::error_with_message(err.into()); + return Err(err); } if values.len() == 1 { - return Value::from_float(values[0]); + return Ok(Value::from_float(values[0])); } let p = p_value.unwrap(); @@ -149,10 +152,12 @@ impl AggFunc for PercentileCont { let upper = index.ceil() as usize; if lower == upper { - Value::from_float(values[lower]) + Ok(Value::from_float(values[lower])) } else { let weight = index - lower as f64; - Value::from_float(values[lower] * (1.0 - weight) + values[upper] * weight) + Ok(Value::from_float( + values[lower] * (1.0 - weight) + values[upper] * weight, + )) } } } @@ -161,8 +166,8 @@ impl AggFunc for PercentileCont { struct PercentileDisc; impl AggFunc for PercentileDisc { - type State = (Vec, Option, Option<&'static str>); - + type State = (Vec, Option, Option); + type Error = &'static str; const NAME: &'static str = "percentile_disc"; const ARGS: i32 = 2; @@ -170,19 +175,19 @@ impl AggFunc for PercentileDisc { Percentile::step(state, args); } - fn finalize(state: Self::State) -> Value { + fn finalize(state: Self::State) -> Result { let (mut values, p_value, err_value) = state; if values.is_empty() { - return Value::null(); + return Ok(Value::null()); } if let Some(err) = err_value { - return Value::error_with_message(err.into()); + return Err(err); } let p = p_value.unwrap(); values.sort_by(|a, b| a.partial_cmp(b).unwrap()); let n = values.len() as f64; let index = (p * (n - 1.0)).floor() as usize; - Value::from_float(values[index]) + Ok(Value::from_float(values[index])) } } diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index df2d67da1..6e86f0c93 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -1,6 +1,4 @@ -use limbo_ext::{ - register_extension, ExtensionApi, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value, -}; +use limbo_ext::{register_extension, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value}; register_extension! { vtabs: { GenerateSeriesVTab } @@ -21,26 +19,27 @@ struct GenerateSeriesVTab; impl VTabModule for GenerateSeriesVTab { type VCursor = GenerateSeriesCursor; + type Error = ResultCode; + const NAME: &'static str = "generate_series"; - fn connect(api: &ExtensionApi) -> ResultCode { + fn init_sql() -> &'static str { // Create table schema - let sql = "CREATE TABLE generate_series( + "CREATE TABLE generate_series( value INTEGER, start INTEGER HIDDEN, stop INTEGER HIDDEN, step INTEGER HIDDEN - )"; - api.declare_virtual_table(Self::NAME, sql) + )" } - fn open() -> Self::VCursor { - GenerateSeriesCursor { + fn open() -> Result { + Ok(GenerateSeriesCursor { start: 0, stop: 0, step: 0, current: 0, - } + }) } fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { @@ -78,7 +77,7 @@ impl VTabModule for GenerateSeriesVTab { ResultCode::OK } - fn column(cursor: &Self::VCursor, idx: u32) -> Value { + fn column(cursor: &Self::VCursor, idx: u32) -> Result { cursor.column(idx) } @@ -163,14 +162,14 @@ impl VTabCursor for GenerateSeriesCursor { false } - fn column(&self, idx: u32) -> Value { - match idx { + fn column(&self, idx: u32) -> Result { + Ok(match idx { 0 => Value::from_integer(self.current), 1 => Value::from_integer(self.start), 2 => Value::from_integer(self.stop), 3 => Value::from_integer(self.step), _ => Value::null(), - } + }) } fn rowid(&self) -> i64 { @@ -227,7 +226,7 @@ mod tests { } // Helper function to collect all values from a cursor, returns Result with error code fn collect_series(series: Series) -> Result, ResultCode> { - let mut cursor = GenerateSeriesVTab::open(); + let mut cursor = GenerateSeriesVTab::open()?; // Create args array for filter let args = vec![ @@ -245,7 +244,7 @@ mod tests { let mut values = Vec::new(); loop { - values.push(cursor.column(0).to_integer().unwrap()); + values.push(cursor.column(0)?.to_integer().unwrap()); if values.len() > 1000 { panic!( "Generated more than 1000 values, expected this many: {:?}", @@ -544,7 +543,7 @@ mod tests { let stop = series.stop; let step = series.step; - let mut cursor = GenerateSeriesVTab::open(); + let mut cursor = GenerateSeriesVTab::open().unwrap(); let args = vec![ Value::from_integer(start), diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 632b95615..089579081 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -171,7 +171,7 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { let fn_body = &ast.block; let alias_check = if let Some(alias) = &scalar_info.alias { quote! { - let Ok(alias_c_name) = std::ffi::CString::new(#alias) else { + let Ok(alias_c_name) = ::std::ffi::CString::new(#alias) else { return ::limbo_ext::ResultCode::Error; }; (api.register_scalar_function)( @@ -193,7 +193,7 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { return ::limbo_ext::ResultCode::Error; } let api = unsafe { &*api }; - let Ok(c_name) = std::ffi::CString::new(#name) else { + let Ok(c_name) = ::std::ffi::CString::new(#name) else { return ::limbo_ext::ResultCode::Error; }; (api.register_scalar_function)( @@ -232,6 +232,7 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { /// ///impl AggFunc for SumPlusOne { /// type State = i64; +/// type Error = &'static str; /// const NAME: &'static str = "sum_plus_one"; /// const ARGS: i32 = 1; /// fn step(state: &mut Self::State, args: &[Value]) { @@ -240,8 +241,8 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { /// }; /// *state += val; /// } -/// fn finalize(state: Self::State) -> Value { -/// Value::from_integer(state + 1) +/// fn finalize(state: Self::State) -> Result { +/// Ok(Value::from_integer(state + 1)) /// } ///} /// ``` @@ -259,11 +260,11 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { impl #struct_name { #[no_mangle] pub extern "C" fn #init_fn_name() -> *mut ::limbo_ext::AggCtx { - let state = Box::new(<#struct_name as ::limbo_ext::AggFunc>::State::default()); - let ctx = Box::new(::limbo_ext::AggCtx { - state: Box::into_raw(state) as *mut ::std::os::raw::c_void, + let state = ::std::boxed::Box::new(<#struct_name as ::limbo_ext::AggFunc>::State::default()); + let ctx = ::std::boxed::Box::new(::limbo_ext::AggCtx { + state: ::std::boxed::Box::into_raw(state) as *mut ::std::os::raw::c_void, }); - Box::into_raw(ctx) + ::std::boxed::Box::into_raw(ctx) } #[no_mangle] @@ -275,7 +276,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { unsafe { let ctx = &mut *ctx; let state = &mut *(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); - let args = std::slice::from_raw_parts(argv, argc as usize); + let args = ::std::slice::from_raw_parts(argv, argc as usize); <#struct_name as ::limbo_ext::AggFunc>::step(state, args); } } @@ -286,8 +287,13 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { ) -> ::limbo_ext::Value { unsafe { let ctx = &mut *ctx; - let state = Box::from_raw(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); - <#struct_name as ::limbo_ext::AggFunc>::finalize(*state) + let state = ::std::boxed::Box::from_raw(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); + match <#struct_name as ::limbo_ext::AggFunc>::finalize(*state) { + Ok(val) => val, + Err(e) => { + ::limbo_ext::Value::error_with_message(e.to_string()) + } + } } } @@ -301,7 +307,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { let api = &*api; let name_str = #struct_name::NAME; - let c_name = match std::ffi::CString::new(name_str) { + let c_name = match ::std::ffi::CString::new(name_str) { Ok(cname) => cname, Err(_) => return ::limbo_ext::ResultCode::Error, }; @@ -335,13 +341,12 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// const NAME: &'static str = "csv_data"; /// /// /// Declare the schema for your virtual table -/// fn connect(api: &ExtensionApi) -> ResultCode { +/// fn init_sql() -> &'static str { /// let sql = "CREATE TABLE csv_data( /// name TEXT, /// age TEXT, /// city TEXT -/// )"; -/// api.declare_virtual_table(Self::NAME, sql) +/// )" /// } /// /// Open the virtual table and return a cursor /// fn open() -> Self::VCursor { @@ -424,20 +429,21 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { impl #struct_name { #[no_mangle] unsafe extern "C" fn #connect_fn_name( - db: *const ::std::ffi::c_void, + db: *const ::std::ffi::c_void ) -> ::limbo_ext::ResultCode { - if db.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let api = unsafe { &*(db as *const ExtensionApi) }; - <#struct_name as ::limbo_ext::VTabModule>::connect(api) + let api = &*(db as *const ::limbo_ext::ExtensionApi); + let sql = <#struct_name as ::limbo_ext::VTabModule>::init_sql(); + api.declare_virtual_table(<#struct_name as ::limbo_ext::VTabModule>::NAME, sql) } #[no_mangle] unsafe extern "C" fn #open_fn_name( ) -> *mut ::std::ffi::c_void { - let cursor = <#struct_name as ::limbo_ext::VTabModule>::open(); - Box::into_raw(Box::new(cursor)) as *mut ::std::ffi::c_void + if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open() { + ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *mut ::std::ffi::c_void + } else { + ::std::ptr::null_mut() + } } #[no_mangle] @@ -450,7 +456,7 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { return ::limbo_ext::ResultCode::Error; } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; - let args = std::slice::from_raw_parts(argv, argc as usize); + let args = ::std::slice::from_raw_parts(argv, argc as usize); <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, argc, args) } @@ -460,10 +466,13 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { idx: u32, ) -> ::limbo_ext::Value { if cursor.is_null() { - return ::limbo_ext::Value::error(ResultCode::Error); + return ::limbo_ext::Value::error(::limbo_ext::ResultCode::Error); } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; - <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) + match <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) { + Ok(val) => val, + Err(e) => ::limbo_ext::Value::error_with_message(e.to_string()) + } } #[no_mangle] @@ -495,14 +504,13 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { if api.is_null() { return ::limbo_ext::ResultCode::Error; } - let api = &*api; let name = <#struct_name as ::limbo_ext::VTabModule>::NAME; // name needs to be a c str FFI compatible, NOT CString - let name_c = std::ffi::CString::new(name).unwrap(); + let name_c = ::std::ffi::CString::new(name).unwrap().into_raw() as *const ::std::ffi::c_char; let module = ::limbo_ext::VTabModuleImpl { - name: name_c.as_ptr(), + name: name_c, connect: Self::#connect_fn_name, open: Self::#open_fn_name, filter: Self::#filter_fn_name, @@ -511,7 +519,7 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { eof: Self::#eof_fn_name, }; - (api.register_module)(api.ctx, name_c.as_ptr(), module) + (api.register_module)(api.ctx, name_c, module) } } }; @@ -586,11 +594,13 @@ pub fn register_extension(input: TokenStream) -> TokenStream { }); let vtab_calls = vtabs.iter().map(|vtab_ident| { let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span()); + let connect_fn = syn::Ident::new(&format!("connect_{}", vtab_ident), vtab_ident.span()); quote! { { let result = unsafe{ #vtab_ident::#register_fn(api)}; if result == ::limbo_ext::ResultCode::OK { - let result = <#vtab_ident as ::limbo_ext::VTabModule>::connect(api); + let api = api as *const _ as *const ::std::ffi::c_void; + let result = #vtab_ident::#connect_fn(api); if !result.is_ok() { return result; } From 9c8083231c8b807853f963f27cff7ca96d3842ec Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Fri, 14 Feb 2025 09:34:30 -0500 Subject: [PATCH 02/13] Implement create virtual table and VUpdate opcode --- core/error.rs | 6 + core/ext/mod.rs | 92 ++++------ core/lib.rs | 140 ++++++++++++--- core/schema.rs | 29 +++- core/translate/delete.rs | 5 +- core/translate/insert.rs | 154 ++++++++++++++--- core/translate/main_loop.rs | 26 ++- core/translate/mod.rs | 163 ++++++++++++++++-- core/translate/planner.rs | 33 ++-- core/translate/pragma.rs | 2 +- core/util.rs | 22 ++- core/vdbe/builder.rs | 2 +- core/vdbe/explain.rs | 28 +++ core/vdbe/insn.rs | 16 ++ core/vdbe/mod.rs | 111 +++++++++++- extensions/core/src/lib.rs | 72 +++++--- extensions/core/src/types.rs | 4 + extensions/kvstore/Cargo.toml | 19 ++ extensions/kvstore/src/lib.rs | 103 +++++++++++ extensions/series/src/lib.rs | 17 +- macros/src/lib.rs | 98 ++++++++--- vendored/sqlite3-parser/src/parser/ast/mod.rs | 12 ++ 22 files changed, 940 insertions(+), 214 deletions(-) create mode 100644 extensions/kvstore/Cargo.toml create mode 100644 extensions/kvstore/src/lib.rs diff --git a/core/error.rs b/core/error.rs index 53308114c..7832747eb 100644 --- a/core/error.rs +++ b/core/error.rs @@ -76,5 +76,11 @@ macro_rules! bail_constraint_error { }; } +impl From for LimboError { + fn from(err: limbo_ext::ResultCode) -> Self { + LimboError::ExtensionError(err.to_string()) + } +} + pub const SQLITE_CONSTRAINT: usize = 19; pub const SQLITE_CONSTRAINT_PRIMARYKEY: usize = SQLITE_CONSTRAINT | (6 << 8); diff --git a/core/ext/mod.rs b/core/ext/mod.rs index a4f5d6cc3..3ea7d9692 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,17 +1,20 @@ -use crate::{function::ExternalFunc, util::columns_from_create_table_body, Database, VirtualTable}; -use fallible_iterator::FallibleIterator; -use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabModuleImpl}; -pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; -use limbo_sqlite3_parser::{ - ast::{Cmd, Stmt}, - lexer::sql::Parser, +use crate::{function::ExternalFunc, Database}; +use limbo_ext::{ + ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabKind, VTabModuleImpl, }; +pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; use std::{ - ffi::{c_char, c_void, CStr}, + ffi::{c_char, c_void, CStr, CString}, rc::Rc, }; type ExternAggFunc = (InitAggFunction, StepFunction, FinalizeFunction); +#[derive(Clone)] +pub struct VTabImpl { + pub module_type: VTabKind, + pub implementation: Rc, +} + unsafe extern "C" fn register_scalar_function( ctx: *mut c_void, name: *const c_char, @@ -53,8 +56,12 @@ unsafe extern "C" fn register_module( ctx: *mut c_void, name: *const c_char, module: VTabModuleImpl, + kind: VTabKind, ) -> ResultCode { - let c_str = unsafe { CStr::from_ptr(name) }; + if name.is_null() || ctx.is_null() { + return ResultCode::Error; + } + let c_str = unsafe { CString::from_raw(name as *mut i8) }; let name_str = match c_str.to_str() { Ok(s) => s.to_string(), Err(_) => return ResultCode::Error, @@ -64,31 +71,7 @@ unsafe extern "C" fn register_module( } let db = unsafe { &mut *(ctx as *mut Database) }; - db.register_module_impl(&name_str, module) -} - -unsafe extern "C" fn declare_vtab( - ctx: *mut c_void, - name: *const c_char, - sql: *const c_char, -) -> ResultCode { - let c_str = unsafe { CStr::from_ptr(name) }; - let name_str = match c_str.to_str() { - Ok(s) => s.to_string(), - Err(_) => return ResultCode::Error, - }; - - let c_str = unsafe { CStr::from_ptr(sql) }; - let sql_str = match c_str.to_str() { - Ok(s) => s.to_string(), - Err(_) => return ResultCode::Error, - }; - - if ctx.is_null() { - return ResultCode::Error; - } - let db = unsafe { &mut *(ctx as *mut Database) }; - db.declare_vtab_impl(&name_str, &sql_str) + db.register_module_impl(&name_str, module, kind) } impl Database { @@ -113,32 +96,22 @@ impl Database { ResultCode::OK } - fn register_module_impl(&mut self, name: &str, module: VTabModuleImpl) -> ResultCode { - self.vtab_modules.insert(name.to_string(), Rc::new(module)); - ResultCode::OK - } - - fn declare_vtab_impl(&mut self, name: &str, sql: &str) -> ResultCode { - let mut parser = Parser::new(sql.as_bytes()); - let cmd = parser.next().unwrap().unwrap(); - let Cmd::Stmt(stmt) = cmd else { - return ResultCode::Error; + fn register_module_impl( + &mut self, + name: &str, + module: VTabModuleImpl, + kind: VTabKind, + ) -> ResultCode { + let module = Rc::new(module); + let vmodule = VTabImpl { + module_type: kind, + implementation: module, }; - let Stmt::CreateTable { body, .. } = stmt else { - return ResultCode::Error; - }; - let Ok(columns) = columns_from_create_table_body(*body) else { - return ResultCode::Error; - }; - let vtab_module = self.vtab_modules.get(name).unwrap().clone(); - - let vtab = VirtualTable { - name: name.to_string(), - implementation: vtab_module, - columns, - args: None, - }; - self.syms.borrow_mut().vtabs.insert(name.to_string(), vtab); + self.syms + .borrow_mut() + .vtab_modules + .insert(name.to_string(), vmodule.into()); + println!("Registered module: {}", name); ResultCode::OK } @@ -148,7 +121,6 @@ impl Database { register_scalar_function, register_aggregate_function, register_module, - declare_vtab, } } diff --git a/core/lib.rs b/core/lib.rs index 00889db98..d9727fd98 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -27,7 +27,7 @@ use fallible_iterator::FallibleIterator; use libloading::{Library, Symbol}; #[cfg(not(target_family = "wasm"))] use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; -use limbo_ext::{ResultCode, VTabModuleImpl, Value as ExtValue}; +use limbo_ext::{ResultCode, VTabKind, VTabModuleImpl, Value as ExtValue}; use limbo_sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser}; use parking_lot::RwLock; use schema::{Column, Schema}; @@ -49,7 +49,7 @@ pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; use types::OwnedValue; pub use types::Value; -use util::parse_schema_rows; +use util::{columns_from_create_table_body, parse_schema_rows}; use vdbe::builder::QueryMode; use vdbe::VTabOpaqueCursor; @@ -87,7 +87,6 @@ pub struct Database { schema: Rc>, header: Rc>, syms: Rc>, - vtab_modules: HashMap>, // Shared structures of a Database are the parts that are common to multiple threads that might // create DB connections. _shared_page_cache: Arc>, @@ -149,8 +148,7 @@ impl Database { header: header.clone(), _shared_page_cache: _shared_page_cache.clone(), _shared_wal: shared_wal.clone(), - syms, - vtab_modules: HashMap::new(), + syms: syms.clone(), }; if let Err(e) = db.register_builtins() { return Err(LimboError::ExtensionError(e)); @@ -169,7 +167,7 @@ impl Database { }); let rows = conn.query("SELECT * FROM sqlite_schema")?; let mut schema = schema.borrow_mut(); - parse_schema_rows(rows, &mut schema, io)?; + parse_schema_rows(rows, &mut schema, io, &syms.borrow())?; Ok(db) } @@ -276,10 +274,9 @@ impl Connection { pub fn prepare(self: &Rc, sql: impl AsRef) -> Result { let sql = sql.as_ref(); tracing::trace!("Preparing: {}", sql); - let db = &self.db; let mut parser = Parser::new(sql.as_bytes()); - let syms = &db.syms.borrow(); let cmd = parser.next()?; + let syms = self.db.syms.borrow(); if let Some(cmd) = cmd { match cmd { Cmd::Stmt(stmt) => { @@ -289,7 +286,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Normal, )?); Ok(Statement::new(program, self.pager.clone())) @@ -315,7 +312,7 @@ impl Connection { pub(crate) fn run_cmd(self: &Rc, cmd: Cmd) -> Result> { let db = self.db.clone(); - let syms: &SymbolTable = &db.syms.borrow(); + let syms = db.syms.borrow(); match cmd { Cmd::Stmt(stmt) => { let program = Rc::new(translate::translate( @@ -324,7 +321,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Normal, )?); let stmt = Statement::new(program, self.pager.clone()); @@ -337,7 +334,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Explain, )?; program.explain(); @@ -346,12 +343,8 @@ impl Connection { Cmd::ExplainQueryPlan(stmt) => { match stmt { ast::Stmt::Select(select) => { - let mut plan = prepare_select_plan( - &self.schema.borrow(), - *select, - &self.db.syms.borrow(), - None, - )?; + let mut plan = + prepare_select_plan(&self.schema.borrow(), *select, &syms, None)?; optimize_plan(&mut plan, &self.schema.borrow())?; println!("{}", plan); } @@ -368,10 +361,9 @@ impl Connection { pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); - let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; + let syms = self.db.syms.borrow(); if let Some(cmd) = cmd { match cmd { Cmd::Explain(stmt) => { @@ -381,7 +373,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Explain, )?; program.explain(); @@ -394,7 +386,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), - syms, + &syms, QueryMode::Normal, )?; @@ -524,14 +516,73 @@ pub type StepResult = vdbe::StepResult; #[derive(Clone, Debug)] pub struct VirtualTable { name: String, - args: Option>, + args: Option>, pub implementation: Rc, columns: Vec, } impl VirtualTable { + pub(crate) fn from_args( + tbl_name: Option<&str>, + module_name: &str, + args: &[String], + syms: &SymbolTable, + kind: VTabKind, + ) -> Result> { + let module = syms + .vtab_modules + .get(module_name) + .ok_or(LimboError::ExtensionError(format!( + "Virtual table module not found: {}", + module_name + )))?; + if let VTabKind::VirtualTable = kind { + if module.module_type != VTabKind::VirtualTable { + return Err(LimboError::ExtensionError(format!( + "Virtual table module {} is not a virtual table", + module_name + ))); + } + }; + let schema = module.implementation.as_ref().init_schema(args)?; + let mut parser = Parser::new(schema.as_bytes()); + parser.reset(schema.as_bytes()); + println!("Schema: {}", schema); + if let ast::Cmd::Stmt(ast::Stmt::CreateTable { body, .. }) = parser.next()?.ok_or( + LimboError::ParseError("Failed to parse schema from virtual table module".to_string()), + )? { + let columns = columns_from_create_table_body(&body)?; + let vtab = Rc::new(VirtualTable { + name: tbl_name.unwrap_or(module_name).to_owned(), + args: Some(args.to_vec()), + implementation: module.implementation.clone(), + columns, + }); + return Ok(vtab); + } + Err(crate::LimboError::ParseError( + "Failed to parse schema from virtual table module".to_string(), + )) + } + pub fn open(&self) -> VTabOpaqueCursor { - let cursor = unsafe { (self.implementation.open)() }; + let args = if let Some(args) = &self.args { + args.iter() + .map(|e| std::ffi::CString::new(e.to_string()).unwrap().into_raw()) + .collect() + } else { + Vec::new() + }; + let cursor = + unsafe { (self.implementation.open)(args.as_slice().as_ptr(), args.len() as i32) }; + // free the CString pointers + for arg in args { + unsafe { + if !arg.is_null() { + let _ = std::ffi::CString::from_raw(arg); + } + } + } VTabOpaqueCursor::new(cursor) } @@ -580,13 +631,51 @@ impl VirtualTable { _ => Err(LimboError::ExtensionError("Next failed".to_string())), } } + + pub fn update(&self, args: &[OwnedValue], rowid: Option) -> Result> { + let arg_count = args.len(); + let mut ext_args = Vec::with_capacity(arg_count); + for i in 0..arg_count { + let ownedvalue_arg = args.get(i).unwrap(); + let extvalue_arg: ExtValue = match ownedvalue_arg { + OwnedValue::Null => Ok(ExtValue::null()), + OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)), + OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)), + OwnedValue::Text(t) => Ok(ExtValue::from_text(t.as_str().to_string())), + OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())), + other => Err(LimboError::ExtensionError(format!( + "Unsupported value type: {:?}", + other + ))), + }?; + ext_args.push(extvalue_arg); + } + let rowid = rowid.unwrap_or(-1); + let newrowid = 0i64; + let implementation = self.implementation.as_ref(); + let rc = unsafe { + (self.implementation.update)( + implementation as *const VTabModuleImpl as *mut std::ffi::c_void, + arg_count as i32, + ext_args.as_ptr(), + rowid, + &newrowid as *const _ as *mut i64, + ) + }; + match rc { + ResultCode::OK => Ok(None), + ResultCode::RowID => Ok(Some(newrowid)), + _ => Err(LimboError::ExtensionError(rc.to_string())), + } + } } pub(crate) struct SymbolTable { pub functions: HashMap>, #[cfg(not(target_family = "wasm"))] extensions: Vec<(Library, *const ExtensionApi)>, - pub vtabs: HashMap, + pub vtabs: HashMap>, + pub vtab_modules: HashMap>, } impl std::fmt::Debug for SymbolTable { @@ -631,6 +720,7 @@ impl SymbolTable { vtabs: HashMap::new(), #[cfg(not(target_family = "wasm"))] extensions: Vec::new(), + vtab_modules: HashMap::new(), } } diff --git a/core/schema.rs b/core/schema.rs index 0395b2c28..884867066 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -12,29 +12,46 @@ use std::rc::Rc; use tracing::trace; pub struct Schema { - pub tables: HashMap>, + pub tables: HashMap>, // table_name to list of indexes for the table pub indexes: HashMap>>, } impl Schema { pub fn new() -> Self { - let mut tables: HashMap> = HashMap::new(); + let mut tables: HashMap> = HashMap::new(); let indexes: HashMap>> = HashMap::new(); - tables.insert("sqlite_schema".to_string(), Rc::new(sqlite_schema_table())); + tables.insert( + "sqlite_schema".to_string(), + Rc::new(Table::BTree(sqlite_schema_table().into())), + ); Self { tables, indexes } } - pub fn add_table(&mut self, table: Rc) { + pub fn add_btree_table(&mut self, table: Rc) { let name = normalize_ident(&table.name); - self.tables.insert(name, table); + self.tables.insert(name, Table::BTree(table).into()); } - pub fn get_table(&self, name: &str) -> Option> { + pub fn add_virtual_table(&mut self, table: Rc) { + let name = normalize_ident(&table.name); + self.tables.insert(name, Table::Virtual(table).into()); + } + + pub fn get_table(&self, name: &str) -> Option> { let name = normalize_ident(name); self.tables.get(&name).cloned() } + pub fn get_btree_table(&self, name: &str) -> Option> { + let name = normalize_ident(name); + if let Some(table) = self.tables.get(&name) { + table.btree() + } else { + None + } + } + pub fn add_index(&mut self, index: Rc) { let table_name = normalize_ident(&index.table_name); self.indexes diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 6a55bfc03..81f8ba6ef 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -42,7 +42,10 @@ pub fn prepare_delete_plan( Some(table) => table, None => crate::bail_corrupt_error!("Parse error: no such table: {}", tbl_name), }; - + //if let Some(table) = table.virtual_table() { + // // TODO: emit VUpdate + //} + let table = table.btree().unwrap(); let table_references = vec![TableReference { table: Table::BTree(table.clone()), identifier: table.name.clone(), diff --git a/core/translate/insert.rs b/core/translate/insert.rs index eb36b7e75..5f933e93a 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -1,15 +1,15 @@ use std::ops::Deref; +use std::rc::Rc; use limbo_sqlite3_parser::ast::{ - DistinctNames, Expr, InsertBody, QualifiedName, ResolveType, ResultColumn, With, + DistinctNames, Expr, InsertBody, OneSelect, QualifiedName, ResolveType, ResultColumn, With, }; use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; -use crate::schema::BTreeTable; +use crate::schema::Table; use crate::util::normalize_ident; use crate::vdbe::builder::{ProgramBuilderOpts, QueryMode}; use crate::vdbe::BranchOffset; -use crate::Result; use crate::{ schema::{Column, Schema}, translate::expr::translate_expr, @@ -19,6 +19,7 @@ use crate::{ }, SymbolTable, }; +use crate::{Result, VirtualTable}; use super::emitter::Resolver; @@ -46,32 +47,45 @@ pub fn translate_insert( if on_conflict.is_some() { crate::bail_parse_error!("ON CONFLICT clause is not supported"); } + + let table_name = &tbl_name.name; + let table = match schema.get_table(table_name.0.as_str()) { + Some(table) => table, + None => crate::bail_corrupt_error!("Parse error: no such table: {}", table_name), + }; let resolver = Resolver::new(syms); + if let Some(virtual_table) = &table.virtual_table() { + translate_virtual_table_insert( + &mut program, + virtual_table.clone(), + columns, + body, + on_conflict, + &resolver, + ); + return Ok(program); + } let init_label = program.allocate_label(); program.emit_insn(Insn::Init { target_pc: init_label, }); let start_offset = program.offset(); - // open table - let table_name = &tbl_name.name; - - let table = match schema.get_table(table_name.0.as_str()) { - Some(table) => table, - None => crate::bail_corrupt_error!("Parse error: no such table: {}", table_name), + let Some(btree_table) = table.btree() else { + crate::bail_corrupt_error!("Parse error: no such table: {}", table_name); }; - if !table.has_rowid { + if !btree_table.has_rowid { crate::bail_parse_error!("INSERT into WITHOUT ROWID table is not supported"); } let cursor_id = program.alloc_cursor_id( Some(table_name.0.clone()), - CursorType::BTreeTable(table.clone()), + CursorType::BTreeTable(btree_table.clone()), ); - let root_page = table.root_page; + let root_page = btree_table.root_page; let values = match body { InsertBody::Select(select, None) => match &select.body.select.deref() { - limbo_sqlite3_parser::ast::OneSelect::Values(values) => values, + OneSelect::Values(values) => values, _ => todo!(), }, _ => todo!(), @@ -79,9 +93,9 @@ pub fn translate_insert( let column_mappings = resolve_columns_for_insert(&table, columns, values)?; // Check if rowid was provided (through INTEGER PRIMARY KEY as a rowid alias) - let rowid_alias_index = table.columns.iter().position(|c| c.is_rowid_alias); + let rowid_alias_index = btree_table.columns.iter().position(|c| c.is_rowid_alias); let has_user_provided_rowid = { - assert_eq!(column_mappings.len(), table.columns.len()); + assert_eq!(column_mappings.len(), btree_table.columns.len()); if let Some(index) = rowid_alias_index { column_mappings[index].value_index.is_some() } else { @@ -91,7 +105,7 @@ pub fn translate_insert( // allocate a register for each column in the table. if not provided by user, they will simply be set as null. // allocate an extra register for rowid regardless of whether user provided a rowid alias column. - let num_cols = table.columns.len(); + let num_cols = btree_table.columns.len(); let rowid_reg = program.alloc_registers(num_cols + 1); let column_registers_start = rowid_reg + 1; let rowid_alias_reg = { @@ -108,7 +122,7 @@ pub fn translate_insert( let inserting_multiple_rows = values.len() > 1; - // Multiple rows - use coroutine for value population + // multiple rows - use coroutine for value population if inserting_multiple_rows { let yield_reg = program.alloc_register(); let jump_on_definition_label = program.allocate_label(); @@ -217,7 +231,7 @@ pub fn translate_insert( target_pc: make_record_label, }); let rowid_column_name = if let Some(index) = rowid_alias_index { - &table + btree_table .columns .get(index) .unwrap() @@ -302,7 +316,7 @@ struct ColumnMapping<'a> { /// - Named columns map to their corresponding value index /// - Unspecified columns map to None fn resolve_columns_for_insert<'a>( - table: &'a BTreeTable, + table: &'a Table, columns: &Option, values: &[Vec], ) -> Result>> { @@ -310,7 +324,7 @@ fn resolve_columns_for_insert<'a>( crate::bail_parse_error!("no values to insert"); } - let table_columns = &table.columns; + let table_columns = &table.columns(); // Case 1: No columns specified - map values to columns in order if columns.is_none() { @@ -318,7 +332,7 @@ fn resolve_columns_for_insert<'a>( if num_values > table_columns.len() { crate::bail_parse_error!( "table {} has {} columns but {} values were supplied", - &table.name, + &table.get_name(), table_columns.len(), num_values ); @@ -361,7 +375,11 @@ fn resolve_columns_for_insert<'a>( }); if table_index.is_none() { - crate::bail_parse_error!("table {} has no column named {}", &table.name, column_name); + crate::bail_parse_error!( + "table {} has no column named {}", + &table.get_name(), + column_name + ); } mappings[table_index.unwrap()].value_index = Some(value_index); @@ -425,3 +443,95 @@ fn populate_column_registers( } Ok(()) } + +fn translate_virtual_table_insert( + program: &mut ProgramBuilder, + virtual_table: Rc, + columns: &Option, + body: &InsertBody, + on_conflict: &Option, + resolver: &Resolver, +) -> Result<()> { + let init_label = program.allocate_label(); + program.emit_insn(Insn::Init { + target_pc: init_label, + }); + let start_offset = program.offset(); + + let values = match body { + InsertBody::Select(select, None) => match &select.body.select.deref() { + OneSelect::Values(values) => values, + _ => crate::bail_parse_error!("Virtual tables only support VALUES clause in INSERT"), + }, + InsertBody::DefaultValues => &vec![], + _ => crate::bail_parse_error!("Unsupported INSERT body for virtual tables"), + }; + + let table = Table::Virtual(virtual_table.clone()); + let column_mappings = resolve_columns_for_insert(&table, columns, values)?; + + let value_registers_start = program.alloc_registers(values[0].len()); + for (i, expr) in values[0].iter().enumerate() { + translate_expr(program, None, expr, value_registers_start + i, resolver)?; + } + + let start_reg = program.alloc_registers(column_mappings.len() + 3); + let rowid_reg = start_reg; // argv[0] = rowid + let insert_rowid_reg = start_reg + 1; // argv[1] = insert_rowid + let data_start_reg = start_reg + 2; // argv[2..] = column values + + program.emit_insn(Insn::Null { + dest: rowid_reg, + dest_end: None, + }); + program.emit_insn(Insn::Null { + dest: insert_rowid_reg, + dest_end: None, + }); + + for (i, mapping) in column_mappings.iter().enumerate() { + let target_reg = data_start_reg + i; + if let Some(value_index) = mapping.value_index { + program.emit_insn(Insn::Copy { + src_reg: value_registers_start + value_index, + dst_reg: target_reg, + amount: 1, + }); + } else { + program.emit_insn(Insn::Null { + dest: target_reg, + dest_end: None, + }); + } + } + + let conflict_action = on_conflict.as_ref().map(|c| c.bit_value()).unwrap_or(0) as u16; + + let cursor_id = program.alloc_cursor_id( + Some(virtual_table.name.clone()), + CursorType::VirtualTable(virtual_table.clone()), + ); + + program.emit_insn(Insn::VUpdate { + cursor_id, + arg_count: column_mappings.len() + 2, + start_reg, + vtab_ptr: virtual_table.implementation.as_ref().ctx as usize, + conflict_action, + }); + + let halt_label = program.allocate_label(); + program.emit_insn(Insn::Halt { + err_code: 0, + description: String::new(), + }); + + program.resolve_label(halt_label, program.offset()); + program.resolve_label(init_label, program.offset()); + + program.emit_insn(Insn::Goto { + target_pc: start_offset, + }); + + Ok(()) +} diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index a9fa9a158..a0e4a13c4 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -293,10 +293,32 @@ pub fn open_loop( }; let start_reg = program.alloc_registers(args.len()); let mut cur_reg = start_reg; - for arg in args { + + for arg_str in args { let reg = cur_reg; cur_reg += 1; - translate_expr(program, Some(tables), &arg, reg, &t_ctx.resolver)?; + + if let Ok(i) = arg_str.parse::() { + program.emit_insn(Insn::Integer { + value: i, + dest: reg, + }); + } else if let Ok(f) = arg_str.parse::() { + program.emit_insn(Insn::Real { + value: f, + dest: reg, + }); + } else if arg_str.starts_with('"') && arg_str.ends_with('"') { + program.emit_insn(Insn::String8 { + value: arg_str.trim_matches('"').to_string(), + dest: reg, + }); + } else { + program.emit_insn(Insn::String8 { + value: arg_str.clone(), + dest: reg, + }); + } } program.emit_insn(Insn::VFilter { cursor_id, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 3ee8f4ce0..7df6258ec 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -33,8 +33,7 @@ use crate::vdbe::builder::{CursorType, ProgramBuilderOpts, QueryMode}; use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; -use limbo_sqlite3_parser::ast::{self, fmt::ToTokens}; -use limbo_sqlite3_parser::ast::{Delete, Insert}; +use limbo_sqlite3_parser::ast::{self, fmt::ToTokens, CreateVirtualTable, Delete, Insert}; use select::translate_select; use std::cell::RefCell; use std::fmt::Display; @@ -74,8 +73,8 @@ pub fn translate( } ast::Stmt::CreateTrigger { .. } => bail_parse_error!("CREATE TRIGGER not supported yet"), ast::Stmt::CreateView { .. } => bail_parse_error!("CREATE VIEW not supported yet"), - ast::Stmt::CreateVirtualTable { .. } => { - bail_parse_error!("CREATE VIRTUAL TABLE not supported yet") + ast::Stmt::CreateVirtualTable(vtab) => { + translate_create_virtual_table(*vtab, schema, query_mode)? } ast::Stmt::Delete(delete) => { let Delete { @@ -94,7 +93,7 @@ pub fn translate( ast::Stmt::DropView { .. } => bail_parse_error!("DROP VIEW not supported yet"), ast::Stmt::Pragma(name, body) => pragma::translate_pragma( query_mode, - &schema, + schema, &name, body.map(|b| *b), database_header.clone(), @@ -177,6 +176,7 @@ addr opcode p1 p2 p3 p4 p5 comment enum SchemaEntryType { Table, Index, + Virtual, } impl SchemaEntryType { @@ -184,9 +184,11 @@ impl SchemaEntryType { match self { SchemaEntryType::Table => "table", SchemaEntryType::Index => "index", + SchemaEntryType::Virtual => "virtual", } } } +const SQLITE_TABLEID: &str = "sqlite_schema"; fn emit_schema_entry( program: &mut ProgramBuilder, @@ -209,11 +211,18 @@ fn emit_schema_entry( program.emit_string8_new_reg(tbl_name.to_string()); let rootpage_reg = program.alloc_register(); - program.emit_insn(Insn::Copy { - src_reg: root_page_reg, - dst_reg: rootpage_reg, - amount: 1, - }); + if matches!(entry_type, SchemaEntryType::Virtual) { + program.emit_insn(Insn::Integer { + dest: rootpage_reg, + value: 0, // virtual tables in sqlite always have rootpage=0 + }); + } else { + program.emit_insn(Insn::Copy { + src_reg: root_page_reg, + dst_reg: rootpage_reg, + amount: 1, + }); + } let sql_reg = program.alloc_register(); if let Some(sql) = sql { @@ -455,10 +464,9 @@ fn translate_create_table( }); } - let table_id = "sqlite_schema".to_string(); - let table = schema.get_table(&table_id).unwrap(); + let table = schema.get_btree_table(SQLITE_TABLEID).unwrap(); let sqlite_schema_cursor_id = program.alloc_cursor_id( - Some(table_id.to_owned()), + Some(SQLITE_TABLEID.to_owned()), CursorType::BTreeTable(table.clone()), ); program.emit_insn(Insn::OpenWriteAsync { @@ -546,3 +554,132 @@ fn create_table_body_to_str(tbl_name: &ast::QualifiedName, body: &ast::CreateTab } sql } + +fn create_vtable_body_to_str(vtab: &CreateVirtualTable) -> String { + let args = if let Some(args) = &vtab.args { + args.iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(", ") + } else { + "".to_string() + }; + let if_not_exists = if vtab.if_not_exists { + "IF NOT EXISTS " + } else { + "" + }; + format!( + "CREATE VIRTUAL TABLE {} {} USING {}{}", + vtab.tbl_name.name.0, + vtab.module_name.0, + if_not_exists, + if args.is_empty() { + String::new() + } else { + format!("({})", args) + } + ) +} + +fn translate_create_virtual_table( + vtab: CreateVirtualTable, + schema: &Schema, + query_mode: QueryMode, +) -> Result { + let ast::CreateVirtualTable { + if_not_exists, + tbl_name, + module_name, + args, + } = &vtab; + + let table_name = tbl_name.name.0.clone(); + let module_name_str = module_name.0.clone(); + let args_vec = args.clone().unwrap_or_default(); + + if schema.get_table(&table_name).is_some() && *if_not_exists { + let mut program = ProgramBuilder::new(ProgramBuilderOpts { + query_mode, + num_cursors: 1, + approx_num_insns: 5, + approx_num_labels: 1, + }); + let init_label = program.emit_init(); + program.emit_halt(); + program.resolve_label(init_label, program.offset()); + program.emit_transaction(true); + program.emit_constant_insns(); + return Ok(program); + } + + let mut program = ProgramBuilder::new(ProgramBuilderOpts { + query_mode, + num_cursors: 2, + approx_num_insns: 40, + approx_num_labels: 2, + }); + + let module_name_reg = program.emit_string8_new_reg(module_name_str.clone()); + let table_name_reg = program.emit_string8_new_reg(table_name.clone()); + + let args_reg = if !args_vec.is_empty() { + let args_start = program.alloc_register(); + for (i, arg) in args_vec.iter().enumerate() { + program.emit_string8(arg.clone(), args_start + i); + } + let args_record_reg = program.alloc_register(); + program.emit_insn(Insn::MakeRecord { + start_reg: args_start, + count: args_vec.len(), + dest_reg: args_record_reg, + }); + Some(args_record_reg) + } else { + None + }; + + program.emit_insn(Insn::VCreate { + module_name: module_name_reg, + table_name: table_name_reg, + args_reg, + }); + + let table = schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let sqlite_schema_cursor_id = program.alloc_cursor_id( + Some(SQLITE_TABLEID.to_owned()), + CursorType::BTreeTable(table.clone()), + ); + program.emit_insn(Insn::OpenWriteAsync { + cursor_id: sqlite_schema_cursor_id, + root_page: 1, + }); + program.emit_insn(Insn::OpenWriteAwait {}); + + let sql = create_vtable_body_to_str(&vtab); + emit_schema_entry( + &mut program, + sqlite_schema_cursor_id, + SchemaEntryType::Virtual, + &tbl_name.name.0, + &tbl_name.name.0, + 0, // virtual tables dont have a root page + Some(sql), + ); + + let parse_schema_where_clause = format!("tbl_name = '{}' AND type != 'trigger'", table_name); + program.emit_insn(Insn::ParseSchema { + db: sqlite_schema_cursor_id, + where_clause: parse_schema_where_clause, + }); + + let init_label = program.emit_init(); + let start_offset = program.offset(); + program.emit_halt(); + program.resolve_label(init_label, program.offset()); + program.emit_transaction(true); + program.emit_constant_insns(); + program.emit_goto(start_offset); + + Ok(program) +} diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 138d1cbc0..c9c266a13 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -11,7 +11,7 @@ use crate::{ schema::{Schema, Table}, util::{exprs_are_equivalent, normalize_ident}, vdbe::BranchOffset, - Result, VirtualTable, + Result, }; use limbo_sqlite3_parser::ast::{ self, Expr, FromClause, JoinType, Limit, Materialized, UnaryOperator, With, @@ -303,7 +303,7 @@ fn parse_from_clause_table<'a>( return Ok(()); }; // Check if our top level schema has this table. - if let Some(table) = schema.get_table(&normalized_qualified_name) { + if let Some(table) = schema.get_btree_table(&normalized_qualified_name) { let alias = maybe_alias .map(|a| match a { ast::As::As(id) => id, @@ -369,9 +369,16 @@ fn parse_from_clause_table<'a>( } ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => { let normalized_name = &normalize_ident(qualified_name.name.0.as_str()); - let Some(vtab) = syms.vtabs.get(normalized_name) else { - crate::bail_parse_error!("Virtual table {} not found", normalized_name); - }; + let vtab = crate::VirtualTable::from_args( + None, + normalized_name, + &maybe_args + .as_ref() + .map(|a| a.iter().map(|s| s.to_string()).collect::>()) + .unwrap_or_default(), + syms, + limbo_ext::VTabKind::TableValuedFunction, + )?; let alias = maybe_alias .as_ref() .map(|a| match a { @@ -383,18 +390,10 @@ fn parse_from_clause_table<'a>( scope.tables.push(TableReference { op: Operation::Scan { iter_dir: None }, join_info: None, - table: Table::Virtual( - VirtualTable { - name: normalized_name.clone(), - args: maybe_args, - implementation: vtab.implementation.clone(), - columns: vtab.columns.clone(), - } - .into(), - ) - .into(), - identifier: alias.clone(), + table: Table::Virtual(vtab), + identifier: alias, }); + Ok(()) } _ => todo!(), @@ -611,7 +610,7 @@ fn parse_join<'a>( constraint, } = join; - parse_from_clause_table(schema, table, scope, syms)?; + parse_from_clause_table(schema, table, scope, &syms)?; let (outer, natural) = match join_operator { ast::JoinOperator::TypedJoin(Some(join_type)) => { diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 1fd738acd..b33f38011 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -218,7 +218,7 @@ fn query_pragma( program.alloc_register(); program.alloc_register(); if let Some(table) = table { - for (i, column) in table.columns.iter().enumerate() { + for (i, column) in table.columns().iter().enumerate() { // cid program.emit_int(i as i64, base_reg); // name diff --git a/core/util.rs b/core/util.rs index 2ee09189d..ddaf931a9 100644 --- a/core/util.rs +++ b/core/util.rs @@ -3,7 +3,7 @@ use std::{rc::Rc, sync::Arc}; use crate::{ schema::{self, Column, Schema, Type}, - Result, Statement, StepResult, IO, + Result, Statement, StepResult, SymbolTable, IO, }; // https://sqlite.org/lang_keywords.html @@ -28,6 +28,7 @@ pub fn parse_schema_rows( rows: Option, schema: &mut Schema, io: Arc, + syms: &SymbolTable, ) -> Result<()> { if let Some(mut rows) = rows { let mut automatic_indexes = Vec::new(); @@ -36,7 +37,7 @@ pub fn parse_schema_rows( StepResult::Row => { let row = rows.row().unwrap(); let ty = row.get::<&str>(0)?; - if ty != "table" && ty != "index" { + if !["table", "index", "virtual"].contains(&ty) { continue; } match ty { @@ -44,7 +45,12 @@ pub fn parse_schema_rows( let root_page: i64 = row.get::(3)?; let sql: &str = row.get::<&str>(4)?; let table = schema::BTreeTable::from_sql(sql, root_page as usize)?; - schema.add_table(Rc::new(table)); + schema.add_btree_table(Rc::new(table)); + } + "virtual" => { + let name: &str = row.get::<&str>(1)?; + let vtab = syms.vtabs.get(name).unwrap().clone(); + schema.add_virtual_table(vtab); } "index" => { let root_page: i64 = row.get::(3)?; @@ -83,7 +89,7 @@ pub fn parse_schema_rows( } for (index_name, table_name, root_page) in automatic_indexes { // We need to process these after all tables are loaded into memory due to the schema.get_table() call - let table = schema.get_table(&table_name).unwrap(); + let table = schema.get_btree_table(&table_name).unwrap(); let index = schema::Index::automatic_from_primary_key(&table, &index_name, root_page as usize)?; schema.add_index(Rc::new(index)); @@ -307,9 +313,11 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { } } -pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result, ()> { +pub fn columns_from_create_table_body(body: &ast::CreateTableBody) -> crate::Result> { let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else { - return Err(()); + return Err(crate::LimboError::ParseError( + "CREATE TABLE body must contain columns and constraints".to_string(), + )); }; Ok(columns @@ -322,7 +330,7 @@ pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result { // https://www.sqlite.org/datatype3.html diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 716610b71..470997158 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -50,7 +50,7 @@ impl CursorType { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Copy)] pub enum QueryMode { Normal, Explain, diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 2b61310d5..a609fc667 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -381,6 +381,19 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::VCreate { + table_name, + module_name, + args_reg, + } => ( + "VCreate", + *table_name as i32, + *module_name as i32, + args_reg.unwrap_or(0) as i32, + OwnedValue::build_text(""), + 0, + format!("table={}, module={}", table_name, module_name), + ), Insn::VFilter { cursor_id, pc_if_empty, @@ -408,6 +421,21 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::VUpdate { + cursor_id, + arg_count, // P2: Number of arguments in argv[] + start_reg, // P3: Start register for argv[] + vtab_ptr, // P4: vtab pointer + conflict_action, // P5: Conflict resolution flags + } => ( + "VUpdate", + *cursor_id as i32, + *arg_count as i32, + *start_reg as i32, + OwnedValue::build_text(&format!("vtab:{}", vtab_ptr)), + *conflict_action, + format!("args=r[{}..{}]", start_reg, start_reg + arg_count - 1), + ), Insn::VNext { cursor_id, pc_if_next, diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 1fad2c479..d6b25046c 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -220,6 +220,13 @@ pub enum Insn { /// Await for the completion of open cursor for a virtual table. VOpenAwait, + /// Create a new virtual table. + VCreate { + module_name: usize, // P1: Name of the module that contains the virtual table implementation + table_name: usize, // P2: Name of the virtual table + args_reg: Option, + }, + /// Initialize the position of the virtual table cursor. VFilter { cursor_id: CursorID, @@ -235,6 +242,15 @@ pub enum Insn { dest: usize, }, + /// `VUpdate`: Virtual Table Insert/Update/Delete Instruction + VUpdate { + cursor_id: usize, // P1: Virtual table cursor number + arg_count: usize, // P2: Number of arguments in argv[] + start_reg: usize, // P3: Start register for argv[] + vtab_ptr: usize, // P4: vtab pointer + conflict_action: u16, // P5: Conflict resolution flags + }, + /// Advance the virtual table cursor to the next row. /// TODO: async VNext { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index a81a1c687..a881c0d56 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -873,6 +873,46 @@ impl Program { .insert(*cursor_id, Some(Cursor::Virtual(cursor))); state.pc += 1; } + Insn::VCreate { + module_name, + table_name, + args_reg, + } => { + let module_name = state.registers[*module_name].to_string(); + let table_name = state.registers[*table_name].to_string(); + let args = if let Some(args_reg) = args_reg { + if let OwnedValue::Record(rec) = &state.registers[*args_reg] { + rec.get_values().iter().map(|v| v.to_string()).collect() + } else { + return Err(LimboError::InternalError( + "VCreate: args_reg is not a record".to_string(), + )); + } + } else { + vec![] + }; + let Some(conn) = self.connection.upgrade() else { + return Err(crate::LimboError::ExtensionError( + "Failed to upgrade Connection".to_string(), + )); + }; + let table = crate::VirtualTable::from_args( + Some(&table_name), + &module_name, + &args, + &conn.db.syms.borrow(), + limbo_ext::VTabKind::VirtualTable, + )?; + { + conn.db + .syms + .as_ref() + .borrow_mut() + .vtabs + .insert(table_name, table.clone()); + } + state.pc += 1; + } Insn::VOpenAwait => { state.pc += 1; } @@ -913,6 +953,68 @@ impl Program { state.registers[*dest] = virtual_table.column(cursor, *column)?; state.pc += 1; } + Insn::VUpdate { + cursor_id, + arg_count, + start_reg, + conflict_action, + .. + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VUpdate on non-virtual table cursor"); + }; + + if *arg_count < 2 { + return Err(LimboError::InternalError( + "VUpdate: arg_count must be at least 2 (rowid and insert_rowid)" + .to_string(), + )); + } + + let mut argv = Vec::with_capacity(*arg_count); + for i in 0..*arg_count { + if let Some(value) = state.registers.get(*start_reg + i) { + argv.push(value.clone()); + } else { + return Err(LimboError::InternalError(format!( + "VUpdate: register out of bounds at {}", + *start_reg + i + ))); + } + } + + let current_rowid = match argv.first() { + Some(OwnedValue::Integer(rowid)) => Some(*rowid), + _ => None, + }; + let insert_rowid = match argv.get(1) { + Some(OwnedValue::Integer(rowid)) => Some(*rowid), + _ => None, + }; + + let result = virtual_table.update(&argv, insert_rowid); + + match result { + Ok(Some(new_rowid)) => { + if *conflict_action == 5 { + if let Some(conn) = self.connection.upgrade() { + conn.update_last_rowid(new_rowid as u64); + } + } + state.pc += 1; + } + Ok(None) => { + state.pc += 1; + } + Err(e) => { + return Err(LimboError::ExtensionError(format!( + "Virtual table update failed: {}", + e + ))); + } + } + } Insn::VNext { cursor_id, pc_if_next, @@ -2724,8 +2826,13 @@ impl Program { where_clause ))?; let mut schema = RefCell::borrow_mut(&conn.schema); - // TODO: This function below is synchronous, make it not async - parse_schema_rows(Some(stmt), &mut schema, conn.pager.io.clone())?; + // TODO: This function below is synchronous, make it async + parse_schema_rows( + Some(stmt), + &mut schema, + conn.pager.io.clone(), + &conn.db.syms.borrow(), + )?; state.pc += 1; } Insn::ReadCookie { db, dest, cookie } => { diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index d06340aa2..2b123a3a2 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -6,34 +6,20 @@ use std::{ }; pub use types::{ResultCode, Value, ValueType}; +pub type ExtResult = std::result::Result; + #[repr(C)] pub struct ExtensionApi { pub ctx: *mut c_void, pub register_scalar_function: RegisterScalarFn, pub register_aggregate_function: RegisterAggFn, pub register_module: RegisterModuleFn, - pub declare_vtab: DeclareVTabFn, -} - -impl ExtensionApi { - pub fn declare_virtual_table(&self, name: &str, sql: &str) -> ResultCode { - let Ok(name) = std::ffi::CString::new(name) else { - return ResultCode::Error; - }; - let Ok(sql) = std::ffi::CString::new(sql) else { - return ResultCode::Error; - }; - unsafe { (self.declare_vtab)(self.ctx, name.as_ptr(), sql.as_ptr()) } - } } pub type ExtensionEntryPoint = unsafe extern "C" fn(api: *const ExtensionApi) -> ResultCode; pub type ScalarFunction = unsafe extern "C" fn(argc: i32, *const Value) -> Value; -pub type DeclareVTabFn = - unsafe extern "C" fn(ctx: *mut c_void, name: *const c_char, sql: *const c_char) -> ResultCode; - pub type RegisterScalarFn = unsafe extern "C" fn(ctx: *mut c_void, name: *const c_char, func: ScalarFunction) -> ResultCode; @@ -50,6 +36,7 @@ pub type RegisterModuleFn = unsafe extern "C" fn( ctx: *mut c_void, name: *const c_char, module: VTabModuleImpl, + kind: VTabKind, ) -> ResultCode; pub type InitAggFunction = unsafe extern "C" fn() -> *mut AggCtx; @@ -74,18 +61,39 @@ pub trait AggFunc { #[repr(C)] #[derive(Clone, Debug)] pub struct VTabModuleImpl { + pub ctx: *mut c_void, pub name: *const c_char, - pub connect: VtabFnConnect, + pub create_schema: VtabFnCreateSchema, pub open: VtabFnOpen, pub filter: VtabFnFilter, pub column: VtabFnColumn, pub next: VtabFnNext, pub eof: VtabFnEof, + pub update: VtabFnUpdate, } -pub type VtabFnConnect = unsafe extern "C" fn(api: *const c_void) -> ResultCode; +impl VTabModuleImpl { + pub fn init_schema(&self, args: &[String]) -> ExtResult { + let c_args = args + .iter() + .map(|s| std::ffi::CString::new(s.as_bytes()).unwrap().into_raw()) + .collect::>(); + let schema = unsafe { (self.create_schema)(c_args.as_ptr(), c_args.len() as i32) }; + c_args.into_iter().for_each(|s| unsafe { + let _ = std::ffi::CString::from_raw(s); + }); + if schema.is_null() { + return Err(ResultCode::InvalidArgs); + } + let schema = unsafe { std::ffi::CString::from_raw(schema) }; + Ok(schema.to_string_lossy().to_string()) + } +} -pub type VtabFnOpen = unsafe extern "C" fn() -> *mut c_void; +pub type VtabFnCreateSchema = + unsafe extern "C" fn(args: *const *mut c_char, argc: i32) -> *mut c_char; + +pub type VtabFnOpen = unsafe extern "C" fn(args: *const *mut c_char, argc: i32) -> *mut c_void; pub type VtabFnFilter = unsafe extern "C" fn(cursor: *mut c_void, argc: i32, argv: *const Value) -> ResultCode; @@ -96,17 +104,34 @@ pub type VtabFnNext = unsafe extern "C" fn(cursor: *mut c_void) -> ResultCode; pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; +pub type VtabFnUpdate = unsafe extern "C" fn( + vtab: *mut c_void, + argc: i32, + argv: *const Value, + rowid: i64, + p_out_rowid: *mut i64, +) -> ResultCode; + +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum VTabKind { + VirtualTable, + TableValuedFunction, +} + pub trait VTabModule: 'static { type VCursor: VTabCursor; + const VTAB_KIND: VTabKind; const NAME: &'static str; type Error: std::fmt::Display; - fn init_sql() -> &'static str; - fn open() -> Result; + fn create_schema(args: &[String]) -> String; + fn open(args: &[String]) -> Result; fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; fn column(cursor: &Self::VCursor, idx: u32) -> Result; fn next(cursor: &mut Self::VCursor) -> ResultCode; fn eof(cursor: &Self::VCursor) -> bool; + fn update(&mut self, args: &[Value], rowid: Option) -> Result, Self::Error>; } pub trait VTabCursor: Sized { @@ -116,8 +141,3 @@ pub trait VTabCursor: Sized { fn eof(&self) -> bool; fn next(&mut self) -> ResultCode; } - -#[repr(C)] -pub struct VTabImpl { - pub module: VTabModuleImpl, -} diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index c29768d7f..55d39ac42 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -21,6 +21,8 @@ pub enum ResultCode { Unavailable = 13, CustomError = 14, EOF = 15, + ReadOnly = 16, + RowID = 17, } impl ResultCode { @@ -52,6 +54,8 @@ impl Display for ResultCode { ResultCode::Unavailable => write!(f, "Unavailable"), ResultCode::CustomError => write!(f, "Error "), ResultCode::EOF => write!(f, "EOF"), + ResultCode::ReadOnly => write!(f, "Read Only"), + ResultCode::RowID => write!(f, "RowID"), } } } diff --git a/extensions/kvstore/Cargo.toml b/extensions/kvstore/Cargo.toml new file mode 100644 index 000000000..81c1f804d --- /dev/null +++ b/extensions/kvstore/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "limbo_kv" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +static= [ "limbo_ext/static" ] + +[dependencies] +limbo_ext = { workspace = true, features = ["static"] } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } diff --git a/extensions/kvstore/src/lib.rs b/extensions/kvstore/src/lib.rs new file mode 100644 index 000000000..184467aca --- /dev/null +++ b/extensions/kvstore/src/lib.rs @@ -0,0 +1,103 @@ +use limbo_ext::{ + register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, +}; +use std::collections::HashMap; + +register_extension! { + vtabs: { KVStoreVTab }, +} + +#[derive(VTabModuleDerive, Default)] +pub struct KVStoreVTab { + store: HashMap, +} + +pub struct KVStoreCursor { + keys: Vec, + values: Vec, + index: usize, +} + +impl VTabModule for KVStoreVTab { + type VCursor = KVStoreCursor; + const VTAB_KIND: VTabKind = VTabKind::VirtualTable; + const NAME: &'static str = "kv_store"; + type Error = String; + + fn create_schema(_args: &[String]) -> String { + "CREATE TABLE x (key TEXT PRIMARY KEY, value TEXT);".to_string() + } + + fn open(_args: &[String]) -> Result { + Ok(KVStoreCursor { + keys: Vec::new(), + values: Vec::new(), + index: 0, + }) + } + + fn filter(cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { + cursor.index = 0; + ResultCode::OK + } + + fn column(cursor: &Self::VCursor, idx: u32) -> Result { + match idx { + 0 => Ok(Value::from_text(cursor.keys[cursor.index].clone())), + 1 => Ok(Value::from_text(cursor.values[cursor.index].clone())), + _ => Err("Invalid column".into()), + } + } + + fn next(cursor: &mut Self::VCursor) -> ResultCode { + cursor.index += 1; + ResultCode::OK + } + + fn eof(cursor: &Self::VCursor) -> bool { + cursor.index >= cursor.keys.len() + } + + fn update(&mut self, args: &[Value], rowid: Option) -> Result, Self::Error> { + match args.len() { + 1 => { + let key = args[0].to_text().ok_or("Invalid key")?; + // Handle DELETE + self.store.remove(key); + Ok(None) + } + 2 => { + let key = args[0].to_text().ok_or("Invalid key")?; + let value = args[1].to_text().ok_or("Invalid value")?; + // Handle INSERT / UPDATE + self.store.insert(key.to_string(), value.to_string()); + Ok(Some(rowid.unwrap_or(0))) + } + _ => { + println!("args: {:?}", args); + Err("Invalid arguments for update".into()) + } + } + } +} + +impl VTabCursor for KVStoreCursor { + type Error = String; + fn rowid(&self) -> i64 { + self.index as i64 + } + fn column(&self, idx: u32) -> Result { + match idx { + 0 => Ok(Value::from_text(self.keys[self.index].clone())), + 1 => Ok(Value::from_text(self.values[self.index].clone())), + _ => Err("Invalid column".into()), + } + } + fn eof(&self) -> bool { + self.index >= self.keys.len() + } + fn next(&mut self) -> ResultCode { + self.index += 1; + ResultCode::OK + } +} diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 6e86f0c93..161bfe886 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -1,4 +1,6 @@ -use limbo_ext::{register_extension, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value}; +use limbo_ext::{ + register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, +}; register_extension! { vtabs: { GenerateSeriesVTab } @@ -14,16 +16,16 @@ macro_rules! try_option { } /// A virtual table that generates a sequence of integers -#[derive(Debug, VTabModuleDerive)] +#[derive(Debug, VTabModuleDerive, Default)] struct GenerateSeriesVTab; impl VTabModule for GenerateSeriesVTab { type VCursor = GenerateSeriesCursor; type Error = ResultCode; - const NAME: &'static str = "generate_series"; + const VTAB_KIND: VTabKind = VTabKind::TableValuedFunction; - fn init_sql() -> &'static str { + fn create_schema(_args: &[String]) -> String { // Create table schema "CREATE TABLE generate_series( value INTEGER, @@ -31,9 +33,10 @@ impl VTabModule for GenerateSeriesVTab { stop INTEGER HIDDEN, step INTEGER HIDDEN )" + .into() } - fn open() -> Result { + fn open(_args: &[String]) -> Result { Ok(GenerateSeriesCursor { start: 0, stop: 0, @@ -88,6 +91,10 @@ impl VTabModule for GenerateSeriesVTab { fn eof(cursor: &Self::VCursor) -> bool { cursor.eof() } + + fn update(&mut self, _args: &[Value], _rowid: Option) -> Result, Self::Error> { + Ok(None) + } } /// The cursor for iterating over the generated sequence diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 089579081..1f221b03f 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -341,7 +341,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// const NAME: &'static str = "csv_data"; /// /// /// Declare the schema for your virtual table -/// fn init_sql() -> &'static str { +/// fn create_schema(args: &[&str]) -> &'static str { /// let sql = "CREATE TABLE csv_data( /// name TEXT, /// age TEXT, @@ -382,6 +382,12 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// fn eof(cursor: &Self::VCursor) -> bool { /// cursor.index >= cursor.rows.len() /// } +/// +/// /// Update the row with the provided values, return the new rowid if provided +/// fn update(&mut self, args: &[Value], rowid: Option) -> Result, Self::Error> { +/// Ok(None)// return Ok(None) for read-only +/// } +/// /// #[derive(Debug)] /// struct CsvCursor { /// rows: Vec>, @@ -389,7 +395,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// /// impl CsvCursor { /// /// Returns the value for a given column index. -/// fn column(&self, idx: u32) -> Value { +/// fn column(&self, idx: u32) -> Result { /// let row = &self.rows[self.index]; /// if (idx as usize) < row.len() { /// Value::from_text(&row[idx as usize]) @@ -418,31 +424,45 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let struct_name = &ast.ident; let register_fn_name = format_ident!("register_{}", struct_name); - let connect_fn_name = format_ident!("connect_{}", struct_name); + let create_schema_fn_name = format_ident!("create_schema_{}", struct_name); let open_fn_name = format_ident!("open_{}", struct_name); let filter_fn_name = format_ident!("filter_{}", struct_name); let column_fn_name = format_ident!("column_{}", struct_name); let next_fn_name = format_ident!("next_{}", struct_name); let eof_fn_name = format_ident!("eof_{}", struct_name); + let update_fn_name = format_ident!("update_{}", struct_name); let expanded = quote! { impl #struct_name { #[no_mangle] - unsafe extern "C" fn #connect_fn_name( - db: *const ::std::ffi::c_void - ) -> ::limbo_ext::ResultCode { - let api = &*(db as *const ::limbo_ext::ExtensionApi); - let sql = <#struct_name as ::limbo_ext::VTabModule>::init_sql(); - api.declare_virtual_table(<#struct_name as ::limbo_ext::VTabModule>::NAME, sql) + unsafe extern "C" fn #create_schema_fn_name( + argv: *const *mut ::std::ffi::c_char, argc: i32 + ) -> *mut ::std::ffi::c_char { + let args = if argv.is_null() { + Vec::new() + } else { + ::std::slice::from_raw_parts(argv, argc as usize).iter().map(|s| { + ::std::ffi::CStr::from_ptr(*s).to_string_lossy().to_string() + }).collect::>() + }; + let sql = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); + ::std::ffi::CString::new(sql).unwrap().into_raw() } #[no_mangle] - unsafe extern "C" fn #open_fn_name( - ) -> *mut ::std::ffi::c_void { - if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open() { - ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *mut ::std::ffi::c_void + unsafe extern "C" fn #open_fn_name(argv: *const *mut ::std::ffi::c_char, argc: i32) -> *mut ::std::ffi::c_void { + let args = if argv.is_null() { + Vec::new() } else { - ::std::ptr::null_mut() + ::std::slice::from_raw_parts(argv, argc as usize).iter().map(|s| { + ::std::ffi::CStr::from_ptr(*s).to_string_lossy().to_string() + }).collect::>() + }; + let schema = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); + if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open(&args) { + return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *mut ::std::ffi::c_void; + } else { + return ::std::ptr::null_mut(); } } @@ -497,6 +517,37 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) } + #[no_mangle] + unsafe extern "C" fn #update_fn_name( + vtab: *mut ::std::ffi::c_void, + argc: i32, + argv: *const ::limbo_ext::Value, + rowid: i64, + p_out_rowid: *mut i64, + ) -> ::limbo_ext::ResultCode { + if vtab.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let vtab = &mut *(vtab as *mut #struct_name); + let args = ::std::slice::from_raw_parts(argv, argc as usize); + let rowid = if rowid == -1 { + None + } else { + Some(rowid as i64) + }; + let result = <#struct_name as ::limbo_ext::VTabModule>::update(vtab, args, rowid); + match result { + Ok(Some(rowid)) => { + // set the output rowid if it was provided + *p_out_rowid = rowid; + ::limbo_ext::ResultCode::RowID + } + Ok(None) => ::limbo_ext::ResultCode::OK, + Err(_) => ::limbo_ext::ResultCode::Error, + } + } + + #[no_mangle] pub unsafe extern "C" fn #register_fn_name( api: *const ::limbo_ext::ExtensionApi @@ -506,20 +557,20 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { } let api = &*api; let name = <#struct_name as ::limbo_ext::VTabModule>::NAME; - // name needs to be a c str FFI compatible, NOT CString let name_c = ::std::ffi::CString::new(name).unwrap().into_raw() as *const ::std::ffi::c_char; - + let table_instance = ::std::boxed::Box::into_raw(::std::boxed::Box::new(#struct_name::default())); let module = ::limbo_ext::VTabModuleImpl { + ctx: table_instance as *mut ::std::ffi::c_void, name: name_c, - connect: Self::#connect_fn_name, + create_schema: Self::#create_schema_fn_name, open: Self::#open_fn_name, filter: Self::#filter_fn_name, column: Self::#column_fn_name, next: Self::#next_fn_name, eof: Self::#eof_fn_name, + update: Self::#update_fn_name, }; - - (api.register_module)(api.ctx, name_c, module) + (api.register_module)(api.ctx, name_c, module, <#struct_name as ::limbo_ext::VTabModule>::VTAB_KIND) } } }; @@ -594,16 +645,11 @@ pub fn register_extension(input: TokenStream) -> TokenStream { }); let vtab_calls = vtabs.iter().map(|vtab_ident| { let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span()); - let connect_fn = syn::Ident::new(&format!("connect_{}", vtab_ident), vtab_ident.span()); quote! { { let result = unsafe{ #vtab_ident::#register_fn(api)}; - if result == ::limbo_ext::ResultCode::OK { - let api = api as *const _ as *const ::std::ffi::c_void; - let result = #vtab_ident::#connect_fn(api); - if !result.is_ok() { - return result; - } + if !result.is_ok() { + return result; } } } diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index ad359cc0b..1aac9c2c4 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -1724,6 +1724,18 @@ pub enum ResolveType { /// `REPLACE` Replace, } +impl ResolveType { + /// Get the OE_XXX bit value + pub fn bit_value(&self) -> usize { + match self { + ResolveType::Rollback => 1, + ResolveType::Abort => 2, + ResolveType::Fail => 3, + ResolveType::Ignore => 4, + ResolveType::Replace => 5, + } + } +} /// `WITH` clause // https://sqlite.org/lang_with.html From 2fd2544f3e06696d8c5a3417b77954a0fb659464 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sat, 15 Feb 2025 18:25:05 -0500 Subject: [PATCH 03/13] Update COMPAT.md --- COMPAT.md | 12 ++++++------ Cargo.lock | 8 ++++++++ Cargo.toml | 3 ++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/COMPAT.md b/COMPAT.md index e265ba554..4f4161dbe 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -572,14 +572,14 @@ Modifiers: | Trace | No | | | Transaction | Yes | | | VBegin | No | | -| VColumn | No | | -| VCreate | No | | +| VColumn | Yes | | +| VCreate | Yes | | | VDestroy | No | | -| VFilter | No | | -| VNext | No | | -| VOpen | No | | +| VFilter | Yes | | +| VNext | Yes | | +| VOpen | Yes |VOpenAsync| | VRename | No | | -| VUpdate | No | | +| VUpdate | Yes | | | Vacuum | No | | | Variable | No | | | VerifyCookie | No | | diff --git a/Cargo.lock b/Cargo.lock index c65ad569a..3e55db38f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1683,6 +1683,14 @@ dependencies = [ "limbo_macros", ] +[[package]] +name = "limbo_kv" +version = "0.0.14" +dependencies = [ + "limbo_ext", + "mimalloc", +] + [[package]] name = "limbo_macros" version = "0.0.14" diff --git a/Cargo.toml b/Cargo.toml index bb2607e59..f43b0a0bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,8 @@ members = [ "cli", "core", "extensions/core", - "extensions/crypto", + "extensions/crypto", + "extensions/kvstore", "extensions/percentile", "extensions/regexp", "extensions/series", From f2e3a6120462c33a85594f3763f342a1200daf4a Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 16 Feb 2025 19:30:51 -0500 Subject: [PATCH 04/13] Change error message in extension tests to match new behavior --- testing/extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/extensions.py b/testing/extensions.py index 76af242dd..240abacb7 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -443,7 +443,7 @@ def test_series(pipe): run_test( pipe, "SELECT * FROM generate_series(1, 10);", - lambda res: "Virtual table generate_series not found" in res, + lambda res: "Virtual table module not found: generate_series" in res, ) run_test(pipe, f".load {ext_path}", returns_null) run_test( From 8b5772fe1cceccbdfe19aa4be333307ad7084c70 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 16 Feb 2025 20:21:09 -0500 Subject: [PATCH 05/13] Implement VUpdate (insert/delete for virtual tables --- core/ext/mod.rs | 1 - core/lib.rs | 54 +++++++++++------------ core/translate/delete.rs | 18 +++++--- core/translate/emitter.rs | 85 +++++++++++++++++++++++++++++++++++-- core/translate/insert.rs | 18 +++++--- core/translate/main_loop.rs | 49 +++++++-------------- core/translate/mod.rs | 2 +- core/translate/planner.rs | 26 ++++++++---- core/util.rs | 29 +++++++++++++ core/vdbe/mod.rs | 66 +++++++++++++++++----------- 10 files changed, 236 insertions(+), 112 deletions(-) diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 3ea7d9692..a5cfd5909 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -111,7 +111,6 @@ impl Database { .borrow_mut() .vtab_modules .insert(name.to_string(), vmodule.into()); - println!("Registered module: {}", name); ResultCode::OK } diff --git a/core/lib.rs b/core/lib.rs index d9727fd98..8719f3c0a 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -516,18 +516,23 @@ pub type StepResult = vdbe::StepResult; #[derive(Clone, Debug)] pub struct VirtualTable { name: String, - args: Option>, + args: Option>, pub implementation: Rc, columns: Vec, } impl VirtualTable { + pub(crate) fn rowid(&self, cursor: &VTabOpaqueCursor) -> i64 { + unsafe { (self.implementation.rowid)(cursor.as_ptr()) } + } + /// takes ownership of the provided Args pub(crate) fn from_args( tbl_name: Option<&str>, module_name: &str, - args: &[String], + args: Vec, syms: &SymbolTable, kind: VTabKind, + exprs: &Option>, ) -> Result> { let module = syms .vtab_modules @@ -544,19 +549,23 @@ impl VirtualTable { ))); } }; - let schema = module.implementation.as_ref().init_schema(args)?; + let schema = module.implementation.as_ref().init_schema(&args)?; + for arg in args { + unsafe { + arg.free(); + } + } let mut parser = Parser::new(schema.as_bytes()); parser.reset(schema.as_bytes()); - println!("Schema: {}", schema); if let ast::Cmd::Stmt(ast::Stmt::CreateTable { body, .. }) = parser.next()?.ok_or( LimboError::ParseError("Failed to parse schema from virtual table module".to_string()), )? { let columns = columns_from_create_table_body(&body)?; let vtab = Rc::new(VirtualTable { name: tbl_name.unwrap_or(module_name).to_owned(), - args: Some(args.to_vec()), implementation: module.implementation.clone(), columns, + args: exprs.clone(), }); return Ok(vtab); } @@ -565,24 +574,8 @@ impl VirtualTable { )) } - pub fn open(&self) -> VTabOpaqueCursor { - let args = if let Some(args) = &self.args { - args.iter() - .map(|e| std::ffi::CString::new(e.to_string()).unwrap().into_raw()) - .collect() - } else { - Vec::new() - }; - let cursor = - unsafe { (self.implementation.open)(args.as_slice().as_ptr(), args.len() as i32) }; - // free the CString pointers - for arg in args { - unsafe { - if !arg.is_null() { - let _ = std::ffi::CString::from_raw(arg); - } - } - } + pub fn open(&self) -> crate::Result { + let cursor = unsafe { (self.implementation.open)(self.implementation.ctx) }; VTabOpaqueCursor::new(cursor) } @@ -620,7 +613,11 @@ impl VirtualTable { pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result { let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) }; - OwnedValue::from_ffi(&val) + let res = OwnedValue::from_ffi(&val)?; + unsafe { + val.free(); + } + Ok(res) } pub fn next(&self, cursor: &VTabOpaqueCursor) -> Result { @@ -632,7 +629,7 @@ impl VirtualTable { } } - pub fn update(&self, args: &[OwnedValue], rowid: Option) -> Result> { + pub fn update(&self, args: &[OwnedValue]) -> Result> { let arg_count = args.len(); let mut ext_args = Vec::with_capacity(arg_count); for i in 0..arg_count { @@ -650,7 +647,6 @@ impl VirtualTable { }?; ext_args.push(extvalue_arg); } - let rowid = rowid.unwrap_or(-1); let newrowid = 0i64; let implementation = self.implementation.as_ref(); let rc = unsafe { @@ -658,10 +654,14 @@ impl VirtualTable { implementation as *const VTabModuleImpl as *mut std::ffi::c_void, arg_count as i32, ext_args.as_ptr(), - rowid, &newrowid as *const _ as *mut i64, ) }; + for arg in ext_args { + unsafe { + arg.free(); + } + } match rc { ResultCode::OK => Ok(None), ResultCode::RowID => Ok(Some(newrowid)), diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 81f8ba6ef..1e0d64a98 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -25,7 +25,7 @@ pub fn translate_delete( let mut program = ProgramBuilder::new(ProgramBuilderOpts { query_mode, num_cursors: 1, - approx_num_insns: estimate_num_instructions(&delete), + approx_num_insns: estimate_num_instructions(delete), approx_num_labels: 0, }); emit_program(&mut program, delete_plan, syms)?; @@ -42,13 +42,17 @@ pub fn prepare_delete_plan( Some(table) => table, None => crate::bail_corrupt_error!("Parse error: no such table: {}", tbl_name), }; - //if let Some(table) = table.virtual_table() { - // // TODO: emit VUpdate - //} - let table = table.btree().unwrap(); + let table = if let Some(table) = table.virtual_table() { + Table::Virtual(table.clone()) + } else if let Some(table) = table.btree() { + Table::BTree(table.clone()) + } else { + crate::bail_corrupt_error!("Table is neither a virtual table nor a btree table"); + }; + let name = tbl_name.name.0.as_str().to_string(); let table_references = vec![TableReference { - table: Table::BTree(table.clone()), - identifier: table.name.clone(), + table, + identifier: name, op: Operation::Scan { iter_dir: None }, join_info: None, }]; diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 7786c90a9..8d095eb9b 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -274,7 +274,7 @@ pub fn emit_query<'a>( fn emit_program_for_delete( program: &mut ProgramBuilder, - mut plan: DeletePlan, + plan: DeletePlan, syms: &SymbolTable, ) -> Result<()> { let (mut t_ctx, init_label, start_offset) = prologue( @@ -286,6 +286,7 @@ fn emit_program_for_delete( // No rows will be read from source table loops if there is a constant false condition eg. WHERE 0 let after_main_loop_label = program.allocate_label(); + t_ctx.label_main_loop_end = Some(after_main_loop_label); if plan.contains_constant_false_condition { program.emit_insn(Insn::Goto { target_pc: after_main_loop_label, @@ -304,11 +305,16 @@ fn emit_program_for_delete( open_loop( program, &mut t_ctx, - &mut plan.table_references, + &plan.table_references, &plan.where_clause, )?; - - emit_delete_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?; + if let Some(table) = plan.table_references.first() { + if table.virtual_table().is_some() { + emit_delete_vtable_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?; + } else { + emit_delete_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?; + } + } // Clean up and close the main execution loop close_loop(program, &mut t_ctx, &plan.table_references)?; @@ -322,6 +328,77 @@ fn emit_program_for_delete( Ok(()) } +fn emit_delete_vtable_insns( + program: &mut ProgramBuilder, + t_ctx: &mut TranslateCtx, + table_references: &[TableReference], + limit: &Option, +) -> Result<()> { + let table_reference = table_references.first().unwrap(); + + let cursor_id = match &table_reference.op { + Operation::Scan { .. } => program.resolve_cursor_id(&table_reference.identifier), + Operation::Search(search) => match search { + Search::RowidEq { .. } | Search::RowidSearch { .. } => { + program.resolve_cursor_id(&table_reference.identifier) + } + Search::IndexSearch { index, .. } => program.resolve_cursor_id(&index.name), + }, + _ => return Ok(()), + }; + + let rowid_reg = program.alloc_register(); + program.emit_insn(Insn::RowId { + cursor_id, + dest: rowid_reg, + }); + // if we have a limit, decrement and check zero + if let Some(limit) = limit { + let limit_reg = program.alloc_register(); + program.emit_insn(Insn::Integer { + value: *limit as i64, + dest: limit_reg, + }); + program.mark_last_insn_constant(); + + program.emit_insn(Insn::DecrJumpZero { + reg: limit_reg, + target_pc: t_ctx.label_main_loop_end.unwrap(), + }); + } + + // we want old_rowid= rowid_reg, new_rowid= NULL, so we pass 2 arguments to VUpdate + // we need a second register for the new rowid = NULL + let new_rowid_reg = program.alloc_register(); + + program.emit_insn(Insn::Null { + dest: new_rowid_reg, + dest_end: None, + }); + + // we'll do VUpdate with arg_count=2: + // argv[0] => old_rowid = rowid_reg + // argv[1] => new_rowid = new_rowid_reg (NULL) + + let Some(virtual_table) = table_reference.virtual_table() else { + return Err(crate::LimboError::ParseError( + "Table is not a virtual table".to_string(), + )); + }; + let conflict_action = 0u16; + let start_reg = rowid_reg; + + program.emit_insn(Insn::VUpdate { + cursor_id, + arg_count: 2, + start_reg, + vtab_ptr: virtual_table.implementation.as_ref().ctx as usize, + conflict_action, + }); + + Ok(()) +} + fn emit_delete_insns( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 5f933e93a..f77171621 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -62,7 +62,7 @@ pub fn translate_insert( body, on_conflict, &resolver, - ); + )?; return Ok(program); } let init_label = program.allocate_label(); @@ -474,11 +474,17 @@ fn translate_virtual_table_insert( for (i, expr) in values[0].iter().enumerate() { translate_expr(program, None, expr, value_registers_start + i, resolver)?; } + /* * + * Inserts for virtual tables are done in a single step. The rowid is not provided by the user, but is generated by the + * vtable implementation. + * argv[0] = current_rowid (NULL for insert) + * argv[1] = insert_rowid (NULL for insert) + * argv[2..] = column values + * */ - let start_reg = program.alloc_registers(column_mappings.len() + 3); - let rowid_reg = start_reg; // argv[0] = rowid - let insert_rowid_reg = start_reg + 1; // argv[1] = insert_rowid - let data_start_reg = start_reg + 2; // argv[2..] = column values + let rowid_reg = program.alloc_registers(column_mappings.len() + 3); + let insert_rowid_reg = rowid_reg + 1; // argv[1] = insert_rowid + let data_start_reg = rowid_reg + 2; // argv[2..] = column values program.emit_insn(Insn::Null { dest: rowid_reg, @@ -515,7 +521,7 @@ fn translate_virtual_table_insert( program.emit_insn(Insn::VUpdate { cursor_id, arg_count: column_mappings.len() + 2, - start_reg, + start_reg: rowid_reg, vtab_ptr: virtual_table.implementation.as_ref().ctx as usize, conflict_action, }); diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index a0e4a13c4..953e39242 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -110,6 +110,10 @@ pub fn init_loop( program.emit_insn(Insn::VOpenAsync { cursor_id }); program.emit_insn(Insn::VOpenAwait {}); } + (OperationMode::DELETE, Table::Virtual(_)) => { + program.emit_insn(Insn::VOpenAsync { cursor_id }); + program.emit_insn(Insn::VOpenAwait {}); + } _ => { unimplemented!() } @@ -286,44 +290,23 @@ pub fn open_loop( }, ), Table::Virtual(ref table) => { - let args = if let Some(args) = table.args.as_ref() { - args - } else { - &vec![] - }; - let start_reg = program.alloc_registers(args.len()); + let start_reg = program + .alloc_registers(table.args.as_ref().map(|a| a.len()).unwrap_or(0)); let mut cur_reg = start_reg; - - for arg_str in args { + let args = match table.args.as_ref() { + Some(args) => args, + None => &vec![], + }; + for arg in args { let reg = cur_reg; cur_reg += 1; - - if let Ok(i) = arg_str.parse::() { - program.emit_insn(Insn::Integer { - value: i, - dest: reg, - }); - } else if let Ok(f) = arg_str.parse::() { - program.emit_insn(Insn::Real { - value: f, - dest: reg, - }); - } else if arg_str.starts_with('"') && arg_str.ends_with('"') { - program.emit_insn(Insn::String8 { - value: arg_str.trim_matches('"').to_string(), - dest: reg, - }); - } else { - program.emit_insn(Insn::String8 { - value: arg_str.clone(), - dest: reg, - }); - } + let _ = + translate_expr(program, Some(tables), arg, reg, &t_ctx.resolver)?; } program.emit_insn(Insn::VFilter { cursor_id, pc_if_empty: loop_end, - arg_count: args.len(), + arg_count: table.args.as_ref().map_or(0, |args| args.len()), args_reg: start_reg, }); } @@ -697,9 +680,9 @@ fn emit_loop_source( ); let offset_jump_to = t_ctx .labels_main_loop - .get(0) + .first() .map(|l| l.next) - .or_else(|| t_ctx.label_main_loop_end); + .or(t_ctx.label_main_loop_end); emit_select_result( program, t_ctx, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 7df6258ec..f485fb7ee 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -572,8 +572,8 @@ fn create_vtable_body_to_str(vtab: &CreateVirtualTable) -> String { format!( "CREATE VIRTUAL TABLE {} {} USING {}{}", vtab.tbl_name.name.0, - vtab.module_name.0, if_not_exists, + vtab.module_name.0, if args.is_empty() { String::new() } else { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index c9c266a13..440744426 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -9,7 +9,7 @@ use super::{ use crate::{ function::Func, schema::{Schema, Table}, - util::{exprs_are_equivalent, normalize_ident}, + util::{exprs_are_equivalent, normalize_ident, vtable_args}, vdbe::BranchOffset, Result, }; @@ -303,16 +303,25 @@ fn parse_from_clause_table<'a>( return Ok(()); }; // Check if our top level schema has this table. - if let Some(table) = schema.get_btree_table(&normalized_qualified_name) { + if let Some(table) = schema.get_table(&normalized_qualified_name) { let alias = maybe_alias .map(|a| match a { ast::As::As(id) => id, ast::As::Elided(id) => id, }) .map(|a| a.0); + let tbl_ref = if let Table::Virtual(tbl) = table.as_ref() { + Table::Virtual(tbl.clone()) + } else if let Table::BTree(table) = table.as_ref() { + Table::BTree(table.clone()) + } else { + return Err(crate::LimboError::InvalidArgument( + "Table type not supported".to_string(), + )); + }; scope.tables.push(TableReference { op: Operation::Scan { iter_dir: None }, - table: Table::BTree(table.clone()), + table: tbl_ref, identifier: alias.unwrap_or(normalized_qualified_name), join_info: None, }); @@ -367,17 +376,16 @@ fn parse_from_clause_table<'a>( .push(TableReference::new_subquery(identifier, subplan, None)); Ok(()) } - ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => { + ast::SelectTable::TableCall(qualified_name, ref maybe_args, maybe_alias) => { let normalized_name = &normalize_ident(qualified_name.name.0.as_str()); + let args = vtable_args(maybe_args.as_ref().unwrap_or(&vec![]).as_slice()); let vtab = crate::VirtualTable::from_args( None, normalized_name, - &maybe_args - .as_ref() - .map(|a| a.iter().map(|s| s.to_string()).collect::>()) - .unwrap_or_default(), + args, syms, limbo_ext::VTabKind::TableValuedFunction, + maybe_args, )?; let alias = maybe_alias .as_ref() @@ -610,7 +618,7 @@ fn parse_join<'a>( constraint, } = join; - parse_from_clause_table(schema, table, scope, &syms)?; + parse_from_clause_table(schema, table, scope, syms)?; let (outer, natural) = match join_operator { ast::JoinOperator::TypedJoin(Some(join_type)) => { diff --git a/core/util.rs b/core/util.rs index ddaf931a9..17c12fbb3 100644 --- a/core/util.rs +++ b/core/util.rs @@ -388,6 +388,35 @@ pub fn columns_from_create_table_body(body: &ast::CreateTableBody) -> crate::Res .collect::>()) } +// for TVF's we need these at planning time so we cannot emit translate_expr +pub fn vtable_args(args: &[ast::Expr]) -> Vec { + let mut vtable_args = Vec::new(); + for arg in args { + match arg { + Expr::Literal(lit) => match lit { + Literal::Numeric(i) => { + if i.contains('.') { + vtable_args.push(limbo_ext::Value::from_float(i.parse().unwrap())); + } else { + vtable_args.push(limbo_ext::Value::from_integer(i.parse().unwrap())); + } + } + Literal::String(s) => { + vtable_args.push(limbo_ext::Value::from_text(s.clone())); + } + Literal::Blob(b) => { + vtable_args.push(limbo_ext::Value::from_blob(b.as_bytes().into())); + } + _ => { + vtable_args.push(limbo_ext::Value::null()); + } + }, + _ => vtable_args.push(limbo_ext::Value::null()), + } + } + vtable_args +} + #[cfg(test)] pub mod tests { use super::*; diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index a881c0d56..f4a2725c2 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -304,14 +304,19 @@ impl Bitfield { } } -pub struct VTabOpaqueCursor(*mut c_void); +pub struct VTabOpaqueCursor(*const c_void); impl VTabOpaqueCursor { - pub fn new(cursor: *mut c_void) -> Self { - Self(cursor) + pub fn new(cursor: *const c_void) -> Result { + if cursor.is_null() { + return Err(LimboError::InternalError( + "VTabOpaqueCursor: cursor is null".into(), + )); + } + Ok(Self(cursor)) } - pub fn as_ptr(&self) -> *mut c_void { + pub fn as_ptr(&self) -> *const c_void { self.0 } } @@ -866,7 +871,7 @@ impl Program { let CursorType::VirtualTable(virtual_table) = cursor_type else { panic!("VOpenAsync on non-virtual table cursor"); }; - let cursor = virtual_table.open(); + let cursor = virtual_table.open()?; state .cursors .borrow_mut() @@ -882,7 +887,7 @@ impl Program { let table_name = state.registers[*table_name].to_string(); let args = if let Some(args_reg) = args_reg { if let OwnedValue::Record(rec) = &state.registers[*args_reg] { - rec.get_values().iter().map(|v| v.to_string()).collect() + rec.get_values().iter().map(|v| v.to_ffi()).collect() } else { return Err(LimboError::InternalError( "VCreate: args_reg is not a record".to_string(), @@ -899,9 +904,10 @@ impl Program { let table = crate::VirtualTable::from_args( Some(&table_name), &module_name, - &args, + args, &conn.db.syms.borrow(), limbo_ext::VTabKind::VirtualTable, + &None, )?; { conn.db @@ -971,7 +977,6 @@ impl Program { .to_string(), )); } - let mut argv = Vec::with_capacity(*arg_count); for i in 0..*arg_count { if let Some(value) = state.registers.get(*start_reg + i) { @@ -983,18 +988,10 @@ impl Program { ))); } } - - let current_rowid = match argv.first() { - Some(OwnedValue::Integer(rowid)) => Some(*rowid), - _ => None, - }; - let insert_rowid = match argv.get(1) { - Some(OwnedValue::Integer(rowid)) => Some(*rowid), - _ => None, - }; - - let result = virtual_table.update(&argv, insert_rowid); - + // argv[0] = current_rowid (for DELETE if applicable) + // argv[1] = insert_rowid (for INSERT if applicable) + // argv[2..] = column values + let result = virtual_table.update(&argv); match result { Ok(Some(new_rowid)) => { if *conflict_action == 5 { @@ -1005,9 +1002,11 @@ impl Program { state.pc += 1; } Ok(None) => { + // no-op or successful update without rowid return state.pc += 1; } Err(e) => { + // virtual table update failed return Err(LimboError::ExtensionError(format!( "Virtual table update failed: {}", e @@ -1355,11 +1354,30 @@ impl Program { } } - let cursor = get_cursor_as_table_mut(&mut cursors, *cursor_id); - if let Some(ref rowid) = cursor.rowid()? { - state.registers[*dest] = OwnedValue::Integer(*rowid as i64); + if let Some(Cursor::Table(btree_cursor)) = cursors.get_mut(*cursor_id).unwrap() + { + if let Some(ref rowid) = btree_cursor.rowid()? { + state.registers[*dest] = OwnedValue::Integer(*rowid as i64); + } else { + state.registers[*dest] = OwnedValue::Null; + } + } else if let Some(Cursor::Virtual(virtual_cursor)) = + cursors.get_mut(*cursor_id).unwrap() + { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VUpdate on non-virtual table cursor"); + }; + let rowid = virtual_table.rowid(virtual_cursor); + if rowid != 0 { + state.registers[*dest] = OwnedValue::Integer(rowid); + } else { + state.registers[*dest] = OwnedValue::Null; + } } else { - state.registers[*dest] = OwnedValue::Null; + return Err(LimboError::InternalError( + "RowId: cursor is not a table or virtual cursor".to_string(), + )); } state.pc += 1; } From 813e7e57d85284423665bcd94b38e814ee1ddcf0 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 16 Feb 2025 20:22:16 -0500 Subject: [PATCH 06/13] Add simple demo vtable extension `kvstore` --- extensions/kvstore/Cargo.toml | 1 + extensions/kvstore/src/lib.rs | 139 ++++++++++++++++++++++------------ testing/extensions.py | 42 ++++++++++ 3 files changed, 132 insertions(+), 50 deletions(-) diff --git a/extensions/kvstore/Cargo.toml b/extensions/kvstore/Cargo.toml index 81c1f804d..cac010bb6 100644 --- a/extensions/kvstore/Cargo.toml +++ b/extensions/kvstore/Cargo.toml @@ -13,6 +13,7 @@ crate-type = ["cdylib", "lib"] static= [ "limbo_ext/static" ] [dependencies] +lazy_static = "1.5.0" limbo_ext = { workspace = true, features = ["static"] } [target.'cfg(not(target_family = "wasm"))'.dependencies] diff --git a/extensions/kvstore/src/lib.rs b/extensions/kvstore/src/lib.rs index 184467aca..f18f7851b 100644 --- a/extensions/kvstore/src/lib.rs +++ b/extensions/kvstore/src/lib.rs @@ -1,103 +1,142 @@ +use lazy_static::lazy_static; use limbo_ext::{ register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, }; -use std::collections::HashMap; +use std::collections::BTreeMap; +use std::sync::Mutex; + +lazy_static! { + static ref GLOBAL_STORE: Mutex> = Mutex::new(BTreeMap::new()); +} register_extension! { vtabs: { KVStoreVTab }, } #[derive(VTabModuleDerive, Default)] -pub struct KVStoreVTab { - store: HashMap, -} +pub struct KVStoreVTab; +/// The cursor holds a snapshot of (rowid, key, value) in memory. pub struct KVStoreCursor { - keys: Vec, - values: Vec, + rows: Vec<(i64, String, String)>, index: usize, } +/// Implementing the VTabModule trait for KVStoreVTab impl VTabModule for KVStoreVTab { type VCursor = KVStoreCursor; const VTAB_KIND: VTabKind = VTabKind::VirtualTable; const NAME: &'static str = "kv_store"; type Error = String; - fn create_schema(_args: &[String]) -> String { + fn create_schema(_args: &[Value]) -> String { "CREATE TABLE x (key TEXT PRIMARY KEY, value TEXT);".to_string() } - fn open(_args: &[String]) -> Result { + fn open(&self) -> Result { Ok(KVStoreCursor { - keys: Vec::new(), - values: Vec::new(), + rows: Vec::new(), index: 0, }) } - fn filter(cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { + fn filter(cursor: &mut Self::VCursor, _args: &[Value]) -> ResultCode { + let store = GLOBAL_STORE.lock().unwrap(); + cursor.rows = store + .iter() + .map(|(&rowid, (ref key, ref val))| (rowid, key.clone(), val.clone())) + .collect(); + cursor.rows.sort_by_key(|(rowid, _, _)| *rowid); cursor.index = 0; ResultCode::OK } - fn column(cursor: &Self::VCursor, idx: u32) -> Result { - match idx { - 0 => Ok(Value::from_text(cursor.keys[cursor.index].clone())), - 1 => Ok(Value::from_text(cursor.values[cursor.index].clone())), - _ => Err("Invalid column".into()), + fn insert(&mut self, values: &[Value]) -> Result { + let key = values + .first() + .and_then(|v| v.to_text()) + .ok_or("Missing key")? + .to_string(); + let val = values + .get(1) + .and_then(|v| v.to_text()) + .ok_or("Missing value")? + .to_string(); + let rowid = hash_key(&key); + { + let mut store = GLOBAL_STORE.lock().unwrap(); + store.insert(rowid, (key, val)); } + Ok(rowid) + } + + fn delete(&mut self, rowid: i64) -> Result<(), Self::Error> { + let mut store = GLOBAL_STORE.lock().unwrap(); + store.remove(&rowid); + Ok(()) + } + + fn update(&mut self, rowid: i64, values: &[Value]) -> Result<(), Self::Error> { + { + let mut store = GLOBAL_STORE.lock().unwrap(); + store.remove(&rowid); + } + let _ = self.insert(values)?; + Ok(()) + } + fn eof(cursor: &Self::VCursor) -> bool { + cursor.index >= cursor.rows.len() } fn next(cursor: &mut Self::VCursor) -> ResultCode { cursor.index += 1; + if cursor.index >= cursor.rows.len() { + return ResultCode::EOF; + } ResultCode::OK } - fn eof(cursor: &Self::VCursor) -> bool { - cursor.index >= cursor.keys.len() - } - - fn update(&mut self, args: &[Value], rowid: Option) -> Result, Self::Error> { - match args.len() { - 1 => { - let key = args[0].to_text().ok_or("Invalid key")?; - // Handle DELETE - self.store.remove(key); - Ok(None) - } - 2 => { - let key = args[0].to_text().ok_or("Invalid key")?; - let value = args[1].to_text().ok_or("Invalid value")?; - // Handle INSERT / UPDATE - self.store.insert(key.to_string(), value.to_string()); - Ok(Some(rowid.unwrap_or(0))) - } - _ => { - println!("args: {:?}", args); - Err("Invalid arguments for update".into()) - } + fn column(cursor: &Self::VCursor, idx: u32) -> Result { + if cursor.index >= cursor.rows.len() { + return Err("cursor out of range".into()); + } + let (_, ref key, ref val) = cursor.rows[cursor.index]; + match idx { + 0 => Ok(Value::from_text(key.clone())), // key + 1 => Ok(Value::from_text(val.clone())), // value + _ => Err("Invalid column".into()), } } } +fn hash_key(key: &str) -> i64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + key.hash(&mut hasher); + hasher.finish() as i64 +} + impl VTabCursor for KVStoreCursor { type Error = String; + fn rowid(&self) -> i64 { - self.index as i64 - } - fn column(&self, idx: u32) -> Result { - match idx { - 0 => Ok(Value::from_text(self.keys[self.index].clone())), - 1 => Ok(Value::from_text(self.values[self.index].clone())), - _ => Err("Invalid column".into()), + if self.index < self.rows.len() { + self.rows[self.index].0 + } else { + println!("rowid: -1"); + -1 } } - fn eof(&self) -> bool { - self.index >= self.keys.len() + + fn column(&self, idx: u32) -> Result { + ::column(self, idx) } + + fn eof(&self) -> bool { + ::eof(self) + } + fn next(&mut self) -> ResultCode { - self.index += 1; - ResultCode::OK + ::next(self) } } diff --git a/testing/extensions.py b/testing/extensions.py index 240abacb7..5910fe754 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -468,6 +468,47 @@ def test_series(pipe): ) +def test_kv(pipe): + ext_path = "./target/debug/liblimbo_kv" + run_test( + pipe, + "create virtual table t using kv_store;", + lambda res: "Virtual table module not found: kv_store" in res, + ) + run_test(pipe, f".load {ext_path}", returns_null) + run_test( + pipe, + "create virtual table t using kv_store;", + returns_null, + "can create kv_store vtable", + ) + run_test( + pipe, + "insert into t values ('hello', 'world');", + returns_null, + "can insert into kv_store vtable", + ) + run_test( + pipe, + "select value from t where key = 'hello';", + lambda res: "world" == res, + "can select from kv_store", + ) + run_test( + pipe, + "delete from t where key = 'hello';", + returns_null, + "can delete from kv_store", + ) + run_test(pipe, "insert into t values ('other', 'value');", returns_null) + run_test( + pipe, + "select value from t where key = 'hello';", + lambda res: "" == res, + "proper data is deleted", + ) + + def main(): pipe = init_limbo() try: @@ -476,6 +517,7 @@ def main(): test_aggregates(pipe) test_crypto(pipe) test_series(pipe) + test_kv(pipe) except Exception as e: print(f"Test FAILED: {e}") From 0547d397b1e23063615d0e9c393a85ce0d387b17 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 16 Feb 2025 20:23:19 -0500 Subject: [PATCH 07/13] Update extension api for vtable interface --- extensions/core/src/lib.rs | 48 +++++++------- extensions/core/src/types.rs | 1 + extensions/series/src/lib.rs | 23 +++---- macros/src/lib.rs | 121 +++++++++++++++++++++++------------ 4 files changed, 115 insertions(+), 78 deletions(-) diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 2b123a3a2..de82705b0 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -61,7 +61,7 @@ pub trait AggFunc { #[repr(C)] #[derive(Clone, Debug)] pub struct VTabModuleImpl { - pub ctx: *mut c_void, + pub ctx: *const c_void, pub name: *const c_char, pub create_schema: VtabFnCreateSchema, pub open: VtabFnOpen, @@ -70,18 +70,12 @@ pub struct VTabModuleImpl { pub next: VtabFnNext, pub eof: VtabFnEof, pub update: VtabFnUpdate, + pub rowid: VtabRowIDFn, } impl VTabModuleImpl { - pub fn init_schema(&self, args: &[String]) -> ExtResult { - let c_args = args - .iter() - .map(|s| std::ffi::CString::new(s.as_bytes()).unwrap().into_raw()) - .collect::>(); - let schema = unsafe { (self.create_schema)(c_args.as_ptr(), c_args.len() as i32) }; - c_args.into_iter().for_each(|s| unsafe { - let _ = std::ffi::CString::from_raw(s); - }); + pub fn init_schema(&self, args: &[Value]) -> ExtResult { + let schema = unsafe { (self.create_schema)(args.as_ptr(), args.len() as i32) }; if schema.is_null() { return Err(ResultCode::InvalidArgs); } @@ -90,25 +84,25 @@ impl VTabModuleImpl { } } -pub type VtabFnCreateSchema = - unsafe extern "C" fn(args: *const *mut c_char, argc: i32) -> *mut c_char; +pub type VtabFnCreateSchema = unsafe extern "C" fn(args: *const Value, argc: i32) -> *mut c_char; -pub type VtabFnOpen = unsafe extern "C" fn(args: *const *mut c_char, argc: i32) -> *mut c_void; +pub type VtabFnOpen = unsafe extern "C" fn(*const c_void) -> *const c_void; pub type VtabFnFilter = - unsafe extern "C" fn(cursor: *mut c_void, argc: i32, argv: *const Value) -> ResultCode; + unsafe extern "C" fn(cursor: *const c_void, argc: i32, argv: *const Value) -> ResultCode; -pub type VtabFnColumn = unsafe extern "C" fn(cursor: *mut c_void, idx: u32) -> Value; +pub type VtabFnColumn = unsafe extern "C" fn(cursor: *const c_void, idx: u32) -> Value; -pub type VtabFnNext = unsafe extern "C" fn(cursor: *mut c_void) -> ResultCode; +pub type VtabFnNext = unsafe extern "C" fn(cursor: *const c_void) -> ResultCode; -pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; +pub type VtabFnEof = unsafe extern "C" fn(cursor: *const c_void) -> bool; + +pub type VtabRowIDFn = unsafe extern "C" fn(cursor: *const c_void) -> i64; pub type VtabFnUpdate = unsafe extern "C" fn( - vtab: *mut c_void, + vtab: *const c_void, argc: i32, argv: *const Value, - rowid: i64, p_out_rowid: *mut i64, ) -> ResultCode; @@ -125,13 +119,21 @@ pub trait VTabModule: 'static { const NAME: &'static str; type Error: std::fmt::Display; - fn create_schema(args: &[String]) -> String; - fn open(args: &[String]) -> Result; - fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; + fn create_schema(args: &[Value]) -> String; + fn open(&self) -> Result; + fn filter(cursor: &mut Self::VCursor, args: &[Value]) -> ResultCode; fn column(cursor: &Self::VCursor, idx: u32) -> Result; fn next(cursor: &mut Self::VCursor) -> ResultCode; fn eof(cursor: &Self::VCursor) -> bool; - fn update(&mut self, args: &[Value], rowid: Option) -> Result, Self::Error>; + fn update(&mut self, _rowid: i64, _args: &[Value]) -> Result<(), Self::Error> { + Ok(()) + } + fn insert(&mut self, _args: &[Value]) -> Result { + Ok(0) + } + fn delete(&mut self, _rowid: i64) -> Result<(), Self::Error> { + Ok(()) + } } pub trait VTabCursor: Sized { diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 55d39ac42..f08fe099e 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -407,6 +407,7 @@ impl Value { } } + /// Extension authors should __not__ use this function. /// # Safety /// consumes the value while freeing the underlying memory with null check. /// however this does assume that the type was properly constructed with diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 161bfe886..43028eed5 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -25,7 +25,7 @@ impl VTabModule for GenerateSeriesVTab { const NAME: &'static str = "generate_series"; const VTAB_KIND: VTabKind = VTabKind::TableValuedFunction; - fn create_schema(_args: &[String]) -> String { + fn create_schema(_args: &[Value]) -> String { // Create table schema "CREATE TABLE generate_series( value INTEGER, @@ -36,7 +36,7 @@ impl VTabModule for GenerateSeriesVTab { .into() } - fn open(_args: &[String]) -> Result { + fn open(&self) -> Result { Ok(GenerateSeriesCursor { start: 0, stop: 0, @@ -45,9 +45,9 @@ impl VTabModule for GenerateSeriesVTab { }) } - fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { + fn filter(cursor: &mut Self::VCursor, args: &[Value]) -> ResultCode { // args are the start, stop, and step - if arg_count == 0 || arg_count > 3 { + if args.is_empty() || args.len() > 3 { return ResultCode::InvalidArgs; } let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); @@ -91,10 +91,6 @@ impl VTabModule for GenerateSeriesVTab { fn eof(cursor: &Self::VCursor) -> bool { cursor.eof() } - - fn update(&mut self, _args: &[Value], _rowid: Option) -> Result, Self::Error> { - Ok(None) - } } /// The cursor for iterating over the generated sequence @@ -233,7 +229,8 @@ mod tests { } // Helper function to collect all values from a cursor, returns Result with error code fn collect_series(series: Series) -> Result, ResultCode> { - let mut cursor = GenerateSeriesVTab::open()?; + let tbl = GenerateSeriesVTab; + let mut cursor = tbl.open()?; // Create args array for filter let args = vec![ @@ -243,7 +240,7 @@ mod tests { ]; // Initialize cursor through filter - match GenerateSeriesVTab::filter(&mut cursor, 3, &args) { + match GenerateSeriesVTab::filter(&mut cursor, &args) { ResultCode::OK => (), ResultCode::EOF => return Ok(vec![]), err => return Err(err), @@ -549,8 +546,8 @@ mod tests { let start = series.start; let stop = series.stop; let step = series.step; - - let mut cursor = GenerateSeriesVTab::open().unwrap(); + let tbl = GenerateSeriesVTab::default(); + let mut cursor = tbl.open().unwrap(); let args = vec![ Value::from_integer(start), @@ -559,7 +556,7 @@ mod tests { ]; // Initialize cursor through filter - GenerateSeriesVTab::filter(&mut cursor, 3, &args); + GenerateSeriesVTab::filter(&mut cursor, &args); let mut rowids = vec![]; while !GenerateSeriesVTab::eof(&cursor) { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 1f221b03f..e85b144e5 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -383,10 +383,20 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// cursor.index >= cursor.rows.len() /// } /// -/// /// Update the row with the provided values, return the new rowid if provided -/// fn update(&mut self, args: &[Value], rowid: Option) -> Result, Self::Error> { +/// /// **Optional** methods for non-readonly tables: +/// +/// /// Update the row with the provided values, return the new rowid +/// fn update(&mut self, rowid: i64, args: &[Value]) -> Result, Self::Error> { /// Ok(None)// return Ok(None) for read-only /// } +/// /// Insert a new row with the provided values, return the new rowid +/// fn insert(&mut self, args: &[Value]) -> Result<(), Self::Error> { +/// Ok(()) // +/// } +/// /// Delete the row with the provided rowid +/// fn delete(&mut self, rowid: i64) -> Result<(), Self::Error> { +/// Ok(()) +/// } /// /// #[derive(Debug)] /// struct CsvCursor { @@ -431,44 +441,40 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let next_fn_name = format_ident!("next_{}", struct_name); let eof_fn_name = format_ident!("eof_{}", struct_name); let update_fn_name = format_ident!("update_{}", struct_name); + let rowid_fn_name = format_ident!("rowid_{}", struct_name); let expanded = quote! { impl #struct_name { #[no_mangle] unsafe extern "C" fn #create_schema_fn_name( - argv: *const *mut ::std::ffi::c_char, argc: i32 + argv: *const ::limbo_ext::Value, argc: i32 ) -> *mut ::std::ffi::c_char { let args = if argv.is_null() { - Vec::new() + &Vec::new() } else { - ::std::slice::from_raw_parts(argv, argc as usize).iter().map(|s| { - ::std::ffi::CStr::from_ptr(*s).to_string_lossy().to_string() - }).collect::>() + ::std::slice::from_raw_parts(argv, argc as usize) }; let sql = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); ::std::ffi::CString::new(sql).unwrap().into_raw() } #[no_mangle] - unsafe extern "C" fn #open_fn_name(argv: *const *mut ::std::ffi::c_char, argc: i32) -> *mut ::std::ffi::c_void { - let args = if argv.is_null() { - Vec::new() + unsafe extern "C" fn #open_fn_name(ctx: *const ::std::ffi::c_void) -> *const ::std::ffi::c_void { + if ctx.is_null() { + return ::std::ptr::null(); + } + let ctx = ctx as *const #struct_name; + let ctx: &#struct_name = &*ctx; + if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open(ctx) { + return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *const ::std::ffi::c_void; } else { - ::std::slice::from_raw_parts(argv, argc as usize).iter().map(|s| { - ::std::ffi::CStr::from_ptr(*s).to_string_lossy().to_string() - }).collect::>() - }; - let schema = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); - if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open(&args) { - return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *mut ::std::ffi::c_void; - } else { - return ::std::ptr::null_mut(); + return ::std::ptr::null(); } } #[no_mangle] unsafe extern "C" fn #filter_fn_name( - cursor: *mut ::std::ffi::c_void, + cursor: *const ::std::ffi::c_void, argc: i32, argv: *const ::limbo_ext::Value, ) -> ::limbo_ext::ResultCode { @@ -477,12 +483,12 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; let args = ::std::slice::from_raw_parts(argv, argc as usize); - <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, argc, args) + <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, args) } #[no_mangle] unsafe extern "C" fn #column_fn_name( - cursor: *mut ::std::ffi::c_void, + cursor: *const ::std::ffi::c_void, idx: u32, ) -> ::limbo_ext::Value { if cursor.is_null() { @@ -497,56 +503,86 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { #[no_mangle] unsafe extern "C" fn #next_fn_name( - cursor: *mut ::std::ffi::c_void, + cursor: *const ::std::ffi::c_void, ) -> ::limbo_ext::ResultCode { if cursor.is_null() { return ::limbo_ext::ResultCode::Error; } - let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); <#struct_name as ::limbo_ext::VTabModule>::next(cursor) } #[no_mangle] unsafe extern "C" fn #eof_fn_name( - cursor: *mut ::std::ffi::c_void, + cursor: *const ::std::ffi::c_void, ) -> bool { if cursor.is_null() { return true; } - let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) } #[no_mangle] unsafe extern "C" fn #update_fn_name( - vtab: *mut ::std::ffi::c_void, + vtab: *const ::std::ffi::c_void, argc: i32, argv: *const ::limbo_ext::Value, - rowid: i64, p_out_rowid: *mut i64, - ) -> ::limbo_ext::ResultCode { + ) -> ::limbo_ext::ResultCode { if vtab.is_null() { return ::limbo_ext::ResultCode::Error; } + let vtab = &mut *(vtab as *mut #struct_name); let args = ::std::slice::from_raw_parts(argv, argc as usize); - let rowid = if rowid == -1 { - None - } else { - Some(rowid as i64) + + let old_rowid = match args.get(0).map(|v| v.value_type()) { + Some(::limbo_ext::ValueType::Integer) => args.get(0).unwrap().to_integer(), + _ => None, }; - let result = <#struct_name as ::limbo_ext::VTabModule>::update(vtab, args, rowid); - match result { - Ok(Some(rowid)) => { - // set the output rowid if it was provided - *p_out_rowid = rowid; - ::limbo_ext::ResultCode::RowID + let new_rowid = match args.get(1).map(|v| v.value_type()) { + Some(::limbo_ext::ValueType::Integer) => args.get(1).unwrap().to_integer(), + _ => None, + }; + let columns = &args[2..]; + match (old_rowid, new_rowid) { + // DELETE: old_rowid provided, no new_rowid + (Some(old), None) => { + if <#struct_name as VTabModule>::delete(vtab, old).is_err() { + return ::limbo_ext::ResultCode::Error; + } + return ::limbo_ext::ResultCode::OK; + } + // UPDATE: old_rowid provided and new_rowid may exist + (Some(old), Some(new)) => { + if <#struct_name as VTabModule>::update(vtab, old, &columns).is_err() { + return ::limbo_ext::ResultCode::Error; + } + return ::limbo_ext::ResultCode::OK; + } + // INSERT: no old_rowid (old_rowid = None) + (None, _) => { + if let Ok(rowid) = <#struct_name as VTabModule>::insert(vtab, &columns) { + if !p_out_rowid.is_null() { + *p_out_rowid = rowid; + return ::limbo_ext::ResultCode::RowID; + } + return ::limbo_ext::ResultCode::OK; + } } - Ok(None) => ::limbo_ext::ResultCode::OK, - Err(_) => ::limbo_ext::ResultCode::Error, } + return ::limbo_ext::ResultCode::Error; } + #[no_mangle] + pub unsafe extern "C" fn #rowid_fn_name(ctx: *const ::std::ffi::c_void) -> i64 { + if ctx.is_null() { + return -1; + } + let cursor = &*(ctx as *const <#struct_name as ::limbo_ext::VTabModule>::VCursor); + <<#struct_name as ::limbo_ext::VTabModule>::VCursor as ::limbo_ext::VTabCursor>::rowid(cursor) + } #[no_mangle] pub unsafe extern "C" fn #register_fn_name( @@ -560,7 +596,7 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let name_c = ::std::ffi::CString::new(name).unwrap().into_raw() as *const ::std::ffi::c_char; let table_instance = ::std::boxed::Box::into_raw(::std::boxed::Box::new(#struct_name::default())); let module = ::limbo_ext::VTabModuleImpl { - ctx: table_instance as *mut ::std::ffi::c_void, + ctx: table_instance as *const ::std::ffi::c_void, name: name_c, create_schema: Self::#create_schema_fn_name, open: Self::#open_fn_name, @@ -569,6 +605,7 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { next: Self::#next_fn_name, eof: Self::#eof_fn_name, update: Self::#update_fn_name, + rowid: Self::#rowid_fn_name, }; (api.register_module)(api.ctx, name_c, module, <#struct_name as ::limbo_ext::VTabModule>::VTAB_KIND) } From 38e54ca85e5853f5a4da5d242714bf9d8b8c5320 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 16 Feb 2025 20:23:41 -0500 Subject: [PATCH 08/13] Update schema dot command to show virtual tables --- Cargo.lock | 1 + cli/app.rs | 4 ++-- core/lib.rs | 1 - core/translate/insert.rs | 9 ++++----- core/translate/mod.rs | 4 ++++ core/vdbe/mod.rs | 4 +--- extensions/core/README.md | 26 +++++++++++++++++++++++--- 7 files changed, 35 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3e55db38f..2cc79446c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1687,6 +1687,7 @@ dependencies = [ name = "limbo_kv" version = "0.0.14" dependencies = [ + "lazy_static", "limbo_ext", "mimalloc", ] diff --git a/cli/app.rs b/cli/app.rs index 5a93da77f..9c09fe20a 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -765,11 +765,11 @@ impl<'a> Limbo<'a> { fn display_schema(&mut self, table: Option<&str>) -> anyhow::Result<()> { let sql = match table { Some(table_name) => format!( - "SELECT sql FROM sqlite_schema WHERE type IN ('table', 'index') AND tbl_name = '{}' AND name NOT LIKE 'sqlite_%'", + "SELECT sql FROM sqlite_schema WHERE type IN ('table', 'index', 'virtual') AND tbl_name = '{}' AND name NOT LIKE 'sqlite_%'", table_name ), None => String::from( - "SELECT sql FROM sqlite_schema WHERE type IN ('table', 'index') AND name NOT LIKE 'sqlite_%'" + "SELECT sql FROM sqlite_schema WHERE type IN ('table', 'index', 'virtual') AND name NOT LIKE 'sqlite_%'" ), }; diff --git a/core/lib.rs b/core/lib.rs index 8719f3c0a..ed45d344e 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -556,7 +556,6 @@ impl VirtualTable { } } let mut parser = Parser::new(schema.as_bytes()); - parser.reset(schema.as_bytes()); if let ast::Cmd::Stmt(ast::Stmt::CreateTable { body, .. }) = parser.next()?.ok_or( LimboError::ParseError("Failed to parse schema from virtual table module".to_string()), )? { diff --git a/core/translate/insert.rs b/core/translate/insert.rs index f77171621..53368d30b 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -122,7 +122,7 @@ pub fn translate_insert( let inserting_multiple_rows = values.len() > 1; - // multiple rows - use coroutine for value population + // Multiple rows - use coroutine for value population if inserting_multiple_rows { let yield_reg = program.alloc_register(); let jump_on_definition_label = program.allocate_label(); @@ -475,10 +475,9 @@ fn translate_virtual_table_insert( translate_expr(program, None, expr, value_registers_start + i, resolver)?; } /* * - * Inserts for virtual tables are done in a single step. The rowid is not provided by the user, but is generated by the - * vtable implementation. - * argv[0] = current_rowid (NULL for insert) - * argv[1] = insert_rowid (NULL for insert) + * Inserts for virtual tables are done in a single step. + * argv[0] = (NULL for insert) + * argv[1] = (NULL for insert) * argv[2..] = column values * */ diff --git a/core/translate/mod.rs b/core/translate/mod.rs index f485fb7ee..827a925a1 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -625,10 +625,14 @@ fn translate_create_virtual_table( let args_reg = if !args_vec.is_empty() { let args_start = program.alloc_register(); + + // Emit string8 instructions for each arg for (i, arg) in args_vec.iter().enumerate() { program.emit_string8(arg.clone(), args_start + i); } let args_record_reg = program.alloc_register(); + + // VCreate expects an array of args as a record program.emit_insn(Insn::MakeRecord { start_reg: args_start, count: args_vec.len(), diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index f4a2725c2..fc980e85b 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -988,13 +988,11 @@ impl Program { ))); } } - // argv[0] = current_rowid (for DELETE if applicable) - // argv[1] = insert_rowid (for INSERT if applicable) - // argv[2..] = column values let result = virtual_table.update(&argv); match result { Ok(Some(new_rowid)) => { if *conflict_action == 5 { + // ResolveType::Replace if let Some(conn) = self.connection.upgrade() { conn.update_last_rowid(new_rowid as u64); } diff --git a/extensions/core/README.md b/extensions/core/README.md index ddcfe413d..33236ef40 100644 --- a/extensions/core/README.md +++ b/extensions/core/README.md @@ -169,7 +169,11 @@ impl VTabModule for CsvVTable { /// Declare the name for your virtual table const NAME: &'static str = "csv_data"; - fn init_sql() -> &'static str { + /// Declare the type of vtable (TableValuedFunction or VirtualTable) + const VTAB_KIND: VTabKind = VTabKind::VirtualTable; + + /// Function to initialize the schema of your vtable + fn create_schema(_args: &[Value]) -> &'static str { "CREATE TABLE csv_data( name TEXT, age TEXT, @@ -178,7 +182,7 @@ impl VTabModule for CsvVTable { } /// Open to return a new cursor: In this simple example, the CSV file is read completely into memory on connect. - fn open() -> Result { + fn open(&self) -> Result { // Read CSV file contents from "data.csv" let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); // For simplicity, we'll ignore the header row. @@ -195,7 +199,7 @@ impl VTabModule for CsvVTable { } /// Filter through result columns. (not used in this simple example) - fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { + fn filter(_cursor: &mut Self::VCursor, _args: &[Value]) -> ResultCode { ResultCode::OK } @@ -218,6 +222,22 @@ impl VTabModule for CsvVTable { fn eof(cursor: &Self::VCursor) -> bool { cursor.index >= cursor.rows.len() } + + /// *Optional* methods for non-readonly tables + + /// Update the value at rowid + fn update(&mut self, _rowid: i64, _args: &[Value]) -> Result<(), Self::Error> { + Ok(()) + } + + /// Insert the value(s) + fn insert(&mut self, _args: &[Value]) -> Result { + Ok(0) + } + /// Delete the value at rowid + fn delete(&mut self, _rowid: i64) -> Result<(), Self::Error> { + Ok(()) + } } /// The cursor for iterating over CSV rows. From 4d2044b0109ac1cfa77e6bd9b7f3e287fc7fcacc Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 17 Feb 2025 08:15:57 -0500 Subject: [PATCH 09/13] Fix ownership semantics in extention value conversions --- core/ext/mod.rs | 4 +- core/lib.rs | 7 +-- core/translate/emitter.rs | 102 ++++++++------------------------------ core/types.rs | 11 ++-- core/vdbe/mod.rs | 14 +++--- 5 files changed, 41 insertions(+), 97 deletions(-) diff --git a/core/ext/mod.rs b/core/ext/mod.rs index a5cfd5909..1402c4098 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -11,7 +11,7 @@ type ExternAggFunc = (InitAggFunction, StepFunction, FinalizeFunction); #[derive(Clone)] pub struct VTabImpl { - pub module_type: VTabKind, + pub module_kind: VTabKind, pub implementation: Rc, } @@ -104,7 +104,7 @@ impl Database { ) -> ResultCode { let module = Rc::new(module); let vmodule = VTabImpl { - module_type: kind, + module_kind: kind, implementation: module, }; self.syms diff --git a/core/lib.rs b/core/lib.rs index ed45d344e..3583ad5d5 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -542,7 +542,7 @@ impl VirtualTable { module_name )))?; if let VTabKind::VirtualTable = kind { - if module.module_type != VTabKind::VirtualTable { + if module.module_kind != VTabKind::VirtualTable { return Err(LimboError::ExtensionError(format!( "Virtual table module {} is not a virtual table", module_name @@ -612,10 +612,7 @@ impl VirtualTable { pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result { let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) }; - let res = OwnedValue::from_ffi(&val)?; - unsafe { - val.free(); - } + let res = OwnedValue::from_ffi(val)?; Ok(res) } diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 8d095eb9b..ca7162af1 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -308,13 +308,7 @@ fn emit_program_for_delete( &plan.table_references, &plan.where_clause, )?; - if let Some(table) = plan.table_references.first() { - if table.virtual_table().is_some() { - emit_delete_vtable_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?; - } else { - emit_delete_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?; - } - } + emit_delete_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?; // Clean up and close the main execution loop close_loop(program, &mut t_ctx, &plan.table_references)?; @@ -328,77 +322,6 @@ fn emit_program_for_delete( Ok(()) } -fn emit_delete_vtable_insns( - program: &mut ProgramBuilder, - t_ctx: &mut TranslateCtx, - table_references: &[TableReference], - limit: &Option, -) -> Result<()> { - let table_reference = table_references.first().unwrap(); - - let cursor_id = match &table_reference.op { - Operation::Scan { .. } => program.resolve_cursor_id(&table_reference.identifier), - Operation::Search(search) => match search { - Search::RowidEq { .. } | Search::RowidSearch { .. } => { - program.resolve_cursor_id(&table_reference.identifier) - } - Search::IndexSearch { index, .. } => program.resolve_cursor_id(&index.name), - }, - _ => return Ok(()), - }; - - let rowid_reg = program.alloc_register(); - program.emit_insn(Insn::RowId { - cursor_id, - dest: rowid_reg, - }); - // if we have a limit, decrement and check zero - if let Some(limit) = limit { - let limit_reg = program.alloc_register(); - program.emit_insn(Insn::Integer { - value: *limit as i64, - dest: limit_reg, - }); - program.mark_last_insn_constant(); - - program.emit_insn(Insn::DecrJumpZero { - reg: limit_reg, - target_pc: t_ctx.label_main_loop_end.unwrap(), - }); - } - - // we want old_rowid= rowid_reg, new_rowid= NULL, so we pass 2 arguments to VUpdate - // we need a second register for the new rowid = NULL - let new_rowid_reg = program.alloc_register(); - - program.emit_insn(Insn::Null { - dest: new_rowid_reg, - dest_end: None, - }); - - // we'll do VUpdate with arg_count=2: - // argv[0] => old_rowid = rowid_reg - // argv[1] => new_rowid = new_rowid_reg (NULL) - - let Some(virtual_table) = table_reference.virtual_table() else { - return Err(crate::LimboError::ParseError( - "Table is not a virtual table".to_string(), - )); - }; - let conflict_action = 0u16; - let start_reg = rowid_reg; - - program.emit_insn(Insn::VUpdate { - cursor_id, - arg_count: 2, - start_reg, - vtab_ptr: virtual_table.implementation.as_ref().ctx as usize, - conflict_action, - }); - - Ok(()) -} - fn emit_delete_insns( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, @@ -423,8 +346,27 @@ fn emit_delete_insns( cursor_id, dest: key_reg, }); - program.emit_insn(Insn::DeleteAsync { cursor_id }); - program.emit_insn(Insn::DeleteAwait { cursor_id }); + + if let Some(vtab) = table_reference.virtual_table() { + let conflict_action = 0u16; + let start_reg = key_reg; + + let new_rowid_reg = program.alloc_register(); + program.emit_insn(Insn::Null { + dest: new_rowid_reg, + dest_end: None, + }); + program.emit_insn(Insn::VUpdate { + cursor_id, + arg_count: 2, + start_reg, + vtab_ptr: vtab.implementation.as_ref().ctx as usize, + conflict_action, + }); + } else { + program.emit_insn(Insn::DeleteAsync { cursor_id }); + program.emit_insn(Insn::DeleteAwait { cursor_id }); + } if let Some(limit) = limit { let limit_reg = program.alloc_register(); program.emit_insn(Insn::Integer { diff --git a/core/types.rs b/core/types.rs index 7eee76e94..f1dafd31e 100644 --- a/core/types.rs +++ b/core/types.rs @@ -223,8 +223,8 @@ impl OwnedValue { } } - pub fn from_ffi(v: &ExtValue) -> Result { - match v.value_type() { + pub fn from_ffi(v: ExtValue) -> Result { + let res = match v.value_type() { ExtValueType::Null => Ok(OwnedValue::Null), ExtValueType::Integer => { let Some(int) = v.to_integer() else { @@ -259,7 +259,11 @@ impl OwnedValue { (code, None) => Err(LimboError::ExtensionError(code.to_string())), } } + }; + unsafe { + v.free(); } + res } } @@ -281,8 +285,7 @@ impl AggContext { if let Self::External(ext_state) = self { if ext_state.finalized_value.is_none() { let final_value = unsafe { (ext_state.finalize_fn)(ext_state.state) }; - ext_state.cache_final_value(OwnedValue::from_ffi(&final_value)?); - unsafe { final_value.free() }; + ext_state.cache_final_value(OwnedValue::from_ffi(final_value)?); } } Ok(()) diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index fc980e85b..a45ea199c 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -169,13 +169,11 @@ macro_rules! call_external_function { ) => {{ if $arg_count == 0 { let result_c_value: ExtValue = unsafe { ($func_ptr)(0, std::ptr::null()) }; - match OwnedValue::from_ffi(&result_c_value) { + match OwnedValue::from_ffi(result_c_value) { Ok(result_ov) => { $state.registers[$dest_register] = result_ov; - unsafe { result_c_value.free() }; } Err(e) => { - unsafe { result_c_value.free() }; return Err(e); } } @@ -188,13 +186,14 @@ macro_rules! call_external_function { } let argv_ptr = ext_values.as_ptr(); let result_c_value: ExtValue = unsafe { ($func_ptr)($arg_count as i32, argv_ptr) }; - match OwnedValue::from_ffi(&result_c_value) { + for arg in ext_values { + unsafe { arg.free() }; + } + match OwnedValue::from_ffi(result_c_value) { Ok(result_ov) => { $state.registers[$dest_register] = result_ov; - unsafe { result_c_value.free() }; } Err(e) => { - unsafe { result_c_value.free() }; return Err(e); } } @@ -1858,6 +1857,9 @@ impl Program { } let argv_ptr = ext_values.as_ptr(); unsafe { step_fn(state_ptr, argc as i32, argv_ptr) }; + for ext_value in ext_values { + unsafe { ext_value.free() }; + } } } }; From e63436dc47a281235ad042aa0d6dbf323d3b10e3 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 17 Feb 2025 09:02:45 -0500 Subject: [PATCH 10/13] Fix sqlite_schema and remove explicit vtables --- cli/app.rs | 4 ++-- core/lib.rs | 31 +++++-------------------------- core/translate/mod.rs | 6 ++---- core/util.rs | 17 +++++++++-------- extensions/core/src/lib.rs | 5 ++++- 5 files changed, 22 insertions(+), 41 deletions(-) diff --git a/cli/app.rs b/cli/app.rs index 9c09fe20a..5a93da77f 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -765,11 +765,11 @@ impl<'a> Limbo<'a> { fn display_schema(&mut self, table: Option<&str>) -> anyhow::Result<()> { let sql = match table { Some(table_name) => format!( - "SELECT sql FROM sqlite_schema WHERE type IN ('table', 'index', 'virtual') AND tbl_name = '{}' AND name NOT LIKE 'sqlite_%'", + "SELECT sql FROM sqlite_schema WHERE type IN ('table', 'index') AND tbl_name = '{}' AND name NOT LIKE 'sqlite_%'", table_name ), None => String::from( - "SELECT sql FROM sqlite_schema WHERE type IN ('table', 'index', 'virtual') AND name NOT LIKE 'sqlite_%'" + "SELECT sql FROM sqlite_schema WHERE type IN ('table', 'index') AND name NOT LIKE 'sqlite_%'" ), }; diff --git a/core/lib.rs b/core/lib.rs index 3583ad5d5..f1ddcc08c 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -544,17 +544,12 @@ impl VirtualTable { if let VTabKind::VirtualTable = kind { if module.module_kind != VTabKind::VirtualTable { return Err(LimboError::ExtensionError(format!( - "Virtual table module {} is not a virtual table", + "{} is not a virtual table module", module_name ))); } }; - let schema = module.implementation.as_ref().init_schema(&args)?; - for arg in args { - unsafe { - arg.free(); - } - } + let schema = module.implementation.as_ref().init_schema(args)?; let mut parser = Parser::new(schema.as_bytes()); if let ast::Cmd::Stmt(ast::Stmt::CreateTable { body, .. }) = parser.next()?.ok_or( LimboError::ParseError("Failed to parse schema from virtual table module".to_string()), @@ -612,8 +607,7 @@ impl VirtualTable { pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result { let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) }; - let res = OwnedValue::from_ffi(val)?; - Ok(res) + OwnedValue::from_ffi(val) } pub fn next(&self, cursor: &VTabOpaqueCursor) -> Result { @@ -627,27 +621,12 @@ impl VirtualTable { pub fn update(&self, args: &[OwnedValue]) -> Result> { let arg_count = args.len(); - let mut ext_args = Vec::with_capacity(arg_count); - for i in 0..arg_count { - let ownedvalue_arg = args.get(i).unwrap(); - let extvalue_arg: ExtValue = match ownedvalue_arg { - OwnedValue::Null => Ok(ExtValue::null()), - OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)), - OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)), - OwnedValue::Text(t) => Ok(ExtValue::from_text(t.as_str().to_string())), - OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())), - other => Err(LimboError::ExtensionError(format!( - "Unsupported value type: {:?}", - other - ))), - }?; - ext_args.push(extvalue_arg); - } + let ext_args = args.iter().map(|arg| arg.to_ffi()).collect::>(); let newrowid = 0i64; let implementation = self.implementation.as_ref(); let rc = unsafe { (self.implementation.update)( - implementation as *const VTabModuleImpl as *mut std::ffi::c_void, + implementation as *const VTabModuleImpl as *const std::ffi::c_void, arg_count as i32, ext_args.as_ptr(), &newrowid as *const _ as *mut i64, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 827a925a1..fe49d05ce 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -176,7 +176,6 @@ addr opcode p1 p2 p3 p4 p5 comment enum SchemaEntryType { Table, Index, - Virtual, } impl SchemaEntryType { @@ -184,7 +183,6 @@ impl SchemaEntryType { match self { SchemaEntryType::Table => "table", SchemaEntryType::Index => "index", - SchemaEntryType::Virtual => "virtual", } } } @@ -211,7 +209,7 @@ fn emit_schema_entry( program.emit_string8_new_reg(tbl_name.to_string()); let rootpage_reg = program.alloc_register(); - if matches!(entry_type, SchemaEntryType::Virtual) { + if root_page_reg == 0 { program.emit_insn(Insn::Integer { dest: rootpage_reg, value: 0, // virtual tables in sqlite always have rootpage=0 @@ -664,7 +662,7 @@ fn translate_create_virtual_table( emit_schema_entry( &mut program, sqlite_schema_cursor_id, - SchemaEntryType::Virtual, + SchemaEntryType::Table, &tbl_name.name.0, &tbl_name.name.0, 0, // virtual tables dont have a root page diff --git a/core/util.rs b/core/util.rs index 17c12fbb3..90a878ceb 100644 --- a/core/util.rs +++ b/core/util.rs @@ -37,20 +37,21 @@ pub fn parse_schema_rows( StepResult::Row => { let row = rows.row().unwrap(); let ty = row.get::<&str>(0)?; - if !["table", "index", "virtual"].contains(&ty) { + if !["table", "index"].contains(&ty) { continue; } match ty { "table" => { let root_page: i64 = row.get::(3)?; let sql: &str = row.get::<&str>(4)?; - let table = schema::BTreeTable::from_sql(sql, root_page as usize)?; - schema.add_btree_table(Rc::new(table)); - } - "virtual" => { - let name: &str = row.get::<&str>(1)?; - let vtab = syms.vtabs.get(name).unwrap().clone(); - schema.add_virtual_table(vtab); + if root_page == 0 && sql.to_lowercase().contains("virtual") { + let name: &str = row.get::<&str>(1)?; + let vtab = syms.vtabs.get(name).unwrap().clone(); + schema.add_virtual_table(vtab); + } else { + let table = schema::BTreeTable::from_sql(sql, root_page as usize)?; + schema.add_btree_table(Rc::new(table)); + } } "index" => { let root_page: i64 = row.get::(3)?; diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index de82705b0..2951d591e 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -74,11 +74,14 @@ pub struct VTabModuleImpl { } impl VTabModuleImpl { - pub fn init_schema(&self, args: &[Value]) -> ExtResult { + pub fn init_schema(&self, args: Vec) -> ExtResult { let schema = unsafe { (self.create_schema)(args.as_ptr(), args.len() as i32) }; if schema.is_null() { return Err(ResultCode::InvalidArgs); } + for arg in args { + unsafe { arg.free() }; + } let schema = unsafe { std::ffi::CString::from_raw(schema) }; Ok(schema.to_string_lossy().to_string()) } From 9b742e1a767198e7657ca689308c64c99afda6fc Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 17 Feb 2025 21:31:27 -0500 Subject: [PATCH 11/13] Remove clone in vtab from_args --- core/lib.rs | 4 ++-- core/translate/planner.rs | 7 +++++-- core/vdbe/mod.rs | 2 +- testing/cli_tests/cli_test_cases.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index f1ddcc08c..267520d20 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -532,7 +532,7 @@ impl VirtualTable { args: Vec, syms: &SymbolTable, kind: VTabKind, - exprs: &Option>, + exprs: Option>, ) -> Result> { let module = syms .vtab_modules @@ -559,7 +559,7 @@ impl VirtualTable { name: tbl_name.unwrap_or(module_name).to_owned(), implementation: module.implementation.clone(), columns, - args: exprs.clone(), + args: exprs, }); return Ok(vtab); } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 440744426..84057d1e6 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -376,9 +376,12 @@ fn parse_from_clause_table<'a>( .push(TableReference::new_subquery(identifier, subplan, None)); Ok(()) } - ast::SelectTable::TableCall(qualified_name, ref maybe_args, maybe_alias) => { + ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => { let normalized_name = &normalize_ident(qualified_name.name.0.as_str()); - let args = vtable_args(maybe_args.as_ref().unwrap_or(&vec![]).as_slice()); + let args = match maybe_args { + Some(ref args) => vtable_args(args), + None => vec![], + }; let vtab = crate::VirtualTable::from_args( None, normalized_name, diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index a45ea199c..2c7566628 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -906,7 +906,7 @@ impl Program { args, &conn.db.syms.borrow(), limbo_ext::VTabKind::VirtualTable, - &None, + None, )?; { conn.db diff --git a/testing/cli_tests/cli_test_cases.py b/testing/cli_tests/cli_test_cases.py index 6eb1106f2..caaf730a4 100755 --- a/testing/cli_tests/cli_test_cases.py +++ b/testing/cli_tests/cli_test_cases.py @@ -142,7 +142,7 @@ def test_output_file(): expected_lines = { f"Output: {output_filename}": "Can direct output to a file", - "Output mode: raw": "Output mode remains raw when output is redirected", + "Output mode: list": "Output mode remains list when output is redirected", "Error: pretty output can only be written to a tty": "Error message for pretty mode", "SELECT 'TEST_ECHO'": "Echoed command", "TEST_ECHO": "Echoed result", From c4f42549f484e052441d678ac208c6064be539f8 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 17 Feb 2025 22:53:51 -0500 Subject: [PATCH 12/13] Fix selecting from empty table error in kv extension --- extensions/kvstore/src/lib.rs | 28 +++++++++++++++++----------- testing/extensions.py | 24 ++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/extensions/kvstore/src/lib.rs b/extensions/kvstore/src/lib.rs index f18f7851b..45f8d0cad 100644 --- a/extensions/kvstore/src/lib.rs +++ b/extensions/kvstore/src/lib.rs @@ -19,7 +19,7 @@ pub struct KVStoreVTab; /// The cursor holds a snapshot of (rowid, key, value) in memory. pub struct KVStoreCursor { rows: Vec<(i64, String, String)>, - index: usize, + index: Option, } /// Implementing the VTabModule trait for KVStoreVTab @@ -36,7 +36,7 @@ impl VTabModule for KVStoreVTab { fn open(&self) -> Result { Ok(KVStoreCursor { rows: Vec::new(), - index: 0, + index: None, }) } @@ -44,10 +44,16 @@ impl VTabModule for KVStoreVTab { let store = GLOBAL_STORE.lock().unwrap(); cursor.rows = store .iter() - .map(|(&rowid, (ref key, ref val))| (rowid, key.clone(), val.clone())) + .map(|(&rowid, (k, v))| (rowid, k.clone(), v.clone())) .collect(); cursor.rows.sort_by_key(|(rowid, _, _)| *rowid); - cursor.index = 0; + + if cursor.rows.is_empty() { + cursor.index = None; + return ResultCode::EOF; + } else { + cursor.index = Some(0); + } ResultCode::OK } @@ -85,22 +91,22 @@ impl VTabModule for KVStoreVTab { Ok(()) } fn eof(cursor: &Self::VCursor) -> bool { - cursor.index >= cursor.rows.len() + cursor.index.is_some_and(|s| s >= cursor.rows.len()) || cursor.index.is_none() } fn next(cursor: &mut Self::VCursor) -> ResultCode { - cursor.index += 1; - if cursor.index >= cursor.rows.len() { + cursor.index = Some(cursor.index.unwrap_or(0) + 1); + if cursor.index.is_some_and(|c| c >= cursor.rows.len()) { return ResultCode::EOF; } ResultCode::OK } fn column(cursor: &Self::VCursor, idx: u32) -> Result { - if cursor.index >= cursor.rows.len() { + if cursor.index.is_some_and(|c| c >= cursor.rows.len()) { return Err("cursor out of range".into()); } - let (_, ref key, ref val) = cursor.rows[cursor.index]; + let (_, ref key, ref val) = cursor.rows[cursor.index.unwrap_or(0)]; match idx { 0 => Ok(Value::from_text(key.clone())), // key 1 => Ok(Value::from_text(val.clone())), // value @@ -120,8 +126,8 @@ impl VTabCursor for KVStoreCursor { type Error = String; fn rowid(&self) -> i64 { - if self.index < self.rows.len() { - self.rows[self.index].0 + if self.index.is_some_and(|c| c < self.rows.len()) { + self.rows[self.index.unwrap_or(0)].0 } else { println!("rowid: -1"); -1 diff --git a/testing/extensions.py b/testing/extensions.py index 5910fe754..6e203d416 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -507,6 +507,30 @@ def test_kv(pipe): lambda res: "" == res, "proper data is deleted", ) + run_test( + pipe, + "select * from t;", + lambda res: "other|value" == res, + "can select after deletion", + ) + run_test( + pipe, + "delete from t where key = 'other';", + returns_null, + "can delete from kv_store", + ) + run_test( + pipe, + "select * from t;", + lambda res: "" == res, + "can select empty table without error", + ) + run_test( + pipe, + "delete from t;", + returns_null, + "can delete from empty table without error", + ) def main(): From 7b05f5333513ab9d3744a891fc8bb1fefd37d323 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 17 Feb 2025 23:22:02 -0500 Subject: [PATCH 13/13] Add more tests for vtab impl --- extensions/kvstore/src/lib.rs | 3 +-- testing/extensions.py | 9 +++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/extensions/kvstore/src/lib.rs b/extensions/kvstore/src/lib.rs index 45f8d0cad..a9de7c71d 100644 --- a/extensions/kvstore/src/lib.rs +++ b/extensions/kvstore/src/lib.rs @@ -16,13 +16,12 @@ register_extension! { #[derive(VTabModuleDerive, Default)] pub struct KVStoreVTab; -/// The cursor holds a snapshot of (rowid, key, value) in memory. +/// the cursor holds a snapshot of (rowid, key, value) in memory. pub struct KVStoreCursor { rows: Vec<(i64, String, String)>, index: Option, } -/// Implementing the VTabModule trait for KVStoreVTab impl VTabModule for KVStoreVTab { type VCursor = KVStoreCursor; const VTAB_KIND: VTabKind = VTabKind::VirtualTable; diff --git a/testing/extensions.py b/testing/extensions.py index 6e203d416..8bff11bc2 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -531,6 +531,15 @@ def test_kv(pipe): returns_null, "can delete from empty table without error", ) + for i in range(100): + write_to_pipe(pipe, f"insert into t values ('key{i}', 'val{i}');") + run_test( + pipe, "select count(*) from t;", lambda res: "100" == res, "can insert 100 rows" + ) + run_test(pipe, "delete from t limit 96;", returns_null, "can delete 96 rows") + run_test( + pipe, "select count(*) from t;", lambda res: "4" == res, "four rows remain" + ) def main():