From 8d93130809b5e3553a03c286f64a9486bcc7e2ba Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Fri, 31 Jan 2025 13:22:48 -0500 Subject: [PATCH] bindings/go: enable multiple connections, register all symbols at library load --- bindings/go/connection.go | 104 ++++++++++---------- bindings/go/limbo_test.go | 168 ++++++++++++++++++++++++-------- bindings/go/limbo_unix.go | 9 -- bindings/go/limbo_windows.go | 21 +--- bindings/go/rows.go | 31 +++--- bindings/go/rs_src/rows.rs | 53 ++++------ bindings/go/rs_src/statement.rs | 10 +- bindings/go/stmt.go | 154 ++++++++++++++--------------- bindings/go/types.go | 47 ++++----- 9 files changed, 313 insertions(+), 284 deletions(-) diff --git a/bindings/go/connection.go b/bindings/go/connection.go index a1ab83fae..672f2e4e3 100644 --- a/bindings/go/connection.go +++ b/bindings/go/connection.go @@ -1,6 +1,7 @@ package limbo import ( + "database/sql" "database/sql/driver" "errors" "fmt" @@ -9,66 +10,67 @@ import ( "github.com/ebitengine/purego" ) -const ( - driverName = "sqlite3" - libName = "lib_limbo_go" -) +func init() { + err := ensureLibLoaded() + if err != nil { + panic(err) + } + sql.Register(driverName, &limboDriver{}) +} type limboDriver struct { sync.Mutex } -var library = sync.OnceValue(func() uintptr { - lib, err := loadLibrary() - if err != nil { - panic(err) - } - return lib -}) - var ( - libOnce sync.Once - loadErr error - dbOpen func(string) uintptr - dbClose func(uintptr) uintptr - connPrepare func(uintptr, string) uintptr - freeBlobFunc func(uintptr) - freeStringFunc func(uintptr) - rowsGetColumns func(uintptr, *uint) uintptr - rowsGetValue func(uintptr, int32) uintptr - closeRows func(uintptr) uintptr - freeCols func(uintptr) uintptr - rowsNext func(uintptr) uintptr - stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr - stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 - stmtParamCount func(uintptr) int32 - closeStmt func(uintptr) int32 + libOnce sync.Once + limboLib uintptr + loadErr error + dbOpen func(string) uintptr + dbClose func(uintptr) uintptr + connPrepare func(uintptr, string) uintptr + freeBlobFunc func(uintptr) + freeStringFunc func(uintptr) + rowsGetColumns func(uintptr) int32 + rowsGetColumnName func(uintptr, int32) uintptr + rowsGetValue func(uintptr, int32) uintptr + closeRows func(uintptr) uintptr + rowsNext func(uintptr) uintptr + stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr + stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 + stmtParamCount func(uintptr) int32 + closeStmt func(uintptr) int32 ) +// Register all the symbols on library load func ensureLibLoaded() error { libOnce.Do(func() { - purego.RegisterLibFunc(&dbOpen, library(), FfiDbOpen) - purego.RegisterLibFunc(&dbClose, library(), FfiDbClose) - purego.RegisterLibFunc(&connPrepare, library(), FfiDbPrepare) - purego.RegisterLibFunc(&freeBlobFunc, library(), FfiFreeBlob) - purego.RegisterLibFunc(&freeStringFunc, library(), FfiFreeCString) - purego.RegisterLibFunc(&rowsGetColumns, library(), FfiRowsGetColumns) - purego.RegisterLibFunc(&rowsGetValue, library(), FfiRowsGetValue) - purego.RegisterLibFunc(&closeRows, library(), FfiRowsClose) - purego.RegisterLibFunc(&freeCols, library(), FfiFreeColumns) - purego.RegisterLibFunc(&rowsNext, library(), FfiRowsNext) - purego.RegisterLibFunc(&stmtQuery, library(), FfiStmtQuery) - purego.RegisterLibFunc(&stmtExec, library(), FfiStmtExec) - purego.RegisterLibFunc(&stmtParamCount, library(), FfiStmtParameterCount) - purego.RegisterLibFunc(&closeStmt, library(), FfiStmtClose) + limboLib, loadErr = loadLibrary() + if loadErr != nil { + return + } + purego.RegisterLibFunc(&dbOpen, limboLib, FfiDbOpen) + purego.RegisterLibFunc(&dbClose, limboLib, FfiDbClose) + purego.RegisterLibFunc(&connPrepare, limboLib, FfiDbPrepare) + purego.RegisterLibFunc(&freeBlobFunc, limboLib, FfiFreeBlob) + purego.RegisterLibFunc(&freeStringFunc, limboLib, FfiFreeCString) + purego.RegisterLibFunc(&rowsGetColumns, limboLib, FfiRowsGetColumns) + purego.RegisterLibFunc(&rowsGetColumnName, limboLib, FfiRowsGetColumnName) + purego.RegisterLibFunc(&rowsGetValue, limboLib, FfiRowsGetValue) + purego.RegisterLibFunc(&closeRows, limboLib, FfiRowsClose) + purego.RegisterLibFunc(&rowsNext, limboLib, FfiRowsNext) + purego.RegisterLibFunc(&stmtQuery, limboLib, FfiStmtQuery) + purego.RegisterLibFunc(&stmtExec, limboLib, FfiStmtExec) + purego.RegisterLibFunc(&stmtParamCount, limboLib, FfiStmtParameterCount) + purego.RegisterLibFunc(&closeStmt, limboLib, FfiStmtClose) }) return loadErr } func (d *limboDriver) Open(name string) (driver.Conn, error) { d.Lock() - defer d.Unlock() conn, err := openConn(name) + d.Unlock() if err != nil { return nil, err } @@ -80,26 +82,24 @@ type limboConn struct { ctx uintptr } -func newConn(ctx uintptr) *limboConn { - return &limboConn{ - sync.Mutex{}, - ctx, - } -} - func openConn(dsn string) (*limboConn, error) { ctx := dbOpen(dsn) if ctx == 0 { return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) } - return newConn(ctx), loadErr + return &limboConn{ + sync.Mutex{}, + ctx, + }, loadErr } func (c *limboConn) Close() error { if c.ctx == 0 { return nil } + c.Lock() dbClose(c.ctx) + c.Unlock() c.ctx = 0 return nil } @@ -114,7 +114,7 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) { if stmtPtr == 0 { return nil, fmt.Errorf("failed to prepare query=%q", query) } - return initStmt(stmtPtr, query), nil + return newStmt(stmtPtr, query), nil } // begin is needed to implement driver.Conn.. for now not implemented diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go index da46d9964..a688cb34e 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/limbo_test.go @@ -18,7 +18,7 @@ func TestMain(m *testing.M) { panic(connErr) } defer conn.Close() - err := createTable() + err := createTable(conn) if err != nil { log.Fatalf("Error creating table: %v", err) } @@ -26,49 +26,13 @@ func TestMain(m *testing.M) { } func TestInsertData(t *testing.T) { - err := insertData() + err := insertData(conn) if err != nil { t.Fatalf("Error inserting data: %v", err) } } -func TestFunction(t *testing.T) { - insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));" - stmt, err := conn.Prepare(insert) - if err != nil { - t.Fatalf("Error preparing statement: %v", err) - } - _, err = stmt.Exec(1, "hello", 100) - if err != nil { - t.Fatalf("Error executing statment with arguments: %v", err) - } - stmt.Close() - stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?") - if err != nil { - t.Fatalf("Error preparing select stmt: %v", err) - } - defer stmt.Close() - rows, err := stmt.Query(1) - if err != nil { - t.Fatalf("Error executing select stmt: %v", err) - } - defer rows.Close() - for rows.Next() { - var b []byte - err = rows.Scan(&b) - if err != nil { - t.Fatalf("Error scanning row: %v", err) - } - fmt.Println("RESULTS: ", string(b)) - } -} - func TestQuery(t *testing.T) { - err := insertData() - if err != nil { - t.Fatalf("Error inserting data: %v", err) - } - query := "SELECT * FROM test;" stmt, err := conn.Prepare(query) if err != nil { @@ -117,6 +81,130 @@ func TestQuery(t *testing.T) { } +func TestFunctions(t *testing.T) { + insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));" + stmt, err := conn.Prepare(insert) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + _, err = stmt.Exec(60, "TestFunction", 400) + if err != nil { + t.Fatalf("Error executing statment with arguments: %v", err) + } + stmt.Close() + stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?") + if err != nil { + t.Fatalf("Error preparing select stmt: %v", err) + } + defer stmt.Close() + rows, err := stmt.Query(60) + if err != nil { + t.Fatalf("Error executing select stmt: %v", err) + } + defer rows.Close() + for rows.Next() { + var b []byte + err = rows.Scan(&b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if len(b) != 400 { + t.Fatalf("Expected 100 bytes, got %d", len(b)) + } + } + sql := "SELECT uuid4_str();" + stmt, err = conn.Prepare(sql) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + defer stmt.Close() + rows, err = stmt.Query() + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + var i int + for rows.Next() { + var b string + err = rows.Scan(&b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if len(b) != 36 { + t.Fatalf("Expected 36 bytes, got %d", len(b)) + } + i++ + fmt.Printf("uuid: %s\n", b) + } + if i != 1 { + t.Fatalf("Expected 1 row, got %d", i) + } + fmt.Println("zeroblob + uuid functions passed") +} + +func TestDuplicateConnection(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + err = createTable(newConn) + if err != nil { + t.Fatalf("Error creating table: %v", err) + } + err = insertData(newConn) + if err != nil { + t.Fatalf("Error inserting data: %v", err) + } + query := "SELECT * FROM test;" + rows, err := newConn.Query(query) + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + for rows.Next() { + var a int + var b string + var c []byte + err = rows.Scan(&a, &b, &c) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + fmt.Println("RESULTS: ", a, b, string(c)) + } +} + +func TestDuplicateConnection2(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + sql := "CREATE TABLE test (foo INTEGER, bar INTEGER, baz BLOB);" + newConn.Exec(sql) + sql = "INSERT INTO test (foo, bar, baz) VALUES (?, ?, uuid4());" + stmt, err := newConn.Prepare(sql) + stmt.Exec(242345, 2342434) + defer stmt.Close() + query := "SELECT * FROM test;" + rows, err := newConn.Query(query) + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + for rows.Next() { + var a int + var b int + var c []byte + err = rows.Scan(&a, &b, &c) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + fmt.Println("RESULTS: ", a, b, string(c)) + if len(c) != 16 { + t.Fatalf("Expected 16 bytes, got %d", len(c)) + } + } +} + func slicesAreEq(a, b []byte) bool { if len(a) != len(b) { fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b)) @@ -133,7 +221,7 @@ func slicesAreEq(a, b []byte) bool { var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"} -func createTable() error { +func createTable(conn *sql.DB) error { insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);" stmt, err := conn.Prepare(insert) if err != nil { @@ -144,7 +232,7 @@ func createTable() error { return err } -func insertData() error { +func insertData(conn *sql.DB) error { for i := 1; i <= 5; i++ { insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);" stmt, err := conn.Prepare(insert) diff --git a/bindings/go/limbo_unix.go b/bindings/go/limbo_unix.go index cb6976594..8ab911416 100644 --- a/bindings/go/limbo_unix.go +++ b/bindings/go/limbo_unix.go @@ -3,7 +3,6 @@ package limbo import ( - "database/sql" "fmt" "os" "path/filepath" @@ -44,11 +43,3 @@ func loadLibrary() (uintptr, error) { } return 0, fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName) } - -func init() { - err := ensureLibLoaded() - if err != nil { - panic(err) - } - sql.Register("sqlite3", &limboDriver{}) -} diff --git a/bindings/go/limbo_windows.go b/bindings/go/limbo_windows.go index 433ddd051..d56381176 100644 --- a/bindings/go/limbo_windows.go +++ b/bindings/go/limbo_windows.go @@ -3,7 +3,6 @@ package limbo import ( - "database/sql" "fmt" "os" "path/filepath" @@ -12,14 +11,14 @@ import ( "golang.org/x/sys/windows" ) -func loadLibrary() error { +func loadLibrary() (uintptr, error) { libName := fmt.Sprintf("%s.dll", libName) pathEnv := os.Getenv("PATH") paths := strings.Split(pathEnv, ";") cwd, err := os.Getwd() if err != nil { - return err + return 0, err } paths = append(paths, cwd) for _, path := range paths { @@ -27,21 +26,11 @@ func loadLibrary() error { if _, err := os.Stat(dllPath); err == nil { slib, loadErr := windows.LoadLibrary(dllPath) if loadErr != nil { - return fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr) + return 0, fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr) } - limboLib = uintptr(slib) - return nil + return uintptr(slib), nil } } - return fmt.Errorf("library %s not found in PATH or CWD", libName) -} - -func init() { - err := loadLibrary() - if err != nil { - fmt.Println("Error opening limbo library: ", err) - os.Exit(1) - } - sql.Register("sqlite3", &limboDriver{}) + return 0, fmt.Errorf("library %s not found in PATH or CWD", libName) } diff --git a/bindings/go/rows.go b/bindings/go/rows.go index 276f0fd1a..2f20bb93f 100644 --- a/bindings/go/rows.go +++ b/bindings/go/rows.go @@ -7,7 +7,6 @@ import ( "sync" ) -// only construct limboRows with initRows function to ensure proper initialization type limboRows struct { mu sync.Mutex ctx uintptr @@ -15,12 +14,12 @@ type limboRows struct { closed bool } -// Initialize/register the FFI function pointers for the rows methods -// DO NOT construct 'limboRows' without this function -func initRows(ctx uintptr) *limboRows { +func newRows(ctx uintptr) *limboRows { return &limboRows{ - mu: sync.Mutex{}, - ctx: ctx, + mu: sync.Mutex{}, + ctx: ctx, + closed: false, + columns: nil, } } @@ -29,15 +28,17 @@ func (r *limboRows) Columns() []string { return nil } if r.columns == nil { - var columnCount uint r.mu.Lock() - defer r.mu.Unlock() - colArrayPtr := rowsGetColumns(r.ctx, &columnCount) - if colArrayPtr != 0 && columnCount > 0 { - r.columns = cArrayToGoStrings(colArrayPtr, columnCount) - if freeCols != nil { - defer freeCols(colArrayPtr) + count := rowsGetColumns(r.ctx) + if count > 0 { + columns := make([]string, 0, count) + for i := 0; i < int(count); i++ { + cstr := rowsGetColumnName(r.ctx, int32(i)) + columns = append(columns, fmt.Sprintf("%s", GoString(cstr))) + freeCString(cstr) } + r.mu.Unlock() + r.columns = columns } } return r.columns @@ -59,14 +60,14 @@ func (r *limboRows) Next(dest []driver.Value) error { if r.ctx == 0 || r.closed { return io.EOF } + r.mu.Lock() + defer r.mu.Unlock() for { status := rowsNext(r.ctx) switch ResultCode(status) { case Row: for i := range dest { - r.mu.Lock() valPtr := rowsGetValue(r.ctx, int32(i)) - r.mu.Unlock() val := toGoValue(valPtr) dest[i] = val } diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs index 19b526c74..c8ec7642a 100644 --- a/bindings/go/rs_src/rows.rs +++ b/bindings/go/rs_src/rows.rs @@ -78,30 +78,30 @@ pub extern "C" fn free_string(s: *mut c_char) { } } +/// Function to get the number of expected ResultColumns in the prepared statement. +/// to avoid the needless complexity of returning an array of strings, this instead +/// works like rows_next/rows_get_value #[no_mangle] -pub extern "C" fn rows_get_columns( - rows_ptr: *mut c_void, - out_length: *mut usize, -) -> *mut *const c_char { - if rows_ptr.is_null() || out_length.is_null() { +pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 { + if rows_ptr.is_null() { + return -1; + } + let rows = LimboRows::from_ptr(rows_ptr); + rows.stmt.columns().len() as i32 +} + +#[no_mangle] +pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *const c_char { + if rows_ptr.is_null() { return std::ptr::null_mut(); } let rows = LimboRows::from_ptr(rows_ptr); - let c_strings: Vec = rows - .stmt - .columns() - .iter() - .map(|name| std::ffi::CString::new(name.as_str()).unwrap()) - .collect(); - - let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect(); - unsafe { - *out_length = c_ptrs.len(); + if idx < 0 || idx as usize >= rows.stmt.columns().len() { + return std::ptr::null_mut(); } - let ptr = c_ptrs.as_ptr(); - std::mem::forget(c_strings); - std::mem::forget(c_ptrs); - ptr as *mut *const c_char + let name = &rows.stmt.columns()[idx as usize]; + let cstr = std::ffi::CString::new(name.as_bytes()).expect("Failed to create CString"); + cstr.into_raw() as *const c_char } #[no_mangle] @@ -111,21 +111,6 @@ pub extern "C" fn rows_close(rows_ptr: *mut c_void) { } } -#[no_mangle] -pub extern "C" fn free_columns(columns: *mut *const c_char) { - if columns.is_null() { - return; - } - unsafe { - let mut idx = 0; - while !(*columns.add(idx)).is_null() { - let _ = std::ffi::CString::from_raw(*columns.add(idx) as *mut c_char); - idx += 1; - } - let _ = Box::from_raw(columns); - } -} - #[no_mangle] pub extern "C" fn free_rows(rows: *mut c_void) { if rows.is_null() { diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 6b0dfc80e..8ad015ded 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -132,13 +132,9 @@ pub struct LimboStatement<'conn> { #[no_mangle] pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode { if !ctx.is_null() { - let stmt = LimboStatement::from_ptr(ctx); - if stmt.statement.is_none() { - return ResultCode::Error; - } else { - let _ = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; - return ResultCode::Ok; - } + let stmt = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; + drop(stmt); + return ResultCode::Ok; } ResultCode::Invalid } diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 97e92149d..0f3c038fd 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -1,6 +1,7 @@ package limbo import ( + "context" "database/sql/driver" "errors" "fmt" @@ -8,22 +9,16 @@ import ( "unsafe" ) -// only construct limboStmt with initStmt function to ensure proper initialization -// inUse tracks whether or not `query` has been called. if inUse > 0, stmt no longer -// owns the underlying data and `rows` is responsible for cleaning it up on close. type limboStmt struct { - mu sync.Mutex - ctx uintptr - sql string - inUse int + mu sync.Mutex + ctx uintptr + sql string } -// Initialize/register the FFI function pointers for the statement methods -func initStmt(ctx uintptr, sql string) *limboStmt { +func newStmt(ctx uintptr, sql string) *limboStmt { return &limboStmt{ - ctx: uintptr(ctx), - sql: sql, - inUse: 0, + ctx: uintptr(ctx), + sql: sql, } } @@ -34,25 +29,19 @@ func (ls *limboStmt) NumInput() int { } func (ls *limboStmt) Close() error { - if ls.inUse == 0 { - ls.mu.Lock() - res := closeStmt(ls.ctx) - ls.mu.Unlock() - if ResultCode(res) != Ok { - return fmt.Errorf("error closing statement: %s", ResultCode(res).String()) - } + ls.mu.Lock() + res := closeStmt(ls.ctx) + ls.mu.Unlock() + if ResultCode(res) != Ok { + return fmt.Errorf("error closing statement: %s", ResultCode(res).String()) } ls.ctx = 0 return nil } func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { - ls.mu.Lock() argArray, cleanup, err := buildArgs(args) - defer func() { - cleanup() - ls.mu.Unlock() - }() + defer cleanup() if err != nil { return nil, err } @@ -62,7 +51,9 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { argPtr = uintptr(unsafe.Pointer(&argArray[0])) } var changes uint64 + ls.mu.Lock() rc := stmtExec(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) + ls.mu.Unlock() switch ResultCode(rc) { case Ok, Done: return driver.RowsAffected(changes), nil @@ -80,9 +71,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { } func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { - ls.mu.Lock() queryArgs, cleanup, err := buildArgs(args) - defer func() { cleanup(); ls.mu.Unlock() }() + defer cleanup() if err != nil { return nil, err } @@ -90,64 +80,66 @@ func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { if len(args) > 0 { argPtr = uintptr(unsafe.Pointer(&queryArgs[0])) } + ls.mu.Lock() rowsPtr := stmtQuery(ls.ctx, argPtr, uint64(len(queryArgs))) + ls.mu.Unlock() if rowsPtr == 0 { return nil, fmt.Errorf("query failed for: %q", ls.sql) } - ls.inUse++ - return initRows(rowsPtr), nil + return newRows(rowsPtr), nil } -// func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { -// ls.mu.Lock() -// stripped := namedValueToValue(args) -// argArray, cleanup, err := getArgsPtr(stripped) -// defer func() { cleanup(); ls.mu.Unlock() }() -// if err != nil { -// return nil, err -// } -// select { -// case <-ctx.Done(): -// return nil, ctx.Err() -// default: -// } -// var changes uint64 -// res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) -// switch ResultCode(res) { -// case Ok, Done: -// changes := uint64(changes) -// return driver.RowsAffected(changes), nil -// case Error: -// return nil, errors.New("error executing statement") -// case Busy: -// return nil, errors.New("busy") -// case Interrupt: -// return nil, errors.New("interrupted") -// default: -// return nil, fmt.Errorf("unexpected status: %d", res) -// } -// } -// -// func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { -// ls.mu.Lock() -// queryArgs, allocs, err := buildNamedArgs(args) -// defer func() { allocs(); ls.mu.Unlock() }() -// if err != nil { -// return nil, err -// } -// argsPtr := uintptr(0) -// if len(queryArgs) > 0 { -// argsPtr = uintptr(unsafe.Pointer(&queryArgs[0])) -// } -// select { -// case <-ctx.Done(): -// return nil, ctx.Err() -// default: -// } -// rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs))) -// if rowsPtr == 0 { -// return nil, fmt.Errorf("query failed for: %q", ls.sql) -// } -// ls.inUse++ -// return initRows(rowsPtr), nil -// } +func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + stripped := namedValueToValue(args) + argArray, cleanup, err := getArgsPtr(stripped) + defer cleanup() + if err != nil { + return nil, err + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + var changes uint64 + ls.mu.Lock() + res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) + ls.mu.Unlock() + switch ResultCode(res) { + case Ok, Done: + changes := uint64(changes) + return driver.RowsAffected(changes), nil + case Error: + return nil, errors.New("error executing statement") + case Busy: + return nil, errors.New("busy") + case Interrupt: + return nil, errors.New("interrupted") + default: + return nil, fmt.Errorf("unexpected status: %d", res) + } +} + +func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + queryArgs, allocs, err := buildNamedArgs(args) + defer allocs() + if err != nil { + return nil, err + } + argsPtr := uintptr(0) + if len(queryArgs) > 0 { + argsPtr = uintptr(unsafe.Pointer(&queryArgs[0])) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + ls.mu.Lock() + rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs))) + ls.mu.Unlock() + if rowsPtr == 0 { + return nil, fmt.Errorf("query failed for: %q", ls.sql) + } + return newRows(rowsPtr), nil +} diff --git a/bindings/go/types.go b/bindings/go/types.go index d54ff27fd..bcb023108 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -65,20 +65,23 @@ func (rc ResultCode) String() string { } const ( - FfiDbOpen string = "db_open" - FfiDbClose string = "db_close" - FfiDbPrepare string = "db_prepare" - FfiStmtExec string = "stmt_execute" - FfiStmtQuery string = "stmt_query" - FfiStmtParameterCount string = "stmt_parameter_count" - FfiStmtClose string = "stmt_close" - FfiRowsClose string = "rows_close" - FfiRowsGetColumns string = "rows_get_columns" - FfiRowsNext string = "rows_next" - FfiRowsGetValue string = "rows_get_value" - FfiFreeColumns string = "free_columns" - FfiFreeCString string = "free_string" - FfiFreeBlob string = "free_blob" + driverName = "sqlite3" + libName = "lib_limbo_go" + FfiDbOpen = "db_open" + FfiDbClose = "db_close" + FfiDbPrepare = "db_prepare" + FfiStmtExec = "stmt_execute" + FfiStmtQuery = "stmt_query" + FfiStmtParameterCount = "stmt_parameter_count" + FfiStmtClose = "stmt_close" + FfiRowsClose = "rows_close" + FfiRowsGetColumns = "rows_get_columns" + FfiRowsGetColumnName = "rows_get_column_name" + FfiRowsNext = "rows_next" + FfiRowsGetValue = "rows_get_value" + FfiFreeColumns = "free_columns" + FfiFreeCString = "free_string" + FfiFreeBlob = "free_blob" ) // convert a namedValue slice into normal values until named parameters are supported @@ -212,22 +215,6 @@ func freeCString(cstrPtr uintptr) { freeStringFunc(cstrPtr) } -func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { - if arrayPtr == 0 || length == 0 { - return nil - } - ptrSlice := unsafe.Slice( - (**byte)(unsafe.Pointer(arrayPtr)), - length, - ) - - out := make([]string, 0, length) - for _, cstr := range ptrSlice { - out = append(out, GoString(uintptr(unsafe.Pointer(cstr)))) - } - return out -} - // convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI // for Blob types, we have to pin them so they are not garbage collected before they can be copied // into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call