Merge 'Add AtomicEnum proc macro to generate atomic wrappers to replace RwLocks' from Preston Thorpe

This PR adds the following derive macro
`AtomicEnum`
for the cases like the following:
```rust
pub enum SyncMode {
    Off = 0,
    Full = 2,
}
// or
pub enum CipherMode {
    Aes128Gcm,
    Aes256Gcm,
    Aegis256,
    Aegis128L,
    Aegis128X2,
    Aegis128X4,
    Aegis256X2,
    Aegis256X4,
}
```
Which are very basic enums, but which currently either require a
`RwLock` (the current solution for both of the above), or they require a
hand rolled atomic wrapper to keep the state without the lock.
```rust
pub struct AtomicDbState(AtomicUsize);

impl AtomicDbState {
    #[inline]
    pub const fn new(state: DbState) -> Self {
        Self(AtomicUsize::new(state as usize))
    }
```
This PR adds `AtomicEnum` derive macro which generates and let's us use
`AtomicDbState` or `AtomicCipherMode`, and derives `get`, `set` and
`swap` methods on them.
Each enum can have up to 1 named or unnamed field, and it supports i8/u8
and boolean types, which it encodes into half of a u16, with the
discriminant in the other half. Otherwise, it will just use a u8 and
encode the boolean into the 7th bit.

Closes #3766
This commit is contained in:
Pekka Enberg
2025-10-22 17:07:58 +03:00
committed by GitHub
8 changed files with 381 additions and 85 deletions

View File

