diff --git a/core/lib.rs b/core/lib.rs index 637e91c92..c3149b64d 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -40,6 +40,7 @@ pub mod numeric; mod numeric; use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES; +use crate::storage::encryption::AtomicCipherMode; use crate::translate::display::PlanContext; use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME; #[cfg(all(feature = "fs", feature = "conn_raw_api"))] @@ -93,7 +94,7 @@ pub use storage::{ wal::{CheckpointMode, CheckpointResult, Wal, WalFile, WalFileShared}, }; use tracing::{instrument, Level}; -use turso_macros::match_ignore_ascii_case; +use turso_macros::{match_ignore_ascii_case, AtomicEnum}; use turso_parser::ast::fmt::ToTokens; use turso_parser::{ast, ast::Cmd, parser::Parser}; use types::IOResult; @@ -178,7 +179,7 @@ impl EncryptionOpts { pub type Result = std::result::Result; -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Clone, AtomicEnum, Copy, PartialEq, Eq, Debug)] enum TransactionState { Write { schema_did_change: bool }, Read, @@ -186,7 +187,7 @@ enum TransactionState { None, } -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq, Eq)] pub enum SyncMode { Off = 0, Full = 2, @@ -562,7 +563,7 @@ impl Database { schema: RwLock::new(self.schema.lock().unwrap().clone()), database_schemas: RwLock::new(std::collections::HashMap::new()), auto_commit: AtomicBool::new(true), - transaction_state: RwLock::new(TransactionState::None), + transaction_state: AtomicTransactionState::new(TransactionState::None), last_insert_rowid: AtomicI64::new(0), last_change: AtomicI64::new(0), total_changes: AtomicI64::new(0), @@ -580,8 +581,8 @@ impl Database { metrics: RwLock::new(ConnectionMetrics::new()), is_nested_stmt: AtomicBool::new(false), encryption_key: RwLock::new(None), - encryption_cipher_mode: RwLock::new(None), - sync_mode: RwLock::new(SyncMode::Full), + encryption_cipher_mode: AtomicCipherMode::new(CipherMode::None), + sync_mode: AtomicSyncMode::new(SyncMode::Full), data_sync_retry: AtomicBool::new(false), busy_timeout: RwLock::new(Duration::new(0, 0)), is_mvcc_bootstrap_connection: AtomicBool::new(is_mvcc_bootstrap_connection), @@ -604,7 +605,7 @@ impl Database { /// we need to read the page_size from the database header. fn read_page_size_from_db_header(&self) -> Result { turso_assert!( - self.db_state.is_initialized(), + self.db_state.get().is_initialized(), "read_page_size_from_db_header called on uninitialized database" ); turso_assert!( @@ -622,7 +623,7 @@ impl Database { fn read_reserved_space_bytes_from_db_header(&self) -> Result { turso_assert!( - self.db_state.is_initialized(), + self.db_state.get().is_initialized(), "read_reserved_space_bytes_from_db_header called on uninitialized database" ); turso_assert!( @@ -658,7 +659,7 @@ impl Database { return Ok(page_size); } } - if self.db_state.is_initialized() { + if self.db_state.get().is_initialized() { Ok(self.read_page_size_from_db_header()?) } else { let Some(size) = requested_page_size else { @@ -674,7 +675,7 @@ impl Database { /// if the database is initialized i.e. it exists on disk, return the reserved space bytes from /// the header or None fn maybe_get_reserved_space_bytes(&self) -> Result> { - if self.db_state.is_initialized() { + if self.db_state.get().is_initialized() { Ok(Some(self.read_reserved_space_bytes_from_db_header()?)) } else { Ok(None) @@ -696,7 +697,7 @@ impl Database { drop(shared_wal); let buffer_pool = self.buffer_pool.clone(); - if self.db_state.is_initialized() { + if self.db_state.get().is_initialized() { buffer_pool.finalize_with_page_size(page_size.get() as usize)?; } @@ -729,7 +730,7 @@ impl Database { let buffer_pool = self.buffer_pool.clone(); - if self.db_state.is_initialized() { + if self.db_state.get().is_initialized() { buffer_pool.finalize_with_page_size(page_size.get() as usize)?; } @@ -1067,7 +1068,7 @@ pub struct Connection { database_schemas: RwLock>>, /// Whether to automatically commit transaction auto_commit: AtomicBool, - transaction_state: RwLock, + transaction_state: AtomicTransactionState, last_insert_rowid: AtomicI64, last_change: AtomicI64, total_changes: AtomicI64, @@ -1096,8 +1097,8 @@ pub struct Connection { /// Generally this is only true for ParseSchema. is_nested_stmt: AtomicBool, encryption_key: RwLock>, - encryption_cipher_mode: RwLock>, - sync_mode: RwLock, + encryption_cipher_mode: AtomicCipherMode, + sync_mode: AtomicSyncMode, data_sync_retry: AtomicBool, /// User defined max accumulated Busy timeout duration /// Default is 0 (no timeout) @@ -1238,8 +1239,7 @@ impl Connection { let reparse_result = self.reparse_schema(); - let previous = - std::mem::replace(&mut *self.transaction_state.write(), TransactionState::None); + let previous = self.transaction_state.swap(TransactionState::None); turso_assert!( matches!(previous, TransactionState::None | TransactionState::Read), "unexpected end transaction state" @@ -1519,7 +1519,7 @@ impl Connection { let _ = conn.pragma_update("cipher", encryption_opts.cipher.to_string()); let _ = conn.pragma_update("hexkey", encryption_opts.hexkey.to_string()); let pager = conn.pager.read(); - if db.db_state.is_initialized() { + if db.db_state.get().is_initialized() { // Clear page cache so the header page can be reread from disk and decrypted using the encryption context. pager.clear_page_cache(false); } @@ -1597,9 +1597,9 @@ impl Connection { header.schema_cookie.get() < version, "cookie can't go back in time" ); - *self.transaction_state.write() = TransactionState::Write { + self.set_tx_state(TransactionState::Write { schema_did_change: true, - }; + }); self.with_schema_mut(|schema| schema.schema_version = version); header.schema_cookie = version.into(); }) @@ -1682,9 +1682,9 @@ impl Connection { })?; // start write transaction and disable auto-commit mode as SQL can be executed within WAL session (at caller own risk) - *self.transaction_state.write() = TransactionState::Write { + self.set_tx_state(TransactionState::Write { schema_did_change: false, - }; + }); self.auto_commit.store(false, Ordering::SeqCst); Ok(()) @@ -2029,7 +2029,7 @@ impl Connection { } pub fn is_db_initialized(&self) -> bool { - self.db.db_state.is_initialized() + self.db.db_state.get().is_initialized() } fn get_pager_from_database_index(&self, index: &usize) -> Arc { @@ -2259,11 +2259,11 @@ impl Connection { } pub fn get_sync_mode(&self) -> SyncMode { - *self.sync_mode.read() + self.sync_mode.get() } pub fn set_sync_mode(&self, mode: SyncMode) { - *self.sync_mode.write() = mode; + self.sync_mode.set(mode); } pub fn get_data_sync_retry(&self) -> bool { @@ -2289,7 +2289,7 @@ impl Connection { pub fn set_encryption_cipher(&self, cipher_mode: CipherMode) -> Result<()> { tracing::trace!("setting encryption cipher for connection"); - *self.encryption_cipher_mode.write() = Some(cipher_mode); + self.encryption_cipher_mode.set(cipher_mode); self.set_encryption_context() } @@ -2300,7 +2300,10 @@ impl Connection { } pub fn get_encryption_cipher_mode(&self) -> Option { - *self.encryption_cipher_mode.read() + match self.encryption_cipher_mode.get() { + CipherMode::None => None, + mode => Some(mode), + } } // if both key and cipher are set, set encryption context on pager @@ -2309,8 +2312,8 @@ impl Connection { let Some(key) = key_guard.as_ref() else { return Ok(()); }; - let cipher_guard = self.encryption_cipher_mode.read(); - let Some(cipher_mode) = *cipher_guard else { + let cipher_mode = self.get_encryption_cipher_mode(); + let Some(cipher_mode) = cipher_mode else { return Ok(()); }; tracing::trace!("setting encryption ctx for connection"); @@ -2348,11 +2351,11 @@ impl Connection { } fn set_tx_state(&self, state: TransactionState) { - *self.transaction_state.write() = state; + self.transaction_state.set(state); } fn get_tx_state(&self) -> TransactionState { - *self.transaction_state.read() + self.transaction_state.get() } pub(crate) fn get_mv_tx_id(&self) -> Option { diff --git a/core/mvcc/database/checkpoint_state_machine.rs b/core/mvcc/database/checkpoint_state_machine.rs index 8c3e109e3..712cc8048 100644 --- a/core/mvcc/database/checkpoint_state_machine.rs +++ b/core/mvcc/database/checkpoint_state_machine.rs @@ -325,9 +325,9 @@ impl CheckpointStateMachine { } result?; if self.update_transaction_state { - *self.connection.transaction_state.write() = TransactionState::Write { + self.connection.set_tx_state(TransactionState::Write { schema_did_change: false, - }; // TODO: schema_did_change?? + }); // TODO: schema_did_change?? } self.lock_states.pager_write_tx = true; self.state = CheckpointState::WriteRow { @@ -534,7 +534,7 @@ impl CheckpointStateMachine { self.lock_states.pager_read_tx = false; self.lock_states.pager_write_tx = false; if self.update_transaction_state { - *self.connection.transaction_state.write() = TransactionState::None; + self.connection.set_tx_state(TransactionState::None); } let header = self .pager @@ -623,12 +623,12 @@ impl StateTransition for CheckpointStateMachine { if self.lock_states.pager_write_tx { self.pager.rollback_tx(self.connection.as_ref()); if self.update_transaction_state { - *self.connection.transaction_state.write() = TransactionState::None; + self.connection.set_tx_state(TransactionState::None); } } else if self.lock_states.pager_read_tx { self.pager.end_read_tx(); if self.update_transaction_state { - *self.connection.transaction_state.write() = TransactionState::None; + self.connection.set_tx_state(TransactionState::None); } } if self.lock_states.blocking_checkpoint_lock_held { diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 440c0de79..0c5279608 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -662,7 +662,7 @@ impl BTreeNodeState { impl BTreeCursor { pub fn new(pager: Arc, root_page: i64, num_columns: usize) -> Self { - let valid_state = if root_page == 1 && !pager.db_state.is_initialized() { + let valid_state = if root_page == 1 && !pager.db_state.get().is_initialized() { CursorValidState::Invalid } else { CursorValidState::Valid diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index 156d17c0b..f0184dbc0 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -10,7 +10,7 @@ use aes_gcm::{ aead::{Aead, AeadCore, KeyInit, OsRng}, Aes128Gcm, Aes256Gcm, Key, Nonce, }; -use turso_macros::match_ignore_ascii_case; +use turso_macros::{match_ignore_ascii_case, AtomicEnum}; /// Encryption Scheme /// We support two major algorithms: AEGIS, AES GCM. These algorithms picked so that they also do @@ -319,8 +319,9 @@ define_aegis_cipher!( "AEGIS-128X4" ); -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq)] pub enum CipherMode { + None, Aes128Gcm, Aes256Gcm, Aegis256, @@ -363,6 +364,7 @@ impl std::fmt::Display for CipherMode { CipherMode::Aegis128X4 => write!(f, "aegis128x4"), CipherMode::Aegis256X2 => write!(f, "aegis256x2"), CipherMode::Aegis256X4 => write!(f, "aegis256x4"), + CipherMode::None => write!(f, "None"), } } } @@ -380,6 +382,7 @@ impl CipherMode { CipherMode::Aegis128L => 16, CipherMode::Aegis128X2 => 16, CipherMode::Aegis128X4 => 16, + CipherMode::None => 0, } } @@ -394,6 +397,7 @@ impl CipherMode { CipherMode::Aegis128L => 16, CipherMode::Aegis128X2 => 16, CipherMode::Aegis128X4 => 16, + CipherMode::None => 0, } } @@ -408,6 +412,7 @@ impl CipherMode { CipherMode::Aegis128L => 16, CipherMode::Aegis128X2 => 16, CipherMode::Aegis128X4 => 16, + CipherMode::None => 0, } } @@ -427,6 +432,7 @@ impl CipherMode { CipherMode::Aegis128L => 6, CipherMode::Aegis128X2 => 7, CipherMode::Aegis128X4 => 8, + CipherMode::None => 0, } } @@ -503,6 +509,11 @@ impl EncryptionContext { CipherMode::Aegis128L => Cipher::Aegis128L(Box::new(Aegis128LCipher::new(key))), CipherMode::Aegis128X2 => Cipher::Aegis128X2(Box::new(Aegis128X2Cipher::new(key))), CipherMode::Aegis128X4 => Cipher::Aegis128X4(Box::new(Aegis128X4Cipher::new(key))), + CipherMode::None => { + return Err(LimboError::InvalidArgument( + "must select valid CipherMode".into(), + )) + } }; Ok(Self { cipher_mode, diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 2e6b50920..ddb791a0f 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -25,6 +25,7 @@ use std::sync::atomic::{ }; use std::sync::{Arc, Mutex}; use tracing::{instrument, trace, Level}; +use turso_macros::AtomicEnum; use super::btree::btree_init_page; use super::page_cache::{CacheError, CacheResizeResult, PageCache, PageCacheKey}; @@ -57,7 +58,7 @@ impl HeaderRef { tracing::trace!("HeaderRef::from_pager - {:?}", state); match state { HeaderRefState::Start => { - if !pager.db_state.is_initialized() { + if !pager.db_state.get().is_initialized() { return Err(LimboError::Page1NotAlloc); } @@ -97,7 +98,7 @@ impl HeaderRefMut { tracing::trace!(?state); match state { HeaderRefState::Start => { - if !pager.db_state.is_initialized() { + if !pager.db_state.get().is_initialized() { return Err(LimboError::Page1NotAlloc); } @@ -416,57 +417,19 @@ impl From for AutoVacuumMode { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(usize)] +#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq, Eq)] pub enum DbState { - Uninitialized = Self::UNINITIALIZED, - Initializing = Self::INITIALIZING, - Initialized = Self::INITIALIZED, + Uninitialized, + Initializing, + Initialized, } impl DbState { - pub(self) const UNINITIALIZED: usize = 0; - pub(self) const INITIALIZING: usize = 1; - pub(self) const INITIALIZED: usize = 2; - - #[inline] pub fn is_initialized(&self) -> bool { matches!(self, DbState::Initialized) } } -#[derive(Debug)] -#[repr(transparent)] -pub struct AtomicDbState(AtomicUsize); - -impl AtomicDbState { - #[inline] - pub const fn new(state: DbState) -> Self { - Self(AtomicUsize::new(state as usize)) - } - - #[inline] - pub fn set(&self, state: DbState) { - self.0.store(state as usize, Ordering::Release); - } - - #[inline] - pub fn get(&self) -> DbState { - let v = self.0.load(Ordering::Acquire); - match v { - DbState::UNINITIALIZED => DbState::Uninitialized, - DbState::INITIALIZING => DbState::Initializing, - DbState::INITIALIZED => DbState::Initialized, - _ => unreachable!(), - } - } - - #[inline] - pub fn is_initialized(&self) -> bool { - self.get().is_initialized() - } -} - #[derive(Debug, Clone)] #[cfg(not(feature = "omit_autovacuum"))] enum PtrMapGetState { @@ -621,7 +584,7 @@ impl Pager { db_state: Arc, init_lock: Arc>, ) -> Result { - let allocate_page1_state = if !db_state.is_initialized() { + let allocate_page1_state = if !db_state.get().is_initialized() { RwLock::new(AllocatePage1State::Start) } else { RwLock::new(AllocatePage1State::Done) @@ -1131,7 +1094,7 @@ impl Pager { #[instrument(skip_all, level = Level::DEBUG)] pub fn maybe_allocate_page1(&self) -> Result> { - if !self.db_state.is_initialized() { + if !self.db_state.get().is_initialized() { if let Ok(_lock) = self.init_lock.try_lock() { match (self.db_state.get(), self.allocating_page1()) { // In case of being empty or (allocating and this connection is performing allocation) then allocate the first page diff --git a/core/util.rs b/core/util.rs index 26696eabe..77062fd7d 100644 --- a/core/util.rs +++ b/core/util.rs @@ -11,6 +11,7 @@ use crate::{ LimboError, OpenFlags, Result, Statement, StepResult, SymbolTable, }; use crate::{Connection, MvStore, IO}; +use std::sync::atomic::AtomicU8; use std::{ collections::HashMap, rc::Rc, diff --git a/macros/src/atomic_enum.rs b/macros/src/atomic_enum.rs new file mode 100644 index 000000000..14dd97ac0 --- /dev/null +++ b/macros/src/atomic_enum.rs @@ -0,0 +1,290 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, Data, DeriveInput, Fields, Type}; + +pub(crate) fn derive_atomic_enum_inner(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = &input.ident; + let atomic_name = syn::Ident::new(&format!("Atomic{name}"), name.span()); + + let variants = match &input.data { + Data::Enum(data) => &data.variants, + _ => { + return syn::Error::new_spanned(input, "AtomicEnum can only be derived for enums") + .to_compile_error() + .into(); + } + }; + + // get info about variants to determine how we have to encode them + let mut has_bool_field = false; + let mut has_u8_field = false; + let mut max_discriminant = 0u8; + + for (idx, variant) in variants.iter().enumerate() { + max_discriminant = idx as u8; + match &variant.fields { + Fields::Unit => {} + Fields::Named(fields) if fields.named.len() == 1 => { + let field = &fields.named[0]; + if is_bool_type(&field.ty) { + has_bool_field = true; + } else if is_u8_or_i8_type(&field.ty) { + has_u8_field = true; + } else { + return syn::Error::new_spanned( + field, + "AtomicEnum only supports bool, u8, or i8 fields", + ) + .to_compile_error() + .into(); + } + } + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + let field = &fields.unnamed[0]; + if is_bool_type(&field.ty) { + has_bool_field = true; + } else if is_u8_or_i8_type(&field.ty) { + has_u8_field = true; + } else { + return syn::Error::new_spanned( + field, + "AtomicEnum only supports bool, u8, or i8 fields", + ) + .to_compile_error() + .into(); + } + } + _ => { + return syn::Error::new_spanned( + variant, + "AtomicEnum only supports unit variants or variants with a single field", + ) + .to_compile_error() + .into(); + } + } + } + + let (storage_type, atomic_type) = if has_u8_field || (has_bool_field && max_discriminant > 127) + { + // Need u16: 8 bits for discriminant, 8 bits for data + (quote! { u16 }, quote! { ::std::sync::atomic::AtomicU16 }) + } else { + // Can use u8: 7 bits for discriminant, 1 bit for bool (if any) + (quote! { u8 }, quote! { ::std::sync::atomic::AtomicU8 }) + }; + + let use_u16 = has_u8_field || (has_bool_field && max_discriminant > 127); + + let to_storage = variants.iter().enumerate().map(|(idx, variant)| { + let var_name = &variant.ident; + let disc = idx as u8; // The discriminant here is just the variant's index + + match &variant.fields { + // Simple unit variant, just store the discriminant + Fields::Unit => { + if use_u16 { + quote! { #name::#var_name => #disc as u16 } + } else { + quote! { #name::#var_name => #disc } + } + } + Fields::Named(fields) => { + // Named field variant like `Write { schema_did_change: bool }` + let field = &fields.named[0]; + let field_name = &field.ident; + + if is_bool_type(&field.ty) { + if use_u16 { + // Pack as: [discriminant_byte | bool_as_byte] + // Example: Write {true} with disc=3 becomes: b100000011 + quote! { + #name::#var_name { ref #field_name } => { + (#disc as u16) | ((*#field_name as u16) << 8) + } + } + } else { + // Same as above but with u8, so only 1 bit for bool + // Example: Write{true} with disc=3 becomes: b10000011 + quote! { + #name::#var_name { ref #field_name } => { + #disc | ((*#field_name as u8) << 7) + } + } + } + } else { + // u8/i8 field always uses u16 to have enough bits + // Pack as: [discriminant_byte | value_byte] + quote! { + #name::#var_name { ref #field_name } => { + (#disc as u16) | ((*#field_name as u16) << 8) + } + } + } + } + Fields::Unnamed(_) => { + // same strategy as above, but for tuple variants like `Write(bool)` + if is_bool_type(&variant.fields.iter().next().unwrap().ty) { + if use_u16 { + quote! { + #name::#var_name(ref val) => { + (#disc as u16) | ((*val as u16) << 8) + } + } + } else { + quote! { + #name::#var_name(ref val) => { + #disc | ((*val as u8) << 7) + } + } + } + } else { + quote! { + #name::#var_name(ref val) => { + (#disc as u16) | ((*val as u16) << 8) + } + } + } + } + } + }); + + // Generate the match arms for decoding the storage representation back to enum + let from_storage = variants.iter().enumerate().map(|(idx, variant)| { + let var_name = &variant.ident; + let disc = idx as u8; + + match &variant.fields { + Fields::Unit => quote! { #disc => #name::#var_name }, + Fields::Named(fields) => { + let field = &fields.named[0]; + let field_name = &field.ident; + + if is_bool_type(&field.ty) { + if use_u16 { + // Extract bool from high byte: check if non-zero + quote! { + #disc => #name::#var_name { + #field_name: (val >> 8) != 0 + } + } + } else { + // check single bool value at bit 7 + quote! { + #disc => #name::#var_name { + #field_name: (val & 0x80) != 0 + } + } + } + } else { + quote! { + #disc => #name::#var_name { + // Extract u8/i8 from high byte and cast to appropriate type + #field_name: (val >> 8) as _ + } + } + } + } + Fields::Unnamed(_) => { + if is_bool_type(&variant.fields.iter().next().unwrap().ty) { + if use_u16 { + quote! { #disc => #name::#var_name((val >> 8) != 0) } + } else { + quote! { #disc => #name::#var_name((val & 0x80) != 0) } + } + } else { + quote! { #disc => #name::#var_name((val >> 8) as _) } + } + } + } + }); + + let discriminant_mask = if use_u16 { + quote! { 0xFF } + } else { + quote! { 0x7F } + }; + let to_storage_arms_copy = to_storage.clone(); + + let expanded = quote! { + #[derive(Debug)] + /// Atomic wrapper for #name + pub struct #atomic_name(#atomic_type); + + impl #atomic_name { + /// Encode enum into storage representation + /// Discriminant in lower bits, field data in upper bits + #[inline] + fn to_storage(val: &#name) -> #storage_type { + match val { + #(#to_storage_arms_copy),* + } + } + + /// Decode storage representation into enum + /// Panics on invalid discriminant + #[inline] + fn from_storage(val: #storage_type) -> #name { + let discriminant = (val & #discriminant_mask) as u8; + match discriminant { + #(#from_storage,)* + _ => panic!(concat!("Invalid ", stringify!(#name), " discriminant: {}"), discriminant), + } + } + + /// Create new atomic enum with initial value + #[inline] + pub const fn new(val: #name) -> Self { + // Can't call to_storage in const context, so inline it + let storage = match val { + #(#to_storage),* + }; + Self(#atomic_type::new(storage)) + } + + #[inline] + /// Load and convert the current value to expected enum + pub fn get(&self) -> #name { + Self::from_storage(self.0.load(::std::sync::atomic::Ordering::SeqCst)) + } + + #[inline] + /// Convert and store new value + pub fn set(&self, val: #name) { + self.0.store(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst) + } + + #[inline] + /// Store new value and return previous value + pub fn swap(&self, val: #name) -> #name { + let prev = self.0.swap(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst); + Self::from_storage(prev) + } + } + + impl From<#name> for #atomic_name { + fn from(val: #name) -> Self { + Self::new(val) + } + } + }; + + TokenStream::from(expanded) +} + +fn is_bool_type(ty: &Type) -> bool { + if let Type::Path(path) = ty { + path.path.is_ident("bool") + } else { + false + } +} + +fn is_u8_or_i8_type(ty: &Type) -> bool { + if let Type::Path(path) = ty { + path.path.is_ident("u8") || path.path.is_ident("i8") + } else { + false + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8df01da22..9db89bad2 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,5 +1,6 @@ mod ext; extern crate proc_macro; +mod atomic_enum; use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree}; use std::collections::HashMap; @@ -464,3 +465,30 @@ pub fn derive_vfs_module(input: TokenStream) -> TokenStream { pub fn match_ignore_ascii_case(input: TokenStream) -> TokenStream { ext::match_ignore_ascci_case(input) } + +/// Derive macro for creating atomic wrappers for enums +/// +/// Supports: +/// - Unit variants +/// - Variants with single bool/u8/i8 fields +/// - Named or unnamed fields +/// +/// Algorithm: +/// - Uses u8 representation, splitting bits for variant discriminant and field data +/// - For bool fields: high bit for bool, lower 7 bits for discriminant +/// - For u8/i8 fields: uses u16 internally (8 bits discriminant, 8 bits data) +/// +/// Example: +/// ```ignore +/// #[derive(AtomicEnum)] +/// enum TransactionState { +/// Write { schema_did_change: bool }, +/// Read, +/// PendingUpgrade, +/// None, +/// } +/// ``` +#[proc_macro_derive(AtomicEnum)] +pub fn derive_atomic_enum(input: TokenStream) -> TokenStream { + atomic_enum::derive_atomic_enum_inner(input) +}