bindings/go: enable multiple connections, register all symbols at library load

This commit is contained in:
PThorpe92
2025-01-31 13:22:48 -05:00
parent 950f29daab
commit 8d93130809
9 changed files with 313 additions and 284 deletions

View File

@@ -1,6 +1,7 @@
package limbo
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
@@ -9,66 +10,67 @@ import (
"github.com/ebitengine/purego"
)
const (
driverName = "sqlite3"
libName = "lib_limbo_go"
)
func init() {
err := ensureLibLoaded()
if err != nil {
panic(err)
}
sql.Register(driverName, &limboDriver{})
}
type limboDriver struct {
sync.Mutex
}
var library = sync.OnceValue(func() uintptr {
lib, err := loadLibrary()
if err != nil {
panic(err)
}
return lib
})
var (
libOnce sync.Once
loadErr error
dbOpen func(string) uintptr
dbClose func(uintptr) uintptr
connPrepare func(uintptr, string) uintptr
freeBlobFunc func(uintptr)
freeStringFunc func(uintptr)
rowsGetColumns func(uintptr, *uint) uintptr
rowsGetValue func(uintptr, int32) uintptr
closeRows func(uintptr) uintptr
freeCols func(uintptr) uintptr
rowsNext func(uintptr) uintptr
stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
stmtParamCount func(uintptr) int32
closeStmt func(uintptr) int32
libOnce sync.Once
limboLib uintptr
loadErr error
dbOpen func(string) uintptr
dbClose func(uintptr) uintptr
connPrepare func(uintptr, string) uintptr
freeBlobFunc func(uintptr)
freeStringFunc func(uintptr)
rowsGetColumns func(uintptr) int32
rowsGetColumnName func(uintptr, int32) uintptr
rowsGetValue func(uintptr, int32) uintptr
closeRows func(uintptr) uintptr
rowsNext func(uintptr) uintptr
stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
stmtParamCount func(uintptr) int32
closeStmt func(uintptr) int32
)
// Register all the symbols on library load
func ensureLibLoaded() error {
libOnce.Do(func() {
purego.RegisterLibFunc(&dbOpen, library(), FfiDbOpen)
purego.RegisterLibFunc(&dbClose, library(), FfiDbClose)
purego.RegisterLibFunc(&connPrepare, library(), FfiDbPrepare)
purego.RegisterLibFunc(&freeBlobFunc, library(), FfiFreeBlob)
purego.RegisterLibFunc(&freeStringFunc, library(), FfiFreeCString)
purego.RegisterLibFunc(&rowsGetColumns, library(), FfiRowsGetColumns)
purego.RegisterLibFunc(&rowsGetValue, library(), FfiRowsGetValue)
purego.RegisterLibFunc(&closeRows, library(), FfiRowsClose)
purego.RegisterLibFunc(&freeCols, library(), FfiFreeColumns)
purego.RegisterLibFunc(&rowsNext, library(), FfiRowsNext)
purego.RegisterLibFunc(&stmtQuery, library(), FfiStmtQuery)
purego.RegisterLibFunc(&stmtExec, library(), FfiStmtExec)
purego.RegisterLibFunc(&stmtParamCount, library(), FfiStmtParameterCount)
purego.RegisterLibFunc(&closeStmt, library(), FfiStmtClose)
limboLib, loadErr = loadLibrary()
if loadErr != nil {
return
}
purego.RegisterLibFunc(&dbOpen, limboLib, FfiDbOpen)
purego.RegisterLibFunc(&dbClose, limboLib, FfiDbClose)
purego.RegisterLibFunc(&connPrepare, limboLib, FfiDbPrepare)
purego.RegisterLibFunc(&freeBlobFunc, limboLib, FfiFreeBlob)
purego.RegisterLibFunc(&freeStringFunc, limboLib, FfiFreeCString)
purego.RegisterLibFunc(&rowsGetColumns, limboLib, FfiRowsGetColumns)
purego.RegisterLibFunc(&rowsGetColumnName, limboLib, FfiRowsGetColumnName)
purego.RegisterLibFunc(&rowsGetValue, limboLib, FfiRowsGetValue)
purego.RegisterLibFunc(&closeRows, limboLib, FfiRowsClose)
purego.RegisterLibFunc(&rowsNext, limboLib, FfiRowsNext)
purego.RegisterLibFunc(&stmtQuery, limboLib, FfiStmtQuery)
purego.RegisterLibFunc(&stmtExec, limboLib, FfiStmtExec)
purego.RegisterLibFunc(&stmtParamCount, limboLib, FfiStmtParameterCount)
purego.RegisterLibFunc(&closeStmt, limboLib, FfiStmtClose)
})
return loadErr
}
func (d *limboDriver) Open(name string) (driver.Conn, error) {
d.Lock()
defer d.Unlock()
conn, err := openConn(name)
d.Unlock()
if err != nil {
return nil, err
}
@@ -80,26 +82,24 @@ type limboConn struct {
ctx uintptr
}
func newConn(ctx uintptr) *limboConn {
return &limboConn{
sync.Mutex{},
ctx,
}
}
func openConn(dsn string) (*limboConn, error) {
ctx := dbOpen(dsn)
if ctx == 0 {
return nil, fmt.Errorf("failed to open database for dsn=%q", dsn)
}
return newConn(ctx), loadErr
return &limboConn{
sync.Mutex{},
ctx,
}, loadErr
}
func (c *limboConn) Close() error {
if c.ctx == 0 {
return nil
}
c.Lock()
dbClose(c.ctx)
c.Unlock()
c.ctx = 0
return nil
}
@@ -114,7 +114,7 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) {
if stmtPtr == 0 {
return nil, fmt.Errorf("failed to prepare query=%q", query)
}
return initStmt(stmtPtr, query), nil
return newStmt(stmtPtr, query), nil
}
// begin is needed to implement driver.Conn.. for now not implemented

