mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-15 21:14:21 +01:00
Support blob types in query arguments for Go bindings
This commit is contained in:
@@ -2,6 +2,7 @@ package limbo_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
_ "limbo"
|
||||
@@ -76,7 +77,7 @@ func TestQuery(t *testing.T) {
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
expectedCols := []string{"foo", "bar"}
|
||||
expectedCols := []string{"foo", "bar", "baz"}
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting columns: %v", err)
|
||||
@@ -93,13 +94,15 @@ func TestQuery(t *testing.T) {
|
||||
for rows.Next() {
|
||||
var a int
|
||||
var b string
|
||||
err = rows.Scan(&a, &b)
|
||||
var c []byte
|
||||
err = rows.Scan(&a, &b, &c)
|
||||
if err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
if a != i || b != rowsMap[i] {
|
||||
t.Fatalf("Expected %d, %s, got %d, %s", i, rowsMap[i], a, b)
|
||||
if a != i || b != rowsMap[i] || string(c) != rowsMap[i] {
|
||||
t.Fatalf("Expected %d, %s, got %d, %s, %b", i, rowsMap[i], a, b, c)
|
||||
}
|
||||
fmt.Println("RESULTS: ", a, b, string(c))
|
||||
i++
|
||||
}
|
||||
|
||||
@@ -111,7 +114,7 @@ func TestQuery(t *testing.T) {
|
||||
var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"}
|
||||
|
||||
func createTable(conn *sql.DB) error {
|
||||
insert := "CREATE TABLE test (foo INT, bar TEXT);"
|
||||
insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -123,13 +126,13 @@ func createTable(conn *sql.DB) error {
|
||||
|
||||
func insertData(conn *sql.DB) error {
|
||||
for i := 1; i <= 5; i++ {
|
||||
insert := "INSERT INTO test (foo, bar) VALUES (?, ?);"
|
||||
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err = stmt.Exec(i, rowsMap[i]); err != nil {
|
||||
if _, err = stmt.Exec(i, rowsMap[i], []byte(rowsMap[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func loadLibrary() error {
|
||||
for _, path := range paths {
|
||||
libPath := filepath.Join(path, libraryName)
|
||||
if _, err := os.Stat(libPath); err == nil {
|
||||
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_LAZY)
|
||||
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if dlerr != nil {
|
||||
return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr)
|
||||
}
|
||||
|
||||
@@ -65,8 +65,7 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v
|
||||
|
||||
if let Some(ref cursor) = ctx.cursor {
|
||||
if let Some(value) = cursor.values.get(col_idx) {
|
||||
let val = LimboValue::from_value(value);
|
||||
return val.to_ptr();
|
||||
return LimboValue::from_value(value).to_ptr();
|
||||
}
|
||||
}
|
||||
std::ptr::null()
|
||||
|
||||
@@ -30,28 +30,22 @@ pub enum ValueType {
|
||||
|
||||
#[repr(C)]
|
||||
pub struct LimboValue {
|
||||
pub value_type: ValueType,
|
||||
pub value: ValueUnion,
|
||||
value_type: ValueType,
|
||||
value: ValueUnion,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub union ValueUnion {
|
||||
pub int_val: i64,
|
||||
pub real_val: f64,
|
||||
pub text_ptr: *const c_char,
|
||||
pub blob_ptr: *const c_void,
|
||||
union ValueUnion {
|
||||
int_val: i64,
|
||||
real_val: f64,
|
||||
text_ptr: *const c_char,
|
||||
blob_ptr: *const c_void,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Blob {
|
||||
pub data: *const u8,
|
||||
pub len: usize,
|
||||
}
|
||||
|
||||
impl Blob {
|
||||
pub fn to_ptr(&self) -> *const c_void {
|
||||
self as *const Blob as *const c_void
|
||||
}
|
||||
struct Blob {
|
||||
data: *const u8,
|
||||
len: i64,
|
||||
}
|
||||
|
||||
pub struct AllocPool {
|
||||
@@ -97,12 +91,12 @@ impl ValueUnion {
|
||||
}
|
||||
|
||||
fn from_bytes(b: &[u8]) -> Self {
|
||||
let blob = Box::new(Blob {
|
||||
data: b.as_ptr(),
|
||||
len: b.len() as i64,
|
||||
});
|
||||
ValueUnion {
|
||||
blob_ptr: Blob {
|
||||
data: b.as_ptr(),
|
||||
len: b.len(),
|
||||
}
|
||||
.to_ptr(),
|
||||
blob_ptr: Box::into_raw(blob) as *const c_void,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,12 +134,12 @@ impl ValueUnion {
|
||||
pub fn to_bytes(&self) -> &[u8] {
|
||||
let blob = unsafe { self.blob_ptr as *const Blob };
|
||||
let blob = unsafe { &*blob };
|
||||
unsafe { std::slice::from_raw_parts(blob.data, blob.len) }
|
||||
unsafe { std::slice::from_raw_parts(blob.data, blob.len as usize) }
|
||||
}
|
||||
}
|
||||
|
||||
impl LimboValue {
|
||||
pub fn new(value_type: ValueType, value: ValueUnion) -> Self {
|
||||
fn new(value_type: ValueType, value: ValueUnion) -> Self {
|
||||
LimboValue { value_type, value }
|
||||
}
|
||||
|
||||
@@ -204,15 +198,9 @@ impl LimboValue {
|
||||
if unsafe { self.value.blob_ptr.is_null() } {
|
||||
return limbo_core::Value::Null;
|
||||
}
|
||||
let blob_ptr = unsafe { self.value.blob_ptr as *const Blob };
|
||||
if blob_ptr.is_null() {
|
||||
limbo_core::Value::Null
|
||||
} else {
|
||||
let blob = unsafe { &*blob_ptr };
|
||||
let data = unsafe { std::slice::from_raw_parts(blob.data, blob.len) };
|
||||
let borrowed = pool.add_blob(data.to_vec());
|
||||
limbo_core::Value::Blob(borrowed)
|
||||
}
|
||||
let bytes = self.value.to_bytes();
|
||||
let borrowed = pool.add_blob(bytes.to_vec());
|
||||
limbo_core::Value::Blob(borrowed)
|
||||
}
|
||||
ValueType::Null => limbo_core::Value::Null,
|
||||
}
|
||||
|
||||
@@ -59,7 +59,8 @@ func (ls *limboStmt) Close() error {
|
||||
}
|
||||
|
||||
func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
argArray, err := buildArgs(args)
|
||||
argArray, cleanup, err := buildArgs(args)
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -87,7 +88,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
}
|
||||
|
||||
func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
queryArgs, err := buildArgs(args)
|
||||
queryArgs, cleanup, err := buildArgs(args)
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -105,7 +107,8 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
|
||||
func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
stripped := namedValueToValue(args)
|
||||
argArray, err := getArgsPtr(stripped)
|
||||
argArray, cleanup, err := getArgsPtr(stripped)
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -132,7 +135,8 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive
|
||||
}
|
||||
|
||||
func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
queryArgs, err := buildNamedArgs(args)
|
||||
queryArgs, allocs, err := buildNamedArgs(args)
|
||||
defer allocs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package limbo
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -88,7 +89,7 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value {
|
||||
return out
|
||||
}
|
||||
|
||||
func buildNamedArgs(named []driver.NamedValue) ([]limboValue, error) {
|
||||
func buildNamedArgs(named []driver.NamedValue) ([]limboValue, func(), error) {
|
||||
args := namedValueToValue(named)
|
||||
return buildArgs(args)
|
||||
}
|
||||
@@ -123,14 +124,14 @@ func (vt valueType) String() string {
|
||||
// struct to pass Go values over FFI
|
||||
type limboValue struct {
|
||||
Type valueType
|
||||
_ [4]byte // padding to align Value to 8 bytes
|
||||
_ [4]byte
|
||||
Value [8]byte
|
||||
}
|
||||
|
||||
// struct to pass byte slices over FFI
|
||||
type Blob struct {
|
||||
Data uintptr
|
||||
Len uint
|
||||
Len int64
|
||||
}
|
||||
|
||||
// convert a limboValue to a native Go value
|
||||
@@ -157,15 +158,15 @@ func toGoValue(valPtr uintptr) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func getArgsPtr(args []driver.Value) (uintptr, error) {
|
||||
func getArgsPtr(args []driver.Value) (uintptr, func(), error) {
|
||||
if len(args) == 0 {
|
||||
return 0, nil
|
||||
return 0, nil, nil
|
||||
}
|
||||
argSlice, err := buildArgs(args)
|
||||
argSlice, allocs, err := buildArgs(args)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, allocs, err
|
||||
}
|
||||
return uintptr(unsafe.Pointer(&argSlice[0])), nil
|
||||
return uintptr(unsafe.Pointer(&argSlice[0])), allocs, nil
|
||||
}
|
||||
|
||||
// convert a byte slice to a Blob type that can be sent over FFI
|
||||
@@ -173,11 +174,10 @@ func makeBlob(b []byte) *Blob {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
blob := &Blob{
|
||||
return &Blob{
|
||||
Data: uintptr(unsafe.Pointer(&b[0])),
|
||||
Len: uint(len(b)),
|
||||
Len: int64(len(b)),
|
||||
}
|
||||
return blob
|
||||
}
|
||||
|
||||
// converts a blob received via FFI to a native Go byte slice
|
||||
@@ -186,6 +186,9 @@ func toGoBlob(blobPtr uintptr) []byte {
|
||||
return nil
|
||||
}
|
||||
blob := (*Blob)(unsafe.Pointer(blobPtr))
|
||||
if blob.Data == 0 || blob.Len == 0 {
|
||||
return nil
|
||||
}
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len)
|
||||
}
|
||||
|
||||
@@ -207,7 +210,8 @@ func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
|
||||
}
|
||||
|
||||
// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI
|
||||
func buildArgs(args []driver.Value) ([]limboValue, error) {
|
||||
func buildArgs(args []driver.Value) ([]limboValue, func(), error) {
|
||||
pinner := new(runtime.Pinner)
|
||||
argSlice := make([]limboValue, len(args))
|
||||
for i, v := range args {
|
||||
limboVal := limboValue{}
|
||||
@@ -225,27 +229,16 @@ func buildArgs(args []driver.Value) ([]limboValue, error) {
|
||||
cstr := CString(val)
|
||||
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr))
|
||||
case []byte:
|
||||
argSlice[i].Type = blobVal
|
||||
limboVal.Type = blobVal
|
||||
blob := makeBlob(val)
|
||||
pinner.Pin(blob)
|
||||
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob))
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type: %T", v)
|
||||
return nil, pinner.Unpin, fmt.Errorf("unsupported type: %T", v)
|
||||
}
|
||||
argSlice[i] = limboVal
|
||||
}
|
||||
return argSlice, nil
|
||||
}
|
||||
|
||||
func storeInt64(data *[8]byte, val int64) {
|
||||
*(*int64)(unsafe.Pointer(data)) = val
|
||||
}
|
||||
|
||||
func storeFloat64(data *[8]byte, val float64) {
|
||||
*(*float64)(unsafe.Pointer(data)) = val
|
||||
}
|
||||
|
||||
func storePointer(data *[8]byte, ptr *byte) {
|
||||
*(*uintptr)(unsafe.Pointer(data)) = uintptr(unsafe.Pointer(ptr))
|
||||
return argSlice, pinner.Unpin, nil
|
||||
}
|
||||
|
||||
/* Credit below (Apache2 License) to:
|
||||
|
||||
Reference in New Issue
Block a user