mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-31 05:44:25 +01:00
Merge 'Toy index improvements' from Nikita Sivukhin
This PR implements more sophisticated algorithm in the toy vector sparse index: now we enumerate components based on the frequency (in order to check unpopular "features" first) and also estimate length threshold which can give us better results compared with current top-k set. Also, this PR adds optional `delta` parameter which can enable approximate search which will return results with score not more than `delta` away from the optimal. In order to implement this index method - index code were slightly adjusted in order to allow to store some non-key payload in the index rows. So, now index can hold N columns where first K <= N columns will be used as identity (before that K always was equal to N). Reviewed-by: Jussi Saurio <jussi.saurio@gmail.com> Closes #3862
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -18,7 +18,9 @@ pub fn translate_integrity_check(
|
||||
root_pages.push(table.root_page);
|
||||
if let Some(indexes) = schema.indexes.get(table.name.as_str()) {
|
||||
for index in indexes.iter() {
|
||||
root_pages.push(index.root_page);
|
||||
if index.root_page > 0 {
|
||||
root_pages.push(index.root_page);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,7 +16,7 @@ use crate::translate::plan::IterationDirection;
|
||||
use crate::vdbe::sorter::Sorter;
|
||||
use crate::vdbe::Register;
|
||||
use crate::vtab::VirtualTableCursor;
|
||||
use crate::{turso_assert, Completion, CompletionError, Result, IO};
|
||||
use crate::{Completion, CompletionError, Result, IO};
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::task::Waker;
|
||||
|
||||
@@ -1594,9 +1594,21 @@ pub fn compare_immutable(
|
||||
r: &[ValueRef],
|
||||
column_info: &[KeyInfo],
|
||||
) -> std::cmp::Ordering {
|
||||
assert_eq!(l.len(), r.len());
|
||||
turso_assert!(column_info.len() >= l.len(), "column_info.len() < l.len()");
|
||||
for (i, (l, r)) in l.iter().zip(r).enumerate() {
|
||||
assert!(
|
||||
l.len() >= column_info.len(),
|
||||
"{} < {}",
|
||||
l.len(),
|
||||
column_info.len()
|
||||
);
|
||||
assert!(
|
||||
r.len() >= column_info.len(),
|
||||
"{} < {}",
|
||||
r.len(),
|
||||
column_info.len()
|
||||
);
|
||||
let l_values = l.iter().take(column_info.len());
|
||||
let r_values = r.iter().take(column_info.len());
|
||||
for (i, (l, r)) in l_values.zip(r_values).enumerate() {
|
||||
let column_order = column_info[i].sort_order;
|
||||
let collation = column_info[i].collation;
|
||||
let cmp = match (l, r) {
|
||||
@@ -1720,10 +1732,6 @@ fn compare_records_int(
|
||||
index_info: &IndexInfo,
|
||||
tie_breaker: std::cmp::Ordering,
|
||||
) -> Result<std::cmp::Ordering> {
|
||||
turso_assert!(
|
||||
index_info.key_info.len() >= unpacked.len(),
|
||||
"index_info.key_info.len() < unpacked.len()"
|
||||
);
|
||||
let payload = serialized.get_payload();
|
||||
if payload.len() < 2 {
|
||||
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
|
||||
@@ -1813,10 +1821,6 @@ fn compare_records_string(
|
||||
index_info: &IndexInfo,
|
||||
tie_breaker: std::cmp::Ordering,
|
||||
) -> Result<std::cmp::Ordering> {
|
||||
turso_assert!(
|
||||
index_info.key_info.len() >= unpacked.len(),
|
||||
"index_info.key_info.len() < unpacked.len()"
|
||||
);
|
||||
let payload = serialized.get_payload();
|
||||
if payload.len() < 2 {
|
||||
return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker);
|
||||
@@ -1926,10 +1930,6 @@ pub fn compare_records_generic(
|
||||
skip: usize,
|
||||
tie_breaker: std::cmp::Ordering,
|
||||
) -> Result<std::cmp::Ordering> {
|
||||
turso_assert!(
|
||||
index_info.key_info.len() >= unpacked.len(),
|
||||
"index_info.key_info.len() < unpacked.len()"
|
||||
);
|
||||
let payload = serialized.get_payload();
|
||||
if payload.is_empty() {
|
||||
return Ok(std::cmp::Ordering::Less);
|
||||
@@ -1960,7 +1960,8 @@ pub fn compare_records_generic(
|
||||
}
|
||||
|
||||
let mut field_idx = skip;
|
||||
while field_idx < unpacked.len() && header_pos < header_end {
|
||||
let field_limit = unpacked.len().min(index_info.key_info.len());
|
||||
while field_idx < field_limit && header_pos < header_end {
|
||||
let (serial_type_raw, bytes_read) = read_varint(&payload[header_pos..])?;
|
||||
header_pos += bytes_read;
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ pub fn vector_slice(vector: &Vector, start: usize, end: usize) -> Result<Vector<
|
||||
continue;
|
||||
}
|
||||
values.extend_from_slice(&value.to_le_bytes());
|
||||
idx.extend_from_slice(&i.to_le_bytes());
|
||||
idx.extend_from_slice(&((i - start) as u32).to_le_bytes());
|
||||
}
|
||||
values.extend_from_slice(&idx);
|
||||
Ok(Vector {
|
||||
|
||||
@@ -75,7 +75,10 @@ fn test_vector_sparse_ivf_create_destroy() {
|
||||
run(&tmp_db, || cursor.create(&conn)).unwrap();
|
||||
}
|
||||
conn.wal_insert_end(true).unwrap();
|
||||
assert_eq!(schema_rows(), vec!["t", "t_idx_scratch"]);
|
||||
assert_eq!(
|
||||
schema_rows(),
|
||||
vec!["t", "t_idx_inverted_index", "t_idx_stats"]
|
||||
);
|
||||
|
||||
conn.wal_insert_begin().unwrap();
|
||||
{
|
||||
@@ -264,8 +267,8 @@ fn test_vector_sparse_ivf_fuzz() {
|
||||
const MOD: u32 = 5;
|
||||
|
||||
let (mut rng, _) = rng_from_time_or_env();
|
||||
let mut keys = Vec::new();
|
||||
for _ in 0..10 {
|
||||
let mut operation = 0;
|
||||
for delta in [0.0, 0.01, 0.05, 0.1, 0.5] {
|
||||
let seed = rng.next_u64();
|
||||
tracing::info!("======== seed: {} ========", seed);
|
||||
|
||||
@@ -274,11 +277,19 @@ fn test_vector_sparse_ivf_fuzz() {
|
||||
TempDatabase::new_with_rusqlite("CREATE TABLE t(key TEXT PRIMARY KEY, embedding)");
|
||||
let index_db =
|
||||
TempDatabase::new_with_rusqlite("CREATE TABLE t(key TEXT PRIMARY KEY, embedding)");
|
||||
tracing::info!(
|
||||
"simple_db: {:?}, index_db: {:?}",
|
||||
simple_db.path,
|
||||
index_db.path,
|
||||
);
|
||||
let simple_conn = simple_db.connect_limbo();
|
||||
let index_conn = index_db.connect_limbo();
|
||||
simple_conn.wal_auto_checkpoint_disable();
|
||||
index_conn.wal_auto_checkpoint_disable();
|
||||
index_conn
|
||||
.execute("CREATE INDEX t_idx ON t USING toy_vector_sparse_ivf (embedding)")
|
||||
.execute(format!("CREATE INDEX t_idx ON t USING toy_vector_sparse_ivf (embedding) WITH (delta = {delta})"))
|
||||
.unwrap();
|
||||
|
||||
let vector = |rng: &mut ChaCha8Rng| {
|
||||
let mut values = Vec::with_capacity(DIMS);
|
||||
for _ in 0..DIMS {
|
||||
@@ -291,13 +302,15 @@ fn test_vector_sparse_ivf_fuzz() {
|
||||
format!("[{}]", values.join(", "))
|
||||
};
|
||||
|
||||
let mut keys = Vec::new();
|
||||
for _ in 0..200 {
|
||||
let choice = rng.next_u32() % 4;
|
||||
operation += 1;
|
||||
if choice == 0 {
|
||||
let key = rng.next_u64().to_string();
|
||||
let v = vector(&mut rng);
|
||||
let sql = format!("INSERT INTO t VALUES ('{key}', vector32_sparse('{v}'))");
|
||||
tracing::info!("{}", sql);
|
||||
tracing::info!("({}) {}", operation, sql);
|
||||
simple_conn.execute(&sql).unwrap();
|
||||
index_conn.execute(sql).unwrap();
|
||||
keys.push(key);
|
||||
@@ -307,14 +320,14 @@ fn test_vector_sparse_ivf_fuzz() {
|
||||
let v = vector(&mut rng);
|
||||
let sql =
|
||||
format!("UPDATE t SET embedding = vector32_sparse('{v}') WHERE key = '{key}'",);
|
||||
tracing::info!("{}", sql);
|
||||
tracing::info!("({}) {}", operation, sql);
|
||||
simple_conn.execute(&sql).unwrap();
|
||||
index_conn.execute(&sql).unwrap();
|
||||
} else if choice == 2 && !keys.is_empty() {
|
||||
let idx = rng.next_u32() as usize % keys.len();
|
||||
let key = &keys[idx];
|
||||
let sql = format!("DELETE FROM t WHERE key = '{key}'");
|
||||
tracing::info!("{}", sql);
|
||||
tracing::info!("({}) {}", operation, sql);
|
||||
simple_conn.execute(&sql).unwrap();
|
||||
index_conn.execute(&sql).unwrap();
|
||||
keys.remove(idx);
|
||||
@@ -322,20 +335,42 @@ fn test_vector_sparse_ivf_fuzz() {
|
||||
let v = vector(&mut rng);
|
||||
let k = rng.next_u32() % 20 + 1;
|
||||
let sql = format!("SELECT key, vector_distance_jaccard(embedding, vector32_sparse('{v}')) as d FROM t ORDER BY d LIMIT {k}");
|
||||
tracing::info!("{}", sql);
|
||||
tracing::info!("({}) {}", operation, sql);
|
||||
let simple_rows = limbo_exec_rows(&simple_db, &simple_conn, &sql);
|
||||
let index_rows = limbo_exec_rows(&index_db, &index_conn, &sql);
|
||||
tracing::info!("simple: {:?}, index_rows: {:?}", simple_rows, index_rows);
|
||||
assert!(index_rows.len() <= simple_rows.len());
|
||||
for (a, b) in index_rows.iter().zip(simple_rows.iter()) {
|
||||
assert_eq!(a, b);
|
||||
if delta == 0.0 {
|
||||
assert_eq!(a, b);
|
||||
} else {
|
||||
match (&a[1], &b[1]) {
|
||||
(rusqlite::types::Value::Real(a), rusqlite::types::Value::Real(b)) => {
|
||||
assert!(
|
||||
*a >= *b || (*a - *b).abs() < 1e-5,
|
||||
"a={}, b={}, delta={}",
|
||||
*a,
|
||||
*b,
|
||||
delta
|
||||
);
|
||||
assert!(
|
||||
*a - delta <= *b || (*a - delta - *b).abs() < 1e-5,
|
||||
"a={}, b={}, delta={}",
|
||||
*a,
|
||||
*b,
|
||||
delta
|
||||
);
|
||||
}
|
||||
_ => panic!("unexpected column values"),
|
||||
}
|
||||
}
|
||||
}
|
||||
for row in simple_rows.iter().skip(index_rows.len()) {
|
||||
match row[1] {
|
||||
rusqlite::types::Value::Real(r) => assert_eq!(r, 1.0),
|
||||
rusqlite::types::Value::Real(r) => assert!((1.0 - r) < 1e-5),
|
||||
_ => panic!("unexpected simple row value"),
|
||||
}
|
||||
}
|
||||
tracing::info!("simple: {:?}, index_rows: {:?}", simple_rows, index_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user