mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-19 09:34:18 +01:00
118 lines
3.3 KiB
Rust
118 lines
3.3 KiB
Rust
use crate::{
|
|
vector::vector_types::{Vector, VectorType},
|
|
LimboError, Result,
|
|
};
|
|
|
|
pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result<Vector> {
|
|
if v1.vector_type != v2.vector_type {
|
|
return Err(LimboError::ConversionError(
|
|
"Mismatched vector types".into(),
|
|
));
|
|
}
|
|
|
|
let data = match v1.vector_type {
|
|
VectorType::Float32Dense | VectorType::Float64Dense => {
|
|
let mut data = Vec::with_capacity(v1.data.len() + v2.data.len());
|
|
data.extend_from_slice(&v1.data);
|
|
data.extend_from_slice(&v2.data);
|
|
data
|
|
}
|
|
VectorType::Float32Sparse => {
|
|
let mut data = Vec::with_capacity(v1.data.len() + v2.data.len());
|
|
data.extend_from_slice(&v1.data[..v1.data.len() / 2]);
|
|
data.extend_from_slice(&v2.data[..v2.data.len() / 2]);
|
|
data.extend_from_slice(&v1.data[v1.data.len() / 2..]);
|
|
data.extend_from_slice(&v2.data[v2.data.len() / 2..]);
|
|
data
|
|
}
|
|
};
|
|
|
|
Ok(Vector {
|
|
vector_type: v1.vector_type,
|
|
dims: v1.dims + v2.dims,
|
|
data,
|
|
})
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::vector::{
|
|
operations::concat::vector_concat,
|
|
vector_types::{Vector, VectorType},
|
|
};
|
|
|
|
fn float32_vec_from(slice: &[f32]) -> Vector {
|
|
let mut data = Vec::new();
|
|
for &v in slice {
|
|
data.extend_from_slice(&v.to_le_bytes());
|
|
}
|
|
|
|
Vector {
|
|
vector_type: VectorType::Float32Dense,
|
|
dims: slice.len(),
|
|
data,
|
|
}
|
|
}
|
|
|
|
fn f32_slice_from_vector(vector: &Vector) -> Vec<f32> {
|
|
vector.as_f32_slice().to_vec()
|
|
}
|
|
|
|
#[test]
|
|
fn test_vector_concat_normal_case() {
|
|
let v1 = float32_vec_from(&[1.0, 2.0, 3.0]);
|
|
let v2 = float32_vec_from(&[4.0, 5.0, 6.0]);
|
|
|
|
let result = vector_concat(&v1, &v2).unwrap();
|
|
|
|
assert_eq!(result.dims, 6);
|
|
assert_eq!(result.vector_type, VectorType::Float32Dense);
|
|
assert_eq!(
|
|
f32_slice_from_vector(&result),
|
|
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_vector_concat_empty_left() {
|
|
let v1 = float32_vec_from(&[]);
|
|
let v2 = float32_vec_from(&[4.0, 5.0]);
|
|
|
|
let result = vector_concat(&v1, &v2).unwrap();
|
|
|
|
assert_eq!(result.dims, 2);
|
|
assert_eq!(f32_slice_from_vector(&result), vec![4.0, 5.0]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_vector_concat_empty_right() {
|
|
let v1 = float32_vec_from(&[1.0, 2.0]);
|
|
let v2 = float32_vec_from(&[]);
|
|
|
|
let result = vector_concat(&v1, &v2).unwrap();
|
|
|
|
assert_eq!(result.dims, 2);
|
|
assert_eq!(f32_slice_from_vector(&result), vec![1.0, 2.0]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_vector_concat_both_empty() {
|
|
let v1 = float32_vec_from(&[]);
|
|
let v2 = float32_vec_from(&[]);
|
|
let result = vector_concat(&v1, &v2).unwrap();
|
|
assert_eq!(result.dims, 0);
|
|
assert_eq!(f32_slice_from_vector(&result), Vec::<f32>::new());
|
|
}
|
|
|
|
#[test]
|
|
fn test_vector_concat_different_lengths() {
|
|
let v1 = float32_vec_from(&[1.0]);
|
|
let v2 = float32_vec_from(&[2.0, 3.0, 4.0]);
|
|
|
|
let result = vector_concat(&v1, &v2).unwrap();
|
|
|
|
assert_eq!(result.dims, 4);
|
|
assert_eq!(f32_slice_from_vector(&result), vec![1.0, 2.0, 3.0, 4.0]);
|
|
}
|
|
}
|