mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-04 17:04:18 +01:00
Merge 'bindings/go Support blob types in query arguments, free non-gc allocations' from Preston Thorpe
This PR fixes/adds support for the Blob type and adds the appropriate tests. Types created on the Go side will be cleaned up rather quickly if nothing is referencing them, so this approach uses `runtime.Pinner` to pin the bytes in memory so the pointers will be valid when Rust uses `from_raw_parts` and then owns a new vec. They are then cleaned up after the FFI call with `pinner.Unpin`. Closes #822
This commit is contained in:
@@ -2,6 +2,7 @@ package limbo_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
_ "limbo"
|
||||
@@ -76,7 +77,7 @@ func TestQuery(t *testing.T) {
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
expectedCols := []string{"foo", "bar"}
|
||||
expectedCols := []string{"foo", "bar", "baz"}
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting columns: %v", err)
|
||||
@@ -93,13 +94,15 @@ func TestQuery(t *testing.T) {
|
||||
for rows.Next() {
|
||||
var a int
|
||||
var b string
|
||||
err = rows.Scan(&a, &b)
|
||||
var c []byte
|
||||
err = rows.Scan(&a, &b, &c)
|
||||
if err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
if a != i || b != rowsMap[i] {
|
||||
t.Fatalf("Expected %d, %s, got %d, %s", i, rowsMap[i], a, b)
|
||||
if a != i || b != rowsMap[i] || string(c) != rowsMap[i] {
|
||||
t.Fatalf("Expected %d, %s, %s, got %d, %s, %b", i, rowsMap[i], rowsMap[i], a, b, c)
|
||||
}
|
||||
fmt.Println("RESULTS: ", a, b, string(c))
|
||||
i++
|
||||
}
|
||||
|
||||
@@ -111,7 +114,7 @@ func TestQuery(t *testing.T) {
|
||||
var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"}
|
||||
|
||||
func createTable(conn *sql.DB) error {
|
||||
insert := "CREATE TABLE test (foo INT, bar TEXT);"
|
||||
insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -123,13 +126,13 @@ func createTable(conn *sql.DB) error {
|
||||
|
||||
func insertData(conn *sql.DB) error {
|
||||
for i := 1; i <= 5; i++ {
|
||||
insert := "INSERT INTO test (foo, bar) VALUES (?, ?);"
|
||||
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);"
|
||||
stmt, err := conn.Prepare(insert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err = stmt.Exec(i, rowsMap[i]); err != nil {
|
||||
if _, err = stmt.Exec(i, rowsMap[i], []byte(rowsMap[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func loadLibrary() error {
|
||||
for _, path := range paths {
|
||||
libPath := filepath.Join(path, libraryName)
|
||||
if _, err := os.Stat(libPath); err == nil {
|
||||
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_LAZY)
|
||||
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if dlerr != nil {
|
||||
return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr)
|
||||
}
|
||||
|
||||
@@ -65,8 +65,7 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v
|
||||
|
||||
if let Some(ref cursor) = ctx.cursor {
|
||||
if let Some(value) = cursor.values.get(col_idx) {
|
||||
let val = LimboValue::from_value(value);
|
||||
return val.to_ptr();
|
||||
return LimboValue::from_value(value).to_ptr();
|
||||
}
|
||||
}
|
||||
std::ptr::null()
|
||||
|
||||
@@ -30,28 +30,22 @@ pub enum ValueType {
|
||||
|
||||
#[repr(C)]
|
||||
pub struct LimboValue {
|
||||
pub value_type: ValueType,
|
||||
pub value: ValueUnion,
|
||||
value_type: ValueType,
|
||||
value: ValueUnion,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub union ValueUnion {
|
||||
pub int_val: i64,
|
||||
pub real_val: f64,
|
||||
pub text_ptr: *const c_char,
|
||||
pub blob_ptr: *const c_void,
|
||||
union ValueUnion {
|
||||
int_val: i64,
|
||||
real_val: f64,
|
||||
text_ptr: *const c_char,
|
||||
blob_ptr: *const c_void,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Blob {
|
||||
pub data: *const u8,
|
||||
pub len: usize,
|
||||
}
|
||||
|
||||
impl Blob {
|
||||
pub fn to_ptr(&self) -> *const c_void {
|
||||
self as *const Blob as *const c_void
|
||||
}
|
||||
struct Blob {
|
||||
data: *const u8,
|
||||
len: i64,
|
||||
}
|
||||
|
||||
pub struct AllocPool {
|
||||
@@ -97,12 +91,12 @@ impl ValueUnion {
|
||||
}
|
||||
|
||||
fn from_bytes(b: &[u8]) -> Self {
|
||||
let blob = Box::new(Blob {
|
||||
data: b.as_ptr(),
|
||||
len: b.len() as i64,
|
||||
});
|
||||
ValueUnion {
|
||||
blob_ptr: Blob {
|
||||
data: b.as_ptr(),
|
||||
len: b.len(),
|
||||
}
|
||||
.to_ptr(),
|
||||
blob_ptr: Box::into_raw(blob) as *const c_void,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,12 +134,12 @@ impl ValueUnion {
|
||||
pub fn to_bytes(&self) -> &[u8] {
|
||||
let blob = unsafe { self.blob_ptr as *const Blob };
|
||||
let blob = unsafe { &*blob };
|
||||
unsafe { std::slice::from_raw_parts(blob.data, blob.len) }
|
||||
unsafe { std::slice::from_raw_parts(blob.data, blob.len as usize) }
|
||||
}
|
||||
}
|
||||
|
||||
impl LimboValue {
|
||||
pub fn new(value_type: ValueType, value: ValueUnion) -> Self {
|
||||
fn new(value_type: ValueType, value: ValueUnion) -> Self {
|
||||
LimboValue { value_type, value }
|
||||
}
|
||||
|
||||
@@ -204,15 +198,9 @@ impl LimboValue {
|
||||
if unsafe { self.value.blob_ptr.is_null() } {
|
||||
return limbo_core::Value::Null;
|
||||
}
|
||||
let blob_ptr = unsafe { self.value.blob_ptr as *const Blob };
|
||||
if blob_ptr.is_null() {
|
||||
limbo_core::Value::Null
|
||||
} else {
|
||||
let blob = unsafe { &*blob_ptr };
|
||||
let data = unsafe { std::slice::from_raw_parts(blob.data, blob.len) };
|
||||
let borrowed = pool.add_blob(data.to_vec());
|
||||
limbo_core::Value::Blob(borrowed)
|
||||
}
|
||||
let bytes = self.value.to_bytes();
|
||||
let borrowed = pool.add_blob(bytes.to_vec());
|
||||
limbo_core::Value::Blob(borrowed)
|
||||
}
|
||||
ValueType::Null => limbo_core::Value::Null,
|
||||
}
|
||||
|
||||
@@ -59,7 +59,8 @@ func (ls *limboStmt) Close() error {
|
||||
}
|
||||
|
||||
func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
argArray, err := buildArgs(args)
|
||||
argArray, cleanup, err := buildArgs(args)
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -87,7 +88,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
}
|
||||
|
||||
func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
queryArgs, err := buildArgs(args)
|
||||
queryArgs, cleanup, err := buildArgs(args)
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -105,7 +107,8 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
|
||||
func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
stripped := namedValueToValue(args)
|
||||
argArray, err := getArgsPtr(stripped)
|
||||
argArray, cleanup, err := getArgsPtr(stripped)
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -132,7 +135,8 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive
|
||||
}
|
||||
|
||||
func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
queryArgs, err := buildNamedArgs(args)
|
||||
queryArgs, allocs, err := buildNamedArgs(args)
|
||||
defer allocs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package limbo
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -77,6 +78,7 @@ const (
|
||||
FfiRowsGetValue string = "rows_get_value"
|
||||
FfiFreeColumns string = "free_columns"
|
||||
FfiFreeCString string = "free_string"
|
||||
FfiFreeBlob string = "free_blob"
|
||||
)
|
||||
|
||||
// convert a namedValue slice into normal values until named parameters are supported
|
||||
@@ -88,7 +90,7 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value {
|
||||
return out
|
||||
}
|
||||
|
||||
func buildNamedArgs(named []driver.NamedValue) ([]limboValue, error) {
|
||||
func buildNamedArgs(named []driver.NamedValue) ([]limboValue, func(), error) {
|
||||
args := namedValueToValue(named)
|
||||
return buildArgs(args)
|
||||
}
|
||||
@@ -123,14 +125,14 @@ func (vt valueType) String() string {
|
||||
// struct to pass Go values over FFI
|
||||
type limboValue struct {
|
||||
Type valueType
|
||||
_ [4]byte // padding to align Value to 8 bytes
|
||||
_ [4]byte
|
||||
Value [8]byte
|
||||
}
|
||||
|
||||
// struct to pass byte slices over FFI
|
||||
type Blob struct {
|
||||
Data uintptr
|
||||
Len uint
|
||||
Len int64
|
||||
}
|
||||
|
||||
// convert a limboValue to a native Go value
|
||||
@@ -146,9 +148,11 @@ func toGoValue(valPtr uintptr) interface{} {
|
||||
return *(*float64)(unsafe.Pointer(&val.Value))
|
||||
case textVal:
|
||||
textPtr := *(*uintptr)(unsafe.Pointer(&val.Value))
|
||||
defer freeCString(textPtr)
|
||||
return GoString(textPtr)
|
||||
case blobVal:
|
||||
blobPtr := *(*uintptr)(unsafe.Pointer(&val.Value))
|
||||
defer freeBlob(blobPtr)
|
||||
return toGoBlob(blobPtr)
|
||||
case nullVal:
|
||||
return nil
|
||||
@@ -157,15 +161,15 @@ func toGoValue(valPtr uintptr) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func getArgsPtr(args []driver.Value) (uintptr, error) {
|
||||
func getArgsPtr(args []driver.Value) (uintptr, func(), error) {
|
||||
if len(args) == 0 {
|
||||
return 0, nil
|
||||
return 0, nil, nil
|
||||
}
|
||||
argSlice, err := buildArgs(args)
|
||||
argSlice, allocs, err := buildArgs(args)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, allocs, err
|
||||
}
|
||||
return uintptr(unsafe.Pointer(&argSlice[0])), nil
|
||||
return uintptr(unsafe.Pointer(&argSlice[0])), allocs, nil
|
||||
}
|
||||
|
||||
// convert a byte slice to a Blob type that can be sent over FFI
|
||||
@@ -173,11 +177,10 @@ func makeBlob(b []byte) *Blob {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
blob := &Blob{
|
||||
return &Blob{
|
||||
Data: uintptr(unsafe.Pointer(&b[0])),
|
||||
Len: uint(len(b)),
|
||||
Len: int64(len(b)),
|
||||
}
|
||||
return blob
|
||||
}
|
||||
|
||||
// converts a blob received via FFI to a native Go byte slice
|
||||
@@ -186,7 +189,37 @@ func toGoBlob(blobPtr uintptr) []byte {
|
||||
return nil
|
||||
}
|
||||
blob := (*Blob)(unsafe.Pointer(blobPtr))
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len)
|
||||
if blob.Data == 0 || blob.Len == 0 {
|
||||
return nil
|
||||
}
|
||||
data := unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len)
|
||||
copied := make([]byte, len(data))
|
||||
copy(copied, data)
|
||||
return copied
|
||||
}
|
||||
|
||||
var freeBlobFunc func(uintptr)
|
||||
|
||||
func freeBlob(blobPtr uintptr) {
|
||||
if blobPtr == 0 {
|
||||
return
|
||||
}
|
||||
if freeBlobFunc == nil {
|
||||
getFfiFunc(&freeBlobFunc, FfiFreeBlob)
|
||||
}
|
||||
freeBlobFunc(blobPtr)
|
||||
}
|
||||
|
||||
var freeStringFunc func(uintptr)
|
||||
|
||||
func freeCString(cstrPtr uintptr) {
|
||||
if cstrPtr == 0 {
|
||||
return
|
||||
}
|
||||
if freeStringFunc == nil {
|
||||
getFfiFunc(&freeStringFunc, FfiFreeCString)
|
||||
}
|
||||
freeStringFunc(cstrPtr)
|
||||
}
|
||||
|
||||
func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
|
||||
@@ -207,7 +240,10 @@ func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
|
||||
}
|
||||
|
||||
// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI
|
||||
func buildArgs(args []driver.Value) ([]limboValue, error) {
|
||||
// for Blob types, we have to pin them so they are not garbage collected before they can be copied
|
||||
// into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call
|
||||
func buildArgs(args []driver.Value) ([]limboValue, func(), error) {
|
||||
pinner := new(runtime.Pinner)
|
||||
argSlice := make([]limboValue, len(args))
|
||||
for i, v := range args {
|
||||
limboVal := limboValue{}
|
||||
@@ -225,27 +261,16 @@ func buildArgs(args []driver.Value) ([]limboValue, error) {
|
||||
cstr := CString(val)
|
||||
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr))
|
||||
case []byte:
|
||||
argSlice[i].Type = blobVal
|
||||
limboVal.Type = blobVal
|
||||
blob := makeBlob(val)
|
||||
pinner.Pin(blob)
|
||||
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob))
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type: %T", v)
|
||||
return nil, pinner.Unpin, fmt.Errorf("unsupported type: %T", v)
|
||||
}
|
||||
argSlice[i] = limboVal
|
||||
}
|
||||
return argSlice, nil
|
||||
}
|
||||
|
||||
func storeInt64(data *[8]byte, val int64) {
|
||||
*(*int64)(unsafe.Pointer(data)) = val
|
||||
}
|
||||
|
||||
func storeFloat64(data *[8]byte, val float64) {
|
||||
*(*float64)(unsafe.Pointer(data)) = val
|
||||
}
|
||||
|
||||
func storePointer(data *[8]byte, ptr *byte) {
|
||||
*(*uintptr)(unsafe.Pointer(data)) = uintptr(unsafe.Pointer(ptr))
|
||||
return argSlice, pinner.Unpin, nil
|
||||
}
|
||||
|
||||
/* Credit below (Apache2 License) to:
|
||||
|
||||
Reference in New Issue
Block a user