From d51614a4fde873bdca418294fa0db91b5d2fbba7 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Fri, 23 May 2025 19:59:47 -0400 Subject: [PATCH] Create extern functions to support vtab xConnect in core/ext --- core/ext/vtab_xconnect.rs | 179 +++++++++++++++++++++++++++++++++++ extensions/core/src/types.rs | 31 ++++++ 2 files changed, 210 insertions(+) create mode 100644 core/ext/vtab_xconnect.rs diff --git a/core/ext/vtab_xconnect.rs b/core/ext/vtab_xconnect.rs new file mode 100644 index 000000000..26dfa3662 --- /dev/null +++ b/core/ext/vtab_xconnect.rs @@ -0,0 +1,179 @@ +use crate::{types::Value, Connection, Statement, StepResult}; +use limbo_ext::{Conn as ExtConn, ResultCode, Stmt, Value as ExtValue}; +use std::{ + boxed::Box, + ffi::{c_char, c_void, CStr, CString}, + num::NonZeroUsize, + ptr, + rc::Weak, +}; + +pub unsafe extern "C" fn close(ctx: *mut c_void) { + if ctx.is_null() { + return; + } + let weak_box: Box> = Box::from_raw(ctx as *mut Weak); + if let Some(conn) = weak_box.upgrade() { + let _ = conn.close(); + } +} + +pub unsafe extern "C" fn prepare_stmt(ctx: *mut ExtConn, sql: *const c_char) -> *const Stmt { + let c_str = unsafe { CStr::from_ptr(sql as *mut c_char) }; + let sql_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ptr::null_mut(), + }; + if ctx.is_null() { + return ptr::null_mut(); + } + let Ok(extcon) = ExtConn::from_ptr(ctx) else { + return ptr::null_mut(); + }; + let weak_ptr = extcon._ctx as *const Weak; + let weak = &*weak_ptr; + let Some(conn) = weak.upgrade() else { + return ptr::null_mut(); + }; + match conn.prepare(&sql_str) { + Ok(stmt) => { + let raw_stmt = Box::into_raw(Box::new(stmt)) as *mut c_void; + Box::into_raw(Box::new(Stmt::new( + extcon._ctx, + raw_stmt, + stmt_bind_args_fn, + stmt_step, + stmt_get_row, + stmt_get_column_names, + stmt_free_current_row, + stmt_close, + ))) as *const Stmt + } + Err(_) => ptr::null_mut(), + } +} + +pub unsafe extern "C" fn stmt_bind_args_fn( + ctx: *mut Stmt, + idx: i32, + arg: *const ExtValue, +) -> ResultCode { + let Ok(stmt) = Stmt::from_ptr(ctx) else { + return ResultCode::Error; + }; + let stmt_ctx: &mut Statement = unsafe { &mut *(stmt._ctx as *mut Statement) }; + let Ok(owned_val) = Value::from_ffi_ptr(arg) else { + tracing::error!("stmt_bind_args_fn: failed to convert arg to Value"); + return ResultCode::Error; + }; + let Some(idx) = NonZeroUsize::new(idx as usize) else { + tracing::error!("stmt_bind_args_fn: invalid index"); + return ResultCode::Error; + }; + stmt_ctx.bind_at(idx, owned_val); + ResultCode::OK +} + +pub unsafe extern "C" fn stmt_step(stmt: *mut Stmt) -> ResultCode { + let Ok(stmt) = Stmt::from_ptr(stmt) else { + tracing::error!("stmt_step: failed to convert stmt to Stmt"); + return ResultCode::Error; + }; + if stmt._conn.is_null() || stmt._ctx.is_null() { + tracing::error!("stmt_step: null connection or context"); + return ResultCode::Error; + } + let conn: &Connection = unsafe { &*(stmt._conn as *const Connection) }; + let stmt_ctx: &mut Statement = unsafe { &mut *(stmt._ctx as *mut Statement) }; + while let Ok(res) = stmt_ctx.step() { + match res { + StepResult::Row => return ResultCode::Row, + StepResult::Done => return ResultCode::EOF, + StepResult::IO => { + // always handle IO step result internally. + let _ = conn.pager.io.run_once(); + continue; + } + StepResult::Interrupt => return ResultCode::Interrupt, + StepResult::Busy => return ResultCode::Busy, + } + } + ResultCode::Error +} + +pub unsafe extern "C" fn stmt_get_row(ctx: *mut Stmt) { + let Ok(stmt) = Stmt::from_ptr(ctx) else { + return; + }; + if !stmt.current_row.is_null() { + stmt.free_current_row(); + } + let stmt_ctx: &mut Statement = unsafe { &mut *(stmt._ctx as *mut Statement) }; + if let Some(row) = stmt_ctx.row() { + let values = row.get_values(); + let mut owned_values = Vec::with_capacity(row.len()); + for value in values { + owned_values.push(Value::to_ffi(value)); + } + stmt.current_row = Box::into_raw(owned_values.into_boxed_slice()) as *mut ExtValue; + stmt.current_row_len = row.len() as i32; + } else { + stmt.current_row_len = 0; + } +} + +pub unsafe extern "C" fn stmt_free_current_row(ctx: *mut Stmt) { + let Ok(stmt) = Stmt::from_ptr(ctx) else { + return; + }; + if !stmt.current_row.is_null() { + let values: &mut [ExtValue] = + std::slice::from_raw_parts_mut(stmt.current_row, stmt.current_row_len as usize); + for value in values.iter_mut() { + let owned_value = std::mem::take(value); + owned_value.__free_internal_type(); + } + let _ = Box::from_raw(stmt.current_row); + } +} + +pub unsafe extern "C" fn stmt_get_column_names( + ctx: *mut Stmt, + count: *mut i32, +) -> *mut *mut c_char { + let Ok(stmt) = Stmt::from_ptr(ctx) else { + *count = 0; + return ptr::null_mut(); + }; + let stmt_ctx: &mut Statement = unsafe { &mut *(stmt._ctx as *mut Statement) }; + let num_cols = stmt_ctx.num_columns(); + if num_cols == 0 { + *count = 0; + return ptr::null_mut(); + } + let mut c_names: Vec<*mut c_char> = Vec::with_capacity(num_cols); + for i in 0..num_cols { + let name = stmt_ctx.get_column_name(i); + let c_str = CString::new(name.as_bytes()).unwrap(); + c_names.push(c_str.into_raw()); + } + + *count = c_names.len() as i32; + let names_array = c_names.into_boxed_slice(); + Box::into_raw(names_array) as *mut *mut c_char +} + +pub unsafe extern "C" fn stmt_close(ctx: *mut Stmt) { + let Ok(stmt) = Stmt::from_ptr(ctx) else { + return; + }; + if !stmt.current_row.is_null() { + stmt.free_current_row(); + } + // take ownership of internal statement + let wrapper = Box::from_raw(stmt as *mut Stmt); + if !wrapper._ctx.is_null() { + let mut _stmt: Box = Box::from_raw(wrapper._ctx as *mut Statement); + _stmt.reset() + } +} diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 90adb3863..5a4a94480 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -23,6 +23,9 @@ pub enum ResultCode { EOF = 15, ReadOnly = 16, RowID = 17, + Row = 18, + Interrupt = 19, + Busy = 20, } impl ResultCode { @@ -60,6 +63,34 @@ impl Display for ResultCode { ResultCode::EOF => write!(f, "EOF"), ResultCode::ReadOnly => write!(f, "Read Only"), ResultCode::RowID => write!(f, "RowID"), + ResultCode::Row => write!(f, "Row"), + ResultCode::Interrupt => write!(f, "Interrupt"), + ResultCode::Busy => write!(f, "Busy"), + } + } +} +#[repr(C)] +#[derive(PartialEq, Debug, Eq, Clone, Copy)] +/// StepResult is used to represent the state of a query as it is exposed +/// to the public API of a connection in a virtual table extension. +/// the IO variant is always handled internally and therefore is not included here. +pub enum StepResult { + Error, + Row, + Done, + Interrupt, + Busy, +} + +impl From for StepResult { + fn from(code: ResultCode) -> Self { + match code { + ResultCode::Error => StepResult::Error, + ResultCode::Row => StepResult::Row, + ResultCode::EOF => StepResult::Done, + ResultCode::Interrupt => StepResult::Interrupt, + ResultCode::Busy => StepResult::Busy, + _ => StepResult::Error, } } }