diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 91bdfec93..0c000a65a 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -1222,6 +1222,64 @@ pub fn read_value(buf: &[u8], serial_type: SerialType) -> Result<(RefValue, usiz } } +#[inline(always)] +pub fn read_integer(buf: &[u8], serial_type: u8) -> Result { + match serial_type { + 1 => { + if buf.len() < 1 { + crate::bail_corrupt_error!("Invalid 1-byte int"); + } + Ok(buf[0] as i8 as i64) + } + 2 => { + if buf.len() < 2 { + crate::bail_corrupt_error!("Invalid 2-byte int"); + } + Ok(i16::from_be_bytes([buf[0], buf[1]]) as i64) + } + 3 => { + if buf.len() < 3 { + crate::bail_corrupt_error!("Invalid 3-byte int"); + } + let sign_extension = if buf[0] <= 0x7F { 0 } else { 0xFF }; + Ok(i32::from_be_bytes([sign_extension, buf[0], buf[1], buf[2]]) as i64) + } + 4 => { + if buf.len() < 4 { + crate::bail_corrupt_error!("Invalid 4-byte int"); + } + Ok(i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64) + } + 5 => { + if buf.len() < 6 { + crate::bail_corrupt_error!("Invalid 6-byte int"); + } + let sign_extension = if buf[0] <= 0x7F { 0 } else { 0xFF }; + Ok(i64::from_be_bytes([ + sign_extension, + sign_extension, + buf[0], + buf[1], + buf[2], + buf[3], + buf[4], + buf[5], + ])) + } + 6 => { + if buf.len() < 8 { + crate::bail_corrupt_error!("Invalid 8-byte int"); + } + Ok(i64::from_be_bytes([ + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], + ])) + } + 8 => Ok(0), + 9 => Ok(1), + _ => crate::bail_corrupt_error!("Invalid serial type for integer"), + } +} + #[inline(always)] pub fn read_varint(buf: &[u8]) -> Result<(u64, usize)> { let mut v: u64 = 0; diff --git a/core/types.rs b/core/types.rs index 01cb0b115..35834204e 100644 --- a/core/types.rs +++ b/core/types.rs @@ -8,7 +8,7 @@ use crate::ext::{ExtValue, ExtValueType}; use crate::pseudo::PseudoCursor; use crate::schema::Index; use crate::storage::btree::BTreeCursor; -use crate::storage::sqlite3_ondisk::{read_varint, write_varint}; +use crate::storage::sqlite3_ondisk::{read_integer, read_value, read_varint, write_varint}; use crate::translate::collate::CollationSeq; use crate::translate::plan::IterationDirection; use crate::vdbe::sorter::Sorter; @@ -1402,6 +1402,48 @@ pub fn compare_immutable( std::cmp::Ordering::Equal } +pub fn compare_immutable_for_testing( + l: &[RefValue], + r: &[RefValue], + index_key_sort_order: IndexKeySortOrder, + collations: &[CollationSeq], +) -> std::cmp::Ordering { + let min_len = l.len().min(r.len()); + + for i in 0..min_len { + let column_order = index_key_sort_order.get_sort_order_for_col(i); + let collation = collations.get(i).copied().unwrap_or_default(); + + let cmp = match (&l[i], &r[i]) { + (RefValue::Text(left), RefValue::Text(right)) => { + collation.compare_strings(left.as_str(), right.as_str()) + } + _ => l[i].partial_cmp(&r[i]).unwrap_or(std::cmp::Ordering::Equal), + }; + + if cmp != std::cmp::Ordering::Equal { + return match column_order { + SortOrder::Asc => cmp, + SortOrder::Desc => cmp.reverse(), + }; + } + } + + // All common fields are equal; resolve by field count difference + let len_cmp = l.len().cmp(&r.len()); + + if len_cmp == std::cmp::Ordering::Equal { + std::cmp::Ordering::Equal + } else { + // Use sort order of the last compared column, or default to Asc + let last_index = min_len.saturating_sub(1); + match index_key_sort_order.get_sort_order_for_col(last_index) { + SortOrder::Asc => len_cmp, + SortOrder::Desc => len_cmp.reverse(), + } + } +} + #[derive(Debug, Clone, Copy)] pub enum RecordCompare { Int, @@ -1412,23 +1454,320 @@ pub enum RecordCompare { impl RecordCompare { pub fn compare( &self, - serialized: &[u8], + serialized: &ImmutableRecord, unpacked: &[RefValue], index_info: &IndexKeyInfo, collations: &[CollationSeq], skip: usize, ) -> Result { - Ok(std::cmp::Ordering::Equal) + match self { + RecordCompare::Int => compare_records_int(serialized, unpacked, index_info, collations), + RecordCompare::String => { + compare_records_string(serialized, unpacked, index_info, collations) + } + RecordCompare::Generic => { + compare_records_generic(serialized, unpacked, index_info, collations, skip) + } + } } } -// pub find_compare(unpacked: &[RefValue], index_info: &IndexKeyInfo) -> RecordCompare { -// if unpacked.len() > 0 && index_info.num_cols <= 13 { -// match &unpacked[0] { -// RefValue::Integer(_) if can_use_int_ -// } -// } -// } +pub fn find_compare( + unpacked: &[RefValue], + index_info: &IndexKeyInfo, + collations: &[CollationSeq], +) -> RecordCompare { + if unpacked.len() > 0 && index_info.num_cols <= 13 { + match &unpacked[0] { + RefValue::Integer(_) => RecordCompare::Int, + RefValue::Text(_) if is_binary_collation(collations, 0) => RecordCompare::String, + _ => RecordCompare::Generic, + } + } else { + RecordCompare::Generic + } +} + +fn compare_records_int( + serialized: &ImmutableRecord, + unpacked: &[RefValue], + index_info: &IndexKeyInfo, + collations: &[CollationSeq], +) -> Result { + let payload = serialized.get_payload(); + if payload.len() < 2 || payload[0] > 63 { + return compare_records_generic(serialized, unpacked, index_info, collations, 0); + } + + let header_size = payload[0] as usize; + let first_serial_type = payload[1]; + + if !matches!(first_serial_type, 1..=6 | 8 | 9) { + return compare_records_generic(serialized, unpacked, index_info, collations, 0); + } + + let data_start = header_size; + if data_start >= payload.len() { + return compare_records_generic(serialized, unpacked, index_info, collations, 0); + } + + let lhs_int = read_integer(&payload[data_start..], first_serial_type)?; + let RefValue::Integer(rhs_int) = unpacked[0] else { + return compare_records_generic(serialized, unpacked, index_info, collations, 0); + }; + + let comparison = match index_info.sort_order.get_sort_order_for_col(0) { + SortOrder::Asc => lhs_int.cmp(&rhs_int), + SortOrder::Desc => lhs_int.cmp(&rhs_int).reverse(), + }; + match comparison { + std::cmp::Ordering::Equal => { + // First fields equal, compare remaining fields if any + if unpacked.len() > 1 { + return compare_records_generic(serialized, unpacked, index_info, collations, 1); + } else { + let mut record_cursor = RecordCursor::new(); + let serial_type_len = record_cursor.len(serialized); + + if serial_type_len > unpacked.len() { + Ok(std::cmp::Ordering::Greater) + } else { + Ok(std::cmp::Ordering::Equal) + } + } + } + other => Ok(other), + } +} + +fn compare_records_string( + serialized: &ImmutableRecord, + unpacked: &[RefValue], + index_info: &IndexKeyInfo, + collations: &[CollationSeq], +) -> Result { + let payload = serialized.get_payload(); + if payload.len() < 2 { + return compare_records_generic(serialized, unpacked, index_info, collations, 0); + } + + let header_size = payload[0] as usize; + let first_serial_type = payload[1]; + + // Check if serial type is not a string or if its a blob + if first_serial_type < 13 || (first_serial_type & 1) == 0 { + return compare_records_generic(serialized, unpacked, index_info, collations, 0); + } + + let RefValue::Text(rhs_text) = &unpacked[0] else { + return compare_records_generic(serialized, unpacked, index_info, collations, 0); + }; + + let string_len = (first_serial_type as usize - 13) / 2; + let data_start = header_size; + + debug_assert!(data_start + string_len <= payload.len()); + + let serial_type = SerialType::try_from(first_serial_type as u64)?; + let (lhs_value, _) = read_value(&payload[data_start..], serial_type)?; + + let RefValue::Text(lhs_text) = lhs_value else { + return compare_records_generic(serialized, unpacked, index_info, collations, 0); + }; + + let comparison = if let Some(collation) = collations.get(0) { + collation.compare_strings(lhs_text.as_str(), rhs_text.as_str()) + } else { + // No collation case + lhs_text.value.to_slice().cmp(rhs_text.value.to_slice()) + }; + + let final_comparison = match index_info.sort_order.get_sort_order_for_col(0) { + SortOrder::Asc => comparison, + SortOrder::Desc => comparison.reverse(), + }; + + match final_comparison { + std::cmp::Ordering::Equal => { + let len_cmp = lhs_text.value.len.cmp(&rhs_text.value.len); + if len_cmp != std::cmp::Ordering::Equal { + let adjusted = match index_info.sort_order.get_sort_order_for_col(0) { + SortOrder::Asc => len_cmp, + SortOrder::Desc => len_cmp.reverse(), + }; + return Ok(adjusted); + } + + if unpacked.len() > 1 { + compare_records_generic(serialized, unpacked, index_info, collations, 1) + } else { + let mut record_cursor = RecordCursor::new(); + let serial_type_len = record_cursor.len(serialized); + + if serial_type_len > unpacked.len() { + Ok(std::cmp::Ordering::Greater) + } else { + Ok(std::cmp::Ordering::Equal) + } + } + } + other => Ok(other), + } +} + +fn compare_records_generic( + serialized: &ImmutableRecord, + unpacked: &[RefValue], + index_info: &IndexKeyInfo, + collations: &[CollationSeq], + skip: usize, +) -> Result { + // println!("hitting compare_records_generic"); + let payload = serialized.get_payload(); + // println!("payload: {:?}", payload); + if payload.is_empty() { + return Ok(std::cmp::Ordering::Less); + } + + let (header_size, mut pos) = read_varint(payload)?; + let header_end = header_size as usize; + + debug_assert!(header_end <= payload.len()); + + let mut serial_types = Vec::new(); + + while pos < header_end { + let (serial_type, bytes_read) = read_varint(&payload[pos..])?; + serial_types.push(serial_type); + pos += bytes_read; + } + + let mut data_pos = header_size as usize; + + for i in 0..skip { + let serial_type = SerialType::try_from(serial_types[i])?; + if !matches!( + serial_type.kind(), + SerialTypeKind::ConstInt0 | SerialTypeKind::ConstInt1 | SerialTypeKind::Null + ) { + let len = serial_type.size(); + data_pos += len; + } + } + + // println!("skip: {}", skip); + // println!( + // "unpacked.len(): {}, serial_types.len(): {}", + // unpacked.len(), + // serial_types.len() + // ); + + for i in skip..unpacked.len().min(serial_types.len()) { + let serial_type = SerialType::try_from(serial_types[i])?; + // println!("i = {}", i); + // println!("serial type kind = {:?}", serial_type.kind()); + let rhs_value = &unpacked[i]; + + let lhs_value = match serial_type.kind() { + SerialTypeKind::ConstInt0 => RefValue::Integer(0), + SerialTypeKind::ConstInt1 => RefValue::Integer(1), + SerialTypeKind::Null => RefValue::Null, + _ => { + // Use existing read_value function for all other types + let (value, field_size) = read_value(&payload[data_pos..], serial_type)?; + data_pos += field_size; + value + } + }; + + // println!("lhs_value: {:?}, rhs_value: {:?}", lhs_value, rhs_value); + + let comparison = match rhs_value { + RefValue::Integer(rhs_int) => { + match &lhs_value { + RefValue::Null => std::cmp::Ordering::Less, + RefValue::Integer(lhs_int) => lhs_int.cmp(rhs_int), + RefValue::Float(lhs_float) => { + sqlite_int_float_compare(*rhs_int, *lhs_float).reverse() + } + RefValue::Text(_) | RefValue::Blob(_) => std::cmp::Ordering::Less, // Numbers < Text/Blob + } + } + + RefValue::Float(rhs_float) => { + match &lhs_value { + RefValue::Null => std::cmp::Ordering::Less, + RefValue::Integer(lhs_int) => sqlite_int_float_compare(*lhs_int, *rhs_float), + RefValue::Float(lhs_float) => { + if lhs_float.is_nan() && rhs_float.is_nan() { + std::cmp::Ordering::Equal + } else if lhs_float.is_nan() { + std::cmp::Ordering::Less // NaN is NULL + } else if rhs_float.is_nan() { + std::cmp::Ordering::Greater + } else { + lhs_float + .partial_cmp(rhs_float) + .unwrap_or(std::cmp::Ordering::Equal) + } + } + RefValue::Text(_) | RefValue::Blob(_) => std::cmp::Ordering::Less, // Numbers < Text/Blob + } + } + + RefValue::Text(rhs_text) => { + match &lhs_value { + RefValue::Null | RefValue::Integer(_) | RefValue::Float(_) => { + std::cmp::Ordering::Less + } + RefValue::Text(lhs_text) => { + if let Some(collation) = collations.get(i) { + collation.compare_strings(lhs_text.as_str(), rhs_text.as_str()) + } else { + // Binary comparison (no collation) + lhs_text.value.to_slice().cmp(rhs_text.value.to_slice()) + } + } + RefValue::Blob(_) => std::cmp::Ordering::Less, // Text < Blob + } + } + + // RHS is a blob + RefValue::Blob(rhs_blob) => match &lhs_value { + RefValue::Null | RefValue::Integer(_) | RefValue::Float(_) | RefValue::Text(_) => { + std::cmp::Ordering::Less + } + RefValue::Blob(lhs_blob) => lhs_blob.to_slice().cmp(rhs_blob.to_slice()), + }, + + // RHS is null + RefValue::Null => { + match &lhs_value { + RefValue::Null => std::cmp::Ordering::Equal, + RefValue::Float(f) if f.is_nan() => std::cmp::Ordering::Equal, // NaN treated as NULL + _ => std::cmp::Ordering::Less, // Non-NULL > NULL + } + } + }; + + let final_comparison = match index_info.sort_order.get_sort_order_for_col(i) { + SortOrder::Asc => comparison, + SortOrder::Desc => comparison.reverse(), + }; + + // Early exit if fields are not equal + if final_comparison != std::cmp::Ordering::Equal { + return Ok(final_comparison); + } + } + + Ok(serial_types.len().cmp(&unpacked.len())) +} + +#[inline(always)] +fn is_binary_collation(collations: &[CollationSeq], col_idx: usize) -> bool { + collations[col_idx] == CollationSeq::Binary +} const I8_LOW: i64 = -128; const I8_HIGH: i64 = 127; @@ -1795,6 +2134,477 @@ impl RawSlice { #[cfg(test)] mod tests { use super::*; + use crate::storage::sqlite3_ondisk::read_value; + use crate::translate::collate::CollationSeq; + + fn create_record(values: Vec) -> ImmutableRecord { + let registers: Vec = values.into_iter().map(Register::Value).collect(); + ImmutableRecord::from_registers(®isters) + } + + fn create_index_info(num_cols: usize, sort_orders: Vec) -> IndexKeyInfo { + IndexKeyInfo { + sort_order: IndexKeySortOrder::from_list(sort_orders.as_slice()), + has_rowid: false, + num_cols, + } + } + + fn value_to_ref_value(value: &Value) -> RefValue { + match value { + Value::Null => RefValue::Null, + Value::Integer(i) => RefValue::Integer(*i), + Value::Float(f) => RefValue::Float(*f), + Value::Text(text) => RefValue::Text(TextRef { + value: RawSlice::from_slice(&text.value), + subtype: text.subtype.clone(), + }), + Value::Blob(blob) => RefValue::Blob(RawSlice::from_slice(blob)), + } + } + + impl TextRef { + fn from_str(s: &str) -> Self { + TextRef { + value: RawSlice::from_slice(s.as_bytes()), + subtype: crate::types::TextSubtype::Text, + } + } + } + + impl RawSlice { + fn from_slice(data: &[u8]) -> Self { + Self { + data: data.as_ptr(), + len: data.len(), + } + } + } + + fn assert_compare_matches_full_comparison( + serialized_values: Vec, + unpacked_values: Vec, + index_info: &IndexKeyInfo, + collations: &[CollationSeq], + test_name: &str, + ) { + let serialized = create_record(serialized_values.clone()); + + let serialized_ref_values: Vec = serialized_values + .iter() + .map(|v| value_to_ref_value(v)) + .collect(); + + let gold_result = compare_immutable_for_testing( + &serialized_ref_values, + &unpacked_values, + index_info.sort_order.clone(), + collations, + ); + + let comparer = find_compare(&unpacked_values, index_info, collations); + let optimized_result = comparer + .compare(&serialized, &unpacked_values, index_info, collations, 0) + .unwrap(); + + assert_eq!( + gold_result, optimized_result, + "Test '{}' failed: Full Comparison: {:?}, Optimized: {:?}, Strategy: {:?}", + test_name, gold_result, optimized_result, comparer + ); + + let generic_result = + compare_records_generic(&serialized, &unpacked_values, index_info, collations, 0) + .unwrap(); + assert_eq!( + gold_result, generic_result, + "Test '{}' failed with generic: Full Comparison: {:?}, Generic: {:?}", + test_name, gold_result, generic_result + ); + } + + #[test] + fn test_integer_fast_path() { + let index_info = create_index_info(2, vec![SortOrder::Asc, SortOrder::Asc]); + let collations = vec![CollationSeq::Binary, CollationSeq::Binary]; + + let test_cases = vec![ + ( + vec![Value::Integer(42)], + vec![RefValue::Integer(42)], + "equal_integers", + ), + ( + vec![Value::Integer(10)], + vec![RefValue::Integer(20)], + "less_than_integers", + ), + ( + vec![Value::Integer(30)], + vec![RefValue::Integer(20)], + "greater_than_integers", + ), + ( + vec![Value::Integer(0)], + vec![RefValue::Integer(0)], + "zero_integers", + ), + ( + vec![Value::Integer(-5)], + vec![RefValue::Integer(-5)], + "negative_integers", + ), + ( + vec![Value::Integer(i64::MAX)], + vec![RefValue::Integer(i64::MAX)], + "max_integers", + ), + ( + vec![Value::Integer(i64::MIN)], + vec![RefValue::Integer(i64::MIN)], + "min_integers", + ), + ( + vec![Value::Integer(42), Value::Text(Text::new("hello"))], + vec![ + RefValue::Integer(42), + RefValue::Text(TextRef::from_str("hello")), + ], + "integer_text_equal", + ), + ( + vec![Value::Integer(42), Value::Text(Text::new("hello"))], + vec![ + RefValue::Integer(42), + RefValue::Text(TextRef::from_str("world")), + ], + "integer_equal_text_different", + ), + ]; + + for (serialized_values, unpacked_values, test_name) in test_cases { + assert_compare_matches_full_comparison( + serialized_values, + unpacked_values, + &index_info, + &collations, + test_name, + ); + } + } + + #[test] + fn test_string_fast_path() { + let index_info = create_index_info(2, vec![SortOrder::Asc, SortOrder::Asc]); + let collations = vec![CollationSeq::Binary, CollationSeq::Binary]; + + let test_cases = vec![ + ( + vec![Value::Text(Text::new("hello"))], + vec![RefValue::Text(TextRef::from_str("hello"))], + "equal_strings", + ), + ( + vec![Value::Text(Text::new("abc"))], + vec![RefValue::Text(TextRef::from_str("def"))], + "less_than_strings", + ), + ( + vec![Value::Text(Text::new("xyz"))], + vec![RefValue::Text(TextRef::from_str("abc"))], + "greater_than_strings", + ), + ( + vec![Value::Text(Text::new(""))], + vec![RefValue::Text(TextRef::from_str(""))], + "empty_strings", + ), + ( + vec![Value::Text(Text::new("a"))], + vec![RefValue::Text(TextRef::from_str("aa"))], + "prefix_strings", + ), + // Multi-field with string first + ( + vec![Value::Text(Text::new("hello")), Value::Integer(42)], + vec![ + RefValue::Text(TextRef::from_str("hello")), + RefValue::Integer(42), + ], + "string_integer_equal", + ), + ( + vec![Value::Text(Text::new("hello")), Value::Integer(42)], + vec![ + RefValue::Text(TextRef::from_str("hello")), + RefValue::Integer(99), + ], + "string_equal_integer_different", + ), + ]; + + for (serialized_values, unpacked_values, test_name) in test_cases { + assert_compare_matches_full_comparison( + serialized_values, + unpacked_values, + &index_info, + &collations, + test_name, + ); + } + } + + #[test] + fn test_type_precedence() { + let index_info = create_index_info(1, vec![SortOrder::Asc]); + let collations = vec![CollationSeq::Binary]; + + // Test SQLite type precedence: NULL < Numbers < Text < Blob + let test_cases = vec![ + // NULL vs others + ( + vec![Value::Null], + vec![RefValue::Integer(42)], + "null_vs_integer", + ), + ( + vec![Value::Null], + vec![RefValue::Float(3.14)], + "null_vs_float", + ), + ( + vec![Value::Null], + vec![RefValue::Text(TextRef::from_str("hello"))], + "null_vs_text", + ), + ( + vec![Value::Null], + vec![RefValue::Blob(RawSlice::from_slice(b"blob"))], + "null_vs_blob", + ), + // Numbers vs Text/Blob + ( + vec![Value::Integer(42)], + vec![RefValue::Text(TextRef::from_str("hello"))], + "integer_vs_text", + ), + ( + vec![Value::Float(3.14)], + vec![RefValue::Text(TextRef::from_str("hello"))], + "float_vs_text", + ), + ( + vec![Value::Integer(42)], + vec![RefValue::Blob(RawSlice::from_slice(b"blob"))], + "integer_vs_blob", + ), + ( + vec![Value::Float(3.14)], + vec![RefValue::Blob(RawSlice::from_slice(b"blob"))], + "float_vs_blob", + ), + // Text vs Blob + ( + vec![Value::Text(Text::new("hello"))], + vec![RefValue::Blob(RawSlice::from_slice(b"blob"))], + "text_vs_blob", + ), + // Integer vs Float (affinity conversion) + ( + vec![Value::Integer(42)], + vec![RefValue::Float(42.0)], + "integer_vs_equal_float", + ), + ( + vec![Value::Integer(42)], + vec![RefValue::Float(42.5)], + "integer_vs_different_float", + ), + ( + vec![Value::Float(42.5)], + vec![RefValue::Integer(42)], + "float_vs_integer", + ), + ]; + + for (serialized_values, unpacked_values, test_name) in test_cases { + assert_compare_matches_full_comparison( + serialized_values, + unpacked_values, + &index_info, + &collations, + test_name, + ); + } + } + + #[test] + fn test_sort_order_desc() { + let index_info = create_index_info(2, vec![SortOrder::Desc, SortOrder::Asc]); + let collations = vec![CollationSeq::Binary, CollationSeq::Binary]; + + let test_cases = vec![ + // DESC order should reverse first field comparison + ( + vec![Value::Integer(10)], + vec![RefValue::Integer(20)], + "desc_integer_reversed", + ), + ( + vec![Value::Text(Text::new("abc"))], + vec![RefValue::Text(TextRef::from_str("def"))], + "desc_string_reversed", + ), + // Mixed sort orders + ( + vec![Value::Integer(10), Value::Text(Text::new("hello"))], + vec![ + RefValue::Integer(20), + RefValue::Text(TextRef::from_str("hello")), + ], + "desc_first_asc_second", + ), + ]; + + for (serialized_values, unpacked_values, test_name) in test_cases { + assert_compare_matches_full_comparison( + serialized_values, + unpacked_values, + &index_info, + &collations, + test_name, + ); + } + } + + #[test] + fn test_edge_cases() { + let index_info = create_index_info(3, vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc]); + let collations = vec![ + CollationSeq::Binary, + CollationSeq::Binary, + CollationSeq::Binary, + ]; + + let test_cases = vec![ + ( + vec![Value::Integer(42)], + vec![ + RefValue::Integer(42), + RefValue::Text(TextRef::from_str("extra")), + ], + "fewer_serialized_fields", + ), + ( + vec![Value::Integer(42), Value::Text(Text::new("extra"))], + vec![RefValue::Integer(42)], + "fewer_unpacked_fields", + ), + (vec![], vec![], "both_empty"), + (vec![], vec![RefValue::Integer(42)], "empty_serialized"), + ( + (0..15).map(|i| Value::Integer(i)).collect(), + (0..15).map(|i| RefValue::Integer(i)).collect(), + "large_field_count", + ), + ( + vec![Value::Blob(vec![1, 2, 3])], + vec![RefValue::Blob(RawSlice::from_slice(&[1, 2, 3]))], + "blob_first_field", + ), + ( + vec![Value::Text(Text::new("hello")), Value::Integer(5)], + vec![RefValue::Text(TextRef::from_str("hello"))], + "equal_text_prefix_but_more_serialized_fields", + ), + ( + vec![Value::Text(Text::new("same")), Value::Integer(5)], + vec![ + RefValue::Text(TextRef::from_str("same")), + RefValue::Integer(5), + ], + "equal_text_then_equal_int", + ), + ]; + + for (serialized_values, unpacked_values, test_name) in test_cases { + assert_compare_matches_full_comparison( + serialized_values, + unpacked_values, + &index_info, + &collations, + test_name, + ); + } + } + + #[test] + fn test_skip_parameter() { + let index_info = create_index_info(3, vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc]); + let collations = vec![ + CollationSeq::Binary, + CollationSeq::Binary, + CollationSeq::Binary, + ]; + + let serialized = create_record(vec![ + Value::Integer(1), + Value::Integer(2), + Value::Integer(3), + ]); + let unpacked = vec![ + RefValue::Integer(1), + RefValue::Integer(99), + RefValue::Integer(3), + ]; + + let result_skip_0 = + compare_records_generic(&serialized, &unpacked, &index_info, &collations, 0).unwrap(); + let result_skip_1 = + compare_records_generic(&serialized, &unpacked, &index_info, &collations, 1).unwrap(); + + assert_eq!(result_skip_0, std::cmp::Ordering::Less); + + assert_eq!(result_skip_1, std::cmp::Ordering::Less); + } + + #[test] + fn test_strategy_selection() { + let index_info_small = + create_index_info(3, vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc]); + let index_info_large = create_index_info(15, vec![SortOrder::Asc; 15]); + let collations = vec![CollationSeq::Binary, CollationSeq::Binary]; + + let int_values = vec![ + RefValue::Integer(42), + RefValue::Text(TextRef::from_str("hello")), + ]; + assert!(matches!( + find_compare(&int_values, &index_info_small, &collations), + RecordCompare::Int + )); + + let string_values = vec![ + RefValue::Text(TextRef::from_str("hello")), + RefValue::Integer(42), + ]; + assert!(matches!( + find_compare(&string_values, &index_info_small, &collations), + RecordCompare::String + )); + + let large_values: Vec = (0..15).map(|i| RefValue::Integer(i)).collect(); + assert!(matches!( + find_compare(&large_values, &index_info_large, &collations), + RecordCompare::Generic + )); + + let blob_values = vec![RefValue::Blob(RawSlice::from_slice(&[1, 2, 3]))]; + assert!(matches!( + find_compare(&blob_values, &index_info_small, &collations), + RecordCompare::Generic + )); + } #[test] fn test_serialize_null() {