mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-23 00:45:37 +01:00
bindings/go: enable multiple connections, register all symbols at library load
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package limbo
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -9,66 +10,67 @@ import (
|
||||
"github.com/ebitengine/purego"
|
||||
)
|
||||
|
||||
const (
|
||||
driverName = "sqlite3"
|
||||
libName = "lib_limbo_go"
|
||||
)
|
||||
func init() {
|
||||
err := ensureLibLoaded()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sql.Register(driverName, &limboDriver{})
|
||||
}
|
||||
|
||||
type limboDriver struct {
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
var library = sync.OnceValue(func() uintptr {
|
||||
lib, err := loadLibrary()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return lib
|
||||
})
|
||||
|
||||
var (
|
||||
libOnce sync.Once
|
||||
loadErr error
|
||||
dbOpen func(string) uintptr
|
||||
dbClose func(uintptr) uintptr
|
||||
connPrepare func(uintptr, string) uintptr
|
||||
freeBlobFunc func(uintptr)
|
||||
freeStringFunc func(uintptr)
|
||||
rowsGetColumns func(uintptr, *uint) uintptr
|
||||
rowsGetValue func(uintptr, int32) uintptr
|
||||
closeRows func(uintptr) uintptr
|
||||
freeCols func(uintptr) uintptr
|
||||
rowsNext func(uintptr) uintptr
|
||||
stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
|
||||
stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
|
||||
stmtParamCount func(uintptr) int32
|
||||
closeStmt func(uintptr) int32
|
||||
libOnce sync.Once
|
||||
limboLib uintptr
|
||||
loadErr error
|
||||
dbOpen func(string) uintptr
|
||||
dbClose func(uintptr) uintptr
|
||||
connPrepare func(uintptr, string) uintptr
|
||||
freeBlobFunc func(uintptr)
|
||||
freeStringFunc func(uintptr)
|
||||
rowsGetColumns func(uintptr) int32
|
||||
rowsGetColumnName func(uintptr, int32) uintptr
|
||||
rowsGetValue func(uintptr, int32) uintptr
|
||||
closeRows func(uintptr) uintptr
|
||||
rowsNext func(uintptr) uintptr
|
||||
stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
|
||||
stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
|
||||
stmtParamCount func(uintptr) int32
|
||||
closeStmt func(uintptr) int32
|
||||
)
|
||||
|
||||
// Register all the symbols on library load
|
||||
func ensureLibLoaded() error {
|
||||
libOnce.Do(func() {
|
||||
purego.RegisterLibFunc(&dbOpen, library(), FfiDbOpen)
|
||||
purego.RegisterLibFunc(&dbClose, library(), FfiDbClose)
|
||||
purego.RegisterLibFunc(&connPrepare, library(), FfiDbPrepare)
|
||||
purego.RegisterLibFunc(&freeBlobFunc, library(), FfiFreeBlob)
|
||||
purego.RegisterLibFunc(&freeStringFunc, library(), FfiFreeCString)
|
||||
purego.RegisterLibFunc(&rowsGetColumns, library(), FfiRowsGetColumns)
|
||||
purego.RegisterLibFunc(&rowsGetValue, library(), FfiRowsGetValue)
|
||||
purego.RegisterLibFunc(&closeRows, library(), FfiRowsClose)
|
||||
purego.RegisterLibFunc(&freeCols, library(), FfiFreeColumns)
|
||||
purego.RegisterLibFunc(&rowsNext, library(), FfiRowsNext)
|
||||
purego.RegisterLibFunc(&stmtQuery, library(), FfiStmtQuery)
|
||||
purego.RegisterLibFunc(&stmtExec, library(), FfiStmtExec)
|
||||
purego.RegisterLibFunc(&stmtParamCount, library(), FfiStmtParameterCount)
|
||||
purego.RegisterLibFunc(&closeStmt, library(), FfiStmtClose)
|
||||
limboLib, loadErr = loadLibrary()
|
||||
if loadErr != nil {
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&dbOpen, limboLib, FfiDbOpen)
|
||||
purego.RegisterLibFunc(&dbClose, limboLib, FfiDbClose)
|
||||
purego.RegisterLibFunc(&connPrepare, limboLib, FfiDbPrepare)
|
||||
purego.RegisterLibFunc(&freeBlobFunc, limboLib, FfiFreeBlob)
|
||||
purego.RegisterLibFunc(&freeStringFunc, limboLib, FfiFreeCString)
|
||||
purego.RegisterLibFunc(&rowsGetColumns, limboLib, FfiRowsGetColumns)
|
||||
purego.RegisterLibFunc(&rowsGetColumnName, limboLib, FfiRowsGetColumnName)
|
||||
purego.RegisterLibFunc(&rowsGetValue, limboLib, FfiRowsGetValue)
|
||||
purego.RegisterLibFunc(&closeRows, limboLib, FfiRowsClose)
|
||||
purego.RegisterLibFunc(&rowsNext, limboLib, FfiRowsNext)
|
||||
purego.RegisterLibFunc(&stmtQuery, limboLib, FfiStmtQuery)
|
||||
purego.RegisterLibFunc(&stmtExec, limboLib, FfiStmtExec)
|
||||
purego.RegisterLibFunc(&stmtParamCount, limboLib, FfiStmtParameterCount)
|
||||
purego.RegisterLibFunc(&closeStmt, limboLib, FfiStmtClose)
|
||||
})
|
||||
return loadErr
|
||||
}
|
||||
|
||||
func (d *limboDriver) Open(name string) (driver.Conn, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
conn, err := openConn(name)
|
||||
d.Unlock()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,26 +82,24 @@ type limboConn struct {
|
||||
ctx uintptr
|
||||
}
|
||||
|
||||
func newConn(ctx uintptr) *limboConn {
|
||||
return &limboConn{
|
||||
sync.Mutex{},
|
||||
ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func openConn(dsn string) (*limboConn, error) {
|
||||
ctx := dbOpen(dsn)
|
||||
if ctx == 0 {
|
||||
return nil, fmt.Errorf("failed to open database for dsn=%q", dsn)
|
||||
}
|
||||
return newConn(ctx), loadErr
|
||||
return &limboConn{
|
||||
sync.Mutex{},
|
||||
ctx,
|
||||
}, loadErr
|
||||
}
|
||||
|
||||
func (c *limboConn) Close() error {
|
||||
if c.ctx == 0 {
|
||||
return nil
|
||||
}
|
||||
c.Lock()
|
||||
dbClose(c.ctx)
|
||||
c.Unlock()
|
||||
c.ctx = 0
|
||||
return nil
|
||||
}
|
||||
@@ -114,7 +114,7 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) {
|
||||
if stmtPtr == 0 {
|
||||
return nil, fmt.Errorf("failed to prepare query=%q", query)
|
||||
}
|
||||
return initStmt(stmtPtr, query), nil
|
||||
return newStmt(stmtPtr, query), nil
|
||||
}
|
||||
|
||||
// begin is needed to implement driver.Conn.. for now not implemented
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestMain(m *testing.M) {
|
||||
panic(connErr)
|
||||
}
|
||||
defer conn.Close()
|
||||
err := createTable()
|
||||
err := createTable(conn)
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating table: %v", err)
|
||||
}
|
||||
@@ -26,49 +26,13 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestInsertData(t *testing.T) {
|
||||
err := insertData()
|
||||
err := insertData(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunction(t *testing.T) {
|
||||
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
if err != nil {
|
||||
t.Fatalf("Error preparing statement: %v", err)
|
||||
}
|
||||
_, err = stmt.Exec(1, "hello", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing statment with arguments: %v", err)
|
||||
}
|
||||
stmt.Close()
|
||||
stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?")
|
||||
if err != nil {
|
||||
t.Fatalf("Error preparing select stmt: %v", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.Query(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing select stmt: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var b []byte
|
||||
err = rows.Scan(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
fmt.Println("RESULTS: ", string(b))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
err := insertData()
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting data: %v", err)
|
||||
}
|
||||
|
||||
query := "SELECT * FROM test;"
|
||||
stmt, err := conn.Prepare(query)
|
||||
if err != nil {
|
||||
@@ -117,6 +81,130 @@ func TestQuery(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestFunctions(t *testing.T) {
|
||||
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
if err != nil {
|
||||
t.Fatalf("Error preparing statement: %v", err)
|
||||
}
|
||||
_, err = stmt.Exec(60, "TestFunction", 400)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing statment with arguments: %v", err)
|
||||
}
|
||||
stmt.Close()
|
||||
stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?")
|
||||
if err != nil {
|
||||
t.Fatalf("Error preparing select stmt: %v", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.Query(60)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing select stmt: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var b []byte
|
||||
err = rows.Scan(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
if len(b) != 400 {
|
||||
t.Fatalf("Expected 100 bytes, got %d", len(b))
|
||||
}
|
||||
}
|
||||
sql := "SELECT uuid4_str();"
|
||||
stmt, err = conn.Prepare(sql)
|
||||
if err != nil {
|
||||
t.Fatalf("Error preparing statement: %v", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err = stmt.Query()
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var i int
|
||||
for rows.Next() {
|
||||
var b string
|
||||
err = rows.Scan(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
if len(b) != 36 {
|
||||
t.Fatalf("Expected 36 bytes, got %d", len(b))
|
||||
}
|
||||
i++
|
||||
fmt.Printf("uuid: %s\n", b)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Fatalf("Expected 1 row, got %d", i)
|
||||
}
|
||||
fmt.Println("zeroblob + uuid functions passed")
|
||||
}
|
||||
|
||||
func TestDuplicateConnection(t *testing.T) {
|
||||
newConn, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening new connection: %v", err)
|
||||
}
|
||||
err = createTable(newConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating table: %v", err)
|
||||
}
|
||||
err = insertData(newConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting data: %v", err)
|
||||
}
|
||||
query := "SELECT * FROM test;"
|
||||
rows, err := newConn.Query(query)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var a int
|
||||
var b string
|
||||
var c []byte
|
||||
err = rows.Scan(&a, &b, &c)
|
||||
if err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
fmt.Println("RESULTS: ", a, b, string(c))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateConnection2(t *testing.T) {
|
||||
newConn, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening new connection: %v", err)
|
||||
}
|
||||
sql := "CREATE TABLE test (foo INTEGER, bar INTEGER, baz BLOB);"
|
||||
newConn.Exec(sql)
|
||||
sql = "INSERT INTO test (foo, bar, baz) VALUES (?, ?, uuid4());"
|
||||
stmt, err := newConn.Prepare(sql)
|
||||
stmt.Exec(242345, 2342434)
|
||||
defer stmt.Close()
|
||||
query := "SELECT * FROM test;"
|
||||
rows, err := newConn.Query(query)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var a int
|
||||
var b int
|
||||
var c []byte
|
||||
err = rows.Scan(&a, &b, &c)
|
||||
if err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
fmt.Println("RESULTS: ", a, b, string(c))
|
||||
if len(c) != 16 {
|
||||
t.Fatalf("Expected 16 bytes, got %d", len(c))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func slicesAreEq(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))
|
||||
@@ -133,7 +221,7 @@ func slicesAreEq(a, b []byte) bool {
|
||||
|
||||
var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"}
|
||||
|
||||
func createTable() error {
|
||||
func createTable(conn *sql.DB) error {
|
||||
insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
if err != nil {
|
||||
@@ -144,7 +232,7 @@ func createTable() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func insertData() error {
|
||||
func insertData(conn *sql.DB) error {
|
||||
for i := 1; i <= 5; i++ {
|
||||
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package limbo
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -44,11 +43,3 @@ func loadLibrary() (uintptr, error) {
|
||||
}
|
||||
return 0, fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName)
|
||||
}
|
||||
|
||||
func init() {
|
||||
err := ensureLibLoaded()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sql.Register("sqlite3", &limboDriver{})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package limbo
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -12,14 +11,14 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func loadLibrary() error {
|
||||
func loadLibrary() (uintptr, error) {
|
||||
libName := fmt.Sprintf("%s.dll", libName)
|
||||
pathEnv := os.Getenv("PATH")
|
||||
paths := strings.Split(pathEnv, ";")
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
paths = append(paths, cwd)
|
||||
for _, path := range paths {
|
||||
@@ -27,21 +26,11 @@ func loadLibrary() error {
|
||||
if _, err := os.Stat(dllPath); err == nil {
|
||||
slib, loadErr := windows.LoadLibrary(dllPath)
|
||||
if loadErr != nil {
|
||||
return fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr)
|
||||
return 0, fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr)
|
||||
}
|
||||
limboLib = uintptr(slib)
|
||||
return nil
|
||||
return uintptr(slib), nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("library %s not found in PATH or CWD", libName)
|
||||
}
|
||||
|
||||
func init() {
|
||||
err := loadLibrary()
|
||||
if err != nil {
|
||||
fmt.Println("Error opening limbo library: ", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
sql.Register("sqlite3", &limboDriver{})
|
||||
return 0, fmt.Errorf("library %s not found in PATH or CWD", libName)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// only construct limboRows with initRows function to ensure proper initialization
|
||||
type limboRows struct {
|
||||
mu sync.Mutex
|
||||
ctx uintptr
|
||||
@@ -15,12 +14,12 @@ type limboRows struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
// Initialize/register the FFI function pointers for the rows methods
|
||||
// DO NOT construct 'limboRows' without this function
|
||||
func initRows(ctx uintptr) *limboRows {
|
||||
func newRows(ctx uintptr) *limboRows {
|
||||
return &limboRows{
|
||||
mu: sync.Mutex{},
|
||||
ctx: ctx,
|
||||
mu: sync.Mutex{},
|
||||
ctx: ctx,
|
||||
closed: false,
|
||||
columns: nil,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,15 +28,17 @@ func (r *limboRows) Columns() []string {
|
||||
return nil
|
||||
}
|
||||
if r.columns == nil {
|
||||
var columnCount uint
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
colArrayPtr := rowsGetColumns(r.ctx, &columnCount)
|
||||
if colArrayPtr != 0 && columnCount > 0 {
|
||||
r.columns = cArrayToGoStrings(colArrayPtr, columnCount)
|
||||
if freeCols != nil {
|
||||
defer freeCols(colArrayPtr)
|
||||
count := rowsGetColumns(r.ctx)
|
||||
if count > 0 {
|
||||
columns := make([]string, 0, count)
|
||||
for i := 0; i < int(count); i++ {
|
||||
cstr := rowsGetColumnName(r.ctx, int32(i))
|
||||
columns = append(columns, fmt.Sprintf("%s", GoString(cstr)))
|
||||
freeCString(cstr)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
r.columns = columns
|
||||
}
|
||||
}
|
||||
return r.columns
|
||||
@@ -59,14 +60,14 @@ func (r *limboRows) Next(dest []driver.Value) error {
|
||||
if r.ctx == 0 || r.closed {
|
||||
return io.EOF
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for {
|
||||
status := rowsNext(r.ctx)
|
||||
switch ResultCode(status) {
|
||||
case Row:
|
||||
for i := range dest {
|
||||
r.mu.Lock()
|
||||
valPtr := rowsGetValue(r.ctx, int32(i))
|
||||
r.mu.Unlock()
|
||||
val := toGoValue(valPtr)
|
||||
dest[i] = val
|
||||
}
|
||||
|
||||
@@ -78,30 +78,30 @@ pub extern "C" fn free_string(s: *mut c_char) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Function to get the number of expected ResultColumns in the prepared statement.
|
||||
/// to avoid the needless complexity of returning an array of strings, this instead
|
||||
/// works like rows_next/rows_get_value
|
||||
#[no_mangle]
|
||||
pub extern "C" fn rows_get_columns(
|
||||
rows_ptr: *mut c_void,
|
||||
out_length: *mut usize,
|
||||
) -> *mut *const c_char {
|
||||
if rows_ptr.is_null() || out_length.is_null() {
|
||||
pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 {
|
||||
if rows_ptr.is_null() {
|
||||
return -1;
|
||||
}
|
||||
let rows = LimboRows::from_ptr(rows_ptr);
|
||||
rows.stmt.columns().len() as i32
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *const c_char {
|
||||
if rows_ptr.is_null() {
|
||||
return std::ptr::null_mut();
|
||||
}
|
||||
let rows = LimboRows::from_ptr(rows_ptr);
|
||||
let c_strings: Vec<std::ffi::CString> = rows
|
||||
.stmt
|
||||
.columns()
|
||||
.iter()
|
||||
.map(|name| std::ffi::CString::new(name.as_str()).unwrap())
|
||||
.collect();
|
||||
|
||||
let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
|
||||
unsafe {
|
||||
*out_length = c_ptrs.len();
|
||||
if idx < 0 || idx as usize >= rows.stmt.columns().len() {
|
||||
return std::ptr::null_mut();
|
||||
}
|
||||
let ptr = c_ptrs.as_ptr();
|
||||
std::mem::forget(c_strings);
|
||||
std::mem::forget(c_ptrs);
|
||||
ptr as *mut *const c_char
|
||||
let name = &rows.stmt.columns()[idx as usize];
|
||||
let cstr = std::ffi::CString::new(name.as_bytes()).expect("Failed to create CString");
|
||||
cstr.into_raw() as *const c_char
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
@@ -111,21 +111,6 @@ pub extern "C" fn rows_close(rows_ptr: *mut c_void) {
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn free_columns(columns: *mut *const c_char) {
|
||||
if columns.is_null() {
|
||||
return;
|
||||
}
|
||||
unsafe {
|
||||
let mut idx = 0;
|
||||
while !(*columns.add(idx)).is_null() {
|
||||
let _ = std::ffi::CString::from_raw(*columns.add(idx) as *mut c_char);
|
||||
idx += 1;
|
||||
}
|
||||
let _ = Box::from_raw(columns);
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn free_rows(rows: *mut c_void) {
|
||||
if rows.is_null() {
|
||||
|
||||
@@ -132,13 +132,9 @@ pub struct LimboStatement<'conn> {
|
||||
#[no_mangle]
|
||||
pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode {
|
||||
if !ctx.is_null() {
|
||||
let stmt = LimboStatement::from_ptr(ctx);
|
||||
if stmt.statement.is_none() {
|
||||
return ResultCode::Error;
|
||||
} else {
|
||||
let _ = unsafe { Box::from_raw(ctx as *mut LimboStatement) };
|
||||
return ResultCode::Ok;
|
||||
}
|
||||
let stmt = unsafe { Box::from_raw(ctx as *mut LimboStatement) };
|
||||
drop(stmt);
|
||||
return ResultCode::Ok;
|
||||
}
|
||||
ResultCode::Invalid
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package limbo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -8,22 +9,16 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// only construct limboStmt with initStmt function to ensure proper initialization
|
||||
// inUse tracks whether or not `query` has been called. if inUse > 0, stmt no longer
|
||||
// owns the underlying data and `rows` is responsible for cleaning it up on close.
|
||||
type limboStmt struct {
|
||||
mu sync.Mutex
|
||||
ctx uintptr
|
||||
sql string
|
||||
inUse int
|
||||
mu sync.Mutex
|
||||
ctx uintptr
|
||||
sql string
|
||||
}
|
||||
|
||||
// Initialize/register the FFI function pointers for the statement methods
|
||||
func initStmt(ctx uintptr, sql string) *limboStmt {
|
||||
func newStmt(ctx uintptr, sql string) *limboStmt {
|
||||
return &limboStmt{
|
||||
ctx: uintptr(ctx),
|
||||
sql: sql,
|
||||
inUse: 0,
|
||||
ctx: uintptr(ctx),
|
||||
sql: sql,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,25 +29,19 @@ func (ls *limboStmt) NumInput() int {
|
||||
}
|
||||
|
||||
func (ls *limboStmt) Close() error {
|
||||
if ls.inUse == 0 {
|
||||
ls.mu.Lock()
|
||||
res := closeStmt(ls.ctx)
|
||||
ls.mu.Unlock()
|
||||
if ResultCode(res) != Ok {
|
||||
return fmt.Errorf("error closing statement: %s", ResultCode(res).String())
|
||||
}
|
||||
ls.mu.Lock()
|
||||
res := closeStmt(ls.ctx)
|
||||
ls.mu.Unlock()
|
||||
if ResultCode(res) != Ok {
|
||||
return fmt.Errorf("error closing statement: %s", ResultCode(res).String())
|
||||
}
|
||||
ls.ctx = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
ls.mu.Lock()
|
||||
argArray, cleanup, err := buildArgs(args)
|
||||
defer func() {
|
||||
cleanup()
|
||||
ls.mu.Unlock()
|
||||
}()
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,7 +51,9 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
argPtr = uintptr(unsafe.Pointer(&argArray[0]))
|
||||
}
|
||||
var changes uint64
|
||||
ls.mu.Lock()
|
||||
rc := stmtExec(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes)))
|
||||
ls.mu.Unlock()
|
||||
switch ResultCode(rc) {
|
||||
case Ok, Done:
|
||||
return driver.RowsAffected(changes), nil
|
||||
@@ -80,9 +71,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
}
|
||||
|
||||
func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
ls.mu.Lock()
|
||||
queryArgs, cleanup, err := buildArgs(args)
|
||||
defer func() { cleanup(); ls.mu.Unlock() }()
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -90,64 +80,66 @@ func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
if len(args) > 0 {
|
||||
argPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
|
||||
}
|
||||
ls.mu.Lock()
|
||||
rowsPtr := stmtQuery(ls.ctx, argPtr, uint64(len(queryArgs)))
|
||||
ls.mu.Unlock()
|
||||
if rowsPtr == 0 {
|
||||
return nil, fmt.Errorf("query failed for: %q", ls.sql)
|
||||
}
|
||||
ls.inUse++
|
||||
return initRows(rowsPtr), nil
|
||||
return newRows(rowsPtr), nil
|
||||
}
|
||||
|
||||
// func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
// ls.mu.Lock()
|
||||
// stripped := namedValueToValue(args)
|
||||
// argArray, cleanup, err := getArgsPtr(stripped)
|
||||
// defer func() { cleanup(); ls.mu.Unlock() }()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// return nil, ctx.Err()
|
||||
// default:
|
||||
// }
|
||||
// var changes uint64
|
||||
// res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes)))
|
||||
// switch ResultCode(res) {
|
||||
// case Ok, Done:
|
||||
// changes := uint64(changes)
|
||||
// return driver.RowsAffected(changes), nil
|
||||
// case Error:
|
||||
// return nil, errors.New("error executing statement")
|
||||
// case Busy:
|
||||
// return nil, errors.New("busy")
|
||||
// case Interrupt:
|
||||
// return nil, errors.New("interrupted")
|
||||
// default:
|
||||
// return nil, fmt.Errorf("unexpected status: %d", res)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
// ls.mu.Lock()
|
||||
// queryArgs, allocs, err := buildNamedArgs(args)
|
||||
// defer func() { allocs(); ls.mu.Unlock() }()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// argsPtr := uintptr(0)
|
||||
// if len(queryArgs) > 0 {
|
||||
// argsPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
|
||||
// }
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// return nil, ctx.Err()
|
||||
// default:
|
||||
// }
|
||||
// rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs)))
|
||||
// if rowsPtr == 0 {
|
||||
// return nil, fmt.Errorf("query failed for: %q", ls.sql)
|
||||
// }
|
||||
// ls.inUse++
|
||||
// return initRows(rowsPtr), nil
|
||||
// }
|
||||
func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
stripped := namedValueToValue(args)
|
||||
argArray, cleanup, err := getArgsPtr(stripped)
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
var changes uint64
|
||||
ls.mu.Lock()
|
||||
res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes)))
|
||||
ls.mu.Unlock()
|
||||
switch ResultCode(res) {
|
||||
case Ok, Done:
|
||||
changes := uint64(changes)
|
||||
return driver.RowsAffected(changes), nil
|
||||
case Error:
|
||||
return nil, errors.New("error executing statement")
|
||||
case Busy:
|
||||
return nil, errors.New("busy")
|
||||
case Interrupt:
|
||||
return nil, errors.New("interrupted")
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected status: %d", res)
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
queryArgs, allocs, err := buildNamedArgs(args)
|
||||
defer allocs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
argsPtr := uintptr(0)
|
||||
if len(queryArgs) > 0 {
|
||||
argsPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
ls.mu.Lock()
|
||||
rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs)))
|
||||
ls.mu.Unlock()
|
||||
if rowsPtr == 0 {
|
||||
return nil, fmt.Errorf("query failed for: %q", ls.sql)
|
||||
}
|
||||
return newRows(rowsPtr), nil
|
||||
}
|
||||
|
||||
@@ -65,20 +65,23 @@ func (rc ResultCode) String() string {
|
||||
}
|
||||
|
||||
const (
|
||||
FfiDbOpen string = "db_open"
|
||||
FfiDbClose string = "db_close"
|
||||
FfiDbPrepare string = "db_prepare"
|
||||
FfiStmtExec string = "stmt_execute"
|
||||
FfiStmtQuery string = "stmt_query"
|
||||
FfiStmtParameterCount string = "stmt_parameter_count"
|
||||
FfiStmtClose string = "stmt_close"
|
||||
FfiRowsClose string = "rows_close"
|
||||
FfiRowsGetColumns string = "rows_get_columns"
|
||||
FfiRowsNext string = "rows_next"
|
||||
FfiRowsGetValue string = "rows_get_value"
|
||||
FfiFreeColumns string = "free_columns"
|
||||
FfiFreeCString string = "free_string"
|
||||
FfiFreeBlob string = "free_blob"
|
||||
driverName = "sqlite3"
|
||||
libName = "lib_limbo_go"
|
||||
FfiDbOpen = "db_open"
|
||||
FfiDbClose = "db_close"
|
||||
FfiDbPrepare = "db_prepare"
|
||||
FfiStmtExec = "stmt_execute"
|
||||
FfiStmtQuery = "stmt_query"
|
||||
FfiStmtParameterCount = "stmt_parameter_count"
|
||||
FfiStmtClose = "stmt_close"
|
||||
FfiRowsClose = "rows_close"
|
||||
FfiRowsGetColumns = "rows_get_columns"
|
||||
FfiRowsGetColumnName = "rows_get_column_name"
|
||||
FfiRowsNext = "rows_next"
|
||||
FfiRowsGetValue = "rows_get_value"
|
||||
FfiFreeColumns = "free_columns"
|
||||
FfiFreeCString = "free_string"
|
||||
FfiFreeBlob = "free_blob"
|
||||
)
|
||||
|
||||
// convert a namedValue slice into normal values until named parameters are supported
|
||||
@@ -212,22 +215,6 @@ func freeCString(cstrPtr uintptr) {
|
||||
freeStringFunc(cstrPtr)
|
||||
}
|
||||
|
||||
func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
|
||||
if arrayPtr == 0 || length == 0 {
|
||||
return nil
|
||||
}
|
||||
ptrSlice := unsafe.Slice(
|
||||
(**byte)(unsafe.Pointer(arrayPtr)),
|
||||
length,
|
||||
)
|
||||
|
||||
out := make([]string, 0, length)
|
||||
for _, cstr := range ptrSlice {
|
||||
out = append(out, GoString(uintptr(unsafe.Pointer(cstr))))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI
|
||||
// for Blob types, we have to pin them so they are not garbage collected before they can be copied
|
||||
// into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call
|
||||
|
||||
Reference in New Issue
Block a user