mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-22 01:24:18 +01:00
Merge 'Add AtomicEnum proc macro to generate atomic wrappers to replace RwLocks' from Preston Thorpe
This PR adds the following derive macro
`AtomicEnum`
for the cases like the following:
```rust
pub enum SyncMode {
Off = 0,
Full = 2,
}
// or
pub enum CipherMode {
Aes128Gcm,
Aes256Gcm,
Aegis256,
Aegis128L,
Aegis128X2,
Aegis128X4,
Aegis256X2,
Aegis256X4,
}
```
Which are very basic enums, but which currently either require a
`RwLock` (the current solution for both of the above), or they require a
hand rolled atomic wrapper to keep the state without the lock.
```rust
pub struct AtomicDbState(AtomicUsize);
impl AtomicDbState {
#[inline]
pub const fn new(state: DbState) -> Self {
Self(AtomicUsize::new(state as usize))
}
```
This PR adds `AtomicEnum` derive macro which generates and let's us use
`AtomicDbState` or `AtomicCipherMode`, and derives `get`, `set` and
`swap` methods on them.
Each enum can have up to 1 named or unnamed field, and it supports i8/u8
and boolean types, which it encodes into half of a u16, with the
discriminant in the other half. Otherwise, it will just use a u8 and
encode the boolean into the 7th bit.
Closes #3766
This commit is contained in:
65
core/lib.rs
65
core/lib.rs
@@ -40,6 +40,7 @@ pub mod numeric;
|
||||
mod numeric;
|
||||
|
||||
use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES;
|
||||
use crate::storage::encryption::AtomicCipherMode;
|
||||
use crate::translate::display::PlanContext;
|
||||
use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME;
|
||||
#[cfg(all(feature = "fs", feature = "conn_raw_api"))]
|
||||
@@ -93,7 +94,7 @@ pub use storage::{
|
||||
wal::{CheckpointMode, CheckpointResult, Wal, WalFile, WalFileShared},
|
||||
};
|
||||
use tracing::{instrument, Level};
|
||||
use turso_macros::match_ignore_ascii_case;
|
||||
use turso_macros::{match_ignore_ascii_case, AtomicEnum};
|
||||
use turso_parser::ast::fmt::ToTokens;
|
||||
use turso_parser::{ast, ast::Cmd, parser::Parser};
|
||||
use types::IOResult;
|
||||
@@ -178,7 +179,7 @@ impl EncryptionOpts {
|
||||
|
||||
pub type Result<T, E = LimboError> = std::result::Result<T, E>;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
#[derive(Clone, AtomicEnum, Copy, PartialEq, Eq, Debug)]
|
||||
enum TransactionState {
|
||||
Write { schema_did_change: bool },
|
||||
Read,
|
||||
@@ -186,7 +187,7 @@ enum TransactionState {
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SyncMode {
|
||||
Off = 0,
|
||||
Full = 2,
|
||||
@@ -562,7 +563,7 @@ impl Database {
|
||||
schema: RwLock::new(self.schema.lock().unwrap().clone()),
|
||||
database_schemas: RwLock::new(std::collections::HashMap::new()),
|
||||
auto_commit: AtomicBool::new(true),
|
||||
transaction_state: RwLock::new(TransactionState::None),
|
||||
transaction_state: AtomicTransactionState::new(TransactionState::None),
|
||||
last_insert_rowid: AtomicI64::new(0),
|
||||
last_change: AtomicI64::new(0),
|
||||
total_changes: AtomicI64::new(0),
|
||||
@@ -580,8 +581,8 @@ impl Database {
|
||||
metrics: RwLock::new(ConnectionMetrics::new()),
|
||||
is_nested_stmt: AtomicBool::new(false),
|
||||
encryption_key: RwLock::new(None),
|
||||
encryption_cipher_mode: RwLock::new(None),
|
||||
sync_mode: RwLock::new(SyncMode::Full),
|
||||
encryption_cipher_mode: AtomicCipherMode::new(CipherMode::None),
|
||||
sync_mode: AtomicSyncMode::new(SyncMode::Full),
|
||||
data_sync_retry: AtomicBool::new(false),
|
||||
busy_timeout: RwLock::new(Duration::new(0, 0)),
|
||||
is_mvcc_bootstrap_connection: AtomicBool::new(is_mvcc_bootstrap_connection),
|
||||
@@ -604,7 +605,7 @@ impl Database {
|
||||
/// we need to read the page_size from the database header.
|
||||
fn read_page_size_from_db_header(&self) -> Result<PageSize> {
|
||||
turso_assert!(
|
||||
self.db_state.is_initialized(),
|
||||
self.db_state.get().is_initialized(),
|
||||
"read_page_size_from_db_header called on uninitialized database"
|
||||
);
|
||||
turso_assert!(
|
||||
@@ -622,7 +623,7 @@ impl Database {
|
||||
|
||||
fn read_reserved_space_bytes_from_db_header(&self) -> Result<u8> {
|
||||
turso_assert!(
|
||||
self.db_state.is_initialized(),
|
||||
self.db_state.get().is_initialized(),
|
||||
"read_reserved_space_bytes_from_db_header called on uninitialized database"
|
||||
);
|
||||
turso_assert!(
|
||||
@@ -658,7 +659,7 @@ impl Database {
|
||||
return Ok(page_size);
|
||||
}
|
||||
}
|
||||
if self.db_state.is_initialized() {
|
||||
if self.db_state.get().is_initialized() {
|
||||
Ok(self.read_page_size_from_db_header()?)
|
||||
} else {
|
||||
let Some(size) = requested_page_size else {
|
||||
@@ -674,7 +675,7 @@ impl Database {
|
||||
/// if the database is initialized i.e. it exists on disk, return the reserved space bytes from
|
||||
/// the header or None
|
||||
fn maybe_get_reserved_space_bytes(&self) -> Result<Option<u8>> {
|
||||
if self.db_state.is_initialized() {
|
||||
if self.db_state.get().is_initialized() {
|
||||
Ok(Some(self.read_reserved_space_bytes_from_db_header()?))
|
||||
} else {
|
||||
Ok(None)
|
||||
@@ -696,7 +697,7 @@ impl Database {
|
||||
drop(shared_wal);
|
||||
|
||||
let buffer_pool = self.buffer_pool.clone();
|
||||
if self.db_state.is_initialized() {
|
||||
if self.db_state.get().is_initialized() {
|
||||
buffer_pool.finalize_with_page_size(page_size.get() as usize)?;
|
||||
}
|
||||
|
||||
@@ -729,7 +730,7 @@ impl Database {
|
||||
|
||||
let buffer_pool = self.buffer_pool.clone();
|
||||
|
||||
if self.db_state.is_initialized() {
|
||||
if self.db_state.get().is_initialized() {
|
||||
buffer_pool.finalize_with_page_size(page_size.get() as usize)?;
|
||||
}
|
||||
|
||||
@@ -1067,7 +1068,7 @@ pub struct Connection {
|
||||
database_schemas: RwLock<std::collections::HashMap<usize, Arc<Schema>>>,
|
||||
/// Whether to automatically commit transaction
|
||||
auto_commit: AtomicBool,
|
||||
transaction_state: RwLock<TransactionState>,
|
||||
transaction_state: AtomicTransactionState,
|
||||
last_insert_rowid: AtomicI64,
|
||||
last_change: AtomicI64,
|
||||
total_changes: AtomicI64,
|
||||
@@ -1096,8 +1097,8 @@ pub struct Connection {
|
||||
/// Generally this is only true for ParseSchema.
|
||||
is_nested_stmt: AtomicBool,
|
||||
encryption_key: RwLock<Option<EncryptionKey>>,
|
||||
encryption_cipher_mode: RwLock<Option<CipherMode>>,
|
||||
sync_mode: RwLock<SyncMode>,
|
||||
encryption_cipher_mode: AtomicCipherMode,
|
||||
sync_mode: AtomicSyncMode,
|
||||
data_sync_retry: AtomicBool,
|
||||
/// User defined max accumulated Busy timeout duration
|
||||
/// Default is 0 (no timeout)
|
||||
@@ -1238,8 +1239,7 @@ impl Connection {
|
||||
|
||||
let reparse_result = self.reparse_schema();
|
||||
|
||||
let previous =
|
||||
std::mem::replace(&mut *self.transaction_state.write(), TransactionState::None);
|
||||
let previous = self.transaction_state.swap(TransactionState::None);
|
||||
turso_assert!(
|
||||
matches!(previous, TransactionState::None | TransactionState::Read),
|
||||
"unexpected end transaction state"
|
||||
@@ -1519,7 +1519,7 @@ impl Connection {
|
||||
let _ = conn.pragma_update("cipher", encryption_opts.cipher.to_string());
|
||||
let _ = conn.pragma_update("hexkey", encryption_opts.hexkey.to_string());
|
||||
let pager = conn.pager.read();
|
||||
if db.db_state.is_initialized() {
|
||||
if db.db_state.get().is_initialized() {
|
||||
// Clear page cache so the header page can be reread from disk and decrypted using the encryption context.
|
||||
pager.clear_page_cache(false);
|
||||
}
|
||||
@@ -1597,9 +1597,9 @@ impl Connection {
|
||||
header.schema_cookie.get() < version,
|
||||
"cookie can't go back in time"
|
||||
);
|
||||
*self.transaction_state.write() = TransactionState::Write {
|
||||
self.set_tx_state(TransactionState::Write {
|
||||
schema_did_change: true,
|
||||
};
|
||||
});
|
||||
self.with_schema_mut(|schema| schema.schema_version = version);
|
||||
header.schema_cookie = version.into();
|
||||
})
|
||||
@@ -1682,9 +1682,9 @@ impl Connection {
|
||||
})?;
|
||||
|
||||
// start write transaction and disable auto-commit mode as SQL can be executed within WAL session (at caller own risk)
|
||||
*self.transaction_state.write() = TransactionState::Write {
|
||||
self.set_tx_state(TransactionState::Write {
|
||||
schema_did_change: false,
|
||||
};
|
||||
});
|
||||
self.auto_commit.store(false, Ordering::SeqCst);
|
||||
|
||||
Ok(())
|
||||
@@ -2029,7 +2029,7 @@ impl Connection {
|
||||
}
|
||||
|
||||
pub fn is_db_initialized(&self) -> bool {
|
||||
self.db.db_state.is_initialized()
|
||||
self.db.db_state.get().is_initialized()
|
||||
}
|
||||
|
||||
fn get_pager_from_database_index(&self, index: &usize) -> Arc<Pager> {
|
||||
@@ -2259,11 +2259,11 @@ impl Connection {
|
||||
}
|
||||
|
||||
pub fn get_sync_mode(&self) -> SyncMode {
|
||||
*self.sync_mode.read()
|
||||
self.sync_mode.get()
|
||||
}
|
||||
|
||||
pub fn set_sync_mode(&self, mode: SyncMode) {
|
||||
*self.sync_mode.write() = mode;
|
||||
self.sync_mode.set(mode);
|
||||
}
|
||||
|
||||
pub fn get_data_sync_retry(&self) -> bool {
|
||||
@@ -2289,7 +2289,7 @@ impl Connection {
|
||||
|
||||
pub fn set_encryption_cipher(&self, cipher_mode: CipherMode) -> Result<()> {
|
||||
tracing::trace!("setting encryption cipher for connection");
|
||||
*self.encryption_cipher_mode.write() = Some(cipher_mode);
|
||||
self.encryption_cipher_mode.set(cipher_mode);
|
||||
self.set_encryption_context()
|
||||
}
|
||||
|
||||
@@ -2300,7 +2300,10 @@ impl Connection {
|
||||
}
|
||||
|
||||
pub fn get_encryption_cipher_mode(&self) -> Option<CipherMode> {
|
||||
*self.encryption_cipher_mode.read()
|
||||
match self.encryption_cipher_mode.get() {
|
||||
CipherMode::None => None,
|
||||
mode => Some(mode),
|
||||
}
|
||||
}
|
||||
|
||||
// if both key and cipher are set, set encryption context on pager
|
||||
@@ -2309,8 +2312,8 @@ impl Connection {
|
||||
let Some(key) = key_guard.as_ref() else {
|
||||
return Ok(());
|
||||
};
|
||||
let cipher_guard = self.encryption_cipher_mode.read();
|
||||
let Some(cipher_mode) = *cipher_guard else {
|
||||
let cipher_mode = self.get_encryption_cipher_mode();
|
||||
let Some(cipher_mode) = cipher_mode else {
|
||||
return Ok(());
|
||||
};
|
||||
tracing::trace!("setting encryption ctx for connection");
|
||||
@@ -2348,11 +2351,11 @@ impl Connection {
|
||||
}
|
||||
|
||||
fn set_tx_state(&self, state: TransactionState) {
|
||||
*self.transaction_state.write() = state;
|
||||
self.transaction_state.set(state);
|
||||
}
|
||||
|
||||
fn get_tx_state(&self) -> TransactionState {
|
||||
*self.transaction_state.read()
|
||||
self.transaction_state.get()
|
||||
}
|
||||
|
||||
pub(crate) fn get_mv_tx_id(&self) -> Option<u64> {
|
||||
|
||||
@@ -325,9 +325,9 @@ impl<Clock: LogicalClock> CheckpointStateMachine<Clock> {
|
||||
}
|
||||
result?;
|
||||
if self.update_transaction_state {
|
||||
*self.connection.transaction_state.write() = TransactionState::Write {
|
||||
self.connection.set_tx_state(TransactionState::Write {
|
||||
schema_did_change: false,
|
||||
}; // TODO: schema_did_change??
|
||||
}); // TODO: schema_did_change??
|
||||
}
|
||||
self.lock_states.pager_write_tx = true;
|
||||
self.state = CheckpointState::WriteRow {
|
||||
@@ -534,7 +534,7 @@ impl<Clock: LogicalClock> CheckpointStateMachine<Clock> {
|
||||
self.lock_states.pager_read_tx = false;
|
||||
self.lock_states.pager_write_tx = false;
|
||||
if self.update_transaction_state {
|
||||
*self.connection.transaction_state.write() = TransactionState::None;
|
||||
self.connection.set_tx_state(TransactionState::None);
|
||||
}
|
||||
let header = self
|
||||
.pager
|
||||
@@ -623,12 +623,12 @@ impl<Clock: LogicalClock> StateTransition for CheckpointStateMachine<Clock> {
|
||||
if self.lock_states.pager_write_tx {
|
||||
self.pager.rollback_tx(self.connection.as_ref());
|
||||
if self.update_transaction_state {
|
||||
*self.connection.transaction_state.write() = TransactionState::None;
|
||||
self.connection.set_tx_state(TransactionState::None);
|
||||
}
|
||||
} else if self.lock_states.pager_read_tx {
|
||||
self.pager.end_read_tx();
|
||||
if self.update_transaction_state {
|
||||
*self.connection.transaction_state.write() = TransactionState::None;
|
||||
self.connection.set_tx_state(TransactionState::None);
|
||||
}
|
||||
}
|
||||
if self.lock_states.blocking_checkpoint_lock_held {
|
||||
|
||||
@@ -662,7 +662,7 @@ impl BTreeNodeState {
|
||||
|
||||
impl BTreeCursor {
|
||||
pub fn new(pager: Arc<Pager>, root_page: i64, num_columns: usize) -> Self {
|
||||
let valid_state = if root_page == 1 && !pager.db_state.is_initialized() {
|
||||
let valid_state = if root_page == 1 && !pager.db_state.get().is_initialized() {
|
||||
CursorValidState::Invalid
|
||||
} else {
|
||||
CursorValidState::Valid
|
||||
|
||||
@@ -10,7 +10,7 @@ use aes_gcm::{
|
||||
aead::{Aead, AeadCore, KeyInit, OsRng},
|
||||
Aes128Gcm, Aes256Gcm, Key, Nonce,
|
||||
};
|
||||
use turso_macros::match_ignore_ascii_case;
|
||||
use turso_macros::{match_ignore_ascii_case, AtomicEnum};
|
||||
|
||||
/// Encryption Scheme
|
||||
/// We support two major algorithms: AEGIS, AES GCM. These algorithms picked so that they also do
|
||||
@@ -319,8 +319,9 @@ define_aegis_cipher!(
|
||||
"AEGIS-128X4"
|
||||
);
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq)]
|
||||
pub enum CipherMode {
|
||||
None,
|
||||
Aes128Gcm,
|
||||
Aes256Gcm,
|
||||
Aegis256,
|
||||
@@ -363,6 +364,7 @@ impl std::fmt::Display for CipherMode {
|
||||
CipherMode::Aegis128X4 => write!(f, "aegis128x4"),
|
||||
CipherMode::Aegis256X2 => write!(f, "aegis256x2"),
|
||||
CipherMode::Aegis256X4 => write!(f, "aegis256x4"),
|
||||
CipherMode::None => write!(f, "None"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -380,6 +382,7 @@ impl CipherMode {
|
||||
CipherMode::Aegis128L => 16,
|
||||
CipherMode::Aegis128X2 => 16,
|
||||
CipherMode::Aegis128X4 => 16,
|
||||
CipherMode::None => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,6 +397,7 @@ impl CipherMode {
|
||||
CipherMode::Aegis128L => 16,
|
||||
CipherMode::Aegis128X2 => 16,
|
||||
CipherMode::Aegis128X4 => 16,
|
||||
CipherMode::None => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -408,6 +412,7 @@ impl CipherMode {
|
||||
CipherMode::Aegis128L => 16,
|
||||
CipherMode::Aegis128X2 => 16,
|
||||
CipherMode::Aegis128X4 => 16,
|
||||
CipherMode::None => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,6 +432,7 @@ impl CipherMode {
|
||||
CipherMode::Aegis128L => 6,
|
||||
CipherMode::Aegis128X2 => 7,
|
||||
CipherMode::Aegis128X4 => 8,
|
||||
CipherMode::None => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,6 +509,11 @@ impl EncryptionContext {
|
||||
CipherMode::Aegis128L => Cipher::Aegis128L(Box::new(Aegis128LCipher::new(key))),
|
||||
CipherMode::Aegis128X2 => Cipher::Aegis128X2(Box::new(Aegis128X2Cipher::new(key))),
|
||||
CipherMode::Aegis128X4 => Cipher::Aegis128X4(Box::new(Aegis128X4Cipher::new(key))),
|
||||
CipherMode::None => {
|
||||
return Err(LimboError::InvalidArgument(
|
||||
"must select valid CipherMode".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
cipher_mode,
|
||||
|
||||
@@ -25,6 +25,7 @@ use std::sync::atomic::{
|
||||
};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tracing::{instrument, trace, Level};
|
||||
use turso_macros::AtomicEnum;
|
||||
|
||||
use super::btree::btree_init_page;
|
||||
use super::page_cache::{CacheError, CacheResizeResult, PageCache, PageCacheKey};
|
||||
@@ -57,7 +58,7 @@ impl HeaderRef {
|
||||
tracing::trace!("HeaderRef::from_pager - {:?}", state);
|
||||
match state {
|
||||
HeaderRefState::Start => {
|
||||
if !pager.db_state.is_initialized() {
|
||||
if !pager.db_state.get().is_initialized() {
|
||||
return Err(LimboError::Page1NotAlloc);
|
||||
}
|
||||
|
||||
@@ -97,7 +98,7 @@ impl HeaderRefMut {
|
||||
tracing::trace!(?state);
|
||||
match state {
|
||||
HeaderRefState::Start => {
|
||||
if !pager.db_state.is_initialized() {
|
||||
if !pager.db_state.get().is_initialized() {
|
||||
return Err(LimboError::Page1NotAlloc);
|
||||
}
|
||||
|
||||
@@ -416,57 +417,19 @@ impl From<u8> for AutoVacuumMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(usize)]
|
||||
#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum DbState {
|
||||
Uninitialized = Self::UNINITIALIZED,
|
||||
Initializing = Self::INITIALIZING,
|
||||
Initialized = Self::INITIALIZED,
|
||||
Uninitialized,
|
||||
Initializing,
|
||||
Initialized,
|
||||
}
|
||||
|
||||
impl DbState {
|
||||
pub(self) const UNINITIALIZED: usize = 0;
|
||||
pub(self) const INITIALIZING: usize = 1;
|
||||
pub(self) const INITIALIZED: usize = 2;
|
||||
|
||||
#[inline]
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
matches!(self, DbState::Initialized)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct AtomicDbState(AtomicUsize);
|
||||
|
||||
impl AtomicDbState {
|
||||
#[inline]
|
||||
pub const fn new(state: DbState) -> Self {
|
||||
Self(AtomicUsize::new(state as usize))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn set(&self, state: DbState) {
|
||||
self.0.store(state as usize, Ordering::Release);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get(&self) -> DbState {
|
||||
let v = self.0.load(Ordering::Acquire);
|
||||
match v {
|
||||
DbState::UNINITIALIZED => DbState::Uninitialized,
|
||||
DbState::INITIALIZING => DbState::Initializing,
|
||||
DbState::INITIALIZED => DbState::Initialized,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.get().is_initialized()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg(not(feature = "omit_autovacuum"))]
|
||||
enum PtrMapGetState {
|
||||
@@ -621,7 +584,7 @@ impl Pager {
|
||||
db_state: Arc<AtomicDbState>,
|
||||
init_lock: Arc<Mutex<()>>,
|
||||
) -> Result<Self> {
|
||||
let allocate_page1_state = if !db_state.is_initialized() {
|
||||
let allocate_page1_state = if !db_state.get().is_initialized() {
|
||||
RwLock::new(AllocatePage1State::Start)
|
||||
} else {
|
||||
RwLock::new(AllocatePage1State::Done)
|
||||
@@ -1131,7 +1094,7 @@ impl Pager {
|
||||
|
||||
#[instrument(skip_all, level = Level::DEBUG)]
|
||||
pub fn maybe_allocate_page1(&self) -> Result<IOResult<()>> {
|
||||
if !self.db_state.is_initialized() {
|
||||
if !self.db_state.get().is_initialized() {
|
||||
if let Ok(_lock) = self.init_lock.try_lock() {
|
||||
match (self.db_state.get(), self.allocating_page1()) {
|
||||
// In case of being empty or (allocating and this connection is performing allocation) then allocate the first page
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::{
|
||||
LimboError, OpenFlags, Result, Statement, StepResult, SymbolTable,
|
||||
};
|
||||
use crate::{Connection, MvStore, IO};
|
||||
use std::sync::atomic::AtomicU8;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
rc::Rc,
|
||||
|
||||
290
macros/src/atomic_enum.rs
Normal file
290
macros/src/atomic_enum.rs
Normal file
@@ -0,0 +1,290 @@
|
||||
use proc_macro::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
|
||||
|
||||
pub(crate) fn derive_atomic_enum_inner(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
let name = &input.ident;
|
||||
let atomic_name = syn::Ident::new(&format!("Atomic{name}"), name.span());
|
||||
|
||||
let variants = match &input.data {
|
||||
Data::Enum(data) => &data.variants,
|
||||
_ => {
|
||||
return syn::Error::new_spanned(input, "AtomicEnum can only be derived for enums")
|
||||
.to_compile_error()
|
||||
.into();
|
||||
}
|
||||
};
|
||||
|
||||
// get info about variants to determine how we have to encode them
|
||||
let mut has_bool_field = false;
|
||||
let mut has_u8_field = false;
|
||||
let mut max_discriminant = 0u8;
|
||||
|
||||
for (idx, variant) in variants.iter().enumerate() {
|
||||
max_discriminant = idx as u8;
|
||||
match &variant.fields {
|
||||
Fields::Unit => {}
|
||||
Fields::Named(fields) if fields.named.len() == 1 => {
|
||||
let field = &fields.named[0];
|
||||
if is_bool_type(&field.ty) {
|
||||
has_bool_field = true;
|
||||
} else if is_u8_or_i8_type(&field.ty) {
|
||||
has_u8_field = true;
|
||||
} else {
|
||||
return syn::Error::new_spanned(
|
||||
field,
|
||||
"AtomicEnum only supports bool, u8, or i8 fields",
|
||||
)
|
||||
.to_compile_error()
|
||||
.into();
|
||||
}
|
||||
}
|
||||
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
|
||||
let field = &fields.unnamed[0];
|
||||
if is_bool_type(&field.ty) {
|
||||
has_bool_field = true;
|
||||
} else if is_u8_or_i8_type(&field.ty) {
|
||||
has_u8_field = true;
|
||||
} else {
|
||||
return syn::Error::new_spanned(
|
||||
field,
|
||||
"AtomicEnum only supports bool, u8, or i8 fields",
|
||||
)
|
||||
.to_compile_error()
|
||||
.into();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return syn::Error::new_spanned(
|
||||
variant,
|
||||
"AtomicEnum only supports unit variants or variants with a single field",
|
||||
)
|
||||
.to_compile_error()
|
||||
.into();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (storage_type, atomic_type) = if has_u8_field || (has_bool_field && max_discriminant > 127)
|
||||
{
|
||||
// Need u16: 8 bits for discriminant, 8 bits for data
|
||||
(quote! { u16 }, quote! { ::std::sync::atomic::AtomicU16 })
|
||||
} else {
|
||||
// Can use u8: 7 bits for discriminant, 1 bit for bool (if any)
|
||||
(quote! { u8 }, quote! { ::std::sync::atomic::AtomicU8 })
|
||||
};
|
||||
|
||||
let use_u16 = has_u8_field || (has_bool_field && max_discriminant > 127);
|
||||
|
||||
let to_storage = variants.iter().enumerate().map(|(idx, variant)| {
|
||||
let var_name = &variant.ident;
|
||||
let disc = idx as u8; // The discriminant here is just the variant's index
|
||||
|
||||
match &variant.fields {
|
||||
// Simple unit variant, just store the discriminant
|
||||
Fields::Unit => {
|
||||
if use_u16 {
|
||||
quote! { #name::#var_name => #disc as u16 }
|
||||
} else {
|
||||
quote! { #name::#var_name => #disc }
|
||||
}
|
||||
}
|
||||
Fields::Named(fields) => {
|
||||
// Named field variant like `Write { schema_did_change: bool }`
|
||||
let field = &fields.named[0];
|
||||
let field_name = &field.ident;
|
||||
|
||||
if is_bool_type(&field.ty) {
|
||||
if use_u16 {
|
||||
// Pack as: [discriminant_byte | bool_as_byte]
|
||||
// Example: Write {true} with disc=3 becomes: b100000011
|
||||
quote! {
|
||||
#name::#var_name { ref #field_name } => {
|
||||
(#disc as u16) | ((*#field_name as u16) << 8)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Same as above but with u8, so only 1 bit for bool
|
||||
// Example: Write{true} with disc=3 becomes: b10000011
|
||||
quote! {
|
||||
#name::#var_name { ref #field_name } => {
|
||||
#disc | ((*#field_name as u8) << 7)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// u8/i8 field always uses u16 to have enough bits
|
||||
// Pack as: [discriminant_byte | value_byte]
|
||||
quote! {
|
||||
#name::#var_name { ref #field_name } => {
|
||||
(#disc as u16) | ((*#field_name as u16) << 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Fields::Unnamed(_) => {
|
||||
// same strategy as above, but for tuple variants like `Write(bool)`
|
||||
if is_bool_type(&variant.fields.iter().next().unwrap().ty) {
|
||||
if use_u16 {
|
||||
quote! {
|
||||
#name::#var_name(ref val) => {
|
||||
(#disc as u16) | ((*val as u16) << 8)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#name::#var_name(ref val) => {
|
||||
#disc | ((*val as u8) << 7)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#name::#var_name(ref val) => {
|
||||
(#disc as u16) | ((*val as u16) << 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Generate the match arms for decoding the storage representation back to enum
|
||||
let from_storage = variants.iter().enumerate().map(|(idx, variant)| {
|
||||
let var_name = &variant.ident;
|
||||
let disc = idx as u8;
|
||||
|
||||
match &variant.fields {
|
||||
Fields::Unit => quote! { #disc => #name::#var_name },
|
||||
Fields::Named(fields) => {
|
||||
let field = &fields.named[0];
|
||||
let field_name = &field.ident;
|
||||
|
||||
if is_bool_type(&field.ty) {
|
||||
if use_u16 {
|
||||
// Extract bool from high byte: check if non-zero
|
||||
quote! {
|
||||
#disc => #name::#var_name {
|
||||
#field_name: (val >> 8) != 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// check single bool value at bit 7
|
||||
quote! {
|
||||
#disc => #name::#var_name {
|
||||
#field_name: (val & 0x80) != 0
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#disc => #name::#var_name {
|
||||
// Extract u8/i8 from high byte and cast to appropriate type
|
||||
#field_name: (val >> 8) as _
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Fields::Unnamed(_) => {
|
||||
if is_bool_type(&variant.fields.iter().next().unwrap().ty) {
|
||||
if use_u16 {
|
||||
quote! { #disc => #name::#var_name((val >> 8) != 0) }
|
||||
} else {
|
||||
quote! { #disc => #name::#var_name((val & 0x80) != 0) }
|
||||
}
|
||||
} else {
|
||||
quote! { #disc => #name::#var_name((val >> 8) as _) }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let discriminant_mask = if use_u16 {
|
||||
quote! { 0xFF }
|
||||
} else {
|
||||
quote! { 0x7F }
|
||||
};
|
||||
let to_storage_arms_copy = to_storage.clone();
|
||||
|
||||
let expanded = quote! {
|
||||
#[derive(Debug)]
|
||||
/// Atomic wrapper for #name
|
||||
pub struct #atomic_name(#atomic_type);
|
||||
|
||||
impl #atomic_name {
|
||||
/// Encode enum into storage representation
|
||||
/// Discriminant in lower bits, field data in upper bits
|
||||
#[inline]
|
||||
fn to_storage(val: &#name) -> #storage_type {
|
||||
match val {
|
||||
#(#to_storage_arms_copy),*
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode storage representation into enum
|
||||
/// Panics on invalid discriminant
|
||||
#[inline]
|
||||
fn from_storage(val: #storage_type) -> #name {
|
||||
let discriminant = (val & #discriminant_mask) as u8;
|
||||
match discriminant {
|
||||
#(#from_storage,)*
|
||||
_ => panic!(concat!("Invalid ", stringify!(#name), " discriminant: {}"), discriminant),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create new atomic enum with initial value
|
||||
#[inline]
|
||||
pub const fn new(val: #name) -> Self {
|
||||
// Can't call to_storage in const context, so inline it
|
||||
let storage = match val {
|
||||
#(#to_storage),*
|
||||
};
|
||||
Self(#atomic_type::new(storage))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Load and convert the current value to expected enum
|
||||
pub fn get(&self) -> #name {
|
||||
Self::from_storage(self.0.load(::std::sync::atomic::Ordering::SeqCst))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Convert and store new value
|
||||
pub fn set(&self, val: #name) {
|
||||
self.0.store(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Store new value and return previous value
|
||||
pub fn swap(&self, val: #name) -> #name {
|
||||
let prev = self.0.swap(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst);
|
||||
Self::from_storage(prev)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<#name> for #atomic_name {
|
||||
fn from(val: #name) -> Self {
|
||||
Self::new(val)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
fn is_bool_type(ty: &Type) -> bool {
|
||||
if let Type::Path(path) = ty {
|
||||
path.path.is_ident("bool")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn is_u8_or_i8_type(ty: &Type) -> bool {
|
||||
if let Type::Path(path) = ty {
|
||||
path.path.is_ident("u8") || path.path.is_ident("i8")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
mod ext;
|
||||
extern crate proc_macro;
|
||||
mod atomic_enum;
|
||||
use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -464,3 +465,30 @@ pub fn derive_vfs_module(input: TokenStream) -> TokenStream {
|
||||
pub fn match_ignore_ascii_case(input: TokenStream) -> TokenStream {
|
||||
ext::match_ignore_ascci_case(input)
|
||||
}
|
||||
|
||||
/// Derive macro for creating atomic wrappers for enums
|
||||
///
|
||||
/// Supports:
|
||||
/// - Unit variants
|
||||
/// - Variants with single bool/u8/i8 fields
|
||||
/// - Named or unnamed fields
|
||||
///
|
||||
/// Algorithm:
|
||||
/// - Uses u8 representation, splitting bits for variant discriminant and field data
|
||||
/// - For bool fields: high bit for bool, lower 7 bits for discriminant
|
||||
/// - For u8/i8 fields: uses u16 internally (8 bits discriminant, 8 bits data)
|
||||
///
|
||||
/// Example:
|
||||
/// ```ignore
|
||||
/// #[derive(AtomicEnum)]
|
||||
/// enum TransactionState {
|
||||
/// Write { schema_did_change: bool },
|
||||
/// Read,
|
||||
/// PendingUpgrade,
|
||||
/// None,
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro_derive(AtomicEnum)]
|
||||
pub fn derive_atomic_enum(input: TokenStream) -> TokenStream {
|
||||
atomic_enum::derive_atomic_enum_inner(input)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user