use super::{DistanceCalculator, DistanceType}; use crate::vector::vector_types::{Vector, VectorType}; use crate::Result; #[derive(Debug, Clone)] pub struct Euclidean; impl DistanceCalculator for Euclidean { fn distance_type() -> DistanceType { DistanceType::Euclidean } fn calculate(v1: &Vector, v2: &Vector) -> Result { match v1.vector_type { VectorType::Float32 => Ok(euclidean_distance_f32(v1.as_f32_slice(), v2.as_f32_slice())), VectorType::Float64 => Ok(euclidean_distance_f64(v1.as_f64_slice(), v2.as_f64_slice())), } } } fn euclidean_distance_f32(v1: &[f32], v2: &[f32]) -> f64 { let sum = v1 .iter() .zip(v2.iter()) .map(|(a, b)| (a - b).powi(2)) .sum::() as f64; sum.sqrt() } fn euclidean_distance_f64(v1: &[f64], v2: &[f64]) -> f64 { let sum = v1 .iter() .zip(v2.iter()) .map(|(a, b)| (a - b).powi(2)) .sum::(); sum.sqrt() } #[cfg(test)] mod tests { use super::*; #[test] fn test_euclidean_distance_f32() { let vectors = [ (0..8).map(|x| x as f32).collect::>(), (1..9).map(|x| x as f32).collect::>(), (2..10).map(|x| x as f32).collect::>(), (3..11).map(|x| x as f32).collect::>(), ]; let query = (2..10).map(|x| x as f32).collect::>(); let expected: Vec = vec![ 32.0_f64.sqrt(), 8.0_f64.sqrt(), 0.0_f64.sqrt(), 8.0_f64.sqrt(), ]; let results = vectors .iter() .map(|v| euclidean_distance_f32(&query, v)) .collect::>(); assert_eq!(results, expected); } #[test] fn test_odd_len() { let v = (0..5).map(|x| x as f32).collect::>(); let query = (2..7).map(|x| x as f32).collect::>(); assert_eq!(euclidean_distance_f32(&v, &query), 20.0_f64.sqrt()); } }