From 950f29daab83c9bf20602ea3ae576823694f876e Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Thu, 30 Jan 2025 13:09:39 -0500 Subject: [PATCH 1/3] bindings/go: Adjust tests for multiple concurrent connections --- .gitignore | 5 +- bindings/go/connection.go | 93 +++++++++---- bindings/go/limbo_test.go | 106 ++++++++------ bindings/go/limbo_unix.go | 18 ++- bindings/go/rows.go | 82 +++++++++++ bindings/go/rs_src/lib.rs | 23 +-- bindings/go/rs_src/statement.rs | 17 ++- bindings/go/stmt.go | 238 +++++++++++--------------------- bindings/go/types.go | 12 +- 9 files changed, 322 insertions(+), 272 deletions(-) create mode 100644 bindings/go/rows.go diff --git a/.gitignore b/.gitignore index 1f1406ceb..c7c56a7ee 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,7 @@ dist/ .DS_Store # Javascript -**/node_modules/ \ No newline at end of file +**/node_modules/ + +# testing +testing/limbo_output.txt diff --git a/bindings/go/connection.go b/bindings/go/connection.go index 8c45824e8..a1ab83fae 100644 --- a/bindings/go/connection.go +++ b/bindings/go/connection.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "errors" "fmt" - "unsafe" + "sync" "github.com/ebitengine/purego" ) @@ -14,57 +14,91 @@ const ( libName = "lib_limbo_go" ) -var limboLib uintptr - -type limboDriver struct{} - -func (d limboDriver) Open(name string) (driver.Conn, error) { - return openConn(name) +type limboDriver struct { + sync.Mutex } -func toCString(s string) uintptr { - b := append([]byte(s), 0) - return uintptr(unsafe.Pointer(&b[0])) +var library = sync.OnceValue(func() uintptr { + lib, err := loadLibrary() + if err != nil { + panic(err) + } + return lib +}) + +var ( + libOnce sync.Once + loadErr error + dbOpen func(string) uintptr + dbClose func(uintptr) uintptr + connPrepare func(uintptr, string) uintptr + freeBlobFunc func(uintptr) + freeStringFunc func(uintptr) + rowsGetColumns func(uintptr, *uint) uintptr + rowsGetValue func(uintptr, int32) uintptr + closeRows func(uintptr) uintptr + freeCols func(uintptr) uintptr + rowsNext func(uintptr) uintptr + stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr + stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 + stmtParamCount func(uintptr) int32 + closeStmt func(uintptr) int32 +) + +func ensureLibLoaded() error { + libOnce.Do(func() { + purego.RegisterLibFunc(&dbOpen, library(), FfiDbOpen) + purego.RegisterLibFunc(&dbClose, library(), FfiDbClose) + purego.RegisterLibFunc(&connPrepare, library(), FfiDbPrepare) + purego.RegisterLibFunc(&freeBlobFunc, library(), FfiFreeBlob) + purego.RegisterLibFunc(&freeStringFunc, library(), FfiFreeCString) + purego.RegisterLibFunc(&rowsGetColumns, library(), FfiRowsGetColumns) + purego.RegisterLibFunc(&rowsGetValue, library(), FfiRowsGetValue) + purego.RegisterLibFunc(&closeRows, library(), FfiRowsClose) + purego.RegisterLibFunc(&freeCols, library(), FfiFreeColumns) + purego.RegisterLibFunc(&rowsNext, library(), FfiRowsNext) + purego.RegisterLibFunc(&stmtQuery, library(), FfiStmtQuery) + purego.RegisterLibFunc(&stmtExec, library(), FfiStmtExec) + purego.RegisterLibFunc(&stmtParamCount, library(), FfiStmtParameterCount) + purego.RegisterLibFunc(&closeStmt, library(), FfiStmtClose) + }) + return loadErr } -// helper to register an FFI function in the lib_limbo_go library -func getFfiFunc(ptr interface{}, name string) { - purego.RegisterLibFunc(ptr, limboLib, name) +func (d *limboDriver) Open(name string) (driver.Conn, error) { + d.Lock() + defer d.Unlock() + conn, err := openConn(name) + if err != nil { + return nil, err + } + return conn, nil } -// TODO: sync primitives type limboConn struct { - ctx uintptr - prepare func(uintptr, string) uintptr + sync.Mutex + ctx uintptr } func newConn(ctx uintptr) *limboConn { - var prepare func(uintptr, string) uintptr - getFfiFunc(&prepare, FfiDbPrepare) return &limboConn{ + sync.Mutex{}, ctx, - prepare, } } func openConn(dsn string) (*limboConn, error) { - var dbOpen func(string) uintptr - getFfiFunc(&dbOpen, FfiDbOpen) - ctx := dbOpen(dsn) if ctx == 0 { return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) } - return newConn(ctx), nil + return newConn(ctx), loadErr } func (c *limboConn) Close() error { if c.ctx == 0 { return nil } - var dbClose func(uintptr) uintptr - getFfiFunc(&dbClose, FfiDbClose) - dbClose(c.ctx) c.ctx = 0 return nil @@ -74,10 +108,9 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) { if c.ctx == 0 { return nil, errors.New("connection closed") } - if c.prepare == nil { - panic("prepare function not set") - } - stmtPtr := c.prepare(c.ctx, query) + c.Lock() + defer c.Unlock() + stmtPtr := connPrepare(c.ctx, query) if stmtPtr == 0 { return nil, fmt.Errorf("failed to prepare query=%q", query) } diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go index 1a787a149..da46d9964 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/limbo_test.go @@ -3,63 +3,68 @@ package limbo_test import ( "database/sql" "fmt" + "log" "testing" _ "limbo" ) -func TestConnection(t *testing.T) { - conn, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatalf("Error opening database: %v", err) +var conn *sql.DB +var connErr error + +func TestMain(m *testing.M) { + conn, connErr = sql.Open("sqlite3", ":memory:") + if connErr != nil { + panic(connErr) } defer conn.Close() -} - -func TestCreateTable(t *testing.T) { - conn, err := sql.Open("sqlite3", ":memory:") + err := createTable() if err != nil { - t.Fatalf("Error opening database: %v", err) - } - defer conn.Close() - - err = createTable(conn) - if err != nil { - t.Fatalf("Error creating table: %v", err) + log.Fatalf("Error creating table: %v", err) } + m.Run() } func TestInsertData(t *testing.T) { - conn, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatalf("Error opening database: %v", err) - } - defer conn.Close() - - err = createTable(conn) - if err != nil { - t.Fatalf("Error creating table: %v", err) - } - - err = insertData(conn) + err := insertData() if err != nil { t.Fatalf("Error inserting data: %v", err) } } +func TestFunction(t *testing.T) { + insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));" + stmt, err := conn.Prepare(insert) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + _, err = stmt.Exec(1, "hello", 100) + if err != nil { + t.Fatalf("Error executing statment with arguments: %v", err) + } + stmt.Close() + stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?") + if err != nil { + t.Fatalf("Error preparing select stmt: %v", err) + } + defer stmt.Close() + rows, err := stmt.Query(1) + if err != nil { + t.Fatalf("Error executing select stmt: %v", err) + } + defer rows.Close() + for rows.Next() { + var b []byte + err = rows.Scan(&b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + fmt.Println("RESULTS: ", string(b)) + } +} + func TestQuery(t *testing.T) { - conn, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatalf("Error opening database: %v", err) - } - defer conn.Close() - - err = createTable(conn) - if err != nil { - t.Fatalf("Error creating table: %v", err) - } - - err = insertData(conn) + err := insertData() if err != nil { t.Fatalf("Error inserting data: %v", err) } @@ -99,8 +104,8 @@ func TestQuery(t *testing.T) { if err != nil { t.Fatalf("Error scanning row: %v", err) } - if a != i || b != rowsMap[i] || string(c) != rowsMap[i] { - t.Fatalf("Expected %d, %s, %s, got %d, %s, %b", i, rowsMap[i], rowsMap[i], a, b, c) + if a != i || b != rowsMap[i] || !slicesAreEq(c, []byte(rowsMap[i])) { + t.Fatalf("Expected %d, %s, %s, got %d, %s, %s", i, rowsMap[i], rowsMap[i], a, b, string(c)) } fmt.Println("RESULTS: ", a, b, string(c)) i++ @@ -109,11 +114,26 @@ func TestQuery(t *testing.T) { if err = rows.Err(); err != nil { t.Fatalf("Row iteration error: %v", err) } + +} + +func slicesAreEq(a, b []byte) bool { + if len(a) != len(b) { + fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b)) + return false + } + for i := range a { + if a[i] != b[i] { + fmt.Printf("SLICES NOT EQUAL: %v != %v\n", a, b) + return false + } + } + return true } var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"} -func createTable(conn *sql.DB) error { +func createTable() error { insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);" stmt, err := conn.Prepare(insert) if err != nil { @@ -124,7 +144,7 @@ func createTable(conn *sql.DB) error { return err } -func insertData(conn *sql.DB) error { +func insertData() error { for i := 1; i <= 5; i++ { insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);" stmt, err := conn.Prepare(insert) diff --git a/bindings/go/limbo_unix.go b/bindings/go/limbo_unix.go index ecafa5a85..cb6976594 100644 --- a/bindings/go/limbo_unix.go +++ b/bindings/go/limbo_unix.go @@ -13,7 +13,7 @@ import ( "github.com/ebitengine/purego" ) -func loadLibrary() error { +func loadLibrary() (uintptr, error) { var libraryName string switch runtime.GOOS { case "darwin": @@ -21,14 +21,14 @@ func loadLibrary() error { case "linux": libraryName = fmt.Sprintf("%s.so", libName) default: - return fmt.Errorf("GOOS=%s is not supported", runtime.GOOS) + return 0, fmt.Errorf("GOOS=%s is not supported", runtime.GOOS) } libPath := os.Getenv("LD_LIBRARY_PATH") paths := strings.Split(libPath, ":") cwd, err := os.Getwd() if err != nil { - return err + return 0, err } paths = append(paths, cwd) @@ -37,20 +37,18 @@ func loadLibrary() error { if _, err := os.Stat(libPath); err == nil { slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL) if dlerr != nil { - return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr) + return 0, fmt.Errorf("failed to load library at %s: %w", libPath, dlerr) } - limboLib = slib - return nil + return slib, nil } } - return fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName) + return 0, fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName) } func init() { - err := loadLibrary() + err := ensureLibLoaded() if err != nil { - fmt.Println(err) - os.Exit(1) + panic(err) } sql.Register("sqlite3", &limboDriver{}) } diff --git a/bindings/go/rows.go b/bindings/go/rows.go new file mode 100644 index 000000000..276f0fd1a --- /dev/null +++ b/bindings/go/rows.go @@ -0,0 +1,82 @@ +package limbo + +import ( + "database/sql/driver" + "fmt" + "io" + "sync" +) + +// only construct limboRows with initRows function to ensure proper initialization +type limboRows struct { + mu sync.Mutex + ctx uintptr + columns []string + closed bool +} + +// Initialize/register the FFI function pointers for the rows methods +// DO NOT construct 'limboRows' without this function +func initRows(ctx uintptr) *limboRows { + return &limboRows{ + mu: sync.Mutex{}, + ctx: ctx, + } +} + +func (r *limboRows) Columns() []string { + if r.ctx == 0 || r.closed { + return nil + } + if r.columns == nil { + var columnCount uint + r.mu.Lock() + defer r.mu.Unlock() + colArrayPtr := rowsGetColumns(r.ctx, &columnCount) + if colArrayPtr != 0 && columnCount > 0 { + r.columns = cArrayToGoStrings(colArrayPtr, columnCount) + if freeCols != nil { + defer freeCols(colArrayPtr) + } + } + } + return r.columns +} + +func (r *limboRows) Close() error { + if r.closed { + return nil + } + r.mu.Lock() + r.closed = true + closeRows(r.ctx) + r.ctx = 0 + r.mu.Unlock() + return nil +} + +func (r *limboRows) Next(dest []driver.Value) error { + if r.ctx == 0 || r.closed { + return io.EOF + } + for { + status := rowsNext(r.ctx) + switch ResultCode(status) { + case Row: + for i := range dest { + r.mu.Lock() + valPtr := rowsGetValue(r.ctx, int32(i)) + r.mu.Unlock() + val := toGoValue(valPtr) + dest[i] = val + } + return nil + case Io: + continue + case Done: + return io.EOF + default: + return fmt.Errorf("unexpected status: %d", status) + } + } +} diff --git a/bindings/go/rs_src/lib.rs b/bindings/go/rs_src/lib.rs index 199ed10c0..70dbce89a 100644 --- a/bindings/go/rs_src/lib.rs +++ b/bindings/go/rs_src/lib.rs @@ -7,7 +7,7 @@ use std::{ ffi::{c_char, c_void}, rc::Rc, str::FromStr, - sync::Arc, + sync::{Arc, RwLock}, }; /// # Safety @@ -26,7 +26,6 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { let db = Database::open_file(io.clone(), &db_options.path.to_string()); match db { Ok(db) => { - println!("Opened database: {}", path); let conn = db.connect(); return LimboConn::new(conn, io).to_ptr(); } @@ -41,20 +40,24 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { #[allow(dead_code)] struct LimboConn { - conn: Rc, + conn: RwLock>, io: Arc, } -impl LimboConn { +impl<'conn> LimboConn { fn new(conn: Rc, io: Arc) -> Self { - LimboConn { conn, io } - } - #[allow(clippy::wrong_self_convention)] - fn to_ptr(self) -> *mut c_void { - Box::into_raw(Box::new(self)) as *mut c_void + LimboConn { + conn: conn.into(), + io, + } } - fn from_ptr(ptr: *mut c_void) -> &'static mut LimboConn { + #[allow(clippy::wrong_self_convention)] + fn to_ptr(self) -> *mut c_void { + Arc::into_raw(Arc::new(self)) as *mut c_void + } + + fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboConn { if ptr.is_null() { panic!("Null pointer"); } diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 329f9b59c..6b0dfc80e 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -13,10 +13,12 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v let query_str = unsafe { std::ffi::CStr::from_ptr(query) }.to_str().unwrap(); let db = LimboConn::from_ptr(ctx); - - let stmt = db.conn.prepare(query_str); + let Ok(conn) = db.conn.read() else { + return std::ptr::null_mut(); + }; + let stmt = conn.prepare(query_str); match stmt { - Ok(stmt) => LimboStatement::new(Some(stmt), db).to_ptr(), + Ok(stmt) => LimboStatement::new(Some(stmt), LimboConn::from_ptr(ctx)).to_ptr(), Err(_) => std::ptr::null_mut(), } } @@ -53,10 +55,13 @@ pub extern "C" fn stmt_execute( return ResultCode::Error; } Ok(StepResult::Done) => { - stmt.conn.conn.total_changes(); + let Ok(conn) = stmt.conn.conn.read() else { + return ResultCode::Done; + }; + let total_changes = conn.total_changes(); if !changes.is_null() { unsafe { - *changes = stmt.conn.conn.total_changes(); + *changes = total_changes; } } return ResultCode::Done; @@ -148,7 +153,7 @@ impl<'conn> LimboStatement<'conn> { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'static mut LimboStatement<'conn> { + fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboStatement<'conn> { if ptr.is_null() { panic!("Null pointer"); } diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 02ad3c3eb..97e92149d 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -1,11 +1,10 @@ package limbo import ( - "context" "database/sql/driver" "errors" "fmt" - "io" + "sync" "unsafe" ) @@ -13,43 +12,32 @@ import ( // inUse tracks whether or not `query` has been called. if inUse > 0, stmt no longer // owns the underlying data and `rows` is responsible for cleaning it up on close. type limboStmt struct { - ctx uintptr - sql string - inUse int - query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr - execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 - getParamCount func(uintptr) int32 - closeStmt func(uintptr) int32 + mu sync.Mutex + ctx uintptr + sql string + inUse int } // Initialize/register the FFI function pointers for the statement methods func initStmt(ctx uintptr, sql string) *limboStmt { - var query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr - getFfiFunc(&query, FfiStmtQuery) - var execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 - getFfiFunc(&execute, FfiStmtExec) - var getParamCount func(uintptr) int32 - getFfiFunc(&getParamCount, FfiStmtParameterCount) - var closeStmt func(uintptr) int32 - getFfiFunc(&closeStmt, FfiStmtClose) return &limboStmt{ - ctx: uintptr(ctx), - sql: sql, - inUse: 0, - execute: execute, - query: query, - getParamCount: getParamCount, - closeStmt: closeStmt, + ctx: uintptr(ctx), + sql: sql, + inUse: 0, } } func (ls *limboStmt) NumInput() int { - return int(ls.getParamCount(ls.ctx)) + ls.mu.Lock() + defer ls.mu.Unlock() + return int(stmtParamCount(ls.ctx)) } func (ls *limboStmt) Close() error { if ls.inUse == 0 { - res := ls.closeStmt(ls.ctx) + ls.mu.Lock() + res := closeStmt(ls.ctx) + ls.mu.Unlock() if ResultCode(res) != Ok { return fmt.Errorf("error closing statement: %s", ResultCode(res).String()) } @@ -59,8 +47,12 @@ func (ls *limboStmt) Close() error { } func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { + ls.mu.Lock() argArray, cleanup, err := buildArgs(args) - defer cleanup() + defer func() { + cleanup() + ls.mu.Unlock() + }() if err != nil { return nil, err } @@ -70,7 +62,7 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { argPtr = uintptr(unsafe.Pointer(&argArray[0])) } var changes uint64 - rc := ls.execute(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) + rc := stmtExec(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) switch ResultCode(rc) { case Ok, Done: return driver.RowsAffected(changes), nil @@ -87,9 +79,10 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { } } -func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) { +func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { + ls.mu.Lock() queryArgs, cleanup, err := buildArgs(args) - defer cleanup() + defer func() { cleanup(); ls.mu.Unlock() }() if err != nil { return nil, err } @@ -97,59 +90,7 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) { if len(args) > 0 { argPtr = uintptr(unsafe.Pointer(&queryArgs[0])) } - rowsPtr := st.query(st.ctx, argPtr, uint64(len(queryArgs))) - if rowsPtr == 0 { - return nil, fmt.Errorf("query failed for: %q", st.sql) - } - st.inUse++ - return initRows(rowsPtr), nil -} - -func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - stripped := namedValueToValue(args) - argArray, cleanup, err := getArgsPtr(stripped) - defer cleanup() - if err != nil { - return nil, err - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - var changes uint64 - res := ls.execute(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) - switch ResultCode(res) { - case Ok, Done: - changes := uint64(changes) - return driver.RowsAffected(changes), nil - case Error: - return nil, errors.New("error executing statement") - case Busy: - return nil, errors.New("busy") - case Interrupt: - return nil, errors.New("interrupted") - default: - return nil, fmt.Errorf("unexpected status: %d", res) - } -} - -func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - queryArgs, allocs, err := buildNamedArgs(args) - defer allocs() - if err != nil { - return nil, err - } - argsPtr := uintptr(0) - if len(queryArgs) > 0 { - argsPtr = uintptr(unsafe.Pointer(&queryArgs[0])) - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - rowsPtr := ls.query(ls.ctx, argsPtr, uint64(len(queryArgs))) + rowsPtr := stmtQuery(ls.ctx, argPtr, uint64(len(queryArgs))) if rowsPtr == 0 { return nil, fmt.Errorf("query failed for: %q", ls.sql) } @@ -157,81 +98,56 @@ func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) return initRows(rowsPtr), nil } -// only construct limboRows with initRows function to ensure proper initialization -type limboRows struct { - ctx uintptr - columns []string - closed bool - getCols func(uintptr, *uint) uintptr - next func(uintptr) uintptr - getValue func(uintptr, int32) uintptr - closeRows func(uintptr) uintptr - freeCols func(uintptr) uintptr -} - -// Initialize/register the FFI function pointers for the rows methods -// DO NOT construct 'limboRows' without this function -func initRows(ctx uintptr) *limboRows { - var getCols func(uintptr, *uint) uintptr - getFfiFunc(&getCols, FfiRowsGetColumns) - var getValue func(uintptr, int32) uintptr - getFfiFunc(&getValue, FfiRowsGetValue) - var closeRows func(uintptr) uintptr - getFfiFunc(&closeRows, FfiRowsClose) - var freeCols func(uintptr) uintptr - getFfiFunc(&freeCols, FfiFreeColumns) - var next func(uintptr) uintptr - getFfiFunc(&next, FfiRowsNext) - - return &limboRows{ - ctx: ctx, - getCols: getCols, - getValue: getValue, - closeRows: closeRows, - freeCols: freeCols, - next: next, - } -} - -func (r *limboRows) Columns() []string { - if r.columns == nil { - var columnCount uint - colArrayPtr := r.getCols(r.ctx, &columnCount) - if colArrayPtr != 0 && columnCount > 0 { - r.columns = cArrayToGoStrings(colArrayPtr, columnCount) - defer r.freeCols(colArrayPtr) - } - } - return r.columns -} - -func (r *limboRows) Close() error { - if r.closed { - return nil - } - r.closed = true - r.closeRows(r.ctx) - r.ctx = 0 - return nil -} - -func (r *limboRows) Next(dest []driver.Value) error { - for { - status := r.next(r.ctx) - switch ResultCode(status) { - case Row: - for i := range dest { - valPtr := r.getValue(r.ctx, int32(i)) - val := toGoValue(valPtr) - dest[i] = val - } - return nil - case Io: - continue - case Done: - return io.EOF - default: - return fmt.Errorf("unexpected status: %d", status) - } - } -} +// func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +// ls.mu.Lock() +// stripped := namedValueToValue(args) +// argArray, cleanup, err := getArgsPtr(stripped) +// defer func() { cleanup(); ls.mu.Unlock() }() +// if err != nil { +// return nil, err +// } +// select { +// case <-ctx.Done(): +// return nil, ctx.Err() +// default: +// } +// var changes uint64 +// res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) +// switch ResultCode(res) { +// case Ok, Done: +// changes := uint64(changes) +// return driver.RowsAffected(changes), nil +// case Error: +// return nil, errors.New("error executing statement") +// case Busy: +// return nil, errors.New("busy") +// case Interrupt: +// return nil, errors.New("interrupted") +// default: +// return nil, fmt.Errorf("unexpected status: %d", res) +// } +// } +// +// func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +// ls.mu.Lock() +// queryArgs, allocs, err := buildNamedArgs(args) +// defer func() { allocs(); ls.mu.Unlock() }() +// if err != nil { +// return nil, err +// } +// argsPtr := uintptr(0) +// if len(queryArgs) > 0 { +// argsPtr = uintptr(unsafe.Pointer(&queryArgs[0])) +// } +// select { +// case <-ctx.Done(): +// return nil, ctx.Err() +// default: +// } +// rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs))) +// if rowsPtr == 0 { +// return nil, fmt.Errorf("query failed for: %q", ls.sql) +// } +// ls.inUse++ +// return initRows(rowsPtr), nil +// } diff --git a/bindings/go/types.go b/bindings/go/types.go index 78fb96153..d54ff27fd 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -198,27 +198,17 @@ func toGoBlob(blobPtr uintptr) []byte { return copied } -var freeBlobFunc func(uintptr) - func freeBlob(blobPtr uintptr) { if blobPtr == 0 { return } - if freeBlobFunc == nil { - getFfiFunc(&freeBlobFunc, FfiFreeBlob) - } freeBlobFunc(blobPtr) } -var freeStringFunc func(uintptr) - func freeCString(cstrPtr uintptr) { if cstrPtr == 0 { return } - if freeStringFunc == nil { - getFfiFunc(&freeStringFunc, FfiFreeCString) - } freeStringFunc(cstrPtr) } @@ -226,7 +216,6 @@ func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { if arrayPtr == 0 || length == 0 { return nil } - ptrSlice := unsafe.Slice( (**byte)(unsafe.Pointer(arrayPtr)), length, @@ -259,6 +248,7 @@ func buildArgs(args []driver.Value) ([]limboValue, func(), error) { case string: limboVal.Type = textVal cstr := CString(val) + pinner.Pin(cstr) *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr)) case []byte: limboVal.Type = blobVal From 8d93130809b5e3553a03c286f64a9486bcc7e2ba Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Fri, 31 Jan 2025 13:22:48 -0500 Subject: [PATCH 2/3] bindings/go: enable multiple connections, register all symbols at library load --- bindings/go/connection.go | 104 ++++++++++---------- bindings/go/limbo_test.go | 168 ++++++++++++++++++++++++-------- bindings/go/limbo_unix.go | 9 -- bindings/go/limbo_windows.go | 21 +--- bindings/go/rows.go | 31 +++--- bindings/go/rs_src/rows.rs | 53 ++++------ bindings/go/rs_src/statement.rs | 10 +- bindings/go/stmt.go | 154 ++++++++++++++--------------- bindings/go/types.go | 47 ++++----- 9 files changed, 313 insertions(+), 284 deletions(-) diff --git a/bindings/go/connection.go b/bindings/go/connection.go index a1ab83fae..672f2e4e3 100644 --- a/bindings/go/connection.go +++ b/bindings/go/connection.go @@ -1,6 +1,7 @@ package limbo import ( + "database/sql" "database/sql/driver" "errors" "fmt" @@ -9,66 +10,67 @@ import ( "github.com/ebitengine/purego" ) -const ( - driverName = "sqlite3" - libName = "lib_limbo_go" -) +func init() { + err := ensureLibLoaded() + if err != nil { + panic(err) + } + sql.Register(driverName, &limboDriver{}) +} type limboDriver struct { sync.Mutex } -var library = sync.OnceValue(func() uintptr { - lib, err := loadLibrary() - if err != nil { - panic(err) - } - return lib -}) - var ( - libOnce sync.Once - loadErr error - dbOpen func(string) uintptr - dbClose func(uintptr) uintptr - connPrepare func(uintptr, string) uintptr - freeBlobFunc func(uintptr) - freeStringFunc func(uintptr) - rowsGetColumns func(uintptr, *uint) uintptr - rowsGetValue func(uintptr, int32) uintptr - closeRows func(uintptr) uintptr - freeCols func(uintptr) uintptr - rowsNext func(uintptr) uintptr - stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr - stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 - stmtParamCount func(uintptr) int32 - closeStmt func(uintptr) int32 + libOnce sync.Once + limboLib uintptr + loadErr error + dbOpen func(string) uintptr + dbClose func(uintptr) uintptr + connPrepare func(uintptr, string) uintptr + freeBlobFunc func(uintptr) + freeStringFunc func(uintptr) + rowsGetColumns func(uintptr) int32 + rowsGetColumnName func(uintptr, int32) uintptr + rowsGetValue func(uintptr, int32) uintptr + closeRows func(uintptr) uintptr + rowsNext func(uintptr) uintptr + stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr + stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 + stmtParamCount func(uintptr) int32 + closeStmt func(uintptr) int32 ) +// Register all the symbols on library load func ensureLibLoaded() error { libOnce.Do(func() { - purego.RegisterLibFunc(&dbOpen, library(), FfiDbOpen) - purego.RegisterLibFunc(&dbClose, library(), FfiDbClose) - purego.RegisterLibFunc(&connPrepare, library(), FfiDbPrepare) - purego.RegisterLibFunc(&freeBlobFunc, library(), FfiFreeBlob) - purego.RegisterLibFunc(&freeStringFunc, library(), FfiFreeCString) - purego.RegisterLibFunc(&rowsGetColumns, library(), FfiRowsGetColumns) - purego.RegisterLibFunc(&rowsGetValue, library(), FfiRowsGetValue) - purego.RegisterLibFunc(&closeRows, library(), FfiRowsClose) - purego.RegisterLibFunc(&freeCols, library(), FfiFreeColumns) - purego.RegisterLibFunc(&rowsNext, library(), FfiRowsNext) - purego.RegisterLibFunc(&stmtQuery, library(), FfiStmtQuery) - purego.RegisterLibFunc(&stmtExec, library(), FfiStmtExec) - purego.RegisterLibFunc(&stmtParamCount, library(), FfiStmtParameterCount) - purego.RegisterLibFunc(&closeStmt, library(), FfiStmtClose) + limboLib, loadErr = loadLibrary() + if loadErr != nil { + return + } + purego.RegisterLibFunc(&dbOpen, limboLib, FfiDbOpen) + purego.RegisterLibFunc(&dbClose, limboLib, FfiDbClose) + purego.RegisterLibFunc(&connPrepare, limboLib, FfiDbPrepare) + purego.RegisterLibFunc(&freeBlobFunc, limboLib, FfiFreeBlob) + purego.RegisterLibFunc(&freeStringFunc, limboLib, FfiFreeCString) + purego.RegisterLibFunc(&rowsGetColumns, limboLib, FfiRowsGetColumns) + purego.RegisterLibFunc(&rowsGetColumnName, limboLib, FfiRowsGetColumnName) + purego.RegisterLibFunc(&rowsGetValue, limboLib, FfiRowsGetValue) + purego.RegisterLibFunc(&closeRows, limboLib, FfiRowsClose) + purego.RegisterLibFunc(&rowsNext, limboLib, FfiRowsNext) + purego.RegisterLibFunc(&stmtQuery, limboLib, FfiStmtQuery) + purego.RegisterLibFunc(&stmtExec, limboLib, FfiStmtExec) + purego.RegisterLibFunc(&stmtParamCount, limboLib, FfiStmtParameterCount) + purego.RegisterLibFunc(&closeStmt, limboLib, FfiStmtClose) }) return loadErr } func (d *limboDriver) Open(name string) (driver.Conn, error) { d.Lock() - defer d.Unlock() conn, err := openConn(name) + d.Unlock() if err != nil { return nil, err } @@ -80,26 +82,24 @@ type limboConn struct { ctx uintptr } -func newConn(ctx uintptr) *limboConn { - return &limboConn{ - sync.Mutex{}, - ctx, - } -} - func openConn(dsn string) (*limboConn, error) { ctx := dbOpen(dsn) if ctx == 0 { return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) } - return newConn(ctx), loadErr + return &limboConn{ + sync.Mutex{}, + ctx, + }, loadErr } func (c *limboConn) Close() error { if c.ctx == 0 { return nil } + c.Lock() dbClose(c.ctx) + c.Unlock() c.ctx = 0 return nil } @@ -114,7 +114,7 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) { if stmtPtr == 0 { return nil, fmt.Errorf("failed to prepare query=%q", query) } - return initStmt(stmtPtr, query), nil + return newStmt(stmtPtr, query), nil } // begin is needed to implement driver.Conn.. for now not implemented diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go index da46d9964..a688cb34e 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/limbo_test.go @@ -18,7 +18,7 @@ func TestMain(m *testing.M) { panic(connErr) } defer conn.Close() - err := createTable() + err := createTable(conn) if err != nil { log.Fatalf("Error creating table: %v", err) } @@ -26,49 +26,13 @@ func TestMain(m *testing.M) { } func TestInsertData(t *testing.T) { - err := insertData() + err := insertData(conn) if err != nil { t.Fatalf("Error inserting data: %v", err) } } -func TestFunction(t *testing.T) { - insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));" - stmt, err := conn.Prepare(insert) - if err != nil { - t.Fatalf("Error preparing statement: %v", err) - } - _, err = stmt.Exec(1, "hello", 100) - if err != nil { - t.Fatalf("Error executing statment with arguments: %v", err) - } - stmt.Close() - stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?") - if err != nil { - t.Fatalf("Error preparing select stmt: %v", err) - } - defer stmt.Close() - rows, err := stmt.Query(1) - if err != nil { - t.Fatalf("Error executing select stmt: %v", err) - } - defer rows.Close() - for rows.Next() { - var b []byte - err = rows.Scan(&b) - if err != nil { - t.Fatalf("Error scanning row: %v", err) - } - fmt.Println("RESULTS: ", string(b)) - } -} - func TestQuery(t *testing.T) { - err := insertData() - if err != nil { - t.Fatalf("Error inserting data: %v", err) - } - query := "SELECT * FROM test;" stmt, err := conn.Prepare(query) if err != nil { @@ -117,6 +81,130 @@ func TestQuery(t *testing.T) { } +func TestFunctions(t *testing.T) { + insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));" + stmt, err := conn.Prepare(insert) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + _, err = stmt.Exec(60, "TestFunction", 400) + if err != nil { + t.Fatalf("Error executing statment with arguments: %v", err) + } + stmt.Close() + stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?") + if err != nil { + t.Fatalf("Error preparing select stmt: %v", err) + } + defer stmt.Close() + rows, err := stmt.Query(60) + if err != nil { + t.Fatalf("Error executing select stmt: %v", err) + } + defer rows.Close() + for rows.Next() { + var b []byte + err = rows.Scan(&b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if len(b) != 400 { + t.Fatalf("Expected 100 bytes, got %d", len(b)) + } + } + sql := "SELECT uuid4_str();" + stmt, err = conn.Prepare(sql) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + defer stmt.Close() + rows, err = stmt.Query() + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + var i int + for rows.Next() { + var b string + err = rows.Scan(&b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if len(b) != 36 { + t.Fatalf("Expected 36 bytes, got %d", len(b)) + } + i++ + fmt.Printf("uuid: %s\n", b) + } + if i != 1 { + t.Fatalf("Expected 1 row, got %d", i) + } + fmt.Println("zeroblob + uuid functions passed") +} + +func TestDuplicateConnection(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + err = createTable(newConn) + if err != nil { + t.Fatalf("Error creating table: %v", err) + } + err = insertData(newConn) + if err != nil { + t.Fatalf("Error inserting data: %v", err) + } + query := "SELECT * FROM test;" + rows, err := newConn.Query(query) + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + for rows.Next() { + var a int + var b string + var c []byte + err = rows.Scan(&a, &b, &c) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + fmt.Println("RESULTS: ", a, b, string(c)) + } +} + +func TestDuplicateConnection2(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + sql := "CREATE TABLE test (foo INTEGER, bar INTEGER, baz BLOB);" + newConn.Exec(sql) + sql = "INSERT INTO test (foo, bar, baz) VALUES (?, ?, uuid4());" + stmt, err := newConn.Prepare(sql) + stmt.Exec(242345, 2342434) + defer stmt.Close() + query := "SELECT * FROM test;" + rows, err := newConn.Query(query) + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + for rows.Next() { + var a int + var b int + var c []byte + err = rows.Scan(&a, &b, &c) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + fmt.Println("RESULTS: ", a, b, string(c)) + if len(c) != 16 { + t.Fatalf("Expected 16 bytes, got %d", len(c)) + } + } +} + func slicesAreEq(a, b []byte) bool { if len(a) != len(b) { fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b)) @@ -133,7 +221,7 @@ func slicesAreEq(a, b []byte) bool { var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"} -func createTable() error { +func createTable(conn *sql.DB) error { insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);" stmt, err := conn.Prepare(insert) if err != nil { @@ -144,7 +232,7 @@ func createTable() error { return err } -func insertData() error { +func insertData(conn *sql.DB) error { for i := 1; i <= 5; i++ { insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);" stmt, err := conn.Prepare(insert) diff --git a/bindings/go/limbo_unix.go b/bindings/go/limbo_unix.go index cb6976594..8ab911416 100644 --- a/bindings/go/limbo_unix.go +++ b/bindings/go/limbo_unix.go @@ -3,7 +3,6 @@ package limbo import ( - "database/sql" "fmt" "os" "path/filepath" @@ -44,11 +43,3 @@ func loadLibrary() (uintptr, error) { } return 0, fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName) } - -func init() { - err := ensureLibLoaded() - if err != nil { - panic(err) - } - sql.Register("sqlite3", &limboDriver{}) -} diff --git a/bindings/go/limbo_windows.go b/bindings/go/limbo_windows.go index 433ddd051..d56381176 100644 --- a/bindings/go/limbo_windows.go +++ b/bindings/go/limbo_windows.go @@ -3,7 +3,6 @@ package limbo import ( - "database/sql" "fmt" "os" "path/filepath" @@ -12,14 +11,14 @@ import ( "golang.org/x/sys/windows" ) -func loadLibrary() error { +func loadLibrary() (uintptr, error) { libName := fmt.Sprintf("%s.dll", libName) pathEnv := os.Getenv("PATH") paths := strings.Split(pathEnv, ";") cwd, err := os.Getwd() if err != nil { - return err + return 0, err } paths = append(paths, cwd) for _, path := range paths { @@ -27,21 +26,11 @@ func loadLibrary() error { if _, err := os.Stat(dllPath); err == nil { slib, loadErr := windows.LoadLibrary(dllPath) if loadErr != nil { - return fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr) + return 0, fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr) } - limboLib = uintptr(slib) - return nil + return uintptr(slib), nil } } - return fmt.Errorf("library %s not found in PATH or CWD", libName) -} - -func init() { - err := loadLibrary() - if err != nil { - fmt.Println("Error opening limbo library: ", err) - os.Exit(1) - } - sql.Register("sqlite3", &limboDriver{}) + return 0, fmt.Errorf("library %s not found in PATH or CWD", libName) } diff --git a/bindings/go/rows.go b/bindings/go/rows.go index 276f0fd1a..2f20bb93f 100644 --- a/bindings/go/rows.go +++ b/bindings/go/rows.go @@ -7,7 +7,6 @@ import ( "sync" ) -// only construct limboRows with initRows function to ensure proper initialization type limboRows struct { mu sync.Mutex ctx uintptr @@ -15,12 +14,12 @@ type limboRows struct { closed bool } -// Initialize/register the FFI function pointers for the rows methods -// DO NOT construct 'limboRows' without this function -func initRows(ctx uintptr) *limboRows { +func newRows(ctx uintptr) *limboRows { return &limboRows{ - mu: sync.Mutex{}, - ctx: ctx, + mu: sync.Mutex{}, + ctx: ctx, + closed: false, + columns: nil, } } @@ -29,15 +28,17 @@ func (r *limboRows) Columns() []string { return nil } if r.columns == nil { - var columnCount uint r.mu.Lock() - defer r.mu.Unlock() - colArrayPtr := rowsGetColumns(r.ctx, &columnCount) - if colArrayPtr != 0 && columnCount > 0 { - r.columns = cArrayToGoStrings(colArrayPtr, columnCount) - if freeCols != nil { - defer freeCols(colArrayPtr) + count := rowsGetColumns(r.ctx) + if count > 0 { + columns := make([]string, 0, count) + for i := 0; i < int(count); i++ { + cstr := rowsGetColumnName(r.ctx, int32(i)) + columns = append(columns, fmt.Sprintf("%s", GoString(cstr))) + freeCString(cstr) } + r.mu.Unlock() + r.columns = columns } } return r.columns @@ -59,14 +60,14 @@ func (r *limboRows) Next(dest []driver.Value) error { if r.ctx == 0 || r.closed { return io.EOF } + r.mu.Lock() + defer r.mu.Unlock() for { status := rowsNext(r.ctx) switch ResultCode(status) { case Row: for i := range dest { - r.mu.Lock() valPtr := rowsGetValue(r.ctx, int32(i)) - r.mu.Unlock() val := toGoValue(valPtr) dest[i] = val } diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs index 19b526c74..c8ec7642a 100644 --- a/bindings/go/rs_src/rows.rs +++ b/bindings/go/rs_src/rows.rs @@ -78,30 +78,30 @@ pub extern "C" fn free_string(s: *mut c_char) { } } +/// Function to get the number of expected ResultColumns in the prepared statement. +/// to avoid the needless complexity of returning an array of strings, this instead +/// works like rows_next/rows_get_value #[no_mangle] -pub extern "C" fn rows_get_columns( - rows_ptr: *mut c_void, - out_length: *mut usize, -) -> *mut *const c_char { - if rows_ptr.is_null() || out_length.is_null() { +pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 { + if rows_ptr.is_null() { + return -1; + } + let rows = LimboRows::from_ptr(rows_ptr); + rows.stmt.columns().len() as i32 +} + +#[no_mangle] +pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *const c_char { + if rows_ptr.is_null() { return std::ptr::null_mut(); } let rows = LimboRows::from_ptr(rows_ptr); - let c_strings: Vec = rows - .stmt - .columns() - .iter() - .map(|name| std::ffi::CString::new(name.as_str()).unwrap()) - .collect(); - - let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect(); - unsafe { - *out_length = c_ptrs.len(); + if idx < 0 || idx as usize >= rows.stmt.columns().len() { + return std::ptr::null_mut(); } - let ptr = c_ptrs.as_ptr(); - std::mem::forget(c_strings); - std::mem::forget(c_ptrs); - ptr as *mut *const c_char + let name = &rows.stmt.columns()[idx as usize]; + let cstr = std::ffi::CString::new(name.as_bytes()).expect("Failed to create CString"); + cstr.into_raw() as *const c_char } #[no_mangle] @@ -111,21 +111,6 @@ pub extern "C" fn rows_close(rows_ptr: *mut c_void) { } } -#[no_mangle] -pub extern "C" fn free_columns(columns: *mut *const c_char) { - if columns.is_null() { - return; - } - unsafe { - let mut idx = 0; - while !(*columns.add(idx)).is_null() { - let _ = std::ffi::CString::from_raw(*columns.add(idx) as *mut c_char); - idx += 1; - } - let _ = Box::from_raw(columns); - } -} - #[no_mangle] pub extern "C" fn free_rows(rows: *mut c_void) { if rows.is_null() { diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 6b0dfc80e..8ad015ded 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -132,13 +132,9 @@ pub struct LimboStatement<'conn> { #[no_mangle] pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode { if !ctx.is_null() { - let stmt = LimboStatement::from_ptr(ctx); - if stmt.statement.is_none() { - return ResultCode::Error; - } else { - let _ = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; - return ResultCode::Ok; - } + let stmt = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; + drop(stmt); + return ResultCode::Ok; } ResultCode::Invalid } diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 97e92149d..0f3c038fd 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -1,6 +1,7 @@ package limbo import ( + "context" "database/sql/driver" "errors" "fmt" @@ -8,22 +9,16 @@ import ( "unsafe" ) -// only construct limboStmt with initStmt function to ensure proper initialization -// inUse tracks whether or not `query` has been called. if inUse > 0, stmt no longer -// owns the underlying data and `rows` is responsible for cleaning it up on close. type limboStmt struct { - mu sync.Mutex - ctx uintptr - sql string - inUse int + mu sync.Mutex + ctx uintptr + sql string } -// Initialize/register the FFI function pointers for the statement methods -func initStmt(ctx uintptr, sql string) *limboStmt { +func newStmt(ctx uintptr, sql string) *limboStmt { return &limboStmt{ - ctx: uintptr(ctx), - sql: sql, - inUse: 0, + ctx: uintptr(ctx), + sql: sql, } } @@ -34,25 +29,19 @@ func (ls *limboStmt) NumInput() int { } func (ls *limboStmt) Close() error { - if ls.inUse == 0 { - ls.mu.Lock() - res := closeStmt(ls.ctx) - ls.mu.Unlock() - if ResultCode(res) != Ok { - return fmt.Errorf("error closing statement: %s", ResultCode(res).String()) - } + ls.mu.Lock() + res := closeStmt(ls.ctx) + ls.mu.Unlock() + if ResultCode(res) != Ok { + return fmt.Errorf("error closing statement: %s", ResultCode(res).String()) } ls.ctx = 0 return nil } func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { - ls.mu.Lock() argArray, cleanup, err := buildArgs(args) - defer func() { - cleanup() - ls.mu.Unlock() - }() + defer cleanup() if err != nil { return nil, err } @@ -62,7 +51,9 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { argPtr = uintptr(unsafe.Pointer(&argArray[0])) } var changes uint64 + ls.mu.Lock() rc := stmtExec(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) + ls.mu.Unlock() switch ResultCode(rc) { case Ok, Done: return driver.RowsAffected(changes), nil @@ -80,9 +71,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { } func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { - ls.mu.Lock() queryArgs, cleanup, err := buildArgs(args) - defer func() { cleanup(); ls.mu.Unlock() }() + defer cleanup() if err != nil { return nil, err } @@ -90,64 +80,66 @@ func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { if len(args) > 0 { argPtr = uintptr(unsafe.Pointer(&queryArgs[0])) } + ls.mu.Lock() rowsPtr := stmtQuery(ls.ctx, argPtr, uint64(len(queryArgs))) + ls.mu.Unlock() if rowsPtr == 0 { return nil, fmt.Errorf("query failed for: %q", ls.sql) } - ls.inUse++ - return initRows(rowsPtr), nil + return newRows(rowsPtr), nil } -// func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { -// ls.mu.Lock() -// stripped := namedValueToValue(args) -// argArray, cleanup, err := getArgsPtr(stripped) -// defer func() { cleanup(); ls.mu.Unlock() }() -// if err != nil { -// return nil, err -// } -// select { -// case <-ctx.Done(): -// return nil, ctx.Err() -// default: -// } -// var changes uint64 -// res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) -// switch ResultCode(res) { -// case Ok, Done: -// changes := uint64(changes) -// return driver.RowsAffected(changes), nil -// case Error: -// return nil, errors.New("error executing statement") -// case Busy: -// return nil, errors.New("busy") -// case Interrupt: -// return nil, errors.New("interrupted") -// default: -// return nil, fmt.Errorf("unexpected status: %d", res) -// } -// } -// -// func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { -// ls.mu.Lock() -// queryArgs, allocs, err := buildNamedArgs(args) -// defer func() { allocs(); ls.mu.Unlock() }() -// if err != nil { -// return nil, err -// } -// argsPtr := uintptr(0) -// if len(queryArgs) > 0 { -// argsPtr = uintptr(unsafe.Pointer(&queryArgs[0])) -// } -// select { -// case <-ctx.Done(): -// return nil, ctx.Err() -// default: -// } -// rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs))) -// if rowsPtr == 0 { -// return nil, fmt.Errorf("query failed for: %q", ls.sql) -// } -// ls.inUse++ -// return initRows(rowsPtr), nil -// } +func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + stripped := namedValueToValue(args) + argArray, cleanup, err := getArgsPtr(stripped) + defer cleanup() + if err != nil { + return nil, err + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + var changes uint64 + ls.mu.Lock() + res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) + ls.mu.Unlock() + switch ResultCode(res) { + case Ok, Done: + changes := uint64(changes) + return driver.RowsAffected(changes), nil + case Error: + return nil, errors.New("error executing statement") + case Busy: + return nil, errors.New("busy") + case Interrupt: + return nil, errors.New("interrupted") + default: + return nil, fmt.Errorf("unexpected status: %d", res) + } +} + +func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + queryArgs, allocs, err := buildNamedArgs(args) + defer allocs() + if err != nil { + return nil, err + } + argsPtr := uintptr(0) + if len(queryArgs) > 0 { + argsPtr = uintptr(unsafe.Pointer(&queryArgs[0])) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + ls.mu.Lock() + rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs))) + ls.mu.Unlock() + if rowsPtr == 0 { + return nil, fmt.Errorf("query failed for: %q", ls.sql) + } + return newRows(rowsPtr), nil +} diff --git a/bindings/go/types.go b/bindings/go/types.go index d54ff27fd..bcb023108 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -65,20 +65,23 @@ func (rc ResultCode) String() string { } const ( - FfiDbOpen string = "db_open" - FfiDbClose string = "db_close" - FfiDbPrepare string = "db_prepare" - FfiStmtExec string = "stmt_execute" - FfiStmtQuery string = "stmt_query" - FfiStmtParameterCount string = "stmt_parameter_count" - FfiStmtClose string = "stmt_close" - FfiRowsClose string = "rows_close" - FfiRowsGetColumns string = "rows_get_columns" - FfiRowsNext string = "rows_next" - FfiRowsGetValue string = "rows_get_value" - FfiFreeColumns string = "free_columns" - FfiFreeCString string = "free_string" - FfiFreeBlob string = "free_blob" + driverName = "sqlite3" + libName = "lib_limbo_go" + FfiDbOpen = "db_open" + FfiDbClose = "db_close" + FfiDbPrepare = "db_prepare" + FfiStmtExec = "stmt_execute" + FfiStmtQuery = "stmt_query" + FfiStmtParameterCount = "stmt_parameter_count" + FfiStmtClose = "stmt_close" + FfiRowsClose = "rows_close" + FfiRowsGetColumns = "rows_get_columns" + FfiRowsGetColumnName = "rows_get_column_name" + FfiRowsNext = "rows_next" + FfiRowsGetValue = "rows_get_value" + FfiFreeColumns = "free_columns" + FfiFreeCString = "free_string" + FfiFreeBlob = "free_blob" ) // convert a namedValue slice into normal values until named parameters are supported @@ -212,22 +215,6 @@ func freeCString(cstrPtr uintptr) { freeStringFunc(cstrPtr) } -func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { - if arrayPtr == 0 || length == 0 { - return nil - } - ptrSlice := unsafe.Slice( - (**byte)(unsafe.Pointer(arrayPtr)), - length, - ) - - out := make([]string, 0, length) - for _, cstr := range ptrSlice { - out = append(out, GoString(uintptr(unsafe.Pointer(cstr)))) - } - return out -} - // convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI // for Blob types, we have to pin them so they are not garbage collected before they can be copied // into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call From 7ee52fca4dd37dbc7a678a895d4c484c0a820338 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Fri, 31 Jan 2025 15:03:14 -0500 Subject: [PATCH 3/3] bindings/go: update readme with example, change module name --- bindings/go/README.md | 58 +++++++++++++++++++++++++-------- bindings/go/go.mod | 2 +- bindings/go/limbo_test.go | 2 +- bindings/go/rs_src/lib.rs | 11 +++---- bindings/go/rs_src/rows.rs | 3 ++ bindings/go/rs_src/statement.rs | 10 ++---- 6 files changed, 55 insertions(+), 31 deletions(-) diff --git a/bindings/go/README.md b/bindings/go/README.md index 3dbffdb45..ab50140bb 100644 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -1,41 +1,71 @@ -## Limbo driver for Go's `database/sql` library +# Limbo driver for Go's `database/sql` library -**NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state. This is merged in only for the purposes of incremental progress and not because the existing code here proper. Expect many and frequent changes. +**NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state. -This uses the [purego](https://github.com/ebitengine/purego) library to call C (in this case Rust with C ABI) functions from Go without the use of `CGO`. +This driver uses the awesome [purego](https://github.com/ebitengine/purego) library to call C (in this case Rust with C ABI) functions from Go without the use of `CGO`. +## To use: (_UNSTABLE_ testing or development purposes only) - -### To test - - -## Linux | MacOS +### Linux | MacOS _All commands listed are relative to the bindings/go directory in the limbo repository_ ``` cargo build --package limbo-go - # Your LD_LIBRARY_PATH environment variable must include limbo's `target/debug` directory -LD_LIBRARY_PATH="../../target/debug:$LD_LIBRARY_PATH" go test +export LD_LIBRARY_PATH="/path/to/limbo/target/debug:$LD_LIBRARY_PATH" ``` - ## Windows ``` cargo build --package limbo-go -# Copy the lib_limbo_go.dll into the current working directory (bindings/go) -# Alternatively, you could add the .dll to a location in your PATH +# You must add limbo's `target/debug` directory to your PATH +# or you could built + copy the .dll to a location in your PATH +# or just the CWD of your go module -cp ../../target/debug/lib_limbo_go.dll . +cp path\to\limbo\target\debug\lib_limbo_go.dll . go test ``` +**Temporarily** you may have to clone the limbo repository and run: + +`go mod edit -replace github.com/tursodatabase/limbo=/path/to/limbo/bindings/go` + +```go +import ( + "fmt" + "database/sql" + _"github.com/tursodatabase/limbo" +) + +func main() { + conn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + sql := "CREATE table go_limbo (foo INTEGER, bar TEXT)" + _ = conn.Exec(sql) + + sql = "INSERT INTO go_limbo (foo, bar) values (?, ?)" + stmt, _ := conn.Prepare(sql) + defer stmt.Close() + _ = stmt.Exec(42, "limbo") + rows, _ := conn.Query("SELECT * from go_limbo") + defer rows.Close() + for rows.Next() { + var a int + var b string + _ = rows.Scan(&a, &b) + fmt.Printf("%d, %s", a, b) + } +} +``` diff --git a/bindings/go/go.mod b/bindings/go/go.mod index e49ba4c96..a9145591b 100644 --- a/bindings/go/go.mod +++ b/bindings/go/go.mod @@ -1,4 +1,4 @@ -module limbo +module github.com/tursodatabase/limbo go 1.23.4 diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go index a688cb34e..45d1dc786 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/limbo_test.go @@ -6,7 +6,7 @@ import ( "log" "testing" - _ "limbo" + _ "github.com/tursodatabase/limbo" ) var conn *sql.DB diff --git a/bindings/go/rs_src/lib.rs b/bindings/go/rs_src/lib.rs index 70dbce89a..fd0172cdf 100644 --- a/bindings/go/rs_src/lib.rs +++ b/bindings/go/rs_src/lib.rs @@ -7,7 +7,7 @@ use std::{ ffi::{c_char, c_void}, rc::Rc, str::FromStr, - sync::{Arc, RwLock}, + sync::Arc, }; /// # Safety @@ -40,21 +40,18 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { #[allow(dead_code)] struct LimboConn { - conn: RwLock>, + conn: Rc, io: Arc, } impl<'conn> LimboConn { fn new(conn: Rc, io: Arc) -> Self { - LimboConn { - conn: conn.into(), - io, - } + LimboConn { conn, io } } #[allow(clippy::wrong_self_convention)] fn to_ptr(self) -> *mut c_void { - Arc::into_raw(Arc::new(self)) as *mut c_void + Box::into_raw(Box::new(self)) as *mut c_void } fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboConn { diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs index c8ec7642a..e089c9f4c 100644 --- a/bindings/go/rs_src/rows.rs +++ b/bindings/go/rs_src/rows.rs @@ -90,6 +90,9 @@ pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 { rows.stmt.columns().len() as i32 } +/// Returns a pointer to a string with the name of the column at the given index. +/// The caller is responsible for freeing the memory, it should be copied on the Go side +/// immediately and 'free_string' called #[no_mangle] pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *const c_char { if rows_ptr.is_null() { diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 8ad015ded..7d5c6c92a 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -13,10 +13,7 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v let query_str = unsafe { std::ffi::CStr::from_ptr(query) }.to_str().unwrap(); let db = LimboConn::from_ptr(ctx); - let Ok(conn) = db.conn.read() else { - return std::ptr::null_mut(); - }; - let stmt = conn.prepare(query_str); + let stmt = db.conn.prepare(query_str); match stmt { Ok(stmt) => LimboStatement::new(Some(stmt), LimboConn::from_ptr(ctx)).to_ptr(), Err(_) => std::ptr::null_mut(), @@ -55,10 +52,7 @@ pub extern "C" fn stmt_execute( return ResultCode::Error; } Ok(StepResult::Done) => { - let Ok(conn) = stmt.conn.conn.read() else { - return ResultCode::Done; - }; - let total_changes = conn.total_changes(); + let total_changes = stmt.conn.conn.total_changes(); if !changes.is_null() { unsafe { *changes = total_changes;