diff --git a/core/io/completions.rs b/core/io/completions.rs index a381324c6..9e5ed605c 100644 --- a/core/io/completions.rs +++ b/core/io/completions.rs @@ -11,10 +11,10 @@ use parking_lot::Mutex; use crate::{Buffer, CompletionError}; -pub type ReadComplete = dyn Fn(Result<(Arc, i32), CompletionError>); -pub type WriteComplete = dyn Fn(Result); -pub type SyncComplete = dyn Fn(Result); -pub type TruncateComplete = dyn Fn(Result); +pub type ReadComplete = dyn Fn(Result<(Arc, i32), CompletionError>) + Send + Sync; +pub type WriteComplete = dyn Fn(Result) + Send + Sync; +pub type SyncComplete = dyn Fn(Result) + Send + Sync; +pub type TruncateComplete = dyn Fn(Result) + Send + Sync; #[must_use] #[derive(Debug, Clone)] @@ -275,7 +275,7 @@ impl Completion { pub fn new_write_linked(complete: F) -> Self where - F: Fn(Result) + 'static, + F: Fn(Result) + Send + Sync + 'static, { Self::new_linked(CompletionType::Write(WriteCompletion::new(Box::new( complete, @@ -284,7 +284,7 @@ impl Completion { pub fn new_write(complete: F) -> Self where - F: Fn(Result) + 'static, + F: Fn(Result) + Send + Sync + 'static, { Self::new(CompletionType::Write(WriteCompletion::new(Box::new( complete, @@ -293,7 +293,7 @@ impl Completion { pub fn new_read(buf: Arc, complete: F) -> Self where - F: Fn(Result<(Arc, i32), CompletionError>) + 'static, + F: Fn(Result<(Arc, i32), CompletionError>) + Send + Sync + 'static, { Self::new(CompletionType::Read(ReadCompletion::new( buf, @@ -302,7 +302,7 @@ impl Completion { } pub fn new_sync(complete: F) -> Self where - F: Fn(Result) + 'static, + F: Fn(Result) + Send + Sync + 'static, { Self::new(CompletionType::Sync(SyncCompletion::new(Box::new( complete, @@ -311,7 +311,7 @@ impl Completion { pub fn new_trunc(complete: F) -> Self where - F: Fn(Result) + 'static, + F: Fn(Result) + Send + Sync + 'static, { Self::new(CompletionType::Truncate(TruncateCompletion::new(Box::new( complete, diff --git a/core/storage/buffer_pool.rs b/core/storage/buffer_pool.rs index 6bac73f71..ed0e3a1ec 100644 --- a/core/storage/buffer_pool.rs +++ b/core/storage/buffer_pool.rs @@ -26,6 +26,10 @@ pub struct ArenaBuffer { len: usize, } +// Unsound: write and read from different threads can be dangerous with current ArenaBuffer implementation without some additional explicit synchronization +unsafe impl Sync for ArenaBuffer {} +unsafe impl Send for ArenaBuffer {} + impl ArenaBuffer { const fn new( arena: Weak, diff --git a/core/storage/wal.rs b/core/storage/wal.rs index eebfbbd98..7b8ae3dd1 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -4,13 +4,14 @@ use rustc_hash::{FxHashMap, FxHashSet}; use std::array; use std::borrow::Cow; use std::collections::BTreeMap; +use std::sync::Mutex; use strum::EnumString; use tracing::{instrument, Level}; use parking_lot::RwLock; use std::fmt::{Debug, Formatter}; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicUsize, Ordering}; -use std::{cell::Cell, fmt, sync::Arc}; +use std::{fmt, sync::Arc}; use super::buffer_pool::BufferPool; use super::pager::{PageRef, Pager}; @@ -1138,10 +1139,15 @@ impl Wal for WalFile { } #[instrument(skip_all, level = Level::DEBUG)] + // todo(sivukhin): change API to accept Buffer or some other owned type + // this method involves IO and cross "async" boundary - so juggling with references is bad and dangerous fn read_frame_raw(&self, frame_id: u64, frame: &mut [u8]) -> Result { tracing::debug!("read_frame_raw({})", frame_id); let offset = self.frame_offset(frame_id); - let (frame_ptr, frame_len) = (frame.as_mut_ptr(), frame.len()); + + // HACK: *mut u8 can't be Sent between threads safely, cast it to usize then + // for the time of writing this comment - this is *safe* as all callers immediately call synchronous method wait_for_completion and hold necessary references + let (frame_ptr, frame_len) = (frame.as_mut_ptr() as usize, frame.len()); let encryption_ctx = { let io_ctx = self.io_ctx.read(); @@ -1157,6 +1163,7 @@ impl Wal for WalFile { "read({bytes_read}) != expected({buf_len})" ); let buf_ptr = buf.as_ptr(); + let frame_ptr = frame_ptr as *mut u8; let frame_ref: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(frame_ptr, frame_len) }; @@ -1194,6 +1201,8 @@ impl Wal for WalFile { } #[instrument(skip_all, level = Level::DEBUG)] + // todo(sivukhin): change API to accept Buffer or some other owned type + // this method involves IO and cross "async" boundary - so juggling with references is bad and dangerous fn write_frame_raw( &mut self, buffer_pool: Arc, @@ -1226,8 +1235,12 @@ impl Wal for WalFile { if frame_id <= self.max_frame.load(Ordering::Acquire) { // just validate if page content from the frame matches frame in the WAL let offset = self.frame_offset(frame_id); - let conflict = Arc::new(Cell::new(false)); - let (page_ptr, page_len) = (page.as_ptr(), page.len()); + let conflict = Arc::new(Mutex::new(false)); + + // HACK: *mut u8 can't be shared between threads safely, cast it to usize then + // for the time of writing this comment - this is *safe* as the function immediately call synchronous method wait_for_completion and hold necessary references + let (page_ptr, page_len) = (page.as_ptr() as usize, page.len()); + let complete = Box::new({ let conflict = conflict.clone(); move |res: Result<(Arc, i32), CompletionError>| { @@ -1239,9 +1252,9 @@ impl Wal for WalFile { bytes_read == buf_len as i32, "read({bytes_read}) != expected({buf_len})" ); - let page = unsafe { std::slice::from_raw_parts(page_ptr, page_len) }; + let page = unsafe { std::slice::from_raw_parts(page_ptr as *mut u8, page_len) }; if buf.as_slice() != page { - conflict.set(true); + *conflict.lock().unwrap() = true; } } }); @@ -1258,7 +1271,7 @@ impl Wal for WalFile { &self.io_ctx.read(), )?; self.io.wait_for_completion(c)?; - return if conflict.get() { + return if *conflict.lock().unwrap() { Err(LimboError::Conflict(format!( "frame content differs from the WAL: frame_id={frame_id}" ))) @@ -2506,11 +2519,7 @@ pub mod test { use parking_lot::RwLock; #[cfg(unix)] use std::os::unix::fs::MetadataExt; - use std::{ - cell::Cell, - rc::Rc, - sync::{atomic::Ordering, Arc}, - }; + use std::sync::{atomic::Ordering, Arc, Mutex}; #[allow(clippy::arc_with_non_send_sync)] pub(crate) fn get_database() -> (Arc, std::path::PathBuf) { let mut path = tempfile::tempdir().unwrap().keep(); @@ -2536,17 +2545,16 @@ pub mod test { let _ = conn.execute("insert into test (value) values ('test1'), ('test2'), ('test3')"); let wal = db.shared_wal.write(); let wal_file = wal.file.as_ref().unwrap().clone(); - let done = Rc::new(Cell::new(false)); + let done = Arc::new(Mutex::new(false)); let _done = done.clone(); let _ = wal_file.truncate( WAL_HEADER_SIZE as u64, Completion::new_trunc(move |_| { - let done = _done.clone(); - done.set(true); + *_done.lock().unwrap() = true; }), ); assert!(wal_file.size().unwrap() == WAL_HEADER_SIZE as u64); - assert!(done.get()); + assert!(*done.lock().unwrap()); } #[test]