Merge 'bindings/go: Progress on Go driver, add sync primitives, prevent crashing on concurrent connections' from Preston Thorpe

This PR continues work on the Go bindings.
- Register all symbols from the library at load time to prevent any
repeated `dlsym` calls.
- Add locks to prevent multiple concurrent FFI calls to functions that
act on the same state.
- Adds documentation/example in the go module `README`.
- Fixes memory access issue causing segfault due to passing pointer to
array of strings, that is difficult to work with in Go without the right
primitives. In place, simply return the amount of ResultColumns and Go
can provide the index to receive the column name, similar to
`rowsGetValue`
On next limbo release, I'll add the example to the main `README` next to
the other language examples. Until then, `go get
github.com/tursodatabase/limbo` will not work so the example will remain
in the bindings readme.

Closes #845
This commit is contained in:
Pekka Enberg
2025-02-01 09:25:52 +02:00
13 changed files with 455 additions and 352 deletions

5
.gitignore vendored
View File

@@ -28,4 +28,7 @@ dist/
.DS_Store
# Javascript
**/node_modules/
**/node_modules/
# testing
testing/limbo_output.txt

View File

@@ -1,41 +1,71 @@
## Limbo driver for Go's `database/sql` library
# Limbo driver for Go's `database/sql` library
**NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state. This is merged in only for the purposes of incremental progress and not because the existing code here proper. Expect many and frequent changes.
**NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state.
This uses the [purego](https://github.com/ebitengine/purego) library to call C (in this case Rust with C ABI) functions from Go without the use of `CGO`.
This driver uses the awesome [purego](https://github.com/ebitengine/purego) library to call C (in this case Rust with C ABI) functions from Go without the use of `CGO`.
## To use: (_UNSTABLE_ testing or development purposes only)
### To test
## Linux | MacOS
### Linux | MacOS
_All commands listed are relative to the bindings/go directory in the limbo repository_
```
cargo build --package limbo-go
# Your LD_LIBRARY_PATH environment variable must include limbo's `target/debug` directory
LD_LIBRARY_PATH="../../target/debug:$LD_LIBRARY_PATH" go test
export LD_LIBRARY_PATH="/path/to/limbo/target/debug:$LD_LIBRARY_PATH"
```
## Windows
```
cargo build --package limbo-go
# Copy the lib_limbo_go.dll into the current working directory (bindings/go)
# Alternatively, you could add the .dll to a location in your PATH
# You must add limbo's `target/debug` directory to your PATH
# or you could built + copy the .dll to a location in your PATH
# or just the CWD of your go module
cp ../../target/debug/lib_limbo_go.dll .
cp path\to\limbo\target\debug\lib_limbo_go.dll .
go test
```
**Temporarily** you may have to clone the limbo repository and run:
`go mod edit -replace github.com/tursodatabase/limbo=/path/to/limbo/bindings/go`
```go
import (
"fmt"
"database/sql"
_"github.com/tursodatabase/limbo"
)
func main() {
conn, err := sql.Open("sqlite3", ":memory:")
if err != nil {
fmt.Printf("Error: %v\n", err)
os.Exit(1)
}
sql := "CREATE table go_limbo (foo INTEGER, bar TEXT)"
_ = conn.Exec(sql)
sql = "INSERT INTO go_limbo (foo, bar) values (?, ?)"
stmt, _ := conn.Prepare(sql)
defer stmt.Close()
_ = stmt.Exec(42, "limbo")
rows, _ := conn.Query("SELECT * from go_limbo")
defer rows.Close()
for rows.Next() {
var a int
var b string
_ = rows.Scan(&a, &b)
fmt.Printf("%d, %s", a, b)
}
}
```

View File

@@ -1,71 +1,105 @@
package limbo
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"unsafe"
"sync"
"github.com/ebitengine/purego"
)
const (
driverName = "sqlite3"
libName = "lib_limbo_go"
func init() {
err := ensureLibLoaded()
if err != nil {
panic(err)
}
sql.Register(driverName, &limboDriver{})
}
type limboDriver struct {
sync.Mutex
}
var (
libOnce sync.Once
limboLib uintptr
loadErr error
dbOpen func(string) uintptr
dbClose func(uintptr) uintptr
connPrepare func(uintptr, string) uintptr
freeBlobFunc func(uintptr)
freeStringFunc func(uintptr)
rowsGetColumns func(uintptr) int32
rowsGetColumnName func(uintptr, int32) uintptr
rowsGetValue func(uintptr, int32) uintptr
closeRows func(uintptr) uintptr
rowsNext func(uintptr) uintptr
stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
stmtParamCount func(uintptr) int32
closeStmt func(uintptr) int32
)
var limboLib uintptr
type limboDriver struct{}
func (d limboDriver) Open(name string) (driver.Conn, error) {
return openConn(name)
// Register all the symbols on library load
func ensureLibLoaded() error {
libOnce.Do(func() {
limboLib, loadErr = loadLibrary()
if loadErr != nil {
return
}
purego.RegisterLibFunc(&dbOpen, limboLib, FfiDbOpen)
purego.RegisterLibFunc(&dbClose, limboLib, FfiDbClose)
purego.RegisterLibFunc(&connPrepare, limboLib, FfiDbPrepare)
purego.RegisterLibFunc(&freeBlobFunc, limboLib, FfiFreeBlob)
purego.RegisterLibFunc(&freeStringFunc, limboLib, FfiFreeCString)
purego.RegisterLibFunc(&rowsGetColumns, limboLib, FfiRowsGetColumns)
purego.RegisterLibFunc(&rowsGetColumnName, limboLib, FfiRowsGetColumnName)
purego.RegisterLibFunc(&rowsGetValue, limboLib, FfiRowsGetValue)
purego.RegisterLibFunc(&closeRows, limboLib, FfiRowsClose)
purego.RegisterLibFunc(&rowsNext, limboLib, FfiRowsNext)
purego.RegisterLibFunc(&stmtQuery, limboLib, FfiStmtQuery)
purego.RegisterLibFunc(&stmtExec, limboLib, FfiStmtExec)
purego.RegisterLibFunc(&stmtParamCount, limboLib, FfiStmtParameterCount)
purego.RegisterLibFunc(&closeStmt, limboLib, FfiStmtClose)
})
return loadErr
}
func toCString(s string) uintptr {
b := append([]byte(s), 0)
return uintptr(unsafe.Pointer(&b[0]))
}
// helper to register an FFI function in the lib_limbo_go library
func getFfiFunc(ptr interface{}, name string) {
purego.RegisterLibFunc(ptr, limboLib, name)
}
// TODO: sync primitives
type limboConn struct {
ctx uintptr
prepare func(uintptr, string) uintptr
}
func newConn(ctx uintptr) *limboConn {
var prepare func(uintptr, string) uintptr
getFfiFunc(&prepare, FfiDbPrepare)
return &limboConn{
ctx,
prepare,
func (d *limboDriver) Open(name string) (driver.Conn, error) {
d.Lock()
conn, err := openConn(name)
d.Unlock()
if err != nil {
return nil, err
}
return conn, nil
}
type limboConn struct {
sync.Mutex
ctx uintptr
}
func openConn(dsn string) (*limboConn, error) {
var dbOpen func(string) uintptr
getFfiFunc(&dbOpen, FfiDbOpen)
ctx := dbOpen(dsn)
if ctx == 0 {
return nil, fmt.Errorf("failed to open database for dsn=%q", dsn)
}
return newConn(ctx), nil
return &limboConn{
sync.Mutex{},
ctx,
}, loadErr
}
func (c *limboConn) Close() error {
if c.ctx == 0 {
return nil
}
var dbClose func(uintptr) uintptr
getFfiFunc(&dbClose, FfiDbClose)
c.Lock()
dbClose(c.ctx)
c.Unlock()
c.ctx = 0
return nil
}
@@ -74,14 +108,13 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) {
if c.ctx == 0 {
return nil, errors.New("connection closed")
}
if c.prepare == nil {
panic("prepare function not set")
}
stmtPtr := c.prepare(c.ctx, query)
c.Lock()
defer c.Unlock()
stmtPtr := connPrepare(c.ctx, query)
if stmtPtr == 0 {
return nil, fmt.Errorf("failed to prepare query=%q", query)
}
return initStmt(stmtPtr, query), nil
return newStmt(stmtPtr, query), nil
}
// begin is needed to implement driver.Conn.. for now not implemented

View File

@@ -1,4 +1,4 @@
module limbo
module github.com/tursodatabase/limbo
go 1.23.4

View File

@@ -3,67 +3,36 @@ package limbo_test
import (
"database/sql"
"fmt"
"log"
"testing"
_ "limbo"
_ "github.com/tursodatabase/limbo"
)
func TestConnection(t *testing.T) {
conn, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening database: %v", err)
var conn *sql.DB
var connErr error
func TestMain(m *testing.M) {
conn, connErr = sql.Open("sqlite3", ":memory:")
if connErr != nil {
panic(connErr)
}
defer conn.Close()
}
func TestCreateTable(t *testing.T) {
conn, err := sql.Open("sqlite3", ":memory:")
err := createTable(conn)
if err != nil {
t.Fatalf("Error opening database: %v", err)
}
defer conn.Close()
err = createTable(conn)
if err != nil {
t.Fatalf("Error creating table: %v", err)
log.Fatalf("Error creating table: %v", err)
}
m.Run()
}
func TestInsertData(t *testing.T) {
conn, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening database: %v", err)
}
defer conn.Close()
err = createTable(conn)
if err != nil {
t.Fatalf("Error creating table: %v", err)
}
err = insertData(conn)
err := insertData(conn)
if err != nil {
t.Fatalf("Error inserting data: %v", err)
}
}
func TestQuery(t *testing.T) {
conn, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening database: %v", err)
}
defer conn.Close()
err = createTable(conn)
if err != nil {
t.Fatalf("Error creating table: %v", err)
}
err = insertData(conn)
if err != nil {
t.Fatalf("Error inserting data: %v", err)
}
query := "SELECT * FROM test;"
stmt, err := conn.Prepare(query)
if err != nil {
@@ -99,8 +68,8 @@ func TestQuery(t *testing.T) {
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
if a != i || b != rowsMap[i] || string(c) != rowsMap[i] {
t.Fatalf("Expected %d, %s, %s, got %d, %s, %b", i, rowsMap[i], rowsMap[i], a, b, c)
if a != i || b != rowsMap[i] || !slicesAreEq(c, []byte(rowsMap[i])) {
t.Fatalf("Expected %d, %s, %s, got %d, %s, %s", i, rowsMap[i], rowsMap[i], a, b, string(c))
}
fmt.Println("RESULTS: ", a, b, string(c))
i++
@@ -109,6 +78,145 @@ func TestQuery(t *testing.T) {
if err = rows.Err(); err != nil {
t.Fatalf("Row iteration error: %v", err)
}
}
func TestFunctions(t *testing.T) {
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));"
stmt, err := conn.Prepare(insert)
if err != nil {
t.Fatalf("Error preparing statement: %v", err)
}
_, err = stmt.Exec(60, "TestFunction", 400)
if err != nil {
t.Fatalf("Error executing statment with arguments: %v", err)
}
stmt.Close()
stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?")
if err != nil {
t.Fatalf("Error preparing select stmt: %v", err)
}
defer stmt.Close()
rows, err := stmt.Query(60)
if err != nil {
t.Fatalf("Error executing select stmt: %v", err)
}
defer rows.Close()
for rows.Next() {
var b []byte
err = rows.Scan(&b)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
if len(b) != 400 {
t.Fatalf("Expected 100 bytes, got %d", len(b))
}
}
sql := "SELECT uuid4_str();"
stmt, err = conn.Prepare(sql)
if err != nil {
t.Fatalf("Error preparing statement: %v", err)
}
defer stmt.Close()
rows, err = stmt.Query()
if err != nil {
t.Fatalf("Error executing query: %v", err)
}
defer rows.Close()
var i int
for rows.Next() {
var b string
err = rows.Scan(&b)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
if len(b) != 36 {
t.Fatalf("Expected 36 bytes, got %d", len(b))
}
i++
fmt.Printf("uuid: %s\n", b)
}
if i != 1 {
t.Fatalf("Expected 1 row, got %d", i)
}
fmt.Println("zeroblob + uuid functions passed")
}
func TestDuplicateConnection(t *testing.T) {
newConn, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening new connection: %v", err)
}
err = createTable(newConn)
if err != nil {
t.Fatalf("Error creating table: %v", err)
}
err = insertData(newConn)
if err != nil {
t.Fatalf("Error inserting data: %v", err)
}
query := "SELECT * FROM test;"
rows, err := newConn.Query(query)
if err != nil {
t.Fatalf("Error executing query: %v", err)
}
defer rows.Close()
for rows.Next() {
var a int
var b string
var c []byte
err = rows.Scan(&a, &b, &c)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
fmt.Println("RESULTS: ", a, b, string(c))
}
}
func TestDuplicateConnection2(t *testing.T) {
newConn, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening new connection: %v", err)
}
sql := "CREATE TABLE test (foo INTEGER, bar INTEGER, baz BLOB);"
newConn.Exec(sql)
sql = "INSERT INTO test (foo, bar, baz) VALUES (?, ?, uuid4());"
stmt, err := newConn.Prepare(sql)
stmt.Exec(242345, 2342434)
defer stmt.Close()
query := "SELECT * FROM test;"
rows, err := newConn.Query(query)
if err != nil {
t.Fatalf("Error executing query: %v", err)
}
defer rows.Close()
for rows.Next() {
var a int
var b int
var c []byte
err = rows.Scan(&a, &b, &c)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
fmt.Println("RESULTS: ", a, b, string(c))
if len(c) != 16 {
t.Fatalf("Expected 16 bytes, got %d", len(c))
}
}
}
func slicesAreEq(a, b []byte) bool {
if len(a) != len(b) {
fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))
return false
}
for i := range a {
if a[i] != b[i] {
fmt.Printf("SLICES NOT EQUAL: %v != %v\n", a, b)
return false
}
}
return true
}
var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"}

View File

@@ -3,7 +3,6 @@
package limbo
import (
"database/sql"
"fmt"
"os"
"path/filepath"
@@ -13,7 +12,7 @@ import (
"github.com/ebitengine/purego"
)
func loadLibrary() error {
func loadLibrary() (uintptr, error) {
var libraryName string
switch runtime.GOOS {
case "darwin":
@@ -21,14 +20,14 @@ func loadLibrary() error {
case "linux":
libraryName = fmt.Sprintf("%s.so", libName)
default:
return fmt.Errorf("GOOS=%s is not supported", runtime.GOOS)
return 0, fmt.Errorf("GOOS=%s is not supported", runtime.GOOS)
}
libPath := os.Getenv("LD_LIBRARY_PATH")
paths := strings.Split(libPath, ":")
cwd, err := os.Getwd()
if err != nil {
return err
return 0, err
}
paths = append(paths, cwd)
@@ -37,20 +36,10 @@ func loadLibrary() error {
if _, err := os.Stat(libPath); err == nil {
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if dlerr != nil {
return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr)
return 0, fmt.Errorf("failed to load library at %s: %w", libPath, dlerr)
}
limboLib = slib
return nil
return slib, nil
}
}
return fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName)
}
func init() {
err := loadLibrary()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
sql.Register("sqlite3", &limboDriver{})
return 0, fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName)
}

View File

@@ -3,7 +3,6 @@
package limbo
import (
"database/sql"
"fmt"
"os"
"path/filepath"
@@ -12,14 +11,14 @@ import (
"golang.org/x/sys/windows"
)
func loadLibrary() error {
func loadLibrary() (uintptr, error) {
libName := fmt.Sprintf("%s.dll", libName)
pathEnv := os.Getenv("PATH")
paths := strings.Split(pathEnv, ";")
cwd, err := os.Getwd()
if err != nil {
return err
return 0, err
}
paths = append(paths, cwd)
for _, path := range paths {
@@ -27,21 +26,11 @@ func loadLibrary() error {
if _, err := os.Stat(dllPath); err == nil {
slib, loadErr := windows.LoadLibrary(dllPath)
if loadErr != nil {
return fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr)
return 0, fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr)
}
limboLib = uintptr(slib)
return nil
return uintptr(slib), nil
}
}
return fmt.Errorf("library %s not found in PATH or CWD", libName)
}
func init() {
err := loadLibrary()
if err != nil {
fmt.Println("Error opening limbo library: ", err)
os.Exit(1)
}
sql.Register("sqlite3", &limboDriver{})
return 0, fmt.Errorf("library %s not found in PATH or CWD", libName)
}

83
bindings/go/rows.go Normal file
View File

@@ -0,0 +1,83 @@
package limbo
import (
"database/sql/driver"
"fmt"
"io"
"sync"
)
type limboRows struct {
mu sync.Mutex
ctx uintptr
columns []string
closed bool
}
func newRows(ctx uintptr) *limboRows {
return &limboRows{
mu: sync.Mutex{},
ctx: ctx,
closed: false,
columns: nil,
}
}
func (r *limboRows) Columns() []string {
if r.ctx == 0 || r.closed {
return nil
}
if r.columns == nil {
r.mu.Lock()
count := rowsGetColumns(r.ctx)
if count > 0 {
columns := make([]string, 0, count)
for i := 0; i < int(count); i++ {
cstr := rowsGetColumnName(r.ctx, int32(i))
columns = append(columns, fmt.Sprintf("%s", GoString(cstr)))
freeCString(cstr)
}
r.mu.Unlock()
r.columns = columns
}
}
return r.columns
}
func (r *limboRows) Close() error {
if r.closed {
return nil
}
r.mu.Lock()
r.closed = true
closeRows(r.ctx)
r.ctx = 0
r.mu.Unlock()
return nil
}
func (r *limboRows) Next(dest []driver.Value) error {
if r.ctx == 0 || r.closed {
return io.EOF
}
r.mu.Lock()
defer r.mu.Unlock()
for {
status := rowsNext(r.ctx)
switch ResultCode(status) {
case Row:
for i := range dest {
valPtr := rowsGetValue(r.ctx, int32(i))
val := toGoValue(valPtr)
dest[i] = val
}
return nil
case Io:
continue
case Done:
return io.EOF
default:
return fmt.Errorf("unexpected status: %d", status)
}
}
}

View File

@@ -26,7 +26,6 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void {
let db = Database::open_file(io.clone(), &db_options.path.to_string());
match db {
Ok(db) => {
println!("Opened database: {}", path);
let conn = db.connect();
return LimboConn::new(conn, io).to_ptr();
}
@@ -45,16 +44,17 @@ struct LimboConn {
io: Arc<dyn limbo_core::IO>,
}
impl LimboConn {
impl<'conn> LimboConn {
fn new(conn: Rc<Connection>, io: Arc<dyn limbo_core::IO>) -> Self {
LimboConn { conn, io }
}
#[allow(clippy::wrong_self_convention)]
fn to_ptr(self) -> *mut c_void {
Box::into_raw(Box::new(self)) as *mut c_void
}
fn from_ptr(ptr: *mut c_void) -> &'static mut LimboConn {
fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboConn {
if ptr.is_null() {
panic!("Null pointer");
}

View File

@@ -78,30 +78,33 @@ pub extern "C" fn free_string(s: *mut c_char) {
}
}
/// Function to get the number of expected ResultColumns in the prepared statement.
/// to avoid the needless complexity of returning an array of strings, this instead
/// works like rows_next/rows_get_value
#[no_mangle]
pub extern "C" fn rows_get_columns(
rows_ptr: *mut c_void,
out_length: *mut usize,
) -> *mut *const c_char {
if rows_ptr.is_null() || out_length.is_null() {
pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 {
if rows_ptr.is_null() {
return -1;
}
let rows = LimboRows::from_ptr(rows_ptr);
rows.stmt.columns().len() as i32
}
/// Returns a pointer to a string with the name of the column at the given index.
/// The caller is responsible for freeing the memory, it should be copied on the Go side
/// immediately and 'free_string' called
#[no_mangle]
pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *const c_char {
if rows_ptr.is_null() {
return std::ptr::null_mut();
}
let rows = LimboRows::from_ptr(rows_ptr);
let c_strings: Vec<std::ffi::CString> = rows
.stmt
.columns()
.iter()
.map(|name| std::ffi::CString::new(name.as_str()).unwrap())
.collect();
let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
unsafe {
*out_length = c_ptrs.len();
if idx < 0 || idx as usize >= rows.stmt.columns().len() {
return std::ptr::null_mut();
}
let ptr = c_ptrs.as_ptr();
std::mem::forget(c_strings);
std::mem::forget(c_ptrs);
ptr as *mut *const c_char
let name = &rows.stmt.columns()[idx as usize];
let cstr = std::ffi::CString::new(name.as_bytes()).expect("Failed to create CString");
cstr.into_raw() as *const c_char
}
#[no_mangle]
@@ -111,21 +114,6 @@ pub extern "C" fn rows_close(rows_ptr: *mut c_void) {
}
}
#[no_mangle]
pub extern "C" fn free_columns(columns: *mut *const c_char) {
if columns.is_null() {
return;
}
unsafe {
let mut idx = 0;
while !(*columns.add(idx)).is_null() {
let _ = std::ffi::CString::from_raw(*columns.add(idx) as *mut c_char);
idx += 1;
}
let _ = Box::from_raw(columns);
}
}
#[no_mangle]
pub extern "C" fn free_rows(rows: *mut c_void) {
if rows.is_null() {

View File

@@ -13,10 +13,9 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v
let query_str = unsafe { std::ffi::CStr::from_ptr(query) }.to_str().unwrap();
let db = LimboConn::from_ptr(ctx);
let stmt = db.conn.prepare(query_str);
match stmt {
Ok(stmt) => LimboStatement::new(Some(stmt), db).to_ptr(),
Ok(stmt) => LimboStatement::new(Some(stmt), LimboConn::from_ptr(ctx)).to_ptr(),
Err(_) => std::ptr::null_mut(),
}
}
@@ -53,10 +52,10 @@ pub extern "C" fn stmt_execute(
return ResultCode::Error;
}
Ok(StepResult::Done) => {
stmt.conn.conn.total_changes();
let total_changes = stmt.conn.conn.total_changes();
if !changes.is_null() {
unsafe {
*changes = stmt.conn.conn.total_changes();
*changes = total_changes;
}
}
return ResultCode::Done;
@@ -127,13 +126,9 @@ pub struct LimboStatement<'conn> {
#[no_mangle]
pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode {
if !ctx.is_null() {
let stmt = LimboStatement::from_ptr(ctx);
if stmt.statement.is_none() {
return ResultCode::Error;
} else {
let _ = unsafe { Box::from_raw(ctx as *mut LimboStatement) };
return ResultCode::Ok;
}
let stmt = unsafe { Box::from_raw(ctx as *mut LimboStatement) };
drop(stmt);
return ResultCode::Ok;
}
ResultCode::Invalid
}
@@ -148,7 +143,7 @@ impl<'conn> LimboStatement<'conn> {
Box::into_raw(Box::new(self)) as *mut c_void
}
fn from_ptr(ptr: *mut c_void) -> &'static mut LimboStatement<'conn> {
fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboStatement<'conn> {
if ptr.is_null() {
panic!("Null pointer");
}

View File

@@ -5,54 +5,35 @@ import (
"database/sql/driver"
"errors"
"fmt"
"io"
"sync"
"unsafe"
)
// only construct limboStmt with initStmt function to ensure proper initialization
// inUse tracks whether or not `query` has been called. if inUse > 0, stmt no longer
// owns the underlying data and `rows` is responsible for cleaning it up on close.
type limboStmt struct {
ctx uintptr
sql string
inUse int
query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
getParamCount func(uintptr) int32
closeStmt func(uintptr) int32
mu sync.Mutex
ctx uintptr
sql string
}
// Initialize/register the FFI function pointers for the statement methods
func initStmt(ctx uintptr, sql string) *limboStmt {
var query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
getFfiFunc(&query, FfiStmtQuery)
var execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
getFfiFunc(&execute, FfiStmtExec)
var getParamCount func(uintptr) int32
getFfiFunc(&getParamCount, FfiStmtParameterCount)
var closeStmt func(uintptr) int32
getFfiFunc(&closeStmt, FfiStmtClose)
func newStmt(ctx uintptr, sql string) *limboStmt {
return &limboStmt{
ctx: uintptr(ctx),
sql: sql,
inUse: 0,
execute: execute,
query: query,
getParamCount: getParamCount,
closeStmt: closeStmt,
ctx: uintptr(ctx),
sql: sql,
}
}
func (ls *limboStmt) NumInput() int {
return int(ls.getParamCount(ls.ctx))
ls.mu.Lock()
defer ls.mu.Unlock()
return int(stmtParamCount(ls.ctx))
}
func (ls *limboStmt) Close() error {
if ls.inUse == 0 {
res := ls.closeStmt(ls.ctx)
if ResultCode(res) != Ok {
return fmt.Errorf("error closing statement: %s", ResultCode(res).String())
}
ls.mu.Lock()
res := closeStmt(ls.ctx)
ls.mu.Unlock()
if ResultCode(res) != Ok {
return fmt.Errorf("error closing statement: %s", ResultCode(res).String())
}
ls.ctx = 0
return nil
@@ -70,7 +51,9 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
argPtr = uintptr(unsafe.Pointer(&argArray[0]))
}
var changes uint64
rc := ls.execute(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes)))
ls.mu.Lock()
rc := stmtExec(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes)))
ls.mu.Unlock()
switch ResultCode(rc) {
case Ok, Done:
return driver.RowsAffected(changes), nil
@@ -87,7 +70,7 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
}
}
func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
queryArgs, cleanup, err := buildArgs(args)
defer cleanup()
if err != nil {
@@ -97,12 +80,13 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
if len(args) > 0 {
argPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
}
rowsPtr := st.query(st.ctx, argPtr, uint64(len(queryArgs)))
ls.mu.Lock()
rowsPtr := stmtQuery(ls.ctx, argPtr, uint64(len(queryArgs)))
ls.mu.Unlock()
if rowsPtr == 0 {
return nil, fmt.Errorf("query failed for: %q", st.sql)
return nil, fmt.Errorf("query failed for: %q", ls.sql)
}
st.inUse++
return initRows(rowsPtr), nil
return newRows(rowsPtr), nil
}
func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
@@ -118,7 +102,9 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive
default:
}
var changes uint64
res := ls.execute(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes)))
ls.mu.Lock()
res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes)))
ls.mu.Unlock()
switch ResultCode(res) {
case Ok, Done:
changes := uint64(changes)
@@ -149,89 +135,11 @@ func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue)
return nil, ctx.Err()
default:
}
rowsPtr := ls.query(ls.ctx, argsPtr, uint64(len(queryArgs)))
ls.mu.Lock()
rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs)))
ls.mu.Unlock()
if rowsPtr == 0 {
return nil, fmt.Errorf("query failed for: %q", ls.sql)
}
ls.inUse++
return initRows(rowsPtr), nil
}
// only construct limboRows with initRows function to ensure proper initialization
type limboRows struct {
ctx uintptr
columns []string
closed bool
getCols func(uintptr, *uint) uintptr
next func(uintptr) uintptr
getValue func(uintptr, int32) uintptr
closeRows func(uintptr) uintptr
freeCols func(uintptr) uintptr
}
// Initialize/register the FFI function pointers for the rows methods
// DO NOT construct 'limboRows' without this function
func initRows(ctx uintptr) *limboRows {
var getCols func(uintptr, *uint) uintptr
getFfiFunc(&getCols, FfiRowsGetColumns)
var getValue func(uintptr, int32) uintptr
getFfiFunc(&getValue, FfiRowsGetValue)
var closeRows func(uintptr) uintptr
getFfiFunc(&closeRows, FfiRowsClose)
var freeCols func(uintptr) uintptr
getFfiFunc(&freeCols, FfiFreeColumns)
var next func(uintptr) uintptr
getFfiFunc(&next, FfiRowsNext)
return &limboRows{
ctx: ctx,
getCols: getCols,
getValue: getValue,
closeRows: closeRows,
freeCols: freeCols,
next: next,
}
}
func (r *limboRows) Columns() []string {
if r.columns == nil {
var columnCount uint
colArrayPtr := r.getCols(r.ctx, &columnCount)
if colArrayPtr != 0 && columnCount > 0 {
r.columns = cArrayToGoStrings(colArrayPtr, columnCount)
defer r.freeCols(colArrayPtr)
}
}
return r.columns
}
func (r *limboRows) Close() error {
if r.closed {
return nil
}
r.closed = true
r.closeRows(r.ctx)
r.ctx = 0
return nil
}
func (r *limboRows) Next(dest []driver.Value) error {
for {
status := r.next(r.ctx)
switch ResultCode(status) {
case Row:
for i := range dest {
valPtr := r.getValue(r.ctx, int32(i))
val := toGoValue(valPtr)
dest[i] = val
}
return nil
case Io:
continue
case Done:
return io.EOF
default:
return fmt.Errorf("unexpected status: %d", status)
}
}
return newRows(rowsPtr), nil
}

