diff --git a/extensions/completion/src/lib.rs b/extensions/completion/src/lib.rs index bfd03756c..ae7a9d50e 100644 --- a/extensions/completion/src/lib.rs +++ b/extensions/completion/src/lib.rs @@ -3,9 +3,12 @@ mod keywords; +use std::rc::Rc; + use keywords::KEYWORDS; use limbo_ext::{ - register_extension, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, VTable, Value, + register_extension, Connection, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, VTable, + Value, }; register_extension! { @@ -84,7 +87,7 @@ impl VTable for CompletionTable { type Cursor = CompletionCursor; type Error = ResultCode; - fn open(&self) -> Result { + fn open(&self, _conn: Option>) -> Result { Ok(CompletionCursor::default()) } } diff --git a/extensions/core/src/vtabs.rs b/extensions/core/src/vtabs.rs index 316e0a5c0..47b7fbf02 100644 --- a/extensions/core/src/vtabs.rs +++ b/extensions/core/src/vtabs.rs @@ -1,5 +1,9 @@ -use crate::{ResultCode, Value}; -use std::ffi::{c_char, c_void}; +use crate::{types::StepResult, ExtResult, ResultCode, Value}; +use std::{ + ffi::{c_char, c_void, CStr, CString}, + num::NonZeroUsize, + rc::Rc, +}; pub type RegisterModuleFn = unsafe extern "C" fn( ctx: *mut c_void, @@ -68,7 +72,7 @@ impl VTabModuleImpl { pub type VtabFnCreate = unsafe extern "C" fn(args: *const Value, argc: i32) -> VTabCreateResult; -pub type VtabFnOpen = unsafe extern "C" fn(table: *const c_void) -> *const c_void; +pub type VtabFnOpen = unsafe extern "C" fn(table: *const c_void, conn: *mut Conn) -> *const c_void; pub type VtabFnClose = unsafe extern "C" fn(cursor: *const c_void) -> ResultCode; @@ -125,7 +129,9 @@ pub trait VTable { type Cursor: VTabCursor; type Error: std::fmt::Display; - fn open(&self) -> Result; + /// 'conn' is an Option to allow for testing. Otherwise a valid connection to the core database + /// that created the virtual table will be available to use in your extension here. + fn open(&self, _conn: Option>) -> Result; fn update(&mut self, _rowid: i64, _args: &[Value]) -> Result<(), Self::Error> { Ok(()) } @@ -336,3 +342,268 @@ impl ConstraintInfo { ((self.plan_info >> 1) as usize, (self.plan_info & 1) != 0) } } + +pub type ConnectFn = unsafe extern "C" fn(ctx: *mut c_void) -> *mut Conn; +pub type PrepareStmtFn = unsafe extern "C" fn(api: *mut Conn, sql: *const c_char) -> *const Stmt; +pub type GetColumnNamesFn = + unsafe extern "C" fn(ctx: *mut Stmt, count: *mut i32) -> *mut *mut c_char; +pub type BindArgsFn = + unsafe extern "C" fn(ctx: *mut Stmt, idx: i32, arg: *const Value) -> ResultCode; +pub type StmtStepFn = unsafe extern "C" fn(ctx: *mut Stmt) -> ResultCode; +pub type StmtGetRowValuesFn = unsafe extern "C" fn(ctx: *mut Stmt); +pub type FreeCurrentRowFn = unsafe extern "C" fn(ctx: *mut Stmt); +pub type CloseConnectionFn = unsafe extern "C" fn(ctx: *mut c_void); +pub type CloseStmtFn = unsafe extern "C" fn(ctx: *mut Stmt); + +/// core database connection +/// public fields for core only +#[repr(C)] +#[derive(Debug, Clone)] +pub struct Conn { + // Rc::Weak from core::Connection + pub _ctx: *mut c_void, + pub _prepare_stmt: PrepareStmtFn, + pub _close: CloseConnectionFn, +} + +impl Conn { + pub fn new(ctx: *mut c_void, prepare_stmt: PrepareStmtFn, close: CloseConnectionFn) -> Self { + Conn { + _ctx: ctx, + _prepare_stmt: prepare_stmt, + _close: close, + } + } + /// # Safety + pub unsafe fn from_ptr(ptr: *mut Conn) -> crate::ExtResult<&'static mut Self> { + if ptr.is_null() { + return Err(ResultCode::Error); + } + Ok(unsafe { &mut *(ptr) }) + } + + pub fn close(&self) { + unsafe { (self._close)(self._ctx) }; + } + + pub fn prepare_stmt(&self, sql: &str) -> *const Stmt { + let Ok(sql) = CString::new(sql) else { + return std::ptr::null(); + }; + unsafe { (self._prepare_stmt)(self as *const Conn as *mut Conn, sql.as_ptr()) } + } +} + +/// Prepared statement for querying a core database connection +/// public API with wrapper methods for extensions +#[derive(Debug)] +#[repr(C)] +pub struct Statement { + _ctx: *const Stmt, +} + +/// The Database connection that opened the VTable: +/// Public API to expose methods for extensions +#[derive(Debug)] +#[repr(C)] +pub struct Connection { + _ctx: *mut Conn, +} + +impl Connection { + pub fn new(ctx: *mut Conn) -> Self { + Connection { _ctx: ctx } + } + + /// From the included SQL string, prepare a statement for execution. + pub fn prepare(self: &Rc, sql: &str) -> ExtResult { + let stmt = unsafe { (*self._ctx).prepare_stmt(sql) }; + if stmt.is_null() { + return Err(ResultCode::Error); + } + Ok(Statement { _ctx: stmt }) + } + + /// Close the connection to the database. + pub fn close(self) { + unsafe { ((*self._ctx)._close)(self._ctx as *mut c_void) }; + } +} + +impl Statement { + /// Bind a value to a parameter in the prepared statement + ///```ignore + /// let stmt = conn.prepare_stmt("select * from users where name = ?"); + /// stmt.bind(1, Value::from_text("test".into())); + pub fn bind(&self, idx: NonZeroUsize, arg: &Value) { + let arg = arg as *const Value; + unsafe { (*self._ctx).bind_args(idx, arg) } + } + + /// Execute the statement and return the next row + ///```ignore + /// while stmt.step() == StepResult::Row { + /// let row = stmt.get_row(); + /// println!("row: {:?}", row); + /// } + /// ``` + pub fn step(&self) -> StepResult { + unsafe { (*self._ctx).step() } + } + + // Get the current row values + ///```ignore + /// while stmt.step() == StepResult::Row { + /// let row = stmt.get_row(); + /// println!("row: {:?}", row); + ///``` + pub fn get_row(&mut self) -> &[Value] { + unsafe { (*self._ctx).get_row() } + } + + /// Get the result column names for the prepared statement + pub fn get_column_names(&self) -> Vec { + unsafe { (*self._ctx).get_column_names() } + } + + /// Close the statement + pub fn close(&self) { + unsafe { (*self._ctx).close() } + } +} + +/// Internal/core use _only_ +/// Extensions should not import or use this type directly +#[repr(C)] +pub struct Stmt { + // Rc::into_raw from core::Connection + pub _conn: *mut c_void, + // Rc::into_raw from core::Statement + pub _ctx: *mut c_void, + pub _bind_args_fn: BindArgsFn, + pub _step: StmtStepFn, + pub _get_row_values: StmtGetRowValuesFn, + pub _get_column_names: GetColumnNamesFn, + pub _free_current_row: FreeCurrentRowFn, + pub _close: CloseStmtFn, + pub current_row: *mut Value, + pub current_row_len: i32, +} + +impl Stmt { + #[allow(clippy::too_many_arguments)] + pub fn new( + conn: *mut c_void, + ctx: *mut c_void, + bind: BindArgsFn, + step: StmtStepFn, + rows: StmtGetRowValuesFn, + names: GetColumnNamesFn, + free_row: FreeCurrentRowFn, + close: CloseStmtFn, + ) -> Self { + Stmt { + _conn: conn, + _ctx: ctx, + _bind_args_fn: bind, + _step: step, + _get_row_values: rows, + _get_column_names: names, + _free_current_row: free_row, + _close: close, + current_row: std::ptr::null_mut(), + current_row_len: -1, + } + } + + /// Close the statement + pub fn close(&self) { + unsafe { (self._close)(self as *const Stmt as *mut Stmt) }; + } + + /// # Safety + /// Derefs a null ptr, does a null check first + pub unsafe fn from_ptr(ptr: *mut Stmt) -> ExtResult<&'static mut Self> { + if ptr.is_null() { + return Err(ResultCode::Error); + } + Ok(unsafe { &mut *(ptr) }) + } + + /// Returns the pointer to the statement. + pub fn to_ptr(&self) -> *const Stmt { + self + } + + /// Bind a value to a parameter in the prepared statement + fn bind_args(&self, idx: NonZeroUsize, arg: *const Value) { + unsafe { (self._bind_args_fn)(self.to_ptr() as *mut Stmt, idx.get() as i32, arg) }; + } + + /// Execute the statement to attempt to retrieve the next result row. + fn step(&self) -> StepResult { + unsafe { (self._step)(self.to_ptr() as *mut Stmt) }.into() + } + + /// Free the memory for the values obtained from the `get_row` method. + /// # Safety + /// This fn is unsafe because it derefs a raw pointer after null and + /// length checks. This fn should only be called with the pointer returned from get_row. + pub unsafe fn free_current_row(&mut self) { + if self.current_row.is_null() || self.current_row_len <= 0 { + return; + } + // free from the core side so we don't have to expose `__free_internal_type` + (self._free_current_row)(self.to_ptr() as *mut Stmt); + self.current_row = std::ptr::null_mut(); + self.current_row_len = -1; + } + + /// Returns the values from the current row in the prepared statement, should + /// be called after the step() method returns `StepResult::Row` + pub fn get_row(&self) -> &[Value] { + unsafe { (self._get_row_values)(self.to_ptr() as *mut Stmt) }; + if self.current_row.is_null() || self.current_row_len < 1 { + return &[]; + } + let col_count = self.current_row_len; + unsafe { std::slice::from_raw_parts(self.current_row, col_count as usize) } + } + + /// Returns the names of the result columns for the prepared statement. + pub fn get_column_names(&self) -> Vec { + let mut count_value: i32 = 0; + let count: *mut i32 = &mut count_value; + let col_names = unsafe { (self._get_column_names)(self.to_ptr() as *mut Stmt, count) }; + if col_names.is_null() || count_value == 0 { + return Vec::new(); + } + let mut names = Vec::new(); + let slice = unsafe { std::slice::from_raw_parts(col_names, count_value as usize) }; + for x in slice { + let name = unsafe { CStr::from_ptr(*x) }; + names.push(name.to_str().unwrap().to_string()); + } + unsafe { free_column_names(col_names, count_value) }; + names + } +} + +/// Free the column names returned from get_column_names +/// # Safety +/// This function is unsafe because it derefs a raw pointer, this fn +/// should only be called with the pointer returned from get_column_names +/// only when they will no longer be used. +pub unsafe fn free_column_names(names: *mut *mut c_char, count: i32) { + if names.is_null() || count < 1 { + return; + } + let slice = std::slice::from_raw_parts_mut(names, count as usize); + + for name in slice { + if !name.is_null() { + let _ = CString::from_raw(*name); + } + } + let _ = Box::from_raw(names); +} diff --git a/extensions/csv/src/lib.rs b/extensions/csv/src/lib.rs index 2f8b2a7aa..340966e8d 100644 --- a/extensions/csv/src/lib.rs +++ b/extensions/csv/src/lib.rs @@ -21,11 +21,12 @@ //! - `columns` — number of columns //! - `schema` — optional custom SQL `CREATE TABLE` schema use limbo_ext::{ - register_extension, ConstraintInfo, IndexInfo, OrderByInfo, ResultCode, VTabCursor, VTabKind, - VTabModule, VTabModuleDerive, VTable, Value, + register_extension, Connection, ConstraintInfo, IndexInfo, OrderByInfo, ResultCode, VTabCursor, + VTabKind, VTabModule, VTabModuleDerive, VTable, Value, }; use std::fs::File; use std::io::{Read, Seek, SeekFrom}; +use std::rc::Rc; register_extension! { vtabs: { CsvVTabModule } @@ -259,7 +260,7 @@ impl VTable for CsvTable { type Cursor = CsvCursor; type Error = ResultCode; - fn open(&self) -> Result { + fn open(&self, _conn: Option>) -> Result { match self.new_reader() { Ok(reader) => Ok(CsvCursor::new(reader, self)), Err(_) => Err(ResultCode::Error), diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index f609ded91..437940b23 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -1,6 +1,8 @@ +use std::rc::Rc; + use limbo_ext::{ - register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, VTable, - Value, + register_extension, Connection, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, + VTable, Value, }; register_extension! { @@ -43,7 +45,7 @@ impl VTable for GenerateSeriesTable { type Cursor = GenerateSeriesCursor; type Error = ResultCode; - fn open(&self) -> Result { + fn open(&self, _conn: Option>) -> Result { Ok(GenerateSeriesCursor { start: 0, stop: 0, @@ -225,7 +227,7 @@ mod tests { // Helper function to collect all values from a cursor, returns Result with error code fn collect_series(series: Series) -> Result, ResultCode> { let tbl = GenerateSeriesTable {}; - let mut cursor = tbl.open()?; + let mut cursor = tbl.open(None)?; // Create args array for filter let args = vec![ @@ -542,7 +544,7 @@ mod tests { let stop = series.stop; let step = series.step; let tbl = GenerateSeriesTable {}; - let mut cursor = tbl.open().unwrap(); + let mut cursor = tbl.open(None).unwrap(); let args = vec![ Value::from_integer(start), diff --git a/extensions/tests/src/lib.rs b/extensions/tests/src/lib.rs index b70b5497b..273547c7d 100644 --- a/extensions/tests/src/lib.rs +++ b/extensions/tests/src/lib.rs @@ -1,14 +1,15 @@ use lazy_static::lazy_static; use limbo_ext::{ - register_extension, scalar, ConstraintInfo, ConstraintOp, ConstraintUsage, ExtResult, - IndexInfo, OrderByInfo, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, VTable, - Value, + register_extension, scalar, Connection, ConstraintInfo, ConstraintOp, ConstraintUsage, + ExtResult, IndexInfo, OrderByInfo, ResultCode, VTabCursor, VTabKind, VTabModule, + VTabModuleDerive, VTable, Value, }; #[cfg(not(target_family = "wasm"))] use limbo_ext::{VfsDerive, VfsExtension, VfsFile}; use std::collections::BTreeMap; use std::fs::{File, OpenOptions}; use std::io::{Read, Seek, SeekFrom, Write}; +use std::rc::Rc; use std::sync::Mutex; register_extension! { @@ -137,7 +138,7 @@ impl VTable for KVStoreTable { type Cursor = KVStoreCursor; type Error = String; - fn open(&self) -> Result { + fn open(&self, _conn: Option>) -> Result { let _ = env_logger::try_init(); Ok(KVStoreCursor { rows: Vec::new(), diff --git a/macros/src/ext/vtab_derive.rs b/macros/src/ext/vtab_derive.rs index 0bcb3c53a..eb5112a90 100644 --- a/macros/src/ext/vtab_derive.rs +++ b/macros/src/ext/vtab_derive.rs @@ -49,13 +49,14 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { } #[no_mangle] - unsafe extern "C" fn #open_fn_name(table: *const ::std::ffi::c_void) -> *const ::std::ffi::c_void { + unsafe extern "C" fn #open_fn_name(table: *const ::std::ffi::c_void, conn: *mut ::limbo_ext::Conn) -> *const ::std::ffi::c_void { if table.is_null() { return ::std::ptr::null(); } let table = table as *const <#struct_name as ::limbo_ext::VTabModule>::Table; let table: &<#struct_name as ::limbo_ext::VTabModule>::Table = &*table; - if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::Table::open(table) { + let conn = if conn.is_null() { None } else { Some(::std::rc::Rc::new(::limbo_ext::Connection::new(conn)))}; + if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::Table::open(table, conn) { return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *const ::std::ffi::c_void; } else { return ::std::ptr::null();