mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-26 04:24:21 +01:00
Merge 'Vector speedup' from Nikita Sivukhin
This PR reduces allocation for vector distance calculation and also use [simsimd](https://github.com/ashvardanian/SimSIMD) library to execute cosine/l2 distances for dense vectors. Reviewed-by: Pere Diaz Bou <pere-altea@homail.com> Closes #3802
This commit is contained in:
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -523,10 +523,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.17"
|
||||
version = "1.2.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a"
|
||||
checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7"
|
||||
dependencies = [
|
||||
"find-msvc-tools",
|
||||
"jobserver",
|
||||
"libc",
|
||||
"shlex",
|
||||
@@ -1504,6 +1505,12 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "find-msvc-tools"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127"
|
||||
|
||||
[[package]]
|
||||
name = "findshlibs"
|
||||
version = "0.10.2"
|
||||
@@ -4143,6 +4150,15 @@ dependencies = [
|
||||
"similar",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simsimd"
|
||||
version = "6.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e3f209c5a8155b8458b1a0d3a6fc9fa09d201e6086fdaae18e9e283b9274f8f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
@@ -4924,6 +4940,7 @@ dependencies = [
|
||||
"rustix 1.0.7",
|
||||
"ryu",
|
||||
"serde",
|
||||
"simsimd",
|
||||
"sorted-vec",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
|
||||
@@ -1,6 +1,19 @@
|
||||
import { expect, test } from 'vitest'
|
||||
import { connect, Database } from './promise-default.js'
|
||||
|
||||
test('vector-test', async () => {
|
||||
const db = await connect(":memory:");
|
||||
const v1 = new Array(1024).fill(0).map((_, i) => i);
|
||||
const v2 = new Array(1024).fill(0).map((_, i) => 1024 - i);
|
||||
const result = await db.prepare(`SELECT
|
||||
vector_distance_cos(vector32('${JSON.stringify(v1)}'), vector32('${JSON.stringify(v2)}')) as cosf32,
|
||||
vector_distance_cos(vector64('${JSON.stringify(v1)}'), vector64('${JSON.stringify(v2)}')) as cosf64,
|
||||
vector_distance_l2(vector32('${JSON.stringify(v1)}'), vector32('${JSON.stringify(v2)}')) as l2f32,
|
||||
vector_distance_l2(vector64('${JSON.stringify(v1)}'), vector64('${JSON.stringify(v2)}')) as l2f64
|
||||
`).all();
|
||||
console.info(result);
|
||||
})
|
||||
|
||||
test('explain', async () => {
|
||||
const db = await connect(":memory:");
|
||||
const stmt = db.prepare("EXPLAIN SELECT 1");
|
||||
|
||||
37
cli/app.rs
37
cli/app.rs
@@ -106,44 +106,65 @@ macro_rules! row_step_result_query {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
let start = if $stats.is_some() {
|
||||
Some(Instant::now())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
match $rows.step() {
|
||||
Ok(StepResult::Row) => {
|
||||
if let Some(ref mut stats) = $stats {
|
||||
stats.execute_time_elapsed_samples.push(start.elapsed());
|
||||
stats
|
||||
.execute_time_elapsed_samples
|
||||
.push(start.unwrap().elapsed());
|
||||
}
|
||||
|
||||
$row_handle
|
||||
}
|
||||
Ok(StepResult::IO) => {
|
||||
let start = Instant::now();
|
||||
if let Some(ref mut stats) = $stats {
|
||||
stats.io_time_elapsed_samples.push(start.unwrap().elapsed());
|
||||
}
|
||||
let start = if $stats.is_some() {
|
||||
Some(Instant::now())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
$rows.run_once()?;
|
||||
if let Some(ref mut stats) = $stats {
|
||||
stats.io_time_elapsed_samples.push(start.elapsed());
|
||||
stats.io_time_elapsed_samples.push(start.unwrap().elapsed());
|
||||
}
|
||||
}
|
||||
Ok(StepResult::Interrupt) => {
|
||||
if let Some(ref mut stats) = $stats {
|
||||
stats.execute_time_elapsed_samples.push(start.elapsed());
|
||||
stats
|
||||
.execute_time_elapsed_samples
|
||||
.push(start.unwrap().elapsed());
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ok(StepResult::Done) => {
|
||||
if let Some(ref mut stats) = $stats {
|
||||
stats.execute_time_elapsed_samples.push(start.elapsed());
|
||||
stats
|
||||
.execute_time_elapsed_samples
|
||||
.push(start.unwrap().elapsed());
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ok(StepResult::Busy) => {
|
||||
if let Some(ref mut stats) = $stats {
|
||||
stats.execute_time_elapsed_samples.push(start.elapsed());
|
||||
stats
|
||||
.execute_time_elapsed_samples
|
||||
.push(start.unwrap().elapsed());
|
||||
}
|
||||
let _ = $app.writeln("database is busy");
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
if let Some(ref mut stats) = $stats {
|
||||
stats.execute_time_elapsed_samples.push(start.elapsed());
|
||||
stats
|
||||
.execute_time_elapsed_samples
|
||||
.push(start.unwrap().elapsed());
|
||||
}
|
||||
let report = miette::Error::from(err).with_source_code($sql.to_owned());
|
||||
let _ = $app.writeln_fmt(format_args!("{report:?}"));
|
||||
|
||||
@@ -83,6 +83,7 @@ aegis = "0.9.0"
|
||||
twox-hash = "2.1.1"
|
||||
intrusive-collections = "0.9.7"
|
||||
roaring = "0.11.2"
|
||||
simsimd = "6.5.3"
|
||||
|
||||
[build-dependencies]
|
||||
chrono = { workspace = true, default-features = false }
|
||||
|
||||
@@ -20,7 +20,7 @@ pub fn parse_vector(value: &Register, type_hint: Option<VectorType>) -> Result<V
|
||||
"Invalid vector value".to_string(),
|
||||
));
|
||||
};
|
||||
Vector::from_blob(blob.to_vec())
|
||||
Vector::from_slice(blob)
|
||||
}
|
||||
_ => Err(LimboError::ConversionError(
|
||||
"Invalid vector type".to_string(),
|
||||
@@ -81,7 +81,7 @@ pub fn vector_extract(args: &[Register]) -> Result<Value> {
|
||||
return Ok(Value::build_text("[]"));
|
||||
}
|
||||
|
||||
let vector = Vector::from_blob(blob.to_vec())?;
|
||||
let vector = Vector::from_vec(blob.to_vec())?;
|
||||
Ok(Value::build_text(operations::text::vector_to_text(&vector)))
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::{
|
||||
LimboError, Result,
|
||||
};
|
||||
|
||||
pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result<Vector> {
|
||||
pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result<Vector<'static>> {
|
||||
if v1.vector_type != v2.vector_type {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Mismatched vector types".into(),
|
||||
@@ -12,17 +12,17 @@ pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result<Vector> {
|
||||
|
||||
let data = match v1.vector_type {
|
||||
VectorType::Float32Dense | VectorType::Float64Dense => {
|
||||
let mut data = Vec::with_capacity(v1.data.len() + v2.data.len());
|
||||
data.extend_from_slice(&v1.data);
|
||||
data.extend_from_slice(&v2.data);
|
||||
let mut data = Vec::with_capacity(v1.bin_len() + v2.bin_len());
|
||||
data.extend_from_slice(v1.bin_data());
|
||||
data.extend_from_slice(v2.bin_data());
|
||||
data
|
||||
}
|
||||
VectorType::Float32Sparse => {
|
||||
let mut data = Vec::with_capacity(v1.data.len() + v2.data.len());
|
||||
data.extend_from_slice(&v1.data[..v1.data.len() / 2]);
|
||||
data.extend_from_slice(&v2.data[..v2.data.len() / 2]);
|
||||
data.extend_from_slice(&v1.data[v1.data.len() / 2..]);
|
||||
data.extend_from_slice(&v2.data[v2.data.len() / 2..]);
|
||||
let mut data = Vec::with_capacity(v1.bin_len() + v2.bin_len());
|
||||
data.extend_from_slice(&v1.bin_data()[..v1.bin_len() / 2]);
|
||||
data.extend_from_slice(&v2.bin_data()[..v2.bin_len() / 2]);
|
||||
data.extend_from_slice(&v1.bin_data()[v1.bin_len() / 2..]);
|
||||
data.extend_from_slice(&v2.bin_data()[v2.bin_len() / 2..]);
|
||||
data
|
||||
}
|
||||
};
|
||||
@@ -30,7 +30,8 @@ pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result<Vector> {
|
||||
Ok(Vector {
|
||||
vector_type: v1.vector_type,
|
||||
dims: v1.dims + v2.dims,
|
||||
data,
|
||||
owned: Some(data),
|
||||
refer: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -41,7 +42,7 @@ mod tests {
|
||||
vector_types::{Vector, VectorType},
|
||||
};
|
||||
|
||||
fn float32_vec_from(slice: &[f32]) -> Vector {
|
||||
fn float32_vec_from(slice: &[f32]) -> Vector<'static> {
|
||||
let mut data = Vec::new();
|
||||
for &v in slice {
|
||||
data.extend_from_slice(&v.to_le_bytes());
|
||||
@@ -50,7 +51,8 @@ mod tests {
|
||||
Vector {
|
||||
vector_type: VectorType::Float32Dense,
|
||||
dims: slice.len(),
|
||||
data,
|
||||
owned: Some(data),
|
||||
refer: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ mod tests {
|
||||
fn assert_vectors(v1: &Vector, v2: &Vector) {
|
||||
assert_eq!(v1.vector_type, v2.vector_type);
|
||||
assert_eq!(v1.dims, v2.dims);
|
||||
assert_eq!(v1.data, v2.data);
|
||||
assert_eq!(v1.bin_data(), v2.bin_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -77,30 +77,33 @@ mod tests {
|
||||
let vf32 = Vector {
|
||||
vector_type: VectorType::Float32Dense,
|
||||
dims: 3,
|
||||
data: concat(&[
|
||||
owned: Some(concat(&[
|
||||
1.0f32.to_le_bytes(),
|
||||
0.0f32.to_le_bytes(),
|
||||
2.0f32.to_le_bytes(),
|
||||
]),
|
||||
])),
|
||||
refer: None,
|
||||
};
|
||||
let vf64 = Vector {
|
||||
vector_type: VectorType::Float64Dense,
|
||||
dims: 3,
|
||||
data: concat(&[
|
||||
owned: Some(concat(&[
|
||||
1.0f64.to_le_bytes(),
|
||||
0.0f64.to_le_bytes(),
|
||||
2.0f64.to_le_bytes(),
|
||||
]),
|
||||
])),
|
||||
refer: None,
|
||||
};
|
||||
let vf32_sparse = Vector {
|
||||
vector_type: VectorType::Float32Sparse,
|
||||
dims: 3,
|
||||
data: concat(&[
|
||||
owned: Some(concat(&[
|
||||
1.0f32.to_le_bytes(),
|
||||
2.0f32.to_le_bytes(),
|
||||
0u32.to_le_bytes(),
|
||||
2u32.to_le_bytes(),
|
||||
]),
|
||||
])),
|
||||
refer: None,
|
||||
};
|
||||
|
||||
let vectors = [vf32, vf64, vf32_sparse];
|
||||
@@ -110,7 +113,8 @@ mod tests {
|
||||
let v_copy = Vector {
|
||||
vector_type: v1.vector_type,
|
||||
dims: v1.dims,
|
||||
data: v1.data.clone(),
|
||||
owned: v1.owned.clone(),
|
||||
refer: None,
|
||||
};
|
||||
assert_vectors(&vector_convert(v_copy, v2.vector_type).unwrap(), v2);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::{
|
||||
vector::vector_types::{Vector, VectorSparse, VectorType},
|
||||
LimboError, Result,
|
||||
};
|
||||
use simsimd::SpatialSimilarity;
|
||||
|
||||
pub fn vector_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
if v1.dims != v2.dims {
|
||||
@@ -15,11 +16,23 @@ pub fn vector_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
));
|
||||
}
|
||||
match v1.vector_type {
|
||||
VectorType::Float32Dense => Ok(vector_f32_distance_cos(
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
VectorType::Float32Dense => Ok(vector_f32_distance_cos_simsimd(
|
||||
v1.as_f32_slice(),
|
||||
v2.as_f32_slice(),
|
||||
)),
|
||||
VectorType::Float64Dense => Ok(vector_f64_distance_cos(
|
||||
#[cfg(target_family = "wasm")]
|
||||
VectorType::Float32Dense => Ok(vector_f32_distance_cos_rust(
|
||||
v1.as_f32_slice(),
|
||||
v2.as_f32_slice(),
|
||||
)),
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
VectorType::Float64Dense => Ok(vector_f64_distance_cos_simsimd(
|
||||
v1.as_f64_slice(),
|
||||
v2.as_f64_slice(),
|
||||
)),
|
||||
#[cfg(target_family = "wasm")]
|
||||
VectorType::Float64Dense => Ok(vector_f64_distance_cos_rust(
|
||||
v1.as_f64_slice(),
|
||||
v2.as_f64_slice(),
|
||||
)),
|
||||
@@ -30,44 +43,44 @@ pub fn vector_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
}
|
||||
}
|
||||
|
||||
fn vector_f32_distance_cos(v1: &[f32], v2: &[f32]) -> f64 {
|
||||
let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0);
|
||||
|
||||
let dims = v1.len();
|
||||
for i in 0..dims {
|
||||
let e1 = v1[i];
|
||||
let e2 = v2[i];
|
||||
dot += e1 * e2;
|
||||
norm1 += e1 * e1;
|
||||
norm2 += e2 * e2;
|
||||
}
|
||||
|
||||
// Check for zero norms to avoid division by zero
|
||||
if norm1 == 0.0 || norm2 == 0.0 {
|
||||
return f64::NAN;
|
||||
}
|
||||
|
||||
1.0 - (dot / (norm1 * norm2).sqrt()) as f64
|
||||
#[allow(dead_code)]
|
||||
fn vector_f32_distance_cos_simsimd(v1: &[f32], v2: &[f32]) -> f64 {
|
||||
f32::cosine(v1, v2).unwrap_or(f64::NAN)
|
||||
}
|
||||
|
||||
fn vector_f64_distance_cos(v1: &[f64], v2: &[f64]) -> f64 {
|
||||
// SimSIMD do not support WASM for now, so we have alternative implementation: https://github.com/ashvardanian/SimSIMD/issues/189
|
||||
#[allow(dead_code)]
|
||||
fn vector_f32_distance_cos_rust(v1: &[f32], v2: &[f32]) -> f64 {
|
||||
let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0);
|
||||
|
||||
let dims = v1.len();
|
||||
for i in 0..dims {
|
||||
let e1 = v1[i];
|
||||
let e2 = v2[i];
|
||||
dot += e1 * e2;
|
||||
norm1 += e1 * e1;
|
||||
norm2 += e2 * e2;
|
||||
for (a, b) in v1.iter().zip(v2.iter()) {
|
||||
dot += a * b;
|
||||
norm1 += a * a;
|
||||
norm2 += b * b;
|
||||
}
|
||||
|
||||
// Check for zero norms
|
||||
if norm1 == 0.0 || norm2 == 0.0 {
|
||||
return f64::NAN;
|
||||
return 0.0;
|
||||
}
|
||||
(1.0 - dot / (norm1 * norm2).sqrt()) as f64
|
||||
}
|
||||
|
||||
1.0 - (dot / (norm1 * norm2).sqrt())
|
||||
#[allow(dead_code)]
|
||||
fn vector_f64_distance_cos_simsimd(v1: &[f64], v2: &[f64]) -> f64 {
|
||||
f64::cosine(v1, v2).unwrap_or(f64::NAN)
|
||||
}
|
||||
|
||||
// SimSIMD do not support WASM for now, so we have alternative implementation: https://github.com/ashvardanian/SimSIMD/issues/189
|
||||
#[allow(dead_code)]
|
||||
fn vector_f64_distance_cos_rust(v1: &[f64], v2: &[f64]) -> f64 {
|
||||
let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0);
|
||||
for (a, b) in v1.iter().zip(v2.iter()) {
|
||||
dot += a * b;
|
||||
norm1 += a * a;
|
||||
norm2 += b * b;
|
||||
}
|
||||
if norm1 == 0.0 || norm2 == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
1.0 - dot / (norm1 * norm2).sqrt()
|
||||
}
|
||||
|
||||
fn vector_f32_sparse_distance_cos(v1: VectorSparse<f32>, v2: VectorSparse<f32>) -> f64 {
|
||||
@@ -120,20 +133,26 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_vector_distance_cos_f32() {
|
||||
assert!(vector_f32_distance_cos(&[], &[]).is_nan());
|
||||
assert!(vector_f32_distance_cos(&[1.0, 2.0], &[0.0, 0.0]).is_nan());
|
||||
assert_eq!(vector_f32_distance_cos(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
|
||||
assert_eq!(vector_f32_distance_cos(&[1.0, 2.0], &[-1.0, -2.0]), 2.0);
|
||||
assert_eq!(vector_f32_distance_cos(&[1.0, 2.0], &[-2.0, 1.0]), 1.0);
|
||||
assert_eq!(vector_f32_distance_cos_simsimd(&[], &[]), 0.0);
|
||||
assert_eq!(
|
||||
vector_f32_distance_cos_simsimd(&[1.0, 2.0], &[0.0, 0.0]),
|
||||
1.0
|
||||
);
|
||||
assert!(vector_f32_distance_cos_simsimd(&[1.0, 2.0], &[1.0, 2.0]).abs() < 1e-6);
|
||||
assert!((vector_f32_distance_cos_simsimd(&[1.0, 2.0], &[-1.0, -2.0]) - 2.0).abs() < 1e-6);
|
||||
assert!((vector_f32_distance_cos_simsimd(&[1.0, 2.0], &[-2.0, 1.0]) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_distance_cos_f64() {
|
||||
assert!(vector_f64_distance_cos(&[], &[]).is_nan());
|
||||
assert!(vector_f64_distance_cos(&[1.0, 2.0], &[0.0, 0.0]).is_nan());
|
||||
assert_eq!(vector_f64_distance_cos(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
|
||||
assert_eq!(vector_f64_distance_cos(&[1.0, 2.0], &[-1.0, -2.0]), 2.0);
|
||||
assert_eq!(vector_f64_distance_cos(&[1.0, 2.0], &[-2.0, 1.0]), 1.0);
|
||||
assert_eq!(vector_f64_distance_cos_simsimd(&[], &[]), 0.0);
|
||||
assert_eq!(
|
||||
vector_f64_distance_cos_simsimd(&[1.0, 2.0], &[0.0, 0.0]),
|
||||
1.0
|
||||
);
|
||||
assert!(vector_f64_distance_cos_simsimd(&[1.0, 2.0], &[1.0, 2.0]).abs() < 1e-6);
|
||||
assert!((vector_f64_distance_cos_simsimd(&[1.0, 2.0], &[-1.0, -2.0]) - 2.0).abs() < 1e-6);
|
||||
assert!((vector_f64_distance_cos_simsimd(&[1.0, 2.0], &[-2.0, 1.0]) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -148,7 +167,7 @@ mod tests {
|
||||
idx: &[1, 2],
|
||||
values: &[1.0, 3.0]
|
||||
},
|
||||
) - vector_f32_distance_cos(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0]))
|
||||
) - vector_f32_distance_cos_simsimd(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0]))
|
||||
.abs()
|
||||
< 1e-7
|
||||
);
|
||||
@@ -169,4 +188,30 @@ mod tests {
|
||||
|
||||
(d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-6
|
||||
}
|
||||
|
||||
#[quickcheck]
|
||||
fn prop_vector_distance_cos_rust_vs_simsimd_f32(
|
||||
v1: ArbitraryVector<100>,
|
||||
v2: ArbitraryVector<100>,
|
||||
) -> bool {
|
||||
let v1 = vector_convert(v1.into(), VectorType::Float32Dense).unwrap();
|
||||
let v2 = vector_convert(v2.into(), VectorType::Float32Dense).unwrap();
|
||||
let d1 = vector_f32_distance_cos_rust(v1.as_f32_slice(), v2.as_f32_slice());
|
||||
let d2 = vector_f32_distance_cos_simsimd(v1.as_f32_slice(), v2.as_f32_slice());
|
||||
println!("d1 vs d2: {d1} vs {d2}");
|
||||
(d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-4
|
||||
}
|
||||
|
||||
#[quickcheck]
|
||||
fn prop_vector_distance_cos_rust_vs_simsimd_f64(
|
||||
v1: ArbitraryVector<100>,
|
||||
v2: ArbitraryVector<100>,
|
||||
) -> bool {
|
||||
let v1 = vector_convert(v1.into(), VectorType::Float64Dense).unwrap();
|
||||
let v2 = vector_convert(v2.into(), VectorType::Float64Dense).unwrap();
|
||||
let d1 = vector_f64_distance_cos_rust(v1.as_f64_slice(), v2.as_f64_slice());
|
||||
let d2 = vector_f64_distance_cos_simsimd(v1.as_f64_slice(), v2.as_f64_slice());
|
||||
println!("d1 vs d2: {d1} vs {d2}");
|
||||
(d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-6
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::{
|
||||
vector::vector_types::{Vector, VectorSparse, VectorType},
|
||||
LimboError, Result,
|
||||
};
|
||||
use simsimd::SpatialSimilarity;
|
||||
|
||||
pub fn vector_distance_l2(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
if v1.dims != v2.dims {
|
||||
@@ -15,12 +16,26 @@ pub fn vector_distance_l2(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
));
|
||||
}
|
||||
match v1.vector_type {
|
||||
VectorType::Float32Dense => {
|
||||
Ok(vector_f32_distance_l2(v1.as_f32_slice(), v2.as_f32_slice()))
|
||||
}
|
||||
VectorType::Float64Dense => {
|
||||
Ok(vector_f64_distance_l2(v1.as_f64_slice(), v2.as_f64_slice()))
|
||||
}
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
VectorType::Float32Dense => Ok(vector_f32_distance_l2_simsimd(
|
||||
v1.as_f32_slice(),
|
||||
v2.as_f32_slice(),
|
||||
)),
|
||||
#[cfg(target_family = "wasm")]
|
||||
VectorType::Float32Dense => Ok(vector_f32_distance_l2_rust(
|
||||
v1.as_f32_slice(),
|
||||
v2.as_f32_slice(),
|
||||
)),
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
VectorType::Float64Dense => Ok(vector_f64_distance_l2_simsimd(
|
||||
v1.as_f64_slice(),
|
||||
v2.as_f64_slice(),
|
||||
)),
|
||||
#[cfg(target_family = "wasm")]
|
||||
VectorType::Float64Dense => Ok(vector_f64_distance_l2_rust(
|
||||
v1.as_f64_slice(),
|
||||
v2.as_f64_slice(),
|
||||
)),
|
||||
VectorType::Float32Sparse => Ok(vector_f32_sparse_distance_l2(
|
||||
v1.as_f32_sparse(),
|
||||
v2.as_f32_sparse(),
|
||||
@@ -28,7 +43,14 @@ pub fn vector_distance_l2(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
}
|
||||
}
|
||||
|
||||
fn vector_f32_distance_l2(v1: &[f32], v2: &[f32]) -> f64 {
|
||||
#[allow(dead_code)]
|
||||
fn vector_f32_distance_l2_simsimd(v1: &[f32], v2: &[f32]) -> f64 {
|
||||
f32::euclidean(v1, v2).unwrap_or(f64::NAN)
|
||||
}
|
||||
|
||||
// SimSIMD do not support WASM for now, so we have alternative implementation: https://github.com/ashvardanian/SimSIMD/issues/189
|
||||
#[allow(dead_code)]
|
||||
fn vector_f32_distance_l2_rust(v1: &[f32], v2: &[f32]) -> f64 {
|
||||
let sum = v1
|
||||
.iter()
|
||||
.zip(v2.iter())
|
||||
@@ -37,7 +59,14 @@ fn vector_f32_distance_l2(v1: &[f32], v2: &[f32]) -> f64 {
|
||||
sum.sqrt()
|
||||
}
|
||||
|
||||
fn vector_f64_distance_l2(v1: &[f64], v2: &[f64]) -> f64 {
|
||||
#[allow(dead_code)]
|
||||
fn vector_f64_distance_l2_simsimd(v1: &[f64], v2: &[f64]) -> f64 {
|
||||
f64::euclidean(v1, v2).unwrap_or(f64::NAN)
|
||||
}
|
||||
|
||||
// SimSIMD do not support WASM for now, so we have alternative implementation: https://github.com/ashvardanian/SimSIMD/issues/189
|
||||
#[allow(dead_code)]
|
||||
fn vector_f64_distance_l2_rust(v1: &[f64], v2: &[f64]) -> f64 {
|
||||
let sum = v1
|
||||
.iter()
|
||||
.zip(v2.iter())
|
||||
@@ -102,7 +131,7 @@ mod tests {
|
||||
];
|
||||
let results = vectors
|
||||
.iter()
|
||||
.map(|v| vector_f32_distance_l2(&query, v))
|
||||
.map(|v| vector_f32_distance_l2_rust(&query, v))
|
||||
.collect::<Vec<f64>>();
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
@@ -111,41 +140,41 @@ mod tests {
|
||||
fn test_vector_distance_l2_odd_len() {
|
||||
let v = (0..5).map(|x| x as f32).collect::<Vec<f32>>();
|
||||
let query = (2..7).map(|x| x as f32).collect::<Vec<f32>>();
|
||||
assert_eq!(vector_f32_distance_l2(&v, &query), 20.0_f64.sqrt());
|
||||
assert_eq!(vector_f32_distance_l2_rust(&v, &query), 20.0_f64.sqrt());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_distance_l2_f32() {
|
||||
assert_eq!(vector_f32_distance_l2(&[], &[]), 0.0);
|
||||
assert_eq!(vector_f32_distance_l2_rust(&[], &[]), 0.0);
|
||||
assert_eq!(
|
||||
vector_f32_distance_l2(&[1.0, 2.0], &[0.0, 0.0]),
|
||||
vector_f32_distance_l2_rust(&[1.0, 2.0], &[0.0, 0.0]),
|
||||
(1f64 + 2f64 * 2f64).sqrt()
|
||||
);
|
||||
assert_eq!(vector_f32_distance_l2(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
|
||||
assert_eq!(vector_f32_distance_l2_rust(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
|
||||
assert_eq!(
|
||||
vector_f32_distance_l2(&[1.0, 2.0], &[-1.0, -2.0]),
|
||||
vector_f32_distance_l2_rust(&[1.0, 2.0], &[-1.0, -2.0]),
|
||||
(2f64 * 2f64 + 4f64 * 4f64).sqrt()
|
||||
);
|
||||
assert_eq!(
|
||||
vector_f32_distance_l2(&[1.0, 2.0], &[-2.0, 1.0]),
|
||||
vector_f32_distance_l2_rust(&[1.0, 2.0], &[-2.0, 1.0]),
|
||||
(3f64 * 3f64 + 1f64 * 1f64).sqrt()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_distance_l2_f64() {
|
||||
assert_eq!(vector_f64_distance_l2(&[], &[]), 0.0);
|
||||
assert_eq!(vector_f64_distance_l2_rust(&[], &[]), 0.0);
|
||||
assert_eq!(
|
||||
vector_f64_distance_l2(&[1.0, 2.0], &[0.0, 0.0]),
|
||||
vector_f64_distance_l2_rust(&[1.0, 2.0], &[0.0, 0.0]),
|
||||
(1f64 + 2f64 * 2f64).sqrt()
|
||||
);
|
||||
assert_eq!(vector_f64_distance_l2(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
|
||||
assert_eq!(vector_f64_distance_l2_rust(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
|
||||
assert_eq!(
|
||||
vector_f64_distance_l2(&[1.0, 2.0], &[-1.0, -2.0]),
|
||||
vector_f64_distance_l2_rust(&[1.0, 2.0], &[-1.0, -2.0]),
|
||||
(2f64 * 2f64 + 4f64 * 4f64).sqrt()
|
||||
);
|
||||
assert_eq!(
|
||||
vector_f64_distance_l2(&[1.0, 2.0], &[-2.0, 1.0]),
|
||||
vector_f64_distance_l2_rust(&[1.0, 2.0], &[-2.0, 1.0]),
|
||||
(3f64 * 3f64 + 1f64 * 1f64).sqrt()
|
||||
);
|
||||
}
|
||||
@@ -162,7 +191,7 @@ mod tests {
|
||||
idx: &[1, 2],
|
||||
values: &[1.0, 3.0]
|
||||
},
|
||||
) - vector_f32_distance_l2(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0]))
|
||||
) - vector_f32_distance_l2_rust(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0]))
|
||||
.abs()
|
||||
< 1e-7
|
||||
);
|
||||
@@ -183,4 +212,28 @@ mod tests {
|
||||
|
||||
(d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-6
|
||||
}
|
||||
|
||||
#[quickcheck]
|
||||
fn prop_vector_distance_l2_rust_vs_simsimd_f32(
|
||||
v1: ArbitraryVector<100>,
|
||||
v2: ArbitraryVector<100>,
|
||||
) -> bool {
|
||||
let v1 = vector_convert(v1.into(), VectorType::Float32Dense).unwrap();
|
||||
let v2 = vector_convert(v2.into(), VectorType::Float32Dense).unwrap();
|
||||
let d1 = vector_f32_distance_l2_rust(v1.as_f32_slice(), v2.as_f32_slice());
|
||||
let d2 = vector_f32_distance_l2_simsimd(v1.as_f32_slice(), v2.as_f32_slice());
|
||||
(d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-4
|
||||
}
|
||||
|
||||
#[quickcheck]
|
||||
fn prop_vector_distance_l2_rust_vs_simsimd_f64(
|
||||
v1: ArbitraryVector<100>,
|
||||
v2: ArbitraryVector<100>,
|
||||
) -> bool {
|
||||
let v1 = vector_convert(v1.into(), VectorType::Float64Dense).unwrap();
|
||||
let v2 = vector_convert(v2.into(), VectorType::Float64Dense).unwrap();
|
||||
let d1 = vector_f64_distance_l2_rust(v1.as_f64_slice(), v2.as_f64_slice());
|
||||
let d2 = vector_f64_distance_l2_simsimd(v1.as_f64_slice(), v2.as_f64_slice());
|
||||
(d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-6
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,17 +3,20 @@ use crate::{
|
||||
Value,
|
||||
};
|
||||
|
||||
pub fn vector_serialize(mut x: Vector) -> Value {
|
||||
pub fn vector_serialize(x: Vector) -> Value {
|
||||
match x.vector_type {
|
||||
VectorType::Float32Dense => Value::from_blob(x.data),
|
||||
VectorType::Float32Dense => Value::from_blob(x.bin_eject()),
|
||||
VectorType::Float64Dense => {
|
||||
x.data.push(2);
|
||||
Value::from_blob(x.data)
|
||||
let mut data = x.bin_eject();
|
||||
data.push(2);
|
||||
Value::from_blob(data)
|
||||
}
|
||||
VectorType::Float32Sparse => {
|
||||
x.data.extend_from_slice(&(x.dims as u32).to_le_bytes());
|
||||
x.data.push(9);
|
||||
Value::from_blob(x.data)
|
||||
let dims = x.dims;
|
||||
let mut data = x.bin_eject();
|
||||
data.extend_from_slice(&(dims as u32).to_le_bytes());
|
||||
data.push(9);
|
||||
Value::from_blob(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::{
|
||||
LimboError, Result,
|
||||
};
|
||||
|
||||
pub fn vector_slice(vector: &Vector, start: usize, end: usize) -> Result<Vector> {
|
||||
pub fn vector_slice(vector: &Vector, start: usize, end: usize) -> Result<Vector<'static>> {
|
||||
if start > end {
|
||||
return Err(LimboError::InvalidArgument(
|
||||
"start index must not be greater than end index".into(),
|
||||
@@ -18,12 +18,14 @@ pub fn vector_slice(vector: &Vector, start: usize, end: usize) -> Result<Vector>
|
||||
VectorType::Float32Dense => Ok(Vector {
|
||||
vector_type: vector.vector_type,
|
||||
dims: end - start,
|
||||
data: vector.data[start * 4..end * 4].to_vec(),
|
||||
owned: Some(vector.bin_data()[start * 4..end * 4].to_vec()),
|
||||
refer: None,
|
||||
}),
|
||||
VectorType::Float64Dense => Ok(Vector {
|
||||
vector_type: vector.vector_type,
|
||||
dims: end - start,
|
||||
data: vector.data[start * 8..end * 8].to_vec(),
|
||||
owned: Some(vector.bin_data()[start * 8..end * 8].to_vec()),
|
||||
refer: None,
|
||||
}),
|
||||
VectorType::Float32Sparse => {
|
||||
let mut values = Vec::new();
|
||||
@@ -41,7 +43,8 @@ pub fn vector_slice(vector: &Vector, start: usize, end: usize) -> Result<Vector>
|
||||
Ok(Vector {
|
||||
vector_type: vector.vector_type,
|
||||
dims: end - start,
|
||||
data: values,
|
||||
owned: Some(values),
|
||||
refer: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -63,7 +66,8 @@ mod tests {
|
||||
Vector {
|
||||
vector_type: VectorType::Float32Dense,
|
||||
dims: slice.len(),
|
||||
data,
|
||||
owned: Some(data),
|
||||
refer: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +56,8 @@ pub fn vector_from_text(vector_type: VectorType, text: &str) -> Result<Vector> {
|
||||
Vector {
|
||||
vector_type,
|
||||
dims: 0,
|
||||
data: Vec::new(),
|
||||
owned: Some(Vec::new()),
|
||||
refer: None,
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -69,7 +70,7 @@ pub fn vector_from_text(vector_type: VectorType, text: &str) -> Result<Vector> {
|
||||
}
|
||||
}
|
||||
|
||||
fn vector32_from_text<'a>(tokens: impl Iterator<Item = &'a str>) -> Result<Vector> {
|
||||
fn vector32_from_text<'a>(tokens: impl Iterator<Item = &'a str>) -> Result<Vector<'static>> {
|
||||
let mut data = Vec::new();
|
||||
for token in tokens {
|
||||
let value = token
|
||||
@@ -85,11 +86,12 @@ fn vector32_from_text<'a>(tokens: impl Iterator<Item = &'a str>) -> Result<Vecto
|
||||
Ok(Vector {
|
||||
vector_type: VectorType::Float32Dense,
|
||||
dims: data.len() / 4,
|
||||
data,
|
||||
owned: Some(data),
|
||||
refer: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn vector64_from_text<'a>(tokens: impl Iterator<Item = &'a str>) -> Result<Vector> {
|
||||
fn vector64_from_text<'a>(tokens: impl Iterator<Item = &'a str>) -> Result<Vector<'static>> {
|
||||
let mut data = Vec::new();
|
||||
for token in tokens {
|
||||
let value = token
|
||||
@@ -105,11 +107,12 @@ fn vector64_from_text<'a>(tokens: impl Iterator<Item = &'a str>) -> Result<Vecto
|
||||
Ok(Vector {
|
||||
vector_type: VectorType::Float64Dense,
|
||||
dims: data.len() / 8,
|
||||
data,
|
||||
owned: Some(data),
|
||||
refer: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn vector32_sparse_from_text<'a>(tokens: impl Iterator<Item = &'a str>) -> Result<Vector> {
|
||||
fn vector32_sparse_from_text<'a>(tokens: impl Iterator<Item = &'a str>) -> Result<Vector<'static>> {
|
||||
let mut idx = Vec::new();
|
||||
let mut values = Vec::new();
|
||||
let mut dims = 0u32;
|
||||
@@ -135,6 +138,7 @@ fn vector32_sparse_from_text<'a>(tokens: impl Iterator<Item = &'a str>) -> Resul
|
||||
Ok(Vector {
|
||||
vector_type: VectorType::Float32Sparse,
|
||||
dims: dims as usize,
|
||||
data: values,
|
||||
owned: Some(values),
|
||||
refer: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,10 +8,11 @@ pub enum VectorType {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Vector {
|
||||
pub struct Vector<'a> {
|
||||
pub vector_type: VectorType,
|
||||
pub dims: usize,
|
||||
pub data: Vec<u8>,
|
||||
pub owned: Option<Vec<u8>>,
|
||||
pub refer: Option<&'a [u8]>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -20,14 +21,14 @@ pub struct VectorSparse<'a, T: std::fmt::Debug> {
|
||||
pub values: &'a [T],
|
||||
}
|
||||
|
||||
impl Vector {
|
||||
pub fn vector_type(mut blob: Vec<u8>) -> Result<(VectorType, Vec<u8>)> {
|
||||
impl<'a> Vector<'a> {
|
||||
pub fn vector_type(blob: &[u8]) -> Result<(VectorType, usize)> {
|
||||
// Even-sized blobs are always float32.
|
||||
if blob.len() % 2 == 0 {
|
||||
return Ok((VectorType::Float32Dense, blob));
|
||||
return Ok((VectorType::Float32Dense, blob.len()));
|
||||
}
|
||||
// Odd-sized blobs have type byte at the end
|
||||
let vector_type = blob.pop().unwrap();
|
||||
let vector_type = blob[blob.len() - 1];
|
||||
/*
|
||||
vector types used by LibSQL:
|
||||
(see https://github.com/tursodatabase/libsql/blob/a55bf61192bdb89e97568de593c4af5b70d24bde/libsql-sqlite3/src/vectorInt.h#L52)
|
||||
@@ -39,12 +40,12 @@ impl Vector {
|
||||
#define VECTOR_TYPE_FLOATB16 6
|
||||
*/
|
||||
match vector_type {
|
||||
1 => Ok((VectorType::Float32Dense, blob)),
|
||||
2 => Ok((VectorType::Float64Dense, blob)),
|
||||
1 => Ok((VectorType::Float32Dense, blob.len() - 1)),
|
||||
2 => Ok((VectorType::Float64Dense, blob.len() - 1)),
|
||||
3..=6 => Err(LimboError::ConversionError(
|
||||
"unsupported vector type from LibSQL".to_string(),
|
||||
)),
|
||||
9 => Ok((VectorType::Float32Sparse, blob)),
|
||||
9 => Ok((VectorType::Float32Sparse, blob.len() - 1)),
|
||||
_ => Err(LimboError::ConversionError(format!(
|
||||
"unknown vector type: {vector_type}"
|
||||
))),
|
||||
@@ -63,7 +64,8 @@ impl Vector {
|
||||
Self {
|
||||
vector_type: VectorType::Float32Dense,
|
||||
dims,
|
||||
data: values,
|
||||
owned: Some(values),
|
||||
refer: None,
|
||||
}
|
||||
}
|
||||
pub fn from_f64(mut values_f64: Vec<f64>) -> Self {
|
||||
@@ -79,7 +81,8 @@ impl Vector {
|
||||
Self {
|
||||
vector_type: VectorType::Float64Dense,
|
||||
dims,
|
||||
data: values,
|
||||
owned: Some(values),
|
||||
refer: None,
|
||||
}
|
||||
}
|
||||
pub fn from_f32_sparse(dims: usize, mut values_f32: Vec<f32>, mut idx_u32: Vec<u32>) -> Self {
|
||||
@@ -105,14 +108,27 @@ impl Vector {
|
||||
Self {
|
||||
vector_type: VectorType::Float32Sparse,
|
||||
dims,
|
||||
data: values,
|
||||
owned: Some(values),
|
||||
refer: None,
|
||||
}
|
||||
}
|
||||
pub fn from_blob(blob: Vec<u8>) -> Result<Self> {
|
||||
let (vector_type, data) = Self::vector_type(blob)?;
|
||||
Self::from_data(vector_type, data)
|
||||
pub fn from_vec(mut blob: Vec<u8>) -> Result<Self> {
|
||||
let (vector_type, len) = Self::vector_type(&blob)?;
|
||||
blob.truncate(len);
|
||||
Self::from_data(vector_type, Some(blob), None)
|
||||
}
|
||||
pub fn from_data(vector_type: VectorType, mut data: Vec<u8>) -> Result<Self> {
|
||||
pub fn from_slice(blob: &'a [u8]) -> Result<Self> {
|
||||
let (vector_type, len) = Self::vector_type(blob)?;
|
||||
Self::from_data(vector_type, None, Some(&blob[..len]))
|
||||
}
|
||||
pub fn from_data(
|
||||
vector_type: VectorType,
|
||||
owned: Option<Vec<u8>>,
|
||||
refer: Option<&'a [u8]>,
|
||||
) -> Result<Self> {
|
||||
let owned_slice = owned.as_deref();
|
||||
let refer_slice = refer.as_ref().map(|&x| x);
|
||||
let data = owned_slice.unwrap_or_else(|| refer_slice.unwrap());
|
||||
match vector_type {
|
||||
VectorType::Float32Dense => {
|
||||
if data.len() % 4 != 0 {
|
||||
@@ -124,7 +140,8 @@ impl Vector {
|
||||
Ok(Vector {
|
||||
vector_type,
|
||||
dims: data.len() / 4,
|
||||
data,
|
||||
owned,
|
||||
refer,
|
||||
})
|
||||
}
|
||||
VectorType::Float64Dense => {
|
||||
@@ -137,7 +154,8 @@ impl Vector {
|
||||
Ok(Vector {
|
||||
vector_type,
|
||||
dims: data.len() / 8,
|
||||
data,
|
||||
owned,
|
||||
refer,
|
||||
})
|
||||
}
|
||||
VectorType::Float32Sparse => {
|
||||
@@ -147,17 +165,41 @@ impl Vector {
|
||||
data.len(),
|
||||
)));
|
||||
}
|
||||
let dims_bytes = data.split_off(data.len() - 4);
|
||||
let original_len = data.len();
|
||||
let dims_bytes = &data[original_len - 4..];
|
||||
let dims = u32::from_le_bytes(dims_bytes.try_into().unwrap()) as usize;
|
||||
let owned = owned.map(|mut x| {
|
||||
x.truncate(original_len - 4);
|
||||
x
|
||||
});
|
||||
let refer = refer.map(|x| &x[0..original_len - 4]);
|
||||
let vector = Vector {
|
||||
vector_type,
|
||||
dims,
|
||||
data,
|
||||
owned,
|
||||
refer,
|
||||
};
|
||||
Ok(vector)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bin_len(&self) -> usize {
|
||||
let owned = self.owned.as_ref().map(|x| x.len());
|
||||
let refer = self.refer.as_ref().map(|x| x.len());
|
||||
owned.unwrap_or_else(|| refer.unwrap())
|
||||
}
|
||||
|
||||
pub fn bin_data(&'a self) -> &'a [u8] {
|
||||
let owned = self.owned.as_deref();
|
||||
let refer = self.refer.as_ref().map(|&x| x);
|
||||
owned.unwrap_or_else(|| refer.unwrap())
|
||||
}
|
||||
|
||||
pub fn bin_eject(self) -> Vec<u8> {
|
||||
self.owned.unwrap_or_else(|| self.refer.unwrap().to_vec())
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// This method is used to reinterpret the underlying `Vec<u8>` data
|
||||
@@ -171,12 +213,12 @@ impl Vector {
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
self.data.len(),
|
||||
self.bin_len(),
|
||||
self.dims * std::mem::size_of::<f32>(),
|
||||
"data length must equal dims * size_of::<f32>()"
|
||||
);
|
||||
|
||||
let ptr = self.data.as_ptr();
|
||||
let ptr = self.bin_data().as_ptr();
|
||||
let align = std::mem::align_of::<f32>();
|
||||
assert_eq!(
|
||||
ptr.align_offset(align),
|
||||
@@ -200,12 +242,12 @@ impl Vector {
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
self.data.len(),
|
||||
self.bin_len(),
|
||||
self.dims * std::mem::size_of::<f64>(),
|
||||
"data length must equal dims * size_of::<f64>()"
|
||||
);
|
||||
|
||||
let ptr = self.data.as_ptr();
|
||||
let ptr = self.bin_data().as_ptr();
|
||||
let align = std::mem::align_of::<f64>();
|
||||
assert_eq!(
|
||||
ptr.align_offset(align),
|
||||
@@ -218,14 +260,14 @@ impl Vector {
|
||||
|
||||
pub fn as_f32_sparse(&self) -> VectorSparse<'_, f32> {
|
||||
debug_assert!(self.vector_type == VectorType::Float32Sparse);
|
||||
let ptr = self.data.as_ptr();
|
||||
let ptr = self.bin_data().as_ptr();
|
||||
let align = std::mem::align_of::<f32>();
|
||||
assert_eq!(
|
||||
ptr.align_offset(align),
|
||||
0,
|
||||
"data pointer must be aligned to {align} bytes for f32 access"
|
||||
);
|
||||
let length = self.data.len() / 4 / 2;
|
||||
let length = self.bin_data().len() / 4 / 2;
|
||||
let values = unsafe { std::slice::from_raw_parts(ptr as *const f32, length) };
|
||||
let idx = unsafe { std::slice::from_raw_parts((ptr as *const u32).add(length), length) };
|
||||
debug_assert!(idx.is_sorted());
|
||||
@@ -292,12 +334,13 @@ pub(crate) mod tests {
|
||||
}
|
||||
|
||||
/// Convert an ArbitraryVector to a Vector.
|
||||
impl<const DIMS: usize> From<ArbitraryVector<DIMS>> for Vector {
|
||||
impl<const DIMS: usize> From<ArbitraryVector<DIMS>> for Vector<'static> {
|
||||
fn from(v: ArbitraryVector<DIMS>) -> Self {
|
||||
Vector {
|
||||
vector_type: v.vector_type,
|
||||
dims: DIMS,
|
||||
data: v.data,
|
||||
owned: Some(v.data),
|
||||
refer: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -357,7 +400,7 @@ pub(crate) mod tests {
|
||||
let vtype = v.vector_type;
|
||||
let value = operations::serialize::vector_serialize(v);
|
||||
let blob = value.to_blob().unwrap().to_vec();
|
||||
match Vector::vector_type(blob) {
|
||||
match Vector::vector_type(&blob) {
|
||||
Ok((detected_type, _)) => detected_type == vtype,
|
||||
Err(_) => false,
|
||||
}
|
||||
@@ -396,12 +439,12 @@ pub(crate) mod tests {
|
||||
VectorType::Float32Dense => {
|
||||
let slice = v.as_f32_slice();
|
||||
// Check if the slice length matches the dimensions and the data length is correct (4 bytes per float)
|
||||
slice.len() == DIMS && (slice.len() * 4 == v.data.len())
|
||||
slice.len() == DIMS && (slice.len() * 4 == v.bin_len())
|
||||
}
|
||||
VectorType::Float64Dense => {
|
||||
let slice = v.as_f64_slice();
|
||||
// Check if the slice length matches the dimensions and the data length is correct (8 bytes per float)
|
||||
slice.len() == DIMS && (slice.len() * 8 == v.data.len())
|
||||
slice.len() == DIMS && (slice.len() * 8 == v.bin_len())
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
@@ -454,12 +497,14 @@ pub(crate) mod tests {
|
||||
let a = Vector {
|
||||
vector_type: VectorType::Float32Dense,
|
||||
dims: 2,
|
||||
data: vec![0, 0, 0, 0, 52, 208, 106, 63],
|
||||
owned: Some(vec![0, 0, 0, 0, 52, 208, 106, 63]),
|
||||
refer: None,
|
||||
};
|
||||
let b = Vector {
|
||||
vector_type: VectorType::Float32Dense,
|
||||
dims: 2,
|
||||
data: vec![0, 0, 0, 0, 58, 100, 45, 192],
|
||||
owned: Some(vec![0, 0, 0, 0, 58, 100, 45, 192]),
|
||||
refer: None,
|
||||
};
|
||||
assert!(
|
||||
(operations::distance_cos::vector_distance_cos(&a, &b).unwrap() - 2.0).abs() <= 1e-6
|
||||
|
||||
Reference in New Issue
Block a user