Merge 'bindings/go Support blob types in query arguments, free non-gc allocations' from Preston Thorpe

This PR fixes/adds support for the Blob type and adds the appropriate
tests.
Types created on the Go side will be cleaned up rather quickly if
nothing is referencing them, so this approach uses `runtime.Pinner` to
pin the bytes in memory so the pointers will be valid when Rust uses
`from_raw_parts` and then owns a new vec. They are then cleaned up after
the FFI call with `pinner.Unpin`.

Closes #822
This commit is contained in:
Pekka Enberg
2025-01-30 21:12:22 +02:00
6 changed files with 93 additions and 74 deletions

View File

@@ -2,6 +2,7 @@ package limbo_test
import (
"database/sql"
"fmt"
"testing"
_ "limbo"
@@ -76,7 +77,7 @@ func TestQuery(t *testing.T) {
}
defer rows.Close()
expectedCols := []string{"foo", "bar"}
expectedCols := []string{"foo", "bar", "baz"}
cols, err := rows.Columns()
if err != nil {
t.Fatalf("Error getting columns: %v", err)
@@ -93,13 +94,15 @@ func TestQuery(t *testing.T) {
for rows.Next() {
var a int
var b string
err = rows.Scan(&a, &b)
var c []byte
err = rows.Scan(&a, &b, &c)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
if a != i || b != rowsMap[i] {
t.Fatalf("Expected %d, %s, got %d, %s", i, rowsMap[i], a, b)
if a != i || b != rowsMap[i] || string(c) != rowsMap[i] {
t.Fatalf("Expected %d, %s, %s, got %d, %s, %b", i, rowsMap[i], rowsMap[i], a, b, c)
}
fmt.Println("RESULTS: ", a, b, string(c))
i++
}
@@ -111,7 +114,7 @@ func TestQuery(t *testing.T) {
var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"}
func createTable(conn *sql.DB) error {
insert := "CREATE TABLE test (foo INT, bar TEXT);"
insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);"
stmt, err := conn.Prepare(insert)
if err != nil {
return err
@@ -123,13 +126,13 @@ func createTable(conn *sql.DB) error {
func insertData(conn *sql.DB) error {
for i := 1; i <= 5; i++ {
insert := "INSERT INTO test (foo, bar) VALUES (?, ?);"
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);"
stmt, err := conn.Prepare(insert)
if err != nil {
return err
}
defer stmt.Close()
if _, err = stmt.Exec(i, rowsMap[i]); err != nil {
if _, err = stmt.Exec(i, rowsMap[i], []byte(rowsMap[i])); err != nil {
return err
}
}

View File

@@ -35,7 +35,7 @@ func loadLibrary() error {
for _, path := range paths {
libPath := filepath.Join(path, libraryName)
if _, err := os.Stat(libPath); err == nil {
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_LAZY)
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if dlerr != nil {
return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr)
}

View File

@@ -65,8 +65,7 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v
if let Some(ref cursor) = ctx.cursor {
if let Some(value) = cursor.values.get(col_idx) {
let val = LimboValue::from_value(value);
return val.to_ptr();
return LimboValue::from_value(value).to_ptr();
}
}
std::ptr::null()

View File

@@ -30,28 +30,22 @@ pub enum ValueType {
#[repr(C)]
pub struct LimboValue {
pub value_type: ValueType,
pub value: ValueUnion,
value_type: ValueType,
value: ValueUnion,
}
#[repr(C)]
pub union ValueUnion {
pub int_val: i64,
pub real_val: f64,
pub text_ptr: *const c_char,
pub blob_ptr: *const c_void,
union ValueUnion {
int_val: i64,
real_val: f64,
text_ptr: *const c_char,
blob_ptr: *const c_void,
}
#[repr(C)]
pub struct Blob {
pub data: *const u8,
pub len: usize,
}
impl Blob {
pub fn to_ptr(&self) -> *const c_void {
self as *const Blob as *const c_void
}
struct Blob {
data: *const u8,
len: i64,
}
pub struct AllocPool {
@@ -97,12 +91,12 @@ impl ValueUnion {
}
fn from_bytes(b: &[u8]) -> Self {
let blob = Box::new(Blob {
data: b.as_ptr(),
len: b.len() as i64,
});
ValueUnion {
blob_ptr: Blob {
data: b.as_ptr(),
len: b.len(),
}
.to_ptr(),
blob_ptr: Box::into_raw(blob) as *const c_void,
}
}
@@ -140,12 +134,12 @@ impl ValueUnion {
pub fn to_bytes(&self) -> &[u8] {
let blob = unsafe { self.blob_ptr as *const Blob };
let blob = unsafe { &*blob };
unsafe { std::slice::from_raw_parts(blob.data, blob.len) }
unsafe { std::slice::from_raw_parts(blob.data, blob.len as usize) }
}
}
impl LimboValue {
pub fn new(value_type: ValueType, value: ValueUnion) -> Self {
fn new(value_type: ValueType, value: ValueUnion) -> Self {
LimboValue { value_type, value }
}
@@ -204,15 +198,9 @@ impl LimboValue {
if unsafe { self.value.blob_ptr.is_null() } {
return limbo_core::Value::Null;
}
let blob_ptr = unsafe { self.value.blob_ptr as *const Blob };
if blob_ptr.is_null() {
limbo_core::Value::Null
} else {
let blob = unsafe { &*blob_ptr };
let data = unsafe { std::slice::from_raw_parts(blob.data, blob.len) };
let borrowed = pool.add_blob(data.to_vec());
limbo_core::Value::Blob(borrowed)
}
let bytes = self.value.to_bytes();
let borrowed = pool.add_blob(bytes.to_vec());
limbo_core::Value::Blob(borrowed)
}
ValueType::Null => limbo_core::Value::Null,
}

View File

@@ -59,7 +59,8 @@ func (ls *limboStmt) Close() error {
}
func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
argArray, err := buildArgs(args)
argArray, cleanup, err := buildArgs(args)
defer cleanup()
if err != nil {
return nil, err
}
@@ -87,7 +88,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
}
func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
queryArgs, err := buildArgs(args)
queryArgs, cleanup, err := buildArgs(args)
defer cleanup()
if err != nil {
return nil, err
}
@@ -105,7 +107,8 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
stripped := namedValueToValue(args)
argArray, err := getArgsPtr(stripped)
argArray, cleanup, err := getArgsPtr(stripped)
defer cleanup()
if err != nil {
return nil, err
}
@@ -132,7 +135,8 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive
}
func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
queryArgs, err := buildNamedArgs(args)
queryArgs, allocs, err := buildNamedArgs(args)
defer allocs()
if err != nil {
return nil, err
}

View File

@@ -3,6 +3,7 @@ package limbo
import (
"database/sql/driver"
"fmt"
"runtime"
"unsafe"
)
@@ -77,6 +78,7 @@ const (
FfiRowsGetValue string = "rows_get_value"
FfiFreeColumns string = "free_columns"
FfiFreeCString string = "free_string"
FfiFreeBlob string = "free_blob"
)
// convert a namedValue slice into normal values until named parameters are supported
@@ -88,7 +90,7 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value {
return out
}
func buildNamedArgs(named []driver.NamedValue) ([]limboValue, error) {
func buildNamedArgs(named []driver.NamedValue) ([]limboValue, func(), error) {
args := namedValueToValue(named)
return buildArgs(args)
}
@@ -123,14 +125,14 @@ func (vt valueType) String() string {
// struct to pass Go values over FFI
type limboValue struct {
Type valueType
_ [4]byte // padding to align Value to 8 bytes
_ [4]byte
Value [8]byte
}
// struct to pass byte slices over FFI
type Blob struct {
Data uintptr
Len uint
Len int64
}
// convert a limboValue to a native Go value
@@ -146,9 +148,11 @@ func toGoValue(valPtr uintptr) interface{} {
return *(*float64)(unsafe.Pointer(&val.Value))
case textVal:
textPtr := *(*uintptr)(unsafe.Pointer(&val.Value))
defer freeCString(textPtr)
return GoString(textPtr)
case blobVal:
blobPtr := *(*uintptr)(unsafe.Pointer(&val.Value))
defer freeBlob(blobPtr)
return toGoBlob(blobPtr)
case nullVal:
return nil
@@ -157,15 +161,15 @@ func toGoValue(valPtr uintptr) interface{} {
}
}
func getArgsPtr(args []driver.Value) (uintptr, error) {
func getArgsPtr(args []driver.Value) (uintptr, func(), error) {
if len(args) == 0 {
return 0, nil
return 0, nil, nil
}
argSlice, err := buildArgs(args)
argSlice, allocs, err := buildArgs(args)
if err != nil {
return 0, err
return 0, allocs, err
}
return uintptr(unsafe.Pointer(&argSlice[0])), nil
return uintptr(unsafe.Pointer(&argSlice[0])), allocs, nil
}
// convert a byte slice to a Blob type that can be sent over FFI
@@ -173,11 +177,10 @@ func makeBlob(b []byte) *Blob {
if len(b) == 0 {
return nil
}
blob := &Blob{
return &Blob{
Data: uintptr(unsafe.Pointer(&b[0])),
Len: uint(len(b)),
Len: int64(len(b)),
}
return blob
}
// converts a blob received via FFI to a native Go byte slice
@@ -186,7 +189,37 @@ func toGoBlob(blobPtr uintptr) []byte {
return nil
}
blob := (*Blob)(unsafe.Pointer(blobPtr))
return unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len)
if blob.Data == 0 || blob.Len == 0 {
return nil
}
data := unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len)
copied := make([]byte, len(data))
copy(copied, data)
return copied
}
var freeBlobFunc func(uintptr)
func freeBlob(blobPtr uintptr) {
if blobPtr == 0 {
return
}
if freeBlobFunc == nil {
getFfiFunc(&freeBlobFunc, FfiFreeBlob)
}
freeBlobFunc(blobPtr)
}
var freeStringFunc func(uintptr)
func freeCString(cstrPtr uintptr) {
if cstrPtr == 0 {
return
}
if freeStringFunc == nil {
getFfiFunc(&freeStringFunc, FfiFreeCString)
}
freeStringFunc(cstrPtr)
}
func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
@@ -207,7 +240,10 @@ func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
}
// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI
func buildArgs(args []driver.Value) ([]limboValue, error) {
// for Blob types, we have to pin them so they are not garbage collected before they can be copied
// into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call
func buildArgs(args []driver.Value) ([]limboValue, func(), error) {
pinner := new(runtime.Pinner)
argSlice := make([]limboValue, len(args))
for i, v := range args {
limboVal := limboValue{}
@@ -225,27 +261,16 @@ func buildArgs(args []driver.Value) ([]limboValue, error) {
cstr := CString(val)
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr))
case []byte:
argSlice[i].Type = blobVal
limboVal.Type = blobVal
blob := makeBlob(val)
pinner.Pin(blob)
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob))
default:
return nil, fmt.Errorf("unsupported type: %T", v)
return nil, pinner.Unpin, fmt.Errorf("unsupported type: %T", v)
}
argSlice[i] = limboVal
}
return argSlice, nil
}
func storeInt64(data *[8]byte, val int64) {
*(*int64)(unsafe.Pointer(data)) = val
}
func storeFloat64(data *[8]byte, val float64) {
*(*float64)(unsafe.Pointer(data)) = val
}
func storePointer(data *[8]byte, ptr *byte) {
*(*uintptr)(unsafe.Pointer(data)) = uintptr(unsafe.Pointer(ptr))
return argSlice, pinner.Unpin, nil
}
/* Credit below (Apache2 License) to: