diff --git a/core/vector/mod.rs b/core/vector/mod.rs index 5a722c34c..dae03a7e1 100644 --- a/core/vector/mod.rs +++ b/core/vector/mod.rs @@ -13,15 +13,8 @@ pub fn vector32(args: &[Register]) -> Result { "vector32 requires exactly one argument".to_string(), )); } - let x = parse_vector(&args[0], Some(VectorType::Float32Dense))?; - // Extract the Vec from Value - if let Value::Blob(data) = vector_serialize_f32(x) { - Ok(Value::Blob(data)) - } else { - Err(LimboError::ConversionError( - "Failed to serialize vector".to_string(), - )) - } + let vector = parse_vector(&args[0], Some(VectorType::Float32Dense))?; + Ok(operations::serialize::vector_serialize(vector)) } pub fn vector64(args: &[Register]) -> Result { @@ -30,15 +23,8 @@ pub fn vector64(args: &[Register]) -> Result { "vector64 requires exactly one argument".to_string(), )); } - let x = parse_vector(&args[0], Some(VectorType::Float64Dense))?; - // Extract the Vec from Value - if let Value::Blob(data) = vector_serialize_f64(x) { - Ok(Value::Blob(data)) - } else { - Err(LimboError::ConversionError( - "Failed to serialize vector".to_string(), - )) - } + let vector = parse_vector(&args[0], Some(VectorType::Float64Dense))?; + Ok(operations::serialize::vector_serialize(vector)) } pub fn vector_extract(args: &[Register]) -> Result { @@ -63,7 +49,7 @@ pub fn vector_extract(args: &[Register]) -> Result { let vector_type = vector_type(blob)?; let vector = vector_deserialize(vector_type, blob)?; - Ok(Value::build_text(vector_to_text(&vector))) + Ok(Value::build_text(operations::text::vector_to_text(&vector))) } pub fn vector_distance_cos(args: &[Register]) -> Result { @@ -102,10 +88,7 @@ pub fn vector_concat(args: &[Register]) -> Result { let x = parse_vector(&args[0], None)?; let y = parse_vector(&args[1], None)?; let vector = operations::concat::vector_concat(&x, &y)?; - match vector.vector_type { - VectorType::Float32Dense => Ok(vector_serialize_f32(vector)), - VectorType::Float64Dense => Ok(vector_serialize_f64(vector)), - } + Ok(operations::serialize::vector_serialize(vector)) } pub fn vector_slice(args: &[Register]) -> Result { @@ -136,8 +119,5 @@ pub fn vector_slice(args: &[Register]) -> Result { let result = operations::slice::vector_slice(&vector, start_index as usize, end_index as usize)?; - Ok(match result.vector_type { - VectorType::Float32Dense => vector_serialize_f32(result), - VectorType::Float64Dense => vector_serialize_f64(result), - }) + Ok(operations::serialize::vector_serialize(result)) } diff --git a/core/vector/operations/deserialize.rs b/core/vector/operations/deserialize.rs new file mode 100644 index 000000000..e69de29bb diff --git a/core/vector/operations/mod.rs b/core/vector/operations/mod.rs index f136eb9bb..55b249aa8 100644 --- a/core/vector/operations/mod.rs +++ b/core/vector/operations/mod.rs @@ -1,4 +1,7 @@ pub mod concat; +pub mod deserialize; pub mod distance_cos; pub mod distance_l2; +pub mod serialize; pub mod slice; +pub mod text; diff --git a/core/vector/operations/serialize.rs b/core/vector/operations/serialize.rs new file mode 100644 index 000000000..fca0bae0b --- /dev/null +++ b/core/vector/operations/serialize.rs @@ -0,0 +1,22 @@ +use crate::{ + vector::vector_types::{Vector, VectorType}, + Value, +}; + +pub fn vector_serialize(x: Vector) -> Value { + match x.vector_type { + VectorType::Float32Dense => vector_f32_serialize(x), + VectorType::Float64Dense => vector_f64_serialize(x), + } +} + +fn vector_f64_serialize(x: Vector) -> Value { + let mut blob = Vec::with_capacity(x.dims * 8 + 1); + blob.extend_from_slice(&x.data); + blob.push(2); + Value::from_blob(blob) +} + +fn vector_f32_serialize(x: Vector) -> Value { + Value::from_blob(x.data) +} diff --git a/core/vector/operations/text.rs b/core/vector/operations/text.rs new file mode 100644 index 000000000..e522e89af --- /dev/null +++ b/core/vector/operations/text.rs @@ -0,0 +1,96 @@ +use crate::{ + vector::vector_types::{Vector, VectorType}, + LimboError, Result, +}; + +pub fn vector_to_text(vector: &Vector) -> String { + let mut text = String::new(); + text.push('['); + match vector.vector_type { + VectorType::Float32Dense => { + let data = vector.as_f32_slice(); + for (i, value) in data.iter().enumerate().take(vector.dims) { + text.push_str(&value.to_string()); + if i < vector.dims - 1 { + text.push(','); + } + } + } + VectorType::Float64Dense => { + let data = vector.as_f64_slice(); + for (i, value) in data.iter().enumerate().take(vector.dims) { + text.push_str(&value.to_string()); + if i < vector.dims - 1 { + text.push(','); + } + } + } + } + text.push(']'); + text +} + +/// Parse a vector in text representation into a Vector. +/// +/// The format of a vector in text representation looks as follows: +/// +/// ```console +/// [1.0, 2.0, 3.0] +/// ``` +pub fn vector_from_text(vector_type: VectorType, text: &str) -> Result { + let text = text.trim(); + let mut chars = text.chars(); + if chars.next() != Some('[') || chars.last() != Some(']') { + return Err(LimboError::ConversionError( + "Invalid vector value".to_string(), + )); + } + let mut data: Vec = Vec::new(); + let text = &text[1..text.len() - 1]; + if text.trim().is_empty() { + return Ok(Vector { + vector_type, + dims: 0, + data, + }); + } + let xs = text.split(','); + for x in xs { + let x = x.trim(); + if x.is_empty() { + return Err(LimboError::ConversionError( + "Invalid vector value".to_string(), + )); + } + match vector_type { + VectorType::Float32Dense => { + let x = x + .parse::() + .map_err(|_| LimboError::ConversionError("Invalid vector value".to_string()))?; + if !x.is_finite() { + return Err(LimboError::ConversionError( + "Invalid vector value".to_string(), + )); + } + data.extend_from_slice(&x.to_le_bytes()); + } + VectorType::Float64Dense => { + let x = x + .parse::() + .map_err(|_| LimboError::ConversionError("Invalid vector value".to_string()))?; + if !x.is_finite() { + return Err(LimboError::ConversionError( + "Invalid vector value".to_string(), + )); + } + data.extend_from_slice(&x.to_le_bytes()); + } + }; + } + let dims = vector_type.size_to_dims(data.len()); + Ok(Vector { + vector_type, + dims, + data, + }) +} diff --git a/core/vector/vector_types.rs b/core/vector/vector_types.rs index d396414e7..c13d7e1ba 100644 --- a/core/vector/vector_types.rs +++ b/core/vector/vector_types.rs @@ -1,5 +1,6 @@ -use crate::types::{Value, ValueType}; +use crate::types::ValueType; use crate::vdbe::Register; +use crate::vector::operations; use crate::{LimboError, Result}; #[derive(Debug, Clone, PartialEq, Copy)] @@ -82,81 +83,11 @@ impl Vector { } } -/// Parse a vector in text representation into a Vector. -/// -/// The format of a vector in text representation looks as follows: -/// -/// ```console -/// [1.0, 2.0, 3.0] -/// ``` -pub fn parse_string_vector(vector_type: VectorType, value: &Value) -> Result { - let Some(text) = value.to_text() else { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - }; - let text = text.trim(); - let mut chars = text.chars(); - if chars.next() != Some('[') || chars.last() != Some(']') { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - let mut data: Vec = Vec::new(); - let text = &text[1..text.len() - 1]; - if text.trim().is_empty() { - return Ok(Vector { - vector_type, - dims: 0, - data, - }); - } - let xs = text.split(','); - for x in xs { - let x = x.trim(); - if x.is_empty() { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - match vector_type { - VectorType::Float32Dense => { - let x = x - .parse::() - .map_err(|_| LimboError::ConversionError("Invalid vector value".to_string()))?; - if !x.is_finite() { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - data.extend_from_slice(&x.to_le_bytes()); - } - VectorType::Float64Dense => { - let x = x - .parse::() - .map_err(|_| LimboError::ConversionError("Invalid vector value".to_string()))?; - if !x.is_finite() { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - data.extend_from_slice(&x.to_le_bytes()); - } - }; - } - let dims = vector_type.size_to_dims(data.len()); - Ok(Vector { - vector_type, - dims, - data, - }) -} - pub fn parse_vector(value: &Register, vec_ty: Option) -> Result { match value.get_value().value_type() { - ValueType::Text => parse_string_vector( + ValueType::Text => operations::text::vector_from_text( vec_ty.unwrap_or(VectorType::Float32Dense), - value.get_value(), + value.get_value().to_text().expect("value must be text"), ), ValueType::Blob => { let Some(blob) = value.get_value().to_blob() else { @@ -180,33 +111,6 @@ pub fn parse_vector(value: &Register, vec_ty: Option) -> Result String { - let mut text = String::new(); - text.push('['); - match vector.vector_type { - VectorType::Float32Dense => { - let data = vector.as_f32_slice(); - for (i, value) in data.iter().enumerate().take(vector.dims) { - text.push_str(&value.to_string()); - if i < vector.dims - 1 { - text.push(','); - } - } - } - VectorType::Float64Dense => { - let data = vector.as_f64_slice(); - for (i, value) in data.iter().enumerate().take(vector.dims) { - text.push_str(&value.to_string()); - if i < vector.dims - 1 { - text.push(','); - } - } - } - } - text.push(']'); - text -} - pub fn vector_deserialize(vector_type: VectorType, blob: &[u8]) -> Result { match vector_type { VectorType::Float32Dense => vector_deserialize_f32(blob), @@ -214,13 +118,6 @@ pub fn vector_deserialize(vector_type: VectorType, blob: &[u8]) -> Result Value { - let mut blob = Vec::with_capacity(x.dims * 8 + 1); - blob.extend_from_slice(&x.data); - blob.push(2); - Value::from_blob(blob) -} - pub fn vector_deserialize_f64(blob: &[u8]) -> Result { Ok(Vector { vector_type: VectorType::Float64Dense, @@ -229,10 +126,6 @@ pub fn vector_deserialize_f64(blob: &[u8]) -> Result { }) } -pub fn vector_serialize_f32(x: Vector) -> Value { - Value::from_blob(x.data) -} - pub fn vector_deserialize_f32(blob: &[u8]) -> Result { Ok(Vector { vector_type: VectorType::Float32Dense, @@ -385,11 +278,7 @@ mod tests { /// Test if the vector type identification is correct for a given vector. fn test_vector_type(v: Vector) -> bool { let vtype = v.vector_type; - let value = match &vtype { - VectorType::Float32Dense => vector_serialize_f32(v), - VectorType::Float64Dense => vector_serialize_f64(v), - }; - + let value = operations::serialize::vector_serialize(v); let blob = value.to_blob().unwrap(); match vector_type(blob) { Ok(detected_type) => detected_type == vtype, @@ -517,24 +406,27 @@ mod tests { #[test] fn parse_string_vector_zero_length() { - let value = Value::from_text("[]"); - let vector = parse_string_vector(VectorType::Float32Dense, &value).unwrap(); + let vector = operations::text::vector_from_text(VectorType::Float32Dense, "[]").unwrap(); assert_eq!(vector.dims, 0); assert_eq!(vector.vector_type, VectorType::Float32Dense); } #[test] fn test_parse_string_vector_valid_whitespace() { - let value = Value::from_text(" [ 1.0 , 2.0 , 3.0 ] "); - let vector = parse_string_vector(VectorType::Float32Dense, &value).unwrap(); + let vector = operations::text::vector_from_text( + VectorType::Float32Dense, + " [ 1.0 , 2.0 , 3.0 ] ", + ) + .unwrap(); assert_eq!(vector.dims, 3); assert_eq!(vector.vector_type, VectorType::Float32Dense); } #[test] fn test_parse_string_vector_valid() { - let value = Value::from_text("[1.0, 2.0, 3.0]"); - let vector = parse_string_vector(VectorType::Float32Dense, &value).unwrap(); + let vector = + operations::text::vector_from_text(VectorType::Float32Dense, "[1.0, 2.0, 3.0]") + .unwrap(); assert_eq!(vector.dims, 3); assert_eq!(vector.vector_type, VectorType::Float32Dense); } @@ -567,11 +459,10 @@ mod tests { /// Test that a vector can be converted to text and back without loss of precision fn test_vector_text_roundtrip(v: Vector) -> bool { // Convert to text - let text = vector_to_text(&v); + let text = operations::text::vector_to_text(&v); // Parse back from text - let value = Value::from_text(&text); - let parsed = parse_string_vector(v.vector_type, &value); + let parsed = operations::text::vector_from_text(v.vector_type, &text); match parsed { Ok(parsed_vector) => {