diff --git a/core/error.rs b/core/error.rs index 87c686932..3dd4841ad 100644 --- a/core/error.rs +++ b/core/error.rs @@ -109,7 +109,7 @@ impl From for CompletionError { } } -#[derive(Debug, Copy, Clone, Error)] +#[derive(Debug, Copy, Clone, PartialEq, Error)] pub enum CompletionError { #[error("I/O error: {0}")] IOError(std::io::ErrorKind), diff --git a/core/io/mod.rs b/core/io/mod.rs index e537c393d..126a539f8 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -7,6 +7,7 @@ use parking_lot::Once; use std::cell::RefCell; use std::fmt; use std::ptr::NonNull; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; use std::{fmt::Debug, pin::Pin}; @@ -136,7 +137,6 @@ pub struct Completion { inner: Arc, } -#[derive(Debug)] struct CompletionInner { completion_type: CompletionType, /// None means we completed successfully @@ -145,6 +145,130 @@ struct CompletionInner { needs_link: bool, /// before calling callback we check if done is true done: Once, + /// Optional parent group this completion belongs to + parent: OnceLock>, +} + +impl fmt::Debug for CompletionInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CompletionInner") + .field("completion_type", &self.completion_type) + .field("needs_link", &self.needs_link) + .field("parent", &self.parent.get().is_some()) + .finish() + } +} + +pub struct CompletionGroup { + completions: Vec, + callback: Box) + Send + Sync>, +} + +impl CompletionGroup { + pub fn new(callback: F) -> Self + where + F: Fn(Result) + Send + Sync + 'static, + { + Self { + completions: Vec::new(), + callback: Box::new(callback), + } + } + + pub fn add(&mut self, completion: &Completion) { + if !completion.finished() || completion.has_error() { + self.completions.push(completion.clone()); + } + // Skip successfully finished completions + } + + pub fn build(self) -> Completion { + let total = self.completions.len(); + if total == 0 { + let group_completion = GroupCompletion::new(self.callback, 0); + return Completion::new(CompletionType::Group(group_completion)); + } + let group_completion = GroupCompletion::new(self.callback, total); + let group = Completion::new(CompletionType::Group(group_completion)); + + for mut c in self.completions { + // If the completion has not completed, link it to the group. + if !c.finished() { + c.link_internal(&group); + continue; + } + let group_inner = match &group.inner.completion_type { + CompletionType::Group(g) => &g.inner, + _ => unreachable!(), + }; + // Return early if there was an error. + if let Some(err) = c.get_error() { + let _ = group_inner.result.set(Some(err)); + group_inner.outstanding.store(0, Ordering::SeqCst); + (group_inner.complete)(Err(err)); + return group; + } + // Mark the successful completion as done. + group_inner.outstanding.fetch_sub(1, Ordering::SeqCst); + } + + let group_inner = match &group.inner.completion_type { + CompletionType::Group(g) => &g.inner, + _ => unreachable!(), + }; + if group_inner.outstanding.load(Ordering::SeqCst) == 0 { + (group_inner.complete)(Ok(0)); + } + group + } +} + +pub struct GroupCompletion { + inner: Arc, +} + +impl fmt::Debug for GroupCompletion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GroupCompletion") + .field( + "outstanding", + &self.inner.outstanding.load(Ordering::SeqCst), + ) + .finish() + } +} + +struct GroupCompletionInner { + /// Number of completions that need to finish + outstanding: AtomicUsize, + /// Callback to invoke when all completions finish + complete: Box) + Send + Sync>, + /// Cached result after all completions finish + result: OnceLock>, +} + +impl GroupCompletion { + pub fn new(complete: F, outstanding: usize) -> Self + where + F: Fn(Result) + Send + Sync + 'static, + { + Self { + inner: Arc::new(GroupCompletionInner { + outstanding: AtomicUsize::new(outstanding), + complete: Box::new(complete), + result: OnceLock::new(), + }), + } + } + + pub fn callback(&self, result: Result) { + assert_eq!( + self.inner.outstanding.load(Ordering::SeqCst), + 0, + "callback called before all completions finished" + ); + (self.inner.complete)(result); + } } impl Debug for CompletionType { @@ -154,6 +278,7 @@ impl Debug for CompletionType { Self::Write(..) => f.debug_tuple("Write").finish(), Self::Sync(..) => f.debug_tuple("Sync").finish(), Self::Truncate(..) => f.debug_tuple("Truncate").finish(), + Self::Group(..) => f.debug_tuple("Group").finish(), } } } @@ -163,6 +288,7 @@ pub enum CompletionType { Write(WriteCompletion), Sync(SyncCompletion), Truncate(TruncateCompletion), + Group(GroupCompletion), } impl Completion { @@ -173,6 +299,7 @@ impl Completion { result: OnceLock::new(), needs_link: false, done: Once::new(), + parent: OnceLock::new(), }), } } @@ -184,6 +311,7 @@ impl Completion { result: OnceLock::new(), needs_link: true, done: Once::new(), + parent: OnceLock::new(), }), } } @@ -245,7 +373,13 @@ impl Completion { } pub fn is_completed(&self) -> bool { - self.inner.result.get().is_some_and(|val| val.is_none()) + match &self.inner.completion_type { + CompletionType::Group(g) => { + g.inner.outstanding.load(Ordering::SeqCst) == 0 + && g.inner.result.get().is_none_or(|e| e.is_none()) + } + _ => self.inner.result.get().is_some(), + } } pub fn has_error(&self) -> bool { @@ -253,12 +387,22 @@ impl Completion { } pub fn get_error(&self) -> Option { - self.inner.result.get().and_then(|res| *res) + match &self.inner.completion_type { + CompletionType::Group(g) => { + // For groups, check the group's cached result field + // (set when the last completion finishes) + g.inner.result.get().and_then(|res| *res) + } + _ => self.inner.result.get().and_then(|res| *res), + } } /// Checks if the Completion completed or errored pub fn finished(&self) -> bool { - self.inner.result.get().is_some() + match &self.inner.completion_type { + CompletionType::Group(g) => g.inner.outstanding.load(Ordering::SeqCst) == 0, + _ => self.inner.result.get().is_some(), + } } pub fn complete(&self, result: i32) { @@ -282,11 +426,27 @@ impl Completion { CompletionType::Write(w) => w.callback(result), CompletionType::Sync(s) => s.callback(result), // fix CompletionType::Truncate(t) => t.callback(result), + CompletionType::Group(g) => g.callback(result), }; self.inner .result .set(result.err()) .expect("result must be set only once"); + + if let Some(group) = self.inner.parent.get() { + // Capture first error in group + if let Err(err) = result { + let _ = group.result.set(Some(err)); + } + let prev = group.outstanding.fetch_sub(1, Ordering::SeqCst); + + // If this was the last completion, call the group callback + if prev == 1 { + let group_result = group.result.get().and_then(|e| *e); + (group.complete)(group_result.map_or(Ok(0), Err)); + } + // TODO: remove self from parent group + } }); } @@ -307,6 +467,19 @@ impl Completion { _ => unreachable!(), } } + + /// Link this completion to a group completion (internal use only) + fn link_internal(&mut self, group: &Completion) { + let group_inner = match &group.inner.completion_type { + CompletionType::Group(g) => &g.inner, + _ => panic!("link_internal() requires a group completion"), + }; + + // Set the parent (can only be set once) + if self.inner.parent.set(group_inner.clone()).is_err() { + panic!("completion can only be linked once"); + } + } } pub struct ReadCompletion { @@ -563,3 +736,221 @@ pub use memory::MemoryIO; pub mod clock; mod common; pub use clock::Clock; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_completion_group_empty() { + let group = CompletionGroup::new(|_| {}); + let group = group.build(); + assert!(group.finished()); + assert!(group.is_completed()); + assert!(group.get_error().is_none()); + } + + #[test] + fn test_completion_group_single_completion() { + let mut group = CompletionGroup::new(|_| {}); + let c = Completion::new_write(|_| {}); + group.add(&c); + let group = group.build(); + + assert!(!group.finished()); + assert!(!group.is_completed()); + + c.complete(0); + + assert!(group.finished()); + assert!(group.is_completed()); + assert!(group.get_error().is_none()); + } + + #[test] + fn test_completion_group_multiple_completions() { + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + let c3 = Completion::new_write(|_| {}); + group.add(&c1); + group.add(&c2); + group.add(&c3); + let group = group.build(); + + assert!(!group.is_completed()); + assert!(!group.finished()); + + c1.complete(0); + assert!(!group.is_completed()); + assert!(!group.finished()); + + c2.complete(0); + assert!(!group.is_completed()); + assert!(!group.finished()); + + c3.complete(0); + assert!(group.is_completed()); + assert!(group.finished()); + } + + #[test] + fn test_completion_group_with_error() { + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + group.add(&c1); + group.add(&c2); + let group = group.build(); + + c1.complete(0); + c2.error(CompletionError::Aborted); + + assert!(group.finished()); + assert!(!group.is_completed()); + assert_eq!(group.get_error(), Some(CompletionError::Aborted)); + } + + #[test] + fn test_completion_group_callback() { + use std::sync::atomic::{AtomicBool, Ordering}; + let called = Arc::new(AtomicBool::new(false)); + let called_clone = called.clone(); + + let mut group = CompletionGroup::new(move |_| { + called_clone.store(true, Ordering::SeqCst); + }); + + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + group.add(&c1); + group.add(&c2); + let group = group.build(); + + assert!(!called.load(Ordering::SeqCst)); + + c1.complete(0); + assert!(!called.load(Ordering::SeqCst)); + + c2.complete(0); + assert!(called.load(Ordering::SeqCst)); + assert!(group.finished()); + assert!(group.is_completed()); + } + + #[test] + fn test_completion_group_some_already_completed() { + // Test some completions added to group, then finish before build() + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + let c3 = Completion::new_write(|_| {}); + + // Add all to group while pending + group.add(&c1); + group.add(&c2); + group.add(&c3); + + // Complete c1 and c2 AFTER adding but BEFORE build() + c1.complete(0); + c2.complete(0); + + let group = group.build(); + + // c1 and c2 finished before build(), so outstanding should account for them + // Only c3 should be pending + assert!(!group.finished()); + assert!(!group.is_completed()); + + // Complete c3 + c3.complete(0); + + // Now the group should be finished + assert!(group.finished()); + assert!(group.is_completed()); + assert!(group.get_error().is_none()); + } + + #[test] + fn test_completion_group_all_already_completed() { + // Test when all completions are already finished before build() + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + + // Complete both before adding to group + c1.complete(0); + c2.complete(0); + + group.add(&c1); + group.add(&c2); + + let group = group.build(); + + // All completions were already complete, so group should be finished immediately + assert!(group.finished()); + assert!(group.is_completed()); + assert!(group.get_error().is_none()); + } + + #[test] + fn test_completion_group_mixed_finished_and_pending() { + use std::sync::atomic::{AtomicBool, Ordering}; + let called = Arc::new(AtomicBool::new(false)); + let called_clone = called.clone(); + + let mut group = CompletionGroup::new(move |_| { + called_clone.store(true, Ordering::SeqCst); + }); + + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + let c3 = Completion::new_write(|_| {}); + let c4 = Completion::new_write(|_| {}); + + // Complete c1 and c3 before adding to group + c1.complete(0); + c3.complete(0); + + group.add(&c1); + group.add(&c2); + group.add(&c3); + group.add(&c4); + + let group = group.build(); + + // Only c2 and c4 should be pending + assert!(!group.finished()); + assert!(!called.load(Ordering::SeqCst)); + + c2.complete(0); + assert!(!group.finished()); + assert!(!called.load(Ordering::SeqCst)); + + c4.complete(0); + assert!(group.finished()); + assert!(group.is_completed()); + assert!(called.load(Ordering::SeqCst)); + } + + #[test] + fn test_completion_group_already_completed_with_error() { + // Test when a completion finishes with error before build() + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + + // Complete c1 with error before adding to group + c1.error(CompletionError::Aborted); + + group.add(&c1); + group.add(&c2); + + let group = group.build(); + + // Group should immediately fail with the error + assert!(group.finished()); + assert!(!group.is_completed()); + assert_eq!(group.get_error(), Some(CompletionError::Aborted)); + } +} diff --git a/core/lib.rs b/core/lib.rs index 5e0288670..11b85be81 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -57,8 +57,8 @@ pub use io::UnixIO; #[cfg(all(feature = "fs", target_os = "linux", feature = "io_uring"))] pub use io::UringIO; pub use io::{ - Buffer, Completion, CompletionType, File, MemoryIO, OpenFlags, PlatformIO, SyscallIO, - WriteCompletion, IO, + Buffer, Completion, CompletionType, File, GroupCompletion, MemoryIO, OpenFlags, PlatformIO, + SyscallIO, WriteCompletion, IO, }; use parking_lot::RwLock; use schema::Schema; diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 881afacd0..c590219bd 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -15,6 +15,7 @@ use super::buffer_pool::BufferPool; use super::pager::{PageRef, Pager}; use super::sqlite3_ondisk::{self, checksum_wal, WalHeader, WAL_MAGIC_BE, WAL_MAGIC_LE}; use crate::fast_lock::SpinLock; +use crate::io::CompletionGroup; use crate::io::{clock, File, IO}; use crate::storage::database::EncryptionOrChecksum; use crate::storage::sqlite3_ondisk::{ @@ -23,8 +24,8 @@ use crate::storage::sqlite3_ondisk::{ }; use crate::types::{IOCompletions, IOResult}; use crate::{ - bail_corrupt_error, io_yield_many, io_yield_one, return_if_io, turso_assert, Buffer, - Completion, CompletionError, IOContext, LimboError, Result, + bail_corrupt_error, io_yield_one, return_if_io, turso_assert, Buffer, Completion, + CompletionError, IOContext, LimboError, Result, }; #[derive(Debug, Clone, Default)] @@ -1823,8 +1824,9 @@ impl WalFile { // to prevent serialization, and we try to issue reads and flush batches concurrently // if at all possible, at the cost of some batching potential. CheckpointState::Processing => { - // Gather I/O completions, estimate with MAX_INFLIGHT_WRITES to prevent realloc - let mut completions = Vec::with_capacity(MAX_INFLIGHT_WRITES); + // Gather I/O completions using a completion group + let mut nr_completions = 0; + let mut group = CompletionGroup::new(|_| {}); // Check and clean any completed writes from pending flush if self.ongoing_checkpoint.process_inflight_writes() { @@ -1891,7 +1893,8 @@ impl WalFile { // the frame requirements let inflight = self.issue_wal_read_into_buffer(page_id as usize, target_frame)?; - completions.push(inflight.completion.clone()); + group.add(&inflight.completion); + nr_completions += 1; self.ongoing_checkpoint.inflight_reads.push(inflight); self.ongoing_checkpoint.current_page += 1; } @@ -1903,12 +1906,15 @@ impl WalFile { let batch_map = self.ongoing_checkpoint.pending_writes.take(); if !batch_map.is_empty() { let done_flag = self.ongoing_checkpoint.add_write(); - completions.extend(write_pages_vectored(pager, batch_map, done_flag)?); + for c in write_pages_vectored(pager, batch_map, done_flag)? { + group.add(&c); + nr_completions += 1; + } } } - if !completions.is_empty() { - io_yield_many!(completions); + if nr_completions > 0 { + io_yield_one!(group.build()); } else if self.ongoing_checkpoint.complete() { self.ongoing_checkpoint.state = CheckpointState::Finalize; }