mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-09 10:14:21 +01:00
return Nan for cosine distance instead of error
- errors are hard to handle in case of some scan operations (something went wrong in the middle - whoe query aborted) - it will be more flexibly if we will return NaN and let user handle situation
This commit is contained in:
@@ -4,37 +4,35 @@ use crate::{
|
||||
};
|
||||
|
||||
pub fn vector_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
match v1.vector_type {
|
||||
VectorType::Float32Dense => vector_f32_distance_cos(v1, v2),
|
||||
VectorType::Float64Dense => vector_f64_distance_cos(v1, v2),
|
||||
}
|
||||
}
|
||||
|
||||
fn vector_f32_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
if v1.dims != v2.dims {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Invalid vector dimensions".to_string(),
|
||||
"Vectors must have the same dimensions".to_string(),
|
||||
));
|
||||
}
|
||||
if v1.vector_type != v2.vector_type {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Invalid vector type".to_string(),
|
||||
"Vectors must be of the same type".to_string(),
|
||||
));
|
||||
}
|
||||
match v1.vector_type {
|
||||
VectorType::Float32Dense => Ok(vector_f32_distance_cos(
|
||||
v1.as_f32_slice(),
|
||||
v2.as_f32_slice(),
|
||||
)),
|
||||
VectorType::Float64Dense => Ok(vector_f64_distance_cos(
|
||||
v1.as_f64_slice(),
|
||||
v2.as_f64_slice(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn vector_f32_distance_cos(v1: &[f32], v2: &[f32]) -> f64 {
|
||||
let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0);
|
||||
let v1_data = v1.as_f32_slice();
|
||||
let v2_data = v2.as_f32_slice();
|
||||
|
||||
// Check for non-finite values
|
||||
if v1_data.iter().any(|x| !x.is_finite()) || v2_data.iter().any(|x| !x.is_finite()) {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Invalid vector value".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
for i in 0..v1.dims {
|
||||
let e1 = v1_data[i];
|
||||
let e2 = v2_data[i];
|
||||
let dims = v1.len();
|
||||
for i in 0..dims {
|
||||
let e1 = v1[i];
|
||||
let e2 = v2[i];
|
||||
dot += e1 * e2;
|
||||
norm1 += e1 * e1;
|
||||
norm2 += e2 * e2;
|
||||
@@ -42,39 +40,19 @@ fn vector_f32_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
|
||||
// Check for zero norms to avoid division by zero
|
||||
if norm1 == 0.0 || norm2 == 0.0 {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Invalid vector value".to_string(),
|
||||
));
|
||||
return f64::NAN;
|
||||
}
|
||||
|
||||
Ok(1.0 - (dot / (norm1 * norm2).sqrt()) as f64)
|
||||
1.0 - (dot / (norm1 * norm2).sqrt()) as f64
|
||||
}
|
||||
|
||||
fn vector_f64_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
if v1.dims != v2.dims {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Invalid vector dimensions".to_string(),
|
||||
));
|
||||
}
|
||||
if v1.vector_type != v2.vector_type {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Invalid vector type".to_string(),
|
||||
));
|
||||
}
|
||||
fn vector_f64_distance_cos(v1: &[f64], v2: &[f64]) -> f64 {
|
||||
let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0);
|
||||
let v1_data = v1.as_f64_slice();
|
||||
let v2_data = v2.as_f64_slice();
|
||||
|
||||
// Check for non-finite values
|
||||
if v1_data.iter().any(|x| !x.is_finite()) || v2_data.iter().any(|x| !x.is_finite()) {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Invalid vector value".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
for i in 0..v1.dims {
|
||||
let e1 = v1_data[i];
|
||||
let e2 = v2_data[i];
|
||||
let dims = v1.len();
|
||||
for i in 0..dims {
|
||||
let e1 = v1[i];
|
||||
let e2 = v2[i];
|
||||
dot += e1 * e2;
|
||||
norm1 += e1 * e1;
|
||||
norm2 += e2 * e2;
|
||||
@@ -82,10 +60,22 @@ fn vector_f64_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
|
||||
// Check for zero norms
|
||||
if norm1 == 0.0 || norm2 == 0.0 {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Invalid vector value".to_string(),
|
||||
));
|
||||
return f64::NAN;
|
||||
}
|
||||
|
||||
Ok(1.0 - (dot / (norm1 * norm2).sqrt()))
|
||||
1.0 - (dot / (norm1 * norm2).sqrt())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_distance_cos_f32() {
|
||||
assert!(vector_f32_distance_cos(&[], &[]).is_nan());
|
||||
assert!(vector_f32_distance_cos(&[1.0, 2.0], &[0.0, 0.0]).is_nan());
|
||||
assert_eq!(vector_f32_distance_cos(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
|
||||
assert_eq!(vector_f32_distance_cos(&[1.0, 2.0], &[-1.0, -2.0]), 2.0);
|
||||
assert_eq!(vector_f32_distance_cos(&[1.0, 2.0], &[-2.0, 1.0]), 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::{
|
||||
};
|
||||
|
||||
pub fn vector_distance_l2(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
// Validate that both vectors have the same dimensions and type
|
||||
if v1.dims != v2.dims {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Vectors must have the same dimensions".to_string(),
|
||||
|
||||
Reference in New Issue
Block a user