use turso_sqlite3_parser::ast::SortOrder; use std::cell::{Cell, RefCell}; use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd, Reverse}; use std::collections::BinaryHeap; use std::rc::Rc; use std::sync::Arc; use tempfile; use crate::{ error::LimboError, io::{ Buffer, BufferData, Completion, CompletionType, File, OpenFlags, ReadCompletion, WriteCompletion, IO, }, storage::sqlite3_ondisk::{read_varint, varint_len, write_varint}, translate::collate::CollationSeq, types::{compare_immutable, IOResult, ImmutableRecord, KeyInfo}, Result, }; pub struct Sorter { /// The records in the in-memory buffer. records: Vec, /// The current record. current: Option, /// The number of values in the key. key_len: usize, /// The key info. index_key_info: Rc>, /// Sorted chunks stored on disk. chunks: Vec, /// The heap of records consumed from the chunks and their corresponding chunk index. chunk_heap: BinaryHeap<(Reverse, usize)>, /// The maximum size of the in-memory buffer in bytes before the records are flushed to a chunk file. max_buffer_size: usize, /// The current size of the in-memory buffer in bytes. current_buffer_size: usize, /// The minimum size of a chunk read buffer in bytes. The actual buffer size can be larger if the largest /// record in the buffer is larger than this value. min_chunk_read_buffer_size: usize, /// The maximum record payload size in the in-memory buffer. max_payload_size_in_buffer: usize, /// The IO object. io: Arc, /// The indices of the chunks for which the read is not complete. wait_for_read_complete: Vec, /// The temporary directory for chunk files. temp_dir: Option, } impl Sorter { pub fn new( order: &[SortOrder], collations: Vec, max_buffer_size_bytes: usize, min_chunk_read_buffer_size_bytes: usize, io: Arc, ) -> Self { assert_eq!(order.len(), collations.len()); Self { records: Vec::new(), current: None, key_len: order.len(), index_key_info: Rc::new( order .iter() .zip(collations) .map(|(order, collation)| KeyInfo { sort_order: *order, collation, }) .collect(), ), chunks: Vec::new(), chunk_heap: BinaryHeap::new(), max_buffer_size: max_buffer_size_bytes, current_buffer_size: 0, min_chunk_read_buffer_size: min_chunk_read_buffer_size_bytes, max_payload_size_in_buffer: 0, io, wait_for_read_complete: Vec::new(), temp_dir: None, } } pub fn is_empty(&self) -> bool { self.records.is_empty() && self.chunks.is_empty() } pub fn has_more(&self) -> bool { self.current.is_some() } // We do the sorting here since this is what is called by the SorterSort instruction pub fn sort(&mut self) -> Result> { if self.chunks.is_empty() { self.records.sort(); self.records.reverse(); } else { self.flush()?; if let IOResult::IO = self.init_chunk_heap()? { return Ok(IOResult::IO); } } self.next() } pub fn next(&mut self) -> Result> { if self.chunks.is_empty() { // Serve from the in-memory buffer. self.current = self.records.pop().map(|r| r.record); } else { // Serve from sorted chunk files. match self.next_from_chunk_heap()? { IOResult::IO => return Ok(IOResult::IO), IOResult::Done(record) => self.current = record, } } Ok(IOResult::Done(())) } pub fn record(&self) -> Option<&ImmutableRecord> { self.current.as_ref() } pub fn insert(&mut self, record: &ImmutableRecord) -> Result<()> { let payload_size = record.get_payload().len(); if self.current_buffer_size + payload_size > self.max_buffer_size { self.flush()?; } self.records.push(SortableImmutableRecord::new( record.clone(), self.key_len, self.index_key_info.clone(), )); self.current_buffer_size += payload_size; self.max_payload_size_in_buffer = self.max_payload_size_in_buffer.max(payload_size); Ok(()) } fn init_chunk_heap(&mut self) -> Result> { let mut all_read_complete = true; // Make sure all chunks read at least one record into their buffer. for chunk in self.chunks.iter_mut() { match chunk.io_state.get() { SortedChunkIOState::WriteComplete => { all_read_complete = false; // Write complete, we can now read from the chunk. chunk.read()?; } SortedChunkIOState::WaitingForWrite | SortedChunkIOState::WaitingForRead => { all_read_complete = false; } SortedChunkIOState::ReadEOF | SortedChunkIOState::ReadComplete => {} _ => { unreachable!("Unexpected chunk IO state: {:?}", chunk.io_state.get()) } } } if !all_read_complete { return Ok(IOResult::IO); } self.chunk_heap.reserve(self.chunks.len()); for chunk_idx in 0..self.chunks.len() { self.push_to_chunk_heap(chunk_idx)?; } Ok(IOResult::Done(())) } fn next_from_chunk_heap(&mut self) -> Result>> { let mut all_read_complete = true; for chunk_idx in self.wait_for_read_complete.iter() { let chunk_io_state = self.chunks[*chunk_idx].io_state.get(); match chunk_io_state { SortedChunkIOState::ReadComplete | SortedChunkIOState::ReadEOF => {} SortedChunkIOState::WaitingForRead => { all_read_complete = false; } _ => { unreachable!("Unexpected chunk IO state: {:?}", chunk_io_state) } } } if !all_read_complete { return Ok(IOResult::IO); } self.wait_for_read_complete.clear(); if let Some((next_record, next_chunk_idx)) = self.chunk_heap.pop() { self.push_to_chunk_heap(next_chunk_idx)?; Ok(IOResult::Done(Some(next_record.0.record))) } else { Ok(IOResult::Done(None)) } } fn push_to_chunk_heap(&mut self, chunk_idx: usize) -> Result<()> { let chunk = &mut self.chunks[chunk_idx]; if chunk.has_more() { let record = chunk.next()?.unwrap(); self.chunk_heap.push(( Reverse(SortableImmutableRecord::new( record, self.key_len, self.index_key_info.clone(), )), chunk_idx, )); if let SortedChunkIOState::WaitingForRead = chunk.io_state.get() { self.wait_for_read_complete.push(chunk_idx); } } Ok(()) } fn flush(&mut self) -> Result<()> { if self.records.is_empty() { return Ok(()); } self.records.sort(); if self.temp_dir.is_none() { self.temp_dir = Some(tempfile::tempdir().map_err(LimboError::IOError)?); } let chunk_file_path = self .temp_dir .as_ref() .unwrap() .path() .join(format!("chunk_{}", self.chunks.len())); let chunk_file = self.io .open_file(chunk_file_path.to_str().unwrap(), OpenFlags::Create, false)?; // Make sure the chunk buffer size can fit the largest record and its size varint. let chunk_buffer_size = self .min_chunk_read_buffer_size .max(self.max_payload_size_in_buffer + 9); let mut chunk = SortedChunk::new(chunk_file.clone(), chunk_buffer_size); chunk.write(&mut self.records)?; self.chunks.push(chunk); self.current_buffer_size = 0; self.max_payload_size_in_buffer = 0; Ok(()) } } struct SortedChunk { /// The chunk file. file: Arc, /// The size of this chunk file in bytes. chunk_size: usize, /// The read buffer. buffer: Rc>>, /// The current length of the buffer. buffer_len: Rc>, /// The records decoded from the chunk file. records: Vec, /// The current IO state of the chunk. io_state: Rc>, /// The total number of bytes read from the chunk file. total_bytes_read: Rc>, } impl SortedChunk { fn new(file: Arc, buffer_size: usize) -> Self { Self { file, chunk_size: 0, buffer: Rc::new(RefCell::new(vec![0; buffer_size])), buffer_len: Rc::new(Cell::new(0)), records: Vec::new(), io_state: Rc::new(Cell::new(SortedChunkIOState::None)), total_bytes_read: Rc::new(Cell::new(0)), } } fn has_more(&self) -> bool { !self.records.is_empty() || self.io_state.get() != SortedChunkIOState::ReadEOF } fn next(&mut self) -> Result> { let mut buffer_len = self.buffer_len.get(); if self.records.is_empty() && buffer_len == 0 { return Ok(None); } if self.records.is_empty() { let mut buffer_ref = self.buffer.borrow_mut(); let buffer = buffer_ref.as_mut_slice(); let mut buffer_offset = 0; while buffer_offset < buffer_len { // Extract records from the buffer until we run out of the buffer or we hit an incomplete record. let (record_size, bytes_read) = match read_varint(&buffer[buffer_offset..buffer_len]) { Ok((record_size, bytes_read)) => (record_size as usize, bytes_read), Err(LimboError::Corrupt(_)) if self.io_state.get() != SortedChunkIOState::ReadEOF => { // Failed to decode a partial varint. break; } Err(e) => { return Err(e); } }; if record_size > buffer_len - (buffer_offset + bytes_read) { if self.io_state.get() == SortedChunkIOState::ReadEOF { crate::bail_corrupt_error!("Incomplete record"); } break; } buffer_offset += bytes_read; let mut record = ImmutableRecord::new(record_size); record.start_serialization(&buffer[buffer_offset..buffer_offset + record_size]); buffer_offset += record_size; self.records.push(record); } if buffer_offset < buffer_len { buffer.copy_within(buffer_offset..buffer_len, 0); buffer_len -= buffer_offset; } else { buffer_len = 0; } self.buffer_len.set(buffer_len); self.records.reverse(); } let record = self.records.pop(); if self.records.is_empty() && self.io_state.get() != SortedChunkIOState::ReadEOF { // We've consumed the last record. Read more payload into the buffer. self.read()?; } Ok(record) } fn read(&mut self) -> Result<()> { if self.io_state.get() == SortedChunkIOState::ReadEOF { return Ok(()); } if self.chunk_size - self.total_bytes_read.get() == 0 { self.io_state.set(SortedChunkIOState::ReadEOF); return Ok(()); } self.io_state.set(SortedChunkIOState::WaitingForRead); let read_buffer_size = self.buffer.borrow().len() - self.buffer_len.get(); let read_buffer_size = read_buffer_size.min(self.chunk_size - self.total_bytes_read.get()); let drop_fn = Rc::new(|_buffer: BufferData| {}); let read_buffer = Buffer::allocate(read_buffer_size, drop_fn); let read_buffer_ref = Arc::new(RefCell::new(read_buffer)); let chunk_io_state_copy = self.io_state.clone(); let stored_buffer_copy = self.buffer.clone(); let stored_buffer_len_copy = self.buffer_len.clone(); let total_bytes_read_copy = self.total_bytes_read.clone(); let read_complete = Box::new(move |buf: Arc>, bytes_read: i32| { let read_buf_ref = buf.borrow(); let read_buf = read_buf_ref.as_slice(); let bytes_read = bytes_read as usize; if bytes_read == 0 { chunk_io_state_copy.set(SortedChunkIOState::ReadEOF); return; } chunk_io_state_copy.set(SortedChunkIOState::ReadComplete); let mut stored_buf_ref = stored_buffer_copy.borrow_mut(); let stored_buf = stored_buf_ref.as_mut_slice(); let mut stored_buf_len = stored_buffer_len_copy.get(); stored_buf[stored_buf_len..stored_buf_len + bytes_read] .copy_from_slice(&read_buf[..bytes_read]); stored_buf_len += bytes_read; stored_buffer_len_copy.set(stored_buf_len); total_bytes_read_copy.set(total_bytes_read_copy.get() + bytes_read); }); let c = Completion::new(CompletionType::Read(ReadCompletion::new( read_buffer_ref, read_complete, ))); self.file.pread(self.total_bytes_read.get(), Arc::new(c))?; Ok(()) } fn write(&mut self, records: &mut Vec) -> Result<()> { assert!(self.io_state.get() == SortedChunkIOState::None); self.io_state.set(SortedChunkIOState::WaitingForWrite); self.chunk_size = 0; // Pre-compute varint lengths for record sizes to determine the total buffer size. let mut record_size_lengths = Vec::with_capacity(records.len()); for record in records.iter() { let record_size = record.record.get_payload().len(); let size_len = varint_len(record_size as u64); record_size_lengths.push(size_len); self.chunk_size += size_len + record_size; } let drop_fn = Rc::new(|_buffer: BufferData| {}); let mut buffer = Buffer::allocate(self.chunk_size, drop_fn); let mut buf_pos = 0; let buf = buffer.as_mut_slice(); for (record, size_len) in records.drain(..).zip(record_size_lengths) { let payload = record.record.get_payload(); // Write the record size varint. write_varint(&mut buf[buf_pos..buf_pos + size_len], payload.len() as u64); buf_pos += size_len; // Write the record payload. buf[buf_pos..buf_pos + payload.len()].copy_from_slice(payload); buf_pos += payload.len(); } let buffer_ref = Arc::new(RefCell::new(buffer)); let buffer_ref_copy = buffer_ref.clone(); let chunk_io_state_copy = self.io_state.clone(); let write_complete = Box::new(move |bytes_written: i32| { chunk_io_state_copy.set(SortedChunkIOState::WriteComplete); let buf_len = buffer_ref_copy.borrow().len(); if bytes_written < buf_len as i32 { tracing::error!("wrote({bytes_written}) less than expected({buf_len})"); } }); let c = Completion::new(CompletionType::Write(WriteCompletion::new(write_complete))); self.file.pwrite(0, buffer_ref, Arc::new(c))?; Ok(()) } } struct SortableImmutableRecord { record: ImmutableRecord, key_len: usize, index_key_info: Rc>, } impl SortableImmutableRecord { fn new(record: ImmutableRecord, key_len: usize, index_key_info: Rc>) -> Self { Self { record, key_len, index_key_info, } } } impl Ord for SortableImmutableRecord { fn cmp(&self, other: &Self) -> Ordering { let this_values = self.record.get_values(); let other_values = other.record.get_values(); let a_key = if this_values.len() >= self.key_len { &this_values[..self.key_len] } else { &this_values[..] }; let b_key = if other_values.len() >= self.key_len { &other_values[..self.key_len] } else { &other_values[..] }; compare_immutable(a_key, b_key, self.index_key_info.as_ref()) } } impl PartialOrd for SortableImmutableRecord { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl PartialEq for SortableImmutableRecord { fn eq(&self, other: &Self) -> bool { self.cmp(other) == Ordering::Equal } } impl Eq for SortableImmutableRecord {} #[derive(Debug, Copy, Clone, PartialEq, Eq)] enum SortedChunkIOState { WaitingForRead, ReadComplete, ReadEOF, WaitingForWrite, WriteComplete, None, } #[cfg(test)] mod tests { use super::*; use crate::translate::collate::CollationSeq; use crate::types::{ImmutableRecord, RefValue, Value, ValueType}; use crate::PlatformIO; use rand_chacha::{ rand_core::{RngCore, SeedableRng}, ChaCha8Rng, }; fn get_seed() -> u64 { std::env::var("SEED").map_or( std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_millis(), |v| { v.parse() .expect("Failed to parse SEED environment variable as u64") }, ) as u64 } #[test] fn fuzz_external_sort() { let seed = get_seed(); let mut rng = ChaCha8Rng::seed_from_u64(seed); let io = Arc::new(PlatformIO::new().unwrap()); let mut sorter = Sorter::new( &[SortOrder::Asc], vec![CollationSeq::Binary], 256, 64, io.clone(), ); let attempts = 8; for _ in 0..attempts { let num_records = 1000 + rng.next_u64() % 2000; let num_records = num_records as i64; let num_values = 1 + rng.next_u64() % 4; let value_types = generate_value_types(&mut rng, num_values as usize); let mut initial_records = Vec::with_capacity(num_records as usize); for i in (0..num_records).rev() { let mut values = vec![Value::Integer(i)]; values.append(&mut generate_values(&mut rng, &value_types)); let record = ImmutableRecord::from_values(&values, values.len()); sorter.insert(&record).expect("Failed to insert the record"); initial_records.push(record); } loop { if let IOResult::IO = sorter.sort().expect("Failed to sort the records") { io.run_once().expect("Failed to run the IO"); continue; } break; } assert!(!sorter.is_empty()); assert!(!sorter.chunks.is_empty()); for i in 0..num_records { assert!(sorter.has_more()); let record = sorter.record().unwrap(); assert_eq!(record.get_values()[0], RefValue::Integer(i)); // Check that the record remained unchanged after sorting. assert_eq!(record, &initial_records[(num_records - i - 1) as usize]); loop { if let IOResult::IO = sorter.next().expect("Failed to get the next record") { io.run_once().expect("Failed to run the IO"); continue; } break; } } assert!(!sorter.has_more()); } } fn generate_value_types(rng: &mut R, num_values: usize) -> Vec { let mut value_types = Vec::with_capacity(num_values); for _ in 0..num_values { let value_type: ValueType = match rng.next_u64() % 4 { 0 => ValueType::Integer, 1 => ValueType::Float, 2 => ValueType::Blob, 3 => ValueType::Null, _ => unreachable!(), }; value_types.push(value_type); } value_types } fn generate_values(rng: &mut R, value_types: &[ValueType]) -> Vec { let mut values = Vec::with_capacity(value_types.len()); for value_type in value_types { let value = match value_type { ValueType::Integer => Value::Integer(rng.next_u64() as i64), ValueType::Float => { let numerator = rng.next_u64() as f64; let denominator = rng.next_u64() as f64; Value::Float(numerator / denominator) } ValueType::Blob => { let mut blob = Vec::with_capacity((rng.next_u64() % 2047 + 1) as usize); rng.fill_bytes(&mut blob); Value::Blob(blob) } ValueType::Null => Value::Null, _ => unreachable!(), }; values.push(value); } values } }