use std::{ cell::UnsafeCell, ops::{Deref, DerefMut}, sync::atomic::{AtomicBool, Ordering}, }; #[derive(Debug)] pub struct SpinLock { locked: AtomicBool, value: UnsafeCell, } pub struct SpinLockGuard<'a, T> { lock: &'a SpinLock, } impl Drop for SpinLockGuard<'_, T> { fn drop(&mut self) { self.lock.locked.store(false, Ordering::Release); } } impl Deref for SpinLockGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { unsafe { &*self.lock.value.get() } } } impl DerefMut for SpinLockGuard<'_, T> { fn deref_mut(&mut self) -> &mut T { unsafe { &mut *self.lock.value.get() } } } unsafe impl Send for SpinLock {} unsafe impl Sync for SpinLock {} impl SpinLock { pub fn new(value: T) -> Self { Self { locked: AtomicBool::new(false), value: UnsafeCell::new(value), } } pub fn lock(&self) -> SpinLockGuard { while self.locked.swap(true, Ordering::Acquire) { std::hint::spin_loop(); } SpinLockGuard { lock: self } } pub fn into_inner(self) -> UnsafeCell { self.value } } #[cfg(test)] mod tests { use std::sync::Arc; use super::SpinLock; #[test] fn test_fast_lock_multiple_thread_sum() { let lock = Arc::new(SpinLock::new(0)); let mut threads = vec![]; const NTHREADS: usize = 1000; for _ in 0..NTHREADS { let lock = lock.clone(); threads.push(std::thread::spawn(move || { let mut guard = lock.lock(); *guard += 1; })); } for thread in threads { thread.join().unwrap(); } assert_eq!(*lock.lock(), NTHREADS); } }