Merge 'Toy index improvements' from Nikita Sivukhin

This PR implements more sophisticated algorithm in the toy vector sparse
index: now we enumerate components based on the frequency (in order to
check unpopular "features" first) and also estimate length threshold
which can give us better results compared with current top-k set.
Also, this PR adds optional `delta` parameter which can enable
approximate search which will return results with score not more than
`delta` away from the optimal.
In order to implement this index method - index code were slightly
adjusted in order to allow to store some non-key payload in the index
rows. So, now index can hold N columns where first K <= N columns will
be used as identity (before that K always was equal to N).

Reviewed-by: Jussi Saurio <jussi.saurio@gmail.com>

Closes #3862
This commit is contained in:
Pekka Enberg
2025-11-07 08:29:47 +02:00
committed by GitHub
5 changed files with 943 additions and 255 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -18,7 +18,9 @@ pub fn translate_integrity_check(
root_pages.push(table.root_page);
if let Some(indexes) = schema.indexes.get(table.name.as_str()) {
for index in indexes.iter() {
root_pages.push(index.root_page);
if index.root_page > 0 {
root_pages.push(index.root_page);
}
}
}
};

View File

@@ -16,7 +16,7 @@ use crate::translate::plan::IterationDirection;
use crate::vdbe::sorter::Sorter;
use crate::vdbe::Register;
use crate::vtab::VirtualTableCursor;
use crate::{turso_assert, Completion, CompletionError, Result, IO};
use crate::{Completion, CompletionError, Result, IO};
use std::fmt::{Debug, Display};
use std::task::Waker;
@@ -1594,9 +1594,21 @@ pub fn compare_immutable(
r: &[ValueRef],
column_info: &[KeyInfo],
) -> std::cmp::Ordering {
assert_eq!(l.len(), r.len());
turso_assert!(column_info.len() >= l.len(), "column_info.len() < l.len()");
for (i, (l, r)) in l.iter().zip(r).enumerate() {
assert!(
l.len() >= column_info.len(),
"{} < {}",
l.len(),
column_info.len()
);
assert!(
r.len() >= column_info.len(),
"{} < {}",
r.len(),
column_info.len()
);
let l_values = l.iter().take(column_info.len());
let r_values = r.iter().take(column_info.len());
for (i, (l, r)) in l_values.zip(r_values).enumerate() {
let column_order = column_info[i].sort_order;
let collation = column_info[i].collation;
let cmp = match (l, r) {
@@ -1720,10 +1732,6 @@ fn compare_records_int(
index_info: &IndexInfo,
tie_breaker: std::cmp::Ordering,
) -> Result<std::cmp::Ordering> {
turso_assert!(
index_info.key_info.len() >= unpacked.len(),
"index_info.key_info.len() < unpacked.len()"
);
let payload = serialized.get_payload();
if payload.len() < 2 {
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
@@ -1813,10 +1821,6 @@ fn compare_records_string(
index_info: &IndexInfo,
tie_breaker: std::cmp::Ordering,
) -> Result<std::cmp::Ordering> {
turso_assert!(
index_info.key_info.len() >= unpacked.len(),
"index_info.key_info.len() < unpacked.len()"
);
let payload = serialized.get_payload();
if payload.len() < 2 {
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
@@ -1926,10 +1930,6 @@ pub fn compare_records_generic(
skip: usize,
tie_breaker: std::cmp::Ordering,
) -> Result<std::cmp::Ordering> {
turso_assert!(
index_info.key_info.len() >= unpacked.len(),
"index_info.key_info.len() < unpacked.len()"
);
let payload = serialized.get_payload();
if payload.is_empty() {
return Ok(std::cmp::Ordering::Less);
@@ -1960,7 +1960,8 @@ pub fn compare_records_generic(
}
let mut field_idx = skip;
while field_idx < unpacked.len() && header_pos < header_end {
let field_limit = unpacked.len().min(index_info.key_info.len());
while field_idx < field_limit && header_pos < header_end {
let (serial_type_raw, bytes_read) = read_varint(&payload[header_pos..])?;
header_pos += bytes_read;

View File

@@ -37,7 +37,7 @@ pub fn vector_slice(vector: &Vector, start: usize, end: usize) -> Result<Vector<
continue;
}
values.extend_from_slice(&value.to_le_bytes());
idx.extend_from_slice(&i.to_le_bytes());
idx.extend_from_slice(&((i - start) as u32).to_le_bytes());
}
values.extend_from_slice(&idx);
Ok(Vector {

View File

@@ -75,7 +75,10 @@ fn test_vector_sparse_ivf_create_destroy() {
run(&tmp_db, || cursor.create(&conn)).unwrap();
}
conn.wal_insert_end(true).unwrap();
assert_eq!(schema_rows(), vec!["t", "t_idx_scratch"]);
assert_eq!(
schema_rows(),
vec!["t", "t_idx_inverted_index", "t_idx_stats"]
);
conn.wal_insert_begin().unwrap();
{
@@ -264,8 +267,8 @@ fn test_vector_sparse_ivf_fuzz() {
const MOD: u32 = 5;
let (mut rng, _) = rng_from_time_or_env();
let mut keys = Vec::new();
for _ in 0..10 {
let mut operation = 0;
for delta in [0.0, 0.01, 0.05, 0.1, 0.5] {
let seed = rng.next_u64();
tracing::info!("======== seed: {} ========", seed);
@@ -274,11 +277,19 @@ fn test_vector_sparse_ivf_fuzz() {
TempDatabase::new_with_rusqlite("CREATE TABLE t(key TEXT PRIMARY KEY, embedding)");
let index_db =
TempDatabase::new_with_rusqlite("CREATE TABLE t(key TEXT PRIMARY KEY, embedding)");
tracing::info!(
"simple_db: {:?}, index_db: {:?}",
simple_db.path,
index_db.path,
);
let simple_conn = simple_db.connect_limbo();
let index_conn = index_db.connect_limbo();
simple_conn.wal_auto_checkpoint_disable();
index_conn.wal_auto_checkpoint_disable();
index_conn
.execute("CREATE INDEX t_idx ON t USING toy_vector_sparse_ivf (embedding)")
.execute(format!("CREATE INDEX t_idx ON t USING toy_vector_sparse_ivf (embedding) WITH (delta = {delta})"))
.unwrap();
let vector = |rng: &mut ChaCha8Rng| {
let mut values = Vec::with_capacity(DIMS);
for _ in 0..DIMS {
@@ -291,13 +302,15 @@ fn test_vector_sparse_ivf_fuzz() {
format!("[{}]", values.join(", "))
};
let mut keys = Vec::new();
for _ in 0..200 {
let choice = rng.next_u32() % 4;
operation += 1;
if choice == 0 {
let key = rng.next_u64().to_string();
let v = vector(&mut rng);
let sql = format!("INSERT INTO t VALUES ('{key}', vector32_sparse('{v}'))");
tracing::info!("{}", sql);
tracing::info!("({}) {}", operation, sql);
simple_conn.execute(&sql).unwrap();
index_conn.execute(sql).unwrap();
keys.push(key);
@@ -307,14 +320,14 @@ fn test_vector_sparse_ivf_fuzz() {
let v = vector(&mut rng);
let sql =
format!("UPDATE t SET embedding = vector32_sparse('{v}') WHERE key = '{key}'",);
tracing::info!("{}", sql);
tracing::info!("({}) {}", operation, sql);
simple_conn.execute(&sql).unwrap();
index_conn.execute(&sql).unwrap();
} else if choice == 2 && !keys.is_empty() {
let idx = rng.next_u32() as usize % keys.len();
let key = &keys[idx];
let sql = format!("DELETE FROM t WHERE key = '{key}'");
tracing::info!("{}", sql);
tracing::info!("({}) {}", operation, sql);
simple_conn.execute(&sql).unwrap();
index_conn.execute(&sql).unwrap();
keys.remove(idx);
@@ -322,20 +335,42 @@ fn test_vector_sparse_ivf_fuzz() {
let v = vector(&mut rng);
let k = rng.next_u32() % 20 + 1;
let sql = format!("SELECT key, vector_distance_jaccard(embedding, vector32_sparse('{v}')) as d FROM t ORDER BY d LIMIT {k}");
tracing::info!("{}", sql);
tracing::info!("({}) {}", operation, sql);
let simple_rows = limbo_exec_rows(&simple_db, &simple_conn, &sql);
let index_rows = limbo_exec_rows(&index_db, &index_conn, &sql);
tracing::info!("simple: {:?}, index_rows: {:?}", simple_rows, index_rows);
assert!(index_rows.len() <= simple_rows.len());
for (a, b) in index_rows.iter().zip(simple_rows.iter()) {
assert_eq!(a, b);
if delta == 0.0 {
assert_eq!(a, b);
} else {
match (&a[1], &b[1]) {
(rusqlite::types::Value::Real(a), rusqlite::types::Value::Real(b)) => {
assert!(
*a >= *b || (*a - *b).abs() < 1e-5,
"a={}, b={}, delta={}",
*a,
*b,
delta
);
assert!(
*a - delta <= *b || (*a - delta - *b).abs() < 1e-5,
"a={}, b={}, delta={}",
*a,
*b,
delta
);
}
_ => panic!("unexpected column values"),
}
}
}
for row in simple_rows.iter().skip(index_rows.len()) {
match row[1] {
rusqlite::types::Value::Real(r) => assert_eq!(r, 1.0),
rusqlite::types::Value::Real(r) => assert!((1.0 - r) < 1e-5),
_ => panic!("unexpected simple row value"),
}
}
tracing::info!("simple: {:?}, index_rows: {:?}", simple_rows, index_rows);
}
}
}