use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, Data, DeriveInput, Fields, Type}; pub(crate) fn derive_atomic_enum_inner(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let name = &input.ident; let atomic_name = syn::Ident::new(&format!("Atomic{name}"), name.span()); let variants = match &input.data { Data::Enum(data) => &data.variants, _ => { return syn::Error::new_spanned(input, "AtomicEnum can only be derived for enums") .to_compile_error() .into(); } }; // get info about variants to determine how we have to encode them let mut has_bool_field = false; let mut has_u8_field = false; let mut max_discriminant = 0u8; for (idx, variant) in variants.iter().enumerate() { max_discriminant = idx as u8; match &variant.fields { Fields::Unit => {} Fields::Named(fields) if fields.named.len() == 1 => { let field = &fields.named[0]; if is_bool_type(&field.ty) { has_bool_field = true; } else if is_u8_or_i8_type(&field.ty) { has_u8_field = true; } else { return syn::Error::new_spanned( field, "AtomicEnum only supports bool, u8, or i8 fields", ) .to_compile_error() .into(); } } Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { let field = &fields.unnamed[0]; if is_bool_type(&field.ty) { has_bool_field = true; } else if is_u8_or_i8_type(&field.ty) { has_u8_field = true; } else { return syn::Error::new_spanned( field, "AtomicEnum only supports bool, u8, or i8 fields", ) .to_compile_error() .into(); } } _ => { return syn::Error::new_spanned( variant, "AtomicEnum only supports unit variants or variants with a single field", ) .to_compile_error() .into(); } } } let (storage_type, atomic_type) = if has_u8_field || (has_bool_field && max_discriminant > 127) { // Need u16: 8 bits for discriminant, 8 bits for data (quote! { u16 }, quote! { ::std::sync::atomic::AtomicU16 }) } else { // Can use u8: 7 bits for discriminant, 1 bit for bool (if any) (quote! { u8 }, quote! { ::std::sync::atomic::AtomicU8 }) }; let use_u16 = has_u8_field || (has_bool_field && max_discriminant > 127); let to_storage = variants.iter().enumerate().map(|(idx, variant)| { let var_name = &variant.ident; let disc = idx as u8; // The discriminant here is just the variant's index match &variant.fields { // Simple unit variant, just store the discriminant Fields::Unit => { if use_u16 { quote! { #name::#var_name => #disc as u16 } } else { quote! { #name::#var_name => #disc } } } Fields::Named(fields) => { // Named field variant like `Write { schema_did_change: bool }` let field = &fields.named[0]; let field_name = &field.ident; if is_bool_type(&field.ty) { if use_u16 { // Pack as: [discriminant_byte | bool_as_byte] // Example: Write {true} with disc=3 becomes: b100000011 quote! { #name::#var_name { ref #field_name } => { (#disc as u16) | ((*#field_name as u16) << 8) } } } else { // Same as above but with u8, so only 1 bit for bool // Example: Write{true} with disc=3 becomes: b10000011 quote! { #name::#var_name { ref #field_name } => { #disc | ((*#field_name as u8) << 7) } } } } else { // u8/i8 field always uses u16 to have enough bits // Pack as: [discriminant_byte | value_byte] quote! { #name::#var_name { ref #field_name } => { (#disc as u16) | ((*#field_name as u16) << 8) } } } } Fields::Unnamed(_) => { // same strategy as above, but for tuple variants like `Write(bool)` if is_bool_type(&variant.fields.iter().next().unwrap().ty) { if use_u16 { quote! { #name::#var_name(ref val) => { (#disc as u16) | ((*val as u16) << 8) } } } else { quote! { #name::#var_name(ref val) => { #disc | ((*val as u8) << 7) } } } } else { quote! { #name::#var_name(ref val) => { (#disc as u16) | ((*val as u16) << 8) } } } } } }); // Generate the match arms for decoding the storage representation back to enum let from_storage = variants.iter().enumerate().map(|(idx, variant)| { let var_name = &variant.ident; let disc = idx as u8; match &variant.fields { Fields::Unit => quote! { #disc => #name::#var_name }, Fields::Named(fields) => { let field = &fields.named[0]; let field_name = &field.ident; if is_bool_type(&field.ty) { if use_u16 { // Extract bool from high byte: check if non-zero quote! { #disc => #name::#var_name { #field_name: (val >> 8) != 0 } } } else { // check single bool value at bit 7 quote! { #disc => #name::#var_name { #field_name: (val & 0x80) != 0 } } } } else { quote! { #disc => #name::#var_name { // Extract u8/i8 from high byte and cast to appropriate type #field_name: (val >> 8) as _ } } } } Fields::Unnamed(_) => { if is_bool_type(&variant.fields.iter().next().unwrap().ty) { if use_u16 { quote! { #disc => #name::#var_name((val >> 8) != 0) } } else { quote! { #disc => #name::#var_name((val & 0x80) != 0) } } } else { quote! { #disc => #name::#var_name((val >> 8) as _) } } } } }); let discriminant_mask = if use_u16 { quote! { 0xFF } } else { quote! { 0x7F } }; let to_storage_arms_copy = to_storage.clone(); let expanded = quote! { #[derive(Debug)] /// Atomic wrapper for #name pub struct #atomic_name(#atomic_type); impl #atomic_name { /// Encode enum into storage representation /// Discriminant in lower bits, field data in upper bits #[inline] fn to_storage(val: &#name) -> #storage_type { match val { #(#to_storage_arms_copy),* } } /// Decode storage representation into enum /// Panics on invalid discriminant #[inline] fn from_storage(val: #storage_type) -> #name { let discriminant = (val & #discriminant_mask) as u8; match discriminant { #(#from_storage,)* _ => panic!(concat!("Invalid ", stringify!(#name), " discriminant: {}"), discriminant), } } /// Create new atomic enum with initial value #[inline] pub const fn new(val: #name) -> Self { // Can't call to_storage in const context, so inline it let storage = match val { #(#to_storage),* }; Self(#atomic_type::new(storage)) } #[inline] /// Load and convert the current value to expected enum pub fn get(&self) -> #name { Self::from_storage(self.0.load(::std::sync::atomic::Ordering::SeqCst)) } #[inline] /// Convert and store new value pub fn set(&self, val: #name) { self.0.store(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst) } #[inline] /// Store new value and return previous value pub fn swap(&self, val: #name) -> #name { let prev = self.0.swap(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst); Self::from_storage(prev) } } impl From<#name> for #atomic_name { fn from(val: #name) -> Self { Self::new(val) } } }; TokenStream::from(expanded) } fn is_bool_type(ty: &Type) -> bool { if let Type::Path(path) = ty { path.path.is_ident("bool") } else { false } } fn is_u8_or_i8_type(ty: &Type) -> bool { if let Type::Path(path) = ty { path.path.is_ident("u8") || path.path.is_ident("i8") } else { false } }