Get name of argument for scalars in extensions to allow for less rigid naming

This commit is contained in:
PThorpe92
2025-02-23 17:48:46 -05:00
parent e7713e87ec
commit 04abb200a7
2 changed files with 13 additions and 5 deletions

View File

@@ -89,7 +89,7 @@ fn uuid7_ts(args: &[Value]) -> Value {
let Some(text) = args[0].to_text() else {
return Value::null();
};
let Ok(uuid) = uuid::Uuid::parse_str(&text) else {
let Ok(uuid) = uuid::Uuid::parse_str(text) else {
return Value::null();
};
let unix = uuid_to_unix(uuid.as_bytes());
@@ -118,7 +118,7 @@ fn uuid_blob(&self, args: &[Value]) -> Value {
let Some(text) = args[0].to_text() else {
return Value::null();
};
match uuid::Uuid::parse_str(&text) {
match uuid::Uuid::parse_str(text) {
Ok(uuid) => Value::from_blob(uuid.as_bytes().to_vec()),
Err(_) => Value::null(),
}

View File

@@ -140,12 +140,11 @@ fn generate_get_description(
/// Declare a scalar function for your extension. This requires the name:
/// #[scalar(name = "example")] of what you wish to call your function with.
/// Your function __must__ use the signature: `fn (args: &[Value]) -> Value`
/// with proper spelling.
/// ```ignore
/// use limbo_ext::{scalar, Value};
/// #[scalar(name = "double", alias = "twice")] // you can provide an <optional> alias
/// fn double(args: &[Value]) -> Value {
/// let arg = args.get(0).unwrap();
/// match arg.value_type() {
/// ValueType::Float => {
/// let val = arg.to_float().unwrap();
@@ -165,9 +164,18 @@ fn generate_get_description(
pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as ItemFn);
let fn_name = &ast.sig.ident;
let args_variable = &ast.sig.inputs.first();
let mut args_variable_name = None;
if let Some(syn::FnArg::Typed(syn::PatType { pat, .. })) = args_variable {
if let syn::Pat::Ident(ident) = &**pat {
args_variable_name = Some(ident.ident.clone());
}
}
let scalar_info = parse_macro_input!(attr as ScalarInfo);
let name = &scalar_info.name;
let register_fn_name = format_ident!("register_{}", fn_name);
let args_variable_name =
format_ident!("{}", args_variable_name.unwrap_or(format_ident!("args")));
let fn_body = &ast.block;
let alias_check = if let Some(alias) = &scalar_info.alias {
quote! {
@@ -210,7 +218,7 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream {
argc: i32,
argv: *const ::limbo_ext::Value
) -> ::limbo_ext::Value {
let args = if argv.is_null() || argc <= 0 {
let #args_variable_name = if argv.is_null() || argc <= 0 {
&[]
} else {
unsafe { std::slice::from_raw_parts(argv, argc as usize) }