diff --git a/core/io/mod.rs b/core/io/mod.rs index 756424b84..ce01b88e5 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -1,6 +1,6 @@ use crate::storage::buffer_pool::ArenaBuffer; use crate::storage::sqlite3_ondisk::WAL_FRAME_HEADER_SIZE; -use crate::{BufferPool, CompletionError, Result}; +use crate::{turso_assert, BufferPool, CompletionError, Result}; use bitflags::bitflags; use cfg_block::cfg_block; use std::cell::RefCell; @@ -37,13 +37,15 @@ pub trait File: Send + Sync { let total_written = total_written.clone(); let _cloned = buf.clone(); Completion::new_write(move |n| { - // reference buffer in callback to ensure alive for async io - let _buf = _cloned.clone(); - // accumulate bytes actually reported by the backend - total_written.fetch_add(n as usize, Ordering::Relaxed); - if outstanding.fetch_sub(1, Ordering::AcqRel) == 1 { - // last one finished - c_main.complete(total_written.load(Ordering::Acquire) as i32); + if let Ok(n) = n { + // reference buffer in callback to ensure alive for async io + let _buf = _cloned.clone(); + // accumulate bytes actually reported by the backend + total_written.fetch_add(n as usize, Ordering::Relaxed); + if outstanding.fetch_sub(1, Ordering::AcqRel) == 1 { + // last one finished + c_main.complete(total_written.load(Ordering::Acquire) as i32); + } } }) }; @@ -110,10 +112,10 @@ pub trait IO: Clock + Send + Sync { } } -pub type ReadComplete = dyn Fn(Arc, i32); -pub type WriteComplete = dyn Fn(i32); -pub type SyncComplete = dyn Fn(i32); -pub type TruncateComplete = dyn Fn(i32); +pub type ReadComplete = dyn Fn(Result<(Arc, i32), CompletionError>); +pub type WriteComplete = dyn Fn(Result); +pub type SyncComplete = dyn Fn(Result); +pub type TruncateComplete = dyn Fn(Result); #[must_use] #[derive(Debug, Clone)] @@ -159,7 +161,7 @@ impl Completion { pub fn new_write(complete: F) -> Self where - F: Fn(i32) + 'static, + F: Fn(Result) + 'static, { Self::new(CompletionType::Write(WriteCompletion::new(Box::new( complete, @@ -168,7 +170,7 @@ impl Completion { pub fn new_read(buf: Arc, complete: F) -> Self where - F: Fn(Arc, i32) + 'static, + F: Fn(Result<(Arc, i32), CompletionError>) + 'static, { Self::new(CompletionType::Read(ReadCompletion::new( buf, @@ -177,7 +179,7 @@ impl Completion { } pub fn new_sync(complete: F) -> Self where - F: Fn(i32) + 'static, + F: Fn(Result) + 'static, { Self::new(CompletionType::Sync(SyncCompletion::new(Box::new( complete, @@ -186,7 +188,7 @@ impl Completion { pub fn new_trunc(complete: F) -> Self where - F: Fn(i32) + 'static, + F: Fn(Result) + 'static, { Self::new(CompletionType::Truncate(TruncateCompletion::new(Box::new( complete, @@ -209,17 +211,39 @@ impl Completion { } pub fn complete(&self, result: i32) { - if !self.inner.is_completed.get() { + if !self.has_error() && !self.inner.is_completed.get() { + let result = Ok(result); match &self.inner.completion_type { - CompletionType::Read(r) => r.complete(result), - CompletionType::Write(w) => w.complete(result), - CompletionType::Sync(s) => s.complete(result), // fix - CompletionType::Truncate(t) => t.complete(result), + CompletionType::Read(r) => r.callback(result), + CompletionType::Write(w) => w.callback(result), + CompletionType::Sync(s) => s.callback(result), // fix + CompletionType::Truncate(t) => t.callback(result), }; self.inner.is_completed.set(true); } } + pub fn error(&self, err: CompletionError) { + turso_assert!( + !self.is_completed(), + "should not error a completed Completion" + ); + if !self.has_error() { + let result = Err(err); + match &self.inner.completion_type { + CompletionType::Read(r) => r.callback(result), + CompletionType::Write(w) => w.callback(result), + CompletionType::Sync(s) => s.callback(result), // fix + CompletionType::Truncate(t) => t.callback(result), + }; + self.inner.error.get_or_init(|| err); + } + } + + pub fn abort(&self) { + self.error(CompletionError::Aborted); + } + /// only call this method if you are sure that the completion is /// a ReadCompletion, panics otherwise pub fn as_read(&self) -> &ReadCompletion { @@ -253,8 +277,8 @@ impl ReadCompletion { &self.buf } - pub fn complete(&self, bytes_read: i32) { - (self.complete)(self.buf.clone(), bytes_read); + pub fn callback(&self, bytes_read: Result) { + (self.complete)(bytes_read.map(|b| (self.buf.clone(), b))); } } @@ -267,7 +291,7 @@ impl WriteCompletion { Self { complete } } - pub fn complete(&self, bytes_written: i32) { + pub fn callback(&self, bytes_written: Result) { (self.complete)(bytes_written); } } @@ -281,7 +305,7 @@ impl SyncCompletion { Self { complete } } - pub fn complete(&self, res: i32) { + pub fn callback(&self, res: Result) { (self.complete)(res); } } @@ -295,7 +319,7 @@ impl TruncateCompletion { Self { complete } } - pub fn complete(&self, res: i32) { + pub fn callback(&self, res: Result) { (self.complete)(res); } } diff --git a/core/lib.rs b/core/lib.rs index c7226818f..e572f5c7f 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -448,7 +448,7 @@ impl Database { "header read must be a multiple of 512 for O_DIRECT" ); let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize)); - let c = Completion::new_read(buf.clone(), move |_buf, _| {}); + let c = Completion::new_read(buf.clone(), move |_res| {}); let c = self.db_file.read_header(c)?; self.io.wait_for_completion(c)?; let page_size = u16::from_be_bytes(buf.as_slice()[16..18].try_into().unwrap()); diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index e5a54437c..7091a8b9a 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -62,7 +62,7 @@ use crate::storage::database::DatabaseStorage; use crate::storage::pager::Pager; use crate::storage::wal::{PendingFlush, READMARK_NOT_USED}; use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype}; -use crate::{bail_corrupt_error, turso_assert, File, Result, WalFileShared}; +use crate::{bail_corrupt_error, turso_assert, CompletionError, File, Result, WalFileShared}; use std::cell::{Cell, UnsafeCell}; use std::collections::{BTreeMap, HashMap}; use std::mem::MaybeUninit; @@ -874,7 +874,10 @@ pub fn begin_read_page( let buf = buffer_pool.get_page(); #[allow(clippy::arc_with_non_send_sync)] let buf = Arc::new(buf); - let complete = Box::new(move |mut buf: Arc, bytes_read: i32| { + let complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { + let Ok((mut buf, bytes_read)) = res else { + return; + }; let buf_len = buf.len(); turso_assert!( (allow_empty_read && bytes_read == 0) || bytes_read == buf_len as i32, @@ -922,7 +925,10 @@ pub fn begin_write_btree_page(pager: &Pager, page: &PageRef) -> Result| { + let Ok(bytes_written) = res else { + return; + }; tracing::trace!("finish_write_btree_page"); let buf_copy = buf_copy.clone(); let buf_len = buf_copy.len(); @@ -1016,6 +1022,9 @@ pub fn write_pages_vectored( let total_sz = (page_sz * run_bufs.len()) as i32; let c = Completion::new_write(move |res| { + let Ok(res) = res else { + return; + }; // writev calls can sometimes return partial writes, but our `pwritev` // implementation aggregates any partial writes and calls completion with total turso_assert!(total_sz == res, "failed to write expected size"); @@ -1586,7 +1595,10 @@ pub fn read_entire_wal_dumb(file: &Arc) -> Result = Box::new(move |buf: Arc, bytes_read: i32| { + let complete: Box = Box::new(move |res: Result<(Arc, i32), _>| { + let Ok((buf, bytes_read)) = res else { + return; + }; let buf_slice = buf.as_slice(); turso_assert!( bytes_read == buf_slice.len() as i32, @@ -1884,7 +1896,10 @@ pub fn begin_write_wal_header(io: &Arc, header: &WalHeader) -> Result< }; let cloned = buffer.clone(); - let write_complete = move |bytes_written: i32| { + let write_complete = move |res: Result| { + let Ok(bytes_written) = res else { + return; + }; // make sure to reference buffer so it's alive for async IO let _buf = cloned.clone(); turso_assert!( diff --git a/core/storage/wal.rs b/core/storage/wal.rs index a8e5bb36b..b4ba856df 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -20,7 +20,8 @@ use crate::storage::sqlite3_ondisk::{ }; use crate::types::{IOCompletions, IOResult}; use crate::{ - bail_corrupt_error, io_yield_many, io_yield_one, turso_assert, Buffer, LimboError, Result, + bail_corrupt_error, io_yield_many, io_yield_one, turso_assert, Buffer, CompletionError, + LimboError, Result, }; use crate::{Completion, Page}; @@ -912,7 +913,10 @@ impl Wal for WalFile { let offset = self.frame_offset(frame_id); page.set_locked(); let frame = page.clone(); - let complete = Box::new(move |buf: Arc, bytes_read: i32| { + let complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { + let Ok((buf, bytes_read)) = res else { + return; + }; let buf_len = buf.len(); turso_assert!( bytes_read == buf_len as i32, @@ -934,7 +938,10 @@ impl Wal for WalFile { tracing::debug!("read_frame({})", frame_id); let offset = self.frame_offset(frame_id); let (frame_ptr, frame_len) = (frame.as_mut_ptr(), frame.len()); - let complete = Box::new(move |buf: Arc, bytes_read: i32| { + let complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { + let Ok((buf, bytes_read)) = res else { + return; + }; let buf_len = buf.len(); turso_assert!( bytes_read == buf_len as i32, @@ -985,7 +992,10 @@ impl Wal for WalFile { let (page_ptr, page_len) = (page.as_ptr(), page.len()); let complete = Box::new({ let conflict = conflict.clone(); - move |buf: Arc, bytes_read: i32| { + move |res: Result<(Arc, i32), CompletionError>| { + let Ok((buf, bytes_read)) = res else { + return; + }; let buf_len = buf.len(); turso_assert!( bytes_read == buf_len as i32, @@ -1077,7 +1087,10 @@ impl Wal for WalFile { let c = Completion::new_write({ let frame_bytes = frame_bytes.clone(); - move |bytes_written| { + move |res: Result| { + let Ok(bytes_written) = res else { + return; + }; let frame_len = frame_bytes.len(); turso_assert!( bytes_written == frame_len as i32, diff --git a/core/vdbe/sorter.rs b/core/vdbe/sorter.rs index 72d26d510..121a87d69 100644 --- a/core/vdbe/sorter.rs +++ b/core/vdbe/sorter.rs @@ -18,7 +18,7 @@ use crate::{ types::{IOResult, ImmutableRecord, KeyInfo, RecordCursor, RefValue}, Result, }; -use crate::{io_yield_many, io_yield_one, return_if_io}; +use crate::{io_yield_many, io_yield_one, return_if_io, CompletionError}; #[derive(Debug, Clone, Copy)] enum SortState { @@ -489,7 +489,10 @@ impl SortedChunk { let stored_buffer_copy = self.buffer.clone(); let stored_buffer_len_copy = self.buffer_len.clone(); let total_bytes_read_copy = self.total_bytes_read.clone(); - let read_complete = Box::new(move |buf: Arc, bytes_read: i32| { + let read_complete = Box::new(move |res: Result<(Arc, i32), CompletionError>| { + let Ok((buf, bytes_read)) = res else { + return; + }; let read_buf_ref = buf.clone(); let read_buf = read_buf_ref.as_slice(); @@ -547,7 +550,10 @@ impl SortedChunk { let buffer_ref_copy = buffer_ref.clone(); let chunk_io_state_copy = self.io_state.clone(); - let write_complete = Box::new(move |bytes_written: i32| { + let write_complete = Box::new(move |res: Result| { + let Ok(bytes_written) = res else { + return; + }; chunk_io_state_copy.set(SortedChunkIOState::WriteComplete); let buf_len = buffer_ref_copy.len(); if bytes_written < buf_len as i32 { diff --git a/packages/turso-sync-engine/src/database_sync_operations.rs b/packages/turso-sync-engine/src/database_sync_operations.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/packages/turso-sync-engine/src/database_sync_operations.rs @@ -0,0 +1 @@ + diff --git a/packages/turso-sync-engine/src/io_operations.rs b/packages/turso-sync-engine/src/io_operations.rs index 8b1378917..139597f9c 100644 --- a/packages/turso-sync-engine/src/io_operations.rs +++ b/packages/turso-sync-engine/src/io_operations.rs @@ -1 +1,2 @@ + diff --git a/sync/engine/src/database_sync_operations.rs b/sync/engine/src/database_sync_operations.rs index 3851db953..8621ffe3a 100644 --- a/sync/engine/src/database_sync_operations.rs +++ b/sync/engine/src/database_sync_operations.rs @@ -74,7 +74,10 @@ pub async fn db_bootstrap( buffer.as_mut_slice().copy_from_slice(chunk); let mut completions = Vec::with_capacity(dbs.len()); for db in dbs { - let c = Completion::new_write(move |size| { + let c = Completion::new_write(move |res| { + let Ok(size) = res else { + return; + }; // todo(sivukhin): we need to error out in case of partial read assert!(size as usize == content_len); }); @@ -818,7 +821,10 @@ pub async fn reset_wal_file( WAL_HEADER + WAL_FRAME_SIZE * (frames_count as usize) }; tracing::debug!("reset db wal to the size of {} frames", frames_count); - let c = Completion::new_trunc(move |rc| { + let c = Completion::new_trunc(move |res| { + let Ok(rc) = res else { + return; + }; assert!(rc as usize == 0); }); let c = wal.truncate(wal_size, c)?; diff --git a/sync/engine/src/io_operations.rs b/sync/engine/src/io_operations.rs index 772287e70..517ad2601 100644 --- a/sync/engine/src/io_operations.rs +++ b/sync/engine/src/io_operations.rs @@ -53,7 +53,12 @@ impl IoOperations for Arc { file: Arc, len: usize, ) -> Result<()> { - let c = Completion::new_trunc(move |rc| tracing::debug!("file truncated: rc={}", rc)); + let c = Completion::new_trunc(move |rc| { + let Ok(rc) = rc else { + return; + }; + tracing::debug!("file truncated: rc={}", rc); + }); let c = file.truncate(len, c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?;