diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 3ece33ab8..871d49d78 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -11,11 +11,11 @@ use crate::{ LEAF_PAGE_HEADER_SIZE_BYTES, LEFT_CHILD_PTR_SIZE_BYTES, }, }, - translate::{collate::CollationSeq, plan::IterationDirection}, + translate::plan::IterationDirection, turso_assert, types::{ - find_compare, get_tie_breaker_from_seek_op, IndexKeyInfo, IndexKeySortOrder, - ParseRecordState, RecordCompare, RecordCursor, SeekResult, + find_compare, get_tie_breaker_from_seek_op, IndexInfo, ParseRecordState, RecordCompare, + RecordCursor, SeekResult, }, MvCursor, }; @@ -507,17 +507,14 @@ pub struct BTreeCursor { reusable_immutable_record: RefCell>, /// Reusable immutable record, used to allow better allocation strategy. parse_record_state: RefCell, - pub index_key_info: Option, + /// Information about the index key structure (sort order, collation, etc) + pub index_info: Option, /// Maintain count of the number of records in the btree. Used for the `Count` opcode count: usize, /// Stores the cursor context before rebalancing so that a seek can be done later context: Option, /// Store whether the Cursor is in a valid state. Meaning if it is pointing to a valid cell index or not pub valid_state: CursorValidState, - /// Colations for Index Btree constraint checks - /// Contains the Collation Seq for the whole Index - /// This Vec should be empty for Table Btree - pub collations: Vec, seek_state: CursorSeekState, /// Separate state to read a record with overflow pages. This separation from `state` is necessary as /// we can be in a function that relies on `state`, but also needs to process overflow pages @@ -568,7 +565,6 @@ impl BTreeCursor { mv_cursor: Option>>, pager: Rc, root_page: usize, - collations: Vec, num_columns: usize, ) -> Self { Self { @@ -586,11 +582,10 @@ impl BTreeCursor { stack: RefCell::new([const { None }; BTCURSOR_MAX_DEPTH + 1]), }, reusable_immutable_record: RefCell::new(None), - index_key_info: None, + index_info: None, count: 0, context: None, valid_state: CursorValidState::Valid, - collations, seek_state: CursorSeekState::Start, read_overflow_state: RefCell::new(None), find_cell_state: FindCellState(None), @@ -605,7 +600,7 @@ impl BTreeCursor { root_page: usize, num_columns: usize, ) -> Self { - Self::new(mv_cursor, pager, root_page, Vec::new(), num_columns) + Self::new(mv_cursor, pager, root_page, num_columns) } pub fn new_index( @@ -613,23 +608,15 @@ impl BTreeCursor { pager: Rc, root_page: usize, index: &Index, - collations: Vec, num_columns: usize, ) -> Self { - let mut cursor = Self::new(mv_cursor, pager, root_page, collations, num_columns); - cursor.index_key_info = Some(IndexKeyInfo::new_from_index(index)); + let mut cursor = Self::new(mv_cursor, pager, root_page, num_columns); + cursor.index_info = Some(IndexInfo::new_from_index(index)); cursor } - pub fn key_sort_order(&self) -> IndexKeySortOrder { - match &self.index_key_info { - Some(index_key_info) => index_key_info.sort_order, - None => IndexKeySortOrder::default(), - } - } - pub fn has_rowid(&self) -> bool { - match &self.index_key_info { + match &self.index_info { Some(index_key_info) => index_key_info.has_rowid, None => true, // currently we don't support WITHOUT ROWID tables } @@ -1493,9 +1480,13 @@ impl BTreeCursor { let iter_dir = cmp.iteration_direction(); let key_values = index_key.get_values(); - let index_info_default = IndexKeyInfo::default(); - let index_info = *self.index_key_info.as_ref().unwrap_or(&index_info_default); - let record_comparer = find_compare(&key_values, &index_info, &self.collations); + let record_comparer = { + let index_info = self + .index_info + .as_ref() + .expect("indexbtree_move_to without index_info"); + find_compare(&key_values, index_info) + }; tracing::debug!("Using record comparison strategy: {:?}", record_comparer); let tie_breaker = get_tie_breaker_from_seek_op(cmp); @@ -1639,8 +1630,9 @@ impl BTreeCursor { .compare( record, &key_values, - &index_info, - &self.collations, + self.index_info + .as_ref() + .expect("indexbtree_move_to without index_info"), 0, tie_breaker, ) @@ -1848,9 +1840,13 @@ impl BTreeCursor { seek_op: SeekOp, ) -> Result> { let key_values = key.get_values(); - let index_info_default = IndexKeyInfo::default(); - let index_info = *self.index_key_info.as_ref().unwrap_or(&index_info_default); - let record_comparer = find_compare(&key_values, &index_info, &self.collations); + let record_comparer = { + let index_info = self + .index_info + .as_ref() + .expect("indexbtree_seek without index_info"); + find_compare(&key_values, index_info) + }; tracing::debug!( "Using record comparison strategy for seek: {:?}", @@ -1972,7 +1968,9 @@ impl BTreeCursor { key_values.as_slice(), seek_op, &record_comparer, - &index_info, + self.index_info + .as_ref() + .expect("indexbtree_seek without index_info"), ); if found { nearest_matching_cell.set(Some(cur_cell_idx as usize)); @@ -2006,21 +2004,14 @@ impl BTreeCursor { key_values: &[RefValue], seek_op: SeekOp, record_comparer: &RecordCompare, - index_info: &IndexKeyInfo, + index_info: &IndexInfo, ) -> (Ordering, bool) { let record = self.get_immutable_record(); let record = record.as_ref().unwrap(); let tie_breaker = get_tie_breaker_from_seek_op(seek_op); let cmp = record_comparer - .compare( - record, - key_values, - index_info, - &self.collations, - 0, - tie_breaker, - ) + .compare(record, key_values, index_info, 0, tie_breaker) .unwrap(); let found = match seek_op { @@ -2189,8 +2180,7 @@ impl BTreeCursor { .as_ref() .unwrap() .get_values().as_slice(), - self.key_sort_order(), - &self.collations, + &self.index_info.as_ref().unwrap().key_info, ); if cmp == Ordering::Equal { tracing::debug!("IndexLeafCell: found exact match with cell_idx={cell_idx}, overwriting"); @@ -3925,8 +3915,11 @@ impl BTreeCursor { compare_immutable( key_values.as_slice(), record_same_number_cols, - self.key_sort_order(), - &self.collations, + self.index_info + .as_ref() + .expect("indexbtree_find_cell without index_info") + .key_info + .as_slice(), ) } }; @@ -5098,10 +5091,6 @@ impl BTreeCursor { } } - pub fn collations(&self) -> &[CollationSeq] { - &self.collations - } - pub fn read_page(&self, page_idx: usize) -> Result { btree_read_page(&self.pager, page_idx) } @@ -6527,10 +6516,12 @@ mod tests { }; use sorted_vec::SortedVec; use test_log::test; + use turso_sqlite3_parser::ast::SortOrder; use super::*; use crate::{ io::{Buffer, Completion, CompletionType, MemoryIO, OpenFlags, IO}, + schema::IndexColumn, storage::{database::DatabaseFile, page_cache::DumbLruPageCache}, types::Text, vdbe::Register, @@ -7097,7 +7088,6 @@ mod tests { fn btree_index_insert_fuzz_run(attempts: usize, inserts: usize) { use crate::storage::pager::CreateBTreeFlags; - let num_columns = 5; let (mut rng, seed) = if std::env::var("SEED").is_ok() { let seed = std::env::var("SEED").unwrap(); @@ -7119,8 +7109,31 @@ mod tests { panic!("btree_create returned IO in test, unexpected") } }; - let mut cursor = - BTreeCursor::new_table(None, pager.clone(), index_root_page, num_columns); + let index_def = Index { + name: "testindex".to_string(), + columns: (0..10) + .map(|i| IndexColumn { + name: format!("test{}", i), + order: SortOrder::Asc, + collation: None, + pos_in_table: i, + default: None, + }) + .collect(), + table_name: "test".to_string(), + root_page: index_root_page, + unique: false, + ephemeral: false, + has_rowid: false, + }; + let num_columns = index_def.columns.len(); + let mut cursor = BTreeCursor::new_index( + None, + pager.clone(), + index_root_page, + &index_def, + num_columns, + ); let mut keys = SortedVec::new(); tracing::info!("seed: {seed}"); for i in 0..inserts { @@ -7129,7 +7142,7 @@ mod tests { let key = { let result; loop { - let cols = (0..10) + let cols = (0..num_columns) .map(|_| (rng.next_u64() % (1 << 30)) as i64) .collect::>(); if seen.contains(&cols) { @@ -8410,7 +8423,7 @@ mod tests { pub fn test_read_write_payload_with_offset() { let (pager, root_page, _, _) = empty_btree(); let num_columns = 5; - let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, vec![], num_columns); + let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, num_columns); let offset = 2; // blobs data starts at offset 2 let initial_text = "hello world"; let initial_blob = initial_text.as_bytes().to_vec(); @@ -8487,7 +8500,7 @@ mod tests { pub fn test_read_write_payload_with_overflow_page() { let (pager, root_page, _, _) = empty_btree(); let num_columns = 5; - let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, vec![], num_columns); + let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, num_columns); let mut large_blob = vec![b'A'; 40960 - 11]; // insert large blob. 40960 = 10 page long. let hello_world = b"hello world"; large_blob.extend_from_slice(hello_world); diff --git a/core/types.rs b/core/types.rs index fd5d76721..49e1ebcc0 100644 --- a/core/types.rs +++ b/core/types.rs @@ -14,7 +14,7 @@ use crate::translate::plan::IterationDirection; use crate::vdbe::sorter::Sorter; use crate::vdbe::Register; use crate::vtab::VirtualTableCursor; -use crate::Result; +use crate::{turso_assert, Result}; use std::fmt::{Debug, Display}; const MAX_REAL_SIZE: u8 = 15; @@ -1441,68 +1441,57 @@ fn sqlite_int_float_compare(int_val: i64, float_val: f64) -> std::cmp::Ordering } } -/// A bitfield that represents the comparison spec for index keys. -/// Since indexed columns can individually specify ASC/DESC, each key must -/// be compared differently. -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -#[repr(transparent)] -pub struct IndexKeySortOrder(pub u64); - -impl IndexKeySortOrder { - pub fn get_sort_order_for_col(&self, column_idx: usize) -> SortOrder { - assert!(column_idx < 64, "column index out of range: {column_idx}"); - match self.0 & (1 << column_idx) { - 0 => SortOrder::Asc, - _ => SortOrder::Desc, - } - } - - pub fn from_index(index: &Index) -> Self { - let mut spec = 0; - for (i, column) in index.columns.iter().enumerate() { - spec |= ((column.order == SortOrder::Desc) as u64) << i; - } - IndexKeySortOrder(spec) - } - - pub fn from_list(order: &[SortOrder]) -> Self { - let mut spec = 0; - for (i, order) in order.iter().enumerate() { - spec |= ((*order == SortOrder::Desc) as u64) << i; - } - IndexKeySortOrder(spec) - } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct KeyInfo { + pub sort_order: SortOrder, + pub collation: CollationSeq, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq)] /// Metadata about an index, used for handling and comparing index keys. /// /// This struct provides information about the sorting order of columns, /// whether the index includes a row ID, and the total number of columns /// in the index. -pub struct IndexKeyInfo { +pub struct IndexInfo { /// Specifies the sorting order (ascending or descending) for each column in the index. - pub sort_order: IndexKeySortOrder, + pub key_info: Vec, /// Indicates whether the index includes a row ID column. pub has_rowid: bool, /// The total number of columns in the index, including the row ID column if present. pub num_cols: usize, } -impl Default for IndexKeyInfo { +impl Default for IndexInfo { fn default() -> Self { Self { - sort_order: IndexKeySortOrder::default(), + key_info: vec![], has_rowid: true, num_cols: 1, } } } -impl IndexKeyInfo { +impl IndexInfo { pub fn new_from_index(index: &Index) -> Self { Self { - sort_order: IndexKeySortOrder::from_index(index), + key_info: { + let mut key_info: Vec = index + .columns + .iter() + .map(|c| KeyInfo { + sort_order: c.order, + collation: c.collation.unwrap_or_default(), + }) + .collect(); + if index.has_rowid { + key_info.push(KeyInfo { + sort_order: SortOrder::Asc, + collation: CollationSeq::Binary, + }); + } + key_info + }, has_rowid: index.has_rowid, num_cols: index.columns.len() + (index.has_rowid as usize), } @@ -1512,13 +1501,13 @@ impl IndexKeyInfo { pub fn compare_immutable( l: &[RefValue], r: &[RefValue], - index_key_sort_order: IndexKeySortOrder, - collations: &[CollationSeq], + column_info: &[KeyInfo], ) -> std::cmp::Ordering { assert_eq!(l.len(), r.len()); + turso_assert!(column_info.len() >= l.len(), "column_info.len() < l.len()"); for (i, (l, r)) in l.iter().zip(r).enumerate() { - let column_order = index_key_sort_order.get_sort_order_for_col(i); - let collation = collations.get(i).copied().unwrap_or_default(); + let column_order = column_info[i].sort_order; + let collation = column_info[i].collation; let cmp = match (l, r) { (RefValue::Text(left), RefValue::Text(right)) => { collation.compare_strings(left.as_str(), right.as_str()) @@ -1547,39 +1536,31 @@ impl RecordCompare { &self, serialized: &ImmutableRecord, unpacked: &[RefValue], - index_info: &IndexKeyInfo, - collations: &[CollationSeq], + index_info: &IndexInfo, skip: usize, tie_breaker: std::cmp::Ordering, ) -> Result { match self { RecordCompare::Int => { - compare_records_int(serialized, unpacked, index_info, collations, tie_breaker) + compare_records_int(serialized, unpacked, index_info, tie_breaker) } RecordCompare::String => { - compare_records_string(serialized, unpacked, index_info, collations, tie_breaker) + compare_records_string(serialized, unpacked, index_info, tie_breaker) + } + RecordCompare::Generic => { + compare_records_generic(serialized, unpacked, index_info, skip, tie_breaker) } - RecordCompare::Generic => compare_records_generic( - serialized, - unpacked, - index_info, - collations, - skip, - tie_breaker, - ), } } } -pub fn find_compare( - unpacked: &[RefValue], - index_info: &IndexKeyInfo, - collations: &[CollationSeq], -) -> RecordCompare { +pub fn find_compare(unpacked: &[RefValue], index_info: &IndexInfo) -> RecordCompare { if !unpacked.is_empty() && index_info.num_cols <= 13 { match &unpacked[0] { RefValue::Integer(_) => RecordCompare::Int, - RefValue::Text(_) if is_binary_collation(collations, 0) => RecordCompare::String, + RefValue::Text(_) if index_info.key_info[0].collation == CollationSeq::Binary => { + RecordCompare::String + } _ => RecordCompare::Generic, } } else { @@ -1644,20 +1625,16 @@ pub fn get_tie_breaker_from_seek_op(seek_op: SeekOp) -> std::cmp::Ordering { fn compare_records_int( serialized: &ImmutableRecord, unpacked: &[RefValue], - index_info: &IndexKeyInfo, - collations: &[CollationSeq], + index_info: &IndexInfo, tie_breaker: std::cmp::Ordering, ) -> Result { + turso_assert!( + index_info.key_info.len() >= unpacked.len(), + "index_info.key_info.len() < unpacked.len()" + ); let payload = serialized.get_payload(); if payload.len() < 2 { - return compare_records_generic( - serialized, - unpacked, - index_info, - collations, - 0, - tie_breaker, - ); + return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); } let (header_size, offset_1st_serialtype) = read_varint(payload)?; @@ -1675,30 +1652,16 @@ fn compare_records_int( let serialtype_is_integer = matches!(first_serial_type, 1..=6 | 8 | 9); if !serialtype_is_integer { - return compare_records_generic( - serialized, - unpacked, - index_info, - collations, - 0, - tie_breaker, - ); + return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); } let data_start = header_size; let lhs_int = read_integer(&payload[data_start..], first_serial_type as u8)?; let RefValue::Integer(rhs_int) = unpacked[0] else { - return compare_records_generic( - serialized, - unpacked, - index_info, - collations, - 0, - tie_breaker, - ); + return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); }; - let comparison = match index_info.sort_order.get_sort_order_for_col(0) { + let comparison = match index_info.key_info[0].sort_order { SortOrder::Asc => lhs_int.cmp(&rhs_int), SortOrder::Desc => lhs_int.cmp(&rhs_int).reverse(), }; @@ -1706,14 +1669,7 @@ fn compare_records_int( std::cmp::Ordering::Equal => { // First fields equal, compare remaining fields if any if unpacked.len() > 1 { - return compare_records_generic( - serialized, - unpacked, - index_info, - collations, - 1, - tie_breaker, - ); + return compare_records_generic(serialized, unpacked, index_info, 1, tie_breaker); } Ok(tie_breaker) } @@ -1762,20 +1718,16 @@ fn compare_records_int( fn compare_records_string( serialized: &ImmutableRecord, unpacked: &[RefValue], - index_info: &IndexKeyInfo, - collations: &[CollationSeq], + index_info: &IndexInfo, tie_breaker: std::cmp::Ordering, ) -> Result { + turso_assert!( + index_info.key_info.len() >= unpacked.len(), + "index_info.key_info.len() < unpacked.len()" + ); let payload = serialized.get_payload(); if payload.len() < 2 { - return compare_records_generic( - serialized, - unpacked, - index_info, - collations, - 0, - tie_breaker, - ); + return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); } let (header_size, offset_1st_serialtype) = read_varint(payload)?; @@ -1793,25 +1745,11 @@ fn compare_records_string( let serialtype_is_string = first_serial_type >= 13 && (first_serial_type & 1) == 1; if !serialtype_is_string { - return compare_records_generic( - serialized, - unpacked, - index_info, - collations, - 0, - tie_breaker, - ); + return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); } let RefValue::Text(rhs_text) = &unpacked[0] else { - return compare_records_generic( - serialized, - unpacked, - index_info, - collations, - 0, - tie_breaker, - ); + return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); }; let string_len = (first_serial_type as usize - 13) / 2; @@ -1823,24 +1761,13 @@ fn compare_records_string( let (lhs_value, _) = read_value(&payload[data_start..], serial_type)?; let RefValue::Text(lhs_text) = lhs_value else { - return compare_records_generic( - serialized, - unpacked, - index_info, - collations, - 0, - tie_breaker, - ); + return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); }; - let comparison = if let Some(collation) = collations.first() { - collation.compare_strings(lhs_text.as_str(), rhs_text.as_str()) - } else { - // No collation case - lhs_text.value.to_slice().cmp(rhs_text.value.to_slice()) - }; + let collation = index_info.key_info[0].collation; + let comparison = collation.compare_strings(lhs_text.as_str(), rhs_text.as_str()); - let final_comparison = match index_info.sort_order.get_sort_order_for_col(0) { + let final_comparison = match index_info.key_info[0].sort_order { SortOrder::Asc => comparison, SortOrder::Desc => comparison.reverse(), }; @@ -1849,7 +1776,7 @@ fn compare_records_string( std::cmp::Ordering::Equal => { let len_cmp = lhs_text.value.len.cmp(&rhs_text.value.len); if len_cmp != std::cmp::Ordering::Equal { - let adjusted = match index_info.sort_order.get_sort_order_for_col(0) { + let adjusted = match index_info.key_info[0].sort_order { SortOrder::Asc => len_cmp, SortOrder::Desc => len_cmp.reverse(), }; @@ -1857,14 +1784,7 @@ fn compare_records_string( } if unpacked.len() > 1 { - return compare_records_generic( - serialized, - unpacked, - index_info, - collations, - 1, - tie_breaker, - ); + return compare_records_generic(serialized, unpacked, index_info, 1, tie_breaker); } Ok(tie_breaker) } @@ -1907,11 +1827,14 @@ fn compare_records_string( pub fn compare_records_generic( serialized: &ImmutableRecord, unpacked: &[RefValue], - index_info: &IndexKeyInfo, - collations: &[CollationSeq], + index_info: &IndexInfo, skip: usize, tie_breaker: std::cmp::Ordering, ) -> Result { + turso_assert!( + index_info.key_info.len() >= unpacked.len(), + "index_info.key_info.len() < unpacked.len()" + ); let payload = serialized.get_payload(); if payload.is_empty() { return Ok(std::cmp::Ordering::Less); @@ -1961,13 +1884,9 @@ pub fn compare_records_generic( }; let comparison = match (&lhs_value, rhs_value) { - (RefValue::Text(lhs_text), RefValue::Text(rhs_text)) => { - if let Some(collation) = collations.get(field_idx) { - collation.compare_strings(lhs_text.as_str(), rhs_text.as_str()) - } else { - lhs_text.value.to_slice().cmp(rhs_text.value.to_slice()) - } - } + (RefValue::Text(lhs_text), RefValue::Text(rhs_text)) => index_info.key_info[field_idx] + .collation + .compare_strings(lhs_text.as_str(), rhs_text.as_str()), (RefValue::Integer(lhs_int), RefValue::Float(rhs_float)) => { sqlite_int_float_compare(*lhs_int, *rhs_float) @@ -1980,7 +1899,7 @@ pub fn compare_records_generic( _ => lhs_value.partial_cmp(rhs_value).unwrap(), }; - let final_comparison = match index_info.sort_order.get_sort_order_for_col(field_idx) { + let final_comparison = match index_info.key_info[field_idx].sort_order { SortOrder::Asc => comparison, SortOrder::Desc => comparison.reverse(), }; @@ -1995,11 +1914,6 @@ pub fn compare_records_generic( Ok(tie_breaker) } -#[inline(always)] -fn is_binary_collation(collations: &[CollationSeq], col_idx: usize) -> bool { - collations[col_idx] == CollationSeq::Binary -} - const I8_LOW: i64 = -128; const I8_HIGH: i64 = 127; const I16_LOW: i64 = -32768; @@ -2428,15 +2342,14 @@ mod tests { pub fn compare_immutable_for_testing( l: &[RefValue], r: &[RefValue], - index_key_sort_order: IndexKeySortOrder, - collations: &[CollationSeq], + index_key_info: &[KeyInfo], tie_breaker: std::cmp::Ordering, ) -> std::cmp::Ordering { let min_len = l.len().min(r.len()); for i in 0..min_len { - let column_order = index_key_sort_order.get_sort_order_for_col(i); - let collation = collations.get(i).copied().unwrap_or_default(); + let column_order = index_key_info[i].sort_order; + let collation = index_key_info[i].collation; let cmp = match (&l[i], &r[i]) { (RefValue::Text(left), RefValue::Text(right)) => { @@ -2461,9 +2374,20 @@ mod tests { ImmutableRecord::from_registers(®isters, registers.len()) } - fn create_index_info(num_cols: usize, sort_orders: Vec) -> IndexKeyInfo { - IndexKeyInfo { - sort_order: IndexKeySortOrder::from_list(sort_orders.as_slice()), + fn create_index_info( + num_cols: usize, + sort_orders: Vec, + collations: Vec, + ) -> IndexInfo { + IndexInfo { + key_info: sort_orders + .into_iter() + .zip(collations) + .map(|(sort_order, collation)| KeyInfo { + sort_order, + collation, + }) + .collect(), has_rowid: false, num_cols, } @@ -2503,8 +2427,7 @@ mod tests { fn assert_compare_matches_full_comparison( serialized_values: Vec, unpacked_values: Vec, - index_info: &IndexKeyInfo, - collations: &[CollationSeq], + index_info: &IndexInfo, test_name: &str, ) { let serialized = create_record(serialized_values.clone()); @@ -2517,21 +2440,13 @@ mod tests { let gold_result = compare_immutable_for_testing( &serialized_ref_values, &unpacked_values, - index_info.sort_order, - collations, + &index_info.key_info, tie_breaker, ); - let comparer = find_compare(&unpacked_values, index_info, collations); + let comparer = find_compare(&unpacked_values, index_info); let optimized_result = comparer - .compare( - &serialized, - &unpacked_values, - index_info, - collations, - 0, - tie_breaker, - ) + .compare(&serialized, &unpacked_values, index_info, 0, tie_breaker) .unwrap(); assert_eq!( @@ -2540,15 +2455,9 @@ mod tests { test_name, gold_result, optimized_result, comparer ); - let generic_result = compare_records_generic( - &serialized, - &unpacked_values, - index_info, - collations, - 0, - tie_breaker, - ) - .unwrap(); + let generic_result = + compare_records_generic(&serialized, &unpacked_values, index_info, 0, tie_breaker) + .unwrap(); assert_eq!( gold_result, generic_result, "Test '{}' failed with generic: Full Comparison: {:?}, Generic: {:?}", @@ -2616,8 +2525,11 @@ mod tests { #[test] fn test_integer_fast_path() { - let index_info = create_index_info(2, vec![SortOrder::Asc, SortOrder::Asc]); - let collations = vec![CollationSeq::Binary, CollationSeq::Binary]; + let index_info = create_index_info( + 2, + vec![SortOrder::Asc, SortOrder::Asc], + vec![CollationSeq::Binary; 2], + ); let test_cases = vec![ ( @@ -2678,7 +2590,6 @@ mod tests { serialized_values, unpacked_values, &index_info, - &collations, test_name, ); } @@ -2686,8 +2597,11 @@ mod tests { #[test] fn test_string_fast_path() { - let index_info = create_index_info(2, vec![SortOrder::Asc, SortOrder::Asc]); - let collations = vec![CollationSeq::Binary, CollationSeq::Binary]; + let index_info = create_index_info( + 2, + vec![SortOrder::Asc, SortOrder::Asc], + vec![CollationSeq::Binary; 2], + ); let test_cases = vec![ ( @@ -2739,7 +2653,6 @@ mod tests { serialized_values, unpacked_values, &index_info, - &collations, test_name, ); } @@ -2747,8 +2660,7 @@ mod tests { #[test] fn test_type_precedence() { - let index_info = create_index_info(1, vec![SortOrder::Asc]); - let collations = vec![CollationSeq::Binary]; + let index_info = create_index_info(1, vec![SortOrder::Asc], vec![CollationSeq::Binary]); // Test SQLite type precedence: NULL < Numbers < Text < Blob let test_cases = vec![ @@ -2823,7 +2735,6 @@ mod tests { serialized_values, unpacked_values, &index_info, - &collations, test_name, ); } @@ -2831,8 +2742,11 @@ mod tests { #[test] fn test_sort_order_desc() { - let index_info = create_index_info(2, vec![SortOrder::Desc, SortOrder::Asc]); - let collations = vec![CollationSeq::Binary, CollationSeq::Binary]; + let index_info = create_index_info( + 2, + vec![SortOrder::Desc, SortOrder::Asc], + vec![CollationSeq::Binary; 2], + ); let test_cases = vec![ // DESC order should reverse first field comparison @@ -2862,7 +2776,6 @@ mod tests { serialized_values, unpacked_values, &index_info, - &collations, test_name, ); } @@ -2870,12 +2783,8 @@ mod tests { #[test] fn test_edge_cases() { - let index_info = create_index_info(3, vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc]); - let collations = vec![ - CollationSeq::Binary, - CollationSeq::Binary, - CollationSeq::Binary, - ]; + let index_info = + create_index_info(15, vec![SortOrder::Asc; 15], vec![CollationSeq::Binary; 15]); let test_cases = vec![ ( @@ -2923,7 +2832,6 @@ mod tests { serialized_values, unpacked_values, &index_info, - &collations, test_name, ); } @@ -2931,12 +2839,11 @@ mod tests { #[test] fn test_skip_parameter() { - let index_info = create_index_info(3, vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc]); - let collations = vec![ - CollationSeq::Binary, - CollationSeq::Binary, - CollationSeq::Binary, - ]; + let index_info = create_index_info( + 3, + vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc], + vec![CollationSeq::Binary; 3], + ); let serialized = create_record(vec![ Value::Integer(1), @@ -2950,24 +2857,10 @@ mod tests { ]; let tie_breaker = std::cmp::Ordering::Equal; - let result_skip_0 = compare_records_generic( - &serialized, - &unpacked, - &index_info, - &collations, - 0, - tie_breaker, - ) - .unwrap(); - let result_skip_1 = compare_records_generic( - &serialized, - &unpacked, - &index_info, - &collations, - 1, - tie_breaker, - ) - .unwrap(); + let result_skip_0 = + compare_records_generic(&serialized, &unpacked, &index_info, 0, tie_breaker).unwrap(); + let result_skip_1 = + compare_records_generic(&serialized, &unpacked, &index_info, 1, tie_breaker).unwrap(); assert_eq!(result_skip_0, std::cmp::Ordering::Less); @@ -2976,17 +2869,21 @@ mod tests { #[test] fn test_strategy_selection() { - let index_info_small = - create_index_info(3, vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc]); - let index_info_large = create_index_info(15, vec![SortOrder::Asc; 15]); - let collations = vec![CollationSeq::Binary, CollationSeq::Binary]; + let collations_small = vec![CollationSeq::Binary; 3]; + let collations_large = vec![CollationSeq::Binary; 15]; + let index_info_small = create_index_info( + 3, + vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc], + collations_small, + ); + let index_info_large = create_index_info(15, vec![SortOrder::Asc; 15], collations_large); let int_values = vec![ RefValue::Integer(42), RefValue::Text(TextRef::from_str("hello")), ]; assert!(matches!( - find_compare(&int_values, &index_info_small, &collations), + find_compare(&int_values, &index_info_small), RecordCompare::Int )); @@ -2995,19 +2892,19 @@ mod tests { RefValue::Integer(42), ]; assert!(matches!( - find_compare(&string_values, &index_info_small, &collations), + find_compare(&string_values, &index_info_small), RecordCompare::String )); let large_values: Vec = (0..15).map(RefValue::Integer).collect(); assert!(matches!( - find_compare(&large_values, &index_info_large, &collations), + find_compare(&large_values, &index_info_large), RecordCompare::Generic )); let blob_values = vec![RefValue::Blob(RawSlice::from_slice(&[1, 2, 3]))]; assert!(matches!( - find_compare(&blob_values, &index_info_small, &collations), + find_compare(&blob_values, &index_info_small), RecordCompare::Generic )); } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 1ec6d9f0f..3d1b95049 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -907,31 +907,11 @@ pub fn op_open_read( .replace(Cursor::new_btree(cursor)); } CursorType::BTreeIndex(index) => { - let conn = program.connection.clone(); - let schema = conn.schema.borrow(); - let table = schema - .get_table(&index.table_name) - .and_then(|table| table.btree()); - let collations = table.map_or(Vec::new(), |table| { - index - .columns - .iter() - .map(|c| { - table - .columns - .get(c.pos_in_table) - .unwrap() - .collation - .unwrap_or_default() - }) - .collect() - }); let cursor = BTreeCursor::new_index( mv_cursor, pager.clone(), *root_page, index.as_ref(), - collations, num_columns, ); cursors @@ -2824,10 +2804,9 @@ pub fn op_idx_ge( registers_to_ref_values(&state.registers[*start_reg..*start_reg + *num_regs]); let tie_breaker = get_tie_breaker_from_idx_comp_op(insn); let ord = compare_records_generic( - &idx_record, // The serialized record from the index - &values, // The record built from registers - &cursor.index_key_info.unwrap(), // Sort order flags - cursor.collations(), // Collation sequences + &idx_record, // The serialized record from the index + &values, // The record built from registers + cursor.index_info.as_ref().unwrap(), // Sort order flags 0, tie_breaker, )?; @@ -2896,8 +2875,7 @@ pub fn op_idx_le( let ord = compare_records_generic( &idx_record, &values, - &cursor.index_key_info.unwrap(), - cursor.collations(), + cursor.index_info.as_ref().unwrap(), 0, tie_breaker, )?; @@ -2948,8 +2926,7 @@ pub fn op_idx_gt( let ord = compare_records_generic( &idx_record, &values, - &cursor.index_key_info.unwrap(), - cursor.collations(), + cursor.index_info.as_ref().unwrap(), 0, tie_breaker, )?; @@ -3001,8 +2978,7 @@ pub fn op_idx_lt( let ord = compare_records_generic( &idx_record, &values, - &cursor.index_key_info.unwrap(), - cursor.collations(), + cursor.index_info.as_ref().unwrap(), 0, tie_breaker, )?; @@ -5152,8 +5128,7 @@ pub fn op_idx_insert( let conflict = compare_immutable( existing_key.as_slice(), inserted_key_vals, - cursor.key_sort_order(), - &cursor.collations, + &cursor.index_info.as_ref().unwrap().key_info, ) == std::cmp::Ordering::Equal; if conflict { if flags.has(IdxInsertFlags::NO_OP_DUPLICATE) { @@ -5586,27 +5561,11 @@ pub fn op_open_write( .and_then(|table| table.btree()); let num_columns = index.columns.len(); - let collations = table.map_or(Vec::new(), |table| { - index - .columns - .iter() - .map(|c| { - table - .columns - .get(c.pos_in_table) - .unwrap() - .collation - .unwrap_or_default() - }) - .collect() - }); - let cursor = BTreeCursor::new_index( mv_cursor, pager.clone(), root_page as usize, index.as_ref(), - collations, num_columns, ); cursors @@ -5695,7 +5654,7 @@ pub fn op_destroy( todo!("temp databases not implemented yet."); } // TODO not sure if should be BTreeCursor::new_table or BTreeCursor::new_index here or neither and just pass an emtpy vec - let mut cursor = BTreeCursor::new(None, pager.clone(), *root, Vec::new(), 0); + let mut cursor = BTreeCursor::new(None, pager.clone(), *root, 0); let former_root_page_result = cursor.btree_destroy()?; if let IOResult::Done(former_root_page) = former_root_page_result { state.registers[*former_root_reg] = @@ -6151,11 +6110,6 @@ pub fn op_open_ephemeral( pager.clone(), root_page as usize, index, - index - .columns - .iter() - .map(|c| c.collation.unwrap_or_default()) - .collect(), num_columns, ) } else { diff --git a/core/vdbe/sorter.rs b/core/vdbe/sorter.rs index ae6e71e97..b2b245d88 100644 --- a/core/vdbe/sorter.rs +++ b/core/vdbe/sorter.rs @@ -2,25 +2,31 @@ use turso_sqlite3_parser::ast::SortOrder; use crate::{ translate::collate::CollationSeq, - types::{compare_immutable, ImmutableRecord, IndexKeySortOrder}, + types::{compare_immutable, ImmutableRecord, KeyInfo}, }; pub struct Sorter { records: Vec, current: Option, - order: IndexKeySortOrder, key_len: usize, - collations: Vec, + index_key_info: Vec, } impl Sorter { pub fn new(order: &[SortOrder], collations: Vec) -> Self { + assert_eq!(order.len(), collations.len()); Self { records: Vec::new(), current: None, key_len: order.len(), - order: IndexKeySortOrder::from_list(order), - collations, + index_key_info: order + .iter() + .zip(collations) + .map(|(order, collation)| KeyInfo { + sort_order: *order, + collation, + }) + .collect(), } } pub fn is_empty(&self) -> bool { @@ -49,7 +55,7 @@ impl Sorter { &b_values[..] }; - compare_immutable(a_key, b_key, self.order, &self.collations) + compare_immutable(a_key, b_key, &self.index_key_info) }); self.records.reverse(); self.next()