From bf6b80edab34e0c3d1a1e2e4ef71fd4e633c5913 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 28 Jan 2025 10:58:26 -0500 Subject: [PATCH] Continue progress go database/sql driver, add tests and CI --- .github/workflows/go.yml | 43 +++++++++ bindings/go/README.md | 41 +++++++++ bindings/go/connection.go | 90 +++++++++++++++++++ bindings/go/go.mod | 2 +- bindings/go/limbo.go | 141 ----------------------------- bindings/go/limbo_test.go | 137 +++++++++++++++++++++++++++++ bindings/go/limbo_unix.go | 56 ++++++++++++ bindings/go/limbo_windows.go | 47 ++++++++++ bindings/go/rs_src/rows.rs | 26 +++--- bindings/go/rs_src/statement.rs | 62 ++++++++----- bindings/go/rs_src/types.rs | 40 +++++++-- bindings/go/stmt.go | 135 ++++++++++++++++++---------- bindings/go/types.go | 151 +++++++++++++++++++------------- 13 files changed, 682 insertions(+), 289 deletions(-) create mode 100644 .github/workflows/go.yml create mode 100644 bindings/go/README.md create mode 100644 bindings/go/connection.go delete mode 100644 bindings/go/limbo.go create mode 100644 bindings/go/limbo_test.go create mode 100644 bindings/go/limbo_unix.go create mode 100644 bindings/go/limbo_windows.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 000000000..151ee791c --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,43 @@ +name: Go Tests + +on: + push: + branches: + - main + tags: + - v* + pull_request: + branches: + - main + +env: + working-directory: bindings/go + +jobs: + test: + runs-on: ubuntu-latest + + defaults: + run: + working-directory: ${{ env.working-directory }} + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Rust(stable) + uses: dtolnay/rust-toolchain@stable + + - name: Set up go + uses: actions/setup-go@v4 + with: + go-version: "1.23" + + - name: build Go bindings library + run: cargo build --package limbo-go + + - name: run Go tests + env: + LD_LIBRARY_PATH: ${{ github.workspace }}/target/debug:$LD_LIBRARY_PATH + run: go test + diff --git a/bindings/go/README.md b/bindings/go/README.md new file mode 100644 index 000000000..3dbffdb45 --- /dev/null +++ b/bindings/go/README.md @@ -0,0 +1,41 @@ +## Limbo driver for Go's `database/sql` library + + +**NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state. This is merged in only for the purposes of incremental progress and not because the existing code here proper. Expect many and frequent changes. + +This uses the [purego](https://github.com/ebitengine/purego) library to call C (in this case Rust with C ABI) functions from Go without the use of `CGO`. + + + + +### To test + + +## Linux | MacOS + +_All commands listed are relative to the bindings/go directory in the limbo repository_ + +``` +cargo build --package limbo-go + + +# Your LD_LIBRARY_PATH environment variable must include limbo's `target/debug` directory + +LD_LIBRARY_PATH="../../target/debug:$LD_LIBRARY_PATH" go test + +``` + + +## Windows + +``` +cargo build --package limbo-go + +# Copy the lib_limbo_go.dll into the current working directory (bindings/go) +# Alternatively, you could add the .dll to a location in your PATH + +cp ../../target/debug/lib_limbo_go.dll . + +go test + +``` diff --git a/bindings/go/connection.go b/bindings/go/connection.go new file mode 100644 index 000000000..8c45824e8 --- /dev/null +++ b/bindings/go/connection.go @@ -0,0 +1,90 @@ +package limbo + +import ( + "database/sql/driver" + "errors" + "fmt" + "unsafe" + + "github.com/ebitengine/purego" +) + +const ( + driverName = "sqlite3" + libName = "lib_limbo_go" +) + +var limboLib uintptr + +type limboDriver struct{} + +func (d limboDriver) Open(name string) (driver.Conn, error) { + return openConn(name) +} + +func toCString(s string) uintptr { + b := append([]byte(s), 0) + return uintptr(unsafe.Pointer(&b[0])) +} + +// helper to register an FFI function in the lib_limbo_go library +func getFfiFunc(ptr interface{}, name string) { + purego.RegisterLibFunc(ptr, limboLib, name) +} + +// TODO: sync primitives +type limboConn struct { + ctx uintptr + prepare func(uintptr, string) uintptr +} + +func newConn(ctx uintptr) *limboConn { + var prepare func(uintptr, string) uintptr + getFfiFunc(&prepare, FfiDbPrepare) + return &limboConn{ + ctx, + prepare, + } +} + +func openConn(dsn string) (*limboConn, error) { + var dbOpen func(string) uintptr + getFfiFunc(&dbOpen, FfiDbOpen) + + ctx := dbOpen(dsn) + if ctx == 0 { + return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) + } + return newConn(ctx), nil +} + +func (c *limboConn) Close() error { + if c.ctx == 0 { + return nil + } + var dbClose func(uintptr) uintptr + getFfiFunc(&dbClose, FfiDbClose) + + dbClose(c.ctx) + c.ctx = 0 + return nil +} + +func (c *limboConn) Prepare(query string) (driver.Stmt, error) { + if c.ctx == 0 { + return nil, errors.New("connection closed") + } + if c.prepare == nil { + panic("prepare function not set") + } + stmtPtr := c.prepare(c.ctx, query) + if stmtPtr == 0 { + return nil, fmt.Errorf("failed to prepare query=%q", query) + } + return initStmt(stmtPtr, query), nil +} + +// begin is needed to implement driver.Conn.. for now not implemented +func (c *limboConn) Begin() (driver.Tx, error) { + return nil, errors.New("transactions not implemented") +} diff --git a/bindings/go/go.mod b/bindings/go/go.mod index 589b9a0e3..e49ba4c96 100644 --- a/bindings/go/go.mod +++ b/bindings/go/go.mod @@ -4,5 +4,5 @@ go 1.23.4 require ( github.com/ebitengine/purego v0.8.2 - golang.org/x/sys/windows v0.29.0 + golang.org/x/sys v0.29.0 ) diff --git a/bindings/go/limbo.go b/bindings/go/limbo.go deleted file mode 100644 index 4011fb1ac..000000000 --- a/bindings/go/limbo.go +++ /dev/null @@ -1,141 +0,0 @@ -package limbo - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "log/slog" - "os" - "runtime" - "sync" - "unsafe" - - "github.com/ebitengine/purego" - "golang.org/x/sys/windows" -) - -const limbo = "../../target/debug/lib_limbo_go" -const driverName = "limbo" - -var limboLib uintptr - -func getSystemLibrary() error { - switch runtime.GOOS { - case "darwin": - slib, err := purego.Dlopen(fmt.Sprintf("%s.dylib", limbo), purego.RTLD_LAZY) - if err != nil { - return err - } - limboLib = slib - case "linux": - slib, err := purego.Dlopen(fmt.Sprintf("%s.so", limbo), purego.RTLD_LAZY) - if err != nil { - return err - } - limboLib = slib - case "windows": - slib, err := windows.LoadLibrary(fmt.Sprintf("%s.dll", limbo)) - if err != nil { - return err - } - limboLib = slib - default: - panic(fmt.Errorf("GOOS=%s is not supported", runtime.GOOS)) - } - return nil -} - -func init() { - err := getSystemLibrary() - if err != nil { - slog.Error("Error opening limbo library: ", err) - os.Exit(1) - } - sql.Register(driverName, &limboDriver{}) -} - -type limboDriver struct{} - -func (d limboDriver) Open(name string) (driver.Conn, error) { - return openConn(name) -} - -func toCString(s string) uintptr { - b := append([]byte(s), 0) - return uintptr(unsafe.Pointer(&b[0])) -} - -// helper to register an FFI function in the lib_limbo_go library -func getFfiFunc(ptr interface{}, name string) { - purego.RegisterLibFunc(&ptr, limboLib, name) -} - -type limboConn struct { - ctx uintptr - sync.Mutex - prepare func(uintptr, uintptr) uintptr -} - -func newConn(ctx uintptr) *limboConn { - var prepare func(uintptr, uintptr) uintptr - getFfiFunc(&prepare, FfiDbPrepare) - return &limboConn{ - ctx, - sync.Mutex{}, - prepare, - } -} - -func openConn(dsn string) (*limboConn, error) { - var dbOpen func(uintptr) uintptr - getFfiFunc(&dbOpen, FfiDbOpen) - - cStr := toCString(dsn) - defer freeCString(cStr) - - ctx := dbOpen(cStr) - if ctx == 0 { - return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) - } - return &limboConn{ctx: ctx}, nil -} - -func (c *limboConn) Close() error { - if c.ctx == 0 { - return nil - } - var dbClose func(uintptr) uintptr - getFfiFunc(&dbClose, FfiDbClose) - - dbClose(c.ctx) - c.ctx = 0 - return nil -} - -func (c *limboConn) Prepare(query string) (driver.Stmt, error) { - if c.ctx == 0 { - return nil, errors.New("connection closed") - } - if c.prepare == nil { - var dbPrepare func(uintptr, uintptr) uintptr - getFfiFunc(&dbPrepare, FfiDbPrepare) - c.prepare = dbPrepare - } - qPtr := toCString(query) - stmtPtr := c.prepare(c.ctx, qPtr) - freeCString(qPtr) - - if stmtPtr == 0 { - return nil, fmt.Errorf("prepare failed: %q", query) - } - return &limboStmt{ - ctx: stmtPtr, - sql: query, - }, nil -} - -// begin is needed to implement driver.Conn.. for now not implemented -func (c *limboConn) Begin() (driver.Tx, error) { - return nil, errors.New("transactions not implemented") -} diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go new file mode 100644 index 000000000..6aaf3f95c --- /dev/null +++ b/bindings/go/limbo_test.go @@ -0,0 +1,137 @@ +package limbo_test + +import ( + "database/sql" + "testing" + + _ "limbo" +) + +func TestConnection(t *testing.T) { + conn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening database: %v", err) + } + defer conn.Close() +} + +func TestCreateTable(t *testing.T) { + conn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening database: %v", err) + } + defer conn.Close() + + err = createTable(conn) + if err != nil { + t.Fatalf("Error creating table: %v", err) + } +} + +func TestInsertData(t *testing.T) { + conn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening database: %v", err) + } + defer conn.Close() + + err = createTable(conn) + if err != nil { + t.Fatalf("Error creating table: %v", err) + } + + err = insertData(conn) + if err != nil { + t.Fatalf("Error inserting data: %v", err) + } +} + +func TestQuery(t *testing.T) { + conn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening database: %v", err) + } + defer conn.Close() + + err = createTable(conn) + if err != nil { + t.Fatalf("Error creating table: %v", err) + } + + err = insertData(conn) + if err != nil { + t.Fatalf("Error inserting data: %v", err) + } + + query := "SELECT * FROM test;" + stmt, err := conn.Prepare(query) + if err != nil { + t.Fatalf("Error preparing query: %v", err) + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + + expectedCols := []string{"foo", "bar"} + cols, err := rows.Columns() + if err != nil { + t.Fatalf("Error getting columns: %v", err) + } + if len(cols) != len(expectedCols) { + t.Fatalf("Expected %d columns, got %d", len(expectedCols), len(cols)) + } + for i, col := range cols { + if col != expectedCols[i] { + t.Errorf("Expected column %d to be %s, got %s", i, expectedCols[i], col) + } + } + var i = 1 + for rows.Next() { + var a int + var b string + err = rows.Scan(&a, &b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if a != i || b != rowsMap[i] { + t.Fatalf("Expected %d, %s, got %d, %s", i, rowsMap[i], a, b) + } + i++ + } + + if err = rows.Err(); err != nil { + t.Fatalf("Row iteration error: %v", err) + } +} + +var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"} + +func createTable(conn *sql.DB) error { + insert := "CREATE TABLE test (foo INT, bar TEXT);" + stmt, err := conn.Prepare(insert) + if err != nil { + return err + } + defer stmt.Close() + _, err = stmt.Exec() + return err +} + +func insertData(conn *sql.DB) error { + for i := 1; i <= 5; i++ { + insert := "INSERT INTO test (foo, bar) VALUES (?, ?);" + stmt, err := conn.Prepare(insert) + if err != nil { + return err + } + defer stmt.Close() + if _, err = stmt.Exec(i, rowsMap[i]); err != nil { + return err + } + } + return nil +} diff --git a/bindings/go/limbo_unix.go b/bindings/go/limbo_unix.go new file mode 100644 index 000000000..69464fc2d --- /dev/null +++ b/bindings/go/limbo_unix.go @@ -0,0 +1,56 @@ +//go:build linux || darwin + +package limbo + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + + "github.com/ebitengine/purego" +) + +func loadLibrary() error { + var libraryName string + switch runtime.GOOS { + case "darwin": + libraryName = fmt.Sprintf("%s.dylib", libName) + case "linux": + libraryName = fmt.Sprintf("%s.so", libName) + default: + return fmt.Errorf("GOOS=%s is not supported", runtime.GOOS) + } + + libPath := os.Getenv("LD_LIBRARY_PATH") + paths := strings.Split(libPath, ":") + cwd, err := os.Getwd() + if err != nil { + return err + } + paths = append(paths, cwd) + + for _, path := range paths { + libPath := filepath.Join(path, libraryName) + if _, err := os.Stat(libPath); err == nil { + slib, dlerr := purego.Dlopen(libPath, purego.RTLD_LAZY) + if dlerr != nil { + return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr) + } + limboLib = slib + return nil + } + } + return fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName) +} + +func init() { + err := loadLibrary() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + sql.Register("sqlite3", &limboDriver{}) +} diff --git a/bindings/go/limbo_windows.go b/bindings/go/limbo_windows.go new file mode 100644 index 000000000..433ddd051 --- /dev/null +++ b/bindings/go/limbo_windows.go @@ -0,0 +1,47 @@ +//go:build windows + +package limbo + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/windows" +) + +func loadLibrary() error { + libName := fmt.Sprintf("%s.dll", libName) + pathEnv := os.Getenv("PATH") + paths := strings.Split(pathEnv, ";") + + cwd, err := os.Getwd() + if err != nil { + return err + } + paths = append(paths, cwd) + for _, path := range paths { + dllPath := filepath.Join(path, libName) + if _, err := os.Stat(dllPath); err == nil { + slib, loadErr := windows.LoadLibrary(dllPath) + if loadErr != nil { + return fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr) + } + limboLib = uintptr(slib) + return nil + } + } + + return fmt.Errorf("library %s not found in PATH or CWD", libName) +} + +func init() { + err := loadLibrary() + if err != nil { + fmt.Println("Error opening limbo library: ", err) + os.Exit(1) + } + sql.Register("sqlite3", &limboDriver{}) +} diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs index 456d57bdc..189ff84f8 100644 --- a/bindings/go/rs_src/rows.rs +++ b/bindings/go/rs_src/rows.rs @@ -1,22 +1,22 @@ use crate::{ - statement::LimboStatement, types::{LimboValue, ResultCode}, + LimboConn, }; -use limbo_core::{Statement, StepResult, Value}; +use limbo_core::{Row, Statement, StepResult}; use std::ffi::{c_char, c_void}; pub struct LimboRows<'a> { - rows: Statement, - cursor: Option>>, - stmt: Box>, + stmt: Box, + conn: &'a LimboConn, + cursor: Option>, } impl<'a> LimboRows<'a> { - pub fn new(rows: Statement, stmt: Box>) -> Self { + pub fn new(stmt: Statement, conn: &'a LimboConn) -> Self { LimboRows { - rows, - stmt, + stmt: Box::new(stmt), cursor: None, + conn, } } @@ -40,14 +40,14 @@ pub extern "C" fn rows_next(ctx: *mut c_void) -> ResultCode { } let ctx = LimboRows::from_ptr(ctx); - match ctx.rows.step() { + match ctx.stmt.step() { Ok(StepResult::Row(row)) => { - ctx.cursor = Some(row.values); + ctx.cursor = Some(row); ResultCode::Row } Ok(StepResult::Done) => ResultCode::Done, Ok(StepResult::IO) => { - let _ = ctx.stmt.conn.io.run_once(); + let _ = ctx.conn.io.run_once(); ResultCode::Io } Ok(StepResult::Busy) => ResultCode::Busy, @@ -64,7 +64,7 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v let ctx = LimboRows::from_ptr(ctx); if let Some(ref cursor) = ctx.cursor { - if let Some(value) = cursor.get(col_idx) { + if let Some(value) = cursor.values.get(col_idx) { let val = LimboValue::from_value(value); return val.to_ptr(); } @@ -89,7 +89,7 @@ pub extern "C" fn rows_get_columns( } let rows = LimboRows::from_ptr(rows_ptr); let c_strings: Vec = rows - .rows + .stmt .columns() .iter() .map(|name| std::ffi::CString::new(name.as_str()).unwrap()) diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 82fb55648..5337f5fb0 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -16,7 +16,7 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v let stmt = db.conn.prepare(query_str.to_string()); match stmt { - Ok(stmt) => LimboStatement::new(stmt, db).to_ptr(), + Ok(stmt) => LimboStatement::new(Some(stmt), db).to_ptr(), Err(_) => std::ptr::null_mut(), } } @@ -38,12 +38,16 @@ pub extern "C" fn stmt_execute( } else { &[] }; + let mut pool = AllocPool::new(); + let Some(statement) = stmt.statement.as_mut() else { + return ResultCode::Error; + }; for (i, arg) in args.iter().enumerate() { - let val = arg.to_value(&mut stmt.pool); - stmt.statement.bind_at(NonZero::new(i + 1).unwrap(), val); + let val = arg.to_value(&mut pool); + statement.bind_at(NonZero::new(i + 1).unwrap(), val); } loop { - match stmt.statement.step() { + match statement.step() { Ok(StepResult::Row(_)) => { // unexpected row during execution, error out. return ResultCode::Error; @@ -79,7 +83,10 @@ pub extern "C" fn stmt_parameter_count(ctx: *mut c_void) -> i32 { return -1; } let stmt = LimboStatement::from_ptr(ctx); - stmt.statement.parameters_count() as i32 + let Some(statement) = stmt.statement.as_ref() else { + return -1; + }; + statement.parameters_count() as i32 } #[no_mangle] @@ -97,32 +104,43 @@ pub extern "C" fn stmt_query( } else { &[] }; + let mut pool = AllocPool::new(); + let Some(mut statement) = stmt.statement.take() else { + return std::ptr::null_mut(); + }; for (i, arg) in args.iter().enumerate() { - let val = arg.to_value(&mut stmt.pool); - stmt.statement.bind_at(NonZero::new(i + 1).unwrap(), val); - } - match stmt.statement.query() { - Ok(rows) => { - let stmt = unsafe { Box::from_raw(stmt) }; - LimboRows::new(rows, stmt).to_ptr() - } - Err(_) => std::ptr::null_mut(), + let val = arg.to_value(&mut pool); + statement.bind_at(NonZero::new(i + 1).unwrap(), val); } + // ownership of the statement is transfered to the LimboRows object. + LimboRows::new(statement, stmt.conn).to_ptr() } pub struct LimboStatement<'conn> { - pub statement: Statement, + /// If 'query' is ran on the statement, ownership is transfered to the LimboRows object, + /// and this is set to true. `stmt_close` should never be called on a statement that has + /// been used to create a LimboRows object. + pub statement: Option, pub conn: &'conn mut LimboConn, - pub pool: AllocPool, +} + +#[no_mangle] +pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode { + if !ctx.is_null() { + let stmt = LimboStatement::from_ptr(ctx); + if stmt.statement.is_none() { + return ResultCode::Error; + } else { + let _ = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; + return ResultCode::Ok; + } + } + ResultCode::Invalid } impl<'conn> LimboStatement<'conn> { - pub fn new(statement: Statement, conn: &'conn mut LimboConn) -> Self { - LimboStatement { - statement, - conn, - pool: AllocPool::new(), - } + pub fn new(statement: Option, conn: &'conn mut LimboConn) -> Self { + LimboStatement { statement, conn } } #[allow(clippy::wrong_self_convention)] diff --git a/bindings/go/rs_src/types.rs b/bindings/go/rs_src/types.rs index 851212c65..334aa6e77 100644 --- a/bindings/go/rs_src/types.rs +++ b/bindings/go/rs_src/types.rs @@ -14,6 +14,9 @@ pub enum ResultCode { ReadOnly = 8, NoData = 9, Done = 10, + SyntaxErr = 11, + ConstraintViolation = 12, + NoSuchEntity = 13, } #[repr(C)] @@ -55,6 +58,7 @@ pub struct AllocPool { strings: Vec, blobs: Vec>, } + impl AllocPool { pub fn new() -> Self { AllocPool { @@ -82,11 +86,13 @@ pub extern "C" fn free_blob(blob_ptr: *mut c_void) { let _ = Box::from_raw(blob_ptr as *mut Blob); } } + #[allow(dead_code)] impl ValueUnion { fn from_str(s: &str) -> Self { + let cstr = std::ffi::CString::new(s).expect("Failed to create CString"); ValueUnion { - text_ptr: s.as_ptr() as *const c_char, + text_ptr: cstr.into_raw(), } } @@ -121,7 +127,14 @@ impl ValueUnion { } pub fn to_str(&self) -> &str { - unsafe { std::ffi::CStr::from_ptr(self.text_ptr).to_str().unwrap() } + unsafe { + if self.text_ptr.is_null() { + return ""; + } + std::ffi::CStr::from_ptr(self.text_ptr) + .to_str() + .unwrap_or("") + } } pub fn to_bytes(&self) -> &[u8] { @@ -157,16 +170,30 @@ impl LimboValue { } } + // The values we get from Go need to be temporarily owned by the statement until they are bound + // then they can be cleaned up immediately afterwards pub fn to_value<'pool>(&self, pool: &'pool mut AllocPool) -> limbo_core::Value<'pool> { match self.value_type { - ValueType::Integer => limbo_core::Value::Integer(unsafe { self.value.int_val }), - ValueType::Real => limbo_core::Value::Float(unsafe { self.value.real_val }), + ValueType::Integer => { + if unsafe { self.value.int_val == 0 } { + return limbo_core::Value::Null; + } + limbo_core::Value::Integer(unsafe { self.value.int_val }) + } + ValueType::Real => { + if unsafe { self.value.real_val == 0.0 } { + return limbo_core::Value::Null; + } + limbo_core::Value::Float(unsafe { self.value.real_val }) + } ValueType::Text => { + if unsafe { self.value.text_ptr.is_null() } { + return limbo_core::Value::Null; + } let cstr = unsafe { std::ffi::CStr::from_ptr(self.value.text_ptr) }; match cstr.to_str() { Ok(utf8_str) => { let owned = utf8_str.to_owned(); - // statement needs to own these strings, will free when closed let borrowed = pool.add_string(owned); limbo_core::Value::Text(borrowed) } @@ -174,6 +201,9 @@ impl LimboValue { } } ValueType::Blob => { + if unsafe { self.value.blob_ptr.is_null() } { + return limbo_core::Value::Null; + } let blob_ptr = unsafe { self.value.blob_ptr as *const Blob }; if blob_ptr.is_null() { limbo_core::Value::Null diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 30bceefac..5f3632810 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -10,34 +10,55 @@ import ( ) // only construct limboStmt with initStmt function to ensure proper initialization +// inUse tracks whether or not `query` has been called. if inUse > 0, stmt no longer +// owns the underlying data and `rows` is responsible for cleaning it up on close. type limboStmt struct { ctx uintptr sql string - query stmtQueryFn - execute stmtExecuteFn + inUse int + query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr + execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 getParamCount func(uintptr) int32 + closeStmt func(uintptr) int32 } // Initialize/register the FFI function pointers for the statement methods func initStmt(ctx uintptr, sql string) *limboStmt { - var query stmtQueryFn - var execute stmtExecuteFn + var query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr + getFfiFunc(&query, FfiStmtQuery) + var execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 + getFfiFunc(&execute, FfiStmtExec) var getParamCount func(uintptr) int32 - methods := []ExtFunc{{query, FfiStmtQuery}, {execute, FfiStmtExec}, {getParamCount, FfiStmtParameterCount}} - for i := range methods { - methods[i].initFunc() - } + getFfiFunc(&getParamCount, FfiStmtParameterCount) + var closeStmt func(uintptr) int32 + getFfiFunc(&closeStmt, FfiStmtClose) return &limboStmt{ - ctx: uintptr(ctx), - sql: sql, + ctx: uintptr(ctx), + sql: sql, + inUse: 0, + execute: execute, + query: query, + getParamCount: getParamCount, + closeStmt: closeStmt, } } -func (st *limboStmt) NumInput() int { - return int(st.getParamCount(st.ctx)) +func (ls *limboStmt) NumInput() int { + return int(ls.getParamCount(ls.ctx)) } -func (st *limboStmt) Exec(args []driver.Value) (driver.Result, error) { +func (ls *limboStmt) Close() error { + if ls.inUse == 0 { + res := ls.closeStmt(ls.ctx) + if ResultCode(res) != Ok { + return fmt.Errorf("error closing statement: %s", ResultCode(res).String()) + } + } + ls.ctx = 0 + return nil +} + +func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { argArray, err := buildArgs(args) if err != nil { return nil, err @@ -48,9 +69,9 @@ func (st *limboStmt) Exec(args []driver.Value) (driver.Result, error) { argPtr = uintptr(unsafe.Pointer(&argArray[0])) } var changes uint64 - rc := st.execute(st.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) + rc := ls.execute(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) switch ResultCode(rc) { - case Ok: + case Ok, Done: return driver.RowsAffected(changes), nil case Error: return nil, errors.New("error executing statement") @@ -70,23 +91,34 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) { if err != nil { return nil, err } - rowsPtr := st.query(st.ctx, uintptr(unsafe.Pointer(&queryArgs[0])), uint64(len(queryArgs))) + argPtr := uintptr(0) + if len(args) > 0 { + argPtr = uintptr(unsafe.Pointer(&queryArgs[0])) + } + rowsPtr := st.query(st.ctx, argPtr, uint64(len(queryArgs))) if rowsPtr == 0 { return nil, fmt.Errorf("query failed for: %q", st.sql) } + st.inUse++ return initRows(rowsPtr), nil } -func (ts *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { stripped := namedValueToValue(args) argArray, err := getArgsPtr(stripped) if err != nil { return nil, err } - var changes uintptr - res := ts.execute(ts.ctx, argArray, uint64(len(args)), changes) + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + var changes uint64 + res := ls.execute(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) switch ResultCode(res) { - case Ok: + case Ok, Done: + changes := uint64(changes) return driver.RowsAffected(changes), nil case Error: return nil, errors.New("error executing statement") @@ -99,15 +131,25 @@ func (ts *limboStmt) ExecContext(ctx context.Context, query string, args []drive } } -func (st *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { queryArgs, err := buildNamedArgs(args) if err != nil { return nil, err } - rowsPtr := st.query(st.ctx, uintptr(unsafe.Pointer(&queryArgs[0])), uint64(len(queryArgs))) - if rowsPtr == 0 { - return nil, fmt.Errorf("query failed for: %q", st.sql) + argsPtr := uintptr(0) + if len(queryArgs) > 0 { + argsPtr = uintptr(unsafe.Pointer(&queryArgs[0])) } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + rowsPtr := ls.query(ls.ctx, argsPtr, uint64(len(queryArgs))) + if rowsPtr == 0 { + return nil, fmt.Errorf("query failed for: %q", ls.sql) + } + ls.inUse++ return initRows(rowsPtr), nil } @@ -127,19 +169,15 @@ type limboRows struct { // DO NOT construct 'limboRows' without this function func initRows(ctx uintptr) *limboRows { var getCols func(uintptr, *uint) uintptr + getFfiFunc(&getCols, FfiRowsGetColumns) var getValue func(uintptr, int32) uintptr + getFfiFunc(&getValue, FfiRowsGetValue) var closeRows func(uintptr) uintptr + getFfiFunc(&closeRows, FfiRowsClose) var freeCols func(uintptr) uintptr + getFfiFunc(&freeCols, FfiFreeColumns) var next func(uintptr) uintptr - methods := []ExtFunc{ - {getCols, FfiRowsGetColumns}, - {getValue, FfiRowsGetValue}, - {closeRows, FfiRowsClose}, - {freeCols, FfiFreeColumns}, - {next, FfiRowsNext}} - for i := range methods { - methods[i].initFunc() - } + getFfiFunc(&next, FfiRowsNext) return &limboRows{ ctx: ctx, @@ -157,9 +195,6 @@ func (r *limboRows) Columns() []string { colArrayPtr := r.getCols(r.ctx, &columnCount) if colArrayPtr != 0 && columnCount > 0 { r.columns = cArrayToGoStrings(colArrayPtr, columnCount) - if r.freeCols == nil { - getFfiFunc(&r.freeCols, FfiFreeColumns) - } defer r.freeCols(colArrayPtr) } } @@ -177,18 +212,22 @@ func (r *limboRows) Close() error { } func (r *limboRows) Next(dest []driver.Value) error { - status := r.next(r.ctx) - switch ResultCode(status) { - case Row: - for i := range dest { - valPtr := r.getValue(r.ctx, int32(i)) - val := toGoValue(valPtr) - dest[i] = val + for { + status := r.next(r.ctx) + switch ResultCode(status) { + case Row: + for i := range dest { + valPtr := r.getValue(r.ctx, int32(i)) + val := toGoValue(valPtr) + dest[i] = val + } + return nil + case Io: + continue + case Done: + return io.EOF + default: + return fmt.Errorf("unexpected status: %d", status) } - return nil - case Done: - return io.EOF - default: - return fmt.Errorf("unexpected status: %d", status) } } diff --git a/bindings/go/types.go b/bindings/go/types.go index c27832f43..b391d6f0a 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -6,23 +6,63 @@ import ( "unsafe" ) -type ResultCode int +type ResultCode int32 const ( - Error ResultCode = -1 - Ok ResultCode = 0 - Row ResultCode = 1 - Busy ResultCode = 2 - Io ResultCode = 3 - Interrupt ResultCode = 4 - Invalid ResultCode = 5 - Null ResultCode = 6 - NoMem ResultCode = 7 - ReadOnly ResultCode = 8 - NoData ResultCode = 9 - Done ResultCode = 10 + Error ResultCode = -1 + Ok ResultCode = 0 + Row ResultCode = 1 + Busy ResultCode = 2 + Io ResultCode = 3 + Interrupt ResultCode = 4 + Invalid ResultCode = 5 + Null ResultCode = 6 + NoMem ResultCode = 7 + ReadOnly ResultCode = 8 + NoData ResultCode = 9 + Done ResultCode = 10 + SyntaxErr ResultCode = 11 + ConstraintViolation ResultCode = 12 + NoSuchEntity ResultCode = 13 ) +func (rc ResultCode) String() string { + switch rc { + case Error: + return "Error" + case Ok: + return "Ok" + case Row: + return "Row" + case Busy: + return "Busy" + case Io: + return "Io" + case Interrupt: + return "Query was interrupted" + case Invalid: + return "Invalid" + case Null: + return "Null" + case NoMem: + return "Out of memory" + case ReadOnly: + return "Read Only" + case NoData: + return "No Data" + case Done: + return "Done" + case SyntaxErr: + return "Syntax Error" + case ConstraintViolation: + return "Constraint Violation" + case NoSuchEntity: + return "No such entity" + default: + return "Unknown response code" + } +} + const ( FfiDbOpen string = "db_open" FfiDbClose string = "db_close" @@ -30,6 +70,7 @@ const ( FfiStmtExec string = "stmt_execute" FfiStmtQuery string = "stmt_query" FfiStmtParameterCount string = "stmt_parameter_count" + FfiStmtClose string = "stmt_close" FfiRowsClose string = "rows_close" FfiRowsGetColumns string = "rows_get_columns" FfiRowsNext string = "rows_next" @@ -48,35 +89,41 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value { } func buildNamedArgs(named []driver.NamedValue) ([]limboValue, error) { - args := make([]driver.Value, len(named)) - for i, nv := range named { - args[i] = nv.Value - } + args := namedValueToValue(named) return buildArgs(args) } -type ExtFunc struct { - funcPtr interface{} - funcName string -} - -func (ef *ExtFunc) initFunc() { - getFfiFunc(&ef.funcPtr, ef.funcName) -} - -type valueType int +type valueType int32 const ( - intVal valueType = iota - textVal - blobVal - realVal - nullVal + intVal valueType = 0 + textVal valueType = 1 + blobVal valueType = 2 + realVal valueType = 3 + nullVal valueType = 4 ) +func (vt valueType) String() string { + switch vt { + case intVal: + return "int" + case textVal: + return "text" + case blobVal: + return "blob" + case realVal: + return "real" + case nullVal: + return "null" + default: + return "unknown" + } +} + // struct to pass Go values over FFI type limboValue struct { Type valueType + _ [4]byte // padding to align Value to 8 bytes Value [8]byte } @@ -88,6 +135,9 @@ type Blob struct { // convert a limboValue to a native Go value func toGoValue(valPtr uintptr) interface{} { + if valPtr == 0 { + return nil + } val := (*limboValue)(unsafe.Pointer(valPtr)) switch val.Type { case intVal: @@ -139,19 +189,6 @@ func toGoBlob(blobPtr uintptr) []byte { return unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len) } -var freeString func(*byte) - -// free a C style string allocated via FFI -func freeCString(cstr uintptr) { - if cstr == 0 { - return - } - if freeString == nil { - getFfiFunc(&freeString, FfiFreeCString) - } - freeString((*byte)(unsafe.Pointer(cstr))) -} - func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { if arrayPtr == 0 || length == 0 { return nil @@ -172,30 +209,29 @@ func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { // convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI func buildArgs(args []driver.Value) ([]limboValue, error) { argSlice := make([]limboValue, len(args)) - for i, v := range args { + limboVal := limboValue{} switch val := v.(type) { case nil: - argSlice[i].Type = nullVal - + limboVal.Type = nullVal case int64: - argSlice[i].Type = intVal - storeInt64(&argSlice[i].Value, val) - + limboVal.Type = intVal + limboVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) case float64: - argSlice[i].Type = realVal - storeFloat64(&argSlice[i].Value, val) + limboVal.Type = realVal + limboVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) case string: - argSlice[i].Type = textVal + limboVal.Type = textVal cstr := CString(val) - storePointer(&argSlice[i].Value, cstr) + *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr)) case []byte: argSlice[i].Type = blobVal blob := makeBlob(val) - *(*uintptr)(unsafe.Pointer(&argSlice[i].Value)) = uintptr(unsafe.Pointer(blob)) + *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob)) default: return nil, fmt.Errorf("unsupported type: %T", v) } + argSlice[i] = limboVal } return argSlice, nil } @@ -212,9 +248,6 @@ func storePointer(data *[8]byte, ptr *byte) { *(*uintptr)(unsafe.Pointer(data)) = uintptr(unsafe.Pointer(ptr)) } -type stmtExecuteFn func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 -type stmtQueryFn func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr - /* Credit below (Apache2 License) to: https://github.com/ebitengine/purego/blob/main/internal/strings/strings.go */