diff --git a/core/vector/operations/concat.rs b/core/vector/operations/concat.rs index 3504ba5bd..3e7f6a4f1 100644 --- a/core/vector/operations/concat.rs +++ b/core/vector/operations/concat.rs @@ -25,7 +25,6 @@ pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result { data.extend_from_slice(&v2.data[v2.data.len() / 2..]); data } - _ => todo!(), }; Ok(Vector { diff --git a/core/vector/operations/distance_cos.rs b/core/vector/operations/distance_cos.rs index e211ee0e4..19fb17688 100644 --- a/core/vector/operations/distance_cos.rs +++ b/core/vector/operations/distance_cos.rs @@ -27,7 +27,6 @@ pub fn vector_distance_cos(v1: &Vector, v2: &Vector) -> Result { v1.as_f32_sparse(), v2.as_f32_sparse(), )), - _ => todo!(), } } @@ -98,7 +97,7 @@ fn vector_f32_sparse_distance_cos(v1: VectorSparse, v2: VectorSparse) v1_pos += 1; } while v2_pos < v2.idx.len() { - norm1 += v2.values[v2_pos] * v2.values[v2_pos]; + norm2 += v2.values[v2_pos] * v2.values[v2_pos]; v2_pos += 1; } @@ -131,4 +130,22 @@ mod tests { assert_eq!(vector_f64_distance_cos(&[1.0, 2.0], &[-1.0, -2.0]), 2.0); assert_eq!(vector_f64_distance_cos(&[1.0, 2.0], &[-2.0, 1.0]), 1.0); } + + #[test] + fn test_distance_cos_f32_sparse() { + assert!( + (vector_f32_sparse_distance_cos( + VectorSparse { + idx: &[0, 1], + values: &[1.0, 2.0] + }, + VectorSparse { + idx: &[1, 2], + values: &[1.0, 3.0] + }, + ) - vector_f32_distance_cos(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0])) + .abs() + < 1e-7 + ); + } } diff --git a/core/vector/operations/distance_l2.rs b/core/vector/operations/distance_l2.rs index 4095632d3..2599fd4d3 100644 --- a/core/vector/operations/distance_l2.rs +++ b/core/vector/operations/distance_l2.rs @@ -25,7 +25,6 @@ pub fn vector_distance_l2(v1: &Vector, v2: &Vector) -> Result { v1.as_f32_sparse(), v2.as_f32_sparse(), )), - _ => todo!(), } } @@ -57,12 +56,22 @@ fn vector_f32_sparse_distance_l2(v1: VectorSparse, v2: VectorSparse) - v1_pos += 1; v2_pos += 1; } else if v1.idx[v1_pos] < v2.idx[v2_pos] { + sum += v1.values[v1_pos].powi(2); v1_pos += 1; } else { + sum += v2.values[v2_pos].powi(2); v2_pos += 1; } } - sum as f64 + while v1_pos < v1.idx.len() { + sum += v1.values[v1_pos].powi(2); + v1_pos += 1; + } + while v2_pos < v2.idx.len() { + sum += v2.values[v2_pos].powi(2); + v2_pos += 1; + } + (sum as f64).sqrt() } #[cfg(test)] @@ -98,4 +107,58 @@ mod tests { let query = (2..7).map(|x| x as f32).collect::>(); assert_eq!(vector_f32_distance_l2(&v, &query), 20.0_f64.sqrt()); } + + #[test] + fn test_distance_l2_f32() { + assert_eq!(vector_f32_distance_l2(&[], &[]), 0.0); + assert_eq!( + vector_f32_distance_l2(&[1.0, 2.0], &[0.0, 0.0]), + (1f64 + 2f64 * 2f64).sqrt() + ); + assert_eq!(vector_f32_distance_l2(&[1.0, 2.0], &[1.0, 2.0]), 0.0); + assert_eq!( + vector_f32_distance_l2(&[1.0, 2.0], &[-1.0, -2.0]), + (2f64 * 2f64 + 4f64 * 4f64).sqrt() + ); + assert_eq!( + vector_f32_distance_l2(&[1.0, 2.0], &[-2.0, 1.0]), + (3f64 * 3f64 + 1f64 * 1f64).sqrt() + ); + } + + #[test] + fn test_distance_l2_f64() { + assert_eq!(vector_f64_distance_l2(&[], &[]), 0.0); + assert_eq!( + vector_f64_distance_l2(&[1.0, 2.0], &[0.0, 0.0]), + (1f64 + 2f64 * 2f64).sqrt() + ); + assert_eq!(vector_f64_distance_l2(&[1.0, 2.0], &[1.0, 2.0]), 0.0); + assert_eq!( + vector_f64_distance_l2(&[1.0, 2.0], &[-1.0, -2.0]), + (2f64 * 2f64 + 4f64 * 4f64).sqrt() + ); + assert_eq!( + vector_f64_distance_l2(&[1.0, 2.0], &[-2.0, 1.0]), + (3f64 * 3f64 + 1f64 * 1f64).sqrt() + ); + } + + #[test] + fn test_distance_l2_f32_sparse() { + assert!( + (vector_f32_sparse_distance_l2( + VectorSparse { + idx: &[0, 1], + values: &[1.0, 2.0] + }, + VectorSparse { + idx: &[1, 2], + values: &[1.0, 3.0] + }, + ) - vector_f32_distance_l2(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0])) + .abs() + < 1e-7 + ); + } }