Enable passing arguments to external functions

This commit is contained in:
PThorpe92
2025-01-12 15:24:50 -05:00
parent 852817c9ff
commit 98eff6cf7a
5 changed files with 132 additions and 45 deletions

View File

@@ -762,13 +762,23 @@ pub fn translate_expr(
crate::bail_parse_error!("aggregation function in non-aggregation context")
}
Func::External(_) => {
let regs = program.alloc_register();
let regs = program.alloc_registers(args_count);
for (i, arg_expr) in args.iter().enumerate() {
translate_expr(
program,
referenced_tables,
&arg_expr[i],
regs + i,
resolver,
)?;
}
program.emit_insn(Insn::Function {
constant_mask: 0,
start_reg: regs,
dest: target_register,
func: func_ctx,
});
Ok(target_register)
}
#[cfg(feature = "json")]

View File

@@ -148,10 +148,9 @@ pub fn prepare_select_plan(
}
Err(_) => {
if syms.functions.contains_key(&name.0) {
// TODO: future extensions can be aggregate functions
log::debug!(
"Resolving {} function from symbol table",
name.0
let contains_aggregates = resolve_aggregates(
expr,
&mut aggregate_expressions,
);
plan.result_columns.push(ResultSetColumn {
name: get_name(
@@ -161,7 +160,7 @@ pub fn prepare_select_plan(
|| format!("expr_{}", result_column_idx),
),
expr: expr.clone(),
contains_aggregates: false,
contains_aggregates,
});
}
}
@@ -202,7 +201,7 @@ pub fn prepare_select_plan(
}
expr => {
let contains_aggregates =
resolve_aggregates(&expr, &mut aggregate_expressions);
resolve_aggregates(expr, &mut aggregate_expressions);
plan.result_columns.push(ResultSetColumn {
name: get_name(
maybe_alias.as_ref(),

View File

@@ -55,7 +55,6 @@ use sorter::Sorter;
use std::borrow::BorrowMut;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap};
use std::os::raw::c_void;
use std::rc::{Rc, Weak};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@@ -147,6 +146,33 @@ macro_rules! return_if_io {
};
}
macro_rules! call_external_function {
(
$func_ptr:expr,
$dest_register:expr,
$state:expr,
$arg_count:expr,
$start_reg:expr
) => {{
if $arg_count == 0 {
let result_c_value: ExtValue = ($func_ptr)(0, std::ptr::null());
let result_ov = OwnedValue::from_ffi(&result_c_value);
$state.registers[$dest_register] = result_ov;
} else {
let register_slice = &$state.registers[$start_reg..$start_reg + $arg_count];
let mut ext_values: Vec<ExtValue> = Vec::with_capacity($arg_count);
for ov in register_slice.iter() {
let val = ov.to_ffi();
ext_values.push(val);
}
let argv_ptr = ext_values.as_ptr();
let result_c_value: ExtValue = ($func_ptr)($arg_count as i32, argv_ptr);
let result_ov = OwnedValue::from_ffi(&result_c_value);
$state.registers[$dest_register] = result_ov;
}
}};
}
struct RegexCache {
like: HashMap<String, Regex>,
glob: HashMap<String, Regex>,
@@ -1839,19 +1865,7 @@ impl Program {
}
},
crate::function::Func::External(f) => {
let values = &state.registers[*start_reg..*start_reg + arg_count];
let c_values: Vec<*const c_void> = values
.iter()
.map(|ov| &ov.to_ffi() as *const _ as *const c_void)
.collect();
let argv_ptr = if c_values.is_empty() {
std::ptr::null()
} else {
c_values.as_ptr()
};
let result_c_value: ExtValue = (f.func)(arg_count as i32, argv_ptr);
let result_ov = OwnedValue::from_ffi(&result_c_value);
state.registers[*dest] = result_ov;
call_external_function! {f.func, *dest, state, arg_count, *start_reg };
}
crate::function::Func::Math(math_func) => match math_func.arity() {
MathFuncArity::Nullary => match math_func {

View File

@@ -1,5 +1,6 @@
use limbo_extension::{
declare_scalar_functions, register_extension, register_scalar_functions, Blob, TextValue, Value,
declare_scalar_functions, register_extension, register_scalar_functions, Blob, TextValue,
Value, ValueType,
};
register_extension! {
@@ -10,6 +11,7 @@ register_extension! {
"uuid7" => uuid7_blob,
"uuid_str" => uuid_str,
"uuid_blob" => uuid_blob,
"exec_ts_from_uuid7" => exec_ts_from_uuid7,
},
}
@@ -60,6 +62,26 @@ declare_scalar_functions! {
Value::from_blob(bytes.to_vec())
}
#[args(1)]
fn exec_ts_from_uuid7(args: &[Value]) -> Value {
match args[0].value_type {
ValueType::Blob => {
let blob = Blob::from_value(&args[0]).unwrap();
let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)};
let uuid = uuid::Uuid::from_slice(slice).unwrap();
let unix = uuid_to_unix(uuid.as_bytes());
Value::from_integer(unix as i64)
}
ValueType::Text => {
let text = TextValue::from_value(&args[0]).unwrap();
let uuid = uuid::Uuid::parse_str(unsafe {text.as_str()}).unwrap();
let unix = uuid_to_unix(uuid.as_bytes());
Value::from_integer(unix as i64)
}
_ => Value::null(),
}
}
#[args(1)]
fn uuid_str(args: &[Value]) -> Value {
if args[0].value_type != limbo_extension::ValueType::Blob {
@@ -96,3 +118,13 @@ declare_scalar_functions! {
}
}
}
#[inline(always)]
fn uuid_to_unix(uuid: &[u8; 16]) -> u64 {
((uuid[0] as u64) << 40)
| ((uuid[1] as u64) << 32)
| ((uuid[2] as u64) << 24)
| ((uuid[3] as u64) << 16)
| ((uuid[4] as u64) << 8)
| (uuid[5] as u64)
}

View File

@@ -7,7 +7,7 @@ pub const RESULT_ERROR: ResultCode = 1;
// TODO: more error types
pub type ExtensionEntryPoint = extern "C" fn(api: *const ExtensionApi) -> ResultCode;
pub type ScalarFunction = extern "C" fn(argc: i32, *const *const c_void) -> Value;
pub type ScalarFunction = extern "C" fn(argc: i32, *const Value) -> Value;
#[repr(C)]
pub struct ExtensionApi {
@@ -54,12 +54,13 @@ macro_rules! register_scalar_functions {
/// . e.g.
/// ```
/// #[args(1)]
/// fn scalar_func(args: &[Value]) -> Value {
/// if args.len() != 1 {
/// return Value::null();
/// }
/// fn scalar_double(args: &[Value]) -> Value {
/// Value::from_integer(args[0].integer * 2)
/// }
///
/// #[args(0..=2)]
/// fn scalar_sum(args: &[Value]) -> Value {
/// Value::from_integer(args.iter().map(|v| v.integer).sum())
/// ```
///
#[macro_export]
@@ -73,7 +74,7 @@ macro_rules! declare_scalar_functions {
$(
extern "C" fn $func_name(
argc: i32,
argv: *const *const std::os::raw::c_void
argv: *const $crate::Value
) -> $crate::Value {
let valid_args = {
match argc {
@@ -85,22 +86,14 @@ macro_rules! declare_scalar_functions {
return $crate::Value::null();
}
if argc == 0 || argv.is_null() {
log::debug!("{} was called with no arguments", stringify!($func_name));
let $args: &[$crate::Value] = &[];
$body
} else {
let ptr_slice = unsafe{ std::slice::from_raw_parts(argv, argc as usize)};
let mut values = Vec::with_capacity(argc as usize);
for &ptr in ptr_slice {
let val_ptr = ptr as *const $crate::Value;
if val_ptr.is_null() {
values.push($crate::Value::null());
} else {
unsafe{values.push(std::ptr::read(val_ptr))};
}
}
let $args: &[$crate::Value] = &values[..];
$body
}
let ptr_slice = unsafe{ std::slice::from_raw_parts(argv, argc as usize)};
let $args: &[$crate::Value] = ptr_slice;
$body
}
}
)*
};
@@ -122,12 +115,42 @@ pub struct Value {
pub value: *mut c_void,
}
impl std::fmt::Debug for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.value_type {
ValueType::Null => write!(f, "Value {{ Null }}"),
ValueType::Integer => write!(f, "Value {{ Integer: {} }}", unsafe {
*(self.value as *const i64)
}),
ValueType::Float => write!(f, "Value {{ Float: {} }}", unsafe {
*(self.value as *const f64)
}),
ValueType::Text => write!(f, "Value {{ Text: {:?} }}", unsafe {
&*(self.value as *const TextValue)
}),
ValueType::Blob => write!(f, "Value {{ Blob: {:?} }}", unsafe {
&*(self.value as *const Blob)
}),
}
}
}
#[repr(C)]
pub struct TextValue {
pub text: *const u8,
pub len: u32,
}
impl std::fmt::Debug for TextValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"TextValue {{ text: {:?}, len: {} }}",
self.text, self.len
)
}
}
impl Default for TextValue {
fn default() -> Self {
Self {
@@ -170,6 +193,12 @@ pub struct Blob {
pub size: u64,
}
impl std::fmt::Debug for Blob {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Blob {{ data: {:?}, size: {} }}", self.data, self.size)
}
}
impl Blob {
pub fn new(data: *const u8, size: u64) -> Self {
Self { data, size }
@@ -208,12 +237,15 @@ impl Value {
}
pub fn from_text(s: String) -> Self {
let text_value = TextValue::new(s.as_ptr(), s.len());
let boxed_text = Box::new(text_value);
std::mem::forget(s);
let buffer = s.into_boxed_str();
let ptr = buffer.as_ptr();
let len = buffer.len();
std::mem::forget(buffer);
let text_value = TextValue::new(ptr, len);
let text_box = Box::new(text_value);
Self {
value_type: ValueType::Text,
value: Box::into_raw(boxed_text) as *mut c_void,
value: Box::into_raw(text_box) as *mut c_void,
}
}