Support blob types in query arguments for Go bindings

This commit is contained in:
PThorpe92
2025-01-29 11:16:31 -05:00
parent 4af6eb2f71
commit d9966d2dc8
6 changed files with 60 additions and 73 deletions

View File

@@ -2,6 +2,7 @@ package limbo_test
import (
"database/sql"
"fmt"
"testing"
_ "limbo"
@@ -76,7 +77,7 @@ func TestQuery(t *testing.T) {
}
defer rows.Close()
expectedCols := []string{"foo", "bar"}
expectedCols := []string{"foo", "bar", "baz"}
cols, err := rows.Columns()
if err != nil {
t.Fatalf("Error getting columns: %v", err)
@@ -93,13 +94,15 @@ func TestQuery(t *testing.T) {
for rows.Next() {
var a int
var b string
err = rows.Scan(&a, &b)
var c []byte
err = rows.Scan(&a, &b, &c)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
if a != i || b != rowsMap[i] {
t.Fatalf("Expected %d, %s, got %d, %s", i, rowsMap[i], a, b)
if a != i || b != rowsMap[i] || string(c) != rowsMap[i] {
t.Fatalf("Expected %d, %s, got %d, %s, %b", i, rowsMap[i], a, b, c)
}
fmt.Println("RESULTS: ", a, b, string(c))
i++
}
@@ -111,7 +114,7 @@ func TestQuery(t *testing.T) {
var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"}
func createTable(conn *sql.DB) error {
insert := "CREATE TABLE test (foo INT, bar TEXT);"
insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);"
stmt, err := conn.Prepare(insert)
if err != nil {
return err
@@ -123,13 +126,13 @@ func createTable(conn *sql.DB) error {
func insertData(conn *sql.DB) error {
for i := 1; i <= 5; i++ {
insert := "INSERT INTO test (foo, bar) VALUES (?, ?);"
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);"
stmt, err := conn.Prepare(insert)
if err != nil {
return err
}
defer stmt.Close()
if _, err = stmt.Exec(i, rowsMap[i]); err != nil {
if _, err = stmt.Exec(i, rowsMap[i], []byte(rowsMap[i])); err != nil {
return err
}
}

View File

@@ -35,7 +35,7 @@ func loadLibrary() error {
for _, path := range paths {
libPath := filepath.Join(path, libraryName)
if _, err := os.Stat(libPath); err == nil {
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_LAZY)
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if dlerr != nil {
return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr)
}

View File

@@ -65,8 +65,7 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v
if let Some(ref cursor) = ctx.cursor {
if let Some(value) = cursor.values.get(col_idx) {
let val = LimboValue::from_value(value);
return val.to_ptr();
return LimboValue::from_value(value).to_ptr();
}
}
std::ptr::null()

View File

@@ -30,28 +30,22 @@ pub enum ValueType {
#[repr(C)]
pub struct LimboValue {
pub value_type: ValueType,
pub value: ValueUnion,
value_type: ValueType,
value: ValueUnion,
}
#[repr(C)]
pub union ValueUnion {
pub int_val: i64,
pub real_val: f64,
pub text_ptr: *const c_char,
pub blob_ptr: *const c_void,
union ValueUnion {
int_val: i64,
real_val: f64,
text_ptr: *const c_char,
blob_ptr: *const c_void,
}
#[repr(C)]
pub struct Blob {
pub data: *const u8,
pub len: usize,
}
impl Blob {
pub fn to_ptr(&self) -> *const c_void {
self as *const Blob as *const c_void
}
struct Blob {
data: *const u8,
len: i64,
}
pub struct AllocPool {
@@ -97,12 +91,12 @@ impl ValueUnion {
}
fn from_bytes(b: &[u8]) -> Self {
let blob = Box::new(Blob {
data: b.as_ptr(),
len: b.len() as i64,
});
ValueUnion {
blob_ptr: Blob {
data: b.as_ptr(),
len: b.len(),
}
.to_ptr(),
blob_ptr: Box::into_raw(blob) as *const c_void,
}
}
@@ -140,12 +134,12 @@ impl ValueUnion {
pub fn to_bytes(&self) -> &[u8] {
let blob = unsafe { self.blob_ptr as *const Blob };
let blob = unsafe { &*blob };
unsafe { std::slice::from_raw_parts(blob.data, blob.len) }
unsafe { std::slice::from_raw_parts(blob.data, blob.len as usize) }
}
}
impl LimboValue {
pub fn new(value_type: ValueType, value: ValueUnion) -> Self {
fn new(value_type: ValueType, value: ValueUnion) -> Self {
LimboValue { value_type, value }
}
@@ -204,15 +198,9 @@ impl LimboValue {
if unsafe { self.value.blob_ptr.is_null() } {
return limbo_core::Value::Null;
}
let blob_ptr = unsafe { self.value.blob_ptr as *const Blob };
if blob_ptr.is_null() {
limbo_core::Value::Null
} else {
let blob = unsafe { &*blob_ptr };
let data = unsafe { std::slice::from_raw_parts(blob.data, blob.len) };
let borrowed = pool.add_blob(data.to_vec());
limbo_core::Value::Blob(borrowed)
}
let bytes = self.value.to_bytes();
let borrowed = pool.add_blob(bytes.to_vec());
limbo_core::Value::Blob(borrowed)
}
ValueType::Null => limbo_core::Value::Null,
}

View File

@@ -59,7 +59,8 @@ func (ls *limboStmt) Close() error {
}
func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
argArray, err := buildArgs(args)
argArray, cleanup, err := buildArgs(args)
defer cleanup()
if err != nil {
return nil, err
}
@@ -87,7 +88,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
}
func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
queryArgs, err := buildArgs(args)
queryArgs, cleanup, err := buildArgs(args)
defer cleanup()
if err != nil {
return nil, err
}
@@ -105,7 +107,8 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
stripped := namedValueToValue(args)
argArray, err := getArgsPtr(stripped)
argArray, cleanup, err := getArgsPtr(stripped)
defer cleanup()
if err != nil {
return nil, err
}
@@ -132,7 +135,8 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive
}
func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
queryArgs, err := buildNamedArgs(args)
queryArgs, allocs, err := buildNamedArgs(args)
defer allocs()
if err != nil {
return nil, err
}

View File

@@ -3,6 +3,7 @@ package limbo
import (
"database/sql/driver"
"fmt"
"runtime"
"unsafe"
)
@@ -88,7 +89,7 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value {
return out
}
func buildNamedArgs(named []driver.NamedValue) ([]limboValue, error) {
func buildNamedArgs(named []driver.NamedValue) ([]limboValue, func(), error) {
args := namedValueToValue(named)
return buildArgs(args)
}
@@ -123,14 +124,14 @@ func (vt valueType) String() string {
// struct to pass Go values over FFI
type limboValue struct {
Type valueType
_ [4]byte // padding to align Value to 8 bytes
_ [4]byte
Value [8]byte
}
// struct to pass byte slices over FFI
type Blob struct {
Data uintptr
Len uint
Len int64
}
// convert a limboValue to a native Go value
@@ -157,15 +158,15 @@ func toGoValue(valPtr uintptr) interface{} {
}
}
func getArgsPtr(args []driver.Value) (uintptr, error) {
func getArgsPtr(args []driver.Value) (uintptr, func(), error) {
if len(args) == 0 {
return 0, nil
return 0, nil, nil
}
argSlice, err := buildArgs(args)
argSlice, allocs, err := buildArgs(args)
if err != nil {
return 0, err
return 0, allocs, err
}
return uintptr(unsafe.Pointer(&argSlice[0])), nil
return uintptr(unsafe.Pointer(&argSlice[0])), allocs, nil
}
// convert a byte slice to a Blob type that can be sent over FFI
@@ -173,11 +174,10 @@ func makeBlob(b []byte) *Blob {
if len(b) == 0 {
return nil
}
blob := &Blob{
return &Blob{
Data: uintptr(unsafe.Pointer(&b[0])),
Len: uint(len(b)),
Len: int64(len(b)),
}
return blob
}
// converts a blob received via FFI to a native Go byte slice
@@ -186,6 +186,9 @@ func toGoBlob(blobPtr uintptr) []byte {
return nil
}
blob := (*Blob)(unsafe.Pointer(blobPtr))
if blob.Data == 0 || blob.Len == 0 {
return nil
}
return unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len)
}
@@ -207,7 +210,8 @@ func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
}
// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI
func buildArgs(args []driver.Value) ([]limboValue, error) {
func buildArgs(args []driver.Value) ([]limboValue, func(), error) {
pinner := new(runtime.Pinner)
argSlice := make([]limboValue, len(args))
for i, v := range args {
limboVal := limboValue{}
@@ -225,27 +229,16 @@ func buildArgs(args []driver.Value) ([]limboValue, error) {
cstr := CString(val)
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr))
case []byte:
argSlice[i].Type = blobVal
limboVal.Type = blobVal
blob := makeBlob(val)
pinner.Pin(blob)
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob))
default:
return nil, fmt.Errorf("unsupported type: %T", v)
return nil, pinner.Unpin, fmt.Errorf("unsupported type: %T", v)
}
argSlice[i] = limboVal
}
return argSlice, nil
}
func storeInt64(data *[8]byte, val int64) {
*(*int64)(unsafe.Pointer(data)) = val
}
func storeFloat64(data *[8]byte, val float64) {
*(*float64)(unsafe.Pointer(data)) = val
}
func storePointer(data *[8]byte, ptr *byte) {
*(*uintptr)(unsafe.Pointer(data)) = uintptr(unsafe.Pointer(ptr))
return argSlice, pinner.Unpin, nil
}
/* Credit below (Apache2 License) to: