diff --git a/macros/src/args.rs b/macros/src/args.rs deleted file mode 100644 index b0d45d20a..000000000 --- a/macros/src/args.rs +++ /dev/null @@ -1,93 +0,0 @@ -use syn::parse::ParseStream; -use syn::punctuated::Punctuated; -use syn::token::Eq; -use syn::{Ident, LitStr, Token}; - -pub(crate) struct RegisterExtensionInput { - pub aggregates: Vec, - pub scalars: Vec, - pub vtabs: Vec, - pub vfs: Vec, -} - -impl syn::parse::Parse for RegisterExtensionInput { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut aggregates = Vec::new(); - let mut scalars = Vec::new(); - let mut vtabs = Vec::new(); - let mut vfs = Vec::new(); - while !input.is_empty() { - if input.peek(syn::Ident) && input.peek2(Token![:]) { - let section_name: Ident = input.parse()?; - input.parse::()?; - let names = ["aggregates", "scalars", "vtabs", "vfs"]; - if names.contains(§ion_name.to_string().as_str()) { - let content; - syn::braced!(content in input); - let parsed_items = Punctuated::::parse_terminated(&content)? - .into_iter() - .collect(); - - match section_name.to_string().as_str() { - "aggregates" => aggregates = parsed_items, - "scalars" => scalars = parsed_items, - "vtabs" => vtabs = parsed_items, - "vfs" => vfs = parsed_items, - _ => unreachable!(), - }; - - if input.peek(Token![,]) { - input.parse::()?; - } - } else { - return Err(syn::Error::new(section_name.span(), "Unknown section")); - } - } else { - return Err(input.error("Expected aggregates:, scalars:, or vtabs: section")); - } - } - - Ok(Self { - aggregates, - scalars, - vtabs, - vfs, - }) - } -} - -pub(crate) struct ScalarInfo { - pub name: String, - pub alias: Option, -} - -impl ScalarInfo { - pub fn new(name: String, alias: Option) -> Self { - Self { name, alias } - } -} - -impl syn::parse::Parse for ScalarInfo { - fn parse(input: ParseStream) -> syn::parse::Result { - let mut name = None; - let mut alias = None; - while !input.is_empty() { - if let Ok(ident) = input.parse::() { - if ident.to_string().as_str() == "name" { - let _ = input.parse::(); - name = Some(input.parse::()?); - } else if ident.to_string().as_str() == "alias" { - let _ = input.parse::(); - alias = Some(input.parse::()?); - } - } - if input.peek(Token![,]) { - input.parse::()?; - } - } - let Some(name) = name else { - return Err(input.error("Expected name")); - }; - Ok(Self::new(name.value(), alias.map(|i| i.value()))) - } -} diff --git a/macros/src/ext/agg_derive.rs b/macros/src/ext/agg_derive.rs new file mode 100644 index 000000000..5d82b9bbc --- /dev/null +++ b/macros/src/ext/agg_derive.rs @@ -0,0 +1,87 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::parse_macro_input; +use syn::DeriveInput; + +pub fn derive_agg_func(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let struct_name = &ast.ident; + + let step_fn_name = format_ident!("{}_step", struct_name); + let finalize_fn_name = format_ident!("{}_finalize", struct_name); + let init_fn_name = format_ident!("{}_init", struct_name); + let register_fn_name = format_ident!("register_{}", struct_name); + + let expanded = quote! { + impl #struct_name { + #[no_mangle] + pub extern "C" fn #init_fn_name() -> *mut ::limbo_ext::AggCtx { + let state = ::std::boxed::Box::new(<#struct_name as ::limbo_ext::AggFunc>::State::default()); + let ctx = ::std::boxed::Box::new(::limbo_ext::AggCtx { + state: ::std::boxed::Box::into_raw(state) as *mut ::std::os::raw::c_void, + }); + ::std::boxed::Box::into_raw(ctx) + } + + #[no_mangle] + pub extern "C" fn #step_fn_name( + ctx: *mut ::limbo_ext::AggCtx, + argc: i32, + argv: *const ::limbo_ext::Value, + ) { + unsafe { + let ctx = &mut *ctx; + let state = &mut *(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); + let args = ::std::slice::from_raw_parts(argv, argc as usize); + <#struct_name as ::limbo_ext::AggFunc>::step(state, args); + } + } + + #[no_mangle] + pub extern "C" fn #finalize_fn_name( + ctx: *mut ::limbo_ext::AggCtx + ) -> ::limbo_ext::Value { + unsafe { + let ctx = &mut *ctx; + let state = ::std::boxed::Box::from_raw(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); + match <#struct_name as ::limbo_ext::AggFunc>::finalize(*state) { + Ok(val) => val, + Err(e) => { + ::limbo_ext::Value::error_with_message(e.to_string()) + } + } + } + } + + #[no_mangle] + pub unsafe extern "C" fn #register_fn_name( + api: *const ::limbo_ext::ExtensionApi + ) -> ::limbo_ext::ResultCode { + if api.is_null() { + return ::limbo_ext::ResultCode::Error; + } + + let api = &*api; + let name_str = #struct_name::NAME; + let c_name = match ::std::ffi::CString::new(name_str) { + Ok(cname) => cname, + Err(_) => return ::limbo_ext::ResultCode::Error, + }; + + (api.register_aggregate_function)( + api.ctx, + c_name.as_ptr(), + #struct_name::ARGS, + #struct_name::#init_fn_name + as ::limbo_ext::InitAggFunction, + #struct_name::#step_fn_name + as ::limbo_ext::StepFunction, + #struct_name::#finalize_fn_name + as ::limbo_ext::FinalizeFunction, + ) + } + } + }; + + TokenStream::from(expanded) +} diff --git a/macros/src/ext/mod.rs b/macros/src/ext/mod.rs new file mode 100644 index 000000000..32bda4a70 --- /dev/null +++ b/macros/src/ext/mod.rs @@ -0,0 +1,212 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::parse::ParseStream; +use syn::punctuated::Punctuated; +use syn::token::Eq; +use syn::{parse_macro_input, Ident, LitStr, Token}; +mod agg_derive; +mod scalars; +mod vfs_derive; +mod vtab_derive; +pub use agg_derive::derive_agg_func; +pub use scalars::scalar; +pub use vfs_derive::derive_vfs_module; +pub use vtab_derive::derive_vtab_module; + +pub fn register_extension(input: TokenStream) -> TokenStream { + let input_ast = parse_macro_input!(input as RegisterExtensionInput); + let RegisterExtensionInput { + aggregates, + scalars, + vtabs, + vfs, + } = input_ast; + + let scalar_calls = scalars.iter().map(|scalar_ident| { + let register_fn = + syn::Ident::new(&format!("register_{}", scalar_ident), scalar_ident.span()); + quote! { + { + let result = unsafe { #register_fn(api)}; + if !result.is_ok() { + return result; + } + } + } + }); + + let aggregate_calls = aggregates.iter().map(|agg_ident| { + let register_fn = syn::Ident::new(&format!("register_{}", agg_ident), agg_ident.span()); + quote! { + { + let result = unsafe{ #agg_ident::#register_fn(api)}; + if !result.is_ok() { + return result; + } + } + } + }); + let vtab_calls = vtabs.iter().map(|vtab_ident| { + let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span()); + quote! { + { + let result = unsafe{ #vtab_ident::#register_fn(api)}; + if !result.is_ok() { + return result; + } + } + } + }); + let vfs_calls = vfs.iter().map(|vfs_ident| { + let register_fn = syn::Ident::new(&format!("register_{}", vfs_ident), vfs_ident.span()); + quote! { + { + let result = unsafe { #register_fn(api) }; + if !result.is_ok() { + return result; + } + } + } + }); + let static_vfs = vfs.iter().map(|vfs_ident| { + let static_register = + syn::Ident::new(&format!("register_static_{}", vfs_ident), vfs_ident.span()); + quote! { + { + let result = api.add_builtin_vfs(unsafe { #static_register()}); + if !result.is_ok() { + return result; + } + } + } + }); + let static_aggregates = aggregate_calls.clone(); + let static_scalars = scalar_calls.clone(); + let static_vtabs = vtab_calls.clone(); + + let expanded = quote! { + #[cfg(not(target_family = "wasm"))] + #[cfg(not(feature = "static"))] + #[global_allocator] + static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; + + #[cfg(feature = "static")] + pub unsafe extern "C" fn register_extension_static(api: &mut ::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { + #(#static_scalars)* + + #(#static_aggregates)* + + #(#static_vtabs)* + + #[cfg(not(target_family = "wasm"))] + #(#static_vfs)* + + ::limbo_ext::ResultCode::OK + } + + #[cfg(not(feature = "static"))] + #[no_mangle] + pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { + #(#scalar_calls)* + + #(#aggregate_calls)* + + #(#vtab_calls)* + + #(#vfs_calls)* + + ::limbo_ext::ResultCode::OK + } + }; + + TokenStream::from(expanded) +} + +pub(crate) struct RegisterExtensionInput { + pub aggregates: Vec, + pub scalars: Vec, + pub vtabs: Vec, + pub vfs: Vec, +} + +impl syn::parse::Parse for RegisterExtensionInput { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut aggregates = Vec::new(); + let mut scalars = Vec::new(); + let mut vtabs = Vec::new(); + let mut vfs = Vec::new(); + while !input.is_empty() { + if input.peek(syn::Ident) && input.peek2(Token![:]) { + let section_name: Ident = input.parse()?; + input.parse::()?; + let names = ["aggregates", "scalars", "vtabs", "vfs"]; + if names.contains(§ion_name.to_string().as_str()) { + let content; + syn::braced!(content in input); + let parsed_items = Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(); + + match section_name.to_string().as_str() { + "aggregates" => aggregates = parsed_items, + "scalars" => scalars = parsed_items, + "vtabs" => vtabs = parsed_items, + "vfs" => vfs = parsed_items, + _ => unreachable!(), + }; + + if input.peek(Token![,]) { + input.parse::()?; + } + } else { + return Err(syn::Error::new(section_name.span(), "Unknown section")); + } + } else { + return Err(input.error("Expected aggregates:, scalars:, or vtabs: section")); + } + } + + Ok(Self { + aggregates, + scalars, + vtabs, + vfs, + }) + } +} + +pub(crate) struct ScalarInfo { + pub name: String, + pub alias: Option, +} + +impl ScalarInfo { + pub fn new(name: String, alias: Option) -> Self { + Self { name, alias } + } +} + +impl syn::parse::Parse for ScalarInfo { + fn parse(input: ParseStream) -> syn::parse::Result { + let mut name = None; + let mut alias = None; + while !input.is_empty() { + if let Ok(ident) = input.parse::() { + if ident.to_string().as_str() == "name" { + let _ = input.parse::(); + name = Some(input.parse::()?); + } else if ident.to_string().as_str() == "alias" { + let _ = input.parse::(); + alias = Some(input.parse::()?); + } + } + if input.peek(Token![,]) { + input.parse::()?; + } + } + let Some(name) = name else { + return Err(input.error("Expected name")); + }; + Ok(Self::new(name.value(), alias.map(|i| i.value()))) + } +} diff --git a/macros/src/ext/scalars.rs b/macros/src/ext/scalars.rs new file mode 100644 index 000000000..a49d055ad --- /dev/null +++ b/macros/src/ext/scalars.rs @@ -0,0 +1,74 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, ItemFn}; + +use super::ScalarInfo; + +pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as ItemFn); + let fn_name = &ast.sig.ident; + let args_variable = &ast.sig.inputs.first(); + let mut args_variable_name = None; + if let Some(syn::FnArg::Typed(syn::PatType { pat, .. })) = args_variable { + if let syn::Pat::Ident(ident) = &**pat { + args_variable_name = Some(ident.ident.clone()); + } + } + let scalar_info = parse_macro_input!(attr as ScalarInfo); + let name = &scalar_info.name; + let register_fn_name = format_ident!("register_{}", fn_name); + let args_variable_name = + format_ident!("{}", args_variable_name.unwrap_or(format_ident!("args"))); + let fn_body = &ast.block; + let alias_check = if let Some(alias) = &scalar_info.alias { + quote! { + let Ok(alias_c_name) = ::std::ffi::CString::new(#alias) else { + return ::limbo_ext::ResultCode::Error; + }; + (api.register_scalar_function)( + api.ctx, + alias_c_name.as_ptr(), + #fn_name, + ); + } + } else { + quote! {} + }; + + let expanded = quote! { + #[no_mangle] + pub unsafe extern "C" fn #register_fn_name( + api: *const ::limbo_ext::ExtensionApi + ) -> ::limbo_ext::ResultCode { + if api.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let api = unsafe { &*api }; + let Ok(c_name) = ::std::ffi::CString::new(#name) else { + return ::limbo_ext::ResultCode::Error; + }; + (api.register_scalar_function)( + api.ctx, + c_name.as_ptr(), + #fn_name, + ); + #alias_check + ::limbo_ext::ResultCode::OK + } + + #[no_mangle] + pub unsafe extern "C" fn #fn_name( + argc: i32, + argv: *const ::limbo_ext::Value + ) -> ::limbo_ext::Value { + let #args_variable_name = if argv.is_null() || argc <= 0 { + &[] + } else { + unsafe { std::slice::from_raw_parts(argv, argc as usize) } + }; + #fn_body + } + }; + + TokenStream::from(expanded) +} diff --git a/macros/src/ext/vfs_derive.rs b/macros/src/ext/vfs_derive.rs new file mode 100644 index 000000000..814e80d1a --- /dev/null +++ b/macros/src/ext/vfs_derive.rs @@ -0,0 +1,218 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, DeriveInput}; + +pub fn derive_vfs_module(input: TokenStream) -> TokenStream { + let derive_input = parse_macro_input!(input as DeriveInput); + let struct_name = &derive_input.ident; + let register_fn_name = format_ident!("register_{}", struct_name); + let register_static = format_ident!("register_static_{}", struct_name); + let open_fn_name = format_ident!("{}_open", struct_name); + let close_fn_name = format_ident!("{}_close", struct_name); + let read_fn_name = format_ident!("{}_read", struct_name); + let write_fn_name = format_ident!("{}_write", struct_name); + let lock_fn_name = format_ident!("{}_lock", struct_name); + let unlock_fn_name = format_ident!("{}_unlock", struct_name); + let sync_fn_name = format_ident!("{}_sync", struct_name); + let size_fn_name = format_ident!("{}_size", struct_name); + let run_once_fn_name = format_ident!("{}_run_once", struct_name); + let generate_random_number_fn_name = format_ident!("{}_generate_random_number", struct_name); + let get_current_time_fn_name = format_ident!("{}_get_current_time", struct_name); + + let expanded = quote! { + #[allow(non_snake_case)] + pub unsafe extern "C" fn #register_static() -> *const ::limbo_ext::VfsImpl { + let ctx = #struct_name::default(); + let ctx = ::std::boxed::Box::into_raw(::std::boxed::Box::new(ctx)) as *const ::std::ffi::c_void; + let name = ::std::ffi::CString::new(<#struct_name as ::limbo_ext::VfsExtension>::NAME).unwrap().into_raw(); + let vfs_mod = ::limbo_ext::VfsImpl { + vfs: ctx, + name, + open: #open_fn_name, + close: #close_fn_name, + read: #read_fn_name, + write: #write_fn_name, + lock: #lock_fn_name, + unlock: #unlock_fn_name, + sync: #sync_fn_name, + size: #size_fn_name, + run_once: #run_once_fn_name, + gen_random_number: #generate_random_number_fn_name, + current_time: #get_current_time_fn_name, + }; + ::std::boxed::Box::into_raw(::std::boxed::Box::new(vfs_mod)) as *const ::limbo_ext::VfsImpl + } + + #[no_mangle] + pub unsafe extern "C" fn #register_fn_name(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { + let ctx = #struct_name::default(); + let ctx = ::std::boxed::Box::into_raw(::std::boxed::Box::new(ctx)) as *const ::std::ffi::c_void; + let name = ::std::ffi::CString::new(<#struct_name as ::limbo_ext::VfsExtension>::NAME).unwrap().into_raw(); + let vfs_mod = ::limbo_ext::VfsImpl { + vfs: ctx, + name, + open: #open_fn_name, + close: #close_fn_name, + read: #read_fn_name, + write: #write_fn_name, + lock: #lock_fn_name, + unlock: #unlock_fn_name, + sync: #sync_fn_name, + size: #size_fn_name, + run_once: #run_once_fn_name, + gen_random_number: #generate_random_number_fn_name, + current_time: #get_current_time_fn_name, + }; + let vfsimpl = ::std::boxed::Box::into_raw(::std::boxed::Box::new(vfs_mod)) as *const ::limbo_ext::VfsImpl; + (api.vfs_interface.register_vfs)(name, vfsimpl) + } + + #[no_mangle] + pub unsafe extern "C" fn #open_fn_name( + ctx: *const ::std::ffi::c_void, + path: *const ::std::ffi::c_char, + flags: i32, + direct: bool, + ) -> *const ::std::ffi::c_void { + let ctx = &*(ctx as *const ::limbo_ext::VfsImpl); + let Ok(path_str) = ::std::ffi::CStr::from_ptr(path).to_str() else { + return ::std::ptr::null_mut(); + }; + let vfs = &*(ctx.vfs as *const #struct_name); + let Ok(file_handle) = <#struct_name as ::limbo_ext::VfsExtension>::open_file(vfs, path_str, flags, direct) else { + return ::std::ptr::null(); + }; + let boxed = ::std::boxed::Box::into_raw(::std::boxed::Box::new(file_handle)) as *const ::std::ffi::c_void; + let Ok(vfs_file) = ::limbo_ext::VfsFileImpl::new(boxed, ctx) else { + return ::std::ptr::null(); + }; + ::std::boxed::Box::into_raw(::std::boxed::Box::new(vfs_file)) as *const ::std::ffi::c_void + } + + #[no_mangle] + pub unsafe extern "C" fn #close_fn_name(file_ptr: *const ::std::ffi::c_void) -> ::limbo_ext::ResultCode { + if file_ptr.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let vfs_instance = &*(vfs_file.vfs as *const #struct_name); + + // this time we need to own it so we can drop it + let file: ::std::boxed::Box<<#struct_name as ::limbo_ext::VfsExtension>::File> = + ::std::boxed::Box::from_raw(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::close(vfs_instance, *file) { + return e; + } + ::limbo_ext::ResultCode::OK + } + + #[no_mangle] + pub unsafe extern "C" fn #read_fn_name(file_ptr: *const ::std::ffi::c_void, buf: *mut u8, count: usize, offset: i64) -> i32 { + if file_ptr.is_null() { + return -1; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + match <#struct_name as ::limbo_ext::VfsExtension>::File::read(file, ::std::slice::from_raw_parts_mut(buf, count), count, offset) { + Ok(n) => n, + Err(_) => -1, + } + } + + #[no_mangle] + pub unsafe extern "C" fn #run_once_fn_name(ctx: *const ::std::ffi::c_void) -> ::limbo_ext::ResultCode { + if ctx.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let ctx = &mut *(ctx as *mut #struct_name); + if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::run_once(ctx) { + return e; + } + ::limbo_ext::ResultCode::OK + } + + #[no_mangle] + pub unsafe extern "C" fn #write_fn_name(file_ptr: *const ::std::ffi::c_void, buf: *const u8, count: usize, offset: i64) -> i32 { + if file_ptr.is_null() { + return -1; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + match <#struct_name as ::limbo_ext::VfsExtension>::File::write(file, ::std::slice::from_raw_parts(buf, count), count, offset) { + Ok(n) => n, + Err(_) => -1, + } + } + + #[no_mangle] + pub unsafe extern "C" fn #lock_fn_name(file_ptr: *const ::std::ffi::c_void, exclusive: bool) -> ::limbo_ext::ResultCode { + if file_ptr.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::File::lock(file, exclusive) { + return e; + } + ::limbo_ext::ResultCode::OK + } + + #[no_mangle] + pub unsafe extern "C" fn #unlock_fn_name(file_ptr: *const ::std::ffi::c_void) -> ::limbo_ext::ResultCode { + if file_ptr.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::File::unlock(file) { + return e; + } + ::limbo_ext::ResultCode::OK + } + + #[no_mangle] + pub unsafe extern "C" fn #sync_fn_name(file_ptr: *const ::std::ffi::c_void) -> i32 { + if file_ptr.is_null() { + return -1; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + if <#struct_name as ::limbo_ext::VfsExtension>::File::sync(file).is_err() { + return -1; + } + 0 + } + + #[no_mangle] + pub unsafe extern "C" fn #size_fn_name(file_ptr: *const ::std::ffi::c_void) -> i64 { + if file_ptr.is_null() { + return -1; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + <#struct_name as ::limbo_ext::VfsExtension>::File::size(file) + } + + #[no_mangle] + pub unsafe extern "C" fn #generate_random_number_fn_name() -> i64 { + let obj = #struct_name::default(); + <#struct_name as ::limbo_ext::VfsExtension>::generate_random_number(&obj) + } + + #[no_mangle] + pub unsafe extern "C" fn #get_current_time_fn_name() -> *const ::std::ffi::c_char { + let obj = #struct_name::default(); + let time = <#struct_name as ::limbo_ext::VfsExtension>::get_current_time(&obj); + // release ownership of the string to core + ::std::ffi::CString::new(time).unwrap().into_raw() as *const ::std::ffi::c_char + } + }; + + TokenStream::from(expanded) +} diff --git a/macros/src/ext/vtab_derive.rs b/macros/src/ext/vtab_derive.rs new file mode 100644 index 000000000..0dd67ca0d --- /dev/null +++ b/macros/src/ext/vtab_derive.rs @@ -0,0 +1,241 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, DeriveInput}; + +pub fn derive_vtab_module(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let struct_name = &ast.ident; + + let register_fn_name = format_ident!("register_{}", struct_name); + let create_schema_fn_name = format_ident!("create_schema_{}", struct_name); + let open_fn_name = format_ident!("open_{}", struct_name); + let close_fn_name = format_ident!("close_{}", struct_name); + let filter_fn_name = format_ident!("filter_{}", struct_name); + let column_fn_name = format_ident!("column_{}", struct_name); + let next_fn_name = format_ident!("next_{}", struct_name); + let eof_fn_name = format_ident!("eof_{}", struct_name); + let update_fn_name = format_ident!("update_{}", struct_name); + let rowid_fn_name = format_ident!("rowid_{}", struct_name); + let destroy_fn_name = format_ident!("destroy_{}", struct_name); + let best_idx_fn_name = format_ident!("best_idx_{}", struct_name); + + let expanded = quote! { + impl #struct_name { + #[no_mangle] + unsafe extern "C" fn #create_schema_fn_name( + argv: *const ::limbo_ext::Value, argc: i32 + ) -> *mut ::std::ffi::c_char { + let args = if argv.is_null() { + &Vec::new() + } else { + ::std::slice::from_raw_parts(argv, argc as usize) + }; + let sql = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); + ::std::ffi::CString::new(sql).unwrap().into_raw() + } + + #[no_mangle] + unsafe extern "C" fn #open_fn_name(ctx: *const ::std::ffi::c_void) -> *const ::std::ffi::c_void { + if ctx.is_null() { + return ::std::ptr::null(); + } + let ctx = ctx as *const #struct_name; + let ctx: &#struct_name = &*ctx; + if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open(ctx) { + return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *const ::std::ffi::c_void; + } else { + return ::std::ptr::null(); + } + } + + #[no_mangle] + unsafe extern "C" fn #close_fn_name( + cursor: *const ::std::ffi::c_void + ) -> ::limbo_ext::ResultCode { + if cursor.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let boxed_cursor = ::std::boxed::Box::from_raw(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); + boxed_cursor.close() + } + + #[no_mangle] + unsafe extern "C" fn #filter_fn_name( + cursor: *const ::std::ffi::c_void, + argc: i32, + argv: *const ::limbo_ext::Value, + idx_str: *const ::std::ffi::c_char, + idx_num: i32, + ) -> ::limbo_ext::ResultCode { + if cursor.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + let args = ::std::slice::from_raw_parts(argv, argc as usize); + let idx_str = if idx_str.is_null() { + None + } else { + Some((unsafe { ::std::ffi::CStr::from_ptr(idx_str).to_str().unwrap() }, idx_num)) + }; + <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, args, idx_str) + } + + #[no_mangle] + unsafe extern "C" fn #column_fn_name( + cursor: *const ::std::ffi::c_void, + idx: u32, + ) -> ::limbo_ext::Value { + if cursor.is_null() { + return ::limbo_ext::Value::error(::limbo_ext::ResultCode::Error); + } + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + match <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) { + Ok(val) => val, + Err(e) => ::limbo_ext::Value::error_with_message(e.to_string()) + } + } + + #[no_mangle] + unsafe extern "C" fn #next_fn_name( + cursor: *const ::std::ffi::c_void, + ) -> ::limbo_ext::ResultCode { + if cursor.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); + <#struct_name as ::limbo_ext::VTabModule>::next(cursor) + } + + #[no_mangle] + unsafe extern "C" fn #eof_fn_name( + cursor: *const ::std::ffi::c_void, + ) -> bool { + if cursor.is_null() { + return true; + } + let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); + <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) + } + + #[no_mangle] + unsafe extern "C" fn #update_fn_name( + vtab: *const ::std::ffi::c_void, + argc: i32, + argv: *const ::limbo_ext::Value, + p_out_rowid: *mut i64, + ) -> ::limbo_ext::ResultCode { + if vtab.is_null() { + return ::limbo_ext::ResultCode::Error; + } + + let vtab = &mut *(vtab as *mut #struct_name); + let args = ::std::slice::from_raw_parts(argv, argc as usize); + + let old_rowid = match args.get(0).map(|v| v.value_type()) { + Some(::limbo_ext::ValueType::Integer) => args.get(0).unwrap().to_integer(), + _ => None, + }; + let new_rowid = match args.get(1).map(|v| v.value_type()) { + Some(::limbo_ext::ValueType::Integer) => args.get(1).unwrap().to_integer(), + _ => None, + }; + let columns = &args[2..]; + match (old_rowid, new_rowid) { + // DELETE: old_rowid provided, no new_rowid + (Some(old), None) => { + if <#struct_name as VTabModule>::delete(vtab, old).is_err() { + return ::limbo_ext::ResultCode::Error; + } + return ::limbo_ext::ResultCode::OK; + } + // UPDATE: old_rowid provided and new_rowid may exist + (Some(old), Some(new)) => { + if <#struct_name as VTabModule>::update(vtab, old, &columns).is_err() { + return ::limbo_ext::ResultCode::Error; + } + return ::limbo_ext::ResultCode::OK; + } + // INSERT: no old_rowid (old_rowid = None) + (None, _) => { + if let Ok(rowid) = <#struct_name as VTabModule>::insert(vtab, &columns) { + if !p_out_rowid.is_null() { + *p_out_rowid = rowid; + return ::limbo_ext::ResultCode::RowID; + } + return ::limbo_ext::ResultCode::OK; + } + } + } + return ::limbo_ext::ResultCode::Error; + } + + #[no_mangle] + pub unsafe extern "C" fn #rowid_fn_name(ctx: *const ::std::ffi::c_void) -> i64 { + if ctx.is_null() { + return -1; + } + let cursor = &*(ctx as *const <#struct_name as ::limbo_ext::VTabModule>::VCursor); + <<#struct_name as ::limbo_ext::VTabModule>::VCursor as ::limbo_ext::VTabCursor>::rowid(cursor) + } + + #[no_mangle] + unsafe extern "C" fn #destroy_fn_name( + vtab: *const ::std::ffi::c_void, + ) -> ::limbo_ext::ResultCode { + if vtab.is_null() { + return ::limbo_ext::ResultCode::Error; + } + + let vtab = &mut *(vtab as *mut #struct_name); + if <#struct_name as VTabModule>::destroy(vtab).is_err() { + return ::limbo_ext::ResultCode::Error; + } + + return ::limbo_ext::ResultCode::OK; + } + + #[no_mangle] + pub unsafe extern "C" fn #best_idx_fn_name( + constraints: *const ::limbo_ext::ConstraintInfo, + n_constraints: i32, + order_by: *const ::limbo_ext::OrderByInfo, + n_order_by: i32, + ) -> ::limbo_ext::ExtIndexInfo { + let constraints = if n_constraints > 0 { std::slice::from_raw_parts(constraints, n_constraints as usize) } else { &[] }; + let order_by = if n_order_by > 0 { std::slice::from_raw_parts(order_by, n_order_by as usize) } else { &[] }; + <#struct_name as ::limbo_ext::VTabModule>::best_index(constraints, order_by).to_ffi() + } + + #[no_mangle] + pub unsafe extern "C" fn #register_fn_name( + api: *const ::limbo_ext::ExtensionApi + ) -> ::limbo_ext::ResultCode { + if api.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let api = &*api; + let name = <#struct_name as ::limbo_ext::VTabModule>::NAME; + let name_c = ::std::ffi::CString::new(name).unwrap().into_raw() as *const ::std::ffi::c_char; + let table_instance = ::std::boxed::Box::into_raw(::std::boxed::Box::new(#struct_name::default())); + let module = ::limbo_ext::VTabModuleImpl { + ctx: table_instance as *const ::std::ffi::c_void, + name: name_c, + create_schema: Self::#create_schema_fn_name, + open: Self::#open_fn_name, + close: Self::#close_fn_name, + filter: Self::#filter_fn_name, + column: Self::#column_fn_name, + next: Self::#next_fn_name, + eof: Self::#eof_fn_name, + update: Self::#update_fn_name, + rowid: Self::#rowid_fn_name, + destroy: Self::#destroy_fn_name, + best_idx: Self::#best_idx_fn_name, + }; + (api.register_vtab_module)(api.ctx, name_c, module, <#struct_name as ::limbo_ext::VTabModule>::VTAB_KIND) + } + } + }; + + TokenStream::from(expanded) +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 2d6694e10..e173d47ba 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,7 +1,4 @@ -mod args; -use args::{RegisterExtensionInput, ScalarInfo}; -use quote::{format_ident, quote}; -use syn::{parse_macro_input, DeriveInput, ItemFn}; +mod ext; extern crate proc_macro; use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree}; use std::collections::HashMap; @@ -138,6 +135,43 @@ fn generate_get_description( enum_impl.parse().unwrap() } +/// Register your extension with 'core' by providing the relevant functions +///```ignore +///use limbo_ext::{register_extension, scalar, Value, AggregateDerive, AggFunc}; +/// +/// register_extension!{ scalars: { return_one }, aggregates: { SumPlusOne } } +/// +///#[scalar(name = "one")] +///fn return_one(args: &[Value]) -> Value { +/// return Value::from_integer(1); +///} +/// +///#[derive(AggregateDerive)] +///struct SumPlusOne; +/// +///impl AggFunc for SumPlusOne { +/// type State = i64; +/// const NAME: &'static str = "sum_plus_one"; +/// const ARGS: i32 = 1; +/// +/// fn step(state: &mut Self::State, args: &[Value]) { +/// let Some(val) = args[0].to_integer() else { +/// return; +/// }; +/// *state += val; +/// } +/// +/// fn finalize(state: Self::State) -> Value { +/// Value::from_integer(state + 1) +/// } +///} +/// +/// ``` +#[proc_macro] +pub fn register_extension(input: TokenStream) -> TokenStream { + ext::register_extension(input) +} + /// Declare a scalar function for your extension. This requires the name: /// #[scalar(name = "example")] of what you wish to call your function with. /// ```ignore @@ -162,72 +196,7 @@ fn generate_get_description( /// ``` #[proc_macro_attribute] pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { - let ast = parse_macro_input!(input as ItemFn); - let fn_name = &ast.sig.ident; - let args_variable = &ast.sig.inputs.first(); - let mut args_variable_name = None; - if let Some(syn::FnArg::Typed(syn::PatType { pat, .. })) = args_variable { - if let syn::Pat::Ident(ident) = &**pat { - args_variable_name = Some(ident.ident.clone()); - } - } - let scalar_info = parse_macro_input!(attr as ScalarInfo); - let name = &scalar_info.name; - let register_fn_name = format_ident!("register_{}", fn_name); - let args_variable_name = - format_ident!("{}", args_variable_name.unwrap_or(format_ident!("args"))); - let fn_body = &ast.block; - let alias_check = if let Some(alias) = &scalar_info.alias { - quote! { - let Ok(alias_c_name) = ::std::ffi::CString::new(#alias) else { - return ::limbo_ext::ResultCode::Error; - }; - (api.register_scalar_function)( - api.ctx, - alias_c_name.as_ptr(), - #fn_name, - ); - } - } else { - quote! {} - }; - - let expanded = quote! { - #[no_mangle] - pub unsafe extern "C" fn #register_fn_name( - api: *const ::limbo_ext::ExtensionApi - ) -> ::limbo_ext::ResultCode { - if api.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let api = unsafe { &*api }; - let Ok(c_name) = ::std::ffi::CString::new(#name) else { - return ::limbo_ext::ResultCode::Error; - }; - (api.register_scalar_function)( - api.ctx, - c_name.as_ptr(), - #fn_name, - ); - #alias_check - ::limbo_ext::ResultCode::OK - } - - #[no_mangle] - pub unsafe extern "C" fn #fn_name( - argc: i32, - argv: *const ::limbo_ext::Value - ) -> ::limbo_ext::Value { - let #args_variable_name = if argv.is_null() || argc <= 0 { - &[] - } else { - unsafe { std::slice::from_raw_parts(argv, argc as usize) } - }; - #fn_body - } - }; - - TokenStream::from(expanded) + ext::scalar(attr, input) } /// Define an aggregate function for your extension by deriving @@ -256,86 +225,7 @@ pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream { /// ``` #[proc_macro_derive(AggregateDerive)] pub fn derive_agg_func(input: TokenStream) -> TokenStream { - let ast = parse_macro_input!(input as DeriveInput); - let struct_name = &ast.ident; - - let step_fn_name = format_ident!("{}_step", struct_name); - let finalize_fn_name = format_ident!("{}_finalize", struct_name); - let init_fn_name = format_ident!("{}_init", struct_name); - let register_fn_name = format_ident!("register_{}", struct_name); - - let expanded = quote! { - impl #struct_name { - #[no_mangle] - pub extern "C" fn #init_fn_name() -> *mut ::limbo_ext::AggCtx { - let state = ::std::boxed::Box::new(<#struct_name as ::limbo_ext::AggFunc>::State::default()); - let ctx = ::std::boxed::Box::new(::limbo_ext::AggCtx { - state: ::std::boxed::Box::into_raw(state) as *mut ::std::os::raw::c_void, - }); - ::std::boxed::Box::into_raw(ctx) - } - - #[no_mangle] - pub extern "C" fn #step_fn_name( - ctx: *mut ::limbo_ext::AggCtx, - argc: i32, - argv: *const ::limbo_ext::Value, - ) { - unsafe { - let ctx = &mut *ctx; - let state = &mut *(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); - let args = ::std::slice::from_raw_parts(argv, argc as usize); - <#struct_name as ::limbo_ext::AggFunc>::step(state, args); - } - } - - #[no_mangle] - pub extern "C" fn #finalize_fn_name( - ctx: *mut ::limbo_ext::AggCtx - ) -> ::limbo_ext::Value { - unsafe { - let ctx = &mut *ctx; - let state = ::std::boxed::Box::from_raw(ctx.state as *mut <#struct_name as ::limbo_ext::AggFunc>::State); - match <#struct_name as ::limbo_ext::AggFunc>::finalize(*state) { - Ok(val) => val, - Err(e) => { - ::limbo_ext::Value::error_with_message(e.to_string()) - } - } - } - } - - #[no_mangle] - pub unsafe extern "C" fn #register_fn_name( - api: *const ::limbo_ext::ExtensionApi - ) -> ::limbo_ext::ResultCode { - if api.is_null() { - return ::limbo_ext::ResultCode::Error; - } - - let api = &*api; - let name_str = #struct_name::NAME; - let c_name = match ::std::ffi::CString::new(name_str) { - Ok(cname) => cname, - Err(_) => return ::limbo_ext::ResultCode::Error, - }; - - (api.register_aggregate_function)( - api.ctx, - c_name.as_ptr(), - #struct_name::ARGS, - #struct_name::#init_fn_name - as ::limbo_ext::InitAggFunction, - #struct_name::#step_fn_name - as ::limbo_ext::StepFunction, - #struct_name::#finalize_fn_name - as ::limbo_ext::FinalizeFunction, - ) - } - } - }; - - TokenStream::from(expanded) + ext::derive_agg_func(input) } /// Macro to derive a VTabModule for your extension. This macro will generate @@ -442,597 +332,97 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// #[proc_macro_derive(VTabModuleDerive)] pub fn derive_vtab_module(input: TokenStream) -> TokenStream { - let ast = parse_macro_input!(input as DeriveInput); - let struct_name = &ast.ident; - - let register_fn_name = format_ident!("register_{}", struct_name); - let create_schema_fn_name = format_ident!("create_schema_{}", struct_name); - let open_fn_name = format_ident!("open_{}", struct_name); - let close_fn_name = format_ident!("close_{}", struct_name); - let filter_fn_name = format_ident!("filter_{}", struct_name); - let column_fn_name = format_ident!("column_{}", struct_name); - let next_fn_name = format_ident!("next_{}", struct_name); - let eof_fn_name = format_ident!("eof_{}", struct_name); - let update_fn_name = format_ident!("update_{}", struct_name); - let rowid_fn_name = format_ident!("rowid_{}", struct_name); - let destroy_fn_name = format_ident!("destroy_{}", struct_name); - let best_idx_fn_name = format_ident!("best_idx_{}", struct_name); - - let expanded = quote! { - impl #struct_name { - #[no_mangle] - unsafe extern "C" fn #create_schema_fn_name( - argv: *const ::limbo_ext::Value, argc: i32 - ) -> *mut ::std::ffi::c_char { - let args = if argv.is_null() { - &Vec::new() - } else { - ::std::slice::from_raw_parts(argv, argc as usize) - }; - let sql = <#struct_name as ::limbo_ext::VTabModule>::create_schema(&args); - ::std::ffi::CString::new(sql).unwrap().into_raw() - } - - #[no_mangle] - unsafe extern "C" fn #open_fn_name(ctx: *const ::std::ffi::c_void) -> *const ::std::ffi::c_void { - if ctx.is_null() { - return ::std::ptr::null(); - } - let ctx = ctx as *const #struct_name; - let ctx: &#struct_name = &*ctx; - if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::open(ctx) { - return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *const ::std::ffi::c_void; - } else { - return ::std::ptr::null(); - } - } - - #[no_mangle] - unsafe extern "C" fn #close_fn_name( - cursor: *const ::std::ffi::c_void - ) -> ::limbo_ext::ResultCode { - if cursor.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let boxed_cursor = ::std::boxed::Box::from_raw(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); - boxed_cursor.close() - } - - #[no_mangle] - unsafe extern "C" fn #filter_fn_name( - cursor: *const ::std::ffi::c_void, - argc: i32, - argv: *const ::limbo_ext::Value, - idx_str: *const ::std::ffi::c_char, - idx_num: i32, - ) -> ::limbo_ext::ResultCode { - if cursor.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; - let args = ::std::slice::from_raw_parts(argv, argc as usize); - let idx_str = if idx_str.is_null() { - None - } else { - Some((unsafe { ::std::ffi::CStr::from_ptr(idx_str).to_str().unwrap() }, idx_num)) - }; - <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, args, idx_str) - } - - #[no_mangle] - unsafe extern "C" fn #column_fn_name( - cursor: *const ::std::ffi::c_void, - idx: u32, - ) -> ::limbo_ext::Value { - if cursor.is_null() { - return ::limbo_ext::Value::error(::limbo_ext::ResultCode::Error); - } - let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; - match <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) { - Ok(val) => val, - Err(e) => ::limbo_ext::Value::error_with_message(e.to_string()) - } - } - - #[no_mangle] - unsafe extern "C" fn #next_fn_name( - cursor: *const ::std::ffi::c_void, - ) -> ::limbo_ext::ResultCode { - if cursor.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); - <#struct_name as ::limbo_ext::VTabModule>::next(cursor) - } - - #[no_mangle] - unsafe extern "C" fn #eof_fn_name( - cursor: *const ::std::ffi::c_void, - ) -> bool { - if cursor.is_null() { - return true; - } - let cursor = &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor); - <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) - } - - #[no_mangle] - unsafe extern "C" fn #update_fn_name( - vtab: *const ::std::ffi::c_void, - argc: i32, - argv: *const ::limbo_ext::Value, - p_out_rowid: *mut i64, - ) -> ::limbo_ext::ResultCode { - if vtab.is_null() { - return ::limbo_ext::ResultCode::Error; - } - - let vtab = &mut *(vtab as *mut #struct_name); - let args = ::std::slice::from_raw_parts(argv, argc as usize); - - let old_rowid = match args.get(0).map(|v| v.value_type()) { - Some(::limbo_ext::ValueType::Integer) => args.get(0).unwrap().to_integer(), - _ => None, - }; - let new_rowid = match args.get(1).map(|v| v.value_type()) { - Some(::limbo_ext::ValueType::Integer) => args.get(1).unwrap().to_integer(), - _ => None, - }; - let columns = &args[2..]; - match (old_rowid, new_rowid) { - // DELETE: old_rowid provided, no new_rowid - (Some(old), None) => { - if <#struct_name as VTabModule>::delete(vtab, old).is_err() { - return ::limbo_ext::ResultCode::Error; - } - return ::limbo_ext::ResultCode::OK; - } - // UPDATE: old_rowid provided and new_rowid may exist - (Some(old), Some(new)) => { - if <#struct_name as VTabModule>::update(vtab, old, &columns).is_err() { - return ::limbo_ext::ResultCode::Error; - } - return ::limbo_ext::ResultCode::OK; - } - // INSERT: no old_rowid (old_rowid = None) - (None, _) => { - if let Ok(rowid) = <#struct_name as VTabModule>::insert(vtab, &columns) { - if !p_out_rowid.is_null() { - *p_out_rowid = rowid; - return ::limbo_ext::ResultCode::RowID; - } - return ::limbo_ext::ResultCode::OK; - } - } - } - return ::limbo_ext::ResultCode::Error; - } - - #[no_mangle] - pub unsafe extern "C" fn #rowid_fn_name(ctx: *const ::std::ffi::c_void) -> i64 { - if ctx.is_null() { - return -1; - } - let cursor = &*(ctx as *const <#struct_name as ::limbo_ext::VTabModule>::VCursor); - <<#struct_name as ::limbo_ext::VTabModule>::VCursor as ::limbo_ext::VTabCursor>::rowid(cursor) - } - - #[no_mangle] - unsafe extern "C" fn #destroy_fn_name( - vtab: *const ::std::ffi::c_void, - ) -> ::limbo_ext::ResultCode { - if vtab.is_null() { - return ::limbo_ext::ResultCode::Error; - } - - let vtab = &mut *(vtab as *mut #struct_name); - if <#struct_name as VTabModule>::destroy(vtab).is_err() { - return ::limbo_ext::ResultCode::Error; - } - - return ::limbo_ext::ResultCode::OK; - } - - #[no_mangle] - pub unsafe extern "C" fn #best_idx_fn_name( - constraints: *const ::limbo_ext::ConstraintInfo, - n_constraints: i32, - order_by: *const ::limbo_ext::OrderByInfo, - n_order_by: i32, - ) -> ::limbo_ext::ExtIndexInfo { - let constraints = if n_constraints > 0 { std::slice::from_raw_parts(constraints, n_constraints as usize) } else { &[] }; - let order_by = if n_order_by > 0 { std::slice::from_raw_parts(order_by, n_order_by as usize) } else { &[] }; - <#struct_name as ::limbo_ext::VTabModule>::best_index(constraints, order_by).to_ffi() - } - - #[no_mangle] - pub unsafe extern "C" fn #register_fn_name( - api: *const ::limbo_ext::ExtensionApi - ) -> ::limbo_ext::ResultCode { - if api.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let api = &*api; - let name = <#struct_name as ::limbo_ext::VTabModule>::NAME; - let name_c = ::std::ffi::CString::new(name).unwrap().into_raw() as *const ::std::ffi::c_char; - let table_instance = ::std::boxed::Box::into_raw(::std::boxed::Box::new(#struct_name::default())); - let module = ::limbo_ext::VTabModuleImpl { - ctx: table_instance as *const ::std::ffi::c_void, - name: name_c, - create_schema: Self::#create_schema_fn_name, - open: Self::#open_fn_name, - close: Self::#close_fn_name, - filter: Self::#filter_fn_name, - column: Self::#column_fn_name, - next: Self::#next_fn_name, - eof: Self::#eof_fn_name, - update: Self::#update_fn_name, - rowid: Self::#rowid_fn_name, - destroy: Self::#destroy_fn_name, - best_idx: Self::#best_idx_fn_name, - }; - (api.register_vtab_module)(api.ctx, name_c, module, <#struct_name as ::limbo_ext::VTabModule>::VTAB_KIND) - } - } - }; - - TokenStream::from(expanded) + ext::derive_vtab_module(input) } +/// ```ignore +/// use limbo_ext::{ExtResult as Result, VfsDerive, VfsExtension, VfsFile}; +/// +/// // Your struct must also impl Default +/// #[derive(VfsDerive, Default)] +/// struct ExampleFS; +/// +/// +/// struct ExampleFile { +/// file: std::fs::File, +/// +/// +/// impl VfsExtension for ExampleFS { +/// /// The name of your vfs module +/// const NAME: &'static str = "example"; +/// +/// type File = ExampleFile; +/// +/// fn open(&self, path: &str, flags: i32, _direct: bool) -> Result { +/// let file = OpenOptions::new() +/// .read(true) +/// .write(true) +/// .create(flags & 1 != 0) +/// .open(path) +/// .map_err(|_| ResultCode::Error)?; +/// Ok(TestFile { file }) +/// } +/// +/// fn run_once(&self) -> Result<()> { +/// // (optional) method to cycle/advance IO, if your extension is asynchronous +/// Ok(()) +/// } +/// +/// fn close(&self, file: Self::File) -> Result<()> { +/// // (optional) method to close or drop the file +/// Ok(()) +/// } +/// +/// fn generate_random_number(&self) -> i64 { +/// // (optional) method to generate random number. Used for testing +/// let mut buf = [0u8; 8]; +/// getrandom::fill(&mut buf).unwrap(); +/// i64::from_ne_bytes(buf) +/// } +/// +/// fn get_current_time(&self) -> String { +/// // (optional) method to generate random number. Used for testing +/// chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string() +/// } +/// +/// +/// impl VfsFile for ExampleFile { +/// fn read( +/// &mut self, +/// buf: &mut [u8], +/// count: usize, +/// offset: i64, +/// ) -> Result { +/// if file.file.seek(SeekFrom::Start(offset as u64)).is_err() { +/// return Err(ResultCode::Error); +/// } +/// file.file +/// .read(&mut buf[..count]) +/// .map_err(|_| ResultCode::Error) +/// .map(|n| n as i32) +/// } +/// +/// fn write(&mut self, buf: &[u8], count: usize, offset: i64) -> Result { +/// if self.file.seek(SeekFrom::Start(offset as u64)).is_err() { +/// return Err(ResultCode::Error); +/// } +/// self.file +/// .write(&buf[..count]) +/// .map_err(|_| ResultCode::Error) +/// .map(|n| n as i32) +/// } +/// +/// fn sync(&self) -> Result<()> { +/// self.file.sync_all().map_err(|_| ResultCode::Error) +/// } +/// +/// fn size(&self) -> i64 { +/// self.file.metadata().map(|m| m.len() as i64).unwrap_or(-1) +/// } +///} +/// +///``` #[proc_macro_derive(VfsDerive)] pub fn derive_vfs_module(input: TokenStream) -> TokenStream { - let derive_input = parse_macro_input!(input as DeriveInput); - let struct_name = &derive_input.ident; - let register_fn_name = format_ident!("register_{}", struct_name); - let register_static = format_ident!("register_static_{}", struct_name); - let open_fn_name = format_ident!("{}_open", struct_name); - let close_fn_name = format_ident!("{}_close", struct_name); - let read_fn_name = format_ident!("{}_read", struct_name); - let write_fn_name = format_ident!("{}_write", struct_name); - let lock_fn_name = format_ident!("{}_lock", struct_name); - let unlock_fn_name = format_ident!("{}_unlock", struct_name); - let sync_fn_name = format_ident!("{}_sync", struct_name); - let size_fn_name = format_ident!("{}_size", struct_name); - let run_once_fn_name = format_ident!("{}_run_once", struct_name); - let generate_random_number_fn_name = format_ident!("{}_generate_random_number", struct_name); - let get_current_time_fn_name = format_ident!("{}_get_current_time", struct_name); - - let expanded = quote! { - #[allow(non_snake_case)] - pub unsafe extern "C" fn #register_static() -> *const ::limbo_ext::VfsImpl { - let ctx = #struct_name::default(); - let ctx = ::std::boxed::Box::into_raw(::std::boxed::Box::new(ctx)) as *const ::std::ffi::c_void; - let name = ::std::ffi::CString::new(<#struct_name as ::limbo_ext::VfsExtension>::NAME).unwrap().into_raw(); - let vfs_mod = ::limbo_ext::VfsImpl { - vfs: ctx, - name, - open: #open_fn_name, - close: #close_fn_name, - read: #read_fn_name, - write: #write_fn_name, - lock: #lock_fn_name, - unlock: #unlock_fn_name, - sync: #sync_fn_name, - size: #size_fn_name, - run_once: #run_once_fn_name, - gen_random_number: #generate_random_number_fn_name, - current_time: #get_current_time_fn_name, - }; - ::std::boxed::Box::into_raw(::std::boxed::Box::new(vfs_mod)) as *const ::limbo_ext::VfsImpl - } - - #[no_mangle] - pub unsafe extern "C" fn #register_fn_name(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { - let ctx = #struct_name::default(); - let ctx = ::std::boxed::Box::into_raw(::std::boxed::Box::new(ctx)) as *const ::std::ffi::c_void; - let name = ::std::ffi::CString::new(<#struct_name as ::limbo_ext::VfsExtension>::NAME).unwrap().into_raw(); - let vfs_mod = ::limbo_ext::VfsImpl { - vfs: ctx, - name, - open: #open_fn_name, - close: #close_fn_name, - read: #read_fn_name, - write: #write_fn_name, - lock: #lock_fn_name, - unlock: #unlock_fn_name, - sync: #sync_fn_name, - size: #size_fn_name, - run_once: #run_once_fn_name, - gen_random_number: #generate_random_number_fn_name, - current_time: #get_current_time_fn_name, - }; - let vfsimpl = ::std::boxed::Box::into_raw(::std::boxed::Box::new(vfs_mod)) as *const ::limbo_ext::VfsImpl; - (api.vfs_interface.register_vfs)(name, vfsimpl) - } - - #[no_mangle] - pub unsafe extern "C" fn #open_fn_name( - ctx: *const ::std::ffi::c_void, - path: *const ::std::ffi::c_char, - flags: i32, - direct: bool, - ) -> *const ::std::ffi::c_void { - let ctx = &*(ctx as *const ::limbo_ext::VfsImpl); - let Ok(path_str) = ::std::ffi::CStr::from_ptr(path).to_str() else { - return ::std::ptr::null_mut(); - }; - let vfs = &*(ctx.vfs as *const #struct_name); - let Ok(file_handle) = <#struct_name as ::limbo_ext::VfsExtension>::open_file(vfs, path_str, flags, direct) else { - return ::std::ptr::null(); - }; - let boxed = ::std::boxed::Box::into_raw(::std::boxed::Box::new(file_handle)) as *const ::std::ffi::c_void; - let Ok(vfs_file) = ::limbo_ext::VfsFileImpl::new(boxed, ctx) else { - return ::std::ptr::null(); - }; - ::std::boxed::Box::into_raw(::std::boxed::Box::new(vfs_file)) as *const ::std::ffi::c_void - } - - #[no_mangle] - pub unsafe extern "C" fn #close_fn_name(file_ptr: *const ::std::ffi::c_void) -> ::limbo_ext::ResultCode { - if file_ptr.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); - let vfs_instance = &*(vfs_file.vfs as *const #struct_name); - - // this time we need to own it so we can drop it - let file: ::std::boxed::Box<<#struct_name as ::limbo_ext::VfsExtension>::File> = - ::std::boxed::Box::from_raw(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); - if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::close(vfs_instance, *file) { - return e; - } - ::limbo_ext::ResultCode::OK - } - - #[no_mangle] - pub unsafe extern "C" fn #read_fn_name(file_ptr: *const ::std::ffi::c_void, buf: *mut u8, count: usize, offset: i64) -> i32 { - if file_ptr.is_null() { - return -1; - } - let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); - let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = - &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); - match <#struct_name as ::limbo_ext::VfsExtension>::File::read(file, ::std::slice::from_raw_parts_mut(buf, count), count, offset) { - Ok(n) => n, - Err(_) => -1, - } - } - - #[no_mangle] - pub unsafe extern "C" fn #run_once_fn_name(ctx: *const ::std::ffi::c_void) -> ::limbo_ext::ResultCode { - if ctx.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let ctx = &mut *(ctx as *mut #struct_name); - if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::run_once(ctx) { - return e; - } - ::limbo_ext::ResultCode::OK - } - - #[no_mangle] - pub unsafe extern "C" fn #write_fn_name(file_ptr: *const ::std::ffi::c_void, buf: *const u8, count: usize, offset: i64) -> i32 { - if file_ptr.is_null() { - return -1; - } - let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); - let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = - &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); - match <#struct_name as ::limbo_ext::VfsExtension>::File::write(file, ::std::slice::from_raw_parts(buf, count), count, offset) { - Ok(n) => n, - Err(_) => -1, - } - } - - #[no_mangle] - pub unsafe extern "C" fn #lock_fn_name(file_ptr: *const ::std::ffi::c_void, exclusive: bool) -> ::limbo_ext::ResultCode { - if file_ptr.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); - let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = - &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); - if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::File::lock(file, exclusive) { - return e; - } - ::limbo_ext::ResultCode::OK - } - - #[no_mangle] - pub unsafe extern "C" fn #unlock_fn_name(file_ptr: *const ::std::ffi::c_void) -> ::limbo_ext::ResultCode { - if file_ptr.is_null() { - return ::limbo_ext::ResultCode::Error; - } - let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); - let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = - &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); - if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::File::unlock(file) { - return e; - } - ::limbo_ext::ResultCode::OK - } - - #[no_mangle] - pub unsafe extern "C" fn #sync_fn_name(file_ptr: *const ::std::ffi::c_void) -> i32 { - if file_ptr.is_null() { - return -1; - } - let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); - let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = - &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); - if <#struct_name as ::limbo_ext::VfsExtension>::File::sync(file).is_err() { - return -1; - } - 0 - } - - #[no_mangle] - pub unsafe extern "C" fn #size_fn_name(file_ptr: *const ::std::ffi::c_void) -> i64 { - if file_ptr.is_null() { - return -1; - } - let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); - let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = - &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); - <#struct_name as ::limbo_ext::VfsExtension>::File::size(file) - } - - #[no_mangle] - pub unsafe extern "C" fn #generate_random_number_fn_name() -> i64 { - let obj = #struct_name::default(); - <#struct_name as ::limbo_ext::VfsExtension>::generate_random_number(&obj) - } - - #[no_mangle] - pub unsafe extern "C" fn #get_current_time_fn_name() -> *const ::std::ffi::c_char { - let obj = #struct_name::default(); - let time = <#struct_name as ::limbo_ext::VfsExtension>::get_current_time(&obj); - // release ownership of the string to core - ::std::ffi::CString::new(time).unwrap().into_raw() as *const ::std::ffi::c_char - } - }; - - TokenStream::from(expanded) -} - -/// Register your extension with 'core' by providing the relevant functions -///```ignore -///use limbo_ext::{register_extension, scalar, Value, AggregateDerive, AggFunc}; -/// -/// register_extension!{ scalars: { return_one }, aggregates: { SumPlusOne } } -/// -///#[scalar(name = "one")] -///fn return_one(args: &[Value]) -> Value { -/// return Value::from_integer(1); -///} -/// -///#[derive(AggregateDerive)] -///struct SumPlusOne; -/// -///impl AggFunc for SumPlusOne { -/// type State = i64; -/// const NAME: &'static str = "sum_plus_one"; -/// const ARGS: i32 = 1; -/// -/// fn step(state: &mut Self::State, args: &[Value]) { -/// let Some(val) = args[0].to_integer() else { -/// return; -/// }; -/// *state += val; -/// } -/// -/// fn finalize(state: Self::State) -> Value { -/// Value::from_integer(state + 1) -/// } -///} -/// -/// ``` -#[proc_macro] -pub fn register_extension(input: TokenStream) -> TokenStream { - let input_ast = parse_macro_input!(input as RegisterExtensionInput); - let RegisterExtensionInput { - aggregates, - scalars, - vtabs, - vfs, - } = input_ast; - - let scalar_calls = scalars.iter().map(|scalar_ident| { - let register_fn = - syn::Ident::new(&format!("register_{}", scalar_ident), scalar_ident.span()); - quote! { - { - let result = unsafe { #register_fn(api)}; - if !result.is_ok() { - return result; - } - } - } - }); - - let aggregate_calls = aggregates.iter().map(|agg_ident| { - let register_fn = syn::Ident::new(&format!("register_{}", agg_ident), agg_ident.span()); - quote! { - { - let result = unsafe{ #agg_ident::#register_fn(api)}; - if !result.is_ok() { - return result; - } - } - } - }); - let vtab_calls = vtabs.iter().map(|vtab_ident| { - let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span()); - quote! { - { - let result = unsafe{ #vtab_ident::#register_fn(api)}; - if !result.is_ok() { - return result; - } - } - } - }); - let vfs_calls = vfs.iter().map(|vfs_ident| { - let register_fn = syn::Ident::new(&format!("register_{}", vfs_ident), vfs_ident.span()); - quote! { - { - let result = unsafe { #register_fn(api) }; - if !result.is_ok() { - return result; - } - } - } - }); - let static_vfs = vfs.iter().map(|vfs_ident| { - let static_register = - syn::Ident::new(&format!("register_static_{}", vfs_ident), vfs_ident.span()); - quote! { - { - let result = api.add_builtin_vfs(unsafe { #static_register()}); - if !result.is_ok() { - return result; - } - } - } - }); - let static_aggregates = aggregate_calls.clone(); - let static_scalars = scalar_calls.clone(); - let static_vtabs = vtab_calls.clone(); - - let expanded = quote! { - #[cfg(not(target_family = "wasm"))] - #[cfg(not(feature = "static"))] - #[global_allocator] - static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; - - #[cfg(feature = "static")] - pub unsafe extern "C" fn register_extension_static(api: &mut ::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { - #(#static_scalars)* - - #(#static_aggregates)* - - #(#static_vtabs)* - - #[cfg(not(target_family = "wasm"))] - #(#static_vfs)* - - ::limbo_ext::ResultCode::OK - } - - #[cfg(not(feature = "static"))] - #[no_mangle] - pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { - #(#scalar_calls)* - - #(#aggregate_calls)* - - #(#vtab_calls)* - - #(#vfs_calls)* - - ::limbo_ext::ResultCode::OK - } - }; - - TokenStream::from(expanded) + ext::derive_vfs_module(input) }