mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-15 14:14:20 +01:00
chnage RecordCompare::compare to use an iterator
This commit is contained in:
@@ -1723,14 +1723,20 @@ pub enum RecordCompare {
|
||||
}
|
||||
|
||||
impl RecordCompare {
|
||||
pub fn compare(
|
||||
pub fn compare<V, E, I>(
|
||||
&self,
|
||||
serialized: &ImmutableRecord,
|
||||
unpacked: &[ValueRef],
|
||||
unpacked: I,
|
||||
index_info: &IndexInfo,
|
||||
skip: usize,
|
||||
tie_breaker: std::cmp::Ordering,
|
||||
) -> Result<std::cmp::Ordering> {
|
||||
) -> Result<std::cmp::Ordering>
|
||||
where
|
||||
V: AsValueRef,
|
||||
E: ExactSizeIterator<Item = V>,
|
||||
I: IntoIterator<IntoIter = E, Item = E::Item>,
|
||||
{
|
||||
let unpacked = unpacked.into_iter();
|
||||
match self {
|
||||
RecordCompare::Int => {
|
||||
compare_records_int(serialized, unpacked, index_info, tie_breaker)
|
||||
@@ -1813,12 +1819,16 @@ pub fn get_tie_breaker_from_seek_op(seek_op: SeekOp) -> std::cmp::Ordering {
|
||||
/// 4. **Sort order**: Applies ascending/descending order to comparison result
|
||||
/// 5. **Remaining fields**: If first field is equal and more fields exist,
|
||||
/// delegates to `compare_records_generic()` with `skip=1`
|
||||
fn compare_records_int(
|
||||
fn compare_records_int<V, I>(
|
||||
serialized: &ImmutableRecord,
|
||||
unpacked: &[ValueRef],
|
||||
unpacked: I,
|
||||
index_info: &IndexInfo,
|
||||
tie_breaker: std::cmp::Ordering,
|
||||
) -> Result<std::cmp::Ordering> {
|
||||
) -> Result<std::cmp::Ordering>
|
||||
where
|
||||
V: AsValueRef,
|
||||
I: ExactSizeIterator<Item = V>,
|
||||
{
|
||||
let payload = serialized.get_payload();
|
||||
if payload.len() < 2 {
|
||||
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
|
||||
@@ -1845,7 +1855,9 @@ fn compare_records_int(
|
||||
let data_start = header_size;
|
||||
|
||||
let lhs_int = read_integer(&payload[data_start..], first_serial_type as u8)?;
|
||||
let ValueRef::Integer(rhs_int) = unpacked[0] else {
|
||||
let mut unpacked = unpacked.peekable();
|
||||
// Do not consume iterator here
|
||||
let ValueRef::Integer(rhs_int) = unpacked.peek().unwrap().as_value_ref() else {
|
||||
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
|
||||
};
|
||||
let comparison = match index_info.key_info[0].sort_order {
|
||||
@@ -1902,12 +1914,16 @@ fn compare_records_int(
|
||||
/// 4. **Length comparison**: If strings are equal, compares lengths
|
||||
/// 5. **Remaining fields**: If first field is equal and more fields exist,
|
||||
/// delegates to `compare_records_generic()` with `skip=1`
|
||||
fn compare_records_string(
|
||||
fn compare_records_string<V, I>(
|
||||
serialized: &ImmutableRecord,
|
||||
unpacked: &[ValueRef],
|
||||
unpacked: I,
|
||||
index_info: &IndexInfo,
|
||||
tie_breaker: std::cmp::Ordering,
|
||||
) -> Result<std::cmp::Ordering> {
|
||||
) -> Result<std::cmp::Ordering>
|
||||
where
|
||||
V: AsValueRef,
|
||||
I: ExactSizeIterator<Item = V>,
|
||||
{
|
||||
let payload = serialized.get_payload();
|
||||
if payload.len() < 2 {
|
||||
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
|
||||
@@ -1931,7 +1947,9 @@ fn compare_records_string(
|
||||
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
|
||||
}
|
||||
|
||||
let ValueRef::Text(rhs_text) = &unpacked[0] else {
|
||||
let mut unpacked = unpacked.peekable();
|
||||
|
||||
let ValueRef::Text(rhs_text) = unpacked.peek().unwrap().as_value_ref() else {
|
||||
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
|
||||
};
|
||||
|
||||
@@ -1948,7 +1966,7 @@ fn compare_records_string(
|
||||
};
|
||||
|
||||
let collation = index_info.key_info[0].collation;
|
||||
let comparison = collation.compare_strings(&lhs_text, rhs_text);
|
||||
let comparison = collation.compare_strings(&lhs_text, &rhs_text);
|
||||
|
||||
let final_comparison = match index_info.key_info[0].sort_order {
|
||||
SortOrder::Asc => comparison,
|
||||
@@ -2007,13 +2025,17 @@ fn compare_records_string(
|
||||
/// The serialized and unpacked records do not have to contain the same number
|
||||
/// of fields. If all fields that appear in both records are equal, then
|
||||
/// `tie_breaker` is returned.
|
||||
pub fn compare_records_generic(
|
||||
pub fn compare_records_generic<V, I>(
|
||||
serialized: &ImmutableRecord,
|
||||
unpacked: &[ValueRef],
|
||||
unpacked: I,
|
||||
index_info: &IndexInfo,
|
||||
skip: usize,
|
||||
tie_breaker: std::cmp::Ordering,
|
||||
) -> Result<std::cmp::Ordering> {
|
||||
) -> Result<std::cmp::Ordering>
|
||||
where
|
||||
V: AsValueRef,
|
||||
I: ExactSizeIterator<Item = V>,
|
||||
{
|
||||
let payload = serialized.get_payload();
|
||||
if payload.is_empty() {
|
||||
return Ok(std::cmp::Ordering::Less);
|
||||
@@ -2042,15 +2064,21 @@ pub fn compare_records_generic(
|
||||
data_pos += serial_type.size();
|
||||
}
|
||||
}
|
||||
// assumes that that the `unpacked' iterator was not skipped outside this function call`
|
||||
let mut unpacked = unpacked.skip(skip);
|
||||
|
||||
let mut field_idx = skip;
|
||||
let field_limit = unpacked.len().min(index_info.key_info.len());
|
||||
while field_idx < field_limit && header_pos < header_end {
|
||||
|
||||
while let Some(rhs_value) = unpacked.next() {
|
||||
let rhs_value = &rhs_value.as_value_ref();
|
||||
if field_idx >= field_limit || header_pos >= header_end {
|
||||
break;
|
||||
}
|
||||
let (serial_type_raw, bytes_read) = read_varint(&payload[header_pos..])?;
|
||||
header_pos += bytes_read;
|
||||
|
||||
let serial_type = SerialType::try_from(serial_type_raw)?;
|
||||
let rhs_value = &unpacked[field_idx];
|
||||
|
||||
let lhs_value = match serial_type.kind() {
|
||||
SerialTypeKind::ConstInt0 => ValueRef::Integer(0),
|
||||
@@ -2729,12 +2757,17 @@ mod tests {
|
||||
"Test '{test_name}' failed: Full Comparison: {gold_result:?}, Optimized: {optimized_result:?}, Strategy: {comparer:?}"
|
||||
);
|
||||
|
||||
let generic_result =
|
||||
compare_records_generic(&serialized, &unpacked_values, index_info, 0, tie_breaker)
|
||||
.unwrap();
|
||||
let generic_result = compare_records_generic(
|
||||
&serialized,
|
||||
unpacked_values.iter(),
|
||||
index_info,
|
||||
0,
|
||||
tie_breaker,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
gold_result, generic_result,
|
||||
"Test '{test_name}' failed with generic: Full Comparison: {gold_result:?}, Generic: {generic_result:?}"
|
||||
"Test '{test_name}' failed with generic: Full Comparison: {gold_result:?}, Generic: {generic_result:?}\n LHS: {serialized_values:?}\n RHS: {unpacked_values:?}"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2859,6 +2892,9 @@ mod tests {
|
||||
];
|
||||
|
||||
for (serialized_values, unpacked_values, test_name) in test_cases {
|
||||
println!(
|
||||
"Testing integer fast path `{test_name}`\nLHS: {serialized_values:?}\nRHS: {unpacked_values:?}"
|
||||
);
|
||||
assert_compare_matches_full_comparison(
|
||||
serialized_values,
|
||||
unpacked_values,
|
||||
@@ -3131,9 +3167,11 @@ mod tests {
|
||||
|
||||
let tie_breaker = std::cmp::Ordering::Equal;
|
||||
let result_skip_0 =
|
||||
compare_records_generic(&serialized, &unpacked, &index_info, 0, tie_breaker).unwrap();
|
||||
compare_records_generic(&serialized, unpacked.iter(), &index_info, 0, tie_breaker)
|
||||
.unwrap();
|
||||
let result_skip_1 =
|
||||
compare_records_generic(&serialized, &unpacked, &index_info, 1, tie_breaker).unwrap();
|
||||
compare_records_generic(&serialized, unpacked.iter(), &index_info, 1, tie_breaker)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result_skip_0, std::cmp::Ordering::Less);
|
||||
|
||||
|
||||
@@ -3452,7 +3452,7 @@ pub fn op_idx_ge(
|
||||
let tie_breaker = get_tie_breaker_from_idx_comp_op(insn);
|
||||
let ord = compare_records_generic(
|
||||
&idx_record, // The serialized record from the index
|
||||
&values, // The record built from registers
|
||||
values, // The record built from registers
|
||||
cursor.get_index_info(), // Sort order flags
|
||||
0,
|
||||
tie_breaker,
|
||||
@@ -3520,7 +3520,7 @@ pub fn op_idx_le(
|
||||
let tie_breaker = get_tie_breaker_from_idx_comp_op(insn);
|
||||
let ord = compare_records_generic(
|
||||
&idx_record,
|
||||
&values,
|
||||
values,
|
||||
cursor.get_index_info(),
|
||||
0,
|
||||
tie_breaker,
|
||||
@@ -3571,7 +3571,7 @@ pub fn op_idx_gt(
|
||||
let tie_breaker = get_tie_breaker_from_idx_comp_op(insn);
|
||||
let ord = compare_records_generic(
|
||||
&idx_record,
|
||||
&values,
|
||||
values,
|
||||
cursor.get_index_info(),
|
||||
0,
|
||||
tie_breaker,
|
||||
@@ -3623,7 +3623,7 @@ pub fn op_idx_lt(
|
||||
let tie_breaker = get_tie_breaker_from_idx_comp_op(insn);
|
||||
let ord = compare_records_generic(
|
||||
&idx_record,
|
||||
&values,
|
||||
values,
|
||||
cursor.get_index_info(),
|
||||
0,
|
||||
tie_breaker,
|
||||
|
||||
@@ -1092,11 +1092,10 @@ fn make_record(registers: &[Register], start_reg: &usize, count: &usize) -> Immu
|
||||
ImmutableRecord::from_registers(regs, regs.len())
|
||||
}
|
||||
|
||||
pub fn registers_to_ref_values<'a>(registers: &'a [Register]) -> Vec<ValueRef<'a>> {
|
||||
registers
|
||||
.iter()
|
||||
.map(|reg| reg.get_value().as_ref())
|
||||
.collect()
|
||||
pub fn registers_to_ref_values<'a>(
|
||||
registers: &'a [Register],
|
||||
) -> impl ExactSizeIterator<Item = ValueRef<'a>> {
|
||||
registers.iter().map(|reg| reg.get_value().as_ref())
|
||||
}
|
||||
|
||||
#[instrument(skip(program), level = Level::DEBUG)]
|
||||
|
||||
Reference in New Issue
Block a user