Update Vtable open method to accept core db connection

This commit is contained in:
PThorpe92
2025-05-23 23:09:45 -04:00
parent 2c784070f1
commit cbd7245677
6 changed files with 299 additions and 20 deletions

View File

@@ -3,9 +3,12 @@
mod keywords;
use std::rc::Rc;
use keywords::KEYWORDS;
use limbo_ext::{
register_extension, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, VTable, Value,
register_extension, Connection, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, VTable,
Value,
};
register_extension! {
@@ -84,7 +87,7 @@ impl VTable for CompletionTable {
type Cursor = CompletionCursor;
type Error = ResultCode;
fn open(&self) -> Result<Self::Cursor, Self::Error> {
fn open(&self, _conn: Option<Rc<Connection>>) -> Result<Self::Cursor, Self::Error> {
Ok(CompletionCursor::default())
}
}

View File

@@ -1,5 +1,9 @@
use crate::{ResultCode, Value};
use std::ffi::{c_char, c_void};
use crate::{types::StepResult, ExtResult, ResultCode, Value};
use std::{
ffi::{c_char, c_void, CStr, CString},
num::NonZeroUsize,
rc::Rc,
};
pub type RegisterModuleFn = unsafe extern "C" fn(
ctx: *mut c_void,
@@ -68,7 +72,7 @@ impl VTabModuleImpl {
pub type VtabFnCreate = unsafe extern "C" fn(args: *const Value, argc: i32) -> VTabCreateResult;
pub type VtabFnOpen = unsafe extern "C" fn(table: *const c_void) -> *const c_void;
pub type VtabFnOpen = unsafe extern "C" fn(table: *const c_void, conn: *mut Conn) -> *const c_void;
pub type VtabFnClose = unsafe extern "C" fn(cursor: *const c_void) -> ResultCode;
@@ -125,7 +129,9 @@ pub trait VTable {
type Cursor: VTabCursor<Error = Self::Error>;
type Error: std::fmt::Display;
fn open(&self) -> Result<Self::Cursor, Self::Error>;
/// 'conn' is an Option to allow for testing. Otherwise a valid connection to the core database
/// that created the virtual table will be available to use in your extension here.
fn open(&self, _conn: Option<Rc<Connection>>) -> Result<Self::Cursor, Self::Error>;
fn update(&mut self, _rowid: i64, _args: &[Value]) -> Result<(), Self::Error> {
Ok(())
}
@@ -336,3 +342,268 @@ impl ConstraintInfo {
((self.plan_info >> 1) as usize, (self.plan_info & 1) != 0)
}
}
pub type ConnectFn = unsafe extern "C" fn(ctx: *mut c_void) -> *mut Conn;
pub type PrepareStmtFn = unsafe extern "C" fn(api: *mut Conn, sql: *const c_char) -> *const Stmt;
pub type GetColumnNamesFn =
unsafe extern "C" fn(ctx: *mut Stmt, count: *mut i32) -> *mut *mut c_char;
pub type BindArgsFn =
unsafe extern "C" fn(ctx: *mut Stmt, idx: i32, arg: *const Value) -> ResultCode;
pub type StmtStepFn = unsafe extern "C" fn(ctx: *mut Stmt) -> ResultCode;
pub type StmtGetRowValuesFn = unsafe extern "C" fn(ctx: *mut Stmt);
pub type FreeCurrentRowFn = unsafe extern "C" fn(ctx: *mut Stmt);
pub type CloseConnectionFn = unsafe extern "C" fn(ctx: *mut c_void);
pub type CloseStmtFn = unsafe extern "C" fn(ctx: *mut Stmt);
/// core database connection
/// public fields for core only
#[repr(C)]
#[derive(Debug, Clone)]
pub struct Conn {
// Rc::Weak from core::Connection
pub _ctx: *mut c_void,
pub _prepare_stmt: PrepareStmtFn,
pub _close: CloseConnectionFn,
}
impl Conn {
pub fn new(ctx: *mut c_void, prepare_stmt: PrepareStmtFn, close: CloseConnectionFn) -> Self {
Conn {
_ctx: ctx,
_prepare_stmt: prepare_stmt,
_close: close,
}
}
/// # Safety
pub unsafe fn from_ptr(ptr: *mut Conn) -> crate::ExtResult<&'static mut Self> {
if ptr.is_null() {
return Err(ResultCode::Error);
}
Ok(unsafe { &mut *(ptr) })
}
pub fn close(&self) {
unsafe { (self._close)(self._ctx) };
}
pub fn prepare_stmt(&self, sql: &str) -> *const Stmt {
let Ok(sql) = CString::new(sql) else {
return std::ptr::null();
};
unsafe { (self._prepare_stmt)(self as *const Conn as *mut Conn, sql.as_ptr()) }
}
}
/// Prepared statement for querying a core database connection
/// public API with wrapper methods for extensions
#[derive(Debug)]
#[repr(C)]
pub struct Statement {
_ctx: *const Stmt,
}
/// The Database connection that opened the VTable:
/// Public API to expose methods for extensions
#[derive(Debug)]
#[repr(C)]
pub struct Connection {
_ctx: *mut Conn,
}
impl Connection {
pub fn new(ctx: *mut Conn) -> Self {
Connection { _ctx: ctx }
}
/// From the included SQL string, prepare a statement for execution.
pub fn prepare(self: &Rc<Self>, sql: &str) -> ExtResult<Statement> {
let stmt = unsafe { (*self._ctx).prepare_stmt(sql) };
if stmt.is_null() {
return Err(ResultCode::Error);
}
Ok(Statement { _ctx: stmt })
}
/// Close the connection to the database.
pub fn close(self) {
unsafe { ((*self._ctx)._close)(self._ctx as *mut c_void) };
}
}
impl Statement {
/// Bind a value to a parameter in the prepared statement
///```ignore
/// let stmt = conn.prepare_stmt("select * from users where name = ?");
/// stmt.bind(1, Value::from_text("test".into()));
pub fn bind(&self, idx: NonZeroUsize, arg: &Value) {
let arg = arg as *const Value;
unsafe { (*self._ctx).bind_args(idx, arg) }
}
/// Execute the statement and return the next row
///```ignore
/// while stmt.step() == StepResult::Row {
/// let row = stmt.get_row();
/// println!("row: {:?}", row);
/// }
/// ```
pub fn step(&self) -> StepResult {
unsafe { (*self._ctx).step() }
}
// Get the current row values
///```ignore
/// while stmt.step() == StepResult::Row {
/// let row = stmt.get_row();
/// println!("row: {:?}", row);
///```
pub fn get_row(&mut self) -> &[Value] {
unsafe { (*self._ctx).get_row() }
}
/// Get the result column names for the prepared statement
pub fn get_column_names(&self) -> Vec<String> {
unsafe { (*self._ctx).get_column_names() }
}
/// Close the statement
pub fn close(&self) {
unsafe { (*self._ctx).close() }
}
}
/// Internal/core use _only_
/// Extensions should not import or use this type directly
#[repr(C)]
pub struct Stmt {
// Rc::into_raw from core::Connection
pub _conn: *mut c_void,
// Rc::into_raw from core::Statement
pub _ctx: *mut c_void,
pub _bind_args_fn: BindArgsFn,
pub _step: StmtStepFn,
pub _get_row_values: StmtGetRowValuesFn,
pub _get_column_names: GetColumnNamesFn,
pub _free_current_row: FreeCurrentRowFn,
pub _close: CloseStmtFn,
pub current_row: *mut Value,
pub current_row_len: i32,
}
impl Stmt {
#[allow(clippy::too_many_arguments)]
pub fn new(
conn: *mut c_void,
ctx: *mut c_void,
bind: BindArgsFn,
step: StmtStepFn,
rows: StmtGetRowValuesFn,
names: GetColumnNamesFn,
free_row: FreeCurrentRowFn,
close: CloseStmtFn,
) -> Self {
Stmt {
_conn: conn,
_ctx: ctx,
_bind_args_fn: bind,
_step: step,
_get_row_values: rows,
_get_column_names: names,
_free_current_row: free_row,
_close: close,
current_row: std::ptr::null_mut(),
current_row_len: -1,
}
}
/// Close the statement
pub fn close(&self) {
unsafe { (self._close)(self as *const Stmt as *mut Stmt) };
}
/// # Safety
/// Derefs a null ptr, does a null check first
pub unsafe fn from_ptr(ptr: *mut Stmt) -> ExtResult<&'static mut Self> {
if ptr.is_null() {
return Err(ResultCode::Error);
}
Ok(unsafe { &mut *(ptr) })
}
/// Returns the pointer to the statement.
pub fn to_ptr(&self) -> *const Stmt {
self
}
/// Bind a value to a parameter in the prepared statement
fn bind_args(&self, idx: NonZeroUsize, arg: *const Value) {
unsafe { (self._bind_args_fn)(self.to_ptr() as *mut Stmt, idx.get() as i32, arg) };
}
/// Execute the statement to attempt to retrieve the next result row.
fn step(&self) -> StepResult {
unsafe { (self._step)(self.to_ptr() as *mut Stmt) }.into()
}
/// Free the memory for the values obtained from the `get_row` method.
/// # Safety
/// This fn is unsafe because it derefs a raw pointer after null and
/// length checks. This fn should only be called with the pointer returned from get_row.
pub unsafe fn free_current_row(&mut self) {
if self.current_row.is_null() || self.current_row_len <= 0 {
return;
}
// free from the core side so we don't have to expose `__free_internal_type`
(self._free_current_row)(self.to_ptr() as *mut Stmt);
self.current_row = std::ptr::null_mut();
self.current_row_len = -1;
}
/// Returns the values from the current row in the prepared statement, should
/// be called after the step() method returns `StepResult::Row`
pub fn get_row(&self) -> &[Value] {
unsafe { (self._get_row_values)(self.to_ptr() as *mut Stmt) };
if self.current_row.is_null() || self.current_row_len < 1 {
return &[];
}
let col_count = self.current_row_len;
unsafe { std::slice::from_raw_parts(self.current_row, col_count as usize) }
}
/// Returns the names of the result columns for the prepared statement.
pub fn get_column_names(&self) -> Vec<String> {
let mut count_value: i32 = 0;
let count: *mut i32 = &mut count_value;
let col_names = unsafe { (self._get_column_names)(self.to_ptr() as *mut Stmt, count) };
if col_names.is_null() || count_value == 0 {
return Vec::new();
}
let mut names = Vec::new();
let slice = unsafe { std::slice::from_raw_parts(col_names, count_value as usize) };
for x in slice {
let name = unsafe { CStr::from_ptr(*x) };
names.push(name.to_str().unwrap().to_string());
}
unsafe { free_column_names(col_names, count_value) };
names
}
}
/// Free the column names returned from get_column_names
/// # Safety
/// This function is unsafe because it derefs a raw pointer, this fn
/// should only be called with the pointer returned from get_column_names
/// only when they will no longer be used.
pub unsafe fn free_column_names(names: *mut *mut c_char, count: i32) {
if names.is_null() || count < 1 {
return;
}
let slice = std::slice::from_raw_parts_mut(names, count as usize);
for name in slice {
if !name.is_null() {
let _ = CString::from_raw(*name);
}
}
let _ = Box::from_raw(names);
}

View File

@@ -21,11 +21,12 @@
//! - `columns` — number of columns
//! - `schema` — optional custom SQL `CREATE TABLE` schema
use limbo_ext::{
register_extension, ConstraintInfo, IndexInfo, OrderByInfo, ResultCode, VTabCursor, VTabKind,
VTabModule, VTabModuleDerive, VTable, Value,
register_extension, Connection, ConstraintInfo, IndexInfo, OrderByInfo, ResultCode, VTabCursor,
VTabKind, VTabModule, VTabModuleDerive, VTable, Value,
};
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::rc::Rc;
register_extension! {
vtabs: { CsvVTabModule }
@@ -259,7 +260,7 @@ impl VTable for CsvTable {
type Cursor = CsvCursor;
type Error = ResultCode;
fn open(&self) -> Result<Self::Cursor, Self::Error> {
fn open(&self, _conn: Option<Rc<Connection>>) -> Result<Self::Cursor, Self::Error> {
match self.new_reader() {
Ok(reader) => Ok(CsvCursor::new(reader, self)),
Err(_) => Err(ResultCode::Error),

View File

@@ -1,6 +1,8 @@
use std::rc::Rc;
use limbo_ext::{
register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, VTable,
Value,
register_extension, Connection, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive,
VTable, Value,
};
register_extension! {
@@ -43,7 +45,7 @@ impl VTable for GenerateSeriesTable {
type Cursor = GenerateSeriesCursor;
type Error = ResultCode;
fn open(&self) -> Result<Self::Cursor, Self::Error> {
fn open(&self, _conn: Option<Rc<Connection>>) -> Result<Self::Cursor, Self::Error> {
Ok(GenerateSeriesCursor {
start: 0,
stop: 0,
@@ -225,7 +227,7 @@ mod tests {
// Helper function to collect all values from a cursor, returns Result with error code
fn collect_series(series: Series) -> Result<Vec<i64>, ResultCode> {
let tbl = GenerateSeriesTable {};
let mut cursor = tbl.open()?;
let mut cursor = tbl.open(None)?;
// Create args array for filter
let args = vec![
@@ -542,7 +544,7 @@ mod tests {
let stop = series.stop;
let step = series.step;
let tbl = GenerateSeriesTable {};
let mut cursor = tbl.open().unwrap();
let mut cursor = tbl.open(None).unwrap();
let args = vec![
Value::from_integer(start),

View File

@@ -1,14 +1,15 @@
use lazy_static::lazy_static;
use limbo_ext::{
register_extension, scalar, ConstraintInfo, ConstraintOp, ConstraintUsage, ExtResult,
IndexInfo, OrderByInfo, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, VTable,
Value,
register_extension, scalar, Connection, ConstraintInfo, ConstraintOp, ConstraintUsage,
ExtResult, IndexInfo, OrderByInfo, ResultCode, VTabCursor, VTabKind, VTabModule,
VTabModuleDerive, VTable, Value,
};
#[cfg(not(target_family = "wasm"))]
use limbo_ext::{VfsDerive, VfsExtension, VfsFile};
use std::collections::BTreeMap;
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::rc::Rc;
use std::sync::Mutex;
register_extension! {
@@ -137,7 +138,7 @@ impl VTable for KVStoreTable {
type Cursor = KVStoreCursor;
type Error = String;
fn open(&self) -> Result<Self::Cursor, Self::Error> {
fn open(&self, _conn: Option<Rc<Connection>>) -> Result<Self::Cursor, Self::Error> {
let _ = env_logger::try_init();
Ok(KVStoreCursor {
rows: Vec::new(),

View File

@@ -49,13 +49,14 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream {
}
#[no_mangle]
unsafe extern "C" fn #open_fn_name(table: *const ::std::ffi::c_void) -> *const ::std::ffi::c_void {
unsafe extern "C" fn #open_fn_name(table: *const ::std::ffi::c_void, conn: *mut ::limbo_ext::Conn) -> *const ::std::ffi::c_void {
if table.is_null() {
return ::std::ptr::null();
}
let table = table as *const <#struct_name as ::limbo_ext::VTabModule>::Table;
let table: &<#struct_name as ::limbo_ext::VTabModule>::Table = &*table;
if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::Table::open(table) {
let conn = if conn.is_null() { None } else { Some(::std::rc::Rc::new(::limbo_ext::Connection::new(conn)))};
if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::Table::open(table, conn) {
return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *const ::std::ffi::c_void;
} else {
return ::std::ptr::null();