View File

@@ -65,20 +65,23 @@ func (rc ResultCode) String() string {
}
const (
FfiDbOpen string = "db_open"
FfiDbClose string = "db_close"
FfiDbPrepare string = "db_prepare"
FfiStmtExec string = "stmt_execute"
FfiStmtQuery string = "stmt_query"
FfiStmtParameterCount string = "stmt_parameter_count"
FfiStmtClose string = "stmt_close"
FfiRowsClose string = "rows_close"
FfiRowsGetColumns string = "rows_get_columns"
FfiRowsNext string = "rows_next"
FfiRowsGetValue string = "rows_get_value"
FfiFreeColumns string = "free_columns"
FfiFreeCString string = "free_string"
FfiFreeBlob string = "free_blob"
driverName = "sqlite3"
libName = "lib_limbo_go"
FfiDbOpen = "db_open"
FfiDbClose = "db_close"
FfiDbPrepare = "db_prepare"
FfiStmtExec = "stmt_execute"
FfiStmtQuery = "stmt_query"
FfiStmtParameterCount = "stmt_parameter_count"
FfiStmtClose = "stmt_close"
FfiRowsClose = "rows_close"
FfiRowsGetColumns = "rows_get_columns"
FfiRowsGetColumnName = "rows_get_column_name"
FfiRowsNext = "rows_next"
FfiRowsGetValue = "rows_get_value"
FfiFreeColumns = "free_columns"
FfiFreeCString = "free_string"
FfiFreeBlob = "free_blob"
)
// convert a namedValue slice into normal values until named parameters are supported
@@ -198,47 +201,20 @@ func toGoBlob(blobPtr uintptr) []byte {
return copied
}
var freeBlobFunc func(uintptr)
func freeBlob(blobPtr uintptr) {
if blobPtr == 0 {
return
}
if freeBlobFunc == nil {
getFfiFunc(&freeBlobFunc, FfiFreeBlob)
}
freeBlobFunc(blobPtr)
}
var freeStringFunc func(uintptr)
func freeCString(cstrPtr uintptr) {
if cstrPtr == 0 {
return
}
if freeStringFunc == nil {
getFfiFunc(&freeStringFunc, FfiFreeCString)
}
freeStringFunc(cstrPtr)
}
func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
if arrayPtr == 0 || length == 0 {
return nil
}
ptrSlice := unsafe.Slice(
(**byte)(unsafe.Pointer(arrayPtr)),
length,
)
out := make([]string, 0, length)
for _, cstr := range ptrSlice {
out = append(out, GoString(uintptr(unsafe.Pointer(cstr))))
}
return out
}
// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI
// for Blob types, we have to pin them so they are not garbage collected before they can be copied
// into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call
@@ -259,6 +235,7 @@ func buildArgs(args []driver.Value) ([]limboValue, func(), error) {
case string:
limboVal.Type = textVal
cstr := CString(val)
pinner.Pin(cstr)
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr))
case []byte:
limboVal.Type = blobVal