Merge 'Initial pass on vector extension' from Pekka Enberg

This pull requests adds some libSQL vector extension functions such as
`vector()` and `vector_distance_cos()`, which can be used for exact
nearest neighbor search as follows:
```
limbo> SELECT embedding, vector_distance_cos(embedding, '[9, 9, 9]')
   ...> FROM movies ORDER BY vector_distance_cos(embedding, '[9, 9, 9]');
[4, 5, 6]|0.013072490692138672
[1, 2, 3]|0.07417994737625122
```
Note that libSQL also support approximate nearest neighbour search with
DiskANN indexing, which is something we eventually want to port to Limbo
as well.

Closes #798
This commit is contained in:
Pekka Enberg
2025-01-29 09:43:04 +02:00
10 changed files with 409 additions and 2 deletions

View File

@@ -616,3 +616,15 @@ The `regexp` extension is compatible with [sqlean-regexp](https://github.com/nal
| regexp_substr(source, pattern) | Yes | |
| regexp_capture(source, pattern[, n]) | No | |
| regexp_replace(source, pattern, replacement) | No | |
### Vector
The `vector` extension is compatible with libSQL native vector search.
| Function | Status | Comment |
|------------------------------------------------|--------|---------|
| vector(x) | Yes | |
| vector32(x) | Yes | |
| vector64(x) | Yes | |
| vector_extract(x) | Yes | |
| vector_distance_cos(x, y) | Yes | |

8
Cargo.lock generated
View File

@@ -1342,6 +1342,7 @@ dependencies = [
"limbo_percentile",
"limbo_regexp",
"limbo_uuid",
"limbo_vector",
"log",
"miette",
"mimalloc",
@@ -1441,6 +1442,13 @@ dependencies = [
"uuid",
]
[[package]]
name = "limbo_vector"
version = "0.0.13"
dependencies = [
"limbo_ext",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"

View File

@@ -18,6 +18,7 @@ members = [
"sqlite3",
"tests",
"extensions/percentile",
"extensions/vector",
]
exclude = ["perf/latency/limbo"]

View File

@@ -62,7 +62,7 @@ limbo-wasm:
cargo build --package limbo-wasm --target wasm32-wasi
.PHONY: limbo-wasm
test: limbo test-compat test-sqlite3 test-shell test-extensions
test: limbo test-compat test-vector test-sqlite3 test-shell test-extensions
.PHONY: test
test-extensions: limbo
@@ -78,6 +78,10 @@ test-compat:
SQLITE_EXEC=$(SQLITE_EXEC) ./testing/all.test
.PHONY: test-compat
test-vector:
SQLITE_EXEC=$(SQLITE_EXEC) ./testing/vector.test
.PHONY: test-vector
test-sqlite3: limbo-c
LIBS="$(SQLITE_LIB)" HEADERS="$(SQLITE_LIB_HEADERS)" make -C sqlite3/tests test
.PHONY: test-sqlite3

View File

@@ -14,7 +14,7 @@ name = "limbo_core"
path = "lib.rs"
[features]
default = ["fs", "json", "uuid", "io_uring"]
default = ["fs", "json", "uuid", "vector", "io_uring"]
fs = []
json = [
"dep:jsonb",
@@ -22,6 +22,7 @@ json = [
"dep:pest_derive",
]
uuid = ["limbo_uuid/static"]
vector = ["limbo_vector/static"]
io_uring = ["dep:io-uring", "rustix/io_uring"]
percentile = ["limbo_percentile/static"]
regexp = ["limbo_regexp/static"]
@@ -61,6 +62,7 @@ rand = "0.8.5"
bumpalo = { version = "3.16.0", features = ["collections", "boxed"] }
limbo_macros = { path = "../macros" }
limbo_uuid = { path = "../extensions/uuid", optional = true, features = ["static"] }
limbo_vector = { path = "../extensions/vector", optional = true, features = ["static"] }
limbo_regexp = { path = "../extensions/regexp", optional = true, features = ["static"] }
limbo_percentile = { path = "../extensions/percentile", optional = true, features = ["static"] }
miette = "7.4.0"

View File

@@ -80,6 +80,10 @@ impl Database {
if unsafe { !limbo_uuid::register_extension_static(&ext_api).is_ok() } {
return Err("Failed to register uuid extension".to_string());
}
#[cfg(feature = "vector")]
if unsafe { !limbo_vector::register_extension_static(&ext_api).is_ok() } {
return Err("Failed to register vector extension".to_string());
}
#[cfg(feature = "percentile")]
if unsafe { !limbo_percentile::register_extension_static(&ext_api).is_ok() } {
return Err("Failed to register percentile extension".to_string());

View File

@@ -0,0 +1,16 @@
[package]
name = "limbo_vector"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
[lib]
crate-type = ["cdylib", "lib"]
[features]
static= [ "limbo_ext/static" ]
[dependencies]
limbo_ext = { path = "../core", features = ["static"] }

View File

@@ -0,0 +1,77 @@
use limbo_ext::{register_extension, scalar, ResultCode, Value};
mod vector;
use vector::*;
#[derive(Debug)]
enum Error {
InvalidType,
InvalidFormat,
InvalidDimensions,
}
type Result<T> = std::result::Result<T, Error>;
#[scalar(name = "vector32", alias = "vector")]
fn vector32(args: &[Value]) -> Value {
if args.len() != 1 {
return Value::error(ResultCode::Error);
}
let Ok(x) = parse_vector(&args[0], Some(VectorType::Float32)) else {
return Value::error(ResultCode::Error);
};
vector_serialize_f32(x)
}
#[scalar(name = "vector64")]
fn vector64(args: &[Value]) -> Value {
if args.len() != 1 {
return Value::error(ResultCode::Error);
}
let Ok(x) = parse_vector(&args[0], Some(VectorType::Float64)) else {
return Value::error(ResultCode::Error);
};
vector_serialize_f64(x)
}
#[scalar(name = "vector_extract")]
fn vector_extract(args: &[Value]) -> Value {
if args.len() != 1 {
return Value::error(ResultCode::Error);
}
let Some(blob) = args[0].to_blob() else {
return Value::error(ResultCode::Error);
};
if blob.is_empty() {
return Value::from_text("[]".to_string());
}
let Ok(vector_type) = vector_type(&blob) else {
return Value::error(ResultCode::Error);
};
let Ok(vector) = vector_deserialize(vector_type, &blob) else {
return Value::error(ResultCode::Error);
};
Value::from_text(vector_to_text(&vector))
}
#[scalar(name = "vector_distance_cos")]
fn vector_distance_cos(args: &[Value]) -> Value {
if args.len() != 2 {
return Value::error(ResultCode::Error);
}
let Ok(x) = parse_vector(&args[0], None) else {
return Value::error(ResultCode::Error);
};
let Ok(y) = parse_vector(&args[1], None) else {
return Value::error(ResultCode::Error);
};
let Ok(dist) = do_vector_distance_cos(&x, &y) else {
return Value::error(ResultCode::Error);
};
Value::from_float(dist)
}
register_extension! {
scalars: { vector32, vector64, vector_extract, vector_distance_cos },
}

View File

@@ -0,0 +1,269 @@
use limbo_ext::{Value, ValueType};
use crate::{Error, Result};
#[derive(Debug, PartialEq)]
pub enum VectorType {
Float32,
Float64,
}
impl VectorType {
pub fn size_to_dims(&self, size: usize) -> usize {
match self {
VectorType::Float32 => size / 4,
VectorType::Float64 => size / 8,
}
}
}
#[derive(Debug)]
pub struct Vector {
pub vector_type: VectorType,
pub dims: usize,
pub data: Vec<u8>,
}
impl Vector {
pub fn as_f32_slice(&self) -> &[f32] {
unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const f32, self.dims) }
}
pub fn as_f64_slice(&self) -> &[f64] {
unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const f64, self.dims) }
}
}
/// Parse a vector in text representation into a Vector.
///
/// The format of a vector in text representation looks as follows:
///
/// ```console
/// [1.0, 2.0, 3.0]
/// ```
pub fn parse_string_vector(vector_type: VectorType, value: &Value) -> Result<Vector> {
let Some(text) = value.to_text() else {
return Err(Error::InvalidFormat);
};
let text = text.trim();
let mut chars = text.chars();
if chars.next() != Some('[') || chars.last() != Some(']') {
return Err(Error::InvalidFormat);
}
let mut data: Vec<u8> = Vec::new();
let text = &text[1..text.len() - 1];
if text.trim().is_empty() {
return Ok(Vector {
vector_type,
dims: 0,
data,
});
}
let xs = text.split(',');
for x in xs {
let x = x.trim();
if x.is_empty() {
return Err(Error::InvalidFormat);
}
match vector_type {
VectorType::Float32 => {
let x = x.parse::<f32>().map_err(|_| Error::InvalidFormat)?;
data.extend_from_slice(&x.to_le_bytes());
}
VectorType::Float64 => {
let x = x.parse::<f64>().map_err(|_| Error::InvalidFormat)?;
data.extend_from_slice(&x.to_le_bytes());
}
};
}
let dims = vector_type.size_to_dims(data.len());
Ok(Vector {
vector_type,
dims,
data,
})
}
pub fn parse_vector(value: &Value, vec_ty: Option<VectorType>) -> Result<Vector> {
match value.value_type() {
ValueType::Text => parse_string_vector(vec_ty.unwrap_or(VectorType::Float32), value),
ValueType::Blob => {
let Some(blob) = value.to_blob() else {
return Err(Error::InvalidFormat);
};
let vector_type = vector_type(&blob)?;
if let Some(vec_ty) = vec_ty {
if vec_ty != vector_type {
return Err(Error::InvalidType);
}
}
vector_deserialize(vector_type, &blob)
}
_ => Err(Error::InvalidType),
}
}
pub fn vector_to_text(vector: &Vector) -> String {
let mut text = String::new();
text.push('[');
match vector.vector_type {
VectorType::Float32 => {
let data = vector.as_f32_slice();
for i in 0..vector.dims {
text.push_str(&data[i].to_string());
if i < vector.dims - 1 {
text.push(',');
}
}
}
VectorType::Float64 => {
let data = vector.as_f64_slice();
for i in 0..vector.dims {
text.push_str(&data[i].to_string());
if i < vector.dims - 1 {
text.push(',');
}
}
}
}
text.push(']');
text
}
pub fn vector_deserialize(vector_type: VectorType, blob: &[u8]) -> Result<Vector> {
match vector_type {
VectorType::Float32 => vector_deserialize_f32(blob),
VectorType::Float64 => vector_deserialize_f64(blob),
}
}
pub fn vector_serialize_f64(x: Vector) -> Value {
let mut blob = Vec::with_capacity(x.dims * 8 + 1);
blob.extend_from_slice(&x.data);
blob.push(2);
Value::from_blob(blob)
}
pub fn vector_deserialize_f64(blob: &[u8]) -> Result<Vector> {
Ok(Vector {
vector_type: VectorType::Float64,
dims: (blob.len() - 1) / 8,
data: blob[..blob.len() - 1].to_vec(),
})
}
pub fn vector_serialize_f32(x: Vector) -> Value {
Value::from_blob(x.data)
}
pub fn vector_deserialize_f32(blob: &[u8]) -> Result<Vector> {
Ok(Vector {
vector_type: VectorType::Float32,
dims: blob.len() / 4,
data: blob.to_vec(),
})
}
pub fn do_vector_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
match v1.vector_type {
VectorType::Float32 => vector_f32_distance_cos(v1, v2),
VectorType::Float64 => vector_f64_distance_cos(v1, v2),
}
}
pub fn vector_f32_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
if v1.dims != v2.dims {
return Err(Error::InvalidDimensions);
}
if v1.vector_type != v2.vector_type {
return Err(Error::InvalidType);
}
let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0);
let v1_data = v1.as_f32_slice();
let v2_data = v2.as_f32_slice();
for i in 0..v1.dims {
let e1 = v1_data[i];
let e2 = v2_data[i];
dot += e1 * e2;
norm1 += e1 * e1;
norm2 += e2 * e2;
}
Ok(1.0 - (dot / (norm1 * norm2).sqrt()) as f64)
}
pub fn vector_f64_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
if v1.dims != v2.dims {
return Err(Error::InvalidDimensions);
}
if v1.vector_type != v2.vector_type {
return Err(Error::InvalidType);
}
let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0);
let v1_data = v1.as_f64_slice();
let v2_data = v2.as_f64_slice();
for i in 0..v1.dims {
let e1 = v1_data[i];
let e2 = v2_data[i];
dot += e1 * e2;
norm1 += e1 * e1;
norm2 += e2 * e2;
}
Ok(1.0 - (dot / (norm1 * norm2).sqrt()))
}
pub fn vector_type(blob: &[u8]) -> Result<VectorType> {
if blob.is_empty() {
return Err(Error::InvalidFormat);
}
// Even-sized blobs are always float32.
if blob.len() % 2 == 0 {
return Ok(VectorType::Float32);
}
// Odd-sized blobs have type byte at the end
let (data_blob, type_byte) = blob.split_at(blob.len() - 1);
let vector_type = type_byte[0];
match vector_type {
1 => {
if data_blob.len() % 4 != 0 {
return Err(Error::InvalidFormat);
}
Ok(VectorType::Float32)
}
2 => {
if data_blob.len() % 8 != 0 {
return Err(Error::InvalidFormat);
}
Ok(VectorType::Float64)
}
_ => Err(Error::InvalidType),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_string_vector_zero_length() {
let value = Value::from_text("[]".to_string());
let vector = parse_string_vector(VectorType::Float32, &value).unwrap();
assert_eq!(vector.dims, 0);
assert_eq!(vector.vector_type, VectorType::Float32);
}
#[test]
fn test_parse_string_vector_valid_whitespace() {
let value = Value::from_text(" [ 1.0 , 2.0 , 3.0 ] ".to_string());
let vector = parse_string_vector(VectorType::Float32, &value).unwrap();
assert_eq!(vector.dims, 3);
assert_eq!(vector.vector_type, VectorType::Float32);
}
#[test]
fn test_parse_string_vector_valid() {
let value = Value::from_text("[1.0, 2.0, 3.0]".to_string());
let vector = parse_string_vector(VectorType::Float32, &value).unwrap();
assert_eq!(vector.dims, 3);
assert_eq!(vector.vector_type, VectorType::Float32);
}
}

14
testing/vector.test Executable file
View File

@@ -0,0 +1,14 @@
#!/usr/bin/env tclsh
set testdir [file dirname $argv0]
source $testdir/tester.tcl
do_execsql_test vector-functions-valid {
SELECT vector_extract(vector('[]'));
SELECT vector_extract(vector(' [ 1 , 2 , 3 ] '));
SELECT vector_extract(vector('[-1000000000000000000]'));
} {
{[]}
{[1,2,3]}
{[-1000000000000000000]}
}