View File

@@ -18,7 +18,7 @@ func TestMain(m *testing.M) {
panic(connErr)
}
defer conn.Close()
err := createTable()
err := createTable(conn)
if err != nil {
log.Fatalf("Error creating table: %v", err)
}
@@ -26,49 +26,13 @@ func TestMain(m *testing.M) {
}
func TestInsertData(t *testing.T) {
err := insertData()
err := insertData(conn)
if err != nil {
t.Fatalf("Error inserting data: %v", err)
}
}
func TestFunction(t *testing.T) {
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));"
stmt, err := conn.Prepare(insert)
if err != nil {
t.Fatalf("Error preparing statement: %v", err)
}
_, err = stmt.Exec(1, "hello", 100)
if err != nil {
t.Fatalf("Error executing statment with arguments: %v", err)
}
stmt.Close()
stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?")
if err != nil {
t.Fatalf("Error preparing select stmt: %v", err)
}
defer stmt.Close()
rows, err := stmt.Query(1)
if err != nil {
t.Fatalf("Error executing select stmt: %v", err)
}
defer rows.Close()
for rows.Next() {
var b []byte
err = rows.Scan(&b)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
fmt.Println("RESULTS: ", string(b))
}
}
func TestQuery(t *testing.T) {
err := insertData()
if err != nil {
t.Fatalf("Error inserting data: %v", err)
}
query := "SELECT * FROM test;"
stmt, err := conn.Prepare(query)
if err != nil {
@@ -117,6 +81,130 @@ func TestQuery(t *testing.T) {
}
func TestFunctions(t *testing.T) {
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));"
stmt, err := conn.Prepare(insert)
if err != nil {
t.Fatalf("Error preparing statement: %v", err)
}
_, err = stmt.Exec(60, "TestFunction", 400)
if err != nil {
t.Fatalf("Error executing statment with arguments: %v", err)
}
stmt.Close()
stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?")
if err != nil {
t.Fatalf("Error preparing select stmt: %v", err)
}
defer stmt.Close()
rows, err := stmt.Query(60)
if err != nil {
t.Fatalf("Error executing select stmt: %v", err)
}
defer rows.Close()
for rows.Next() {
var b []byte
err = rows.Scan(&b)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
if len(b) != 400 {
t.Fatalf("Expected 100 bytes, got %d", len(b))
}
}
sql := "SELECT uuid4_str();"
stmt, err = conn.Prepare(sql)
if err != nil {
t.Fatalf("Error preparing statement: %v", err)
}
defer stmt.Close()
rows, err = stmt.Query()
if err != nil {
t.Fatalf("Error executing query: %v", err)
}
defer rows.Close()
var i int
for rows.Next() {
var b string
err = rows.Scan(&b)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
if len(b) != 36 {
t.Fatalf("Expected 36 bytes, got %d", len(b))
}
i++
fmt.Printf("uuid: %s\n", b)
}
if i != 1 {
t.Fatalf("Expected 1 row, got %d", i)
}
fmt.Println("zeroblob + uuid functions passed")
}
func TestDuplicateConnection(t *testing.T) {
newConn, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening new connection: %v", err)
}
err = createTable(newConn)
if err != nil {
t.Fatalf("Error creating table: %v", err)
}
err = insertData(newConn)
if err != nil {
t.Fatalf("Error inserting data: %v", err)
}
query := "SELECT * FROM test;"
rows, err := newConn.Query(query)
if err != nil {
t.Fatalf("Error executing query: %v", err)
}
defer rows.Close()
for rows.Next() {
var a int
var b string
var c []byte
err = rows.Scan(&a, &b, &c)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
fmt.Println("RESULTS: ", a, b, string(c))
}
}
func TestDuplicateConnection2(t *testing.T) {
newConn, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening new connection: %v", err)
}
sql := "CREATE TABLE test (foo INTEGER, bar INTEGER, baz BLOB);"
newConn.Exec(sql)
sql = "INSERT INTO test (foo, bar, baz) VALUES (?, ?, uuid4());"
stmt, err := newConn.Prepare(sql)
stmt.Exec(242345, 2342434)
defer stmt.Close()
query := "SELECT * FROM test;"
rows, err := newConn.Query(query)
if err != nil {
t.Fatalf("Error executing query: %v", err)
}
defer rows.Close()
for rows.Next() {
var a int
var b int
var c []byte
err = rows.Scan(&a, &b, &c)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
fmt.Println("RESULTS: ", a, b, string(c))
if len(c) != 16 {
t.Fatalf("Expected 16 bytes, got %d", len(c))
}
}
}
func slicesAreEq(a, b []byte) bool {
if len(a) != len(b) {
fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))
@@ -133,7 +221,7 @@ func slicesAreEq(a, b []byte) bool {
var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"}
func createTable() error {
func createTable(conn *sql.DB) error {
insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);"
stmt, err := conn.Prepare(insert)
if err != nil {
@@ -144,7 +232,7 @@ func createTable() error {
return err
}
func insertData() error {
func insertData(conn *sql.DB) error {
for i := 1; i <= 5; i++ {
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);"
stmt, err := conn.Prepare(insert)

View File

@@ -3,7 +3,6 @@
package limbo
import (
"database/sql"
"fmt"
"os"
"path/filepath"
@@ -44,11 +43,3 @@ func loadLibrary() (uintptr, error) {
}
return 0, fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName)
}
func init() {
err := ensureLibLoaded()
if err != nil {
panic(err)
}
sql.Register("sqlite3", &limboDriver{})
}

View File

@@ -3,7 +3,6 @@
package limbo
import (
"database/sql"
"fmt"
"os"
"path/filepath"
@@ -12,14 +11,14 @@ import (
"golang.org/x/sys/windows"
)
func loadLibrary() error {
func loadLibrary() (uintptr, error) {
libName := fmt.Sprintf("%s.dll", libName)
pathEnv := os.Getenv("PATH")
paths := strings.Split(pathEnv, ";")
cwd, err := os.Getwd()
if err != nil {
return err
return 0, err
}
paths = append(paths, cwd)
for _, path := range paths {
@@ -27,21 +26,11 @@ func loadLibrary() error {
if _, err := os.Stat(dllPath); err == nil {
slib, loadErr := windows.LoadLibrary(dllPath)
if loadErr != nil {
return fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr)
return 0, fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr)
}
limboLib = uintptr(slib)
return nil
return uintptr(slib), nil
}
}
return fmt.Errorf("library %s not found in PATH or CWD", libName)
}
func init() {
err := loadLibrary()
if err != nil {
fmt.Println("Error opening limbo library: ", err)
os.Exit(1)
}
sql.Register("sqlite3", &limboDriver{})
return 0, fmt.Errorf("library %s not found in PATH or CWD", libName)
}

View File

@@ -7,7 +7,6 @@ import (
"sync"
)
// only construct limboRows with initRows function to ensure proper initialization
type limboRows struct {
mu sync.Mutex
ctx uintptr
@@ -15,12 +14,12 @@ type limboRows struct {
closed bool
}
// Initialize/register the FFI function pointers for the rows methods
// DO NOT construct 'limboRows' without this function
func initRows(ctx uintptr) *limboRows {
func newRows(ctx uintptr) *limboRows {
return &limboRows{
mu: sync.Mutex{},
ctx: ctx,
mu: sync.Mutex{},
ctx: ctx,
closed: false,
columns: nil,
}
}
@@ -29,15 +28,17 @@ func (r *limboRows) Columns() []string {
return nil
}
if r.columns == nil {
var columnCount uint
r.mu.Lock()
defer r.mu.Unlock()
colArrayPtr := rowsGetColumns(r.ctx, &columnCount)
if colArrayPtr != 0 && columnCount > 0 {
r.columns = cArrayToGoStrings(colArrayPtr, columnCount)
if freeCols != nil {
defer freeCols(colArrayPtr)
count := rowsGetColumns(r.ctx)
if count > 0 {
columns := make([]string, 0, count)
for i := 0; i < int(count); i++ {
cstr := rowsGetColumnName(r.ctx, int32(i))
columns = append(columns, fmt.Sprintf("%s", GoString(cstr)))
freeCString(cstr)
}
r.mu.Unlock()
r.columns = columns
}
}
return r.columns
@@ -59,14 +60,14 @@ func (r *limboRows) Next(dest []driver.Value) error {
if r.ctx == 0 || r.closed {
return io.EOF
}
r.mu.Lock()
defer r.mu.Unlock()
for {
status := rowsNext(r.ctx)
switch ResultCode(status) {
case Row:
for i := range dest {
r.mu.Lock()
valPtr := rowsGetValue(r.ctx, int32(i))
r.mu.Unlock()
val := toGoValue(valPtr)
dest[i] = val
}

View File

@@ -78,30 +78,30 @@ pub extern "C" fn free_string(s: *mut c_char) {
}
}
/// Function to get the number of expected ResultColumns in the prepared statement.
/// to avoid the needless complexity of returning an array of strings, this instead
/// works like rows_next/rows_get_value
#[no_mangle]
pub extern "C" fn rows_get_columns(
rows_ptr: *mut c_void,
out_length: *mut usize,
) -> *mut *const c_char {
if rows_ptr.is_null() || out_length.is_null() {
pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 {
if rows_ptr.is_null() {
return -1;
}
let rows = LimboRows::from_ptr(rows_ptr);
rows.stmt.columns().len() as i32
}
#[no_mangle]
pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *const c_char {
if rows_ptr.is_null() {
return std::ptr::null_mut();
}
let rows = LimboRows::from_ptr(rows_ptr);
let c_strings: Vec<std::ffi::CString> = rows
.stmt
.columns()
.iter()
.map(|name| std::ffi::CString::new(name.as_str()).unwrap())
.collect();
let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
unsafe {
*out_length = c_ptrs.len();
if idx < 0 || idx as usize >= rows.stmt.columns().len() {
return std::ptr::null_mut();
}
let ptr = c_ptrs.as_ptr();
std::mem::forget(c_strings);
std::mem::forget(c_ptrs);
ptr as *mut *const c_char
let name = &rows.stmt.columns()[idx as usize];
let cstr = std::ffi::CString::new(name.as_bytes()).expect("Failed to create CString");
cstr.into_raw() as *const c_char
}
#[no_mangle]
@@ -111,21 +111,6 @@ pub extern "C" fn rows_close(rows_ptr: *mut c_void) {
}
}
#[no_mangle]
pub extern "C" fn free_columns(columns: *mut *const c_char) {
if columns.is_null() {
return;
}
unsafe {
let mut idx = 0;
while !(*columns.add(idx)).is_null() {
let _ = std::ffi::CString::from_raw(*columns.add(idx) as *mut c_char);
idx += 1;
}
let _ = Box::from_raw(columns);
}
}
#[no_mangle]
pub extern "C" fn free_rows(rows: *mut c_void) {
if rows.is_null() {

View File

@@ -132,13 +132,9 @@ pub struct LimboStatement<'conn> {
#[no_mangle]
pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode {
if !ctx.is_null() {
let stmt = LimboStatement::from_ptr(ctx);
if stmt.statement.is_none() {
return ResultCode::Error;
} else {
let _ = unsafe { Box::from_raw(ctx as *mut LimboStatement) };
return ResultCode::Ok;
}
let stmt = unsafe { Box::from_raw(ctx as *mut LimboStatement) };
drop(stmt);
return ResultCode::Ok;
}
ResultCode::Invalid
}

View File

@@ -1,6 +1,7 @@
package limbo
import (
"context"
"database/sql/driver"
"errors"
"fmt"
@@ -8,22 +9,16 @@ import (
"unsafe"
)
// only construct limboStmt with initStmt function to ensure proper initialization
// inUse tracks whether or not `query` has been called. if inUse > 0, stmt no longer
// owns the underlying data and `rows` is responsible for cleaning it up on close.
type limboStmt struct {
mu sync.Mutex
ctx uintptr
sql string
inUse int
mu sync.Mutex
ctx uintptr
sql string
}
// Initialize/register the FFI function pointers for the statement methods
func initStmt(ctx uintptr, sql string) *limboStmt {
func newStmt(ctx uintptr, sql string) *limboStmt {
return &limboStmt{
ctx: uintptr(ctx),
sql: sql,
inUse: 0,
ctx: uintptr(ctx),
sql: sql,
}
}
@@ -34,25 +29,19 @@ func (ls *limboStmt) NumInput() int {
}
func (ls *limboStmt) Close() error {
if ls.inUse == 0 {
ls.mu.Lock()
res := closeStmt(ls.ctx)
ls.mu.Unlock()
if ResultCode(res) != Ok {
return fmt.Errorf("error closing statement: %s", ResultCode(res).String())
}
ls.mu.Lock()
res := closeStmt(ls.ctx)
ls.mu.Unlock()
if ResultCode(res) != Ok {
return fmt.Errorf("error closing statement: %s", ResultCode(res).String())
}
ls.ctx = 0
return nil
}
func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
ls.mu.Lock()
argArray, cleanup, err := buildArgs(args)
defer func() {
cleanup()
ls.mu.Unlock()
}()
defer cleanup()
if err != nil {
return nil, err
}
@@ -62,7 +51,9 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
argPtr = uintptr(unsafe.Pointer(&argArray[0]))
}
var changes uint64
ls.mu.Lock()
rc := stmtExec(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes)))
ls.mu.Unlock()
switch ResultCode(rc) {
case Ok, Done:
return driver.RowsAffected(changes), nil
@@ -80,9 +71,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
}
func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
ls.mu.Lock()
queryArgs, cleanup, err := buildArgs(args)
defer func() { cleanup(); ls.mu.Unlock() }()
defer cleanup()
if err != nil {
return nil, err
}
@@ -90,64 +80,66 @@ func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
if len(args) > 0 {
argPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
}
ls.mu.Lock()
rowsPtr := stmtQuery(ls.ctx, argPtr, uint64(len(queryArgs)))
ls.mu.Unlock()
if rowsPtr == 0 {
return nil, fmt.Errorf("query failed for: %q", ls.sql)
}
ls.inUse++
return initRows(rowsPtr), nil
return newRows(rowsPtr), nil
}
// func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
// ls.mu.Lock()
// stripped := namedValueToValue(args)
// argArray, cleanup, err := getArgsPtr(stripped)
// defer func() { cleanup(); ls.mu.Unlock() }()
// if err != nil {
// return nil, err
// }
// select {
// case <-ctx.Done():
// return nil, ctx.Err()
// default:
// }
// var changes uint64
// res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes)))
// switch ResultCode(res) {
// case Ok, Done:
// changes := uint64(changes)
// return driver.RowsAffected(changes), nil
// case Error:
// return nil, errors.New("error executing statement")
// case Busy:
// return nil, errors.New("busy")
// case Interrupt:
// return nil, errors.New("interrupted")
// default:
// return nil, fmt.Errorf("unexpected status: %d", res)
// }
// }
//
// func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
// ls.mu.Lock()
// queryArgs, allocs, err := buildNamedArgs(args)
// defer func() { allocs(); ls.mu.Unlock() }()
// if err != nil {
// return nil, err
// }
// argsPtr := uintptr(0)
// if len(queryArgs) > 0 {
// argsPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
// }
// select {
// case <-ctx.Done():
// return nil, ctx.Err()
// default:
// }
// rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs)))
// if rowsPtr == 0 {
// return nil, fmt.Errorf("query failed for: %q", ls.sql)
// }
// ls.inUse++
// return initRows(rowsPtr), nil
// }
func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
stripped := namedValueToValue(args)
argArray, cleanup, err := getArgsPtr(stripped)
defer cleanup()
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var changes uint64
ls.mu.Lock()
res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes)))
ls.mu.Unlock()
switch ResultCode(res) {
case Ok, Done:
changes := uint64(changes)
return driver.RowsAffected(changes), nil
case Error:
return nil, errors.New("error executing statement")
case Busy:
return nil, errors.New("busy")
case Interrupt:
return nil, errors.New("interrupted")
default:
return nil, fmt.Errorf("unexpected status: %d", res)
}
}
func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
queryArgs, allocs, err := buildNamedArgs(args)
defer allocs()
if err != nil {
return nil, err
}
argsPtr := uintptr(0)
if len(queryArgs) > 0 {
argsPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
ls.mu.Lock()
rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs)))
ls.mu.Unlock()
if rowsPtr == 0 {
return nil, fmt.Errorf("query failed for: %q", ls.sql)
}
return newRows(rowsPtr), nil
}

View File

@@ -65,20 +65,23 @@ func (rc ResultCode) String() string {
}
const (
FfiDbOpen string = "db_open"
FfiDbClose string = "db_close"
FfiDbPrepare string = "db_prepare"
FfiStmtExec string = "stmt_execute"
FfiStmtQuery string = "stmt_query"
FfiStmtParameterCount string = "stmt_parameter_count"
FfiStmtClose string = "stmt_close"
FfiRowsClose string = "rows_close"
FfiRowsGetColumns string = "rows_get_columns"
FfiRowsNext string = "rows_next"
FfiRowsGetValue string = "rows_get_value"
FfiFreeColumns string = "free_columns"
FfiFreeCString string = "free_string"
FfiFreeBlob string = "free_blob"
driverName = "sqlite3"
libName = "lib_limbo_go"
FfiDbOpen = "db_open"
FfiDbClose = "db_close"
FfiDbPrepare = "db_prepare"
FfiStmtExec = "stmt_execute"
FfiStmtQuery = "stmt_query"
FfiStmtParameterCount = "stmt_parameter_count"
FfiStmtClose = "stmt_close"
FfiRowsClose = "rows_close"
FfiRowsGetColumns = "rows_get_columns"
FfiRowsGetColumnName = "rows_get_column_name"
FfiRowsNext = "rows_next"
FfiRowsGetValue = "rows_get_value"
FfiFreeColumns = "free_columns"
FfiFreeCString = "free_string"
FfiFreeBlob = "free_blob"
)
// convert a namedValue slice into normal values until named parameters are supported
@@ -212,22 +215,6 @@ func freeCString(cstrPtr uintptr) {
freeStringFunc(cstrPtr)
}
func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
if arrayPtr == 0 || length == 0 {
return nil
}
ptrSlice := unsafe.Slice(
(**byte)(unsafe.Pointer(arrayPtr)),
length,
)
out := make([]string, 0, length)
for _, cstr := range ptrSlice {
out = append(out, GoString(uintptr(unsafe.Pointer(cstr))))
}
return out
}
// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI
// for Blob types, we have to pin them so they are not garbage collected before they can be copied
// into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call