diff --git a/COMPAT.md b/COMPAT.md index f3b6f1a50..f7082afa1 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -616,3 +616,15 @@ The `regexp` extension is compatible with [sqlean-regexp](https://github.com/nal | regexp_substr(source, pattern) | Yes | | | regexp_capture(source, pattern[, n]) | No | | | regexp_replace(source, pattern, replacement) | No | | + +### Vector + +The `vector` extension is compatible with libSQL native vector search. + +| Function | Status | Comment | +|------------------------------------------------|--------|---------| +| vector(x) | Yes | | +| vector32(x) | Yes | | +| vector64(x) | Yes | | +| vector_extract(x) | Yes | | +| vector_distance_cos(x, y) | Yes | | diff --git a/Cargo.lock b/Cargo.lock index 0e2deb1ea..c0c2abab1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1342,6 +1342,7 @@ dependencies = [ "limbo_percentile", "limbo_regexp", "limbo_uuid", + "limbo_vector", "log", "miette", "mimalloc", @@ -1441,6 +1442,13 @@ dependencies = [ "uuid", ] +[[package]] +name = "limbo_vector" +version = "0.0.13" +dependencies = [ + "limbo_ext", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" diff --git a/Cargo.toml b/Cargo.toml index 0ffbdf6ac..3ff1ad34a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ members = [ "sqlite3", "tests", "extensions/percentile", + "extensions/vector", ] exclude = ["perf/latency/limbo"] diff --git a/Makefile b/Makefile index 69fcccf83..b66e80d76 100644 --- a/Makefile +++ b/Makefile @@ -62,7 +62,7 @@ limbo-wasm: cargo build --package limbo-wasm --target wasm32-wasi .PHONY: limbo-wasm -test: limbo test-compat test-sqlite3 test-shell test-extensions +test: limbo test-compat test-vector test-sqlite3 test-shell test-extensions .PHONY: test test-extensions: limbo @@ -78,6 +78,10 @@ test-compat: SQLITE_EXEC=$(SQLITE_EXEC) ./testing/all.test .PHONY: test-compat +test-vector: + SQLITE_EXEC=$(SQLITE_EXEC) ./testing/vector.test +.PHONY: test-vector + test-sqlite3: limbo-c LIBS="$(SQLITE_LIB)" HEADERS="$(SQLITE_LIB_HEADERS)" make -C sqlite3/tests test .PHONY: test-sqlite3 diff --git a/core/Cargo.toml b/core/Cargo.toml index 694617ec3..1fedb535c 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -14,7 +14,7 @@ name = "limbo_core" path = "lib.rs" [features] -default = ["fs", "json", "uuid", "io_uring"] +default = ["fs", "json", "uuid", "vector", "io_uring"] fs = [] json = [ "dep:jsonb", @@ -22,6 +22,7 @@ json = [ "dep:pest_derive", ] uuid = ["limbo_uuid/static"] +vector = ["limbo_vector/static"] io_uring = ["dep:io-uring", "rustix/io_uring"] percentile = ["limbo_percentile/static"] regexp = ["limbo_regexp/static"] @@ -61,6 +62,7 @@ rand = "0.8.5" bumpalo = { version = "3.16.0", features = ["collections", "boxed"] } limbo_macros = { path = "../macros" } limbo_uuid = { path = "../extensions/uuid", optional = true, features = ["static"] } +limbo_vector = { path = "../extensions/vector", optional = true, features = ["static"] } limbo_regexp = { path = "../extensions/regexp", optional = true, features = ["static"] } limbo_percentile = { path = "../extensions/percentile", optional = true, features = ["static"] } miette = "7.4.0" diff --git a/core/ext/mod.rs b/core/ext/mod.rs index cbd2fa258..c38a99f9c 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -80,6 +80,10 @@ impl Database { if unsafe { !limbo_uuid::register_extension_static(&ext_api).is_ok() } { return Err("Failed to register uuid extension".to_string()); } + #[cfg(feature = "vector")] + if unsafe { !limbo_vector::register_extension_static(&ext_api).is_ok() } { + return Err("Failed to register vector extension".to_string()); + } #[cfg(feature = "percentile")] if unsafe { !limbo_percentile::register_extension_static(&ext_api).is_ok() } { return Err("Failed to register percentile extension".to_string()); diff --git a/extensions/vector/Cargo.toml b/extensions/vector/Cargo.toml new file mode 100644 index 000000000..e80863fd2 --- /dev/null +++ b/extensions/vector/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "limbo_vector" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +static= [ "limbo_ext/static" ] + +[dependencies] +limbo_ext = { path = "../core", features = ["static"] } diff --git a/extensions/vector/src/lib.rs b/extensions/vector/src/lib.rs new file mode 100644 index 000000000..388368e12 --- /dev/null +++ b/extensions/vector/src/lib.rs @@ -0,0 +1,77 @@ +use limbo_ext::{register_extension, scalar, ResultCode, Value}; + +mod vector; + +use vector::*; + +#[derive(Debug)] +enum Error { + InvalidType, + InvalidFormat, + InvalidDimensions, +} + +type Result = std::result::Result; + +#[scalar(name = "vector32", alias = "vector")] +fn vector32(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::Error); + } + let Ok(x) = parse_vector(&args[0], Some(VectorType::Float32)) else { + return Value::error(ResultCode::Error); + }; + vector_serialize_f32(x) +} + +#[scalar(name = "vector64")] +fn vector64(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::Error); + } + let Ok(x) = parse_vector(&args[0], Some(VectorType::Float64)) else { + return Value::error(ResultCode::Error); + }; + vector_serialize_f64(x) +} + +#[scalar(name = "vector_extract")] +fn vector_extract(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::Error); + } + let Some(blob) = args[0].to_blob() else { + return Value::error(ResultCode::Error); + }; + if blob.is_empty() { + return Value::from_text("[]".to_string()); + } + let Ok(vector_type) = vector_type(&blob) else { + return Value::error(ResultCode::Error); + }; + let Ok(vector) = vector_deserialize(vector_type, &blob) else { + return Value::error(ResultCode::Error); + }; + Value::from_text(vector_to_text(&vector)) +} + +#[scalar(name = "vector_distance_cos")] +fn vector_distance_cos(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::Error); + } + let Ok(x) = parse_vector(&args[0], None) else { + return Value::error(ResultCode::Error); + }; + let Ok(y) = parse_vector(&args[1], None) else { + return Value::error(ResultCode::Error); + }; + let Ok(dist) = do_vector_distance_cos(&x, &y) else { + return Value::error(ResultCode::Error); + }; + Value::from_float(dist) +} + +register_extension! { + scalars: { vector32, vector64, vector_extract, vector_distance_cos }, +} diff --git a/extensions/vector/src/vector.rs b/extensions/vector/src/vector.rs new file mode 100644 index 000000000..ba6f8128f --- /dev/null +++ b/extensions/vector/src/vector.rs @@ -0,0 +1,269 @@ +use limbo_ext::{Value, ValueType}; + +use crate::{Error, Result}; + +#[derive(Debug, PartialEq)] +pub enum VectorType { + Float32, + Float64, +} + +impl VectorType { + pub fn size_to_dims(&self, size: usize) -> usize { + match self { + VectorType::Float32 => size / 4, + VectorType::Float64 => size / 8, + } + } +} + +#[derive(Debug)] +pub struct Vector { + pub vector_type: VectorType, + pub dims: usize, + pub data: Vec, +} + +impl Vector { + pub fn as_f32_slice(&self) -> &[f32] { + unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const f32, self.dims) } + } + + pub fn as_f64_slice(&self) -> &[f64] { + unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const f64, self.dims) } + } +} + +/// Parse a vector in text representation into a Vector. +/// +/// The format of a vector in text representation looks as follows: +/// +/// ```console +/// [1.0, 2.0, 3.0] +/// ``` +pub fn parse_string_vector(vector_type: VectorType, value: &Value) -> Result { + let Some(text) = value.to_text() else { + return Err(Error::InvalidFormat); + }; + let text = text.trim(); + let mut chars = text.chars(); + if chars.next() != Some('[') || chars.last() != Some(']') { + return Err(Error::InvalidFormat); + } + let mut data: Vec = Vec::new(); + let text = &text[1..text.len() - 1]; + if text.trim().is_empty() { + return Ok(Vector { + vector_type, + dims: 0, + data, + }); + } + let xs = text.split(','); + for x in xs { + let x = x.trim(); + if x.is_empty() { + return Err(Error::InvalidFormat); + } + match vector_type { + VectorType::Float32 => { + let x = x.parse::().map_err(|_| Error::InvalidFormat)?; + data.extend_from_slice(&x.to_le_bytes()); + } + VectorType::Float64 => { + let x = x.parse::().map_err(|_| Error::InvalidFormat)?; + data.extend_from_slice(&x.to_le_bytes()); + } + }; + } + let dims = vector_type.size_to_dims(data.len()); + Ok(Vector { + vector_type, + dims, + data, + }) +} + +pub fn parse_vector(value: &Value, vec_ty: Option) -> Result { + match value.value_type() { + ValueType::Text => parse_string_vector(vec_ty.unwrap_or(VectorType::Float32), value), + ValueType::Blob => { + let Some(blob) = value.to_blob() else { + return Err(Error::InvalidFormat); + }; + let vector_type = vector_type(&blob)?; + if let Some(vec_ty) = vec_ty { + if vec_ty != vector_type { + return Err(Error::InvalidType); + } + } + vector_deserialize(vector_type, &blob) + } + _ => Err(Error::InvalidType), + } +} + +pub fn vector_to_text(vector: &Vector) -> String { + let mut text = String::new(); + text.push('['); + match vector.vector_type { + VectorType::Float32 => { + let data = vector.as_f32_slice(); + for i in 0..vector.dims { + text.push_str(&data[i].to_string()); + if i < vector.dims - 1 { + text.push(','); + } + } + } + VectorType::Float64 => { + let data = vector.as_f64_slice(); + for i in 0..vector.dims { + text.push_str(&data[i].to_string()); + if i < vector.dims - 1 { + text.push(','); + } + } + } + } + text.push(']'); + text +} + +pub fn vector_deserialize(vector_type: VectorType, blob: &[u8]) -> Result { + match vector_type { + VectorType::Float32 => vector_deserialize_f32(blob), + VectorType::Float64 => vector_deserialize_f64(blob), + } +} + +pub fn vector_serialize_f64(x: Vector) -> Value { + let mut blob = Vec::with_capacity(x.dims * 8 + 1); + blob.extend_from_slice(&x.data); + blob.push(2); + Value::from_blob(blob) +} + +pub fn vector_deserialize_f64(blob: &[u8]) -> Result { + Ok(Vector { + vector_type: VectorType::Float64, + dims: (blob.len() - 1) / 8, + data: blob[..blob.len() - 1].to_vec(), + }) +} + +pub fn vector_serialize_f32(x: Vector) -> Value { + Value::from_blob(x.data) +} + +pub fn vector_deserialize_f32(blob: &[u8]) -> Result { + Ok(Vector { + vector_type: VectorType::Float32, + dims: blob.len() / 4, + data: blob.to_vec(), + }) +} + +pub fn do_vector_distance_cos(v1: &Vector, v2: &Vector) -> Result { + match v1.vector_type { + VectorType::Float32 => vector_f32_distance_cos(v1, v2), + VectorType::Float64 => vector_f64_distance_cos(v1, v2), + } +} + +pub fn vector_f32_distance_cos(v1: &Vector, v2: &Vector) -> Result { + if v1.dims != v2.dims { + return Err(Error::InvalidDimensions); + } + if v1.vector_type != v2.vector_type { + return Err(Error::InvalidType); + } + let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0); + let v1_data = v1.as_f32_slice(); + let v2_data = v2.as_f32_slice(); + for i in 0..v1.dims { + let e1 = v1_data[i]; + let e2 = v2_data[i]; + dot += e1 * e2; + norm1 += e1 * e1; + norm2 += e2 * e2; + } + Ok(1.0 - (dot / (norm1 * norm2).sqrt()) as f64) +} + +pub fn vector_f64_distance_cos(v1: &Vector, v2: &Vector) -> Result { + if v1.dims != v2.dims { + return Err(Error::InvalidDimensions); + } + if v1.vector_type != v2.vector_type { + return Err(Error::InvalidType); + } + let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0); + let v1_data = v1.as_f64_slice(); + let v2_data = v2.as_f64_slice(); + for i in 0..v1.dims { + let e1 = v1_data[i]; + let e2 = v2_data[i]; + dot += e1 * e2; + norm1 += e1 * e1; + norm2 += e2 * e2; + } + Ok(1.0 - (dot / (norm1 * norm2).sqrt())) +} + +pub fn vector_type(blob: &[u8]) -> Result { + if blob.is_empty() { + return Err(Error::InvalidFormat); + } + // Even-sized blobs are always float32. + if blob.len() % 2 == 0 { + return Ok(VectorType::Float32); + } + // Odd-sized blobs have type byte at the end + let (data_blob, type_byte) = blob.split_at(blob.len() - 1); + let vector_type = type_byte[0]; + match vector_type { + 1 => { + if data_blob.len() % 4 != 0 { + return Err(Error::InvalidFormat); + } + Ok(VectorType::Float32) + } + 2 => { + if data_blob.len() % 8 != 0 { + return Err(Error::InvalidFormat); + } + Ok(VectorType::Float64) + } + _ => Err(Error::InvalidType), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_string_vector_zero_length() { + let value = Value::from_text("[]".to_string()); + let vector = parse_string_vector(VectorType::Float32, &value).unwrap(); + assert_eq!(vector.dims, 0); + assert_eq!(vector.vector_type, VectorType::Float32); + } + + #[test] + fn test_parse_string_vector_valid_whitespace() { + let value = Value::from_text(" [ 1.0 , 2.0 , 3.0 ] ".to_string()); + let vector = parse_string_vector(VectorType::Float32, &value).unwrap(); + assert_eq!(vector.dims, 3); + assert_eq!(vector.vector_type, VectorType::Float32); + } + + #[test] + fn test_parse_string_vector_valid() { + let value = Value::from_text("[1.0, 2.0, 3.0]".to_string()); + let vector = parse_string_vector(VectorType::Float32, &value).unwrap(); + assert_eq!(vector.dims, 3); + assert_eq!(vector.vector_type, VectorType::Float32); + } +} diff --git a/testing/vector.test b/testing/vector.test new file mode 100755 index 000000000..7cae2ca5e --- /dev/null +++ b/testing/vector.test @@ -0,0 +1,14 @@ +#!/usr/bin/env tclsh + +set testdir [file dirname $argv0] +source $testdir/tester.tcl + +do_execsql_test vector-functions-valid { + SELECT vector_extract(vector('[]')); + SELECT vector_extract(vector(' [ 1 , 2 , 3 ] ')); + SELECT vector_extract(vector('[-1000000000000000000]')); +} { + {[]} + {[1,2,3]} + {[-1000000000000000000]} +}