Sorter: compute values upfront instead of deserializing on every comparison

This commit is contained in:
Jussi Saurio
2025-10-09 15:01:47 +03:00
parent 7948259d37
commit e0461dd78a

View File

@@ -1,6 +1,5 @@
use turso_parser::ast::SortOrder;
use std::cell::{RefCell, UnsafeCell};
use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd, Reverse};
use std::collections::BinaryHeap;
use std::rc::Rc;
@@ -204,7 +203,7 @@ impl Sorter {
};
match record {
Some(record) => {
if let Some(error) = record.deserialization_error.replace(None) {
if let Some(error) = record.deserialization_error {
// If there was a key deserialization error during the comparison, return the error.
return Err(error);
}
@@ -614,12 +613,14 @@ impl SortedChunk {
struct SortableImmutableRecord {
record: ImmutableRecord,
cursor: RecordCursor,
// SAFETY: borrows from 'self
key_values: UnsafeCell<Vec<ValueRef<'static>>>,
/// SAFETY: borrows from self
/// These are precomputed on record construction so that they can be reused during
/// sorting comparisons.
key_values: Vec<ValueRef<'static>>,
key_len: usize,
index_key_info: Rc<Vec<KeyInfo>>,
/// The key deserialization error, if any.
deserialization_error: RefCell<Option<LimboError>>,
deserialization_error: Option<LimboError>,
}
impl SortableImmutableRecord {
@@ -634,45 +635,40 @@ impl SortableImmutableRecord {
index_key_info.len() >= cursor.serial_types.len(),
"index_key_info.len() < cursor.serial_types.len()"
);
// Pre-compute all key values upfront
let mut key_values = Vec::with_capacity(key_len);
let mut deserialization_error = None;
for i in 0..key_len {
match cursor.deserialize_column(&record, i) {
Ok(value) => {
// SAFETY: We're storing the value with 'static lifetime but it's actually bound to the record
// This is safe because the record lives as long as this struct
let value: ValueRef<'static> = unsafe { std::mem::transmute(value) };
key_values.push(value);
}
Err(err) => {
deserialization_error = Some(err);
break;
}
}
}
Ok(Self {
record,
cursor,
key_values: UnsafeCell::new(Vec::with_capacity(key_len)),
key_values,
index_key_info,
deserialization_error: RefCell::new(None),
deserialization_error,
key_len,
})
}
fn key_value<'a>(&'a self, i: usize) -> Option<ValueRef<'a>> {
// SAFETY: there are no other active references
let key_values = unsafe { &mut *self.key_values.get() };
if i >= key_values.len() {
assert_eq!(key_values.len(), i, "access must be sequential");
let value = match self.cursor.deserialize_column(&self.record, i) {
Ok(value) => value,
Err(err) => {
self.deserialization_error.replace(Some(err));
return None;
}
};
// SAFETY: no 'static lifetime is exposed, all references are bound to 'self
let value: ValueRef<'static> = unsafe { std::mem::transmute(value) };
key_values.push(value);
}
Some(key_values[i])
}
}
impl Ord for SortableImmutableRecord {
fn cmp(&self, other: &Self) -> Ordering {
if self.deserialization_error.borrow().is_some()
|| other.deserialization_error.borrow().is_some()
{
if self.deserialization_error.is_some() || other.deserialization_error.is_some() {
// If one of the records has a deserialization error, circumvent the comparison and return early.
return Ordering::Equal;
}
@@ -682,13 +678,8 @@ impl Ord for SortableImmutableRecord {
);
for i in 0..self.key_len {
let Some(this_key_value) = self.key_value(i) else {
return Ordering::Equal;
};
let Some(other_key_value) = other.key_value(i) else {
return Ordering::Equal;
};
let this_key_value = self.key_values[i];
let other_key_value = other.key_values[i];
let column_order = self.index_key_info[i].sort_order;
let collation = self.index_key_info[i].collation;