From ce3527df406cb34ebbfba5cb195b217d73eed57a Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Wed, 5 Nov 2025 16:40:42 -0300 Subject: [PATCH] chnage `RecordCompare::compare` to use an iterator --- core/types.rs | 84 ++++++++++++++++++++++++++++++++------------ core/vdbe/execute.rs | 8 ++--- core/vdbe/mod.rs | 9 +++-- 3 files changed, 69 insertions(+), 32 deletions(-) diff --git a/core/types.rs b/core/types.rs index c842205b1..f9442b9c4 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1723,14 +1723,20 @@ pub enum RecordCompare { } impl RecordCompare { - pub fn compare( + pub fn compare( &self, serialized: &ImmutableRecord, - unpacked: &[ValueRef], + unpacked: I, index_info: &IndexInfo, skip: usize, tie_breaker: std::cmp::Ordering, - ) -> Result { + ) -> Result + where + V: AsValueRef, + E: ExactSizeIterator, + I: IntoIterator, + { + let unpacked = unpacked.into_iter(); match self { RecordCompare::Int => { compare_records_int(serialized, unpacked, index_info, tie_breaker) @@ -1813,12 +1819,16 @@ pub fn get_tie_breaker_from_seek_op(seek_op: SeekOp) -> std::cmp::Ordering { /// 4. **Sort order**: Applies ascending/descending order to comparison result /// 5. **Remaining fields**: If first field is equal and more fields exist, /// delegates to `compare_records_generic()` with `skip=1` -fn compare_records_int( +fn compare_records_int( serialized: &ImmutableRecord, - unpacked: &[ValueRef], + unpacked: I, index_info: &IndexInfo, tie_breaker: std::cmp::Ordering, -) -> Result { +) -> Result +where + V: AsValueRef, + I: ExactSizeIterator, +{ let payload = serialized.get_payload(); if payload.len() < 2 { return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); @@ -1845,7 +1855,9 @@ fn compare_records_int( let data_start = header_size; let lhs_int = read_integer(&payload[data_start..], first_serial_type as u8)?; - let ValueRef::Integer(rhs_int) = unpacked[0] else { + let mut unpacked = unpacked.peekable(); + // Do not consume iterator here + let ValueRef::Integer(rhs_int) = unpacked.peek().unwrap().as_value_ref() else { return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); }; let comparison = match index_info.key_info[0].sort_order { @@ -1902,12 +1914,16 @@ fn compare_records_int( /// 4. **Length comparison**: If strings are equal, compares lengths /// 5. **Remaining fields**: If first field is equal and more fields exist, /// delegates to `compare_records_generic()` with `skip=1` -fn compare_records_string( +fn compare_records_string( serialized: &ImmutableRecord, - unpacked: &[ValueRef], + unpacked: I, index_info: &IndexInfo, tie_breaker: std::cmp::Ordering, -) -> Result { +) -> Result +where + V: AsValueRef, + I: ExactSizeIterator, +{ let payload = serialized.get_payload(); if payload.len() < 2 { return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); @@ -1931,7 +1947,9 @@ fn compare_records_string( return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); } - let ValueRef::Text(rhs_text) = &unpacked[0] else { + let mut unpacked = unpacked.peekable(); + + let ValueRef::Text(rhs_text) = unpacked.peek().unwrap().as_value_ref() else { return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); }; @@ -1948,7 +1966,7 @@ fn compare_records_string( }; let collation = index_info.key_info[0].collation; - let comparison = collation.compare_strings(&lhs_text, rhs_text); + let comparison = collation.compare_strings(&lhs_text, &rhs_text); let final_comparison = match index_info.key_info[0].sort_order { SortOrder::Asc => comparison, @@ -2007,13 +2025,17 @@ fn compare_records_string( /// The serialized and unpacked records do not have to contain the same number /// of fields. If all fields that appear in both records are equal, then /// `tie_breaker` is returned. -pub fn compare_records_generic( +pub fn compare_records_generic( serialized: &ImmutableRecord, - unpacked: &[ValueRef], + unpacked: I, index_info: &IndexInfo, skip: usize, tie_breaker: std::cmp::Ordering, -) -> Result { +) -> Result +where + V: AsValueRef, + I: ExactSizeIterator, +{ let payload = serialized.get_payload(); if payload.is_empty() { return Ok(std::cmp::Ordering::Less); @@ -2042,15 +2064,21 @@ pub fn compare_records_generic( data_pos += serial_type.size(); } } + // assumes that that the `unpacked' iterator was not skipped outside this function call` + let mut unpacked = unpacked.skip(skip); let mut field_idx = skip; let field_limit = unpacked.len().min(index_info.key_info.len()); - while field_idx < field_limit && header_pos < header_end { + + while let Some(rhs_value) = unpacked.next() { + let rhs_value = &rhs_value.as_value_ref(); + if field_idx >= field_limit || header_pos >= header_end { + break; + } let (serial_type_raw, bytes_read) = read_varint(&payload[header_pos..])?; header_pos += bytes_read; let serial_type = SerialType::try_from(serial_type_raw)?; - let rhs_value = &unpacked[field_idx]; let lhs_value = match serial_type.kind() { SerialTypeKind::ConstInt0 => ValueRef::Integer(0), @@ -2729,12 +2757,17 @@ mod tests { "Test '{test_name}' failed: Full Comparison: {gold_result:?}, Optimized: {optimized_result:?}, Strategy: {comparer:?}" ); - let generic_result = - compare_records_generic(&serialized, &unpacked_values, index_info, 0, tie_breaker) - .unwrap(); + let generic_result = compare_records_generic( + &serialized, + unpacked_values.iter(), + index_info, + 0, + tie_breaker, + ) + .unwrap(); assert_eq!( gold_result, generic_result, - "Test '{test_name}' failed with generic: Full Comparison: {gold_result:?}, Generic: {generic_result:?}" + "Test '{test_name}' failed with generic: Full Comparison: {gold_result:?}, Generic: {generic_result:?}\n LHS: {serialized_values:?}\n RHS: {unpacked_values:?}" ); } @@ -2859,6 +2892,9 @@ mod tests { ]; for (serialized_values, unpacked_values, test_name) in test_cases { + println!( + "Testing integer fast path `{test_name}`\nLHS: {serialized_values:?}\nRHS: {unpacked_values:?}" + ); assert_compare_matches_full_comparison( serialized_values, unpacked_values, @@ -3131,9 +3167,11 @@ mod tests { let tie_breaker = std::cmp::Ordering::Equal; let result_skip_0 = - compare_records_generic(&serialized, &unpacked, &index_info, 0, tie_breaker).unwrap(); + compare_records_generic(&serialized, unpacked.iter(), &index_info, 0, tie_breaker) + .unwrap(); let result_skip_1 = - compare_records_generic(&serialized, &unpacked, &index_info, 1, tie_breaker).unwrap(); + compare_records_generic(&serialized, unpacked.iter(), &index_info, 1, tie_breaker) + .unwrap(); assert_eq!(result_skip_0, std::cmp::Ordering::Less); diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index fb661f6e2..d211658d1 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -3452,7 +3452,7 @@ pub fn op_idx_ge( let tie_breaker = get_tie_breaker_from_idx_comp_op(insn); let ord = compare_records_generic( &idx_record, // The serialized record from the index - &values, // The record built from registers + values, // The record built from registers cursor.get_index_info(), // Sort order flags 0, tie_breaker, @@ -3520,7 +3520,7 @@ pub fn op_idx_le( let tie_breaker = get_tie_breaker_from_idx_comp_op(insn); let ord = compare_records_generic( &idx_record, - &values, + values, cursor.get_index_info(), 0, tie_breaker, @@ -3571,7 +3571,7 @@ pub fn op_idx_gt( let tie_breaker = get_tie_breaker_from_idx_comp_op(insn); let ord = compare_records_generic( &idx_record, - &values, + values, cursor.get_index_info(), 0, tie_breaker, @@ -3623,7 +3623,7 @@ pub fn op_idx_lt( let tie_breaker = get_tie_breaker_from_idx_comp_op(insn); let ord = compare_records_generic( &idx_record, - &values, + values, cursor.get_index_info(), 0, tie_breaker, diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index e77ab7e40..c139cdab4 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -1092,11 +1092,10 @@ fn make_record(registers: &[Register], start_reg: &usize, count: &usize) -> Immu ImmutableRecord::from_registers(regs, regs.len()) } -pub fn registers_to_ref_values<'a>(registers: &'a [Register]) -> Vec> { - registers - .iter() - .map(|reg| reg.get_value().as_ref()) - .collect() +pub fn registers_to_ref_values<'a>( + registers: &'a [Register], +) -> impl ExactSizeIterator> { + registers.iter().map(|reg| reg.get_value().as_ref()) } #[instrument(skip(program), level = Level::DEBUG)]