mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-03 00:14:21 +01:00
Merge 'Initial pass on vector extension' from Pekka Enberg
This pull requests adds some libSQL vector extension functions such as `vector()` and `vector_distance_cos()`, which can be used for exact nearest neighbor search as follows: ``` limbo> SELECT embedding, vector_distance_cos(embedding, '[9, 9, 9]') ...> FROM movies ORDER BY vector_distance_cos(embedding, '[9, 9, 9]'); [4, 5, 6]|0.013072490692138672 [1, 2, 3]|0.07417994737625122 ``` Note that libSQL also support approximate nearest neighbour search with DiskANN indexing, which is something we eventually want to port to Limbo as well. Closes #798
This commit is contained in:
12
COMPAT.md
12
COMPAT.md
@@ -616,3 +616,15 @@ The `regexp` extension is compatible with [sqlean-regexp](https://github.com/nal
|
||||
| regexp_substr(source, pattern) | Yes | |
|
||||
| regexp_capture(source, pattern[, n]) | No | |
|
||||
| regexp_replace(source, pattern, replacement) | No | |
|
||||
|
||||
### Vector
|
||||
|
||||
The `vector` extension is compatible with libSQL native vector search.
|
||||
|
||||
| Function | Status | Comment |
|
||||
|------------------------------------------------|--------|---------|
|
||||
| vector(x) | Yes | |
|
||||
| vector32(x) | Yes | |
|
||||
| vector64(x) | Yes | |
|
||||
| vector_extract(x) | Yes | |
|
||||
| vector_distance_cos(x, y) | Yes | |
|
||||
|
||||
8
Cargo.lock
generated
8
Cargo.lock
generated
@@ -1342,6 +1342,7 @@ dependencies = [
|
||||
"limbo_percentile",
|
||||
"limbo_regexp",
|
||||
"limbo_uuid",
|
||||
"limbo_vector",
|
||||
"log",
|
||||
"miette",
|
||||
"mimalloc",
|
||||
@@ -1441,6 +1442,13 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "limbo_vector"
|
||||
version = "0.0.13"
|
||||
dependencies = [
|
||||
"limbo_ext",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.15"
|
||||
|
||||
@@ -18,6 +18,7 @@ members = [
|
||||
"sqlite3",
|
||||
"tests",
|
||||
"extensions/percentile",
|
||||
"extensions/vector",
|
||||
]
|
||||
exclude = ["perf/latency/limbo"]
|
||||
|
||||
|
||||
6
Makefile
6
Makefile
@@ -62,7 +62,7 @@ limbo-wasm:
|
||||
cargo build --package limbo-wasm --target wasm32-wasi
|
||||
.PHONY: limbo-wasm
|
||||
|
||||
test: limbo test-compat test-sqlite3 test-shell test-extensions
|
||||
test: limbo test-compat test-vector test-sqlite3 test-shell test-extensions
|
||||
.PHONY: test
|
||||
|
||||
test-extensions: limbo
|
||||
@@ -78,6 +78,10 @@ test-compat:
|
||||
SQLITE_EXEC=$(SQLITE_EXEC) ./testing/all.test
|
||||
.PHONY: test-compat
|
||||
|
||||
test-vector:
|
||||
SQLITE_EXEC=$(SQLITE_EXEC) ./testing/vector.test
|
||||
.PHONY: test-vector
|
||||
|
||||
test-sqlite3: limbo-c
|
||||
LIBS="$(SQLITE_LIB)" HEADERS="$(SQLITE_LIB_HEADERS)" make -C sqlite3/tests test
|
||||
.PHONY: test-sqlite3
|
||||
|
||||
@@ -14,7 +14,7 @@ name = "limbo_core"
|
||||
path = "lib.rs"
|
||||
|
||||
[features]
|
||||
default = ["fs", "json", "uuid", "io_uring"]
|
||||
default = ["fs", "json", "uuid", "vector", "io_uring"]
|
||||
fs = []
|
||||
json = [
|
||||
"dep:jsonb",
|
||||
@@ -22,6 +22,7 @@ json = [
|
||||
"dep:pest_derive",
|
||||
]
|
||||
uuid = ["limbo_uuid/static"]
|
||||
vector = ["limbo_vector/static"]
|
||||
io_uring = ["dep:io-uring", "rustix/io_uring"]
|
||||
percentile = ["limbo_percentile/static"]
|
||||
regexp = ["limbo_regexp/static"]
|
||||
@@ -61,6 +62,7 @@ rand = "0.8.5"
|
||||
bumpalo = { version = "3.16.0", features = ["collections", "boxed"] }
|
||||
limbo_macros = { path = "../macros" }
|
||||
limbo_uuid = { path = "../extensions/uuid", optional = true, features = ["static"] }
|
||||
limbo_vector = { path = "../extensions/vector", optional = true, features = ["static"] }
|
||||
limbo_regexp = { path = "../extensions/regexp", optional = true, features = ["static"] }
|
||||
limbo_percentile = { path = "../extensions/percentile", optional = true, features = ["static"] }
|
||||
miette = "7.4.0"
|
||||
|
||||
@@ -80,6 +80,10 @@ impl Database {
|
||||
if unsafe { !limbo_uuid::register_extension_static(&ext_api).is_ok() } {
|
||||
return Err("Failed to register uuid extension".to_string());
|
||||
}
|
||||
#[cfg(feature = "vector")]
|
||||
if unsafe { !limbo_vector::register_extension_static(&ext_api).is_ok() } {
|
||||
return Err("Failed to register vector extension".to_string());
|
||||
}
|
||||
#[cfg(feature = "percentile")]
|
||||
if unsafe { !limbo_percentile::register_extension_static(&ext_api).is_ok() } {
|
||||
return Err("Failed to register percentile extension".to_string());
|
||||
|
||||
16
extensions/vector/Cargo.toml
Normal file
16
extensions/vector/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "limbo_vector"
|
||||
version.workspace = true
|
||||
authors.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "lib"]
|
||||
|
||||
[features]
|
||||
static= [ "limbo_ext/static" ]
|
||||
|
||||
[dependencies]
|
||||
limbo_ext = { path = "../core", features = ["static"] }
|
||||
77
extensions/vector/src/lib.rs
Normal file
77
extensions/vector/src/lib.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use limbo_ext::{register_extension, scalar, ResultCode, Value};
|
||||
|
||||
mod vector;
|
||||
|
||||
use vector::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Error {
|
||||
InvalidType,
|
||||
InvalidFormat,
|
||||
InvalidDimensions,
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[scalar(name = "vector32", alias = "vector")]
|
||||
fn vector32(args: &[Value]) -> Value {
|
||||
if args.len() != 1 {
|
||||
return Value::error(ResultCode::Error);
|
||||
}
|
||||
let Ok(x) = parse_vector(&args[0], Some(VectorType::Float32)) else {
|
||||
return Value::error(ResultCode::Error);
|
||||
};
|
||||
vector_serialize_f32(x)
|
||||
}
|
||||
|
||||
#[scalar(name = "vector64")]
|
||||
fn vector64(args: &[Value]) -> Value {
|
||||
if args.len() != 1 {
|
||||
return Value::error(ResultCode::Error);
|
||||
}
|
||||
let Ok(x) = parse_vector(&args[0], Some(VectorType::Float64)) else {
|
||||
return Value::error(ResultCode::Error);
|
||||
};
|
||||
vector_serialize_f64(x)
|
||||
}
|
||||
|
||||
#[scalar(name = "vector_extract")]
|
||||
fn vector_extract(args: &[Value]) -> Value {
|
||||
if args.len() != 1 {
|
||||
return Value::error(ResultCode::Error);
|
||||
}
|
||||
let Some(blob) = args[0].to_blob() else {
|
||||
return Value::error(ResultCode::Error);
|
||||
};
|
||||
if blob.is_empty() {
|
||||
return Value::from_text("[]".to_string());
|
||||
}
|
||||
let Ok(vector_type) = vector_type(&blob) else {
|
||||
return Value::error(ResultCode::Error);
|
||||
};
|
||||
let Ok(vector) = vector_deserialize(vector_type, &blob) else {
|
||||
return Value::error(ResultCode::Error);
|
||||
};
|
||||
Value::from_text(vector_to_text(&vector))
|
||||
}
|
||||
|
||||
#[scalar(name = "vector_distance_cos")]
|
||||
fn vector_distance_cos(args: &[Value]) -> Value {
|
||||
if args.len() != 2 {
|
||||
return Value::error(ResultCode::Error);
|
||||
}
|
||||
let Ok(x) = parse_vector(&args[0], None) else {
|
||||
return Value::error(ResultCode::Error);
|
||||
};
|
||||
let Ok(y) = parse_vector(&args[1], None) else {
|
||||
return Value::error(ResultCode::Error);
|
||||
};
|
||||
let Ok(dist) = do_vector_distance_cos(&x, &y) else {
|
||||
return Value::error(ResultCode::Error);
|
||||
};
|
||||
Value::from_float(dist)
|
||||
}
|
||||
|
||||
register_extension! {
|
||||
scalars: { vector32, vector64, vector_extract, vector_distance_cos },
|
||||
}
|
||||
269
extensions/vector/src/vector.rs
Normal file
269
extensions/vector/src/vector.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
use limbo_ext::{Value, ValueType};
|
||||
|
||||
use crate::{Error, Result};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum VectorType {
|
||||
Float32,
|
||||
Float64,
|
||||
}
|
||||
|
||||
impl VectorType {
|
||||
pub fn size_to_dims(&self, size: usize) -> usize {
|
||||
match self {
|
||||
VectorType::Float32 => size / 4,
|
||||
VectorType::Float64 => size / 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Vector {
|
||||
pub vector_type: VectorType,
|
||||
pub dims: usize,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Vector {
|
||||
pub fn as_f32_slice(&self) -> &[f32] {
|
||||
unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const f32, self.dims) }
|
||||
}
|
||||
|
||||
pub fn as_f64_slice(&self) -> &[f64] {
|
||||
unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const f64, self.dims) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a vector in text representation into a Vector.
|
||||
///
|
||||
/// The format of a vector in text representation looks as follows:
|
||||
///
|
||||
/// ```console
|
||||
/// [1.0, 2.0, 3.0]
|
||||
/// ```
|
||||
pub fn parse_string_vector(vector_type: VectorType, value: &Value) -> Result<Vector> {
|
||||
let Some(text) = value.to_text() else {
|
||||
return Err(Error::InvalidFormat);
|
||||
};
|
||||
let text = text.trim();
|
||||
let mut chars = text.chars();
|
||||
if chars.next() != Some('[') || chars.last() != Some(']') {
|
||||
return Err(Error::InvalidFormat);
|
||||
}
|
||||
let mut data: Vec<u8> = Vec::new();
|
||||
let text = &text[1..text.len() - 1];
|
||||
if text.trim().is_empty() {
|
||||
return Ok(Vector {
|
||||
vector_type,
|
||||
dims: 0,
|
||||
data,
|
||||
});
|
||||
}
|
||||
let xs = text.split(',');
|
||||
for x in xs {
|
||||
let x = x.trim();
|
||||
if x.is_empty() {
|
||||
return Err(Error::InvalidFormat);
|
||||
}
|
||||
match vector_type {
|
||||
VectorType::Float32 => {
|
||||
let x = x.parse::<f32>().map_err(|_| Error::InvalidFormat)?;
|
||||
data.extend_from_slice(&x.to_le_bytes());
|
||||
}
|
||||
VectorType::Float64 => {
|
||||
let x = x.parse::<f64>().map_err(|_| Error::InvalidFormat)?;
|
||||
data.extend_from_slice(&x.to_le_bytes());
|
||||
}
|
||||
};
|
||||
}
|
||||
let dims = vector_type.size_to_dims(data.len());
|
||||
Ok(Vector {
|
||||
vector_type,
|
||||
dims,
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn parse_vector(value: &Value, vec_ty: Option<VectorType>) -> Result<Vector> {
|
||||
match value.value_type() {
|
||||
ValueType::Text => parse_string_vector(vec_ty.unwrap_or(VectorType::Float32), value),
|
||||
ValueType::Blob => {
|
||||
let Some(blob) = value.to_blob() else {
|
||||
return Err(Error::InvalidFormat);
|
||||
};
|
||||
let vector_type = vector_type(&blob)?;
|
||||
if let Some(vec_ty) = vec_ty {
|
||||
if vec_ty != vector_type {
|
||||
return Err(Error::InvalidType);
|
||||
}
|
||||
}
|
||||
vector_deserialize(vector_type, &blob)
|
||||
}
|
||||
_ => Err(Error::InvalidType),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vector_to_text(vector: &Vector) -> String {
|
||||
let mut text = String::new();
|
||||
text.push('[');
|
||||
match vector.vector_type {
|
||||
VectorType::Float32 => {
|
||||
let data = vector.as_f32_slice();
|
||||
for i in 0..vector.dims {
|
||||
text.push_str(&data[i].to_string());
|
||||
if i < vector.dims - 1 {
|
||||
text.push(',');
|
||||
}
|
||||
}
|
||||
}
|
||||
VectorType::Float64 => {
|
||||
let data = vector.as_f64_slice();
|
||||
for i in 0..vector.dims {
|
||||
text.push_str(&data[i].to_string());
|
||||
if i < vector.dims - 1 {
|
||||
text.push(',');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
text.push(']');
|
||||
text
|
||||
}
|
||||
|
||||
pub fn vector_deserialize(vector_type: VectorType, blob: &[u8]) -> Result<Vector> {
|
||||
match vector_type {
|
||||
VectorType::Float32 => vector_deserialize_f32(blob),
|
||||
VectorType::Float64 => vector_deserialize_f64(blob),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vector_serialize_f64(x: Vector) -> Value {
|
||||
let mut blob = Vec::with_capacity(x.dims * 8 + 1);
|
||||
blob.extend_from_slice(&x.data);
|
||||
blob.push(2);
|
||||
Value::from_blob(blob)
|
||||
}
|
||||
|
||||
pub fn vector_deserialize_f64(blob: &[u8]) -> Result<Vector> {
|
||||
Ok(Vector {
|
||||
vector_type: VectorType::Float64,
|
||||
dims: (blob.len() - 1) / 8,
|
||||
data: blob[..blob.len() - 1].to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn vector_serialize_f32(x: Vector) -> Value {
|
||||
Value::from_blob(x.data)
|
||||
}
|
||||
|
||||
pub fn vector_deserialize_f32(blob: &[u8]) -> Result<Vector> {
|
||||
Ok(Vector {
|
||||
vector_type: VectorType::Float32,
|
||||
dims: blob.len() / 4,
|
||||
data: blob.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn do_vector_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
match v1.vector_type {
|
||||
VectorType::Float32 => vector_f32_distance_cos(v1, v2),
|
||||
VectorType::Float64 => vector_f64_distance_cos(v1, v2),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vector_f32_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
if v1.dims != v2.dims {
|
||||
return Err(Error::InvalidDimensions);
|
||||
}
|
||||
if v1.vector_type != v2.vector_type {
|
||||
return Err(Error::InvalidType);
|
||||
}
|
||||
let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0);
|
||||
let v1_data = v1.as_f32_slice();
|
||||
let v2_data = v2.as_f32_slice();
|
||||
for i in 0..v1.dims {
|
||||
let e1 = v1_data[i];
|
||||
let e2 = v2_data[i];
|
||||
dot += e1 * e2;
|
||||
norm1 += e1 * e1;
|
||||
norm2 += e2 * e2;
|
||||
}
|
||||
Ok(1.0 - (dot / (norm1 * norm2).sqrt()) as f64)
|
||||
}
|
||||
|
||||
pub fn vector_f64_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
if v1.dims != v2.dims {
|
||||
return Err(Error::InvalidDimensions);
|
||||
}
|
||||
if v1.vector_type != v2.vector_type {
|
||||
return Err(Error::InvalidType);
|
||||
}
|
||||
let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0);
|
||||
let v1_data = v1.as_f64_slice();
|
||||
let v2_data = v2.as_f64_slice();
|
||||
for i in 0..v1.dims {
|
||||
let e1 = v1_data[i];
|
||||
let e2 = v2_data[i];
|
||||
dot += e1 * e2;
|
||||
norm1 += e1 * e1;
|
||||
norm2 += e2 * e2;
|
||||
}
|
||||
Ok(1.0 - (dot / (norm1 * norm2).sqrt()))
|
||||
}
|
||||
|
||||
pub fn vector_type(blob: &[u8]) -> Result<VectorType> {
|
||||
if blob.is_empty() {
|
||||
return Err(Error::InvalidFormat);
|
||||
}
|
||||
// Even-sized blobs are always float32.
|
||||
if blob.len() % 2 == 0 {
|
||||
return Ok(VectorType::Float32);
|
||||
}
|
||||
// Odd-sized blobs have type byte at the end
|
||||
let (data_blob, type_byte) = blob.split_at(blob.len() - 1);
|
||||
let vector_type = type_byte[0];
|
||||
match vector_type {
|
||||
1 => {
|
||||
if data_blob.len() % 4 != 0 {
|
||||
return Err(Error::InvalidFormat);
|
||||
}
|
||||
Ok(VectorType::Float32)
|
||||
}
|
||||
2 => {
|
||||
if data_blob.len() % 8 != 0 {
|
||||
return Err(Error::InvalidFormat);
|
||||
}
|
||||
Ok(VectorType::Float64)
|
||||
}
|
||||
_ => Err(Error::InvalidType),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_string_vector_zero_length() {
|
||||
let value = Value::from_text("[]".to_string());
|
||||
let vector = parse_string_vector(VectorType::Float32, &value).unwrap();
|
||||
assert_eq!(vector.dims, 0);
|
||||
assert_eq!(vector.vector_type, VectorType::Float32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_string_vector_valid_whitespace() {
|
||||
let value = Value::from_text(" [ 1.0 , 2.0 , 3.0 ] ".to_string());
|
||||
let vector = parse_string_vector(VectorType::Float32, &value).unwrap();
|
||||
assert_eq!(vector.dims, 3);
|
||||
assert_eq!(vector.vector_type, VectorType::Float32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_string_vector_valid() {
|
||||
let value = Value::from_text("[1.0, 2.0, 3.0]".to_string());
|
||||
let vector = parse_string_vector(VectorType::Float32, &value).unwrap();
|
||||
assert_eq!(vector.dims, 3);
|
||||
assert_eq!(vector.vector_type, VectorType::Float32);
|
||||
}
|
||||
}
|
||||
14
testing/vector.test
Executable file
14
testing/vector.test
Executable file
@@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env tclsh
|
||||
|
||||
set testdir [file dirname $argv0]
|
||||
source $testdir/tester.tcl
|
||||
|
||||
do_execsql_test vector-functions-valid {
|
||||
SELECT vector_extract(vector('[]'));
|
||||
SELECT vector_extract(vector(' [ 1 , 2 , 3 ] '));
|
||||
SELECT vector_extract(vector('[-1000000000000000000]'));
|
||||
} {
|
||||
{[]}
|
||||
{[1,2,3]}
|
||||
{[-1000000000000000000]}
|
||||
}
|
||||
Reference in New Issue
Block a user