@@ -40,6 +40,7 @@ pub mod numeric;
mod numeric;
use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES;
use crate::storage::encryption::AtomicCipherMode;
use crate::translate::display::PlanContext;
use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME;
#[cfg(all(feature = "fs", feature = "conn_raw_api"))]
@@ -93,7 +94,7 @@ pub use storage::{
wal::{CheckpointMode, CheckpointResult, Wal, WalFile, WalFileShared},
};
use tracing::{instrument, Level};
use turso_macros::match_ignore_ascii_case;
use turso_macros::{match_ignore_ascii_case, AtomicEnum};
use turso_parser::ast::fmt::ToTokens;
use turso_parser::{ast, ast::Cmd, parser::Parser};
use types::IOResult;
@@ -178,7 +179,7 @@ impl EncryptionOpts {
pub type Result<T, E = LimboError> = std::result::Result<T, E>;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[derive(Clone, AtomicEnum, Copy, PartialEq, Eq, Debug)]
enum TransactionState {
Write { schema_did_change: bool },
Read,
@@ -186,7 +187,7 @@ enum TransactionState {
None,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq, Eq)]
pub enum SyncMode {
Off = 0,
Full = 2,
@@ -562,7 +563,7 @@ impl Database {
schema: RwLock::new(self.schema.lock().unwrap().clone()),
database_schemas: RwLock::new(std::collections::HashMap::new()),
auto_commit: AtomicBool::new(true),
transaction_state: RwLock::new(TransactionState::None),
transaction_state: AtomicTransactionState::new(TransactionState::None),
last_insert_rowid: AtomicI64::new(0),
last_change: AtomicI64::new(0),
total_changes: AtomicI64::new(0),
@@ -580,8 +581,8 @@ impl Database {
metrics: RwLock::new(ConnectionMetrics::new()),
is_nested_stmt: AtomicBool::new(false),
encryption_key: RwLock::new(None),
encryption_cipher_mode: RwLock::new(None),
sync_mode: RwLock::new(SyncMode::Full),
encryption_cipher_mode: AtomicCipherMode::new(CipherMode::None),
sync_mode: AtomicSyncMode::new(SyncMode::Full),
data_sync_retry: AtomicBool::new(false),
busy_timeout: RwLock::new(Duration::new(0, 0)),
is_mvcc_bootstrap_connection: AtomicBool::new(is_mvcc_bootstrap_connection),
@@ -604,7 +605,7 @@ impl Database {
/// we need to read the page_size from the database header.
fn read_page_size_from_db_header(&self) -> Result<PageSize> {
turso_assert!(
self.db_state.is_initialized(),
self.db_state.get().is_initialized(),
"read_page_size_from_db_header called on uninitialized database"
);
turso_assert!(
@@ -622,7 +623,7 @@ impl Database {
fn read_reserved_space_bytes_from_db_header(&self) -> Result<u8> {
turso_assert!(
self.db_state.is_initialized(),
self.db_state.get().is_initialized(),
"read_reserved_space_bytes_from_db_header called on uninitialized database"
);
turso_assert!(
@@ -658,7 +659,7 @@ impl Database {
return Ok(page_size);
}
}
if self.db_state.is_initialized() {
if self.db_state.get().is_initialized() {
Ok(self.read_page_size_from_db_header()?)
} else {
let Some(size) = requested_page_size else {
@@ -674,7 +675,7 @@ impl Database {
/// if the database is initialized i.e. it exists on disk, return the reserved space bytes from
/// the header or None
fn maybe_get_reserved_space_bytes(&self) -> Result<Option<u8>> {
if self.db_state.is_initialized() {
if self.db_state.get().is_initialized() {
Ok(Some(self.read_reserved_space_bytes_from_db_header()?))
} else {
Ok(None)
@@ -696,7 +697,7 @@ impl Database {
drop(shared_wal);
let buffer_pool = self.buffer_pool.clone();
if self.db_state.is_initialized() {
if self.db_state.get().is_initialized() {
buffer_pool.finalize_with_page_size(page_size.get() as usize)?;
}
@@ -729,7 +730,7 @@ impl Database {
let buffer_pool = self.buffer_pool.clone();
if self.db_state.is_initialized() {
if self.db_state.get().is_initialized() {
buffer_pool.finalize_with_page_size(page_size.get() as usize)?;
}
@@ -1067,7 +1068,7 @@ pub struct Connection {
database_schemas: RwLock<std::collections::HashMap<usize, Arc<Schema>>>,
/// Whether to automatically commit transaction
auto_commit: AtomicBool,
transaction_state: RwLock<TransactionState>,
transaction_state: AtomicTransactionState,
last_insert_rowid: AtomicI64,
last_change: AtomicI64,
total_changes: AtomicI64,
@@ -1096,8 +1097,8 @@ pub struct Connection {
/// Generally this is only true for ParseSchema.
is_nested_stmt: AtomicBool,
encryption_key: RwLock<Option<EncryptionKey>>,
encryption_cipher_mode: RwLock<Option<CipherMode>>,
sync_mode: RwLock<SyncMode>,
encryption_cipher_mode: AtomicCipherMode,
sync_mode: AtomicSyncMode,
data_sync_retry: AtomicBool,
/// User defined max accumulated Busy timeout duration
/// Default is 0 (no timeout)
@@ -1238,8 +1239,7 @@ impl Connection {
let reparse_result = self.reparse_schema();
let previous =
std::mem::replace(&mut *self.transaction_state.write(), TransactionState::None);
let previous = self.transaction_state.swap(TransactionState::None);
turso_assert!(
matches!(previous, TransactionState::None | TransactionState::Read),
"unexpected end transaction state"
@@ -1519,7 +1519,7 @@ impl Connection {
let _ = conn.pragma_update("cipher", encryption_opts.cipher.to_string());
let _ = conn.pragma_update("hexkey", encryption_opts.hexkey.to_string());
let pager = conn.pager.read();
if db.db_state.is_initialized() {
if db.db_state.get().is_initialized() {
// Clear page cache so the header page can be reread from disk and decrypted using the encryption context.
pager.clear_page_cache(false);
}
@@ -1597,9 +1597,9 @@ impl Connection {
header.schema_cookie.get() < version,
"cookie can't go back in time"
);
*self.transaction_state.write() = TransactionState::Write {
self.set_tx_state(TransactionState::Write {
schema_did_change: true,
};
});
self.with_schema_mut(|schema| schema.schema_version = version);
header.schema_cookie = version.into();
})
@@ -1682,9 +1682,9 @@ impl Connection {
})?;
// start write transaction and disable auto-commit mode as SQL can be executed within WAL session (at caller own risk)
*self.transaction_state.write() = TransactionState::Write {
self.set_tx_state(TransactionState::Write {
schema_did_change: false,
};
});
self.auto_commit.store(false, Ordering::SeqCst);
Ok(())
@@ -2029,7 +2029,7 @@ impl Connection {
}
pub fn is_db_initialized(&self) -> bool {
self.db.db_state.is_initialized()
self.db.db_state.get().is_initialized()
}
fn get_pager_from_database_index(&self, index: &usize) -> Arc<Pager> {
@@ -2259,11 +2259,11 @@ impl Connection {
}
pub fn get_sync_mode(&self) -> SyncMode {
*self.sync_mode.read()
self.sync_mode.get()
}
pub fn set_sync_mode(&self, mode: SyncMode) {
*self.sync_mode.write() = mode;
self.sync_mode.set(mode);
}
pub fn get_data_sync_retry(&self) -> bool {
@@ -2289,7 +2289,7 @@ impl Connection {
pub fn set_encryption_cipher(&self, cipher_mode: CipherMode) -> Result<()> {
tracing::trace!("setting encryption cipher for connection");
*self.encryption_cipher_mode.write() = Some(cipher_mode);
self.encryption_cipher_mode.set(cipher_mode);
self.set_encryption_context()
}
@@ -2300,7 +2300,10 @@ impl Connection {
}
pub fn get_encryption_cipher_mode(&self) -> Option<CipherMode> {
*self.encryption_cipher_mode.read()
match self.encryption_cipher_mode.get() {
CipherMode::None => None,
mode => Some(mode),
}
}
// if both key and cipher are set, set encryption context on pager
@@ -2309,8 +2312,8 @@ impl Connection {
let Some(key) = key_guard.as_ref() else {
return Ok(());
};
let cipher_guard = self.encryption_cipher_mode.read();
let Some(cipher_mode) = *cipher_guard else {
let cipher_mode = self.get_encryption_cipher_mode();
let Some(cipher_mode) = cipher_mode else {
return Ok(());
};
tracing::trace!("setting encryption ctx for connection");
@@ -2348,11 +2351,11 @@ impl Connection {
}
fn set_tx_state(&self, state: TransactionState) {
*self.transaction_state.write() = state;
self.transaction_state.set(state);
}
fn get_tx_state(&self) -> TransactionState {
*self.transaction_state.read()
self.transaction_state.get()
}
pub(crate) fn get_mv_tx_id(&self) -> Option<u64> {

View File

@@ -325,9 +325,9 @@ impl<Clock: LogicalClock> CheckpointStateMachine<Clock> {
}
result?;
if self.update_transaction_state {
*self.connection.transaction_state.write() = TransactionState::Write {
self.connection.set_tx_state(TransactionState::Write {
schema_did_change: false,
}; // TODO: schema_did_change??
}); // TODO: schema_did_change??
}
self.lock_states.pager_write_tx = true;
self.state = CheckpointState::WriteRow {
@@ -534,7 +534,7 @@ impl<Clock: LogicalClock> CheckpointStateMachine<Clock> {
self.lock_states.pager_read_tx = false;
self.lock_states.pager_write_tx = false;
if self.update_transaction_state {
*self.connection.transaction_state.write() = TransactionState::None;
self.connection.set_tx_state(TransactionState::None);
}
let header = self
.pager
@@ -623,12 +623,12 @@ impl<Clock: LogicalClock> StateTransition for CheckpointStateMachine<Clock> {
if self.lock_states.pager_write_tx {
self.pager.rollback_tx(self.connection.as_ref());
if self.update_transaction_state {
*self.connection.transaction_state.write() = TransactionState::None;
self.connection.set_tx_state(TransactionState::None);
}
} else if self.lock_states.pager_read_tx {
self.pager.end_read_tx();
if self.update_transaction_state {
*self.connection.transaction_state.write() = TransactionState::None;
self.connection.set_tx_state(TransactionState::None);
}
}
if self.lock_states.blocking_checkpoint_lock_held {

View File

@@ -662,7 +662,7 @@ impl BTreeNodeState {
impl BTreeCursor {
pub fn new(pager: Arc<Pager>, root_page: i64, num_columns: usize) -> Self {
let valid_state = if root_page == 1 && !pager.db_state.is_initialized() {
let valid_state = if root_page == 1 && !pager.db_state.get().is_initialized() {
CursorValidState::Invalid
} else {
CursorValidState::Valid

View File

@@ -10,7 +10,7 @@ use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes128Gcm, Aes256Gcm, Key, Nonce,
};
use turso_macros::match_ignore_ascii_case;
use turso_macros::{match_ignore_ascii_case, AtomicEnum};
/// Encryption Scheme
/// We support two major algorithms: AEGIS, AES GCM. These algorithms picked so that they also do
@@ -319,8 +319,9 @@ define_aegis_cipher!(
"AEGIS-128X4"
);
#[derive(Debug, Clone, Copy, PartialEq)]
#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq)]
pub enum CipherMode {
None,
Aes128Gcm,
Aes256Gcm,
Aegis256,
@@ -363,6 +364,7 @@ impl std::fmt::Display for CipherMode {
CipherMode::Aegis128X4 => write!(f, "aegis128x4"),
CipherMode::Aegis256X2 => write!(f, "aegis256x2"),
CipherMode::Aegis256X4 => write!(f, "aegis256x4"),
CipherMode::None => write!(f, "None"),
}
}
}
@@ -380,6 +382,7 @@ impl CipherMode {
CipherMode::Aegis128L => 16,
CipherMode::Aegis128X2 => 16,
CipherMode::Aegis128X4 => 16,
CipherMode::None => 0,
}
}
@@ -394,6 +397,7 @@ impl CipherMode {
CipherMode::Aegis128L => 16,
CipherMode::Aegis128X2 => 16,
CipherMode::Aegis128X4 => 16,
CipherMode::None => 0,
}
}
@@ -408,6 +412,7 @@ impl CipherMode {
CipherMode::Aegis128L => 16,
CipherMode::Aegis128X2 => 16,
CipherMode::Aegis128X4 => 16,
CipherMode::None => 0,
}
}
@@ -427,6 +432,7 @@ impl CipherMode {
CipherMode::Aegis128L => 6,
CipherMode::Aegis128X2 => 7,
CipherMode::Aegis128X4 => 8,
CipherMode::None => 0,
}
}
@@ -503,6 +509,11 @@ impl EncryptionContext {
CipherMode::Aegis128L => Cipher::Aegis128L(Box::new(Aegis128LCipher::new(key))),
CipherMode::Aegis128X2 => Cipher::Aegis128X2(Box::new(Aegis128X2Cipher::new(key))),
CipherMode::Aegis128X4 => Cipher::Aegis128X4(Box::new(Aegis128X4Cipher::new(key))),
CipherMode::None => {
return Err(LimboError::InvalidArgument(
"must select valid CipherMode".into(),
))
}
};
Ok(Self {
cipher_mode,

View File

@@ -25,6 +25,7 @@ use std::sync::atomic::{
};
use std::sync::{Arc, Mutex};
use tracing::{instrument, trace, Level};
use turso_macros::AtomicEnum;
use super::btree::btree_init_page;
use super::page_cache::{CacheError, CacheResizeResult, PageCache, PageCacheKey};
@@ -57,7 +58,7 @@ impl HeaderRef {
tracing::trace!("HeaderRef::from_pager - {:?}", state);
match state {
HeaderRefState::Start => {
if !pager.db_state.is_initialized() {
if !pager.db_state.get().is_initialized() {
return Err(LimboError::Page1NotAlloc);
}
@@ -97,7 +98,7 @@ impl HeaderRefMut {
tracing::trace!(?state);
match state {
HeaderRefState::Start => {
if !pager.db_state.is_initialized() {
if !pager.db_state.get().is_initialized() {
return Err(LimboError::Page1NotAlloc);
}
@@ -416,57 +417,19 @@ impl From<u8> for AutoVacuumMode {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(usize)]
#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq, Eq)]
pub enum DbState {
Uninitialized = Self::UNINITIALIZED,
Initializing = Self::INITIALIZING,
Initialized = Self::INITIALIZED,
Uninitialized,
Initializing,
Initialized,
}
impl DbState {
pub(self) const UNINITIALIZED: usize = 0;
pub(self) const INITIALIZING: usize = 1;
pub(self) const INITIALIZED: usize = 2;
#[inline]
pub fn is_initialized(&self) -> bool {
matches!(self, DbState::Initialized)
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct AtomicDbState(AtomicUsize);
impl AtomicDbState {
#[inline]
pub const fn new(state: DbState) -> Self {
Self(AtomicUsize::new(state as usize))
}
#[inline]
pub fn set(&self, state: DbState) {
self.0.store(state as usize, Ordering::Release);
}
#[inline]
pub fn get(&self) -> DbState {
let v = self.0.load(Ordering::Acquire);
match v {
DbState::UNINITIALIZED => DbState::Uninitialized,
DbState::INITIALIZING => DbState::Initializing,
DbState::INITIALIZED => DbState::Initialized,
_ => unreachable!(),
}
}
#[inline]
pub fn is_initialized(&self) -> bool {
self.get().is_initialized()
}
}
#[derive(Debug, Clone)]
#[cfg(not(feature = "omit_autovacuum"))]
enum PtrMapGetState {
@@ -621,7 +584,7 @@ impl Pager {
db_state: Arc<AtomicDbState>,
init_lock: Arc<Mutex<()>>,
) -> Result<Self> {
let allocate_page1_state = if !db_state.is_initialized() {
let allocate_page1_state = if !db_state.get().is_initialized() {
RwLock::new(AllocatePage1State::Start)
} else {
RwLock::new(AllocatePage1State::Done)
@@ -1131,7 +1094,7 @@ impl Pager {
#[instrument(skip_all, level = Level::DEBUG)]
pub fn maybe_allocate_page1(&self) -> Result<IOResult<()>> {
if !self.db_state.is_initialized() {
if !self.db_state.get().is_initialized() {
if let Ok(_lock) = self.init_lock.try_lock() {
match (self.db_state.get(), self.allocating_page1()) {
// In case of being empty or (allocating and this connection is performing allocation) then allocate the first page

View File

@@ -11,6 +11,7 @@ use crate::{
LimboError, OpenFlags, Result, Statement, StepResult, SymbolTable,
};
use crate::{Connection, MvStore, IO};
use std::sync::atomic::AtomicU8;
use std::{
collections::HashMap,
rc::Rc,

290
macros/src/atomic_enum.rs Normal file
View File

@@ -0,0 +1,290 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
pub(crate) fn derive_atomic_enum_inner(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let atomic_name = syn::Ident::new(&format!("Atomic{name}"), name.span());
let variants = match &input.data {
Data::Enum(data) => &data.variants,
_ => {
return syn::Error::new_spanned(input, "AtomicEnum can only be derived for enums")
.to_compile_error()
.into();
}
};
// get info about variants to determine how we have to encode them
let mut has_bool_field = false;
let mut has_u8_field = false;
let mut max_discriminant = 0u8;
for (idx, variant) in variants.iter().enumerate() {
max_discriminant = idx as u8;
match &variant.fields {
Fields::Unit => {}
Fields::Named(fields) if fields.named.len() == 1 => {
let field = &fields.named[0];
if is_bool_type(&field.ty) {
has_bool_field = true;
} else if is_u8_or_i8_type(&field.ty) {
has_u8_field = true;
} else {
return syn::Error::new_spanned(
field,
"AtomicEnum only supports bool, u8, or i8 fields",
)
.to_compile_error()
.into();
}
}
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
let field = &fields.unnamed[0];
if is_bool_type(&field.ty) {
has_bool_field = true;
} else if is_u8_or_i8_type(&field.ty) {
has_u8_field = true;
} else {
return syn::Error::new_spanned(
field,
"AtomicEnum only supports bool, u8, or i8 fields",
)
.to_compile_error()
.into();
}
}
_ => {
return syn::Error::new_spanned(
variant,
"AtomicEnum only supports unit variants or variants with a single field",
)
.to_compile_error()
.into();
}
}
}
let (storage_type, atomic_type) = if has_u8_field || (has_bool_field && max_discriminant > 127)
{
// Need u16: 8 bits for discriminant, 8 bits for data
(quote! { u16 }, quote! { ::std::sync::atomic::AtomicU16 })
} else {
// Can use u8: 7 bits for discriminant, 1 bit for bool (if any)
(quote! { u8 }, quote! { ::std::sync::atomic::AtomicU8 })
};
let use_u16 = has_u8_field || (has_bool_field && max_discriminant > 127);
let to_storage = variants.iter().enumerate().map(|(idx, variant)| {
let var_name = &variant.ident;
let disc = idx as u8; // The discriminant here is just the variant's index
match &variant.fields {
// Simple unit variant, just store the discriminant
Fields::Unit => {
if use_u16 {
quote! { #name::#var_name => #disc as u16 }
} else {
quote! { #name::#var_name => #disc }
}
}
Fields::Named(fields) => {
// Named field variant like `Write { schema_did_change: bool }`
let field = &fields.named[0];
let field_name = &field.ident;
if is_bool_type(&field.ty) {
if use_u16 {
// Pack as: [discriminant_byte | bool_as_byte]
// Example: Write {true} with disc=3 becomes: b100000011
quote! {
#name::#var_name { ref #field_name } => {
(#disc as u16) | ((*#field_name as u16) << 8)
}
}
} else {
// Same as above but with u8, so only 1 bit for bool
// Example: Write{true} with disc=3 becomes: b10000011
quote! {
#name::#var_name { ref #field_name } => {
#disc | ((*#field_name as u8) << 7)
}
}
}
} else {
// u8/i8 field always uses u16 to have enough bits
// Pack as: [discriminant_byte | value_byte]
quote! {
#name::#var_name { ref #field_name } => {
(#disc as u16) | ((*#field_name as u16) << 8)
}
}
}
}
Fields::Unnamed(_) => {
// same strategy as above, but for tuple variants like `Write(bool)`
if is_bool_type(&variant.fields.iter().next().unwrap().ty) {
if use_u16 {
quote! {
#name::#var_name(ref val) => {
(#disc as u16) | ((*val as u16) << 8)
}
}
} else {
quote! {
#name::#var_name(ref val) => {
#disc | ((*val as u8) << 7)
}
}
}
} else {
quote! {
#name::#var_name(ref val) => {
(#disc as u16) | ((*val as u16) << 8)
}
}
}
}
}
});
// Generate the match arms for decoding the storage representation back to enum
let from_storage = variants.iter().enumerate().map(|(idx, variant)| {
let var_name = &variant.ident;
let disc = idx as u8;
match &variant.fields {
Fields::Unit => quote! { #disc => #name::#var_name },
Fields::Named(fields) => {
let field = &fields.named[0];
let field_name = &field.ident;
if is_bool_type(&field.ty) {
if use_u16 {
// Extract bool from high byte: check if non-zero
quote! {
#disc => #name::#var_name {
#field_name: (val >> 8) != 0
}
}
} else {
// check single bool value at bit 7
quote! {
#disc => #name::#var_name {
#field_name: (val & 0x80) != 0
}
}
}
} else {
quote! {
#disc => #name::#var_name {
// Extract u8/i8 from high byte and cast to appropriate type
#field_name: (val >> 8) as _
}
}
}
}
Fields::Unnamed(_) => {
if is_bool_type(&variant.fields.iter().next().unwrap().ty) {
if use_u16 {
quote! { #disc => #name::#var_name((val >> 8) != 0) }
} else {
quote! { #disc => #name::#var_name((val & 0x80) != 0) }
}
} else {
quote! { #disc => #name::#var_name((val >> 8) as _) }
}
}
}
});
let discriminant_mask = if use_u16 {
quote! { 0xFF }
} else {
quote! { 0x7F }
};
let to_storage_arms_copy = to_storage.clone();
let expanded = quote! {
#[derive(Debug)]
/// Atomic wrapper for #name
pub struct #atomic_name(#atomic_type);
impl #atomic_name {
/// Encode enum into storage representation
/// Discriminant in lower bits, field data in upper bits
#[inline]
fn to_storage(val: &#name) -> #storage_type {
match val {
#(#to_storage_arms_copy),*
}
}
/// Decode storage representation into enum
/// Panics on invalid discriminant
#[inline]
fn from_storage(val: #storage_type) -> #name {
let discriminant = (val & #discriminant_mask) as u8;
match discriminant {
#(#from_storage,)*
_ => panic!(concat!("Invalid ", stringify!(#name), " discriminant: {}"), discriminant),
}
}
/// Create new atomic enum with initial value
#[inline]
pub const fn new(val: #name) -> Self {
// Can't call to_storage in const context, so inline it
let storage = match val {
#(#to_storage),*
};
Self(#atomic_type::new(storage))
}
#[inline]
/// Load and convert the current value to expected enum
pub fn get(&self) -> #name {
Self::from_storage(self.0.load(::std::sync::atomic::Ordering::SeqCst))
}
#[inline]
/// Convert and store new value
pub fn set(&self, val: #name) {
self.0.store(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst)
}
#[inline]
/// Store new value and return previous value
pub fn swap(&self, val: #name) -> #name {
let prev = self.0.swap(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst);
Self::from_storage(prev)
}
}
impl From<#name> for #atomic_name {
fn from(val: #name) -> Self {
Self::new(val)
}
}
};
TokenStream::from(expanded)
}
fn is_bool_type(ty: &Type) -> bool {
if let Type::Path(path) = ty {
path.path.is_ident("bool")
} else {
false
}
}
fn is_u8_or_i8_type(ty: &Type) -> bool {
if let Type::Path(path) = ty {
path.path.is_ident("u8") || path.path.is_ident("i8")
} else {
false
}
}

View File

@@ -1,5 +1,6 @@
mod ext;
extern crate proc_macro;
mod atomic_enum;
use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree};
use std::collections::HashMap;
@@ -464,3 +465,30 @@ pub fn derive_vfs_module(input: TokenStream) -> TokenStream {
pub fn match_ignore_ascii_case(input: TokenStream) -> TokenStream {
ext::match_ignore_ascci_case(input)
}
/// Derive macro for creating atomic wrappers for enums
///
/// Supports:
/// - Unit variants
/// - Variants with single bool/u8/i8 fields
/// - Named or unnamed fields
///
/// Algorithm:
/// - Uses u8 representation, splitting bits for variant discriminant and field data
/// - For bool fields: high bit for bool, lower 7 bits for discriminant
/// - For u8/i8 fields: uses u16 internally (8 bits discriminant, 8 bits data)
///
/// Example:
/// ```ignore
/// #[derive(AtomicEnum)]
/// enum TransactionState {
/// Write { schema_did_change: bool },
/// Read,
/// PendingUpgrade,
/// None,
/// }
/// ```
#[proc_macro_derive(AtomicEnum)]
pub fn derive_atomic_enum(input: TokenStream) -> TokenStream {
atomic_enum::derive_atomic_enum_inner(input)
}