Get name of argument for scalars in extensions to allow for less rigid naming

This commit is contained in:
PThorpe92
2025-02-23 17:48:46 -05:00
parent e7713e87ec
commit 04abb200a7
2 changed files with 13 additions and 5 deletions

View File

@@ -140,12 +140,11 @@ fn generate_get_description(
/// Declare a scalar function for your extension. This requires the name:
/// #[scalar(name = "example")] of what you wish to call your function with.
/// Your function __must__ use the signature: `fn (args: &[Value]) -> Value`
/// with proper spelling.
/// ```ignore
/// use limbo_ext::{scalar, Value};
/// #[scalar(name = "double", alias = "twice")] // you can provide an <optional> alias
/// fn double(args: &[Value]) -> Value {
/// let arg = args.get(0).unwrap();
/// match arg.value_type() {
/// ValueType::Float => {
/// let val = arg.to_float().unwrap();
@@ -165,9 +164,18 @@ fn generate_get_description(
pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as ItemFn);
let fn_name = &ast.sig.ident;
let args_variable = &ast.sig.inputs.first();
let mut args_variable_name = None;
if let Some(syn::FnArg::Typed(syn::PatType { pat, .. })) = args_variable {
if let syn::Pat::Ident(ident) = &**pat {
args_variable_name = Some(ident.ident.clone());
}
}
let scalar_info = parse_macro_input!(attr as ScalarInfo);
let name = &scalar_info.name;
let register_fn_name = format_ident!("register_{}", fn_name);
let args_variable_name =
format_ident!("{}", args_variable_name.unwrap_or(format_ident!("args")));
let fn_body = &ast.block;
let alias_check = if let Some(alias) = &scalar_info.alias {
quote! {
@@ -210,7 +218,7 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream {
argc: i32,
argv: *const ::limbo_ext::Value
) -> ::limbo_ext::Value {
let args = if argv.is_null() || argc <= 0 {
let #args_variable_name = if argv.is_null() || argc <= 0 {
&[]
} else {
unsafe { std::slice::from_raw_parts(argv, argc as usize) }