use std::{ cell::UnsafeCell, sync::atomic::{AtomicBool, Ordering}, }; #[derive(Debug)] pub struct FastLock { lock: AtomicBool, value: UnsafeCell, } pub struct FastLockGuard<'a, T> { lock: &'a FastLock, } impl<'a, T> FastLockGuard<'a, T> { pub fn get_mut(&self) -> &mut T { self.lock.get_mut() } } impl<'a, T> Drop for FastLockGuard<'a, T> { fn drop(&mut self) { self.lock.unlock(); } } unsafe impl Send for FastLock {} unsafe impl Sync for FastLock {} impl FastLock { pub fn new(value: T) -> Self { Self { lock: AtomicBool::new(false), value: UnsafeCell::new(value), } } pub fn lock(&self) -> FastLockGuard { while self .lock .compare_exchange(false, true, Ordering::Acquire, Ordering::Acquire) .is_err() { std::thread::yield_now(); } FastLockGuard { lock: self } } pub fn unlock(&self) { assert!(self .lock .compare_exchange(true, false, Ordering::Acquire, Ordering::Acquire) .is_ok()); } pub fn get_mut(&self) -> &mut T { unsafe { self.value.get().as_mut().unwrap() } } } #[cfg(test)] mod tests { use std::sync::Arc; use super::FastLock; #[test] fn test_fast_lock_multiple_thread_sum() { let lock = Arc::new(FastLock::new(0)); let mut threads = vec![]; const NTHREADS: usize = 1000; for _ in 0..NTHREADS { let lock = lock.clone(); threads.push(std::thread::spawn(move || { lock.lock(); let value = lock.get_mut(); *value += 1; })); } for thread in threads { thread.join().unwrap(); } assert_eq!(*lock.get_mut(), NTHREADS); } }