From c91c22a6a84af216e7945d9c0b2e23888d98d174 Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Wed, 6 Aug 2025 13:12:45 -0300 Subject: [PATCH] state machine for `next` --- core/vdbe/sorter.rs | 153 ++++++++++++++++++++++++-------------------- 1 file changed, 85 insertions(+), 68 deletions(-) diff --git a/core/vdbe/sorter.rs b/core/vdbe/sorter.rs index 52771e5b5..9ce453372 100644 --- a/core/vdbe/sorter.rs +++ b/core/vdbe/sorter.rs @@ -7,6 +7,7 @@ use std::rc::Rc; use std::sync::Arc; use tempfile; +use crate::return_if_io; use crate::{ error::LimboError, io::{Buffer, Completion, File, OpenFlags, IO}, @@ -152,11 +153,6 @@ impl Sorter { // Make sure all chunks read at least one record into their buffer. for chunk in self.chunks.iter_mut() { match chunk.io_state.get() { - SortedChunkIOState::WriteComplete => { - all_read_complete = false; - // Write complete, we can now read from the chunk. - let _c = chunk.read()?; - } SortedChunkIOState::WaitingForWrite | SortedChunkIOState::WaitingForRead => { all_read_complete = false; } @@ -203,11 +199,10 @@ impl Sorter { } } - fn push_to_chunk_heap(&mut self, chunk_idx: usize) -> Result<()> { + fn push_to_chunk_heap(&mut self, chunk_idx: usize) -> Result> { let chunk = &mut self.chunks[chunk_idx]; - if chunk.has_more() { - let record = chunk.next()?.unwrap(); + if let Some(record) = return_if_io!(chunk.next()) { self.chunk_heap.push(( Reverse(SortableImmutableRecord::new( record, @@ -220,7 +215,8 @@ impl Sorter { self.wait_for_read_complete.push(chunk_idx); } } - Ok(()) + + Ok(IOResult::Done(())) } fn flush(&mut self) -> Result<()> { @@ -259,6 +255,12 @@ impl Sorter { } } +#[derive(Debug, Clone, Copy)] +enum NextState { + Start, + Finish, +} + struct SortedChunk { /// The chunk file. file: Arc, @@ -274,6 +276,8 @@ struct SortedChunk { io_state: Rc>, /// The total number of bytes read from the chunk file. total_bytes_read: Rc>, + /// State machine for [SortedChunk::next] + next_state: NextState, } impl SortedChunk { @@ -286,6 +290,7 @@ impl SortedChunk { records: Vec::new(), io_state: Rc::new(Cell::new(SortedChunkIOState::None)), total_bytes_read: Rc::new(Cell::new(0)), + next_state: NextState::Start, } } @@ -293,66 +298,81 @@ impl SortedChunk { !self.records.is_empty() || self.io_state.get() != SortedChunkIOState::ReadEOF } - fn next(&mut self) -> Result> { - let mut buffer_len = self.buffer_len.get(); - if self.records.is_empty() && buffer_len == 0 { - return Ok(None); - } - - if self.records.is_empty() { - let mut buffer_ref = self.buffer.borrow_mut(); - let buffer = buffer_ref.as_mut_slice(); - let mut buffer_offset = 0; - while buffer_offset < buffer_len { - // Extract records from the buffer until we run out of the buffer or we hit an incomplete record. - let (record_size, bytes_read) = - match read_varint(&buffer[buffer_offset..buffer_len]) { - Ok((record_size, bytes_read)) => (record_size as usize, bytes_read), - Err(LimboError::Corrupt(_)) - if self.io_state.get() != SortedChunkIOState::ReadEOF => - { - // Failed to decode a partial varint. - break; - } - Err(e) => { - return Err(e); - } - }; - if record_size > buffer_len - (buffer_offset + bytes_read) { - if self.io_state.get() == SortedChunkIOState::ReadEOF { - crate::bail_corrupt_error!("Incomplete record"); + fn next(&mut self) -> Result>> { + loop { + match self.next_state { + NextState::Start => { + let mut buffer_len = self.buffer_len.get(); + if self.records.is_empty() && buffer_len == 0 { + return Ok(IOResult::Done(None)); + } + + if self.records.is_empty() { + let mut buffer_ref = self.buffer.borrow_mut(); + let buffer = buffer_ref.as_mut_slice(); + let mut buffer_offset = 0; + while buffer_offset < buffer_len { + // Extract records from the buffer until we run out of the buffer or we hit an incomplete record. + let (record_size, bytes_read) = + match read_varint(&buffer[buffer_offset..buffer_len]) { + Ok((record_size, bytes_read)) => { + (record_size as usize, bytes_read) + } + Err(LimboError::Corrupt(_)) + if self.io_state.get() != SortedChunkIOState::ReadEOF => + { + // Failed to decode a partial varint. + break; + } + Err(e) => { + return Err(e); + } + }; + if record_size > buffer_len - (buffer_offset + bytes_read) { + if self.io_state.get() == SortedChunkIOState::ReadEOF { + crate::bail_corrupt_error!("Incomplete record"); + } + break; + } + buffer_offset += bytes_read; + + let mut record = ImmutableRecord::new(record_size); + record.start_serialization( + &buffer[buffer_offset..buffer_offset + record_size], + ); + buffer_offset += record_size; + + self.records.push(record); + } + if buffer_offset < buffer_len { + buffer.copy_within(buffer_offset..buffer_len, 0); + buffer_len -= buffer_offset; + } else { + buffer_len = 0; + } + self.buffer_len.set(buffer_len); + + self.records.reverse(); + } + + self.next_state = NextState::Finish; + if self.records.len() == 1 && self.io_state.get() != SortedChunkIOState::ReadEOF + { + // We've consumed the last record. Read more payload into the buffer. + if self.chunk_size - self.total_bytes_read.get() == 0 { + self.io_state.set(SortedChunkIOState::ReadEOF); + } else { + let _c = self.read()?; + return Ok(IOResult::IO); + } } - break; } - buffer_offset += bytes_read; - - let mut record = ImmutableRecord::new(record_size); - record.start_serialization(&buffer[buffer_offset..buffer_offset + record_size]); - buffer_offset += record_size; - - self.records.push(record); - } - if buffer_offset < buffer_len { - buffer.copy_within(buffer_offset..buffer_len, 0); - buffer_len -= buffer_offset; - } else { - buffer_len = 0; - } - self.buffer_len.set(buffer_len); - - self.records.reverse(); - } - - let record = self.records.pop(); - if self.records.is_empty() && self.io_state.get() != SortedChunkIOState::ReadEOF { - // We've consumed the last record. Read more payload into the buffer. - if self.chunk_size - self.total_bytes_read.get() == 0 { - self.io_state.set(SortedChunkIOState::ReadEOF); - } else { - let _c = self.read()?; + NextState::Finish => { + self.next_state = NextState::Start; + return Ok(IOResult::Done(self.records.pop())); + } } } - Ok(record) } fn read(&mut self) -> Result { @@ -427,9 +447,7 @@ impl SortedChunk { let buffer_ref = Arc::new(buffer); let buffer_ref_copy = buffer_ref.clone(); - let chunk_io_state_copy = self.io_state.clone(); let write_complete = Box::new(move |bytes_written: i32| { - chunk_io_state_copy.set(SortedChunkIOState::WriteComplete); let buf_len = buffer_ref_copy.len(); if bytes_written < buf_len as i32 { tracing::error!("wrote({bytes_written}) less than expected({buf_len})"); @@ -563,7 +581,6 @@ enum SortedChunkIOState { ReadComplete, ReadEOF, WaitingForWrite, - WriteComplete, None, }