mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-16 05:24:22 +01:00
bindings/go: Adjust tests for multiple concurrent connections
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -28,4 +28,7 @@ dist/
|
||||
.DS_Store
|
||||
|
||||
# Javascript
|
||||
**/node_modules/
|
||||
**/node_modules/
|
||||
|
||||
# testing
|
||||
testing/limbo_output.txt
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
"sync"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
)
|
||||
@@ -14,57 +14,91 @@ const (
|
||||
libName = "lib_limbo_go"
|
||||
)
|
||||
|
||||
var limboLib uintptr
|
||||
|
||||
type limboDriver struct{}
|
||||
|
||||
func (d limboDriver) Open(name string) (driver.Conn, error) {
|
||||
return openConn(name)
|
||||
type limboDriver struct {
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func toCString(s string) uintptr {
|
||||
b := append([]byte(s), 0)
|
||||
return uintptr(unsafe.Pointer(&b[0]))
|
||||
var library = sync.OnceValue(func() uintptr {
|
||||
lib, err := loadLibrary()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return lib
|
||||
})
|
||||
|
||||
var (
|
||||
libOnce sync.Once
|
||||
loadErr error
|
||||
dbOpen func(string) uintptr
|
||||
dbClose func(uintptr) uintptr
|
||||
connPrepare func(uintptr, string) uintptr
|
||||
freeBlobFunc func(uintptr)
|
||||
freeStringFunc func(uintptr)
|
||||
rowsGetColumns func(uintptr, *uint) uintptr
|
||||
rowsGetValue func(uintptr, int32) uintptr
|
||||
closeRows func(uintptr) uintptr
|
||||
freeCols func(uintptr) uintptr
|
||||
rowsNext func(uintptr) uintptr
|
||||
stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
|
||||
stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
|
||||
stmtParamCount func(uintptr) int32
|
||||
closeStmt func(uintptr) int32
|
||||
)
|
||||
|
||||
func ensureLibLoaded() error {
|
||||
libOnce.Do(func() {
|
||||
purego.RegisterLibFunc(&dbOpen, library(), FfiDbOpen)
|
||||
purego.RegisterLibFunc(&dbClose, library(), FfiDbClose)
|
||||
purego.RegisterLibFunc(&connPrepare, library(), FfiDbPrepare)
|
||||
purego.RegisterLibFunc(&freeBlobFunc, library(), FfiFreeBlob)
|
||||
purego.RegisterLibFunc(&freeStringFunc, library(), FfiFreeCString)
|
||||
purego.RegisterLibFunc(&rowsGetColumns, library(), FfiRowsGetColumns)
|
||||
purego.RegisterLibFunc(&rowsGetValue, library(), FfiRowsGetValue)
|
||||
purego.RegisterLibFunc(&closeRows, library(), FfiRowsClose)
|
||||
purego.RegisterLibFunc(&freeCols, library(), FfiFreeColumns)
|
||||
purego.RegisterLibFunc(&rowsNext, library(), FfiRowsNext)
|
||||
purego.RegisterLibFunc(&stmtQuery, library(), FfiStmtQuery)
|
||||
purego.RegisterLibFunc(&stmtExec, library(), FfiStmtExec)
|
||||
purego.RegisterLibFunc(&stmtParamCount, library(), FfiStmtParameterCount)
|
||||
purego.RegisterLibFunc(&closeStmt, library(), FfiStmtClose)
|
||||
})
|
||||
return loadErr
|
||||
}
|
||||
|
||||
// helper to register an FFI function in the lib_limbo_go library
|
||||
func getFfiFunc(ptr interface{}, name string) {
|
||||
purego.RegisterLibFunc(ptr, limboLib, name)
|
||||
func (d *limboDriver) Open(name string) (driver.Conn, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
conn, err := openConn(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// TODO: sync primitives
|
||||
type limboConn struct {
|
||||
ctx uintptr
|
||||
prepare func(uintptr, string) uintptr
|
||||
sync.Mutex
|
||||
ctx uintptr
|
||||
}
|
||||
|
||||
func newConn(ctx uintptr) *limboConn {
|
||||
var prepare func(uintptr, string) uintptr
|
||||
getFfiFunc(&prepare, FfiDbPrepare)
|
||||
return &limboConn{
|
||||
sync.Mutex{},
|
||||
ctx,
|
||||
prepare,
|
||||
}
|
||||
}
|
||||
|
||||
func openConn(dsn string) (*limboConn, error) {
|
||||
var dbOpen func(string) uintptr
|
||||
getFfiFunc(&dbOpen, FfiDbOpen)
|
||||
|
||||
ctx := dbOpen(dsn)
|
||||
if ctx == 0 {
|
||||
return nil, fmt.Errorf("failed to open database for dsn=%q", dsn)
|
||||
}
|
||||
return newConn(ctx), nil
|
||||
return newConn(ctx), loadErr
|
||||
}
|
||||
|
||||
func (c *limboConn) Close() error {
|
||||
if c.ctx == 0 {
|
||||
return nil
|
||||
}
|
||||
var dbClose func(uintptr) uintptr
|
||||
getFfiFunc(&dbClose, FfiDbClose)
|
||||
|
||||
dbClose(c.ctx)
|
||||
c.ctx = 0
|
||||
return nil
|
||||
@@ -74,10 +108,9 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) {
|
||||
if c.ctx == 0 {
|
||||
return nil, errors.New("connection closed")
|
||||
}
|
||||
if c.prepare == nil {
|
||||
panic("prepare function not set")
|
||||
}
|
||||
stmtPtr := c.prepare(c.ctx, query)
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
stmtPtr := connPrepare(c.ctx, query)
|
||||
if stmtPtr == 0 {
|
||||
return nil, fmt.Errorf("failed to prepare query=%q", query)
|
||||
}
|
||||
|
||||
@@ -3,63 +3,68 @@ package limbo_test
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
_ "limbo"
|
||||
)
|
||||
|
||||
func TestConnection(t *testing.T) {
|
||||
conn, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening database: %v", err)
|
||||
var conn *sql.DB
|
||||
var connErr error
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
conn, connErr = sql.Open("sqlite3", ":memory:")
|
||||
if connErr != nil {
|
||||
panic(connErr)
|
||||
}
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
func TestCreateTable(t *testing.T) {
|
||||
conn, err := sql.Open("sqlite3", ":memory:")
|
||||
err := createTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening database: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = createTable(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating table: %v", err)
|
||||
log.Fatalf("Error creating table: %v", err)
|
||||
}
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestInsertData(t *testing.T) {
|
||||
conn, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening database: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = createTable(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating table: %v", err)
|
||||
}
|
||||
|
||||
err = insertData(conn)
|
||||
err := insertData()
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunction(t *testing.T) {
|
||||
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
if err != nil {
|
||||
t.Fatalf("Error preparing statement: %v", err)
|
||||
}
|
||||
_, err = stmt.Exec(1, "hello", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing statment with arguments: %v", err)
|
||||
}
|
||||
stmt.Close()
|
||||
stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?")
|
||||
if err != nil {
|
||||
t.Fatalf("Error preparing select stmt: %v", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.Query(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing select stmt: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var b []byte
|
||||
err = rows.Scan(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
fmt.Println("RESULTS: ", string(b))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
conn, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening database: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = createTable(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating table: %v", err)
|
||||
}
|
||||
|
||||
err = insertData(conn)
|
||||
err := insertData()
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting data: %v", err)
|
||||
}
|
||||
@@ -99,8 +104,8 @@ func TestQuery(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
if a != i || b != rowsMap[i] || string(c) != rowsMap[i] {
|
||||
t.Fatalf("Expected %d, %s, %s, got %d, %s, %b", i, rowsMap[i], rowsMap[i], a, b, c)
|
||||
if a != i || b != rowsMap[i] || !slicesAreEq(c, []byte(rowsMap[i])) {
|
||||
t.Fatalf("Expected %d, %s, %s, got %d, %s, %s", i, rowsMap[i], rowsMap[i], a, b, string(c))
|
||||
}
|
||||
fmt.Println("RESULTS: ", a, b, string(c))
|
||||
i++
|
||||
@@ -109,11 +114,26 @@ func TestQuery(t *testing.T) {
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("Row iteration error: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func slicesAreEq(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
fmt.Printf("SLICES NOT EQUAL: %v != %v\n", a, b)
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"}
|
||||
|
||||
func createTable(conn *sql.DB) error {
|
||||
func createTable() error {
|
||||
insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
if err != nil {
|
||||
@@ -124,7 +144,7 @@ func createTable(conn *sql.DB) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func insertData(conn *sql.DB) error {
|
||||
func insertData() error {
|
||||
for i := 1; i <= 5; i++ {
|
||||
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/ebitengine/purego"
|
||||
)
|
||||
|
||||
func loadLibrary() error {
|
||||
func loadLibrary() (uintptr, error) {
|
||||
var libraryName string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
@@ -21,14 +21,14 @@ func loadLibrary() error {
|
||||
case "linux":
|
||||
libraryName = fmt.Sprintf("%s.so", libName)
|
||||
default:
|
||||
return fmt.Errorf("GOOS=%s is not supported", runtime.GOOS)
|
||||
return 0, fmt.Errorf("GOOS=%s is not supported", runtime.GOOS)
|
||||
}
|
||||
|
||||
libPath := os.Getenv("LD_LIBRARY_PATH")
|
||||
paths := strings.Split(libPath, ":")
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
paths = append(paths, cwd)
|
||||
|
||||
@@ -37,20 +37,18 @@ func loadLibrary() error {
|
||||
if _, err := os.Stat(libPath); err == nil {
|
||||
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if dlerr != nil {
|
||||
return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr)
|
||||
return 0, fmt.Errorf("failed to load library at %s: %w", libPath, dlerr)
|
||||
}
|
||||
limboLib = slib
|
||||
return nil
|
||||
return slib, nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName)
|
||||
return 0, fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName)
|
||||
}
|
||||
|
||||
func init() {
|
||||
err := loadLibrary()
|
||||
err := ensureLibLoaded()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
panic(err)
|
||||
}
|
||||
sql.Register("sqlite3", &limboDriver{})
|
||||
}
|
||||
|
||||
82
bindings/go/rows.go
Normal file
82
bindings/go/rows.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package limbo
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// only construct limboRows with initRows function to ensure proper initialization
|
||||
type limboRows struct {
|
||||
mu sync.Mutex
|
||||
ctx uintptr
|
||||
columns []string
|
||||
closed bool
|
||||
}
|
||||
|
||||
// Initialize/register the FFI function pointers for the rows methods
|
||||
// DO NOT construct 'limboRows' without this function
|
||||
func initRows(ctx uintptr) *limboRows {
|
||||
return &limboRows{
|
||||
mu: sync.Mutex{},
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *limboRows) Columns() []string {
|
||||
if r.ctx == 0 || r.closed {
|
||||
return nil
|
||||
}
|
||||
if r.columns == nil {
|
||||
var columnCount uint
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
colArrayPtr := rowsGetColumns(r.ctx, &columnCount)
|
||||
if colArrayPtr != 0 && columnCount > 0 {
|
||||
r.columns = cArrayToGoStrings(colArrayPtr, columnCount)
|
||||
if freeCols != nil {
|
||||
defer freeCols(colArrayPtr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return r.columns
|
||||
}
|
||||
|
||||
func (r *limboRows) Close() error {
|
||||
if r.closed {
|
||||
return nil
|
||||
}
|
||||
r.mu.Lock()
|
||||
r.closed = true
|
||||
closeRows(r.ctx)
|
||||
r.ctx = 0
|
||||
r.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *limboRows) Next(dest []driver.Value) error {
|
||||
if r.ctx == 0 || r.closed {
|
||||
return io.EOF
|
||||
}
|
||||
for {
|
||||
status := rowsNext(r.ctx)
|
||||
switch ResultCode(status) {
|
||||
case Row:
|
||||
for i := range dest {
|
||||
r.mu.Lock()
|
||||
valPtr := rowsGetValue(r.ctx, int32(i))
|
||||
r.mu.Unlock()
|
||||
val := toGoValue(valPtr)
|
||||
dest[i] = val
|
||||
}
|
||||
return nil
|
||||
case Io:
|
||||
continue
|
||||
case Done:
|
||||
return io.EOF
|
||||
default:
|
||||
return fmt.Errorf("unexpected status: %d", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@ use std::{
|
||||
ffi::{c_char, c_void},
|
||||
rc::Rc,
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
/// # Safety
|
||||
@@ -26,7 +26,6 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void {
|
||||
let db = Database::open_file(io.clone(), &db_options.path.to_string());
|
||||
match db {
|
||||
Ok(db) => {
|
||||
println!("Opened database: {}", path);
|
||||
let conn = db.connect();
|
||||
return LimboConn::new(conn, io).to_ptr();
|
||||
}
|
||||
@@ -41,20 +40,24 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void {
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct LimboConn {
|
||||
conn: Rc<Connection>,
|
||||
conn: RwLock<Rc<Connection>>,
|
||||
io: Arc<dyn limbo_core::IO>,
|
||||
}
|
||||
|
||||
impl LimboConn {
|
||||
impl<'conn> LimboConn {
|
||||
fn new(conn: Rc<Connection>, io: Arc<dyn limbo_core::IO>) -> Self {
|
||||
LimboConn { conn, io }
|
||||
}
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
fn to_ptr(self) -> *mut c_void {
|
||||
Box::into_raw(Box::new(self)) as *mut c_void
|
||||
LimboConn {
|
||||
conn: conn.into(),
|
||||
io,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_ptr(ptr: *mut c_void) -> &'static mut LimboConn {
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
fn to_ptr(self) -> *mut c_void {
|
||||
Arc::into_raw(Arc::new(self)) as *mut c_void
|
||||
}
|
||||
|
||||
fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboConn {
|
||||
if ptr.is_null() {
|
||||
panic!("Null pointer");
|
||||
}
|
||||
|
||||
@@ -13,10 +13,12 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v
|
||||
let query_str = unsafe { std::ffi::CStr::from_ptr(query) }.to_str().unwrap();
|
||||
|
||||
let db = LimboConn::from_ptr(ctx);
|
||||
|
||||
let stmt = db.conn.prepare(query_str);
|
||||
let Ok(conn) = db.conn.read() else {
|
||||
return std::ptr::null_mut();
|
||||
};
|
||||
let stmt = conn.prepare(query_str);
|
||||
match stmt {
|
||||
Ok(stmt) => LimboStatement::new(Some(stmt), db).to_ptr(),
|
||||
Ok(stmt) => LimboStatement::new(Some(stmt), LimboConn::from_ptr(ctx)).to_ptr(),
|
||||
Err(_) => std::ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
@@ -53,10 +55,13 @@ pub extern "C" fn stmt_execute(
|
||||
return ResultCode::Error;
|
||||
}
|
||||
Ok(StepResult::Done) => {
|
||||
stmt.conn.conn.total_changes();
|
||||
let Ok(conn) = stmt.conn.conn.read() else {
|
||||
return ResultCode::Done;
|
||||
};
|
||||
let total_changes = conn.total_changes();
|
||||
if !changes.is_null() {
|
||||
unsafe {
|
||||
*changes = stmt.conn.conn.total_changes();
|
||||
*changes = total_changes;
|
||||
}
|
||||
}
|
||||
return ResultCode::Done;
|
||||
@@ -148,7 +153,7 @@ impl<'conn> LimboStatement<'conn> {
|
||||
Box::into_raw(Box::new(self)) as *mut c_void
|
||||
}
|
||||
|
||||
fn from_ptr(ptr: *mut c_void) -> &'static mut LimboStatement<'conn> {
|
||||
fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboStatement<'conn> {
|
||||
if ptr.is_null() {
|
||||
panic!("Null pointer");
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package limbo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -13,43 +12,32 @@ import (
|
||||
// inUse tracks whether or not `query` has been called. if inUse > 0, stmt no longer
|
||||
// owns the underlying data and `rows` is responsible for cleaning it up on close.
|
||||
type limboStmt struct {
|
||||
ctx uintptr
|
||||
sql string
|
||||
inUse int
|
||||
query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
|
||||
execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
|
||||
getParamCount func(uintptr) int32
|
||||
closeStmt func(uintptr) int32
|
||||
mu sync.Mutex
|
||||
ctx uintptr
|
||||
sql string
|
||||
inUse int
|
||||
}
|
||||
|
||||
// Initialize/register the FFI function pointers for the statement methods
|
||||
func initStmt(ctx uintptr, sql string) *limboStmt {
|
||||
var query func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr
|
||||
getFfiFunc(&query, FfiStmtQuery)
|
||||
var execute func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32
|
||||
getFfiFunc(&execute, FfiStmtExec)
|
||||
var getParamCount func(uintptr) int32
|
||||
getFfiFunc(&getParamCount, FfiStmtParameterCount)
|
||||
var closeStmt func(uintptr) int32
|
||||
getFfiFunc(&closeStmt, FfiStmtClose)
|
||||
return &limboStmt{
|
||||
ctx: uintptr(ctx),
|
||||
sql: sql,
|
||||
inUse: 0,
|
||||
execute: execute,
|
||||
query: query,
|
||||
getParamCount: getParamCount,
|
||||
closeStmt: closeStmt,
|
||||
ctx: uintptr(ctx),
|
||||
sql: sql,
|
||||
inUse: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *limboStmt) NumInput() int {
|
||||
return int(ls.getParamCount(ls.ctx))
|
||||
ls.mu.Lock()
|
||||
defer ls.mu.Unlock()
|
||||
return int(stmtParamCount(ls.ctx))
|
||||
}
|
||||
|
||||
func (ls *limboStmt) Close() error {
|
||||
if ls.inUse == 0 {
|
||||
res := ls.closeStmt(ls.ctx)
|
||||
ls.mu.Lock()
|
||||
res := closeStmt(ls.ctx)
|
||||
ls.mu.Unlock()
|
||||
if ResultCode(res) != Ok {
|
||||
return fmt.Errorf("error closing statement: %s", ResultCode(res).String())
|
||||
}
|
||||
@@ -59,8 +47,12 @@ func (ls *limboStmt) Close() error {
|
||||
}
|
||||
|
||||
func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
ls.mu.Lock()
|
||||
argArray, cleanup, err := buildArgs(args)
|
||||
defer cleanup()
|
||||
defer func() {
|
||||
cleanup()
|
||||
ls.mu.Unlock()
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -70,7 +62,7 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
argPtr = uintptr(unsafe.Pointer(&argArray[0]))
|
||||
}
|
||||
var changes uint64
|
||||
rc := ls.execute(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes)))
|
||||
rc := stmtExec(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes)))
|
||||
switch ResultCode(rc) {
|
||||
case Ok, Done:
|
||||
return driver.RowsAffected(changes), nil
|
||||
@@ -87,9 +79,10 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
ls.mu.Lock()
|
||||
queryArgs, cleanup, err := buildArgs(args)
|
||||
defer cleanup()
|
||||
defer func() { cleanup(); ls.mu.Unlock() }()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -97,59 +90,7 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
if len(args) > 0 {
|
||||
argPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
|
||||
}
|
||||
rowsPtr := st.query(st.ctx, argPtr, uint64(len(queryArgs)))
|
||||
if rowsPtr == 0 {
|
||||
return nil, fmt.Errorf("query failed for: %q", st.sql)
|
||||
}
|
||||
st.inUse++
|
||||
return initRows(rowsPtr), nil
|
||||
}
|
||||
|
||||
func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
stripped := namedValueToValue(args)
|
||||
argArray, cleanup, err := getArgsPtr(stripped)
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
var changes uint64
|
||||
res := ls.execute(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes)))
|
||||
switch ResultCode(res) {
|
||||
case Ok, Done:
|
||||
changes := uint64(changes)
|
||||
return driver.RowsAffected(changes), nil
|
||||
case Error:
|
||||
return nil, errors.New("error executing statement")
|
||||
case Busy:
|
||||
return nil, errors.New("busy")
|
||||
case Interrupt:
|
||||
return nil, errors.New("interrupted")
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected status: %d", res)
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
queryArgs, allocs, err := buildNamedArgs(args)
|
||||
defer allocs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
argsPtr := uintptr(0)
|
||||
if len(queryArgs) > 0 {
|
||||
argsPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
rowsPtr := ls.query(ls.ctx, argsPtr, uint64(len(queryArgs)))
|
||||
rowsPtr := stmtQuery(ls.ctx, argPtr, uint64(len(queryArgs)))
|
||||
if rowsPtr == 0 {
|
||||
return nil, fmt.Errorf("query failed for: %q", ls.sql)
|
||||
}
|
||||
@@ -157,81 +98,56 @@ func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue)
|
||||
return initRows(rowsPtr), nil
|
||||
}
|
||||
|
||||
// only construct limboRows with initRows function to ensure proper initialization
|
||||
type limboRows struct {
|
||||
ctx uintptr
|
||||
columns []string
|
||||
closed bool
|
||||
getCols func(uintptr, *uint) uintptr
|
||||
next func(uintptr) uintptr
|
||||
getValue func(uintptr, int32) uintptr
|
||||
closeRows func(uintptr) uintptr
|
||||
freeCols func(uintptr) uintptr
|
||||
}
|
||||
|
||||
// Initialize/register the FFI function pointers for the rows methods
|
||||
// DO NOT construct 'limboRows' without this function
|
||||
func initRows(ctx uintptr) *limboRows {
|
||||
var getCols func(uintptr, *uint) uintptr
|
||||
getFfiFunc(&getCols, FfiRowsGetColumns)
|
||||
var getValue func(uintptr, int32) uintptr
|
||||
getFfiFunc(&getValue, FfiRowsGetValue)
|
||||
var closeRows func(uintptr) uintptr
|
||||
getFfiFunc(&closeRows, FfiRowsClose)
|
||||
var freeCols func(uintptr) uintptr
|
||||
getFfiFunc(&freeCols, FfiFreeColumns)
|
||||
var next func(uintptr) uintptr
|
||||
getFfiFunc(&next, FfiRowsNext)
|
||||
|
||||
return &limboRows{
|
||||
ctx: ctx,
|
||||
getCols: getCols,
|
||||
getValue: getValue,
|
||||
closeRows: closeRows,
|
||||
freeCols: freeCols,
|
||||
next: next,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *limboRows) Columns() []string {
|
||||
if r.columns == nil {
|
||||
var columnCount uint
|
||||
colArrayPtr := r.getCols(r.ctx, &columnCount)
|
||||
if colArrayPtr != 0 && columnCount > 0 {
|
||||
r.columns = cArrayToGoStrings(colArrayPtr, columnCount)
|
||||
defer r.freeCols(colArrayPtr)
|
||||
}
|
||||
}
|
||||
return r.columns
|
||||
}
|
||||
|
||||
func (r *limboRows) Close() error {
|
||||
if r.closed {
|
||||
return nil
|
||||
}
|
||||
r.closed = true
|
||||
r.closeRows(r.ctx)
|
||||
r.ctx = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *limboRows) Next(dest []driver.Value) error {
|
||||
for {
|
||||
status := r.next(r.ctx)
|
||||
switch ResultCode(status) {
|
||||
case Row:
|
||||
for i := range dest {
|
||||
valPtr := r.getValue(r.ctx, int32(i))
|
||||
val := toGoValue(valPtr)
|
||||
dest[i] = val
|
||||
}
|
||||
return nil
|
||||
case Io:
|
||||
continue
|
||||
case Done:
|
||||
return io.EOF
|
||||
default:
|
||||
return fmt.Errorf("unexpected status: %d", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
// func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
// ls.mu.Lock()
|
||||
// stripped := namedValueToValue(args)
|
||||
// argArray, cleanup, err := getArgsPtr(stripped)
|
||||
// defer func() { cleanup(); ls.mu.Unlock() }()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// return nil, ctx.Err()
|
||||
// default:
|
||||
// }
|
||||
// var changes uint64
|
||||
// res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes)))
|
||||
// switch ResultCode(res) {
|
||||
// case Ok, Done:
|
||||
// changes := uint64(changes)
|
||||
// return driver.RowsAffected(changes), nil
|
||||
// case Error:
|
||||
// return nil, errors.New("error executing statement")
|
||||
// case Busy:
|
||||
// return nil, errors.New("busy")
|
||||
// case Interrupt:
|
||||
// return nil, errors.New("interrupted")
|
||||
// default:
|
||||
// return nil, fmt.Errorf("unexpected status: %d", res)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
// ls.mu.Lock()
|
||||
// queryArgs, allocs, err := buildNamedArgs(args)
|
||||
// defer func() { allocs(); ls.mu.Unlock() }()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// argsPtr := uintptr(0)
|
||||
// if len(queryArgs) > 0 {
|
||||
// argsPtr = uintptr(unsafe.Pointer(&queryArgs[0]))
|
||||
// }
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// return nil, ctx.Err()
|
||||
// default:
|
||||
// }
|
||||
// rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs)))
|
||||
// if rowsPtr == 0 {
|
||||
// return nil, fmt.Errorf("query failed for: %q", ls.sql)
|
||||
// }
|
||||
// ls.inUse++
|
||||
// return initRows(rowsPtr), nil
|
||||
// }
|
||||
|
||||
@@ -198,27 +198,17 @@ func toGoBlob(blobPtr uintptr) []byte {
|
||||
return copied
|
||||
}
|
||||
|
||||
var freeBlobFunc func(uintptr)
|
||||
|
||||
func freeBlob(blobPtr uintptr) {
|
||||
if blobPtr == 0 {
|
||||
return
|
||||
}
|
||||
if freeBlobFunc == nil {
|
||||
getFfiFunc(&freeBlobFunc, FfiFreeBlob)
|
||||
}
|
||||
freeBlobFunc(blobPtr)
|
||||
}
|
||||
|
||||
var freeStringFunc func(uintptr)
|
||||
|
||||
func freeCString(cstrPtr uintptr) {
|
||||
if cstrPtr == 0 {
|
||||
return
|
||||
}
|
||||
if freeStringFunc == nil {
|
||||
getFfiFunc(&freeStringFunc, FfiFreeCString)
|
||||
}
|
||||
freeStringFunc(cstrPtr)
|
||||
}
|
||||
|
||||
@@ -226,7 +216,6 @@ func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
|
||||
if arrayPtr == 0 || length == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ptrSlice := unsafe.Slice(
|
||||
(**byte)(unsafe.Pointer(arrayPtr)),
|
||||
length,
|
||||
@@ -259,6 +248,7 @@ func buildArgs(args []driver.Value) ([]limboValue, func(), error) {
|
||||
case string:
|
||||
limboVal.Type = textVal
|
||||
cstr := CString(val)
|
||||
pinner.Pin(cstr)
|
||||
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr))
|
||||
case []byte:
|
||||
limboVal.Type = blobVal
|
||||
|
||||
Reference in New Issue
Block a user