use proc_macro::TokenStream; use quote::quote; use syn::parse::ParseStream; use syn::punctuated::Punctuated; use syn::token::Eq; use syn::{parse_macro_input, Ident, LitStr, Token}; mod agg_derive; mod scalars; mod vfs_derive; mod vtab_derive; pub use agg_derive::derive_agg_func; pub use scalars::scalar; pub use vfs_derive::derive_vfs_module; pub use vtab_derive::derive_vtab_module; pub fn register_extension(input: TokenStream) -> TokenStream { let input_ast = parse_macro_input!(input as RegisterExtensionInput); let RegisterExtensionInput { aggregates, scalars, vtabs, vfs, } = input_ast; let scalar_calls = scalars.iter().map(|scalar_ident| { let register_fn = syn::Ident::new(&format!("register_{}", scalar_ident), scalar_ident.span()); quote! { { let result = unsafe { #register_fn(api)}; if !result.is_ok() { return result; } } } }); let aggregate_calls = aggregates.iter().map(|agg_ident| { let register_fn = syn::Ident::new(&format!("register_{}", agg_ident), agg_ident.span()); quote! { { let result = unsafe{ #agg_ident::#register_fn(api)}; if !result.is_ok() { return result; } } } }); let vtab_calls = vtabs.iter().map(|vtab_ident| { let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span()); quote! { { let result = unsafe{ #vtab_ident::#register_fn(api)}; if !result.is_ok() { return result; } } } }); let vfs_calls = vfs.iter().map(|vfs_ident| { let register_fn = syn::Ident::new(&format!("register_{}", vfs_ident), vfs_ident.span()); quote! { { let result = unsafe { #register_fn(api) }; if !result.is_ok() { return result; } } } }); let static_vfs = vfs.iter().map(|vfs_ident| { let static_register = syn::Ident::new(&format!("register_static_{}", vfs_ident), vfs_ident.span()); quote! { { let result = api.add_builtin_vfs(unsafe { #static_register()}); if !result.is_ok() { return result; } } } }); let static_aggregates = aggregate_calls.clone(); let static_scalars = scalar_calls.clone(); let static_vtabs = vtab_calls.clone(); let expanded = quote! { #[cfg(not(target_family = "wasm"))] #[cfg(not(feature = "static"))] #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[cfg(feature = "static")] pub unsafe extern "C" fn register_extension_static(api: &mut ::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { #(#static_scalars)* #(#static_aggregates)* #(#static_vtabs)* #[cfg(not(target_family = "wasm"))] #(#static_vfs)* ::limbo_ext::ResultCode::OK } #[cfg(not(feature = "static"))] #[no_mangle] pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { #(#scalar_calls)* #(#aggregate_calls)* #(#vtab_calls)* #(#vfs_calls)* ::limbo_ext::ResultCode::OK } }; TokenStream::from(expanded) } pub(crate) struct RegisterExtensionInput { pub aggregates: Vec, pub scalars: Vec, pub vtabs: Vec, pub vfs: Vec, } impl syn::parse::Parse for RegisterExtensionInput { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut aggregates = Vec::new(); let mut scalars = Vec::new(); let mut vtabs = Vec::new(); let mut vfs = Vec::new(); while !input.is_empty() { if input.peek(syn::Ident) && input.peek2(Token![:]) { let section_name: Ident = input.parse()?; input.parse::()?; let names = ["aggregates", "scalars", "vtabs", "vfs"]; if names.contains(§ion_name.to_string().as_str()) { let content; syn::braced!(content in input); let parsed_items = Punctuated::::parse_terminated(&content)? .into_iter() .collect(); match section_name.to_string().as_str() { "aggregates" => aggregates = parsed_items, "scalars" => scalars = parsed_items, "vtabs" => vtabs = parsed_items, "vfs" => vfs = parsed_items, _ => unreachable!(), }; if input.peek(Token![,]) { input.parse::()?; } } else { return Err(syn::Error::new(section_name.span(), "Unknown section")); } } else { return Err(input.error("Expected aggregates:, scalars:, or vtabs: section")); } } Ok(Self { aggregates, scalars, vtabs, vfs, }) } } pub(crate) struct ScalarInfo { pub name: String, pub alias: Option, } impl ScalarInfo { pub fn new(name: String, alias: Option) -> Self { Self { name, alias } } } impl syn::parse::Parse for ScalarInfo { fn parse(input: ParseStream) -> syn::parse::Result { let mut name = None; let mut alias = None; while !input.is_empty() { if let Ok(ident) = input.parse::() { if ident.to_string().as_str() == "name" { let _ = input.parse::(); name = Some(input.parse::()?); } else if ident.to_string().as_str() == "alias" { let _ = input.parse::(); alias = Some(input.parse::()?); } } if input.peek(Token![,]) { input.parse::()?; } } let Some(name) = name else { return Err(input.error("Expected name")); }; Ok(Self::new(name.value(), alias.map(|i| i.value()))) } }