diff --git a/core/vdbe/sorter.rs b/core/vdbe/sorter.rs index d8469458c..105d8095d 100644 --- a/core/vdbe/sorter.rs +++ b/core/vdbe/sorter.rs @@ -1,10 +1,10 @@ use turso_parser::ast::SortOrder; -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd, Reverse}; use std::collections::BinaryHeap; use std::rc::Rc; -use std::sync::Arc; +use std::sync::{atomic, Arc, RwLock}; use tempfile; use crate::types::IOCompletions; @@ -175,7 +175,10 @@ impl Sorter { SortState::InitHeap => { turso_assert!( !self.chunks.iter().any(|chunk| { - matches!(chunk.io_state.get(), SortedChunkIOState::WaitingForWrite) + matches!( + *chunk.io_state.read().unwrap(), + SortedChunkIOState::WaitingForWrite + ) }), "chunks should been written" ); @@ -231,7 +234,10 @@ impl Sorter { InsertState::Insert => { turso_assert!( !self.chunks.iter().any(|chunk| { - matches!(chunk.io_state.get(), SortedChunkIOState::WaitingForWrite) + matches!( + *chunk.io_state.read().unwrap(), + SortedChunkIOState::WaitingForWrite + ) }), "chunks should have written" ); @@ -272,7 +278,7 @@ impl Sorter { // Make sure all chunks read at least one record into their buffer. turso_assert!( !self.chunks.iter().any(|chunk| matches!( - chunk.io_state.get(), + *chunk.io_state.read().unwrap(), SortedChunkIOState::WaitingForRead )), "chunks should have been read" @@ -292,10 +298,10 @@ impl Sorter { fn next_from_chunk_heap(&mut self) -> Result>> { // Make sure all chunks read at least one record into their buffer. turso_assert!( - !self - .chunks - .iter() - .any(|chunk| matches!(chunk.io_state.get(), SortedChunkIOState::WaitingForRead)), + !self.chunks.iter().any(|chunk| matches!( + *chunk.io_state.read().unwrap(), + SortedChunkIOState::WaitingForRead + )), "chunks should have been read" ); @@ -394,15 +400,15 @@ struct SortedChunk { /// The size of this chunk file in bytes. chunk_size: usize, /// The read buffer. - buffer: Rc>>, + buffer: Arc>>, /// The current length of the buffer. - buffer_len: Rc>, + buffer_len: Arc, /// The records decoded from the chunk file. records: Vec, /// The current IO state of the chunk. - io_state: Rc>, + io_state: Arc>, /// The total number of bytes read from the chunk file. - total_bytes_read: Rc>, + total_bytes_read: Arc, /// State machine for [SortedChunk::next] next_state: NextState, } @@ -413,26 +419,34 @@ impl SortedChunk { file, start_offset: start_offset as u64, chunk_size: 0, - buffer: Rc::new(RefCell::new(vec![0; buffer_size])), - buffer_len: Rc::new(Cell::new(0)), + buffer: Arc::new(RwLock::new(vec![0; buffer_size])), + buffer_len: Arc::new(atomic::AtomicUsize::new(0)), records: Vec::new(), - io_state: Rc::new(Cell::new(SortedChunkIOState::None)), - total_bytes_read: Rc::new(Cell::new(0)), + io_state: Arc::new(RwLock::new(SortedChunkIOState::None)), + total_bytes_read: Arc::new(atomic::AtomicUsize::new(0)), next_state: NextState::Start, } } + fn buffer_len(&self) -> usize { + self.buffer_len.load(atomic::Ordering::SeqCst) + } + + fn set_buffer_len(&self, len: usize) { + self.buffer_len.store(len, atomic::Ordering::SeqCst); + } + fn next(&mut self) -> Result>> { loop { match self.next_state { NextState::Start => { - let mut buffer_len = self.buffer_len.get(); + let mut buffer_len = self.buffer_len(); if self.records.is_empty() && buffer_len == 0 { return Ok(IOResult::Done(None)); } if self.records.is_empty() { - let mut buffer_ref = self.buffer.borrow_mut(); + let mut buffer_ref = self.buffer.write().unwrap(); let buffer = buffer_ref.as_mut_slice(); let mut buffer_offset = 0; while buffer_offset < buffer_len { @@ -443,7 +457,8 @@ impl SortedChunk { (record_size as usize, bytes_read) } Err(LimboError::Corrupt(_)) - if self.io_state.get() != SortedChunkIOState::ReadEOF => + if *self.io_state.read().unwrap() + != SortedChunkIOState::ReadEOF => { // Failed to decode a partial varint. break; @@ -453,7 +468,7 @@ impl SortedChunk { } }; if record_size > buffer_len - (buffer_offset + bytes_read) { - if self.io_state.get() == SortedChunkIOState::ReadEOF { + if *self.io_state.read().unwrap() == SortedChunkIOState::ReadEOF { crate::bail_corrupt_error!("Incomplete record"); } break; @@ -474,18 +489,21 @@ impl SortedChunk { } else { buffer_len = 0; } - self.buffer_len.set(buffer_len); + self.set_buffer_len(buffer_len); self.records.reverse(); } self.next_state = NextState::Finish; // This check is done to see if we need to read more from the chunk before popping the record - if self.records.len() == 1 && self.io_state.get() != SortedChunkIOState::ReadEOF + if self.records.len() == 1 + && *self.io_state.read().unwrap() != SortedChunkIOState::ReadEOF { // We've consumed the last record. Read more payload into the buffer. - if self.chunk_size - self.total_bytes_read.get() == 0 { - self.io_state.set(SortedChunkIOState::ReadEOF); + if self.chunk_size - self.total_bytes_read.load(atomic::Ordering::SeqCst) + == 0 + { + *self.io_state.write().unwrap() = SortedChunkIOState::ReadEOF; } else { let c = self.read()?; io_yield_one!(c); @@ -501,10 +519,11 @@ impl SortedChunk { } fn read(&mut self) -> Result { - self.io_state.set(SortedChunkIOState::WaitingForRead); + *self.io_state.write().unwrap() = SortedChunkIOState::WaitingForRead; - let read_buffer_size = self.buffer.borrow().len() - self.buffer_len.get(); - let read_buffer_size = read_buffer_size.min(self.chunk_size - self.total_bytes_read.get()); + let read_buffer_size = self.buffer.read().unwrap().len() - self.buffer_len(); + let read_buffer_size = read_buffer_size + .min(self.chunk_size - self.total_bytes_read.load(atomic::Ordering::SeqCst)); let read_buffer = Buffer::new_temporary(read_buffer_size); let read_buffer_ref = Arc::new(read_buffer); @@ -522,27 +541,28 @@ impl SortedChunk { let bytes_read = bytes_read as usize; if bytes_read == 0 { - chunk_io_state_copy.set(SortedChunkIOState::ReadEOF); + *chunk_io_state_copy.write().unwrap() = SortedChunkIOState::ReadEOF; return; } - chunk_io_state_copy.set(SortedChunkIOState::ReadComplete); + *chunk_io_state_copy.write().unwrap() = SortedChunkIOState::ReadComplete; - let mut stored_buf_ref = stored_buffer_copy.borrow_mut(); + let mut stored_buf_ref = stored_buffer_copy.write().unwrap(); let stored_buf = stored_buf_ref.as_mut_slice(); - let mut stored_buf_len = stored_buffer_len_copy.get(); + let mut stored_buf_len = stored_buffer_len_copy.load(atomic::Ordering::SeqCst); stored_buf[stored_buf_len..stored_buf_len + bytes_read] .copy_from_slice(&read_buf[..bytes_read]); stored_buf_len += bytes_read; - stored_buffer_len_copy.set(stored_buf_len); - total_bytes_read_copy.set(total_bytes_read_copy.get() + bytes_read); + stored_buffer_len_copy.store(stored_buf_len, atomic::Ordering::SeqCst); + total_bytes_read_copy.fetch_add(bytes_read, atomic::Ordering::SeqCst); }); let c = Completion::new_read(read_buffer_ref, read_complete); - let c = self - .file - .pread(self.start_offset + self.total_bytes_read.get() as u64, c)?; + let c = self.file.pread( + self.start_offset + self.total_bytes_read.load(atomic::Ordering::SeqCst) as u64, + c, + )?; Ok(c) } @@ -552,8 +572,8 @@ impl SortedChunk { record_size_lengths: Vec, chunk_size: usize, ) -> Result { - assert!(self.io_state.get() == SortedChunkIOState::None); - self.io_state.set(SortedChunkIOState::WaitingForWrite); + assert!(*self.io_state.read().unwrap() == SortedChunkIOState::None); + *self.io_state.write().unwrap() = SortedChunkIOState::WaitingForWrite; self.chunk_size = chunk_size; let buffer = Buffer::new_temporary(self.chunk_size); @@ -578,7 +598,7 @@ impl SortedChunk { let Ok(bytes_written) = res else { return; }; - chunk_io_state_copy.set(SortedChunkIOState::WriteComplete); + *chunk_io_state_copy.write().unwrap() = SortedChunkIOState::WriteComplete; let buf_len = buffer_ref_copy.len(); if bytes_written < buf_len as i32 { tracing::error!("wrote({bytes_written}) less than expected({buf_len})");