diff --git a/.gitignore b/.gitignore index 1f1406ceb..c7c56a7ee 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,7 @@ dist/ .DS_Store # Javascript -**/node_modules/ \ No newline at end of file +**/node_modules/ + +# testing +testing/limbo_output.txt diff --git a/bindings/go/README.md b/bindings/go/README.md index 3dbffdb45..ab50140bb 100644 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -1,41 +1,71 @@ -## Limbo driver for Go's `database/sql` library +# Limbo driver for Go's `database/sql` library -**NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state. This is merged in only for the purposes of incremental progress and not because the existing code here proper. Expect many and frequent changes. +**NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state. -This uses the [purego](https://github.com/ebitengine/purego) library to call C (in this case Rust with C ABI) functions from Go without the use of `CGO`. +This driver uses the awesome [purego](https://github.com/ebitengine/purego) library to call C (in this case Rust with C ABI) functions from Go without the use of `CGO`. +## To use: (_UNSTABLE_ testing or development purposes only) - -### To test - - -## Linux | MacOS +### Linux | MacOS _All commands listed are relative to the bindings/go directory in the limbo repository_ ``` cargo build --package limbo-go - # Your LD_LIBRARY_PATH environment variable must include limbo's `target/debug` directory -LD_LIBRARY_PATH="../../target/debug:$LD_LIBRARY_PATH" go test +export LD_LIBRARY_PATH="/path/to/limbo/target/debug:$LD_LIBRARY_PATH" ``` - ## Windows ``` cargo build --package limbo-go -# Copy the lib_limbo_go.dll into the current working directory (bindings/go) -# Alternatively, you could add the .dll to a location in your PATH +# You must add limbo's `target/debug` directory to your PATH +# or you could built + copy the .dll to a location in your PATH +# or just the CWD of your go module -cp ../../target/debug/lib_limbo_go.dll . +cp path\to\limbo\target\debug\lib_limbo_go.dll . go test ``` +**Temporarily** you may have to clone the limbo repository and run: + +`go mod edit -replace github.com/tursodatabase/limbo=/path/to/limbo/bindings/go` + +```go +import ( + "fmt" + "database/sql" + _"github.com/tursodatabase/limbo" +) + +func main() { + conn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + sql := "CREATE table go_limbo (foo INTEGER, bar TEXT)" + _ = conn.Exec(sql) + + sql = "INSERT INTO go_limbo (foo, bar) values (?, ?)" + stmt, _ := conn.Prepare(sql) + defer stmt.Close() + _ = stmt.Exec(42, "limbo") + rows, _ := conn.Query("SELECT * from go_limbo") + defer rows.Close() + for rows.Next() { + var a int + var b string + _ = rows.Scan(&a, &b) + fmt.Printf("%d, %s", a, b) + } +} +``` diff --git a/bindings/go/connection.go b/bindings/go/connection.go index 8c45824e8..672f2e4e3 100644 --- a/bindings/go/connection.go +++ b/bindings/go/connection.go @@ -1,71 +1,105 @@ package limbo import ( + "database/sql" "database/sql/driver" "errors" "fmt" - "unsafe" + "sync" "github.com/ebitengine/purego" ) -const ( - driverName = "sqlite3" - libName = "lib_limbo_go" +func init() { + err := ensureLibLoaded() + if err != nil { + panic(err) + } + sql.Register(driverName, &limboDriver{}) +} + +type limboDriver struct { + sync.Mutex +} + +var ( + libOnce sync.Once + limboLib uintptr + loadErr error + dbOpen func(string) uintptr + dbClose func(uintptr) uintptr + connPrepare func(uintptr, string) uintptr + freeBlobFunc func(uintptr) + freeStringFunc func(uintptr) + rowsGetColumns func(uintptr) int32 + rowsGetColumnName func(uintptr, int32) uintptr + rowsGetValue func(uintptr, int32) uintptr + closeRows func(uintptr) uintptr + rowsNext func(uintptr) uintptr + stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr + stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 + stmtParamCount func(uintptr) int32 + closeStmt func(uintptr) int32 ) -var limboLib uintptr - -type limboDriver struct{} - -func (d limboDriver) Open(name string) (driver.Conn, error) { - return openConn(name) +// Register all the symbols on library load +func ensureLibLoaded() error { + libOnce.Do(func() { + limboLib, loadErr = loadLibrary() + if loadErr != nil { + return + } + purego.RegisterLibFunc(&dbOpen, limboLib, FfiDbOpen) + purego.RegisterLibFunc(&dbClose, limboLib, FfiDbClose) + purego.RegisterLibFunc(&connPrepare, limboLib, FfiDbPrepare) + purego.RegisterLibFunc(&freeBlobFunc, limboLib, FfiFreeBlob) + purego.RegisterLibFunc(&freeStringFunc, limboLib, FfiFreeCString) + purego.RegisterLibFunc(&rowsGetColumns, limboLib, FfiRowsGetColumns) + purego.RegisterLibFunc(&rowsGetColumnName, limboLib, FfiRowsGetColumnName) + purego.RegisterLibFunc(&rowsGetValue, limboLib, FfiRowsGetValue) + purego.RegisterLibFunc(&closeRows, limboLib, FfiRowsClose) + purego.RegisterLibFunc(&rowsNext, limboLib, FfiRowsNext) + purego.RegisterLibFunc(&stmtQuery, limboLib, FfiStmtQuery) + purego.RegisterLibFunc(&stmtExec, limboLib, FfiStmtExec) + purego.RegisterLibFunc(&stmtParamCount, limboLib, FfiStmtParameterCount) + purego.RegisterLibFunc(&closeStmt, limboLib, FfiStmtClose) + }) + return loadErr } -func toCString(s string) uintptr { - b := append([]byte(s), 0) - return uintptr(unsafe.Pointer(&b[0])) -} - -// helper to register an FFI function in the lib_limbo_go library -func getFfiFunc(ptr interface{}, name string) { - purego.RegisterLibFunc(ptr, limboLib, name) -} - -// TODO: sync primitives -type limboConn struct { - ctx uintptr - prepare func(uintptr, string) uintptr -} - -func newConn(ctx uintptr) *limboConn { - var prepare func(uintptr, string) uintptr - getFfiFunc(&prepare, FfiDbPrepare) - return &limboConn{ - ctx, - prepare, +func (d *limboDriver) Open(name string) (driver.Conn, error) { + d.Lock() + conn, err := openConn(name) + d.Unlock() + if err != nil { + return nil, err } + return conn, nil +} + +type limboConn struct { + sync.Mutex + ctx uintptr } func openConn(dsn string) (*limboConn, error) { - var dbOpen func(string) uintptr - getFfiFunc(&dbOpen, FfiDbOpen) - ctx := dbOpen(dsn) if ctx == 0 { return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) } - return newConn(ctx), nil + return &limboConn{ + sync.Mutex{}, + ctx, + }, loadErr } func (c *limboConn) Close() error { if c.ctx == 0 { return nil } - var dbClose func(uintptr) uintptr - getFfiFunc(&dbClose, FfiDbClose) - + c.Lock() dbClose(c.ctx) + c.Unlock() c.ctx = 0 return nil } @@ -74,14 +108,13 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) { if c.ctx == 0 { return nil, errors.New("connection closed") } - if c.prepare == nil { - panic("prepare function not set") - } - stmtPtr := c.prepare(c.ctx, query) + c.Lock() + defer c.Unlock() + stmtPtr := connPrepare(c.ctx, query) if stmtPtr == 0 { return nil, fmt.Errorf("failed to prepare query=%q", query) } - return initStmt(stmtPtr, query), nil + return newStmt(stmtPtr, query), nil } // begin is needed to implement driver.Conn.. for now not implemented diff --git a/bindings/go/go.mod b/bindings/go/go.mod index e49ba4c96..a9145591b 100644 --- a/bindings/go/go.mod +++ b/bindings/go/go.mod @@ -1,4 +1,4 @@ -module limbo +module github.com/tursodatabase/limbo go 1.23.4 diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go index 1a787a149..45d1dc786 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/limbo_test.go @@ -3,67 +3,36 @@ package limbo_test import ( "database/sql" "fmt" + "log" "testing" - _ "limbo" + _ "github.com/tursodatabase/limbo" ) -func TestConnection(t *testing.T) { - conn, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatalf("Error opening database: %v", err) +var conn *sql.DB +var connErr error + +func TestMain(m *testing.M) { + conn, connErr = sql.Open("sqlite3", ":memory:") + if connErr != nil { + panic(connErr) } defer conn.Close() -} - -func TestCreateTable(t *testing.T) { - conn, err := sql.Open("sqlite3", ":memory:") + err := createTable(conn) if err != nil { - t.Fatalf("Error opening database: %v", err) - } - defer conn.Close() - - err = createTable(conn) - if err != nil { - t.Fatalf("Error creating table: %v", err) + log.Fatalf("Error creating table: %v", err) } + m.Run() } func TestInsertData(t *testing.T) { - conn, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatalf("Error opening database: %v", err) - } - defer conn.Close() - - err = createTable(conn) - if err != nil { - t.Fatalf("Error creating table: %v", err) - } - - err = insertData(conn) + err := insertData(conn) if err != nil { t.Fatalf("Error inserting data: %v", err) } } func TestQuery(t *testing.T) { - conn, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatalf("Error opening database: %v", err) - } - defer conn.Close() - - err = createTable(conn) - if err != nil { - t.Fatalf("Error creating table: %v", err) - } - - err = insertData(conn) - if err != nil { - t.Fatalf("Error inserting data: %v", err) - } - query := "SELECT * FROM test;" stmt, err := conn.Prepare(query) if err != nil { @@ -99,8 +68,8 @@ func TestQuery(t *testing.T) { if err != nil { t.Fatalf("Error scanning row: %v", err) } - if a != i || b != rowsMap[i] || string(c) != rowsMap[i] { - t.Fatalf("Expected %d, %s, %s, got %d, %s, %b", i, rowsMap[i], rowsMap[i], a, b, c) + if a != i || b != rowsMap[i] || !slicesAreEq(c, []byte(rowsMap[i])) { + t.Fatalf("Expected %d, %s, %s, got %d, %s, %s", i, rowsMap[i], rowsMap[i], a, b, string(c)) } fmt.Println("RESULTS: ", a, b, string(c)) i++ @@ -109,6 +78,145 @@ func TestQuery(t *testing.T) { if err = rows.Err(); err != nil { t.Fatalf("Row iteration error: %v", err) } + +} + +func TestFunctions(t *testing.T) { + insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));" + stmt, err := conn.Prepare(insert) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + _, err = stmt.Exec(60, "TestFunction", 400) + if err != nil { + t.Fatalf("Error executing statment with arguments: %v", err) + } + stmt.Close() + stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?") + if err != nil { + t.Fatalf("Error preparing select stmt: %v", err) + } + defer stmt.Close() + rows, err := stmt.Query(60) + if err != nil { + t.Fatalf("Error executing select stmt: %v", err) + } + defer rows.Close() + for rows.Next() { + var b []byte + err = rows.Scan(&b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if len(b) != 400 { + t.Fatalf("Expected 100 bytes, got %d", len(b)) + } + } + sql := "SELECT uuid4_str();" + stmt, err = conn.Prepare(sql) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + defer stmt.Close() + rows, err = stmt.Query() + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + var i int + for rows.Next() { + var b string + err = rows.Scan(&b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if len(b) != 36 { + t.Fatalf("Expected 36 bytes, got %d", len(b)) + } + i++ + fmt.Printf("uuid: %s\n", b) + } + if i != 1 { + t.Fatalf("Expected 1 row, got %d", i) + } + fmt.Println("zeroblob + uuid functions passed") +} + +func TestDuplicateConnection(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + err = createTable(newConn) + if err != nil { + t.Fatalf("Error creating table: %v", err) + } + err = insertData(newConn) + if err != nil { + t.Fatalf("Error inserting data: %v", err) + } + query := "SELECT * FROM test;" + rows, err := newConn.Query(query) + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + for rows.Next() { + var a int + var b string + var c []byte + err = rows.Scan(&a, &b, &c) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + fmt.Println("RESULTS: ", a, b, string(c)) + } +} + +func TestDuplicateConnection2(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + sql := "CREATE TABLE test (foo INTEGER, bar INTEGER, baz BLOB);" + newConn.Exec(sql) + sql = "INSERT INTO test (foo, bar, baz) VALUES (?, ?, uuid4());" + stmt, err := newConn.Prepare(sql) + stmt.Exec(242345, 2342434) + defer stmt.Close() + query := "SELECT * FROM test;" + rows, err := newConn.Query(query) + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + for rows.Next() { + var a int + var b int + var c []byte + err = rows.Scan(&a, &b, &c) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + fmt.Println("RESULTS: ", a, b, string(c)) + if len(c) != 16 { + t.Fatalf("Expected 16 bytes, got %d", len(c)) + } + } +} + +func slicesAreEq(a, b []byte) bool { + if len(a) != len(b) { + fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b)) + return false + } + for i := range a { + if a[i] != b[i] { + fmt.Printf("SLICES NOT EQUAL: %v != %v\n", a, b) + return false + } + } + return true } var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"} diff --git a/bindings/go/limbo_unix.go b/bindings/go/limbo_unix.go index ecafa5a85..8ab911416 100644 --- a/bindings/go/limbo_unix.go +++ b/bindings/go/limbo_unix.go @@ -3,7 +3,6 @@ package limbo import ( - "database/sql" "fmt" "os" "path/filepath" @@ -13,7 +12,7 @@ import ( "github.com/ebitengine/purego" ) -func loadLibrary() error { +func loadLibrary() (uintptr, error) { var libraryName string switch runtime.GOOS { case "darwin": @@ -21,14 +20,14 @@ func loadLibrary() error { case "linux": libraryName = fmt.Sprintf("%s.so", libName) default: - return fmt.Errorf("GOOS=%s is not supported", runtime.GOOS) + return 0, fmt.Errorf("GOOS=%s is not supported", runtime.GOOS) } libPath := os.Getenv("LD_LIBRARY_PATH") paths := strings.Split(libPath, ":") cwd, err := os.Getwd() if err != nil { - return err + return 0, err } paths = append(paths, cwd) @@ -37,20 +36,10 @@ func loadLibrary() error { if _, err := os.Stat(libPath); err == nil { slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL) if dlerr != nil { - return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr) + return 0, fmt.Errorf("failed to load library at %s: %w", libPath, dlerr) } - limboLib = slib - return nil + return slib, nil } } - return fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName) -} - -func init() { - err := loadLibrary() - if err != nil { - fmt.Println(err) - os.Exit(1) - } - sql.Register("sqlite3", &limboDriver{}) + return 0, fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName) } diff --git a/bindings/go/limbo_windows.go b/bindings/go/limbo_windows.go index 433ddd051..d56381176 100644 --- a/bindings/go/limbo_windows.go +++ b/bindings/go/limbo_windows.go @@ -3,7 +3,6 @@ package limbo import ( - "database/sql" "fmt" "os" "path/filepath" @@ -12,14 +11,14 @@ import ( "golang.org/x/sys/windows" ) -func loadLibrary() error { +func loadLibrary() (uintptr, error) { libName := fmt.Sprintf("%s.dll", libName) pathEnv := os.Getenv("PATH") paths := strings.Split(pathEnv, ";") cwd, err := os.Getwd() if err != nil { - return err + return 0, err } paths = append(paths, cwd) for _, path := range paths { @@ -27,21 +26,11 @@ func loadLibrary() error { if _, err := os.Stat(dllPath); err == nil { slib, loadErr := windows.LoadLibrary(dllPath) if loadErr != nil { - return fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr) + return 0, fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr) } - limboLib = uintptr(slib) - return nil + return uintptr(slib), nil } } - return fmt.Errorf("library %s not found in PATH or CWD", libName) -} - -func init() { - err := loadLibrary() - if err != nil { - fmt.Println("Error opening limbo library: ", err) - os.Exit(1) - } - sql.Register("sqlite3", &limboDriver{}) + return 0, fmt.Errorf("library %s not found in PATH or CWD", libName) } diff --git a/bindings/go/rows.go b/bindings/go/rows.go new file mode 100644 index 000000000..2f20bb93f --- /dev/null +++ b/bindings/go/rows.go @@ -0,0 +1,83 @@ +package limbo + +import ( + "database/sql/driver" + "fmt" + "io" + "sync" +) + +type limboRows struct { + mu sync.Mutex + ctx uintptr + columns []string + closed bool +} + +func newRows(ctx uintptr) *limboRows { + return &limboRows{ + mu: sync.Mutex{}, + ctx: ctx, + closed: false, + columns: nil, + } +} + +func (r *limboRows) Columns() []string { + if r.ctx == 0 || r.closed { + return nil + } + if r.columns == nil { + r.mu.Lock() + count := rowsGetColumns(r.ctx) + if count > 0 { + columns := make([]string, 0, count) + for i := 0; i < int(count); i++ { + cstr := rowsGetColumnName(r.ctx, int32(i)) + columns = append(columns, fmt.Sprintf("%s", GoString(cstr))) + freeCString(cstr) + } + r.mu.Unlock() + r.columns = columns + } + } + return r.columns +} + +func (r *limboRows) Close() error { + if r.closed { + return nil + } + r.mu.Lock() + r.closed = true + closeRows(r.ctx) + r.ctx = 0 + r.mu.Unlock() + return nil +} + +func (r *limboRows) Next(dest []driver.Value) error { + if r.ctx == 0 || r.closed { + return io.EOF + } + r.mu.Lock() + defer r.mu.Unlock() + for { + status := rowsNext(r.ctx) + switch ResultCode(status) { + case Row: + for i := range dest { + valPtr := rowsGetValue(r.ctx, int32(i)) + val := toGoValue(valPtr) + dest[i] = val + } + return nil + case Io: + continue + case Done: + return io.EOF + default: + return fmt.Errorf("unexpected status: %d", status) + } + } +} diff --git a/bindings/go/rs_src/lib.rs b/bindings/go/rs_src/lib.rs index 199ed10c0..fd0172cdf 100644 --- a/bindings/go/rs_src/lib.rs +++ b/bindings/go/rs_src/lib.rs @@ -26,7 +26,6 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { let db = Database::open_file(io.clone(), &db_options.path.to_string()); match db { Ok(db) => { - println!("Opened database: {}", path); let conn = db.connect(); return LimboConn::new(conn, io).to_ptr(); } @@ -45,16 +44,17 @@ struct LimboConn { io: Arc, } -impl LimboConn { +impl<'conn> LimboConn { fn new(conn: Rc, io: Arc) -> Self { LimboConn { conn, io } } + #[allow(clippy::wrong_self_convention)] fn to_ptr(self) -> *mut c_void { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'static mut LimboConn { + fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboConn { if ptr.is_null() { panic!("Null pointer"); } diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs index 19b526c74..e089c9f4c 100644 --- a/bindings/go/rs_src/rows.rs +++ b/bindings/go/rs_src/rows.rs @@ -78,30 +78,33 @@ pub extern "C" fn free_string(s: *mut c_char) { } } +/// Function to get the number of expected ResultColumns in the prepared statement. +/// to avoid the needless complexity of returning an array of strings, this instead +/// works like rows_next/rows_get_value #[no_mangle] -pub extern "C" fn rows_get_columns( - rows_ptr: *mut c_void, - out_length: *mut usize, -) -> *mut *const c_char { - if rows_ptr.is_null() || out_length.is_null() { +pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 { + if rows_ptr.is_null() { + return -1; + } + let rows = LimboRows::from_ptr(rows_ptr); + rows.stmt.columns().len() as i32 +} + +/// Returns a pointer to a string with the name of the column at the given index. +/// The caller is responsible for freeing the memory, it should be copied on the Go side +/// immediately and 'free_string' called +#[no_mangle] +pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *const c_char { + if rows_ptr.is_null() { return std::ptr::null_mut(); } let rows = LimboRows::from_ptr(rows_ptr); - let c_strings: Vec = rows - .stmt - .columns() - .iter() - .map(|name| std::ffi::CString::new(name.as_str()).unwrap()) - .collect(); - - let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect(); - unsafe { - *out_length = c_ptrs.len(); + if idx < 0 || idx as usize >= rows.stmt.columns().len() { + return std::ptr::null_mut(); } - let ptr = c_ptrs.as_ptr(); - std::mem::forget(c_strings); - std::mem::forget(c_ptrs); - ptr as *mut *const c_char + let name = &rows.stmt.columns()[idx as usize]; + let cstr = std::ffi::CString::new(name.as_bytes()).expect("Failed to create CString"); + cstr.into_raw() as *const c_char } #[no_mangle] @@ -111,21 +114,6 @@ pub extern "C" fn rows_close(rows_ptr: *mut c_void) { } } -#[no_mangle] -pub extern "C" fn free_columns(columns: *mut *const c_char) { - if columns.is_null() { - return; - } - unsafe { - let mut idx = 0; - while !(*columns.add(idx)).is_null() { - let _ = std::ffi::CString::from_raw(*columns.add(idx) as *mut c_char); - idx += 1; - } - let _ = Box::from_raw(columns); - } -} - #[no_mangle] pub extern "C" fn free_rows(rows: *mut c_void) { if rows.is_null() { diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 329f9b59c..7d5c6c92a 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -13,10 +13,9 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v let query_str = unsafe { std::ffi::CStr::from_ptr(query) }.to_str().unwrap(); let db = LimboConn::from_ptr(ctx); - let stmt = db.conn.prepare(query_str); match stmt { - Ok(stmt) => LimboStatement::new(Some(stmt), db).to_ptr(), + Ok(stmt) => LimboStatement::new(Some(stmt), LimboConn::from_ptr(ctx)).to_ptr(), Err(_) => std::ptr::null_mut(), } } @@ -53,10 +52,10 @@ pub extern "C" fn stmt_execute( return ResultCode::Error; } Ok(StepResult::Done) => { - stmt.conn.conn.total_changes(); + let total_changes = stmt.conn.conn.total_changes(); if !changes.is_null() { unsafe { - *changes = stmt.conn.conn.total_changes(); + *changes = total_changes; } } return ResultCode::Done; @@ -127,13 +126,9 @@ pub struct LimboStatement<'conn> { #[no_mangle] pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode { if !ctx.is_null() { - let stmt = LimboStatement::from_ptr(ctx); - if stmt.statement.is_none() { - return ResultCode::Error; - } else { - let _ = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; - return ResultCode::Ok; - } + let stmt = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; + drop(stmt); + return ResultCode::Ok; } ResultCode::Invalid } @@ -148,7 +143,7 @@ impl<'conn> LimboStatement<'conn> { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'static mut LimboStatement<'conn> { + fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboStatement<'conn> { if ptr.is_null() { panic!("Null pointer"); } diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 02ad3c3eb..0f3c038fd 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -5,54 +5,35 @@ import ( "database/sql/driver" "errors" "fmt" - "io" + "sync" "unsafe" ) -// only construct limboStmt with initStmt function to ensure proper initialization -// inUse tracks whether or not `query` has been called. if inUse > 0, stmt no longer -// owns the underlying data and `rows` is responsible for cleaning it up on close. type limboStmt struct { - ctx uintptr - sql string - inUse int - query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr - execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 - getParamCount func(uintptr) int32 - closeStmt func(uintptr) int32 + mu sync.Mutex + ctx uintptr + sql string } -// Initialize/register the FFI function pointers for the statement methods -func initStmt(ctx uintptr, sql string) *limboStmt { - var query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr - getFfiFunc(&query, FfiStmtQuery) - var execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 - getFfiFunc(&execute, FfiStmtExec) - var getParamCount func(uintptr) int32 - getFfiFunc(&getParamCount, FfiStmtParameterCount) - var closeStmt func(uintptr) int32 - getFfiFunc(&closeStmt, FfiStmtClose) +func newStmt(ctx uintptr, sql string) *limboStmt { return &limboStmt{ - ctx: uintptr(ctx), - sql: sql, - inUse: 0, - execute: execute, - query: query, - getParamCount: getParamCount, - closeStmt: closeStmt, + ctx: uintptr(ctx), + sql: sql, } } func (ls *limboStmt) NumInput() int { - return int(ls.getParamCount(ls.ctx)) + ls.mu.Lock() + defer ls.mu.Unlock() + return int(stmtParamCount(ls.ctx)) } func (ls *limboStmt) Close() error { - if ls.inUse == 0 { - res := ls.closeStmt(ls.ctx) - if ResultCode(res) != Ok { - return fmt.Errorf("error closing statement: %s", ResultCode(res).String()) - } + ls.mu.Lock() + res := closeStmt(ls.ctx) + ls.mu.Unlock() + if ResultCode(res) != Ok { + return fmt.Errorf("error closing statement: %s", ResultCode(res).String()) } ls.ctx = 0 return nil @@ -70,7 +51,9 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { argPtr = uintptr(unsafe.Pointer(&argArray[0])) } var changes uint64 - rc := ls.execute(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) + ls.mu.Lock() + rc := stmtExec(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) + ls.mu.Unlock() switch ResultCode(rc) { case Ok, Done: return driver.RowsAffected(changes), nil @@ -87,7 +70,7 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { } } -func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) { +func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { queryArgs, cleanup, err := buildArgs(args) defer cleanup() if err != nil { @@ -97,12 +80,13 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) { if len(args) > 0 { argPtr = uintptr(unsafe.Pointer(&queryArgs[0])) } - rowsPtr := st.query(st.ctx, argPtr, uint64(len(queryArgs))) + ls.mu.Lock() + rowsPtr := stmtQuery(ls.ctx, argPtr, uint64(len(queryArgs))) + ls.mu.Unlock() if rowsPtr == 0 { - return nil, fmt.Errorf("query failed for: %q", st.sql) + return nil, fmt.Errorf("query failed for: %q", ls.sql) } - st.inUse++ - return initRows(rowsPtr), nil + return newRows(rowsPtr), nil } func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { @@ -118,7 +102,9 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive default: } var changes uint64 - res := ls.execute(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) + ls.mu.Lock() + res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) + ls.mu.Unlock() switch ResultCode(res) { case Ok, Done: changes := uint64(changes) @@ -149,89 +135,11 @@ func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) return nil, ctx.Err() default: } - rowsPtr := ls.query(ls.ctx, argsPtr, uint64(len(queryArgs))) + ls.mu.Lock() + rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs))) + ls.mu.Unlock() if rowsPtr == 0 { return nil, fmt.Errorf("query failed for: %q", ls.sql) } - ls.inUse++ - return initRows(rowsPtr), nil -} - -// only construct limboRows with initRows function to ensure proper initialization -type limboRows struct { - ctx uintptr - columns []string - closed bool - getCols func(uintptr, *uint) uintptr - next func(uintptr) uintptr - getValue func(uintptr, int32) uintptr - closeRows func(uintptr) uintptr - freeCols func(uintptr) uintptr -} - -// Initialize/register the FFI function pointers for the rows methods -// DO NOT construct 'limboRows' without this function -func initRows(ctx uintptr) *limboRows { - var getCols func(uintptr, *uint) uintptr - getFfiFunc(&getCols, FfiRowsGetColumns) - var getValue func(uintptr, int32) uintptr - getFfiFunc(&getValue, FfiRowsGetValue) - var closeRows func(uintptr) uintptr - getFfiFunc(&closeRows, FfiRowsClose) - var freeCols func(uintptr) uintptr - getFfiFunc(&freeCols, FfiFreeColumns) - var next func(uintptr) uintptr - getFfiFunc(&next, FfiRowsNext) - - return &limboRows{ - ctx: ctx, - getCols: getCols, - getValue: getValue, - closeRows: closeRows, - freeCols: freeCols, - next: next, - } -} - -func (r *limboRows) Columns() []string { - if r.columns == nil { - var columnCount uint - colArrayPtr := r.getCols(r.ctx, &columnCount) - if colArrayPtr != 0 && columnCount > 0 { - r.columns = cArrayToGoStrings(colArrayPtr, columnCount) - defer r.freeCols(colArrayPtr) - } - } - return r.columns -} - -func (r *limboRows) Close() error { - if r.closed { - return nil - } - r.closed = true - r.closeRows(r.ctx) - r.ctx = 0 - return nil -} - -func (r *limboRows) Next(dest []driver.Value) error { - for { - status := r.next(r.ctx) - switch ResultCode(status) { - case Row: - for i := range dest { - valPtr := r.getValue(r.ctx, int32(i)) - val := toGoValue(valPtr) - dest[i] = val - } - return nil - case Io: - continue - case Done: - return io.EOF - default: - return fmt.Errorf("unexpected status: %d", status) - } - } + return newRows(rowsPtr), nil } diff --git a/bindings/go/types.go b/bindings/go/types.go index 78fb96153..bcb023108 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -65,20 +65,23 @@ func (rc ResultCode) String() string { } const ( - FfiDbOpen string = "db_open" - FfiDbClose string = "db_close" - FfiDbPrepare string = "db_prepare" - FfiStmtExec string = "stmt_execute" - FfiStmtQuery string = "stmt_query" - FfiStmtParameterCount string = "stmt_parameter_count" - FfiStmtClose string = "stmt_close" - FfiRowsClose string = "rows_close" - FfiRowsGetColumns string = "rows_get_columns" - FfiRowsNext string = "rows_next" - FfiRowsGetValue string = "rows_get_value" - FfiFreeColumns string = "free_columns" - FfiFreeCString string = "free_string" - FfiFreeBlob string = "free_blob" + driverName = "sqlite3" + libName = "lib_limbo_go" + FfiDbOpen = "db_open" + FfiDbClose = "db_close" + FfiDbPrepare = "db_prepare" + FfiStmtExec = "stmt_execute" + FfiStmtQuery = "stmt_query" + FfiStmtParameterCount = "stmt_parameter_count" + FfiStmtClose = "stmt_close" + FfiRowsClose = "rows_close" + FfiRowsGetColumns = "rows_get_columns" + FfiRowsGetColumnName = "rows_get_column_name" + FfiRowsNext = "rows_next" + FfiRowsGetValue = "rows_get_value" + FfiFreeColumns = "free_columns" + FfiFreeCString = "free_string" + FfiFreeBlob = "free_blob" ) // convert a namedValue slice into normal values until named parameters are supported @@ -198,47 +201,20 @@ func toGoBlob(blobPtr uintptr) []byte { return copied } -var freeBlobFunc func(uintptr) - func freeBlob(blobPtr uintptr) { if blobPtr == 0 { return } - if freeBlobFunc == nil { - getFfiFunc(&freeBlobFunc, FfiFreeBlob) - } freeBlobFunc(blobPtr) } -var freeStringFunc func(uintptr) - func freeCString(cstrPtr uintptr) { if cstrPtr == 0 { return } - if freeStringFunc == nil { - getFfiFunc(&freeStringFunc, FfiFreeCString) - } freeStringFunc(cstrPtr) } -func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { - if arrayPtr == 0 || length == 0 { - return nil - } - - ptrSlice := unsafe.Slice( - (**byte)(unsafe.Pointer(arrayPtr)), - length, - ) - - out := make([]string, 0, length) - for _, cstr := range ptrSlice { - out = append(out, GoString(uintptr(unsafe.Pointer(cstr)))) - } - return out -} - // convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI // for Blob types, we have to pin them so they are not garbage collected before they can be copied // into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call @@ -259,6 +235,7 @@ func buildArgs(args []driver.Value) ([]limboValue, func(), error) { case string: limboVal.Type = textVal cstr := CString(val) + pinner.Pin(cstr) *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr)) case []byte: limboVal.Type = blobVal