mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-29 14:04:22 +01:00
291 lines
10 KiB
Rust
291 lines
10 KiB
Rust
use proc_macro::TokenStream;
|
|
use quote::quote;
|
|
use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
|
|
|
|
pub(crate) fn derive_atomic_enum_inner(input: TokenStream) -> TokenStream {
|
|
let input = parse_macro_input!(input as DeriveInput);
|
|
let name = &input.ident;
|
|
let atomic_name = syn::Ident::new(&format!("Atomic{name}"), name.span());
|
|
|
|
let variants = match &input.data {
|
|
Data::Enum(data) => &data.variants,
|
|
_ => {
|
|
return syn::Error::new_spanned(input, "AtomicEnum can only be derived for enums")
|
|
.to_compile_error()
|
|
.into();
|
|
}
|
|
};
|
|
|
|
// get info about variants to determine how we have to encode them
|
|
let mut has_bool_field = false;
|
|
let mut has_u8_field = false;
|
|
let mut max_discriminant = 0u8;
|
|
|
|
for (idx, variant) in variants.iter().enumerate() {
|
|
max_discriminant = idx as u8;
|
|
match &variant.fields {
|
|
Fields::Unit => {}
|
|
Fields::Named(fields) if fields.named.len() == 1 => {
|
|
let field = &fields.named[0];
|
|
if is_bool_type(&field.ty) {
|
|
has_bool_field = true;
|
|
} else if is_u8_or_i8_type(&field.ty) {
|
|
has_u8_field = true;
|
|
} else {
|
|
return syn::Error::new_spanned(
|
|
field,
|
|
"AtomicEnum only supports bool, u8, or i8 fields",
|
|
)
|
|
.to_compile_error()
|
|
.into();
|
|
}
|
|
}
|
|
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
|
|
let field = &fields.unnamed[0];
|
|
if is_bool_type(&field.ty) {
|
|
has_bool_field = true;
|
|
} else if is_u8_or_i8_type(&field.ty) {
|
|
has_u8_field = true;
|
|
} else {
|
|
return syn::Error::new_spanned(
|
|
field,
|
|
"AtomicEnum only supports bool, u8, or i8 fields",
|
|
)
|
|
.to_compile_error()
|
|
.into();
|
|
}
|
|
}
|
|
_ => {
|
|
return syn::Error::new_spanned(
|
|
variant,
|
|
"AtomicEnum only supports unit variants or variants with a single field",
|
|
)
|
|
.to_compile_error()
|
|
.into();
|
|
}
|
|
}
|
|
}
|
|
|
|
let (storage_type, atomic_type) = if has_u8_field || (has_bool_field && max_discriminant > 127)
|
|
{
|
|
// Need u16: 8 bits for discriminant, 8 bits for data
|
|
(quote! { u16 }, quote! { ::std::sync::atomic::AtomicU16 })
|
|
} else {
|
|
// Can use u8: 7 bits for discriminant, 1 bit for bool (if any)
|
|
(quote! { u8 }, quote! { ::std::sync::atomic::AtomicU8 })
|
|
};
|
|
|
|
let use_u16 = has_u8_field || (has_bool_field && max_discriminant > 127);
|
|
|
|
let to_storage = variants.iter().enumerate().map(|(idx, variant)| {
|
|
let var_name = &variant.ident;
|
|
let disc = idx as u8; // The discriminant here is just the variant's index
|
|
|
|
match &variant.fields {
|
|
// Simple unit variant, just store the discriminant
|
|
Fields::Unit => {
|
|
if use_u16 {
|
|
quote! { #name::#var_name => #disc as u16 }
|
|
} else {
|
|
quote! { #name::#var_name => #disc }
|
|
}
|
|
}
|
|
Fields::Named(fields) => {
|
|
// Named field variant like `Write { schema_did_change: bool }`
|
|
let field = &fields.named[0];
|
|
let field_name = &field.ident;
|
|
|
|
if is_bool_type(&field.ty) {
|
|
if use_u16 {
|
|
// Pack as: [discriminant_byte | bool_as_byte]
|
|
// Example: Write {true} with disc=3 becomes: b100000011
|
|
quote! {
|
|
#name::#var_name { ref #field_name } => {
|
|
(#disc as u16) | ((*#field_name as u16) << 8)
|
|
}
|
|
}
|
|
} else {
|
|
// Same as above but with u8, so only 1 bit for bool
|
|
// Example: Write{true} with disc=3 becomes: b10000011
|
|
quote! {
|
|
#name::#var_name { ref #field_name } => {
|
|
#disc | ((*#field_name as u8) << 7)
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// u8/i8 field always uses u16 to have enough bits
|
|
// Pack as: [discriminant_byte | value_byte]
|
|
quote! {
|
|
#name::#var_name { ref #field_name } => {
|
|
(#disc as u16) | ((*#field_name as u16) << 8)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Fields::Unnamed(_) => {
|
|
// same strategy as above, but for tuple variants like `Write(bool)`
|
|
if is_bool_type(&variant.fields.iter().next().unwrap().ty) {
|
|
if use_u16 {
|
|
quote! {
|
|
#name::#var_name(ref val) => {
|
|
(#disc as u16) | ((*val as u16) << 8)
|
|
}
|
|
}
|
|
} else {
|
|
quote! {
|
|
#name::#var_name(ref val) => {
|
|
#disc | ((*val as u8) << 7)
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
quote! {
|
|
#name::#var_name(ref val) => {
|
|
(#disc as u16) | ((*val as u16) << 8)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Generate the match arms for decoding the storage representation back to enum
|
|
let from_storage = variants.iter().enumerate().map(|(idx, variant)| {
|
|
let var_name = &variant.ident;
|
|
let disc = idx as u8;
|
|
|
|
match &variant.fields {
|
|
Fields::Unit => quote! { #disc => #name::#var_name },
|
|
Fields::Named(fields) => {
|
|
let field = &fields.named[0];
|
|
let field_name = &field.ident;
|
|
|
|
if is_bool_type(&field.ty) {
|
|
if use_u16 {
|
|
// Extract bool from high byte: check if non-zero
|
|
quote! {
|
|
#disc => #name::#var_name {
|
|
#field_name: (val >> 8) != 0
|
|
}
|
|
}
|
|
} else {
|
|
// check single bool value at bit 7
|
|
quote! {
|
|
#disc => #name::#var_name {
|
|
#field_name: (val & 0x80) != 0
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
quote! {
|
|
#disc => #name::#var_name {
|
|
// Extract u8/i8 from high byte and cast to appropriate type
|
|
#field_name: (val >> 8) as _
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Fields::Unnamed(_) => {
|
|
if is_bool_type(&variant.fields.iter().next().unwrap().ty) {
|
|
if use_u16 {
|
|
quote! { #disc => #name::#var_name((val >> 8) != 0) }
|
|
} else {
|
|
quote! { #disc => #name::#var_name((val & 0x80) != 0) }
|
|
}
|
|
} else {
|
|
quote! { #disc => #name::#var_name((val >> 8) as _) }
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
let discriminant_mask = if use_u16 {
|
|
quote! { 0xFF }
|
|
} else {
|
|
quote! { 0x7F }
|
|
};
|
|
let to_storage_arms_copy = to_storage.clone();
|
|
|
|
let expanded = quote! {
|
|
#[derive(Debug)]
|
|
/// Atomic wrapper for #name
|
|
pub struct #atomic_name(#atomic_type);
|
|
|
|
impl #atomic_name {
|
|
/// Encode enum into storage representation
|
|
/// Discriminant in lower bits, field data in upper bits
|
|
#[inline]
|
|
fn to_storage(val: &#name) -> #storage_type {
|
|
match val {
|
|
#(#to_storage_arms_copy),*
|
|
}
|
|
}
|
|
|
|
/// Decode storage representation into enum
|
|
/// Panics on invalid discriminant
|
|
#[inline]
|
|
fn from_storage(val: #storage_type) -> #name {
|
|
let discriminant = (val & #discriminant_mask) as u8;
|
|
match discriminant {
|
|
#(#from_storage,)*
|
|
_ => panic!(concat!("Invalid ", stringify!(#name), " discriminant: {}"), discriminant),
|
|
}
|
|
}
|
|
|
|
/// Create new atomic enum with initial value
|
|
#[inline]
|
|
pub const fn new(val: #name) -> Self {
|
|
// Can't call to_storage in const context, so inline it
|
|
let storage = match val {
|
|
#(#to_storage),*
|
|
};
|
|
Self(#atomic_type::new(storage))
|
|
}
|
|
|
|
#[inline]
|
|
/// Load and convert the current value to expected enum
|
|
pub fn get(&self) -> #name {
|
|
Self::from_storage(self.0.load(::std::sync::atomic::Ordering::SeqCst))
|
|
}
|
|
|
|
#[inline]
|
|
/// Convert and store new value
|
|
pub fn set(&self, val: #name) {
|
|
self.0.store(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst)
|
|
}
|
|
|
|
#[inline]
|
|
/// Store new value and return previous value
|
|
pub fn swap(&self, val: #name) -> #name {
|
|
let prev = self.0.swap(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst);
|
|
Self::from_storage(prev)
|
|
}
|
|
}
|
|
|
|
impl From<#name> for #atomic_name {
|
|
fn from(val: #name) -> Self {
|
|
Self::new(val)
|
|
}
|
|
}
|
|
};
|
|
|
|
TokenStream::from(expanded)
|
|
}
|
|
|
|
fn is_bool_type(ty: &Type) -> bool {
|
|
if let Type::Path(path) = ty {
|
|
path.path.is_ident("bool")
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
fn is_u8_or_i8_type(ty: &Type) -> bool {
|
|
if let Type::Path(path) = ty {
|
|
path.path.is_ident("u8") || path.path.is_ident("i8")
|
|
} else {
|
|
false
|
|
}
|
|
}
|