Merge 'refactor/btree&vdbe: fold index key info (sort order, collations) into a single struct' from Jussi Saurio

These are nearly always used together in some form, so it makes sense to
colocate them, and it also makes many code paths simpler, as we don't
separately pass `collations` and `key_sort_order` around
As a side effect, as the bitfield-based `IndexKeySortOrder` is removed,
we now remove the arbitrary 64 column restriction for indexes, see e.g.
this sim failure which fails to 64+ index columns (not sure why it uses
an index if they are disabled):
https://github.com/tursodatabase/turso/actions/runs/16339391964/job/4615
8045158

Closes #2131
This commit is contained in:
Jussi Saurio
2025-07-17 11:55:56 +03:00
4 changed files with 226 additions and 356 deletions

View File

@@ -11,11 +11,11 @@ use crate::{
LEAF_PAGE_HEADER_SIZE_BYTES, LEFT_CHILD_PTR_SIZE_BYTES,
},
},
translate::{collate::CollationSeq, plan::IterationDirection},
translate::plan::IterationDirection,
turso_assert,
types::{
find_compare, get_tie_breaker_from_seek_op, IndexKeyInfo, IndexKeySortOrder,
ParseRecordState, RecordCompare, RecordCursor, SeekResult,
find_compare, get_tie_breaker_from_seek_op, IndexInfo, ParseRecordState, RecordCompare,
RecordCursor, SeekResult,
},
MvCursor,
};
@@ -507,17 +507,14 @@ pub struct BTreeCursor {
reusable_immutable_record: RefCell<Option<ImmutableRecord>>,
/// Reusable immutable record, used to allow better allocation strategy.
parse_record_state: RefCell<ParseRecordState>,
pub index_key_info: Option<IndexKeyInfo>,
/// Information about the index key structure (sort order, collation, etc)
pub index_info: Option<IndexInfo>,
/// Maintain count of the number of records in the btree. Used for the `Count` opcode
count: usize,
/// Stores the cursor context before rebalancing so that a seek can be done later
context: Option<CursorContext>,
/// Store whether the Cursor is in a valid state. Meaning if it is pointing to a valid cell index or not
pub valid_state: CursorValidState,
/// Colations for Index Btree constraint checks
/// Contains the Collation Seq for the whole Index
/// This Vec should be empty for Table Btree
pub collations: Vec<CollationSeq>,
seek_state: CursorSeekState,
/// Separate state to read a record with overflow pages. This separation from `state` is necessary as
/// we can be in a function that relies on `state`, but also needs to process overflow pages
@@ -568,7 +565,6 @@ impl BTreeCursor {
mv_cursor: Option<Rc<RefCell<MvCursor>>>,
pager: Rc<Pager>,
root_page: usize,
collations: Vec<CollationSeq>,
num_columns: usize,
) -> Self {
Self {
@@ -586,11 +582,10 @@ impl BTreeCursor {
stack: RefCell::new([const { None }; BTCURSOR_MAX_DEPTH + 1]),
},
reusable_immutable_record: RefCell::new(None),
index_key_info: None,
index_info: None,
count: 0,
context: None,
valid_state: CursorValidState::Valid,
collations,
seek_state: CursorSeekState::Start,
read_overflow_state: RefCell::new(None),
find_cell_state: FindCellState(None),
@@ -605,7 +600,7 @@ impl BTreeCursor {
root_page: usize,
num_columns: usize,
) -> Self {
Self::new(mv_cursor, pager, root_page, Vec::new(), num_columns)
Self::new(mv_cursor, pager, root_page, num_columns)
}
pub fn new_index(
@@ -613,23 +608,15 @@ impl BTreeCursor {
pager: Rc<Pager>,
root_page: usize,
index: &Index,
collations: Vec<CollationSeq>,
num_columns: usize,
) -> Self {
let mut cursor = Self::new(mv_cursor, pager, root_page, collations, num_columns);
cursor.index_key_info = Some(IndexKeyInfo::new_from_index(index));
let mut cursor = Self::new(mv_cursor, pager, root_page, num_columns);
cursor.index_info = Some(IndexInfo::new_from_index(index));
cursor
}
pub fn key_sort_order(&self) -> IndexKeySortOrder {
match &self.index_key_info {
Some(index_key_info) => index_key_info.sort_order,
None => IndexKeySortOrder::default(),
}
}
pub fn has_rowid(&self) -> bool {
match &self.index_key_info {
match &self.index_info {
Some(index_key_info) => index_key_info.has_rowid,
None => true, // currently we don't support WITHOUT ROWID tables
}
@@ -1493,9 +1480,13 @@ impl BTreeCursor {
let iter_dir = cmp.iteration_direction();
let key_values = index_key.get_values();
let index_info_default = IndexKeyInfo::default();
let index_info = *self.index_key_info.as_ref().unwrap_or(&index_info_default);
let record_comparer = find_compare(&key_values, &index_info, &self.collations);
let record_comparer = {
let index_info = self
.index_info
.as_ref()
.expect("indexbtree_move_to without index_info");
find_compare(&key_values, index_info)
};
tracing::debug!("Using record comparison strategy: {:?}", record_comparer);
let tie_breaker = get_tie_breaker_from_seek_op(cmp);
@@ -1639,8 +1630,9 @@ impl BTreeCursor {
.compare(
record,
&key_values,
&index_info,
&self.collations,
self.index_info
.as_ref()
.expect("indexbtree_move_to without index_info"),
0,
tie_breaker,
)
@@ -1848,9 +1840,13 @@ impl BTreeCursor {
seek_op: SeekOp,
) -> Result<IOResult<SeekResult>> {
let key_values = key.get_values();
let index_info_default = IndexKeyInfo::default();
let index_info = *self.index_key_info.as_ref().unwrap_or(&index_info_default);
let record_comparer = find_compare(&key_values, &index_info, &self.collations);
let record_comparer = {
let index_info = self
.index_info
.as_ref()
.expect("indexbtree_seek without index_info");
find_compare(&key_values, index_info)
};
tracing::debug!(
"Using record comparison strategy for seek: {:?}",
@@ -1972,7 +1968,9 @@ impl BTreeCursor {
key_values.as_slice(),
seek_op,
&record_comparer,
&index_info,
self.index_info
.as_ref()
.expect("indexbtree_seek without index_info"),
);
if found {
nearest_matching_cell.set(Some(cur_cell_idx as usize));
@@ -2006,21 +2004,14 @@ impl BTreeCursor {
key_values: &[RefValue],
seek_op: SeekOp,
record_comparer: &RecordCompare,
index_info: &IndexKeyInfo,
index_info: &IndexInfo,
) -> (Ordering, bool) {
let record = self.get_immutable_record();
let record = record.as_ref().unwrap();
let tie_breaker = get_tie_breaker_from_seek_op(seek_op);
let cmp = record_comparer
.compare(
record,
key_values,
index_info,
&self.collations,
0,
tie_breaker,
)
.compare(record, key_values, index_info, 0, tie_breaker)
.unwrap();
let found = match seek_op {
@@ -2189,8 +2180,7 @@ impl BTreeCursor {
.as_ref()
.unwrap()
.get_values().as_slice(),
self.key_sort_order(),
&self.collations,
&self.index_info.as_ref().unwrap().key_info,
);
if cmp == Ordering::Equal {
tracing::debug!("IndexLeafCell: found exact match with cell_idx={cell_idx}, overwriting");
@@ -3925,8 +3915,11 @@ impl BTreeCursor {
compare_immutable(
key_values.as_slice(),
record_same_number_cols,
self.key_sort_order(),
&self.collations,
self.index_info
.as_ref()
.expect("indexbtree_find_cell without index_info")
.key_info
.as_slice(),
)
}
};
@@ -5098,10 +5091,6 @@ impl BTreeCursor {
}
}
pub fn collations(&self) -> &[CollationSeq] {
&self.collations
}
pub fn read_page(&self, page_idx: usize) -> Result<BTreePage> {
btree_read_page(&self.pager, page_idx)
}
@@ -6527,10 +6516,12 @@ mod tests {
};
use sorted_vec::SortedVec;
use test_log::test;
use turso_sqlite3_parser::ast::SortOrder;
use super::*;
use crate::{
io::{Buffer, Completion, CompletionType, MemoryIO, OpenFlags, IO},
schema::IndexColumn,
storage::{database::DatabaseFile, page_cache::DumbLruPageCache},
types::Text,
vdbe::Register,
@@ -7097,7 +7088,6 @@ mod tests {
fn btree_index_insert_fuzz_run(attempts: usize, inserts: usize) {
use crate::storage::pager::CreateBTreeFlags;
let num_columns = 5;
let (mut rng, seed) = if std::env::var("SEED").is_ok() {
let seed = std::env::var("SEED").unwrap();
@@ -7119,8 +7109,31 @@ mod tests {
panic!("btree_create returned IO in test, unexpected")
}
};
let mut cursor =
BTreeCursor::new_table(None, pager.clone(), index_root_page, num_columns);
let index_def = Index {
name: "testindex".to_string(),
columns: (0..10)
.map(|i| IndexColumn {
name: format!("test{}", i),
order: SortOrder::Asc,
collation: None,
pos_in_table: i,
default: None,
})
.collect(),
table_name: "test".to_string(),
root_page: index_root_page,
unique: false,
ephemeral: false,
has_rowid: false,
};
let num_columns = index_def.columns.len();
let mut cursor = BTreeCursor::new_index(
None,
pager.clone(),
index_root_page,
&index_def,
num_columns,
);
let mut keys = SortedVec::new();
tracing::info!("seed: {seed}");
for i in 0..inserts {
@@ -7129,7 +7142,7 @@ mod tests {
let key = {
let result;
loop {
let cols = (0..10)
let cols = (0..num_columns)
.map(|_| (rng.next_u64() % (1 << 30)) as i64)
.collect::<Vec<_>>();
if seen.contains(&cols) {
@@ -8410,7 +8423,7 @@ mod tests {
pub fn test_read_write_payload_with_offset() {
let (pager, root_page, _, _) = empty_btree();
let num_columns = 5;
let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, vec![], num_columns);
let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, num_columns);
let offset = 2; // blobs data starts at offset 2
let initial_text = "hello world";
let initial_blob = initial_text.as_bytes().to_vec();
@@ -8487,7 +8500,7 @@ mod tests {
pub fn test_read_write_payload_with_overflow_page() {
let (pager, root_page, _, _) = empty_btree();
let num_columns = 5;
let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, vec![], num_columns);
let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, num_columns);
let mut large_blob = vec![b'A'; 40960 - 11]; // insert large blob. 40960 = 10 page long.
let hello_world = b"hello world";
large_blob.extend_from_slice(hello_world);

View File

@@ -14,7 +14,7 @@ use crate::translate::plan::IterationDirection;
use crate::vdbe::sorter::Sorter;
use crate::vdbe::Register;
use crate::vtab::VirtualTableCursor;
use crate::Result;
use crate::{turso_assert, Result};
use std::fmt::{Debug, Display};
const MAX_REAL_SIZE: u8 = 15;
@@ -1441,68 +1441,57 @@ fn sqlite_int_float_compare(int_val: i64, float_val: f64) -> std::cmp::Ordering
}
}
/// A bitfield that represents the comparison spec for index keys.
/// Since indexed columns can individually specify ASC/DESC, each key must
/// be compared differently.
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(transparent)]
pub struct IndexKeySortOrder(pub u64);
impl IndexKeySortOrder {
pub fn get_sort_order_for_col(&self, column_idx: usize) -> SortOrder {
assert!(column_idx < 64, "column index out of range: {column_idx}");
match self.0 & (1 << column_idx) {
0 => SortOrder::Asc,
_ => SortOrder::Desc,
}
}
pub fn from_index(index: &Index) -> Self {
let mut spec = 0;
for (i, column) in index.columns.iter().enumerate() {
spec |= ((column.order == SortOrder::Desc) as u64) << i;
}
IndexKeySortOrder(spec)
}
pub fn from_list(order: &[SortOrder]) -> Self {
let mut spec = 0;
for (i, order) in order.iter().enumerate() {
spec |= ((*order == SortOrder::Desc) as u64) << i;
}
IndexKeySortOrder(spec)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct KeyInfo {
pub sort_order: SortOrder,
pub collation: CollationSeq,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Debug, Clone, PartialEq, Eq)]
/// Metadata about an index, used for handling and comparing index keys.
///
/// This struct provides information about the sorting order of columns,
/// whether the index includes a row ID, and the total number of columns
/// in the index.
pub struct IndexKeyInfo {
pub struct IndexInfo {
/// Specifies the sorting order (ascending or descending) for each column in the index.
pub sort_order: IndexKeySortOrder,
pub key_info: Vec<KeyInfo>,
/// Indicates whether the index includes a row ID column.
pub has_rowid: bool,
/// The total number of columns in the index, including the row ID column if present.
pub num_cols: usize,
}
impl Default for IndexKeyInfo {
impl Default for IndexInfo {
fn default() -> Self {
Self {
sort_order: IndexKeySortOrder::default(),
key_info: vec![],
has_rowid: true,
num_cols: 1,
}
}
}
impl IndexKeyInfo {
impl IndexInfo {
pub fn new_from_index(index: &Index) -> Self {
Self {
sort_order: IndexKeySortOrder::from_index(index),
key_info: {
let mut key_info: Vec<KeyInfo> = index
.columns
.iter()
.map(|c| KeyInfo {
sort_order: c.order,
collation: c.collation.unwrap_or_default(),
})
.collect();
if index.has_rowid {
key_info.push(KeyInfo {
sort_order: SortOrder::Asc,
collation: CollationSeq::Binary,
});
}
key_info
},
has_rowid: index.has_rowid,
num_cols: index.columns.len() + (index.has_rowid as usize),
}
@@ -1512,13 +1501,13 @@ impl IndexKeyInfo {
pub fn compare_immutable(
l: &[RefValue],
r: &[RefValue],
index_key_sort_order: IndexKeySortOrder,
collations: &[CollationSeq],
column_info: &[KeyInfo],
) -> std::cmp::Ordering {
assert_eq!(l.len(), r.len());
turso_assert!(column_info.len() >= l.len(), "column_info.len() < l.len()");
for (i, (l, r)) in l.iter().zip(r).enumerate() {
let column_order = index_key_sort_order.get_sort_order_for_col(i);
let collation = collations.get(i).copied().unwrap_or_default();
let column_order = column_info[i].sort_order;
let collation = column_info[i].collation;
let cmp = match (l, r) {
(RefValue::Text(left), RefValue::Text(right)) => {
collation.compare_strings(left.as_str(), right.as_str())
@@ -1547,39 +1536,31 @@ impl RecordCompare {
&self,
serialized: &ImmutableRecord,
unpacked: &[RefValue],
index_info: &IndexKeyInfo,
collations: &[CollationSeq],
index_info: &IndexInfo,
skip: usize,
tie_breaker: std::cmp::Ordering,
) -> Result<std::cmp::Ordering> {
match self {
RecordCompare::Int => {
compare_records_int(serialized, unpacked, index_info, collations, tie_breaker)
compare_records_int(serialized, unpacked, index_info, tie_breaker)
}
RecordCompare::String => {
compare_records_string(serialized, unpacked, index_info, collations, tie_breaker)
compare_records_string(serialized, unpacked, index_info, tie_breaker)
}
RecordCompare::Generic => {
compare_records_generic(serialized, unpacked, index_info, skip, tie_breaker)
}
RecordCompare::Generic => compare_records_generic(
serialized,
unpacked,
index_info,
collations,
skip,
tie_breaker,
),
}
}
}
pub fn find_compare(
unpacked: &[RefValue],
index_info: &IndexKeyInfo,
collations: &[CollationSeq],
) -> RecordCompare {
pub fn find_compare(unpacked: &[RefValue], index_info: &IndexInfo) -> RecordCompare {
if !unpacked.is_empty() && index_info.num_cols <= 13 {
match &unpacked[0] {
RefValue::Integer(_) => RecordCompare::Int,
RefValue::Text(_) if is_binary_collation(collations, 0) => RecordCompare::String,
RefValue::Text(_) if index_info.key_info[0].collation == CollationSeq::Binary => {
RecordCompare::String
}
_ => RecordCompare::Generic,
}
} else {
@@ -1644,20 +1625,16 @@ pub fn get_tie_breaker_from_seek_op(seek_op: SeekOp) -> std::cmp::Ordering {
fn compare_records_int(
serialized: &ImmutableRecord,
unpacked: &[RefValue],
index_info: &IndexKeyInfo,
collations: &[CollationSeq],
index_info: &IndexInfo,
tie_breaker: std::cmp::Ordering,
) -> Result<std::cmp::Ordering> {
turso_assert!(
index_info.key_info.len() >= unpacked.len(),
"index_info.key_info.len() < unpacked.len()"
);
let payload = serialized.get_payload();
if payload.len() < 2 {
return compare_records_generic(
serialized,
unpacked,
index_info,
collations,
0,
tie_breaker,
);
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
}
let (header_size, offset_1st_serialtype) = read_varint(payload)?;
@@ -1675,30 +1652,16 @@ fn compare_records_int(
let serialtype_is_integer = matches!(first_serial_type, 1..=6 | 8 | 9);
if !serialtype_is_integer {
return compare_records_generic(
serialized,
unpacked,
index_info,
collations,
0,
tie_breaker,
);
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
}
let data_start = header_size;
let lhs_int = read_integer(&payload[data_start..], first_serial_type as u8)?;
let RefValue::Integer(rhs_int) = unpacked[0] else {
return compare_records_generic(
serialized,
unpacked,
index_info,
collations,
0,
tie_breaker,
);
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
};
let comparison = match index_info.sort_order.get_sort_order_for_col(0) {
let comparison = match index_info.key_info[0].sort_order {
SortOrder::Asc => lhs_int.cmp(&rhs_int),
SortOrder::Desc => lhs_int.cmp(&rhs_int).reverse(),
};
@@ -1706,14 +1669,7 @@ fn compare_records_int(
std::cmp::Ordering::Equal => {
// First fields equal, compare remaining fields if any
if unpacked.len() > 1 {
return compare_records_generic(
serialized,
unpacked,
index_info,
collations,
1,
tie_breaker,
);
return compare_records_generic(serialized, unpacked, index_info, 1, tie_breaker);
}
Ok(tie_breaker)
}
@@ -1762,20 +1718,16 @@ fn compare_records_int(
fn compare_records_string(
serialized: &ImmutableRecord,
unpacked: &[RefValue],
index_info: &IndexKeyInfo,
collations: &[CollationSeq],
index_info: &IndexInfo,
tie_breaker: std::cmp::Ordering,
) -> Result<std::cmp::Ordering> {
turso_assert!(
index_info.key_info.len() >= unpacked.len(),
"index_info.key_info.len() < unpacked.len()"
);
let payload = serialized.get_payload();
if payload.len() < 2 {
return compare_records_generic(
serialized,
unpacked,
index_info,
collations,
0,
tie_breaker,
);
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
}
let (header_size, offset_1st_serialtype) = read_varint(payload)?;
@@ -1793,25 +1745,11 @@ fn compare_records_string(
let serialtype_is_string = first_serial_type >= 13 && (first_serial_type & 1) == 1;
if !serialtype_is_string {
return compare_records_generic(
serialized,
unpacked,
index_info,
collations,
0,
tie_breaker,
);
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
}
let RefValue::Text(rhs_text) = &unpacked[0] else {
return compare_records_generic(
serialized,
unpacked,
index_info,
collations,
0,
tie_breaker,
);
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
};
let string_len = (first_serial_type as usize - 13) / 2;
@@ -1823,24 +1761,13 @@ fn compare_records_string(
let (lhs_value, _) = read_value(&payload[data_start..], serial_type)?;
let RefValue::Text(lhs_text) = lhs_value else {
return compare_records_generic(
serialized,
unpacked,
index_info,
collations,
0,
tie_breaker,
);
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
};
let comparison = if let Some(collation) = collations.first() {
collation.compare_strings(lhs_text.as_str(), rhs_text.as_str())
} else {
// No collation case
lhs_text.value.to_slice().cmp(rhs_text.value.to_slice())
};
let collation = index_info.key_info[0].collation;
let comparison = collation.compare_strings(lhs_text.as_str(), rhs_text.as_str());
let final_comparison = match index_info.sort_order.get_sort_order_for_col(0) {
let final_comparison = match index_info.key_info[0].sort_order {
SortOrder::Asc => comparison,
SortOrder::Desc => comparison.reverse(),
};
@@ -1849,7 +1776,7 @@ fn compare_records_string(
std::cmp::Ordering::Equal => {
let len_cmp = lhs_text.value.len.cmp(&rhs_text.value.len);
if len_cmp != std::cmp::Ordering::Equal {
let adjusted = match index_info.sort_order.get_sort_order_for_col(0) {
let adjusted = match index_info.key_info[0].sort_order {
SortOrder::Asc => len_cmp,
SortOrder::Desc => len_cmp.reverse(),
};
@@ -1857,14 +1784,7 @@ fn compare_records_string(
}
if unpacked.len() > 1 {
return compare_records_generic(
serialized,
unpacked,
index_info,
collations,
1,
tie_breaker,
);
return compare_records_generic(serialized, unpacked, index_info, 1, tie_breaker);
}
Ok(tie_breaker)
}
@@ -1907,11 +1827,14 @@ fn compare_records_string(
pub fn compare_records_generic(
serialized: &ImmutableRecord,
unpacked: &[RefValue],
index_info: &IndexKeyInfo,
collations: &[CollationSeq],
index_info: &IndexInfo,
skip: usize,
tie_breaker: std::cmp::Ordering,
) -> Result<std::cmp::Ordering> {
turso_assert!(
index_info.key_info.len() >= unpacked.len(),
"index_info.key_info.len() < unpacked.len()"
);
let payload = serialized.get_payload();
if payload.is_empty() {
return Ok(std::cmp::Ordering::Less);
@@ -1961,13 +1884,9 @@ pub fn compare_records_generic(
};
let comparison = match (&lhs_value, rhs_value) {
(RefValue::Text(lhs_text), RefValue::Text(rhs_text)) => {
if let Some(collation) = collations.get(field_idx) {
collation.compare_strings(lhs_text.as_str(), rhs_text.as_str())
} else {
lhs_text.value.to_slice().cmp(rhs_text.value.to_slice())
}
}
(RefValue::Text(lhs_text), RefValue::Text(rhs_text)) => index_info.key_info[field_idx]
.collation
.compare_strings(lhs_text.as_str(), rhs_text.as_str()),
(RefValue::Integer(lhs_int), RefValue::Float(rhs_float)) => {
sqlite_int_float_compare(*lhs_int, *rhs_float)
@@ -1980,7 +1899,7 @@ pub fn compare_records_generic(
_ => lhs_value.partial_cmp(rhs_value).unwrap(),
};
let final_comparison = match index_info.sort_order.get_sort_order_for_col(field_idx) {
let final_comparison = match index_info.key_info[field_idx].sort_order {
SortOrder::Asc => comparison,
SortOrder::Desc => comparison.reverse(),
};
@@ -1995,11 +1914,6 @@ pub fn compare_records_generic(
Ok(tie_breaker)
}
#[inline(always)]
fn is_binary_collation(collations: &[CollationSeq], col_idx: usize) -> bool {
collations[col_idx] == CollationSeq::Binary
}
const I8_LOW: i64 = -128;
const I8_HIGH: i64 = 127;
const I16_LOW: i64 = -32768;
@@ -2428,15 +2342,14 @@ mod tests {
pub fn compare_immutable_for_testing(
l: &[RefValue],
r: &[RefValue],
index_key_sort_order: IndexKeySortOrder,
collations: &[CollationSeq],
index_key_info: &[KeyInfo],
tie_breaker: std::cmp::Ordering,
) -> std::cmp::Ordering {
let min_len = l.len().min(r.len());
for i in 0..min_len {
let column_order = index_key_sort_order.get_sort_order_for_col(i);
let collation = collations.get(i).copied().unwrap_or_default();
let column_order = index_key_info[i].sort_order;
let collation = index_key_info[i].collation;
let cmp = match (&l[i], &r[i]) {
(RefValue::Text(left), RefValue::Text(right)) => {
@@ -2461,9 +2374,20 @@ mod tests {
ImmutableRecord::from_registers(&registers, registers.len())
}
fn create_index_info(num_cols: usize, sort_orders: Vec<SortOrder>) -> IndexKeyInfo {
IndexKeyInfo {
sort_order: IndexKeySortOrder::from_list(sort_orders.as_slice()),
fn create_index_info(
num_cols: usize,
sort_orders: Vec<SortOrder>,
collations: Vec<CollationSeq>,
) -> IndexInfo {
IndexInfo {
key_info: sort_orders
.into_iter()
.zip(collations)
.map(|(sort_order, collation)| KeyInfo {
sort_order,
collation,
})
.collect(),
has_rowid: false,
num_cols,
}
@@ -2503,8 +2427,7 @@ mod tests {
fn assert_compare_matches_full_comparison(
serialized_values: Vec<Value>,
unpacked_values: Vec<RefValue>,
index_info: &IndexKeyInfo,
collations: &[CollationSeq],
index_info: &IndexInfo,
test_name: &str,
) {
let serialized = create_record(serialized_values.clone());
@@ -2517,21 +2440,13 @@ mod tests {
let gold_result = compare_immutable_for_testing(
&serialized_ref_values,
&unpacked_values,
index_info.sort_order,
collations,
&index_info.key_info,
tie_breaker,
);
let comparer = find_compare(&unpacked_values, index_info, collations);
let comparer = find_compare(&unpacked_values, index_info);
let optimized_result = comparer
.compare(
&serialized,
&unpacked_values,
index_info,
collations,
0,
tie_breaker,
)
.compare(&serialized, &unpacked_values, index_info, 0, tie_breaker)
.unwrap();
assert_eq!(
@@ -2540,15 +2455,9 @@ mod tests {
test_name, gold_result, optimized_result, comparer
);
let generic_result = compare_records_generic(
&serialized,
&unpacked_values,
index_info,
collations,
0,
tie_breaker,
)
.unwrap();
let generic_result =
compare_records_generic(&serialized, &unpacked_values, index_info, 0, tie_breaker)
.unwrap();
assert_eq!(
gold_result, generic_result,
"Test '{}' failed with generic: Full Comparison: {:?}, Generic: {:?}",
@@ -2616,8 +2525,11 @@ mod tests {
#[test]
fn test_integer_fast_path() {
let index_info = create_index_info(2, vec![SortOrder::Asc, SortOrder::Asc]);
let collations = vec![CollationSeq::Binary, CollationSeq::Binary];
let index_info = create_index_info(
2,
vec![SortOrder::Asc, SortOrder::Asc],
vec![CollationSeq::Binary; 2],
);
let test_cases = vec![
(
@@ -2678,7 +2590,6 @@ mod tests {
serialized_values,
unpacked_values,
&index_info,
&collations,
test_name,
);
}
@@ -2686,8 +2597,11 @@ mod tests {
#[test]
fn test_string_fast_path() {
let index_info = create_index_info(2, vec![SortOrder::Asc, SortOrder::Asc]);
let collations = vec![CollationSeq::Binary, CollationSeq::Binary];
let index_info = create_index_info(
2,
vec![SortOrder::Asc, SortOrder::Asc],
vec![CollationSeq::Binary; 2],
);
let test_cases = vec![
(
@@ -2739,7 +2653,6 @@ mod tests {
serialized_values,
unpacked_values,
&index_info,
&collations,
test_name,
);
}
@@ -2747,8 +2660,7 @@ mod tests {
#[test]
fn test_type_precedence() {
let index_info = create_index_info(1, vec![SortOrder::Asc]);
let collations = vec![CollationSeq::Binary];
let index_info = create_index_info(1, vec![SortOrder::Asc], vec![CollationSeq::Binary]);
// Test SQLite type precedence: NULL < Numbers < Text < Blob
let test_cases = vec![
@@ -2823,7 +2735,6 @@ mod tests {
serialized_values,
unpacked_values,
&index_info,
&collations,
test_name,
);
}
@@ -2831,8 +2742,11 @@ mod tests {
#[test]
fn test_sort_order_desc() {
let index_info = create_index_info(2, vec![SortOrder::Desc, SortOrder::Asc]);
let collations = vec![CollationSeq::Binary, CollationSeq::Binary];
let index_info = create_index_info(
2,
vec![SortOrder::Desc, SortOrder::Asc],
vec![CollationSeq::Binary; 2],
);
let test_cases = vec![
// DESC order should reverse first field comparison
@@ -2862,7 +2776,6 @@ mod tests {
serialized_values,
unpacked_values,
&index_info,
&collations,
test_name,
);
}
@@ -2870,12 +2783,8 @@ mod tests {
#[test]
fn test_edge_cases() {
let index_info = create_index_info(3, vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc]);
let collations = vec![
CollationSeq::Binary,
CollationSeq::Binary,
CollationSeq::Binary,
];
let index_info =
create_index_info(15, vec![SortOrder::Asc; 15], vec![CollationSeq::Binary; 15]);
let test_cases = vec![
(
@@ -2923,7 +2832,6 @@ mod tests {
serialized_values,
unpacked_values,
&index_info,
&collations,
test_name,
);
}
@@ -2931,12 +2839,11 @@ mod tests {
#[test]
fn test_skip_parameter() {
let index_info = create_index_info(3, vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc]);
let collations = vec![
CollationSeq::Binary,
CollationSeq::Binary,
CollationSeq::Binary,
];
let index_info = create_index_info(
3,
vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc],
vec![CollationSeq::Binary; 3],
);
let serialized = create_record(vec![
Value::Integer(1),
@@ -2950,24 +2857,10 @@ mod tests {
];
let tie_breaker = std::cmp::Ordering::Equal;
let result_skip_0 = compare_records_generic(
&serialized,
&unpacked,
&index_info,
&collations,
0,
tie_breaker,
)
.unwrap();
let result_skip_1 = compare_records_generic(
&serialized,
&unpacked,
&index_info,
&collations,
1,
tie_breaker,
)
.unwrap();
let result_skip_0 =
compare_records_generic(&serialized, &unpacked, &index_info, 0, tie_breaker).unwrap();
let result_skip_1 =
compare_records_generic(&serialized, &unpacked, &index_info, 1, tie_breaker).unwrap();
assert_eq!(result_skip_0, std::cmp::Ordering::Less);
@@ -2976,17 +2869,21 @@ mod tests {
#[test]
fn test_strategy_selection() {
let index_info_small =
create_index_info(3, vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc]);
let index_info_large = create_index_info(15, vec![SortOrder::Asc; 15]);
let collations = vec![CollationSeq::Binary, CollationSeq::Binary];
let collations_small = vec![CollationSeq::Binary; 3];
let collations_large = vec![CollationSeq::Binary; 15];
let index_info_small = create_index_info(
3,
vec![SortOrder::Asc, SortOrder::Asc, SortOrder::Asc],
collations_small,
);
let index_info_large = create_index_info(15, vec![SortOrder::Asc; 15], collations_large);
let int_values = vec![
RefValue::Integer(42),
RefValue::Text(TextRef::from_str("hello")),
];
assert!(matches!(
find_compare(&int_values, &index_info_small, &collations),
find_compare(&int_values, &index_info_small),
RecordCompare::Int
));
@@ -2995,19 +2892,19 @@ mod tests {
RefValue::Integer(42),
];
assert!(matches!(
find_compare(&string_values, &index_info_small, &collations),
find_compare(&string_values, &index_info_small),
RecordCompare::String
));
let large_values: Vec<RefValue> = (0..15).map(RefValue::Integer).collect();
assert!(matches!(
find_compare(&large_values, &index_info_large, &collations),
find_compare(&large_values, &index_info_large),
RecordCompare::Generic
));
let blob_values = vec![RefValue::Blob(RawSlice::from_slice(&[1, 2, 3]))];
assert!(matches!(
find_compare(&blob_values, &index_info_small, &collations),
find_compare(&blob_values, &index_info_small),
RecordCompare::Generic
));
}

View File

@@ -907,31 +907,11 @@ pub fn op_open_read(
.replace(Cursor::new_btree(cursor));
}
CursorType::BTreeIndex(index) => {
let conn = program.connection.clone();
let schema = conn.schema.borrow();
let table = schema
.get_table(&index.table_name)
.and_then(|table| table.btree());
let collations = table.map_or(Vec::new(), |table| {
index
.columns
.iter()
.map(|c| {
table
.columns
.get(c.pos_in_table)
.unwrap()
.collation
.unwrap_or_default()
})
.collect()
});
let cursor = BTreeCursor::new_index(
mv_cursor,
pager.clone(),
*root_page,
index.as_ref(),
collations,
num_columns,
);
cursors
@@ -2824,10 +2804,9 @@ pub fn op_idx_ge(
registers_to_ref_values(&state.registers[*start_reg..*start_reg + *num_regs]);
let tie_breaker = get_tie_breaker_from_idx_comp_op(insn);
let ord = compare_records_generic(
&idx_record, // The serialized record from the index
&values, // The record built from registers
&cursor.index_key_info.unwrap(), // Sort order flags
cursor.collations(), // Collation sequences
&idx_record, // The serialized record from the index
&values, // The record built from registers
cursor.index_info.as_ref().unwrap(), // Sort order flags
0,
tie_breaker,
)?;
@@ -2896,8 +2875,7 @@ pub fn op_idx_le(
let ord = compare_records_generic(
&idx_record,
&values,
&cursor.index_key_info.unwrap(),
cursor.collations(),
cursor.index_info.as_ref().unwrap(),
0,
tie_breaker,
)?;
@@ -2948,8 +2926,7 @@ pub fn op_idx_gt(
let ord = compare_records_generic(
&idx_record,
&values,
&cursor.index_key_info.unwrap(),
cursor.collations(),
cursor.index_info.as_ref().unwrap(),
0,
tie_breaker,
)?;
@@ -3001,8 +2978,7 @@ pub fn op_idx_lt(
let ord = compare_records_generic(
&idx_record,
&values,
&cursor.index_key_info.unwrap(),
cursor.collations(),
cursor.index_info.as_ref().unwrap(),
0,
tie_breaker,
)?;
@@ -5152,8 +5128,7 @@ pub fn op_idx_insert(
let conflict = compare_immutable(
existing_key.as_slice(),
inserted_key_vals,
cursor.key_sort_order(),
&cursor.collations,
&cursor.index_info.as_ref().unwrap().key_info,
) == std::cmp::Ordering::Equal;
if conflict {
if flags.has(IdxInsertFlags::NO_OP_DUPLICATE) {
@@ -5586,27 +5561,11 @@ pub fn op_open_write(
.and_then(|table| table.btree());
let num_columns = index.columns.len();
let collations = table.map_or(Vec::new(), |table| {
index
.columns
.iter()
.map(|c| {
table
.columns
.get(c.pos_in_table)
.unwrap()
.collation
.unwrap_or_default()
})
.collect()
});
let cursor = BTreeCursor::new_index(
mv_cursor,
pager.clone(),
root_page as usize,
index.as_ref(),
collations,
num_columns,
);
cursors
@@ -5695,7 +5654,7 @@ pub fn op_destroy(
todo!("temp databases not implemented yet.");
}
// TODO not sure if should be BTreeCursor::new_table or BTreeCursor::new_index here or neither and just pass an emtpy vec
let mut cursor = BTreeCursor::new(None, pager.clone(), *root, Vec::new(), 0);
let mut cursor = BTreeCursor::new(None, pager.clone(), *root, 0);
let former_root_page_result = cursor.btree_destroy()?;
if let IOResult::Done(former_root_page) = former_root_page_result {
state.registers[*former_root_reg] =
@@ -6151,11 +6110,6 @@ pub fn op_open_ephemeral(
pager.clone(),
root_page as usize,
index,
index
.columns
.iter()
.map(|c| c.collation.unwrap_or_default())
.collect(),
num_columns,
)
} else {

View File

@@ -2,25 +2,31 @@ use turso_sqlite3_parser::ast::SortOrder;
use crate::{
translate::collate::CollationSeq,
types::{compare_immutable, ImmutableRecord, IndexKeySortOrder},
types::{compare_immutable, ImmutableRecord, KeyInfo},
};
pub struct Sorter {
records: Vec<ImmutableRecord>,
current: Option<ImmutableRecord>,
order: IndexKeySortOrder,
key_len: usize,
collations: Vec<CollationSeq>,
index_key_info: Vec<KeyInfo>,
}
impl Sorter {
pub fn new(order: &[SortOrder], collations: Vec<CollationSeq>) -> Self {
assert_eq!(order.len(), collations.len());
Self {
records: Vec::new(),
current: None,
key_len: order.len(),
order: IndexKeySortOrder::from_list(order),
collations,
index_key_info: order
.iter()
.zip(collations)
.map(|(order, collation)| KeyInfo {
sort_order: *order,
collation,
})
.collect(),
}
}
pub fn is_empty(&self) -> bool {
@@ -49,7 +55,7 @@ impl Sorter {
&b_values[..]
};
compare_immutable(a_key, b_key, self.order, &self.collations)
compare_immutable(a_key, b_key, &self.index_key_info)
});
self.records.reverse();
self.next()