diff --git a/Cargo.lock b/Cargo.lock index c0c2abab1..9edc113b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -672,6 +672,16 @@ dependencies = [ "log", ] +[[package]] +name = "env_logger" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" +dependencies = [ + "log", + "regex", +] + [[package]] name = "env_logger" version = "0.10.2" @@ -1447,6 +1457,9 @@ name = "limbo_vector" version = "0.0.13" dependencies = [ "limbo_ext", + "quickcheck", + "quickcheck_macros", + "rand", ] [[package]] @@ -2053,6 +2066,28 @@ dependencies = [ "memchr", ] +[[package]] +name = "quickcheck" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6" +dependencies = [ + "env_logger 0.8.4", + "log", + "rand", +] + +[[package]] +name = "quickcheck_macros" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b22a693222d716a9587786f37ac3f6b4faedb5b80c23914e7303ff5a1d8016e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "quote" version = "1.0.38" diff --git a/extensions/vector/Cargo.toml b/extensions/vector/Cargo.toml index e80863fd2..e6886ba1a 100644 --- a/extensions/vector/Cargo.toml +++ b/extensions/vector/Cargo.toml @@ -11,6 +11,12 @@ crate-type = ["cdylib", "lib"] [features] static= [ "limbo_ext/static" ] +default = ["quickcheck/default"] [dependencies] limbo_ext = { path = "../core", features = ["static"] } + +[dev-dependencies] +quickcheck = { version = "1.0", default-features = false } +quickcheck_macros = { version = "1.0", default-features = false } +rand = "0.8" # Required for quickcheck diff --git a/extensions/vector/src/vector.rs b/extensions/vector/src/vector.rs index ba6f8128f..7c1e7c77d 100644 --- a/extensions/vector/src/vector.rs +++ b/extensions/vector/src/vector.rs @@ -2,7 +2,7 @@ use limbo_ext::{Value, ValueType}; use crate::{Error, Result}; -#[derive(Debug, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum VectorType { Float32, Float64, @@ -68,10 +68,16 @@ pub fn parse_string_vector(vector_type: VectorType, value: &Value) -> Result { let x = x.parse::().map_err(|_| Error::InvalidFormat)?; + if !x.is_finite() { + return Err(Error::InvalidFormat); + } data.extend_from_slice(&x.to_le_bytes()); } VectorType::Float64 => { let x = x.parse::().map_err(|_| Error::InvalidFormat)?; + if !x.is_finite() { + return Err(Error::InvalidFormat); + } data.extend_from_slice(&x.to_le_bytes()); } }; @@ -181,6 +187,12 @@ pub fn vector_f32_distance_cos(v1: &Vector, v2: &Vector) -> Result { let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0); let v1_data = v1.as_f32_slice(); let v2_data = v2.as_f32_slice(); + + // Check for non-finite values + if v1_data.iter().any(|x| !x.is_finite()) || v2_data.iter().any(|x| !x.is_finite()) { + return Err(Error::InvalidFormat); + } + for i in 0..v1.dims { let e1 = v1_data[i]; let e2 = v2_data[i]; @@ -188,6 +200,12 @@ pub fn vector_f32_distance_cos(v1: &Vector, v2: &Vector) -> Result { norm1 += e1 * e1; norm2 += e2 * e2; } + + // Check for zero norms to avoid division by zero + if norm1 == 0.0 || norm2 == 0.0 { + return Err(Error::InvalidFormat); + } + Ok(1.0 - (dot / (norm1 * norm2).sqrt()) as f64) } @@ -201,6 +219,12 @@ pub fn vector_f64_distance_cos(v1: &Vector, v2: &Vector) -> Result { let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0); let v1_data = v1.as_f64_slice(); let v2_data = v2.as_f64_slice(); + + // Check for non-finite values + if v1_data.iter().any(|x| !x.is_finite()) || v2_data.iter().any(|x| !x.is_finite()) { + return Err(Error::InvalidFormat); + } + for i in 0..v1.dims { let e1 = v1_data[i]; let e2 = v2_data[i]; @@ -208,6 +232,12 @@ pub fn vector_f64_distance_cos(v1: &Vector, v2: &Vector) -> Result { norm1 += e1 * e1; norm2 += e2 * e2; } + + // Check for zero norms + if norm1 == 0.0 || norm2 == 0.0 { + return Err(Error::InvalidFormat); + } + Ok(1.0 - (dot / (norm1 * norm2).sqrt())) } @@ -242,6 +272,250 @@ pub fn vector_type(blob: &[u8]) -> Result { #[cfg(test)] mod tests { use super::*; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; + + // Helper to generate arbitrary vectors of specific type and dimensions + #[derive(Debug, Clone)] + struct ArbitraryVector { + vector_type: VectorType, + data: Vec, + } + + /// How to create an arbitrary vector of DIMS dims. + impl ArbitraryVector { + fn generate_f32_vector(g: &mut Gen) -> Vec { + (0..DIMS) + .map(|_| { + loop { + let f = f32::arbitrary(g); + // f32::arbitrary() can generate "problem values" like NaN, infinity, and very small values + // Skip these values + if f.is_finite() && f.abs() >= 1e-6 { + // Scale to [-1, 1] range + return f % 2.0 - 1.0; + } + } + }) + .collect() + } + + fn generate_f64_vector(g: &mut Gen) -> Vec { + (0..DIMS) + .map(|_| { + loop { + let f = f64::arbitrary(g); + // f64::arbitrary() can generate "problem values" like NaN, infinity, and very small values + // Skip these values + if f.is_finite() && f.abs() >= 1e-6 { + // Scale to [-1, 1] range + return f % 2.0 - 1.0; + } + } + }) + .collect() + } + } + + /// Convert an ArbitraryVector to a Vector. + impl From> for Vector { + fn from(v: ArbitraryVector) -> Self { + Vector { + vector_type: v.vector_type, + dims: DIMS, + data: v.data, + } + } + } + + /// Implement the quickcheck Arbitrary trait for ArbitraryVector. + impl Arbitrary for ArbitraryVector { + fn arbitrary(g: &mut Gen) -> Self { + let vector_type = if bool::arbitrary(g) { + VectorType::Float32 + } else { + VectorType::Float64 + }; + + let data = match vector_type { + VectorType::Float32 => { + let floats = Self::generate_f32_vector(g); + floats.iter().flat_map(|f| f.to_le_bytes()).collect() + } + VectorType::Float64 => { + let floats = Self::generate_f64_vector(g); + floats.iter().flat_map(|f| f.to_le_bytes()).collect() + } + }; + + ArbitraryVector { vector_type, data } + } + } + + #[quickcheck] + fn prop_vector_type_identification_2d(v: ArbitraryVector<2>) -> bool { + test_vector_type::<2>(v.into()) + } + + #[quickcheck] + fn prop_vector_type_identification_3d(v: ArbitraryVector<3>) -> bool { + test_vector_type::<3>(v.into()) + } + + #[quickcheck] + fn prop_vector_type_identification_4d(v: ArbitraryVector<4>) -> bool { + test_vector_type::<4>(v.into()) + } + + #[quickcheck] + fn prop_vector_type_identification_100d(v: ArbitraryVector<100>) -> bool { + test_vector_type::<100>(v.into()) + } + + #[quickcheck] + fn prop_vector_type_identification_1536d(v: ArbitraryVector<1536>) -> bool { + test_vector_type::<1536>(v.into()) + } + + /// Test if the vector type identification is correct for a given vector. + fn test_vector_type(v: Vector) -> bool { + let vtype = v.vector_type.clone(); + let value = match &vtype { + VectorType::Float32 => vector_serialize_f32(v), + VectorType::Float64 => vector_serialize_f64(v), + }; + + let blob = value.to_blob().unwrap(); + match vector_type(&blob) { + Ok(detected_type) => detected_type == vtype, + Err(_) => false, + } + } + + #[quickcheck] + fn prop_slice_conversion_safety_2d(v: ArbitraryVector<2>) -> bool { + test_slice_conversion::<2>(v.into()) + } + + #[quickcheck] + fn prop_slice_conversion_safety_3d(v: ArbitraryVector<3>) -> bool { + test_slice_conversion::<3>(v.into()) + } + + #[quickcheck] + fn prop_slice_conversion_safety_4d(v: ArbitraryVector<4>) -> bool { + test_slice_conversion::<4>(v.into()) + } + + #[quickcheck] + fn prop_slice_conversion_safety_100d(v: ArbitraryVector<100>) -> bool { + test_slice_conversion::<100>(v.into()) + } + + #[quickcheck] + fn prop_slice_conversion_safety_1536d(v: ArbitraryVector<1536>) -> bool { + test_slice_conversion::<1536>(v.into()) + } + + /// Test if the slice conversion is safe for a given vector: + /// - The slice length matches the dimensions + /// - The data length is correct (4 bytes per float for f32, 8 bytes per float for f64) + fn test_slice_conversion(v: Vector) -> bool { + match v.vector_type { + VectorType::Float32 => { + let slice = v.as_f32_slice(); + // Check if the slice length matches the dimensions and the data length is correct (4 bytes per float) + slice.len() == DIMS && (slice.len() * 4 == v.data.len()) + } + VectorType::Float64 => { + let slice = v.as_f64_slice(); + // Check if the slice length matches the dimensions and the data length is correct (8 bytes per float) + slice.len() == DIMS && (slice.len() * 8 == v.data.len()) + } + } + } + + // Test size_to_dims calculation with different dimensions + #[quickcheck] + fn prop_size_to_dims_calculation_2d(v: ArbitraryVector<2>) -> bool { + test_size_to_dims::<2>(v.into()) + } + + #[quickcheck] + fn prop_size_to_dims_calculation_3d(v: ArbitraryVector<3>) -> bool { + test_size_to_dims::<3>(v.into()) + } + + #[quickcheck] + fn prop_size_to_dims_calculation_4d(v: ArbitraryVector<4>) -> bool { + test_size_to_dims::<4>(v.into()) + } + + #[quickcheck] + fn prop_size_to_dims_calculation_100d(v: ArbitraryVector<100>) -> bool { + test_size_to_dims::<100>(v.into()) + } + + #[quickcheck] + fn prop_size_to_dims_calculation_1536d(v: ArbitraryVector<1536>) -> bool { + test_size_to_dims::<1536>(v.into()) + } + + /// Test if the size_to_dims calculation is correct for a given vector. + fn test_size_to_dims(v: Vector) -> bool { + let size = v.data.len(); + let calculated_dims = v.vector_type.size_to_dims(size); + calculated_dims == DIMS + } + + #[quickcheck] + fn prop_vector_distance_safety_2d(v1: ArbitraryVector<2>, v2: ArbitraryVector<2>) -> bool { + test_vector_distance::<2>(&v1.into(), &v2.into()) + } + + #[quickcheck] + fn prop_vector_distance_safety_3d(v1: ArbitraryVector<3>, v2: ArbitraryVector<3>) -> bool { + test_vector_distance::<3>(&v1.into(), &v2.into()) + } + + #[quickcheck] + fn prop_vector_distance_safety_4d(v1: ArbitraryVector<4>, v2: ArbitraryVector<4>) -> bool { + test_vector_distance::<4>(&v1.into(), &v2.into()) + } + + #[quickcheck] + fn prop_vector_distance_safety_100d( + v1: ArbitraryVector<100>, + v2: ArbitraryVector<100>, + ) -> bool { + test_vector_distance::<100>(&v1.into(), &v2.into()) + } + + #[quickcheck] + fn prop_vector_distance_safety_1536d( + v1: ArbitraryVector<1536>, + v2: ArbitraryVector<1536>, + ) -> bool { + test_vector_distance::<1536>(&v1.into(), &v2.into()) + } + + /// Test if the vector distance calculation is correct for a given pair of vectors: + /// - The vectors have the same dimensions + /// - The vectors have the same type + /// - The distance must be between 0 and 2 + fn test_vector_distance(v1: &Vector, v2: &Vector) -> bool { + if v1.vector_type != v2.vector_type { + // Skip test if types are different + return true; + } + match do_vector_distance_cos(&v1, &v2) { + Ok(distance) => { + // Cosine distance is always between 0 and 2 + distance >= 0.0 && distance <= 2.0 + } + Err(_) => false, + } + } #[test] fn parse_string_vector_zero_length() {