diff --git a/core/vdbe/sorter.rs b/core/vdbe/sorter.rs index 07e8088c3..021256d2a 100644 --- a/core/vdbe/sorter.rs +++ b/core/vdbe/sorter.rs @@ -15,7 +15,8 @@ use crate::{ }, storage::sqlite3_ondisk::{read_varint, varint_len, write_varint}, translate::collate::CollationSeq, - types::{compare_immutable, IOResult, ImmutableRecord, KeyInfo}, + turso_assert, + types::{IOResult, ImmutableRecord, KeyInfo, RecordCursor, RefValue}, Result, }; @@ -107,15 +108,25 @@ impl Sorter { } pub fn next(&mut self) -> Result> { - if self.chunks.is_empty() { + let record = if self.chunks.is_empty() { // Serve from the in-memory buffer. - self.current = self.records.pop().map(|r| r.record); + self.records.pop() } else { // Serve from sorted chunk files. match self.next_from_chunk_heap()? { IOResult::IO => return Ok(IOResult::IO), - IOResult::Done(record) => self.current = record, + IOResult::Done(record) => record, } + }; + match record { + Some(record) => { + if let Some(error) = record.deserialization_error.replace(None) { + // If there was a key deserialization error during the comparison, return the error. + return Err(error); + } + self.current = Some(record.record); + } + None => self.current = None, } Ok(IOResult::Done(())) } @@ -133,7 +144,7 @@ impl Sorter { record.clone(), self.key_len, self.index_key_info.clone(), - )); + )?); self.current_buffer_size += payload_size; self.max_payload_size_in_buffer = self.max_payload_size_in_buffer.max(payload_size); Ok(()) @@ -168,7 +179,7 @@ impl Sorter { Ok(IOResult::Done(())) } - fn next_from_chunk_heap(&mut self) -> Result>> { + fn next_from_chunk_heap(&mut self) -> Result>> { let mut all_read_complete = true; for chunk_idx in self.wait_for_read_complete.iter() { let chunk_io_state = self.chunks[*chunk_idx].io_state.get(); @@ -189,7 +200,7 @@ impl Sorter { if let Some((next_record, next_chunk_idx)) = self.chunk_heap.pop() { self.push_to_chunk_heap(next_chunk_idx)?; - Ok(IOResult::Done(Some(next_record.0.record))) + Ok(IOResult::Done(Some(next_record.0))) } else { Ok(IOResult::Done(None)) } @@ -205,7 +216,7 @@ impl Sorter { record, self.key_len, self.index_key_info.clone(), - )), + )?), chunk_idx, )); if let SortedChunkIOState::WaitingForRead = chunk.io_state.get() { @@ -444,38 +455,102 @@ impl SortedChunk { struct SortableImmutableRecord { record: ImmutableRecord, - key_len: usize, + cursor: RecordCursor, + key_values: RefCell>, index_key_info: Rc>, + /// The key deserialization error, if any. + deserialization_error: RefCell>, } impl SortableImmutableRecord { - fn new(record: ImmutableRecord, key_len: usize, index_key_info: Rc>) -> Self { - Self { + fn new( + record: ImmutableRecord, + key_len: usize, + index_key_info: Rc>, + ) -> Result { + let mut cursor = RecordCursor::with_capacity(key_len); + cursor.ensure_parsed_upto(&record, key_len - 1)?; + turso_assert!( + index_key_info.len() >= cursor.serial_types.len(), + "index_key_info.len() < cursor.serial_types.len()" + ); + Ok(Self { record, - key_len, + cursor, + key_values: RefCell::new(Vec::with_capacity(key_len)), index_key_info, + deserialization_error: RefCell::new(None), + }) + } + + /// Attempts to deserialize the key value at the given index. + /// If the key value has already been deserialized, this does nothing. + /// The deserialized key value is stored in the `key_values` field. + /// In case of an error, the error is stored in the `deserialization_error` field. + fn try_deserialize_key(&self, idx: usize) { + let mut key_values = self.key_values.borrow_mut(); + if idx < key_values.len() { + // The key value with this index has already been deserialized. + return; + } + match self.cursor.deserialize_column(&self.record, idx) { + Ok(value) => key_values.push(value), + Err(error) => { + self.deserialization_error.replace(Some(error)); + } } } } impl Ord for SortableImmutableRecord { fn cmp(&self, other: &Self) -> Ordering { - let this_values = self.record.get_values(); - let other_values = other.record.get_values(); + if self.deserialization_error.borrow().is_some() + || other.deserialization_error.borrow().is_some() + { + // If one of the records has a deserialization error, circumvent the comparison and return early. + return Ordering::Equal; + } + assert_eq!( + self.cursor.serial_types.len(), + other.cursor.serial_types.len() + ); + let this_key_values_len = self.key_values.borrow().len(); + let other_key_values_len = other.key_values.borrow().len(); - let a_key = if this_values.len() >= self.key_len { - &this_values[..self.key_len] - } else { - &this_values[..] - }; + for i in 0..self.cursor.serial_types.len() { + // Lazily deserialize the key values if they haven't been deserialized already. + if i >= this_key_values_len { + self.try_deserialize_key(i); + if self.deserialization_error.borrow().is_some() { + return Ordering::Equal; + } + } + if i >= other_key_values_len { + other.try_deserialize_key(i); + if other.deserialization_error.borrow().is_some() { + return Ordering::Equal; + } + } - let b_key = if other_values.len() >= self.key_len { - &other_values[..self.key_len] - } else { - &other_values[..] - }; + let this_key_value = &self.key_values.borrow()[i]; + let other_key_value = &other.key_values.borrow()[i]; + let column_order = self.index_key_info[i].sort_order; + let collation = self.index_key_info[i].collation; - compare_immutable(a_key, b_key, self.index_key_info.as_ref()) + let cmp = match (this_key_value, other_key_value) { + (RefValue::Text(left), RefValue::Text(right)) => { + collation.compare_strings(left.as_str(), right.as_str()) + } + _ => this_key_value.partial_cmp(other_key_value).unwrap(), + }; + if !cmp.is_eq() { + return match column_order { + SortOrder::Asc => cmp, + SortOrder::Desc => cmp.reverse(), + }; + } + } + std::cmp::Ordering::Equal } }