diff --git a/macros/src/atomic_enum.rs b/macros/src/atomic_enum.rs new file mode 100644 index 000000000..d1248f9d8 --- /dev/null +++ b/macros/src/atomic_enum.rs @@ -0,0 +1,290 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, Data, DeriveInput, Fields, Type}; + +pub(crate) fn derive_atomic_enum_inner(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = &input.ident; + let atomic_name = syn::Ident::new(&format!("Atomic{name}"), name.span()); + + let variants = match &input.data { + Data::Enum(data) => &data.variants, + _ => { + return syn::Error::new_spanned(input, "AtomicEnum can only be derived for enums") + .to_compile_error() + .into(); + } + }; + + // get info about variants to determine how we have to encode them + let mut has_bool_field = false; + let mut has_u8_field = false; + let mut max_discriminant = 0u8; + + for (idx, variant) in variants.iter().enumerate() { + max_discriminant = idx as u8; + match &variant.fields { + Fields::Unit => {} + Fields::Named(fields) if fields.named.len() == 1 => { + let field = &fields.named[0]; + if is_bool_type(&field.ty) { + has_bool_field = true; + } else if is_u8_or_i8_type(&field.ty) { + has_u8_field = true; + } else { + return syn::Error::new_spanned( + field, + "AtomicEnum only supports bool, u8, or i8 fields", + ) + .to_compile_error() + .into(); + } + } + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + let field = &fields.unnamed[0]; + if is_bool_type(&field.ty) { + has_bool_field = true; + } else if is_u8_or_i8_type(&field.ty) { + has_u8_field = true; + } else { + return syn::Error::new_spanned( + field, + "AtomicEnum only supports bool, u8, or i8 fields", + ) + .to_compile_error() + .into(); + } + } + _ => { + return syn::Error::new_spanned( + variant, + "AtomicEnum only supports unit variants or variants with a single field", + ) + .to_compile_error() + .into(); + } + } + } + + let (storage_type, atomic_type) = if has_u8_field || (has_bool_field && max_discriminant > 127) + { + // Need u16: 8 bits for discriminant, 8 bits for data + (quote! { u16 }, quote! { ::std::sync::atomic::AtomicU16 }) + } else { + // Can use u8: 7 bits for discriminant, 1 bit for bool (if any) + (quote! { u8 }, quote! { ::std::sync::atomic::AtomicU8 }) + }; + + let use_u16 = has_u8_field || (has_bool_field && max_discriminant > 127); + + let to_storage = variants.iter().enumerate().map(|(idx, variant)| { + let var_name = &variant.ident; + let disc = idx as u8; // The discriminant here is just the variant's index + + match &variant.fields { + // Simple unit variant, just store the discriminant + Fields::Unit => { + if use_u16 { + quote! { #name::#var_name => #disc as u16 } + } else { + quote! { #name::#var_name => #disc } + } + } + Fields::Named(fields) => { + // Named field variant like `Write { schema_did_change: bool }` + let field = &fields.named[0]; + let field_name = &field.ident; + + if is_bool_type(&field.ty) { + if use_u16 { + // Pack as: [discriminant_byte | bool_as_byte] + // Example: Write {true} with disc=3 becomes: b100000011 + quote! { + #name::#var_name { ref #field_name } => { + (#disc as u16) | ((*#field_name as u16) << 8) + } + } + } else { + // Same as above but with u8, so only 1 bit for bool + // Example: Write{true} with disc=3 becomes: b10000011 + quote! { + #name::#var_name { ref #field_name } => { + #disc | ((*#field_name as u8) << 7) + } + } + } + } else { + // u8/i8 field always uses u16 to have enough bits + // Pack as: [discriminant_byte | value_byte] + quote! { + #name::#var_name { ref #field_name } => { + (#disc as u16) | ((*#field_name as u16) << 8) + } + } + } + } + Fields::Unnamed(_) => { + // same strategy as above, but for tuple variants like `Write(bool)` + if is_bool_type(&variant.fields.iter().next().unwrap().ty) { + if use_u16 { + quote! { + #name::#var_name(ref val) => { + (#disc as u16) | ((*val as u16) << 8) + } + } + } else { + quote! { + #name::#var_name(ref val) => { + #disc | ((*val as u8) << 7) + } + } + } + } else { + quote! { + #name::#var_name(ref val) => { + (#disc as u16) | ((*val as u16) << 8) + } + } + } + } + } + }); + + // Generate the match arms for decoding the storage representation back to enum + let from_storage = variants.iter().enumerate().map(|(idx, variant)| { + let var_name = &variant.ident; + let disc = idx as u8; + + match &variant.fields { + Fields::Unit => quote! { #disc => #name::#var_name }, + Fields::Named(fields) => { + let field = &fields.named[0]; + let field_name = &field.ident; + + if is_bool_type(&field.ty) { + if use_u16 { + // Extract bool from high byte: check if non-zero + quote! { + #disc => #name::#var_name { + #field_name: (val >> 8) != 0 + } + } + } else { + // check single bool value at bit 7 + quote! { + #disc => #name::#var_name { + #field_name: (val & 0x80) != 0 + } + } + } + } else { + quote! { + #disc => #name::#var_name { + // Extract u8/i8 from high byte and cast to appropriate type + #field_name: (val >> 8) as _ + } + } + } + } + Fields::Unnamed(_) => { + if is_bool_type(&variant.fields.iter().next().unwrap().ty) { + if use_u16 { + quote! { #disc => #name::#var_name((val >> 8) != 0) } + } else { + quote! { #disc => #name::#var_name((val & 0x80) != 0) } + } + } else { + quote! { #disc => #name::#var_name((val >> 8) as _) } + } + } + } + }); + + let discriminant_mask = if use_u16 { + quote! { 0xFF } + } else { + quote! { 0x7F } + }; + let to_storage_arms_copy = to_storage.clone(); + + let expanded = quote! { + #[derive(Debug)] + /// Atomic wrapper for #name + pub struct #atomic_name(#atomic_type); + + impl #atomic_name { + /// Encode enum into storage representation + /// Discriminant in lower bits, field data in upper bits + #[inline] + fn to_storage(val: &#name) -> #storage_type { + match val { + #(#to_storage_arms_copy),* + } + } + + /// Decode storage representation into enum + /// Panics on invalid discriminant + #[inline] + fn from_storage(val: #storage_type) -> #name { + let discriminant = (val & #discriminant_mask) as u8; + match discriminant { + #(#from_storage,)* + _ => panic!(concat!("Invalid ", stringify!(#name), " discriminant: {}"), discriminant), + } + } + + /// Create new atomic enum with initial value + #[inline] + pub const fn new(val: #name) -> Self { + // Can't call to_storage in const context, so inline it + let storage = match val { + #(#to_storage),* + }; + Self(#atomic_type::new(storage)) + } + + #[inline] + /// Load and convert the current value to expected enum + pub fn get(&self) -> #name { + Self::from_storage(self.0.load(::std::sync::atomic::Ordering::Acquire)) + } + + #[inline] + /// Convert and store new value + pub fn set(&self, val: #name) { + self.0.store(Self::to_storage(&val), ::std::sync::atomic::Ordering::Release) + } + + #[inline] + /// Store new value and return previous value + pub fn swap(&self, val: #name) -> #name { + let prev = self.0.swap(Self::to_storage(&val), ::std::sync::atomic::Ordering::AcqRel); + Self::from_storage(prev) + } + } + + impl From<#name> for #atomic_name { + fn from(val: #name) -> Self { + Self::new(val) + } + } + }; + + TokenStream::from(expanded) +} + +fn is_bool_type(ty: &Type) -> bool { + if let Type::Path(path) = ty { + path.path.is_ident("bool") + } else { + false + } +} + +fn is_u8_or_i8_type(ty: &Type) -> bool { + if let Type::Path(path) = ty { + path.path.is_ident("u8") || path.path.is_ident("i8") + } else { + false + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8df01da22..9c95cbd1a 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,5 +1,6 @@ mod ext; extern crate proc_macro; +mod atomic_enum; use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree}; use std::collections::HashMap; @@ -464,3 +465,30 @@ pub fn derive_vfs_module(input: TokenStream) -> TokenStream { pub fn match_ignore_ascii_case(input: TokenStream) -> TokenStream { ext::match_ignore_ascci_case(input) } + +/// Derive macro for creating atomic wrappers for enums +/// +/// Supports: +/// - Unit variants +/// - Variants with single bool/u8/i8 fields +/// - Named or unnamed fields +/// +/// Algorithm: +/// - Uses u8 representation, splitting bits for variant discriminant and field data +/// - For bool fields: high bit for bool, lower 7 bits for discriminant +/// - For u8/i8 fields: uses u16 internally (8 bits discriminant, 8 bits data) +/// +/// Example: +/// ```rust +/// #[derive(AtomicEnum)] +/// enum TransactionState { +/// Write { schema_did_change: bool }, +/// Read, +/// PendingUpgrade, +/// None, +/// } +/// ``` +#[proc_macro_derive(AtomicEnum)] +pub fn derive_atomic_enum(input: TokenStream) -> TokenStream { + atomic_enum::derive_atomic_enum_inner(input) +}