diff --git a/bindings/javascript/src/lib.rs b/bindings/javascript/src/lib.rs index e15cdaf7f..aa0c4772b 100644 --- a/bindings/javascript/src/lib.rs +++ b/bindings/javascript/src/lib.rs @@ -718,6 +718,17 @@ impl turso_core::DatabaseStorage for DatabaseFile { let pos = (page_idx - 1) * size; self.file.pwrite(pos, buffer, c) } + fn write_pages( + &self, + page_idx: usize, + page_size: usize, + buffers: Vec>>, + c: turso_core::Completion, + ) -> turso_core::Result { + let pos = page_idx.saturating_sub(1) * page_size; + let c = self.file.pwritev(pos, buffers, c)?; + Ok(c) + } fn sync(&self, c: turso_core::Completion) -> turso_core::Result { self.file.sync(c) diff --git a/core/Cargo.toml b/core/Cargo.toml index 651282010..2adc63372 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -19,7 +19,7 @@ default = ["fs", "uuid", "time", "json", "series"] fs = ["turso_ext/vfs"] json = [] uuid = ["dep:uuid"] -io_uring = ["dep:io-uring", "rustix/io_uring", "dep:libc"] +io_uring = ["dep:io-uring", "rustix/io_uring"] time = [] fuzz = [] omit_autovacuum = [] @@ -29,10 +29,12 @@ series = [] [target.'cfg(target_os = "linux")'.dependencies] io-uring = { version = "0.7.5", optional = true } +libc = { version = "0.2.172" } [target.'cfg(target_family = "unix")'.dependencies] polling = "3.7.4" rustix = { version = "1.0.5", features = ["fs"] } +libc = { version = "0.2.172" } [target.'cfg(not(target_family = "wasm"))'.dependencies] mimalloc = { version = "0.1.46", default-features = false } @@ -44,7 +46,6 @@ turso_ext = { workspace = true, features = ["core_only"] } cfg_block = "0.1.1" fallible-iterator = "0.3.0" hex = "0.4.3" -libc = { version = "0.2.172", optional = true } turso_sqlite3_parser = { workspace = true } thiserror = "1.0.61" getrandom = { version = "0.2.15" } diff --git a/core/io/io_uring.rs b/core/io/io_uring.rs index f33c04db3..b2afeb652 100644 --- a/core/io/io_uring.rs +++ b/core/io/io_uring.rs @@ -2,38 +2,43 @@ use super::{common, Completion, CompletionInner, File, OpenFlags, IO}; use crate::io::clock::{Clock, Instant}; +use crate::storage::wal::CKPT_BATCH_PAGES; use crate::{turso_assert, LimboError, MemoryIO, Result}; use rustix::fs::{self, FlockOperation, OFlags}; -use std::cell::RefCell; -use std::collections::VecDeque; -use std::fmt; -use std::io::ErrorKind; -use std::os::fd::AsFd; -use std::os::unix::io::AsRawFd; -use std::rc::Rc; -use std::sync::Arc; -use thiserror::Error; +use std::{ + cell::RefCell, + collections::{HashMap, VecDeque}, + io::ErrorKind, + ops::Deref, + os::{fd::AsFd, unix::io::AsRawFd}, + rc::Rc, + sync::Arc, +}; use tracing::{debug, trace}; +/// Size of the io_uring submission and completion queues const ENTRIES: u32 = 512; + +/// Idle timeout for the sqpoll kernel thread before it needs +/// to be woken back up by a call IORING_ENTER_SQ_WAKEUP flag. +/// (handled by the io_uring crate in `submit_and_wait`) const SQPOLL_IDLE: u32 = 1000; + +/// Number of file descriptors we preallocate for io_uring. +/// NOTE: we may need to increase this when `attach` is fully implemented. const FILES: u32 = 8; -#[derive(Debug, Error)] -enum UringIOError { - IOUringCQError(i32), -} +/// Number of Vec> we preallocate on initialization +const IOVEC_POOL_SIZE: usize = 64; -impl fmt::Display for UringIOError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - UringIOError::IOUringCQError(code) => write!( - f, - "IOUring completion queue error occurred with code {code}", - ), - } - } -} +/// Maximum number of iovec entries per writev operation. +/// IOV_MAX is typically 1024, but we limit it to a smaller number +const MAX_IOVEC_ENTRIES: usize = CKPT_BATCH_PAGES; + +/// Maximum number of I/O operations to wait for in a single run, +/// waiting for > 1 can reduce the amount of `io_uring_enter` syscalls we +/// make, but can increase single operation latency. +const MAX_WAIT: usize = 4; pub struct UringIO { inner: Rc>, @@ -45,6 +50,8 @@ unsafe impl Sync for UringIO {} struct WrappedIOUring { ring: io_uring::IoUring, pending_ops: usize, + writev_states: HashMap, + iov_pool: IovecPool, } struct InnerUringIO { @@ -52,6 +59,39 @@ struct InnerUringIO { free_files: VecDeque, } +/// preallocated vec of iovec arrays to avoid allocations during writev operations +struct IovecPool { + pool: Vec>, +} + +impl IovecPool { + fn new() -> Self { + let pool = (0..IOVEC_POOL_SIZE) + .map(|_| { + Box::new( + [libc::iovec { + iov_base: std::ptr::null_mut(), + iov_len: 0, + }; MAX_IOVEC_ENTRIES], + ) + }) + .collect(); + Self { pool } + } + + #[inline(always)] + fn acquire(&mut self) -> Option> { + self.pool.pop() + } + + #[inline(always)] + fn release(&mut self, iovec: Box<[libc::iovec; MAX_IOVEC_ENTRIES]>) { + if self.pool.len() < IOVEC_POOL_SIZE { + self.pool.push(iovec); + } + } +} + impl UringIO { pub fn new() -> Result { let ring = match io_uring::IoUring::builder() @@ -69,6 +109,8 @@ impl UringIO { ring: WrappedIOUring { ring, pending_ops: 0, + writev_states: HashMap::new(), + iov_pool: IovecPool::new(), }, free_files: (0..FILES).collect(), }; @@ -79,6 +121,126 @@ impl UringIO { } } +/// io_uring crate decides not to export their `UseFixed` trait, so we +/// are forced to use a macro here to handle either fixed or raw file descriptors. +macro_rules! with_fd { + ($file:expr, |$fd:ident| $body:expr) => { + match $file.id() { + Some(id) => { + let $fd = io_uring::types::Fixed(id); + $body + } + None => { + let $fd = io_uring::types::Fd($file.as_raw_fd()); + $body + } + } + }; +} + +/// wrapper type to represent a possibly registered file descriptor, +/// only used in WritevState, and piggy-backs on the available methods from +/// `UringFile`, so we don't have to store the file on `WritevState`. +enum Fd { + Fixed(u32), + RawFd(i32), +} + +impl Fd { + /// to match the behavior of the File, we need to implement the same methods + fn id(&self) -> Option { + match self { + Fd::Fixed(id) => Some(*id), + Fd::RawFd(_) => None, + } + } + /// ONLY to be called by the macro, in the case where id() is None + fn as_raw_fd(&self) -> i32 { + match self { + Fd::RawFd(fd) => *fd, + _ => panic!("Cannot call as_raw_fd on a Fixed Fd"), + } + } +} + +/// State to track an ongoing writev operation in +/// the case of a partial write. +struct WritevState { + /// File descriptor/id of the file we are writing to + file_id: Fd, + /// absolute file offset for next submit + file_pos: usize, + /// current buffer index in `bufs` + current_buffer_idx: usize, + /// intra-buffer offset + current_buffer_offset: usize, + /// total bytes written so far + total_written: usize, + /// cache the sum of all buffer lengths for the total expected write + total_len: usize, + /// buffers to write + bufs: Vec>>, + /// we keep the last iovec allocation alive until final CQE + last_iov_allocation: Option>, +} + +impl WritevState { + fn new(file: &UringFile, pos: usize, bufs: Vec>>) -> Self { + let file_id = file + .id() + .map(Fd::Fixed) + .unwrap_or_else(|| Fd::RawFd(file.as_raw_fd())); + let total_len = bufs.iter().map(|b| b.borrow().len()).sum(); + Self { + file_id, + file_pos: pos, + current_buffer_idx: 0, + current_buffer_offset: 0, + total_written: 0, + bufs, + last_iov_allocation: None, + total_len, + } + } + + #[inline(always)] + fn remaining(&self) -> usize { + self.total_len - self.total_written + } + + /// Advance (idx, off, pos) after written bytes + #[inline(always)] + fn advance(&mut self, written: usize) { + let mut remaining = written; + while remaining > 0 { + let current_buf_len = { + let r = self.bufs[self.current_buffer_idx].borrow(); + r.len() + }; + let left = current_buf_len - self.current_buffer_offset; + if remaining < left { + self.current_buffer_offset += remaining; + self.file_pos += remaining; + remaining = 0; + } else { + remaining -= left; + self.file_pos += left; + self.current_buffer_idx += 1; + self.current_buffer_offset = 0; + } + } + self.total_written += written; + } + + #[inline(always)] + /// Free the allocation that keeps the iovec array alive while writev is ongoing + fn free_last_iov(&mut self, pool: &mut IovecPool) { + if let Some(allocation) = self.last_iov_allocation.take() { + pool.release(allocation); + } + } +} + impl InnerUringIO { fn register_file(&mut self, fd: i32) -> Result { if let Some(slot) = self.free_files.pop_front() { @@ -106,33 +268,119 @@ impl WrappedIOUring { fn submit_entry(&mut self, entry: &io_uring::squeue::Entry) { trace!("submit_entry({:?})", entry); unsafe { - self.ring - .submission() - .push(entry) - .expect("submission queue is full"); + let mut sub = self.ring.submission_shared(); + match sub.push(entry) { + Ok(_) => self.pending_ops += 1, + Err(e) => { + tracing::error!("Failed to submit entry: {e}"); + self.ring.submit().expect("failed to submit entry"); + sub.push(entry).expect("failed to push entry after submit"); + self.pending_ops += 1; + } + } } - self.pending_ops += 1; } - fn wait_for_completion(&mut self) -> Result<()> { - self.ring.submit_and_wait(1)?; + fn submit_and_wait(&mut self) -> Result<()> { + if self.empty() { + return Ok(()); + } + let wants = std::cmp::min(self.pending_ops, MAX_WAIT); + tracing::trace!("submit_and_wait for {wants} pending operations to complete"); + self.ring.submit_and_wait(wants)?; Ok(()) } - fn get_completion(&mut self) -> Option { - // NOTE: This works because CompletionQueue's next function pops the head of the queue. This is not normal behaviour of iterators - let entry = self.ring.completion().next(); - if entry.is_some() { - trace!("get_completion({:?})", entry); - // consumed an entry from completion queue, update pending_ops - self.pending_ops -= 1; - } - entry - } - fn empty(&self) -> bool { self.pending_ops == 0 } + + /// Submit or resubmit a writev operation + fn submit_writev(&mut self, key: u64, mut st: WritevState) { + st.free_last_iov(&mut self.iov_pool); + let mut iov_allocation = self.iov_pool.acquire().unwrap_or_else(|| { + // Fallback: allocate a new one if pool is exhausted + Box::new( + [libc::iovec { + iov_base: std::ptr::null_mut(), + iov_len: 0, + }; MAX_IOVEC_ENTRIES], + ) + }); + let mut iov_count = 0; + for (idx, buffer) in st + .bufs + .iter() + .enumerate() + .skip(st.current_buffer_idx) + .take(MAX_IOVEC_ENTRIES) + { + let buf = buffer.borrow(); + let buf_slice = buf.as_slice(); + // ensure we are providing a pointer to the proper offset in the buffer + let slice = if idx == st.current_buffer_idx { + &buf_slice[st.current_buffer_offset..] + } else { + buf_slice + }; + if slice.is_empty() { + continue; + } + iov_allocation[iov_count] = libc::iovec { + iov_base: slice.as_ptr() as *mut _, + iov_len: slice.len(), + }; + iov_count += 1; + } + // Store the pointers and get the pointer to the iovec array that we pass + // to the writev operation, and keep the array itself alive + let ptr = iov_allocation.as_ptr() as *mut libc::iovec; + st.last_iov_allocation = Some(iov_allocation); + + let entry = with_fd!(st.file_id, |fd| { + io_uring::opcode::Writev::new(fd, ptr, iov_count as u32) + .offset(st.file_pos as u64) + .build() + .user_data(key) + }); + // track the current state in case we get a partial write + self.writev_states.insert(key, st); + self.submit_entry(&entry); + } + + fn handle_writev_completion(&mut self, mut state: WritevState, user_data: u64, result: i32) { + if result < 0 { + let err = std::io::Error::from_raw_os_error(result); + tracing::error!("writev failed (user_data: {}): {}", user_data, err); + state.free_last_iov(&mut self.iov_pool); + completion_from_key(user_data).complete(result); + return; + } + + let written = result as usize; + state.advance(written); + match state.remaining() { + 0 => { + tracing::info!( + "writev operation completed: wrote {} bytes", + state.total_written + ); + // write complete, return iovec to pool + state.free_last_iov(&mut self.iov_pool); + completion_from_key(user_data).complete(state.total_written as i32); + } + remaining => { + tracing::trace!( + "resubmitting writev operation for user_data {}: wrote {} bytes, remaining {}", + user_data, + written, + remaining + ); + // partial write, submit next + self.submit_writev(user_data, state); + } + } + } } impl IO for UringIO { @@ -179,26 +427,28 @@ impl IO for UringIO { trace!("run_once()"); let mut inner = self.inner.borrow_mut(); let ring = &mut inner.ring; - if ring.empty() { return Ok(()); } - - ring.wait_for_completion()?; - while let Some(cqe) = ring.get_completion() { + ring.submit_and_wait()?; + loop { + let Some(cqe) = ring.ring.completion().next() else { + return Ok(()); + }; + ring.pending_ops -= 1; + let user_data = cqe.user_data(); let result = cqe.result(); - if result < 0 { - return Err(LimboError::UringIOError(format!( - "{} cqe: {:?}", - UringIOError::IOUringCQError(result), - cqe - ))); + turso_assert!( + user_data != 0, + "user_data must not be zero, we dont submit linked timeouts or cancelations that would cause this" + ); + if let Some(state) = ring.writev_states.remove(&user_data) { + // if we have ongoing writev state, handle it separately and don't call completion + ring.handle_writev_completion(state, user_data, result); + continue; } - let ud = cqe.user_data(); - turso_assert!(ud > 0, "therea are no linked timeouts or cancelations, all cqe user_data should be valid arc pointers"); - completion_from_key(ud).complete(result); + completion_from_key(user_data).complete(result) } - Ok(()) } fn generate_random_number(&self) -> i64 { @@ -242,24 +492,22 @@ pub struct UringFile { id: Option, } +impl Deref for UringFile { + type Target = std::fs::File; + fn deref(&self) -> &Self::Target { + &self.file + } +} + +impl UringFile { + fn id(&self) -> Option { + self.id + } +} + unsafe impl Send for UringFile {} unsafe impl Sync for UringFile {} -macro_rules! with_fd { - ($file:expr, |$fd:ident| $body:expr) => { - match $file.id { - Some(id) => { - let $fd = io_uring::types::Fixed(id); - $body - } - None => { - let $fd = io_uring::types::Fd($file.file.as_raw_fd()); - $body - } - } - }; -} - impl File for UringFile { fn lock_file(&self, exclusive: bool) -> Result<()> { let fd = self.file.as_fd(); @@ -350,6 +598,24 @@ impl File for UringFile { Ok(c) } + fn pwritev( + &self, + pos: usize, + bufs: Vec>>, + c: Completion, + ) -> Result { + // for a single buffer use pwrite directly + if bufs.len().eq(&1) { + return self.pwrite(pos, bufs[0].clone(), c.clone()); + } + tracing::trace!("pwritev(pos = {}, bufs.len() = {})", pos, bufs.len()); + let mut io = self.io.borrow_mut(); + // create state to track ongoing writev operation + let state = WritevState::new(self, pos, bufs); + io.ring.submit_writev(get_key(c.clone()), state); + Ok(c) + } + fn size(&self) -> Result { Ok(self.file.metadata()?.len()) } diff --git a/core/io/memory.rs b/core/io/memory.rs index 7dbf05d50..4d056aeb4 100644 --- a/core/io/memory.rs +++ b/core/io/memory.rs @@ -187,6 +187,49 @@ impl File for MemoryFile { Ok(c) } + fn pwritev( + &self, + pos: usize, + buffers: Vec>>, + c: Completion, + ) -> Result { + let mut offset = pos; + let mut total_written = 0; + + for buffer in buffers { + let buf = buffer.borrow(); + let buf_len = buf.len(); + if buf_len == 0 { + continue; + } + + let mut remaining = buf_len; + let mut buf_offset = 0; + let data = &buf.as_slice(); + + while remaining > 0 { + let page_no = offset / PAGE_SIZE; + let page_offset = offset % PAGE_SIZE; + let bytes_to_write = remaining.min(PAGE_SIZE - page_offset); + + { + let page = self.get_or_allocate_page(page_no); + page[page_offset..page_offset + bytes_to_write] + .copy_from_slice(&data[buf_offset..buf_offset + bytes_to_write]); + } + + offset += bytes_to_write; + buf_offset += bytes_to_write; + remaining -= bytes_to_write; + } + total_written += buf_len; + } + c.complete(total_written as i32); + self.size + .set(core::cmp::max(pos + total_written, self.size.get())); + Ok(c) + } + fn size(&self) -> Result { Ok(self.size.get() as u64) } diff --git a/core/io/mod.rs b/core/io/mod.rs index 82ef51313..8560216e8 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -18,6 +18,46 @@ pub trait File: Send + Sync { fn pwrite(&self, pos: usize, buffer: Arc>, c: Completion) -> Result; fn sync(&self, c: Completion) -> Result; + fn pwritev( + &self, + pos: usize, + buffers: Vec>>, + c: Completion, + ) -> Result { + use std::sync::atomic::{AtomicUsize, Ordering}; + if buffers.is_empty() { + c.complete(0); + return Ok(c); + } + // naive default implementation can be overridden on backends where it makes sense to + let mut pos = pos; + let outstanding = Arc::new(AtomicUsize::new(buffers.len())); + let total_written = Arc::new(AtomicUsize::new(0)); + + for buf in buffers { + let len = buf.borrow().len(); + let child_c = { + let c_main = c.clone(); + let outstanding = outstanding.clone(); + let total_written = total_written.clone(); + Completion::new_write(move |n| { + // accumulate bytes actually reported by the backend + total_written.fetch_add(n as usize, Ordering::Relaxed); + if outstanding.fetch_sub(1, Ordering::AcqRel) == 1 { + // last one finished + c_main.complete(total_written.load(Ordering::Acquire) as i32); + } + }) + }; + if let Err(e) = self.pwrite(pos, buf.clone(), child_c) { + // best-effort: mark as done so caller won't wait forever + c.complete(-1); + return Err(e); + } + pos += len; + } + Ok(c) + } fn size(&self) -> Result; fn truncate(&self, len: usize, c: Completion) -> Result; } @@ -304,10 +344,10 @@ cfg_block! { pub use unix::UnixIO as PlatformIO; } - #[cfg(target_os = "windows")] { + #[cfg(target_os = "windows")] { mod windows; pub use windows::WindowsIO as PlatformIO; - pub use PlatformIO as SyscallIO; + pub use PlatformIO as SyscallIO; } #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "android", target_os = "ios")))] { diff --git a/core/io/unix.rs b/core/io/unix.rs index 9cb50a3f8..7e73e6904 100644 --- a/core/io/unix.rs +++ b/core/io/unix.rs @@ -1,15 +1,15 @@ +use super::{Completion, File, MemoryIO, OpenFlags, IO}; use crate::error::LimboError; +use crate::io::clock::{Clock, Instant}; use crate::io::common; use crate::Result; - -use super::{Completion, File, MemoryIO, OpenFlags, IO}; -use crate::io::clock::{Clock, Instant}; use polling::{Event, Events, Poller}; use rustix::{ fd::{AsFd, AsRawFd}, fs::{self, FlockOperation, OFlags, OpenOptionsExt}, io::Errno, }; +use std::os::fd::RawFd; use std::{ cell::{RefCell, UnsafeCell}, mem::MaybeUninit, @@ -40,11 +40,6 @@ impl OwnedCallbacks { self.as_mut().inline_count == 0 } - fn get(&self, fd: usize) -> Option<&CompletionCallback> { - let callbacks = unsafe { &mut *self.0.get() }; - callbacks.get(fd) - } - fn remove(&self, fd: usize) -> Option { let callbacks = unsafe { &mut *self.0.get() }; callbacks.remove(fd) @@ -135,16 +130,6 @@ impl Callbacks { } } - fn get(&self, fd: usize) -> Option<&CompletionCallback> { - if let Some(pos) = self.find_inline(fd) { - let (_, callback) = unsafe { self.inline_entries[pos].assume_init_ref() }; - return Some(callback); - } else if let Some(pos) = self.heap_entries.iter().position(|&(k, _)| k == fd) { - return Some(&self.heap_entries[pos].1); - } - None - } - fn remove(&mut self, fd: usize) -> Option { if let Some(pos) = self.find_inline(fd) { let (_, callback) = unsafe { self.inline_entries[pos].assume_init_read() }; @@ -213,6 +198,35 @@ impl Clock for UnixIO { } } +fn try_pwritev_raw( + fd: RawFd, + off: u64, + bufs: &[Arc>], + start_idx: usize, + start_off: usize, +) -> std::io::Result { + const MAX_IOV: usize = 1024; + let iov_len = std::cmp::min(bufs.len() - start_idx, MAX_IOV); + let mut iov = Vec::with_capacity(iov_len); + + for (i, b) in bufs.iter().enumerate().skip(start_idx).take(iov_len) { + let r = b.borrow(); // borrow just to get pointer/len + let s = r.as_slice(); + let s = if i == start_idx { &s[start_off..] } else { s }; + iov.push(libc::iovec { + iov_base: s.as_ptr() as *mut _, + iov_len: s.len(), + }); + } + + let n = unsafe { libc::pwritev(fd, iov.as_ptr(), iov.len() as i32, off as i64) }; + if n < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(n as usize) + } +} + impl IO for UnixIO { fn open_file(&self, path: &str, flags: OpenFlags, _direct: bool) -> Result> { trace!("open_file(path = {})", path); @@ -243,46 +257,129 @@ impl IO for UnixIO { if self.callbacks.is_empty() { return Ok(()); } + self.events.clear(); trace!("run_once() waits for events"); self.poller.wait(self.events.as_mut(), None)?; for event in self.events.iter() { - if let Some(cf) = self.callbacks.get(event.key) { - let result = match cf { - CompletionCallback::Read(ref file, ref c, pos) => { - let file = file.lock().unwrap(); - let r = c.as_read(); - let mut buf = r.buf_mut(); - rustix::io::pread(file.as_fd(), buf.as_mut_slice(), *pos as u64) - } - CompletionCallback::Write(ref file, _, ref buf, pos) => { - let file = file.lock().unwrap(); - let buf = buf.borrow(); - rustix::io::pwrite(file.as_fd(), buf.as_slice(), *pos as u64) - } - }; - match result { - Ok(n) => { - let cf = self - .callbacks - .remove(event.key) - .expect("callback should exist"); - match cf { - CompletionCallback::Read(_, c, _) => c.complete(0), - CompletionCallback::Write(_, c, _, _) => c.complete(n as i32), - } - } - Err(Errno::AGAIN) => (), - Err(e) => { - self.callbacks.remove(event.key); + let key = event.key; + let cb = match self.callbacks.remove(key) { + Some(cb) => cb, + None => continue, // could have been completed/removed already + }; - trace!("run_once() error: {}", e); - return Err(e.into()); + match cb { + CompletionCallback::Read(ref file, c, pos) => { + let f = file + .lock() + .map_err(|e| LimboError::LockingError(e.to_string()))?; + let r = c.as_read(); + let mut buf = r.buf_mut(); + match rustix::io::pread(f.as_fd(), buf.as_mut_slice(), pos as u64) { + Ok(n) => c.complete(n as i32), + Err(Errno::AGAIN) => { + // re-arm + unsafe { self.poller.as_mut().add(&f.as_fd(), Event::readable(key))? }; + self.callbacks.as_mut().insert( + key, + CompletionCallback::Read(file.clone(), c.clone(), pos), + ); + } + Err(e) => return Err(e.into()), + } + } + + CompletionCallback::Write(ref file, c, buf, pos) => { + let f = file + .lock() + .map_err(|e| LimboError::LockingError(e.to_string()))?; + let b = buf.borrow(); + match rustix::io::pwrite(f.as_fd(), b.as_slice(), pos as u64) { + Ok(n) => c.complete(n as i32), + Err(Errno::AGAIN) => { + unsafe { self.poller.as_mut().add(&f.as_fd(), Event::writable(key))? }; + self.callbacks.as_mut().insert( + key, + CompletionCallback::Write(file.clone(), c, buf.clone(), pos), + ); + } + Err(e) => return Err(e.into()), + } + } + + CompletionCallback::Writev(file, c, bufs, mut pos, mut idx, mut off) => { + let f = file + .lock() + .map_err(|e| LimboError::LockingError(e.to_string()))?; + // keep trying until WouldBlock or we're done with this event + match try_pwritev_raw(f.as_raw_fd(), pos as u64, &bufs, idx, off) { + Ok(written) => { + // advance through buffers + let mut rem = written; + while rem > 0 { + let len = { + let r = bufs[idx].borrow(); + r.len() + }; + let left = len - off; + if rem < left { + off += rem; + rem = 0; + } else { + rem -= left; + idx += 1; + off = 0; + if idx == bufs.len() { + break; + } + } + } + pos += written; + + if idx == bufs.len() { + c.complete(pos as i32); + } else { + // Not finished; re-arm and store updated state + unsafe { + self.poller.as_mut().add(&f.as_fd(), Event::writable(key))? + }; + self.callbacks.as_mut().insert( + key, + CompletionCallback::Writev( + file.clone(), + c.clone(), + bufs, + pos, + idx, + off, + ), + ); + } + break; + } + Err(e) if e.kind() == ErrorKind::WouldBlock => { + // re-arm with same state + unsafe { self.poller.as_mut().add(&f.as_fd(), Event::writable(key))? }; + self.callbacks.as_mut().insert( + key, + CompletionCallback::Writev( + file.clone(), + c.clone(), + bufs, + pos, + idx, + off, + ), + ); + break; + } + Err(e) => return Err(e.into()), } } } } + Ok(()) } @@ -312,6 +409,14 @@ enum CompletionCallback { Arc>, usize, ), + Writev( + Arc>, + Completion, + Vec>>, + usize, // absolute file offset + usize, // buf index + usize, // intra-buf offset + ), } pub struct UnixFile<'io> { @@ -431,6 +536,52 @@ impl File for UnixFile<'_> { } } + #[instrument(err, skip_all, level = Level::TRACE)] + fn pwritev( + &self, + pos: usize, + buffers: Vec>>, + c: Completion, + ) -> Result { + let file = self + .file + .lock() + .map_err(|e| LimboError::LockingError(e.to_string()))?; + + match try_pwritev_raw(file.as_raw_fd(), pos as u64, &buffers, 0, 0) { + Ok(written) => { + trace!("pwritev wrote {written}"); + c.complete(written as i32); + Ok(c) + } + Err(e) => { + if e.kind() == ErrorKind::WouldBlock { + trace!("pwritev blocks"); + } else { + return Err(e.into()); + } + // Set up state so we can resume later + let fd = file.as_raw_fd(); + self.poller + .add(&file.as_fd(), Event::writable(fd as usize))?; + let buf_idx = 0; + let buf_offset = 0; + self.callbacks.insert( + fd as usize, + CompletionCallback::Writev( + self.file.clone(), + c.clone(), + buffers, + pos, + buf_idx, + buf_offset, + ), + ); + Ok(c) + } + } + } + #[instrument(err, skip_all, level = Level::TRACE)] fn sync(&self, c: Completion) -> Result { let file = self.file.lock().unwrap(); diff --git a/core/storage/database.rs b/core/storage/database.rs index fd2555b59..0370d398c 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -16,6 +16,13 @@ pub trait DatabaseStorage: Send + Sync { buffer: Arc>, c: Completion, ) -> Result; + fn write_pages( + &self, + first_page_idx: usize, + page_size: usize, + buffers: Vec>>, + c: Completion, + ) -> Result; fn sync(&self, c: Completion) -> Result; fn size(&self) -> Result; fn truncate(&self, len: usize, c: Completion) -> Result; @@ -61,6 +68,22 @@ impl DatabaseStorage for DatabaseFile { self.file.pwrite(pos, buffer, c) } + fn write_pages( + &self, + page_idx: usize, + page_size: usize, + buffers: Vec>>, + c: Completion, + ) -> Result { + assert!(page_idx > 0); + assert!(page_size >= 512); + assert!(page_size <= 65536); + assert_eq!(page_size & (page_size - 1), 0); + let pos = (page_idx - 1) * page_size; + let c = self.file.pwritev(pos, buffers, c)?; + Ok(c) + } + #[instrument(skip_all, level = Level::DEBUG)] fn sync(&self, c: Completion) -> Result { self.file.sync(c) @@ -120,6 +143,22 @@ impl DatabaseStorage for FileMemoryStorage { self.file.pwrite(pos, buffer, c) } + fn write_pages( + &self, + page_idx: usize, + page_size: usize, + buffer: Vec>>, + c: Completion, + ) -> Result { + assert!(page_idx > 0); + assert!(page_size >= 512); + assert!(page_size <= 65536); + assert_eq!(page_size & (page_size - 1), 0); + let pos = (page_idx - 1) * page_size; + let c = self.file.pwritev(pos, buffer, c)?; + Ok(c) + } + #[instrument(skip_all, level = Level::DEBUG)] fn sync(&self, c: Completion) -> Result { self.file.sync(c) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index a31094f19..90fcb2893 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -346,7 +346,7 @@ pub struct Pager { /// Cache page_size and reserved_space at Pager init and reuse for subsequent /// `usable_space` calls. TODO: Invalidate reserved_space when we add the functionality /// to change it. - page_size: Cell>, + pub(crate) page_size: Cell>, reserved_space: OnceCell, free_page_state: RefCell, } @@ -1303,11 +1303,11 @@ impl Pager { return Ok(CheckpointResult::default()); } - let counter = Rc::new(RefCell::new(0)); + let write_counter = Rc::new(RefCell::new(0)); let mut checkpoint_result = self.io.block(|| { self.wal .borrow_mut() - .checkpoint(self, counter.clone(), mode) + .checkpoint(self, write_counter.clone(), mode) })?; if checkpoint_result.everything_backfilled() diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 829f049b6..e196c2ae5 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -58,14 +58,15 @@ use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_thr use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; use crate::storage::pager::Pager; +use crate::storage::wal::PendingFlush; use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype}; use crate::{turso_assert, File, Result, WalFileShared}; use std::cell::{RefCell, UnsafeCell}; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::mem::MaybeUninit; use std::pin::Pin; use std::rc::Rc; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; /// The size of the database header in bytes. @@ -852,6 +853,115 @@ pub fn begin_write_btree_page( res } +#[instrument(skip_all, level = Level::DEBUG)] +/// Write a batch of pages to the database file. +/// +/// we have a batch of pages to write, lets say the following: +/// (they are already sorted by id thanks to BTreeMap) +/// [1,2,3,6,7,9,10,11,12] +// +/// we want to collect this into runs of: +/// [1,2,3], [6,7], [9,10,11,12] +/// and submit each run as a `writev` call, +/// for 3 total syscalls instead of 9. +pub fn write_pages_vectored( + pager: &Pager, + batch: BTreeMap>>, +) -> Result { + if batch.is_empty() { + return Ok(PendingFlush::default()); + } + + // batch item array is already sorted by id, so we just need to find contiguous ranges of page_id's + // to submit as `writev`/write_pages calls. + + let page_sz = pager.page_size.get().unwrap_or(DEFAULT_PAGE_SIZE) as usize; + + // Count expected number of runs to create the atomic counter we need to track each batch + let mut run_count = 0; + let mut prev_id = None; + for &id in batch.keys() { + if let Some(prev) = prev_id { + if id != prev + 1 { + run_count += 1; + } + } else { + run_count = 1; // First run + } + prev_id = Some(id); + } + + // Create the atomic counters + let runs_left = Arc::new(AtomicUsize::new(run_count)); + let done = Arc::new(AtomicBool::new(false)); + // we know how many runs, but we don't know how many buffers per run, so we can only give an + // estimate of the capacity + const EST_BUFF_CAPACITY: usize = 32; + + // Iterate through the batch, submitting each run as soon as it ends + // We can reuse this across runs without reallocating + let mut run_bufs = Vec::with_capacity(EST_BUFF_CAPACITY); + let mut run_start_id: Option = None; + let mut all_ids = Vec::with_capacity(batch.len()); + + // Iterate through the batch + let mut iter = batch.into_iter().peekable(); + + while let Some((id, item)) = iter.next() { + // Track the start of the run + if run_start_id.is_none() { + run_start_id = Some(id); + } + + // Add this page to the current run + run_bufs.push(item); + all_ids.push(id); + + // Check if this is the end of a run + let is_end_of_run = match iter.peek() { + Some(&(next_id, _)) => next_id != id + 1, + None => true, + }; + + if is_end_of_run { + let start_id = run_start_id.expect("should have a start id"); + let runs_left_cl = runs_left.clone(); + let done_cl = done.clone(); + + let c = Completion::new_write(move |_| { + if runs_left_cl.fetch_sub(1, Ordering::AcqRel) == 1 { + done_cl.store(true, Ordering::Release); + } + }); + + // Submit write operation for this run, decrementing the counter if we error + if let Err(e) = pager + .db_file + .write_pages(start_id, page_sz, run_bufs.clone(), c) + { + if runs_left.fetch_sub(1, Ordering::AcqRel) == 1 { + done.store(true, Ordering::Release); + } + return Err(e); + } + + // Reset for next run + run_bufs.clear(); + run_start_id = None; + } + } + + tracing::debug!( + "write_pages_vectored: {} pages to write, runs: {run_count}", + all_ids.len() + ); + + Ok(PendingFlush { + pages: all_ids, + done, + }) +} + #[instrument(skip_all, level = Level::DEBUG)] pub fn begin_sync( db_file: Arc, diff --git a/core/storage/wal.rs b/core/storage/wal.rs index eb55e9dc2..90e407cd9 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -3,7 +3,7 @@ use std::array; use std::cell::UnsafeCell; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use strum::EnumString; use tracing::{instrument, Level}; @@ -21,7 +21,7 @@ use crate::io::{File, IO}; use crate::result::LimboResult; use crate::storage::sqlite3_ondisk::{ begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame, - WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE, + write_pages_vectored, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE, }; use crate::types::IOResult; use crate::{turso_assert, Buffer, LimboError, Result}; @@ -31,7 +31,7 @@ use self::sqlite3_ondisk::{checksum_wal, PageContent, WAL_MAGIC_BE, WAL_MAGIC_LE use super::buffer_pool::BufferPool; use super::pager::{PageRef, Pager}; -use super::sqlite3_ondisk::{self, begin_write_btree_page, WalHeader}; +use super::sqlite3_ondisk::{self, WalHeader}; pub const READMARK_NOT_USED: u32 = 0xffffffff; @@ -393,11 +393,69 @@ pub enum CheckpointState { Start, ReadFrame, WaitReadFrame, - WritePage, - WaitWritePage, + AccumulatePage, + FlushBatch, + WaitFlush, Done, } +/// IOV_MAX is 1024 on most systems, lets use 512 to be safe +pub const CKPT_BATCH_PAGES: usize = 512; +type PageId = usize; + +/// Batch is a collection of pages that are being checkpointed together. It is used to +/// aggregate contiguous pages into a single write operation to the database file. +pub(super) struct Batch { + items: BTreeMap>>, +} +// TODO(preston): implement the same thing for `readv` +impl Batch { + fn new() -> Self { + Self { + items: BTreeMap::new(), + } + } + fn is_full(&self) -> bool { + self.items.len() >= CKPT_BATCH_PAGES + } + fn add_to_batch(&mut self, scratch: &PageRef, pool: &Arc) { + let (id, buf_clone) = unsafe { + let inner = &*scratch.inner.get(); + let id = inner.id; + let contents = inner.contents.as_ref().expect("scratch has contents"); + let buf = contents.buffer.clone(); + (id, buf) + }; + // Insert the new batch item at the correct position + self.items.insert(id, buf_clone); + + // Re-initialize scratch with a fresh buffer + let raw = pool.get(); + let pool_clone = pool.clone(); + let drop_fn = Rc::new(move |b| pool_clone.put(b)); + let new_buf = Arc::new(RefCell::new(Buffer::new(raw, drop_fn))); + + unsafe { + let inner = &mut *scratch.inner.get(); + inner.contents = Some(PageContent::new(0, new_buf)); + // reset flags on scratch so it won't be cleared later with the real page + inner.flags.store(0, Ordering::SeqCst); + } + } +} + +impl std::ops::Deref for Batch { + type Target = BTreeMap>>; + fn deref(&self) -> &Self::Target { + &self.items + } +} +impl std::ops::DerefMut for Batch { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.items + } +} + // Checkpointing is a state machine that has multiple steps. Since there are multiple steps we save // in flight information of the checkpoint in OngoingCheckpoint. page is just a helper Page to do // page operations like reading a frame to a page, and writing a page to disk. This page should not @@ -407,13 +465,45 @@ pub enum CheckpointState { // current_page is a helper to iterate through all the pages that might have a frame in the safe // range. This is inefficient for now. struct OngoingCheckpoint { - page: PageRef, + scratch_page: PageRef, + batch: Batch, state: CheckpointState, + pending_flush: Option, min_frame: u64, max_frame: u64, current_page: u64, } +pub(super) struct PendingFlush { + // page ids to clear + pub(super) pages: Vec, + // completion flag set by IO callback + pub(super) done: Arc, +} + +impl Default for PendingFlush { + fn default() -> Self { + Self::new() + } +} + +impl PendingFlush { + pub fn new() -> Self { + Self { + pages: Vec::with_capacity(CKPT_BATCH_PAGES), + done: Arc::new(AtomicBool::new(false)), + } + } + // clear the dirty flag of all pages in the pending flush batch + fn clear_dirty(&self, pager: &Pager) { + for id in &self.pages { + if let Some(p) = pager.cache_get(*id) { + p.clear_dirty(); + } + } + } +} + impl fmt::Debug for OngoingCheckpoint { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("OngoingCheckpoint") @@ -1079,7 +1169,7 @@ impl Wal for WalFile { fn checkpoint( &mut self, pager: &Pager, - write_counter: Rc>, + _write_counter: Rc>, mode: CheckpointMode, ) -> Result> { if matches!(mode, CheckpointMode::Full) { @@ -1087,7 +1177,7 @@ impl Wal for WalFile { "Full checkpoint mode is not implemented yet".into(), )); } - self.checkpoint_inner(pager, write_counter, mode) + self.checkpoint_inner(pager, _write_counter, mode) .inspect_err(|_| { let _ = self.checkpoint_guard.take(); }) @@ -1204,7 +1294,9 @@ impl WalFile { max_frame: unsafe { (*shared.get()).max_frame.load(Ordering::SeqCst) }, shared, ongoing_checkpoint: OngoingCheckpoint { - page: checkpoint_page, + scratch_page: checkpoint_page, + batch: Batch::new(), + pending_flush: None, state: CheckpointState::Start, min_frame: 0, max_frame: 0, @@ -1263,6 +1355,8 @@ impl WalFile { self.ongoing_checkpoint.max_frame = 0; self.ongoing_checkpoint.current_page = 0; self.max_frame_read_lock_index.set(NO_LOCK_HELD); + self.ongoing_checkpoint.batch.clear(); + let _ = self.ongoing_checkpoint.pending_flush.take(); self.sync_state.set(SyncState::NotSyncing); self.syncing.set(false); } @@ -1311,7 +1405,7 @@ impl WalFile { fn checkpoint_inner( &mut self, pager: &Pager, - write_counter: Rc>, + _write_counter: Rc>, mode: CheckpointMode, ) -> Result> { 'checkpoint_loop: loop { @@ -1358,7 +1452,14 @@ impl WalFile { let frame_cache = frame_cache.lock(); assert!(self.ongoing_checkpoint.current_page as usize <= pages_in_frames.len()); if self.ongoing_checkpoint.current_page as usize == pages_in_frames.len() { - self.ongoing_checkpoint.state = CheckpointState::Done; + if self.ongoing_checkpoint.batch.is_empty() { + // no more pages to checkpoint, we are done + tracing::info!("checkpoint done, no more pages to checkpoint"); + self.ongoing_checkpoint.state = CheckpointState::Done; + } else { + // flush the batch + self.ongoing_checkpoint.state = CheckpointState::FlushBatch; + } continue 'checkpoint_loop; } let page = pages_in_frames[self.ongoing_checkpoint.current_page as usize]; @@ -1374,10 +1475,10 @@ impl WalFile { page, *frame ); - self.ongoing_checkpoint.page.get().id = page as usize; + self.ongoing_checkpoint.scratch_page.get().id = page as usize; let _ = self.read_frame( *frame, - self.ongoing_checkpoint.page.clone(), + self.ongoing_checkpoint.scratch_page.clone(), self.buffer_pool.clone(), )?; self.ongoing_checkpoint.state = CheckpointState::WaitReadFrame; @@ -1387,30 +1488,65 @@ impl WalFile { self.ongoing_checkpoint.current_page += 1; } CheckpointState::WaitReadFrame => { - if self.ongoing_checkpoint.page.is_locked() { + if self.ongoing_checkpoint.scratch_page.is_locked() { return Ok(IOResult::IO); } else { - self.ongoing_checkpoint.state = CheckpointState::WritePage; + self.ongoing_checkpoint.state = CheckpointState::AccumulatePage; } } - CheckpointState::WritePage => { - self.ongoing_checkpoint.page.set_dirty(); - let _ = begin_write_btree_page( + CheckpointState::AccumulatePage => { + // mark before batching + self.ongoing_checkpoint.scratch_page.set_dirty(); + // we read the frame into memory, add it to our batch + self.ongoing_checkpoint + .batch + .add_to_batch(&self.ongoing_checkpoint.scratch_page, &self.buffer_pool); + + let more_pages = (self.ongoing_checkpoint.current_page as usize) + < self + .get_shared() + .pages_in_frames + .lock() + .len() + .saturating_sub(1) + && !self.ongoing_checkpoint.batch.is_full(); + + // if we can read more pages, continue reading and accumulating pages + if more_pages { + self.ongoing_checkpoint.current_page += 1; + self.ongoing_checkpoint.state = CheckpointState::ReadFrame; + } else { + // if we have enough pages in the batch, flush it + self.ongoing_checkpoint.state = CheckpointState::FlushBatch; + } + } + CheckpointState::FlushBatch => { + tracing::trace!("started checkpoint backfilling batch"); + self.ongoing_checkpoint.pending_flush = Some(write_pages_vectored( pager, - &self.ongoing_checkpoint.page, - write_counter.clone(), - )?; - self.ongoing_checkpoint.state = CheckpointState::WaitWritePage; + std::mem::take(&mut self.ongoing_checkpoint.batch), + )?); + // batch is queued + self.ongoing_checkpoint.batch.clear(); + self.ongoing_checkpoint.state = CheckpointState::WaitFlush; } - CheckpointState::WaitWritePage => { - if *write_counter.borrow() > 0 { - return Ok(IOResult::IO); + CheckpointState::WaitFlush => { + match self.ongoing_checkpoint.pending_flush.as_ref() { + Some(pf) if pf.done.load(Ordering::SeqCst) => { + // flush is done, we can continue + tracing::trace!("checkpoint backfilling batch done"); + } + Some(_) => return Ok(IOResult::IO), + None => panic!("we should have a pending flush here"), } - // If page was in cache clear it. - if let Some(page) = pager.cache_get(self.ongoing_checkpoint.page.get().id) { - page.clear_dirty(); - } - self.ongoing_checkpoint.page.clear_dirty(); + tracing::debug!("finished checkpoint backfilling batch"); + let pf = self + .ongoing_checkpoint + .pending_flush + .as_ref() + .expect("we should have a pending flush here"); + pf.clear_dirty(pager); + // done with batch let shared = self.get_shared(); if (self.ongoing_checkpoint.current_page as usize) < shared.pages_in_frames.lock().len() @@ -1418,6 +1554,7 @@ impl WalFile { self.ongoing_checkpoint.current_page += 1; self.ongoing_checkpoint.state = CheckpointState::ReadFrame; } else { + tracing::debug!("WaitFlush transitioning checkpoint to Done"); self.ongoing_checkpoint.state = CheckpointState::Done; } } @@ -1426,8 +1563,11 @@ impl WalFile { // In Restart or Truncate mode, we need to restart the log over and possibly truncate the file // Release all locks and return the current num of wal frames and the amount we backfilled CheckpointState::Done => { - if *write_counter.borrow() > 0 { - return Ok(IOResult::IO); + if let Some(pf) = self.ongoing_checkpoint.pending_flush.as_ref() { + turso_assert!( + pf.done.load(Ordering::Relaxed), + "checkpoint pending flush must have finished" + ); } let mut checkpoint_result = { let shared = self.get_shared(); @@ -1491,6 +1631,11 @@ impl WalFile { } else { let _ = self.checkpoint_guard.take(); } + self.ongoing_checkpoint.scratch_page.clear_dirty(); + self.ongoing_checkpoint.scratch_page.get().id = 0; + self.ongoing_checkpoint.scratch_page.get().contents = None; + let _ = self.ongoing_checkpoint.pending_flush.take(); + self.ongoing_checkpoint.batch.clear(); self.ongoing_checkpoint.state = CheckpointState::Start; return Ok(IOResult::Done(checkpoint_result)); } @@ -1918,6 +2063,25 @@ pub mod test { } } + fn count_test_table(conn: &Arc) -> i64 { + let mut stmt = conn.prepare("select count(*) from test").unwrap(); + loop { + match stmt.step() { + Ok(StepResult::Row) => { + break; + } + Ok(StepResult::IO) => { + stmt.run_once().unwrap(); + } + _ => { + panic!("Failed to step through the statement"); + } + } + } + let count: i64 = stmt.row().unwrap().get(0).unwrap(); + count + } + fn run_checkpoint_until_done( wal: &mut dyn Wal, pager: &crate::Pager, @@ -2496,6 +2660,75 @@ pub mod test { std::fs::remove_dir_all(path).unwrap(); } + #[test] + fn test_wal_checkpoint_truncate_db_file_contains_data() { + let (db, path) = get_database(); + let conn = db.connect().unwrap(); + + let walpath = { + let mut p = path.clone().into_os_string().into_string().unwrap(); + p.push_str("/test.db-wal"); + std::path::PathBuf::from(p) + }; + + conn.execute("create table test(id integer primary key, value text)") + .unwrap(); + bulk_inserts(&conn, 10, 100); + + // Get size before checkpoint + let size_before = std::fs::metadata(&walpath).unwrap().len(); + assert!(size_before > 0, "WAL file should have content"); + + // Do a TRUNCATE checkpoint + { + let pager = conn.pager.borrow(); + let mut wal = pager.wal.borrow_mut(); + run_checkpoint_until_done(&mut *wal, &pager, CheckpointMode::Truncate); + } + + // Check file size after truncate + let size_after = std::fs::metadata(&walpath).unwrap().len(); + assert_eq!(size_after, 0, "WAL file should be truncated to 0 bytes"); + + // Verify we can still write to the database + conn.execute("INSERT INTO test VALUES (1001, 'after-truncate')") + .unwrap(); + + // Check WAL has new content + let new_size = std::fs::metadata(&walpath).unwrap().len(); + assert!(new_size >= 32, "WAL file too small"); + let hdr = read_wal_header(&walpath); + let expected_magic = if cfg!(target_endian = "big") { + sqlite3_ondisk::WAL_MAGIC_BE + } else { + sqlite3_ondisk::WAL_MAGIC_LE + }; + assert!( + hdr.magic == expected_magic, + "bad WAL magic: {:#X}, expected: {:#X}", + hdr.magic, + sqlite3_ondisk::WAL_MAGIC_BE + ); + assert_eq!(hdr.file_format, 3007000); + assert_eq!(hdr.page_size, 4096, "invalid page size"); + assert_eq!(hdr.checkpoint_seq, 1, "invalid checkpoint_seq"); + { + let pager = conn.pager.borrow(); + let mut wal = pager.wal.borrow_mut(); + run_checkpoint_until_done(&mut *wal, &pager, CheckpointMode::Passive); + } + // delete the WAL file so we can read right from db and assert + // that everything was backfilled properly + std::fs::remove_file(&walpath).unwrap(); + + let count = count_test_table(&conn); + assert_eq!( + count, 1001, + "we should have 1001 rows in the table all together" + ); + std::fs::remove_dir_all(path).unwrap(); + } + fn read_wal_header(path: &std::path::Path) -> sqlite3_ondisk::WalHeader { use std::{fs::File, io::Read}; let mut hdr = [0u8; 32]; diff --git a/simulator/runner/file.rs b/simulator/runner/file.rs index ba3680333..9ed80e34c 100644 --- a/simulator/runner/file.rs +++ b/simulator/runner/file.rs @@ -222,6 +222,34 @@ impl File for SimulatorFile { Ok(c) } + fn pwritev( + &self, + pos: usize, + buffers: Vec>>, + c: turso_core::Completion, + ) -> Result { + self.nr_pwrite_calls.set(self.nr_pwrite_calls.get() + 1); + if self.fault.get() { + tracing::debug!("pwritev fault"); + self.nr_pwrite_faults.set(self.nr_pwrite_faults.get() + 1); + return Err(turso_core::LimboError::InternalError( + FAULT_ERROR_MSG.into(), + )); + } + if let Some(latency) = self.generate_latency_duration() { + let cloned_c = c.clone(); + let op = + Box::new(move |file: &SimulatorFile| file.inner.pwritev(pos, buffers, cloned_c)); + self.queued_io + .borrow_mut() + .push(DelayedIo { time: latency, op }); + Ok(c) + } else { + let c = self.inner.pwritev(pos, buffers, c)?; + Ok(c) + } + } + fn size(&self) -> Result { self.inner.size() } diff --git a/testing/cli_tests/vfs_bench.py b/testing/cli_tests/vfs_bench.py index b54ababf3..dc637c37b 100644 --- a/testing/cli_tests/vfs_bench.py +++ b/testing/cli_tests/vfs_bench.py @@ -48,6 +48,9 @@ def bench_one(vfs: str, sql: str, iterations: int) -> list[float]: def setup_temp_db() -> None: + # make sure we start fresh, otherwise we could end up with + # one having to checkpoint the others from the previous run + cleanup_temp_db() cmd = ["sqlite3", "testing/testing.db", ".clone testing/temp.db"] proc = subprocess.run(cmd, check=True) proc.check_returncode() @@ -57,7 +60,9 @@ def setup_temp_db() -> None: def cleanup_temp_db() -> None: if DB_FILE.exists(): DB_FILE.unlink() - os.remove("testing/temp.db-wal") + wal_file = DB_FILE.with_suffix(".db-wal") + if wal_file.exists(): + os.remove(wal_file) def main() -> None: @@ -65,7 +70,6 @@ def main() -> None: parser.add_argument("sql", help="SQL statement to execute (quote it)") parser.add_argument("iterations", type=int, help="number of repetitions") args = parser.parse_args() - setup_temp_db() sql, iterations = args.sql, args.iterations if iterations <= 0: @@ -79,12 +83,15 @@ def main() -> None: averages: Dict[str, float] = {} for vfs in vfs_list: + setup_temp_db() test(f"\n### VFS: {vfs} ###") times = bench_one(vfs, sql, iterations) info(f"All times ({vfs}):", " ".join(f"{t:.6f}" for t in times)) avg = statistics.mean(times) averages[vfs] = avg + cleanup_temp_db() + info("\n" + "-" * 60) info("Average runtime per VFS") info("-" * 60) @@ -106,7 +113,6 @@ def main() -> None: faster_slower = "slower" if pct > 0 else "faster" info(f"{vfs:<{name_pad}} : {avg:.6f} ({abs(pct):.1f}% {faster_slower} than {baseline})") info("-" * 60) - cleanup_temp_db() if __name__ == "__main__":