diff --git a/bindings/go/cmd/main.go b/bindings/go/cmd/main.go deleted file mode 100644 index 32fcbea23..000000000 --- a/bindings/go/cmd/main.go +++ /dev/null @@ -1,18 +0,0 @@ -// package main -// -// import ( -// "fmt" -// ) -// -// func main() { -// conn, err := lc.Open("new.db") -// if err != nil { -// panic(err) -// } -// fmt.Println("Connected to database") -// sql := "select c from t;" -// conn.Query(sql) -// -// conn.Close() -// fmt.Println("Connection closed") -// } diff --git a/bindings/go/go.mod b/bindings/go/go.mod index fa1d99d3e..c108e721d 100644 --- a/bindings/go/go.mod +++ b/bindings/go/go.mod @@ -2,4 +2,7 @@ module turso go 1.23.4 -require github.com/ebitengine/purego v0.8.2 // indirect +require ( + github.com/ebitengine/purego v0.8.2 + golang.org/x/sys/windows v0.29.0 +) diff --git a/bindings/go/go.sum b/bindings/go/go.sum index 38eca3dfd..16a0ba53f 100644 --- a/bindings/go/go.sum +++ b/bindings/go/go.sum @@ -1,2 +1,4 @@ github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I= github.com/ebitengine/purego v0.8.2/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/bindings/go/rs_src/lib.rs b/bindings/go/rs_src/lib.rs index 862c8c191..36b5a6db1 100644 --- a/bindings/go/rs_src/lib.rs +++ b/bindings/go/rs_src/lib.rs @@ -1,6 +1,8 @@ +mod rows; +#[allow(dead_code)] mod statement; mod types; -use limbo_core::{Connection, Database, LimboError, Value}; +use limbo_core::{Connection, Database, LimboError}; use std::{ ffi::{c_char, c_void}, rc::Rc, @@ -8,6 +10,9 @@ use std::{ sync::Arc, }; +/// # Safety +/// Safe to be called from Go with null terminated DSN string. +/// performs null check on the path. #[no_mangle] pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { if path.is_null() { @@ -34,27 +39,22 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { std::ptr::null_mut() } -struct TursoConn<'a> { +#[allow(dead_code)] +struct TursoConn { conn: Rc, io: Arc, - cursor_idx: usize, - cursor: Option>>, } -impl<'a> TursoConn<'_> { +impl TursoConn { fn new(conn: Rc, io: Arc) -> Self { - TursoConn { - conn, - io, - cursor_idx: 0, - cursor: None, - } + TursoConn { conn, io } } + #[allow(clippy::wrong_self_convention)] fn to_ptr(self) -> *mut c_void { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'static mut TursoConn<'a> { + fn from_ptr(ptr: *mut c_void) -> &'static mut TursoConn { if ptr.is_null() { panic!("Null pointer"); } @@ -68,7 +68,7 @@ impl<'a> TursoConn<'_> { #[no_mangle] pub unsafe extern "C" fn db_close(db: *mut c_void) { if !db.is_null() { - let _ = unsafe { Box::from_raw(db) }; + let _ = unsafe { Box::from_raw(db as *mut TursoConn) }; } } @@ -77,19 +77,12 @@ fn get_io(db_location: &DbType) -> Result, LimboError> { Ok(match db_location { DbType::Memory => Arc::new(limbo_core::MemoryIO::new()?), _ => { - #[cfg(target_family = "unix")] - if cfg!(all(target_os = "linux", feature = "io_uring")) { - Arc::new(limbo_core::UringIO::new()?) - } else { - Arc::new(limbo_core::UnixIO::new()?) - } - - #[cfg(target_family = "windows")] - Arc::new(limbo_core::WindowsIO::new()?); + return Ok(Arc::new(limbo_core::PlatformIO::new()?)); } }) } +#[allow(dead_code)] struct DbOptions { path: DbType, params: Parameters, diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs new file mode 100644 index 000000000..c505924bc --- /dev/null +++ b/bindings/go/rs_src/rows.rs @@ -0,0 +1,138 @@ +use crate::{ + statement::TursoStatement, + types::{ResultCode, TursoValue}, +}; +use limbo_core::{Rows, StepResult, Value}; +use std::ffi::{c_char, c_void}; + +pub struct TursoRows<'a> { + rows: Rows, + cursor: Option>>, + stmt: Box>, +} + +impl<'a> TursoRows<'a> { + pub fn new(rows: Rows, stmt: Box>) -> Self { + TursoRows { + rows, + stmt, + cursor: None, + } + } + + #[allow(clippy::wrong_self_convention)] + pub fn to_ptr(self) -> *mut c_void { + Box::into_raw(Box::new(self)) as *mut c_void + } + + pub fn from_ptr(ptr: *mut c_void) -> &'static mut TursoRows<'a> { + if ptr.is_null() { + panic!("Null pointer"); + } + unsafe { &mut *(ptr as *mut TursoRows) } + } +} + +#[no_mangle] +pub extern "C" fn rows_next(ctx: *mut c_void) -> ResultCode { + if ctx.is_null() { + return ResultCode::Error; + } + let ctx = TursoRows::from_ptr(ctx); + + match ctx.rows.next_row() { + Ok(StepResult::Row(row)) => { + ctx.cursor = Some(row.values); + ResultCode::Row + } + Ok(StepResult::Done) => ResultCode::Done, + Ok(StepResult::IO) => { + let _ = ctx.stmt.conn.io.run_once(); + ResultCode::Io + } + Ok(StepResult::Busy) => ResultCode::Busy, + Ok(StepResult::Interrupt) => ResultCode::Interrupt, + Err(_) => ResultCode::Error, + } +} + +#[no_mangle] +pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_void { + if ctx.is_null() { + return std::ptr::null(); + } + let ctx = TursoRows::from_ptr(ctx); + + if let Some(ref cursor) = ctx.cursor { + if let Some(value) = cursor.get(col_idx) { + let val = TursoValue::from_value(value); + return val.to_ptr(); + } + } + std::ptr::null() +} + +#[no_mangle] +pub extern "C" fn free_string(s: *mut c_char) { + if !s.is_null() { + unsafe { drop(std::ffi::CString::from_raw(s)) }; + } +} + +#[no_mangle] +pub extern "C" fn rows_get_columns( + rows_ptr: *mut c_void, + out_length: *mut usize, +) -> *mut *const c_char { + if rows_ptr.is_null() || out_length.is_null() { + return std::ptr::null_mut(); + } + let rows = TursoRows::from_ptr(rows_ptr); + let c_strings: Vec = rows + .rows + .columns() + .iter() + .map(|name| std::ffi::CString::new(name.as_str()).unwrap()) + .collect(); + + let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect(); + unsafe { + *out_length = c_ptrs.len(); + } + let ptr = c_ptrs.as_ptr(); + std::mem::forget(c_strings); + std::mem::forget(c_ptrs); + ptr as *mut *const c_char +} + +#[no_mangle] +pub extern "C" fn rows_close(rows_ptr: *mut c_void) { + if !rows_ptr.is_null() { + let _ = unsafe { Box::from_raw(rows_ptr as *mut TursoRows) }; + } +} + +#[no_mangle] +pub extern "C" fn free_columns(columns: *mut *const c_char) { + if columns.is_null() { + return; + } + unsafe { + let mut idx = 0; + while !(*columns.add(idx)).is_null() { + let _ = std::ffi::CString::from_raw(*columns.add(idx) as *mut c_char); + idx += 1; + } + let _ = Box::from_raw(columns); + } +} + +#[no_mangle] +pub extern "C" fn free_rows(rows: *mut c_void) { + if rows.is_null() { + return; + } + unsafe { + let _ = Box::from_raw(rows as *mut Rows); + } +} diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 99a45d692..4a4e29e34 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -1,7 +1,9 @@ -use crate::types::ResultCode; +use crate::rows::TursoRows; +use crate::types::{AllocPool, ResultCode, TursoValue}; use crate::TursoConn; -use limbo_core::{Rows, Statement, StepResult, Value}; +use limbo_core::{Statement, StepResult}; use std::ffi::{c_char, c_void}; +use std::num::NonZero; #[no_mangle] pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_void { @@ -19,142 +21,119 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v } } -struct TursoStatement<'a> { - statement: Statement, - conn: &'a TursoConn<'a>, +#[no_mangle] +pub extern "C" fn stmt_execute( + ctx: *mut c_void, + args_ptr: *mut TursoValue, + arg_count: usize, + changes: *mut i64, +) -> ResultCode { + if ctx.is_null() { + return ResultCode::Error; + } + let stmt = TursoStatement::from_ptr(ctx); + + let args = if !args_ptr.is_null() && arg_count > 0 { + unsafe { std::slice::from_raw_parts(args_ptr, arg_count) } + } else { + &[] + }; + for (i, arg) in args.iter().enumerate() { + let val = arg.to_value(&mut stmt.pool); + stmt.statement.bind_at(NonZero::new(i + 1).unwrap(), val); + } + loop { + match stmt.statement.step() { + Ok(StepResult::Row(_)) => { + // unexpected row during execution, error out. + return ResultCode::Error; + } + Ok(StepResult::Done) => { + stmt.conn.conn.total_changes(); + if !changes.is_null() { + unsafe { + *changes = stmt.conn.conn.total_changes(); + } + } + return ResultCode::Done; + } + Ok(StepResult::IO) => { + let _ = stmt.conn.io.run_once(); + } + Ok(StepResult::Busy) => { + return ResultCode::Busy; + } + Ok(StepResult::Interrupt) => { + return ResultCode::Interrupt; + } + Err(_) => { + return ResultCode::Error; + } + } + } } -impl<'a> TursoStatement<'a> { - fn new(statement: Statement, conn: &'a TursoConn<'a>) -> Self { - TursoStatement { statement, conn } +#[no_mangle] +pub extern "C" fn stmt_parameter_count(ctx: *mut c_void) -> i32 { + if ctx.is_null() { + return -1; } + let stmt = TursoStatement::from_ptr(ctx); + stmt.statement.parameters_count() as i32 +} + +#[no_mangle] +pub extern "C" fn stmt_query( + ctx: *mut c_void, + args_ptr: *mut TursoValue, + args_count: usize, +) -> *mut c_void { + if ctx.is_null() { + return std::ptr::null_mut(); + } + let stmt = TursoStatement::from_ptr(ctx); + let args = if !args_ptr.is_null() && args_count > 0 { + unsafe { std::slice::from_raw_parts(args_ptr, args_count) } + } else { + &[] + }; + for (i, arg) in args.iter().enumerate() { + let val = arg.to_value(&mut stmt.pool); + stmt.statement.bind_at(NonZero::new(i + 1).unwrap(), val); + } + match stmt.statement.query() { + Ok(rows) => { + let stmt = unsafe { Box::from_raw(stmt) }; + TursoRows::new(rows, stmt).to_ptr() + } + Err(_) => std::ptr::null_mut(), + } +} + +pub struct TursoStatement<'conn> { + pub statement: Statement, + pub conn: &'conn mut TursoConn, + pub pool: AllocPool, +} + +impl<'conn> TursoStatement<'conn> { + pub fn new(statement: Statement, conn: &'conn mut TursoConn) -> Self { + TursoStatement { + statement, + conn, + pool: AllocPool::new(), + } + } + + #[allow(clippy::wrong_self_convention)] fn to_ptr(self) -> *mut c_void { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'static mut TursoStatement<'a> { + + fn from_ptr(ptr: *mut c_void) -> &'static mut TursoStatement<'conn> { if ptr.is_null() { panic!("Null pointer"); } unsafe { &mut *(ptr as *mut TursoStatement) } } } - -#[no_mangle] -pub extern "C" fn db_get_columns(ctx: *mut c_void) -> *const c_void { - if ctx.is_null() { - return std::ptr::null(); - } - let stmt = TursoStatement::from_ptr(ctx); - let columns = stmt.statement.columns(); - let mut column_names = Vec::new(); - for column in columns { - column_names.push(column.name().to_string()); - } - let c_string = std::ffi::CString::new(column_names.join(",")).unwrap(); - c_string.into_raw() as *const c_void -} - -struct TursoRows<'a> { - rows: Rows<'a>, - conn: &'a mut TursoConn<'a>, -} - -impl<'a> TursoRows<'a> { - fn new(rows: Rows<'a>, conn: &'a mut TursoConn<'a>) -> Self { - TursoRows { rows, conn } - } - - fn to_ptr(self) -> *mut c_void { - Box::into_raw(Box::new(self)) as *mut c_void - } - - fn from_ptr(ptr: *mut c_void) -> &'static mut TursoRows<'a> { - if ptr.is_null() { - panic!("Null pointer"); - } - unsafe { &mut *(ptr as *mut TursoRows) } - } -} - -#[no_mangle] -pub extern "C" fn rows_next(ctx: *mut c_void, rows_ptr: *mut c_void) -> ResultCode { - if rows_ptr.is_null() || ctx.is_null() { - return ResultCode::Error; - } - let rows = unsafe { &mut *(rows_ptr as *mut Rows) }; - let conn = TursoConn::from_ptr(ctx); - - match rows.next_row() { - Ok(StepResult::Row(row)) => { - conn.cursor = Some(row.values); - ResultCode::Row - } - Ok(StepResult::Done) => { - // No more rows - ResultCode::Done - } - Ok(StepResult::IO) => { - let _ = conn.io.run_once(); - ResultCode::Io - } - Ok(StepResult::Busy) => ResultCode::Busy, - Ok(StepResult::Interrupt) => ResultCode::Interrupt, - Err(_) => ResultCode::Error, - } -} - -#[no_mangle] -pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_char { - if ctx.is_null() { - return std::ptr::null(); - } - let conn = TursoConn::from_ptr(ctx); - - if let Some(ref cursor) = conn.cursor { - if let Some(value) = cursor.get(col_idx) { - let c_string = std::ffi::CString::new(value.to_string()).unwrap(); - return c_string.into_raw(); // Caller must free this pointer - } - } - std::ptr::null() // No data or invalid index -} - -// Free the returned string -#[no_mangle] -pub extern "C" fn free_c_string(s: *mut c_char) { - if !s.is_null() { - unsafe { drop(std::ffi::CString::from_raw(s)) }; - } -} -#[no_mangle] -pub extern "C" fn rows_get_string( - ctx: *mut c_void, - rows_ptr: *mut c_void, - col_idx: i32, -) -> *const c_char { - if rows_ptr.is_null() || ctx.is_null() { - return std::ptr::null(); - } - let _rows = unsafe { &mut *(rows_ptr as *mut Rows) }; - let conn = TursoConn::from_ptr(ctx); - if col_idx > conn.cursor_idx as i32 || conn.cursor.is_none() { - return std::ptr::null(); - } - if let Some(values) = &conn.cursor { - let value = &values[col_idx as usize]; - match value { - Value::Text(s) => { - return s.as_ptr() as *const i8; - } - _ => return std::ptr::null(), - } - }; - std::ptr::null() -} - -#[no_mangle] -pub extern "C" fn rows_close(rows_ptr: *mut c_void) { - if !rows_ptr.is_null() { - let _ = unsafe { Box::from_raw(rows_ptr as *mut Rows) }; - } -} diff --git a/bindings/go/rs_src/types.rs b/bindings/go/rs_src/types.rs index 711887229..b8fc3ac75 100644 --- a/bindings/go/rs_src/types.rs +++ b/bindings/go/rs_src/types.rs @@ -1,14 +1,190 @@ +use std::ffi::{c_char, c_void}; +#[allow(dead_code)] #[repr(C)] pub enum ResultCode { Error = -1, Ok = 0, Row = 1, Busy = 2, - Done = 3, - Io = 4, - Interrupt = 5, - Invalid = 6, - Null = 7, - NoMem = 8, - ReadOnly = 9, + Io = 3, + Interrupt = 4, + Invalid = 5, + Null = 6, + NoMem = 7, + ReadOnly = 8, + NoData = 9, + Done = 10, +} + +#[repr(C)] +pub enum ValueType { + Integer = 0, + Text = 1, + Blob = 2, + Real = 3, + Null = 4, +} + +#[repr(C)] +pub struct TursoValue { + pub value_type: ValueType, + pub value: ValueUnion, +} + +#[repr(C)] +pub union ValueUnion { + pub int_val: i64, + pub real_val: f64, + pub text_ptr: *const c_char, + pub blob_ptr: *const c_void, +} + +#[repr(C)] +pub struct Blob { + pub data: *const u8, + pub len: usize, +} + +impl Blob { + pub fn to_ptr(&self) -> *const c_void { + self as *const Blob as *const c_void + } +} + +pub struct AllocPool { + strings: Vec, + blobs: Vec>, +} +impl AllocPool { + pub fn new() -> Self { + AllocPool { + strings: Vec::new(), + blobs: Vec::new(), + } + } + pub fn add_string(&mut self, s: String) -> &String { + self.strings.push(s); + self.strings.last().unwrap() + } + + pub fn add_blob(&mut self, b: Vec) -> &Vec { + self.blobs.push(b); + self.blobs.last().unwrap() + } +} + +#[no_mangle] +pub extern "C" fn free_blob(blob_ptr: *mut c_void) { + if blob_ptr.is_null() { + return; + } + unsafe { + let _ = Box::from_raw(blob_ptr as *mut Blob); + } +} +#[allow(dead_code)] +impl ValueUnion { + fn from_str(s: &str) -> Self { + ValueUnion { + text_ptr: s.as_ptr() as *const c_char, + } + } + + fn from_bytes(b: &[u8]) -> Self { + ValueUnion { + blob_ptr: Blob { + data: b.as_ptr(), + len: b.len(), + } + .to_ptr(), + } + } + + fn from_int(i: i64) -> Self { + ValueUnion { int_val: i } + } + + fn from_real(r: f64) -> Self { + ValueUnion { real_val: r } + } + + fn from_null() -> Self { + ValueUnion { int_val: 0 } + } + + pub fn to_int(&self) -> i64 { + unsafe { self.int_val } + } + + pub fn to_real(&self) -> f64 { + unsafe { self.real_val } + } + + pub fn to_str(&self) -> &str { + unsafe { std::ffi::CStr::from_ptr(self.text_ptr).to_str().unwrap() } + } + + pub fn to_bytes(&self) -> &[u8] { + let blob = unsafe { self.blob_ptr as *const Blob }; + let blob = unsafe { &*blob }; + unsafe { std::slice::from_raw_parts(blob.data, blob.len) } + } +} + +impl TursoValue { + pub fn new(value_type: ValueType, value: ValueUnion) -> Self { + TursoValue { value_type, value } + } + + #[allow(clippy::wrong_self_convention)] + pub fn to_ptr(self) -> *const c_void { + Box::into_raw(Box::new(self)) as *const c_void + } + + pub fn from_value(value: &limbo_core::Value<'_>) -> Self { + match value { + limbo_core::Value::Integer(i) => { + TursoValue::new(ValueType::Integer, ValueUnion::from_int(*i)) + } + limbo_core::Value::Float(r) => { + TursoValue::new(ValueType::Real, ValueUnion::from_real(*r)) + } + limbo_core::Value::Text(s) => TursoValue::new(ValueType::Text, ValueUnion::from_str(s)), + limbo_core::Value::Blob(b) => { + TursoValue::new(ValueType::Blob, ValueUnion::from_bytes(b)) + } + limbo_core::Value::Null => TursoValue::new(ValueType::Null, ValueUnion::from_null()), + } + } + + pub fn to_value<'pool>(&self, pool: &'pool mut AllocPool) -> limbo_core::Value<'pool> { + match self.value_type { + ValueType::Integer => limbo_core::Value::Integer(unsafe { self.value.int_val }), + ValueType::Real => limbo_core::Value::Float(unsafe { self.value.real_val }), + ValueType::Text => { + let cstr = unsafe { std::ffi::CStr::from_ptr(self.value.text_ptr) }; + match cstr.to_str() { + Ok(utf8_str) => { + let owned = utf8_str.to_owned(); + // statement needs to own these strings, will free when closed + let borrowed = pool.add_string(owned); + limbo_core::Value::Text(borrowed) + } + Err(_) => limbo_core::Value::Null, + } + } + ValueType::Blob => { + let blob_ptr = unsafe { self.value.blob_ptr as *const Blob }; + if blob_ptr.is_null() { + limbo_core::Value::Null + } else { + let blob = unsafe { &*blob_ptr }; + let data = unsafe { std::slice::from_raw_parts(blob.data, blob.len) }; + let borrowed = pool.add_blob(data.to_vec()); + limbo_core::Value::Blob(borrowed) + } + } + ValueType::Null => limbo_core::Value::Null, + } + } } diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 20e1a5774..2b7895fe2 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -1,77 +1,192 @@ package turso import ( + "context" "database/sql/driver" + "errors" "fmt" "io" + "unsafe" ) -type stmt struct { - ctx uintptr - sql string +// only construct tursoStmt with initStmt function to ensure proper initialization +type tursoStmt struct { + ctx uintptr + sql string + query stmtQueryFn + execute stmtExecuteFn + getParamCount func(uintptr) int32 } -type rows struct { - ctx uintptr - rowsPtr uintptr - columns []string - err error +// Initialize/register the FFI function pointers for the statement methods +func initStmt(ctx uintptr, sql string) *tursoStmt { + var query stmtQueryFn + var execute stmtExecuteFn + var getParamCount func(uintptr) int32 + methods := []ExtFunc{{query, FfiStmtQuery}, {execute, FfiStmtExec}, {getParamCount, FfiStmtParameterCount}} + for i := range methods { + methods[i].initFunc() + } + return &tursoStmt{ + ctx: uintptr(ctx), + sql: sql, + } } -func (ls *stmt) Query(args []driver.Value) (driver.Rows, error) { - var dbPrepare func(uintptr, uintptr) uintptr - getExtFunc(&dbPrepare, "db_prepare") +func (st *tursoStmt) NumInput() int { + return int(st.getParamCount(st.ctx)) +} - queryPtr := toCString(ls.sql) - defer freeCString(queryPtr) +func (st *tursoStmt) Exec(args []driver.Value) (driver.Result, error) { + argArray, err := buildArgs(args) + if err != nil { + return nil, err + } + argPtr := uintptr(0) + argCount := uint64(len(argArray)) + if argCount > 0 { + argPtr = uintptr(unsafe.Pointer(&argArray[0])) + } + var changes uint64 + rc := st.execute(st.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) + switch ResultCode(rc) { + case Ok: + return driver.RowsAffected(changes), nil + case Error: + return nil, errors.New("error executing statement") + case Busy: + return nil, errors.New("busy") + case Interrupt: + return nil, errors.New("interrupted") + case Invalid: + return nil, errors.New("invalid statement") + default: + return nil, fmt.Errorf("unexpected status: %d", rc) + } +} - rowsPtr := dbPrepare(ls.ctx, queryPtr) +func (st *tursoStmt) Query(args []driver.Value) (driver.Rows, error) { + queryArgs, err := buildArgs(args) + if err != nil { + return nil, err + } + rowsPtr := st.query(st.ctx, uintptr(unsafe.Pointer(&queryArgs[0])), uint64(len(queryArgs))) if rowsPtr == 0 { - return nil, fmt.Errorf("failed to prepare query") + return nil, fmt.Errorf("query failed for: %q", st.sql) } - var colFunc func(uintptr, uintptr) uintptr - - getExtFunc(&colFunc, "columns") - - rows := &rows{ - ctx: ls.ctx, - rowsPtr: rowsPtr, - } - return rows, nil + return initRows(rowsPtr), nil } -func (lr *rows) Columns() []string { - return lr.columns +func (ts *tursoStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + stripped := namedValueToValue(args) + argArray, err := getArgsPtr(stripped) + if err != nil { + return nil, err + } + var changes uintptr + res := ts.execute(ts.ctx, argArray, uint64(len(args)), changes) + switch ResultCode(res) { + case Ok: + return driver.RowsAffected(changes), nil + case Error: + return nil, errors.New("error executing statement") + case Busy: + return nil, errors.New("busy") + case Interrupt: + return nil, errors.New("interrupted") + default: + return nil, fmt.Errorf("unexpected status: %d", res) + } } -func (lr *rows) Close() error { - var rowsClose func(uintptr) - getExtFunc(&rowsClose, "rows_close") - rowsClose(lr.rowsPtr) +func (st *tursoStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + queryArgs, err := buildNamedArgs(args) + if err != nil { + return nil, err + } + rowsPtr := st.query(st.ctx, uintptr(unsafe.Pointer(&queryArgs[0])), uint64(len(queryArgs))) + if rowsPtr == 0 { + return nil, fmt.Errorf("query failed for: %q", st.sql) + } + return initRows(rowsPtr), nil +} + +// only construct tursoRows with initRows function to ensure proper initialization +type tursoRows struct { + ctx uintptr + columns []string + closed bool + getCols func(uintptr, *uint) uintptr + next func(uintptr) uintptr + getValue func(uintptr, int32) uintptr + closeRows func(uintptr) uintptr + freeCols func(uintptr) uintptr +} + +// Initialize/register the FFI function pointers for the rows methods +// DO NOT construct 'tursoRows' without this function +func initRows(ctx uintptr) *tursoRows { + var getCols func(uintptr, *uint) uintptr + var getValue func(uintptr, int32) uintptr + var closeRows func(uintptr) uintptr + var freeCols func(uintptr) uintptr + var next func(uintptr) uintptr + methods := []ExtFunc{ + {getCols, FfiRowsGetColumns}, + {getValue, FfiRowsGetValue}, + {closeRows, FfiRowsClose}, + {freeCols, FfiFreeColumns}, + {next, FfiRowsNext}} + for i := range methods { + methods[i].initFunc() + } + + return &tursoRows{ + ctx: ctx, + getCols: getCols, + getValue: getValue, + closeRows: closeRows, + freeCols: freeCols, + next: next, + } +} + +func (r *tursoRows) Columns() []string { + if r.columns == nil { + var columnCount uint + colArrayPtr := r.getCols(r.ctx, &columnCount) + if colArrayPtr != 0 && columnCount > 0 { + r.columns = cArrayToGoStrings(colArrayPtr, columnCount) + if r.freeCols == nil { + getFfiFunc(&r.freeCols, FfiFreeColumns) + } + defer r.freeCols(colArrayPtr) + } + } + return r.columns +} + +func (r *tursoRows) Close() error { + if r.closed { + return nil + } + r.closed = true + r.closeRows(r.ctx) + r.ctx = 0 return nil } -func (lr *rows) Next(dest []driver.Value) error { - var rowsNext func(uintptr, uintptr) int32 - getExtFunc(&rowsNext, "rows_next") - - status := rowsNext(lr.ctx, lr.rowsPtr) +func (r *tursoRows) Next(dest []driver.Value) error { + status := r.next(r.ctx) switch ResultCode(status) { case Row: for i := range dest { - getExtFunc(&rowsGetValue, "rows_get_value") - - valPtr := rowsGetValue(lr.ctx, int32(i)) - if valPtr != 0 { - val := cStringToGoString(valPtr) - dest[i] = val - freeCString(valPtr) - } else { - dest[i] = nil - } + valPtr := r.getValue(r.ctx, int32(i)) + val := toGoValue(valPtr) + dest[i] = val } return nil - case 0: // No more rows + case Done: return io.EOF default: return fmt.Errorf("unexpected status: %d", status) diff --git a/bindings/go/turso.go b/bindings/go/turso.go index 0c095ac80..dcafb3a64 100644 --- a/bindings/go/turso.go +++ b/bindings/go/turso.go @@ -4,43 +4,61 @@ import ( "database/sql" "database/sql/driver" "errors" + "fmt" "log/slog" "os" + "runtime" "sync" "unsafe" "github.com/ebitengine/purego" + "golang.org/x/sys/windows" ) -const ( - turso = "../../target/debug/lib_turso_go.so" -) +const turso = "../../target/debug/lib_turso_go" +const driverName = "turso" -func toGoStr(ptr uintptr, length int) string { - if ptr == 0 { - return "" +var tursoLib uintptr + +func getSystemLibrary() error { + switch runtime.GOOS { + case "darwin": + slib, err := purego.Dlopen(fmt.Sprintf("%s.dylib", turso), purego.RTLD_LAZY) + if err != nil { + return err + } + tursoLib = slib + case "linux": + slib, err := purego.Dlopen(fmt.Sprintf("%s.so", turso), purego.RTLD_LAZY) + if err != nil { + return err + } + tursoLib = slib + case "windows": + slib, err := windows.LoadLibrary(fmt.Sprintf("%s.dll", turso)) + if err != nil { + return err + } + tursoLib = slib + default: + panic(fmt.Errorf("GOOS=%s is not supported", runtime.GOOS)) } - uptr := unsafe.Pointer(ptr) - s := (*string)(uptr) - if s == nil { - // redundant - return "" - } - return *s + return nil } func init() { - slib, err := purego.Dlopen(turso, purego.RTLD_LAZY) + err := getSystemLibrary() if err != nil { slog.Error("Error opening turso library: ", err) os.Exit(1) } - lib = slib - sql.Register("turso", &tursoDriver{}) + sql.Register(driverName, &tursoDriver{}) } -type tursoDriver struct { - tursoCtx +type tursoDriver struct{} + +func (d tursoDriver) Open(name string) (driver.Conn, error) { + return openConn(name) } func toCString(s string) uintptr { @@ -48,80 +66,76 @@ func toCString(s string) uintptr { return uintptr(unsafe.Pointer(&b[0])) } -func getExtFunc(ptr interface{}, name string) { - purego.RegisterLibFunc(ptr, lib, name) +// helper to register an FFI function in the lib_turso_go library +func getFfiFunc(ptr interface{}, name string) { + purego.RegisterLibFunc(&ptr, tursoLib, name) } -type conn struct { +type tursoConn struct { ctx uintptr sync.Mutex - writeTimeFmt string - lastInsertID int64 - lastAffected int64 + prepare func(uintptr, uintptr) uintptr } -func newConn() *conn { - return &conn{ - 0, +func newConn(ctx uintptr) *tursoConn { + var prepare func(uintptr, uintptr) uintptr + getFfiFunc(&prepare, FfiDbPrepare) + return &tursoConn{ + ctx, sync.Mutex{}, - "2006-01-02 15:04:05", - 0, - 0, + prepare, } } -func open(dsn string) (*conn, error) { - var open func(uintptr) uintptr - getExtFunc(&open, ExtDBOpen) - c := newConn() - path := toCString(dsn) - ctx := open(path) - c.ctx = ctx - return c, nil -} +func openConn(dsn string) (*tursoConn, error) { + var dbOpen func(uintptr) uintptr + getFfiFunc(&dbOpen, FfiDbOpen) -type tursoCtx struct { - conn *conn - tx *sql.Tx - err error - rows *sql.Rows - stmt *sql.Stmt -} + cStr := toCString(dsn) + defer freeCString(cStr) -func (lc tursoCtx) Open(dsn string) (driver.Conn, error) { - conn, err := open(dsn) - if err != nil { - return nil, err + ctx := dbOpen(cStr) + if ctx == 0 { + return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) } - nc := tursoCtx{conn: conn} - return nc, nil + return &tursoConn{ctx: ctx}, nil } -func (lc tursoCtx) Close() error { - var closedb func(uintptr) uintptr - getExtFunc(&closedb, ExtDBClose) - closedb(lc.conn.ctx) +func (c *tursoConn) Close() error { + if c.ctx == 0 { + return nil + } + var dbClose func(uintptr) uintptr + getFfiFunc(&dbClose, FfiDbClose) + + dbClose(c.ctx) + c.ctx = 0 return nil } -// TODO: Begin not implemented -func (lc tursoCtx) Begin() (driver.Tx, error) { - return nil, nil +func (c *tursoConn) Prepare(query string) (driver.Stmt, error) { + if c.ctx == 0 { + return nil, errors.New("connection closed") + } + if c.prepare == nil { + var dbPrepare func(uintptr, uintptr) uintptr + getFfiFunc(&dbPrepare, FfiDbPrepare) + c.prepare = dbPrepare + } + qPtr := toCString(query) + stmtPtr := c.prepare(c.ctx, qPtr) + freeCString(qPtr) + + if stmtPtr == 0 { + return nil, fmt.Errorf("prepare failed: %q", query) + } + return &tursoStmt{ + ctx: stmtPtr, + sql: query, + }, nil } -func (ls tursoCtx) Prepare(sql string) (driver.Stmt, error) { - var prepare func(uintptr, uintptr) uintptr - getExtFunc(&prepare, ExtDBPrepare) - s := toCString(sql) - statement := prepare(ls.conn.ctx, s) - if statement == 0 { - return nil, errors.New("no rows") - } - ls.stmt = stmt{ - ctx: statement, - - } - - } - return nil, nil +// begin is needed to implement driver.Conn.. for now not implemented +func (c *tursoConn) Begin() (driver.Tx, error) { + return nil, errors.New("transactions not implemented") } diff --git a/bindings/go/types.go b/bindings/go/types.go index 0569b317a..e24b2f168 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -1,28 +1,248 @@ package turso +import ( + "database/sql/driver" + "fmt" + "unsafe" +) + type ResultCode int const ( - Error ResultCode = -1 - Ok ResultCode = 0 - Row ResultCode = 1 - Busy ResultCode = 2 - Done ResultCode = 3 - Io ResultCode = 4 - Interrupt ResultCode = 5 - Invalid ResultCode = 6 - Null ResultCode = 7 - NoMem ResultCode = 8 - ReadOnly ResultCode = 9 - ExtDBOpen string = "db_open" - ExtDBClose string = "db_close" - ExtDBPrepare string = "db_prepare" + Error ResultCode = -1 + Ok ResultCode = 0 + Row ResultCode = 1 + Busy ResultCode = 2 + Io ResultCode = 3 + Interrupt ResultCode = 4 + Invalid ResultCode = 5 + Null ResultCode = 6 + NoMem ResultCode = 7 + ReadOnly ResultCode = 8 + NoData ResultCode = 9 + Done ResultCode = 10 ) -var ( - lib uintptr - dbPrepare func(uintptr, uintptr) uintptr - rowsNext func(rowsPtr uintptr) int32 - rowsGetValue func(rowsPtr uintptr, colIdx uint) uintptr - freeCString func(strPtr uintptr) +const ( + FfiDbOpen string = "db_open" + FfiDbClose string = "db_close" + FfiDbPrepare string = "db_prepare" + FfiStmtExec string = "stmt_execute" + FfiStmtQuery string = "stmt_query" + FfiStmtParameterCount string = "stmt_parameter_count" + FfiRowsClose string = "rows_close" + FfiRowsGetColumns string = "rows_get_columns" + FfiRowsNext string = "rows_next" + FfiRowsGetValue string = "rows_get_value" + FfiFreeColumns string = "free_columns" + FfiFreeCString string = "free_string" ) + +// convert a namedValue slice into normal values until named parameters are supported +func namedValueToValue(named []driver.NamedValue) []driver.Value { + out := make([]driver.Value, len(named)) + for i, nv := range named { + out[i] = nv.Value + } + return out +} + +func buildNamedArgs(named []driver.NamedValue) ([]tursoValue, error) { + args := make([]driver.Value, len(named)) + for i, nv := range named { + args[i] = nv.Value + } + return buildArgs(args) +} + +type ExtFunc struct { + funcPtr interface{} + funcName string +} + +func (ef *ExtFunc) initFunc() { + getFfiFunc(&ef.funcPtr, ef.funcName) +} + +type valueType int + +const ( + intVal valueType = iota + textVal + blobVal + realVal + nullVal +) + +// struct to pass Go values over FFI +type tursoValue struct { + Type valueType + Value [8]byte +} + +// struct to pass byte slices over FFI +type Blob struct { + Data uintptr + Len uint +} + +// convert a tursoValue to a native Go value +func toGoValue(valPtr uintptr) interface{} { + val := (*tursoValue)(unsafe.Pointer(valPtr)) + switch val.Type { + case intVal: + return *(*int64)(unsafe.Pointer(&val.Value)) + case realVal: + return *(*float64)(unsafe.Pointer(&val.Value)) + case textVal: + textPtr := *(*uintptr)(unsafe.Pointer(&val.Value)) + return GoString(textPtr) + case blobVal: + blobPtr := *(*uintptr)(unsafe.Pointer(&val.Value)) + return toGoBlob(blobPtr) + case nullVal: + return nil + default: + return nil + } +} + +func getArgsPtr(args []driver.Value) (uintptr, error) { + if len(args) == 0 { + return 0, nil + } + argSlice, err := buildArgs(args) + if err != nil { + return 0, err + } + return uintptr(unsafe.Pointer(&argSlice[0])), nil +} + +// convert a byte slice to a Blob type that can be sent over FFI +func makeBlob(b []byte) *Blob { + if len(b) == 0 { + return nil + } + blob := &Blob{ + Data: uintptr(unsafe.Pointer(&b[0])), + Len: uint(len(b)), + } + return blob +} + +// converts a blob received via FFI to a native Go byte slice +func toGoBlob(blobPtr uintptr) []byte { + if blobPtr == 0 { + return nil + } + blob := (*Blob)(unsafe.Pointer(blobPtr)) + return unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len) +} + +var freeString func(*byte) + +// free a C style string allocated via FFI +func freeCString(cstr uintptr) { + if cstr == 0 { + return + } + if freeString == nil { + getFfiFunc(&freeString, FfiFreeCString) + } + freeString((*byte)(unsafe.Pointer(cstr))) +} + +func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { + if arrayPtr == 0 || length == 0 { + return nil + } + + ptrSlice := unsafe.Slice( + (**byte)(unsafe.Pointer(arrayPtr)), + length, + ) + + out := make([]string, 0, length) + for _, cstr := range ptrSlice { + out = append(out, GoString(uintptr(unsafe.Pointer(cstr)))) + } + return out +} + +// convert a Go slice of driver.Value to a slice of tursoValue that can be sent over FFI +func buildArgs(args []driver.Value) ([]tursoValue, error) { + argSlice := make([]tursoValue, len(args)) + + for i, v := range args { + switch val := v.(type) { + case nil: + argSlice[i].Type = nullVal + + case int64: + argSlice[i].Type = intVal + storeInt64(&argSlice[i].Value, val) + + case float64: + argSlice[i].Type = realVal + storeFloat64(&argSlice[i].Value, val) + case string: + argSlice[i].Type = textVal + cstr := CString(val) + storePointer(&argSlice[i].Value, cstr) + case []byte: + argSlice[i].Type = blobVal + blob := makeBlob(val) + *(*uintptr)(unsafe.Pointer(&argSlice[i].Value)) = uintptr(unsafe.Pointer(blob)) + default: + return nil, fmt.Errorf("unsupported type: %T", v) + } + } + return argSlice, nil +} + +func storeInt64(data *[8]byte, val int64) { + *(*int64)(unsafe.Pointer(data)) = val +} + +func storeFloat64(data *[8]byte, val float64) { + *(*float64)(unsafe.Pointer(data)) = val +} + +func storePointer(data *[8]byte, ptr *byte) { + *(*uintptr)(unsafe.Pointer(data)) = uintptr(unsafe.Pointer(ptr)) +} + +type stmtExecuteFn func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 +type stmtQueryFn func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr + +/* Credit below (Apache2 License) to: +https://github.com/ebitengine/purego/blob/main/internal/strings/strings.go +*/ + +func hasSuffix(s, suffix string) bool { + return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix +} + +func CString(name string) *byte { + if hasSuffix(name, "\x00") { + return &(*(*[]byte)(unsafe.Pointer(&name)))[0] + } + b := make([]byte, len(name)+1) + copy(b, name) + return &b[0] +} + +func GoString(c uintptr) string { + ptr := *(*unsafe.Pointer)(unsafe.Pointer(&c)) + if ptr == nil { + return "" + } + var length int + for { + if *(*byte)(unsafe.Add(ptr, uintptr(length))) == '\x00' { + break + } + length++ + } + return string(unsafe.Slice((*byte)(ptr), length)) +} diff --git a/core/lib.rs b/core/lib.rs index d75510f7c..f03c73e00 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -427,6 +427,10 @@ impl Connection { let prev_total_changes = self.total_changes.get(); self.total_changes.set(prev_total_changes + nchange); } + + pub fn total_changes(&self) -> i64 { + self.total_changes.get() + } } pub struct Statement { @@ -473,6 +477,10 @@ impl Statement { &self.program.parameters } + pub fn parameters_count(&self) -> usize { + self.program.parameters.count() + } + pub fn bind_at(&mut self, index: NonZero, value: Value) { self.state.bind_at(index, value.into()); }