From 343ccb3f72e409575fdec540349e206901f525e4 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 14 Jan 2025 11:49:34 -0500 Subject: [PATCH] Replace declare_scalar_functions in extension API with proc macro --- Cargo.lock | 46 +++++---- extensions/uuid/src/lib.rs | 196 +++++++++++++++++++------------------ limbo_extension/Cargo.toml | 1 + limbo_extension/src/lib.rs | 39 +------- macros/Cargo.toml | 5 + macros/src/args.rs | 63 ++++++++++++ macros/src/lib.rs | 122 +++++++++++++++++++++++ 7 files changed, 319 insertions(+), 153 deletions(-) create mode 100644 macros/src/args.rs diff --git a/Cargo.lock b/Cargo.lock index 7de3e3500..9cbce1207 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -331,7 +331,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -665,7 +665,7 @@ checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -809,7 +809,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1251,6 +1251,7 @@ dependencies = [ name = "limbo_extension" version = "0.0.12" dependencies = [ + "limbo_macros", "log", ] @@ -1266,6 +1267,11 @@ dependencies = [ [[package]] name = "limbo_macros" version = "0.0.12" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] [[package]] name = "limbo_sim" @@ -1374,7 +1380,7 @@ checksum = "23c9b935fbe1d6cbd1dac857b54a688145e2d93f48db36010514d0f612d0ad67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1435,7 +1441,7 @@ dependencies = [ "cfg-if", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1595,7 +1601,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1841,7 +1847,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1854,7 +1860,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1868,9 +1874,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -2024,7 +2030,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.90", + "syn 2.0.96", "unicode-ident", ] @@ -2137,7 +2143,7 @@ checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -2315,9 +2321,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.90" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -2404,7 +2410,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -2415,7 +2421,7 @@ checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -2454,7 +2460,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -2588,7 +2594,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", "wasm-bindgen-shared", ] @@ -2623,7 +2629,7 @@ checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2916,5 +2922,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index e6a3a4f9b..f8d8f3816 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -1,5 +1,5 @@ use limbo_extension::{ - declare_scalar_functions, register_extension, register_scalar_functions, Value, ValueType, + export_scalar, register_extension, register_scalar_functions, Value, ValueType, }; register_extension! { @@ -11,31 +11,34 @@ register_extension! { "uuid_str" => uuid_str, "uuid_blob" => uuid_blob, "uuid7_timestamp_ms" => exec_ts_from_uuid7, + "gen_random_uuid" => uuid4_str, }, } -declare_scalar_functions! { - #[args(0)] - fn uuid4_str(_args: &[Value]) -> Value { - let uuid = uuid::Uuid::new_v4().to_string(); - Value::from_text(uuid) - } +#[export_scalar] +#[args(0)] +fn uuid4_str(_args: &[Value]) -> Value { + let uuid = uuid::Uuid::new_v4().to_string(); + Value::from_text(uuid) +} - #[args(0)] - fn uuid4_blob(_args: &[Value]) -> Value { - let uuid = uuid::Uuid::new_v4(); - let bytes = uuid.as_bytes(); - Value::from_blob(bytes.to_vec()) - } +#[export_scalar] +#[args(0)] +fn uuid4_blob(_args: &[Value]) -> Value { + let uuid = uuid::Uuid::new_v4(); + let bytes = uuid.as_bytes(); + Value::from_blob(bytes.to_vec()) +} - #[args(0..=1)] - fn uuid7_str(args: &[Value]) -> Value { - let timestamp = if args.is_empty() { - let ctx = uuid::ContextV7::new(); - uuid::Timestamp::now(ctx) - } else { - let arg = &args[0]; - match arg.value_type() { +#[export_scalar] +#[args(0..=1)] +fn uuid7_str(args: &[Value]) -> Value { + let timestamp = if args.is_empty() { + let ctx = uuid::ContextV7::new(); + uuid::Timestamp::now(ctx) + } else { + let arg = &args[0]; + match arg.value_type() { ValueType::Integer => { let ctx = uuid::ContextV7::new(); let Some(int) = arg.to_integer() else { @@ -43,94 +46,97 @@ declare_scalar_functions! { }; uuid::Timestamp::from_unix(ctx, int as u64, 0) } - ValueType::Text => { - let Some(text) = arg.to_text() else { - return Value::null(); - }; + ValueType::Text => { + let Some(text) = arg.to_text() else { + return Value::null(); + }; match text.parse::() { Ok(unix) => { - if unix <= 0 { - return Value::null(); - } + if unix <= 0 { + return Value::null(); + } uuid::Timestamp::from_unix(uuid::ContextV7::new(), unix as u64, 0) } Err(_) => return Value::null(), } } _ => return Value::null(), - } - }; - let uuid = uuid::Uuid::new_v7(timestamp); - Value::from_text(uuid.to_string()) - } + } + }; + let uuid = uuid::Uuid::new_v7(timestamp); + Value::from_text(uuid.to_string()) +} - #[args(0..=1)] - fn uuid7_blob(args: &[Value]) -> Value { - let timestamp = if args.is_empty() { - let ctx = uuid::ContextV7::new(); - uuid::Timestamp::now(ctx) - } else if args[0].value_type() == limbo_extension::ValueType::Integer { - let ctx = uuid::ContextV7::new(); - let Some(int) = args[0].to_integer() else { - return Value::null(); - }; - uuid::Timestamp::from_unix(ctx, int as u64, 0) - } else { +#[export_scalar] +#[args(0..=1)] +fn uuid7_blob(args: &[Value]) -> Value { + let timestamp = if args.is_empty() { + let ctx = uuid::ContextV7::new(); + uuid::Timestamp::now(ctx) + } else if args[0].value_type() == limbo_extension::ValueType::Integer { + let ctx = uuid::ContextV7::new(); + let Some(int) = args[0].to_integer() else { + return Value::null(); + }; + uuid::Timestamp::from_unix(ctx, int as u64, 0) + } else { + return Value::null(); + }; + let uuid = uuid::Uuid::new_v7(timestamp); + let bytes = uuid.as_bytes(); + Value::from_blob(bytes.to_vec()) +} + +#[export_scalar] +#[args(1)] +fn exec_ts_from_uuid7(args: &[Value]) -> Value { + match args[0].value_type() { + ValueType::Blob => { + let Some(blob) = &args[0].to_blob() else { return Value::null(); - }; - let uuid = uuid::Uuid::new_v7(timestamp); - let bytes = uuid.as_bytes(); - Value::from_blob(bytes.to_vec()) - } - - #[args(1)] - fn exec_ts_from_uuid7(args: &[Value]) -> Value { - match args[0].value_type() { - ValueType::Blob => { - let Some(blob) = &args[0].to_blob() else { - return Value::null(); - }; - let uuid = uuid::Uuid::from_slice(blob.as_slice()).unwrap(); - let unix = uuid_to_unix(uuid.as_bytes()); - Value::from_integer(unix as i64) - } - ValueType::Text => { - let Some(text) = args[0].to_text() else { - return Value::null(); - }; - let Ok(uuid) = uuid::Uuid::parse_str(&text) else { - return Value::null(); - }; - let unix = uuid_to_unix(uuid.as_bytes()); - Value::from_integer(unix as i64) - } - _ => Value::null(), + }; + let uuid = uuid::Uuid::from_slice(blob.as_slice()).unwrap(); + let unix = uuid_to_unix(uuid.as_bytes()); + Value::from_integer(unix as i64) } - } - - #[args(1)] - fn uuid_str(args: &[Value]) -> Value { - let Some(blob) = args[0].to_blob() else { - return Value::null(); - }; - let parsed = uuid::Uuid::from_slice(blob.as_slice()).ok().map(|u| u.to_string()); - match parsed { - Some(s) => Value::from_text(s), - None => Value::null() + ValueType::Text => { + let Some(text) = args[0].to_text() else { + return Value::null(); + }; + let Ok(uuid) = uuid::Uuid::parse_str(&text) else { + return Value::null(); + }; + let unix = uuid_to_unix(uuid.as_bytes()); + Value::from_integer(unix as i64) } + _ => Value::null(), } +} - #[args(1)] - fn uuid_blob(args: &[Value]) -> Value { - let Some(text) = args[0].to_text() else { - return Value::null(); - }; - match uuid::Uuid::parse_str(&text) { - Ok(uuid) => { - Value::from_blob(uuid.as_bytes().to_vec()) - } - Err(_) => Value::null() - } +#[export_scalar] +#[args(1)] +fn uuid_str(args: &[Value]) -> Value { + let Some(blob) = args[0].to_blob() else { + return Value::null(); + }; + let parsed = uuid::Uuid::from_slice(blob.as_slice()) + .ok() + .map(|u| u.to_string()); + match parsed { + Some(s) => Value::from_text(s), + None => Value::null(), + } +} + +#[export_scalar] +#[args(1)] +fn uuid_blob(args: &[Value]) -> Value { + let Some(text) = args[0].to_text() else { + return Value::null(); + }; + match uuid::Uuid::parse_str(&text) { + Ok(uuid) => Value::from_blob(uuid.as_bytes().to_vec()), + Err(_) => Value::null(), } } diff --git a/limbo_extension/Cargo.toml b/limbo_extension/Cargo.toml index 2928ed853..94c0229e5 100644 --- a/limbo_extension/Cargo.toml +++ b/limbo_extension/Cargo.toml @@ -8,3 +8,4 @@ repository.workspace = true [dependencies] log = "0.4.20" +limbo_macros = { path = "../macros" } diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index 0666c588b..ab598cc09 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -1,7 +1,6 @@ use std::os::raw::{c_char, c_void}; - pub type ResultCode = i32; - +pub use limbo_macros::export_scalar; pub const RESULT_OK: ResultCode = 0; pub const RESULT_ERROR: ResultCode = 1; // TODO: more error types @@ -50,42 +49,6 @@ macro_rules! register_scalar_functions { } } -#[macro_export] -macro_rules! declare_scalar_functions { - ( - $( - #[args($($args_count:tt)+)] - fn $func_name:ident ($args:ident : &[Value]) -> Value $body:block - )* - ) => { - $( - extern "C" fn $func_name( - argc: i32, - argv: *const $crate::Value - ) -> $crate::Value { - let valid_args = { - match argc { - $($args_count)+ => true, - _ => false, - } - }; - if !valid_args { - return $crate::Value::null(); - } - if argc == 0 || argv.is_null() { - log::debug!("{} was called with no arguments", stringify!($func_name)); - let $args: &[$crate::Value] = &[]; - $body - } else { - let ptr_slice = unsafe{ std::slice::from_raw_parts(argv, argc as usize)}; - let $args: &[$crate::Value] = ptr_slice; - $body - } - } - )* - }; -} - #[repr(C)] #[derive(PartialEq, Eq, Clone, Copy)] pub enum ValueType { diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 78a3805c6..fb41bc18b 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -11,3 +11,8 @@ description = "The Limbo database library" [lib] proc-macro = true + +[dependencies] +quote = "1.0.38" +proc-macro2 = "1.0.38" +syn = { version = "2.0.96", features = ["full"]} diff --git a/macros/src/args.rs b/macros/src/args.rs new file mode 100644 index 000000000..d0988b5e9 --- /dev/null +++ b/macros/src/args.rs @@ -0,0 +1,63 @@ +use syn::parse::{Parse, ParseStream, Result as ParseResult}; +use syn::{LitInt, Token}; +#[derive(Debug)] +pub enum ArgsSpec { + Exact(i32), + Range { + lower: i32, + upper: i32, + inclusive: bool, + }, +} + +pub struct ArgsAttr { + pub spec: ArgsSpec, +} + +impl Parse for ArgsAttr { + fn parse(input: ParseStream) -> ParseResult { + if input.peek(LitInt) { + let start_lit = input.parse::()?; + let start_val = start_lit.base10_parse::()?; + + if input.is_empty() { + return Ok(ArgsAttr { + spec: ArgsSpec::Exact(start_val), + }); + } + if input.peek(Token![..=]) { + let _dots = input.parse::()?; + let end_lit = input.parse::()?; + let end_val = end_lit.base10_parse::()?; + Ok(ArgsAttr { + spec: ArgsSpec::Range { + lower: start_val, + upper: end_val, + inclusive: true, + }, + }) + } else if input.peek(Token![..]) { + let _dots = input.parse::()?; + let end_lit = input.parse::()?; + let end_val = end_lit.base10_parse::()?; + Ok(ArgsAttr { + spec: ArgsSpec::Range { + lower: start_val, + upper: end_val, + inclusive: false, + }, + }) + } else { + Err(syn::Error::new_spanned( + start_lit, + "Expected '..' or '..=' for a range, or nothing for a single integer.", + )) + } + } else { + Err(syn::Error::new( + input.span(), + "Expected an integer or a range expression, like `0`, `0..2`, or `0..=2`.", + )) + } + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 1cbf31e7a..5b21e2a90 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,3 +1,5 @@ +mod args; +use args::{ArgsAttr, ArgsSpec}; extern crate proc_macro; use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree}; use std::collections::HashMap; @@ -133,3 +135,123 @@ fn generate_get_description( ); enum_impl.parse().unwrap() } + +use quote::quote; +use syn::{parse_macro_input, Attribute, Block, ItemFn}; +/// Macro to transform the preferred API for scalar functions in extensions into +/// an FFI-compatible function signature while validating argc +#[proc_macro_attribute] +pub fn export_scalar(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut input_fn = parse_macro_input!(item as ItemFn); + + let fn_name = &input_fn.sig.ident; + let fn_body: &Block = &input_fn.block; + + let mut extracted_spec: Option = None; + let mut arg_err = None; + let kept_attrs: Vec = input_fn + .attrs + .into_iter() + .filter_map(|attr| { + if attr.path().is_ident("args") { + let parsed_attr = match attr.parse_args::() { + Ok(p) => p, + Err(err) => { + arg_err = Some(err.to_compile_error()); + return None; + } + }; + extracted_spec = Some(parsed_attr.spec); + None + } else { + Some(attr) + } + }) + .collect(); + input_fn.attrs = kept_attrs; + if let Some(arg_err) = arg_err { + return arg_err.into(); + } + let spec = match extracted_spec { + Some(s) => s, + None => { + return syn::Error::new_spanned( + fn_name, + "Expected an attribute with integer or range: #[args(1)] #[args(0..2)], etc.", + ) + .to_compile_error() + .into() + } + }; + let arg_check = match spec { + ArgsSpec::Exact(exact_count) => { + quote! { + if argc != #exact_count { + log::error!( + "{} was called with {} arguments, expected exactly {}", + stringify!(#fn_name), + argc, + #exact_count + ); + return ::limbo_extension::Value::null(); + } + } + } + ArgsSpec::Range { + lower, + upper, + inclusive: true, + } => { + quote! { + if !(#lower..=#upper).contains(&argc) { + log::error!( + "{} was called with {} arguments, expected {}..={} range", + stringify!(#fn_name), + argc, + #lower, + #upper + ); + return ::limbo_extension::Value::null(); + } + } + } + ArgsSpec::Range { + lower, + upper, + inclusive: false, + } => { + quote! { + if !(#lower..#upper).contains(&argc) { + log::error!( + "{} was called with {} arguments, expected {}..{} (exclusive)", + stringify!(#fn_name), + argc, + #lower, + #upper + ); + return ::limbo_extension::Value::null(); + } + } + } + }; + let expanded = quote! { + #[export_name = stringify!(#fn_name)] + extern "C" fn #fn_name(argc: i32, argv: *const ::limbo_extension::Value) -> ::limbo_extension::Value { + #arg_check + + // from_raw_parts doesn't currently accept null ptr + if argc == 0 || argv.is_null() { + log::debug!("{} was called with no arguments", stringify!(#fn_name)); + let args: &[::limbo_extension::Value] = &[]; + #fn_body + } else { + let ptr_slice = unsafe { + std::slice::from_raw_parts(argv, argc as usize) + }; + let args: &[::limbo_extension::Value] = ptr_slice; + #fn_body + } + } + }; + TokenStream::from(expanded) +}