change completion callbacks to take a Result param + create separate functions to declare a completion errored

This commit is contained in:
pedrocarlo
2025-08-13 17:17:14 -03:00
parent 71ca221390
commit ab3b68e360
9 changed files with 114 additions and 43 deletions

View File

@@ -1,6 +1,6 @@
use crate::storage::buffer_pool::ArenaBuffer;
use crate::storage::sqlite3_ondisk::WAL_FRAME_HEADER_SIZE;
use crate::{BufferPool, CompletionError, Result};
use crate::{turso_assert, BufferPool, CompletionError, Result};
use bitflags::bitflags;
use cfg_block::cfg_block;
use std::cell::RefCell;
@@ -37,13 +37,15 @@ pub trait File: Send + Sync {
let total_written = total_written.clone();
let _cloned = buf.clone();
Completion::new_write(move |n| {
// reference buffer in callback to ensure alive for async io
let _buf = _cloned.clone();
// accumulate bytes actually reported by the backend
total_written.fetch_add(n as usize, Ordering::Relaxed);
if outstanding.fetch_sub(1, Ordering::AcqRel) == 1 {
// last one finished
c_main.complete(total_written.load(Ordering::Acquire) as i32);
if let Ok(n) = n {
// reference buffer in callback to ensure alive for async io
let _buf = _cloned.clone();
// accumulate bytes actually reported by the backend
total_written.fetch_add(n as usize, Ordering::Relaxed);
if outstanding.fetch_sub(1, Ordering::AcqRel) == 1 {
// last one finished
c_main.complete(total_written.load(Ordering::Acquire) as i32);
}
}
})
};
@@ -110,10 +112,10 @@ pub trait IO: Clock + Send + Sync {
}
}
pub type ReadComplete = dyn Fn(Arc<Buffer>, i32);
pub type WriteComplete = dyn Fn(i32);
pub type SyncComplete = dyn Fn(i32);
pub type TruncateComplete = dyn Fn(i32);
pub type ReadComplete = dyn Fn(Result<(Arc<Buffer>, i32), CompletionError>);
pub type WriteComplete = dyn Fn(Result<i32, CompletionError>);
pub type SyncComplete = dyn Fn(Result<i32, CompletionError>);
pub type TruncateComplete = dyn Fn(Result<i32, CompletionError>);
#[must_use]
#[derive(Debug, Clone)]
@@ -159,7 +161,7 @@ impl Completion {
pub fn new_write<F>(complete: F) -> Self
where
F: Fn(i32) + 'static,
F: Fn(Result<i32, CompletionError>) + 'static,
{
Self::new(CompletionType::Write(WriteCompletion::new(Box::new(
complete,
@@ -168,7 +170,7 @@ impl Completion {
pub fn new_read<F>(buf: Arc<Buffer>, complete: F) -> Self
where
F: Fn(Arc<Buffer>, i32) + 'static,
F: Fn(Result<(Arc<Buffer>, i32), CompletionError>) + 'static,
{
Self::new(CompletionType::Read(ReadCompletion::new(
buf,
@@ -177,7 +179,7 @@ impl Completion {
}
pub fn new_sync<F>(complete: F) -> Self
where
F: Fn(i32) + 'static,
F: Fn(Result<i32, CompletionError>) + 'static,
{
Self::new(CompletionType::Sync(SyncCompletion::new(Box::new(
complete,
@@ -186,7 +188,7 @@ impl Completion {
pub fn new_trunc<F>(complete: F) -> Self
where
F: Fn(i32) + 'static,
F: Fn(Result<i32, CompletionError>) + 'static,
{
Self::new(CompletionType::Truncate(TruncateCompletion::new(Box::new(
complete,
@@ -209,17 +211,39 @@ impl Completion {
}
pub fn complete(&self, result: i32) {
if !self.inner.is_completed.get() {
if !self.has_error() && !self.inner.is_completed.get() {
let result = Ok(result);
match &self.inner.completion_type {
CompletionType::Read(r) => r.complete(result),
CompletionType::Write(w) => w.complete(result),
CompletionType::Sync(s) => s.complete(result), // fix
CompletionType::Truncate(t) => t.complete(result),
CompletionType::Read(r) => r.callback(result),
CompletionType::Write(w) => w.callback(result),
CompletionType::Sync(s) => s.callback(result), // fix
CompletionType::Truncate(t) => t.callback(result),
};
self.inner.is_completed.set(true);
}
}
pub fn error(&self, err: CompletionError) {
turso_assert!(
!self.is_completed(),
"should not error a completed Completion"
);
if !self.has_error() {
let result = Err(err);
match &self.inner.completion_type {
CompletionType::Read(r) => r.callback(result),
CompletionType::Write(w) => w.callback(result),
CompletionType::Sync(s) => s.callback(result), // fix
CompletionType::Truncate(t) => t.callback(result),
};
self.inner.error.get_or_init(|| err);
}
}
pub fn abort(&self) {
self.error(CompletionError::Aborted);
}
/// only call this method if you are sure that the completion is
/// a ReadCompletion, panics otherwise
pub fn as_read(&self) -> &ReadCompletion {
@@ -253,8 +277,8 @@ impl ReadCompletion {
&self.buf
}
pub fn complete(&self, bytes_read: i32) {
(self.complete)(self.buf.clone(), bytes_read);
pub fn callback(&self, bytes_read: Result<i32, CompletionError>) {
(self.complete)(bytes_read.map(|b| (self.buf.clone(), b)));
}
}
@@ -267,7 +291,7 @@ impl WriteCompletion {
Self { complete }
}
pub fn complete(&self, bytes_written: i32) {
pub fn callback(&self, bytes_written: Result<i32, CompletionError>) {
(self.complete)(bytes_written);
}
}
@@ -281,7 +305,7 @@ impl SyncCompletion {
Self { complete }
}
pub fn complete(&self, res: i32) {
pub fn callback(&self, res: Result<i32, CompletionError>) {
(self.complete)(res);
}
}
@@ -295,7 +319,7 @@ impl TruncateCompletion {
Self { complete }
}
pub fn complete(&self, res: i32) {
pub fn callback(&self, res: Result<i32, CompletionError>) {
(self.complete)(res);
}
}

View File

@@ -448,7 +448,7 @@ impl Database {
"header read must be a multiple of 512 for O_DIRECT"
);
let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize));
let c = Completion::new_read(buf.clone(), move |_buf, _| {});
let c = Completion::new_read(buf.clone(), move |_res| {});
let c = self.db_file.read_header(c)?;
self.io.wait_for_completion(c)?;
let page_size = u16::from_be_bytes(buf.as_slice()[16..18].try_into().unwrap());

View File

@@ -62,7 +62,7 @@ use crate::storage::database::DatabaseStorage;
use crate::storage::pager::Pager;
use crate::storage::wal::{PendingFlush, READMARK_NOT_USED};
use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype};
use crate::{bail_corrupt_error, turso_assert, File, Result, WalFileShared};
use crate::{bail_corrupt_error, turso_assert, CompletionError, File, Result, WalFileShared};
use std::cell::{Cell, UnsafeCell};
use std::collections::{BTreeMap, HashMap};
use std::mem::MaybeUninit;
@@ -874,7 +874,10 @@ pub fn begin_read_page(
let buf = buffer_pool.get_page();
#[allow(clippy::arc_with_non_send_sync)]
let buf = Arc::new(buf);
let complete = Box::new(move |mut buf: Arc<Buffer>, bytes_read: i32| {
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((mut buf, bytes_read)) = res else {
return;
};
let buf_len = buf.len();
turso_assert!(
(allow_empty_read && bytes_read == 0) || bytes_read == buf_len as i32,
@@ -922,7 +925,10 @@ pub fn begin_write_btree_page(pager: &Pager, page: &PageRef) -> Result<Completio
let write_complete = {
let buf_copy = buffer.clone();
Box::new(move |bytes_written: i32| {
Box::new(move |res: Result<i32, CompletionError>| {
let Ok(bytes_written) = res else {
return;
};
tracing::trace!("finish_write_btree_page");
let buf_copy = buf_copy.clone();
let buf_len = buf_copy.len();
@@ -1016,6 +1022,9 @@ pub fn write_pages_vectored(
let total_sz = (page_sz * run_bufs.len()) as i32;
let c = Completion::new_write(move |res| {
let Ok(res) = res else {
return;
};
// writev calls can sometimes return partial writes, but our `pwritev`
// implementation aggregates any partial writes and calls completion with total
turso_assert!(total_sz == res, "failed to write expected size");
@@ -1586,7 +1595,10 @@ pub fn read_entire_wal_dumb(file: &Arc<dyn File>) -> Result<Arc<UnsafeCell<WalFi
}));
let wal_file_shared_for_completion = wal_file_shared_ret.clone();
let complete: Box<ReadComplete> = Box::new(move |buf: Arc<Buffer>, bytes_read: i32| {
let complete: Box<ReadComplete> = Box::new(move |res: Result<(Arc<Buffer>, i32), _>| {
let Ok((buf, bytes_read)) = res else {
return;
};
let buf_slice = buf.as_slice();
turso_assert!(
bytes_read == buf_slice.len() as i32,
@@ -1884,7 +1896,10 @@ pub fn begin_write_wal_header(io: &Arc<dyn File>, header: &WalHeader) -> Result<
};
let cloned = buffer.clone();
let write_complete = move |bytes_written: i32| {
let write_complete = move |res: Result<i32, CompletionError>| {
let Ok(bytes_written) = res else {
return;
};
// make sure to reference buffer so it's alive for async IO
let _buf = cloned.clone();
turso_assert!(

View File

@@ -20,7 +20,8 @@ use crate::storage::sqlite3_ondisk::{
};
use crate::types::{IOCompletions, IOResult};
use crate::{
bail_corrupt_error, io_yield_many, io_yield_one, turso_assert, Buffer, LimboError, Result,
bail_corrupt_error, io_yield_many, io_yield_one, turso_assert, Buffer, CompletionError,
LimboError, Result,
};
use crate::{Completion, Page};
@@ -912,7 +913,10 @@ impl Wal for WalFile {
let offset = self.frame_offset(frame_id);
page.set_locked();
let frame = page.clone();
let complete = Box::new(move |buf: Arc<Buffer>, bytes_read: i32| {
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
return;
};
let buf_len = buf.len();
turso_assert!(
bytes_read == buf_len as i32,
@@ -934,7 +938,10 @@ impl Wal for WalFile {
tracing::debug!("read_frame({})", frame_id);
let offset = self.frame_offset(frame_id);
let (frame_ptr, frame_len) = (frame.as_mut_ptr(), frame.len());
let complete = Box::new(move |buf: Arc<Buffer>, bytes_read: i32| {
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
return;
};
let buf_len = buf.len();
turso_assert!(
bytes_read == buf_len as i32,
@@ -985,7 +992,10 @@ impl Wal for WalFile {
let (page_ptr, page_len) = (page.as_ptr(), page.len());
let complete = Box::new({
let conflict = conflict.clone();
move |buf: Arc<Buffer>, bytes_read: i32| {
move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
return;
};
let buf_len = buf.len();
turso_assert!(
bytes_read == buf_len as i32,
@@ -1077,7 +1087,10 @@ impl Wal for WalFile {
let c = Completion::new_write({
let frame_bytes = frame_bytes.clone();
move |bytes_written| {
move |res: Result<i32, CompletionError>| {
let Ok(bytes_written) = res else {
return;
};
let frame_len = frame_bytes.len();
turso_assert!(
bytes_written == frame_len as i32,

View File

@@ -18,7 +18,7 @@ use crate::{
types::{IOResult, ImmutableRecord, KeyInfo, RecordCursor, RefValue},
Result,
};
use crate::{io_yield_many, io_yield_one, return_if_io};
use crate::{io_yield_many, io_yield_one, return_if_io, CompletionError};
#[derive(Debug, Clone, Copy)]
enum SortState {
@@ -489,7 +489,10 @@ impl SortedChunk {
let stored_buffer_copy = self.buffer.clone();
let stored_buffer_len_copy = self.buffer_len.clone();
let total_bytes_read_copy = self.total_bytes_read.clone();
let read_complete = Box::new(move |buf: Arc<Buffer>, bytes_read: i32| {
let read_complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
return;
};
let read_buf_ref = buf.clone();
let read_buf = read_buf_ref.as_slice();
@@ -547,7 +550,10 @@ impl SortedChunk {
let buffer_ref_copy = buffer_ref.clone();
let chunk_io_state_copy = self.io_state.clone();
let write_complete = Box::new(move |bytes_written: i32| {
let write_complete = Box::new(move |res: Result<i32, CompletionError>| {
let Ok(bytes_written) = res else {
return;
};
chunk_io_state_copy.set(SortedChunkIOState::WriteComplete);
let buf_len = buffer_ref_copy.len();
if bytes_written < buf_len as i32 {

View File

@@ -0,0 +1 @@

View File

@@ -74,7 +74,10 @@ pub async fn db_bootstrap<C: ProtocolIO>(
buffer.as_mut_slice().copy_from_slice(chunk);
let mut completions = Vec::with_capacity(dbs.len());
for db in dbs {
let c = Completion::new_write(move |size| {
let c = Completion::new_write(move |res| {
let Ok(size) = res else {
return;
};
// todo(sivukhin): we need to error out in case of partial read
assert!(size as usize == content_len);
});
@@ -818,7 +821,10 @@ pub async fn reset_wal_file(
WAL_HEADER + WAL_FRAME_SIZE * (frames_count as usize)
};
tracing::debug!("reset db wal to the size of {} frames", frames_count);
let c = Completion::new_trunc(move |rc| {
let c = Completion::new_trunc(move |res| {
let Ok(rc) = res else {
return;
};
assert!(rc as usize == 0);
});
let c = wal.truncate(wal_size, c)?;

View File

@@ -53,7 +53,12 @@ impl IoOperations for Arc<dyn turso_core::IO> {
file: Arc<dyn turso_core::File>,
len: usize,
) -> Result<()> {
let c = Completion::new_trunc(move |rc| tracing::debug!("file truncated: rc={}", rc));
let c = Completion::new_trunc(move |rc| {
let Ok(rc) = rc else {
return;
};
tracing::debug!("file truncated: rc={}", rc);
});
let c = file.truncate(len, c)?;
while !c.is_completed() {
coro.yield_(ProtocolCommand::IO).await?;