diff --git a/core/vector/vector_types.rs b/core/vector/vector_types.rs index 2ec79ed1d..779f1db86 100644 --- a/core/vector/vector_types.rs +++ b/core/vector/vector_types.rs @@ -25,11 +25,59 @@ pub struct Vector { } impl Vector { + /// # Safety + /// + /// This method is used to reinterpret the underlying `Vec` data + /// as a `&[f32]` slice. This is only valid if: + /// - The buffer is correctly aligned for `f32` + /// - The length of the buffer is exactly `dims * size_of::()` pub fn as_f32_slice(&self) -> &[f32] { - unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const f32, self.dims) } + if self.dims == 0 { + return &[]; + } + + assert_eq!( + self.data.len(), + self.dims * std::mem::size_of::(), + "data length must equal dims * size_of::()" + ); + + let ptr = self.data.as_ptr(); + let align = std::mem::align_of::(); + assert_eq!( + ptr.align_offset(align), + 0, + "data pointer must be aligned to {align} bytes for f32 access" + ); + + unsafe { std::slice::from_raw_parts(ptr as *const f32, self.dims) } } + /// # Safety + /// + /// This method is used to reinterpret the underlying `Vec` data + /// as a `&[f64]` slice. This is only valid if: + /// - The buffer is correctly aligned for `f64` + /// - The length of the buffer is exactly `dims * size_of::()` pub fn as_f64_slice(&self) -> &[f64] { + if self.dims == 0 { + return &[]; + } + + assert_eq!( + self.data.len(), + self.dims * std::mem::size_of::(), + "data length must equal dims * size_of::()" + ); + + let ptr = self.data.as_ptr(); + let align = std::mem::align_of::(); + assert_eq!( + ptr.align_offset(align), + 0, + "data pointer must be aligned to {align} bytes for f64 access" + ); + unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const f64, self.dims) } } } @@ -281,11 +329,6 @@ pub fn vector_f64_distance_cos(v1: &Vector, v2: &Vector) -> Result { } pub fn vector_type(blob: &[u8]) -> Result { - if blob.is_empty() { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } // Even-sized blobs are always float32. if blob.len() % 2 == 0 { return Ok(VectorType::Float32); @@ -706,6 +749,7 @@ mod tests { let v2 = float32_vec_from(&[]); let result = vector_concat(&v1, &v2).unwrap(); assert_eq!(result.dims, 0); + assert_eq!(f32_slice_from_vector(&result), Vec::::new()); } #[test]