implement operations for sparse vectors

This commit is contained in:
Nikita Sivukhin
2025-10-09 20:52:58 +04:00
parent 84643dc4f2
commit 585d11b736
3 changed files with 84 additions and 5 deletions

View File

@@ -25,7 +25,6 @@ pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result<Vector> {
data.extend_from_slice(&v2.data[v2.data.len() / 2..]);
data
}
_ => todo!(),
};
Ok(Vector {

View File

@@ -27,7 +27,6 @@ pub fn vector_distance_cos(v1: &Vector, v2: &Vector) -> Result<f64> {
v1.as_f32_sparse(),
v2.as_f32_sparse(),
)),
_ => todo!(),
}
}
@@ -98,7 +97,7 @@ fn vector_f32_sparse_distance_cos(v1: VectorSparse<f32>, v2: VectorSparse<f32>)
v1_pos += 1;
}
while v2_pos < v2.idx.len() {
norm1 += v2.values[v2_pos] * v2.values[v2_pos];
norm2 += v2.values[v2_pos] * v2.values[v2_pos];
v2_pos += 1;
}
@@ -131,4 +130,22 @@ mod tests {
assert_eq!(vector_f64_distance_cos(&[1.0, 2.0], &[-1.0, -2.0]), 2.0);
assert_eq!(vector_f64_distance_cos(&[1.0, 2.0], &[-2.0, 1.0]), 1.0);
}
#[test]
fn test_distance_cos_f32_sparse() {
assert!(
(vector_f32_sparse_distance_cos(
VectorSparse {
idx: &[0, 1],
values: &[1.0, 2.0]
},
VectorSparse {
idx: &[1, 2],
values: &[1.0, 3.0]
},
) - vector_f32_distance_cos(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0]))
.abs()
< 1e-7
);
}
}

View File

@@ -25,7 +25,6 @@ pub fn vector_distance_l2(v1: &Vector, v2: &Vector) -> Result<f64> {
v1.as_f32_sparse(),
v2.as_f32_sparse(),
)),
_ => todo!(),
}
}
@@ -57,12 +56,22 @@ fn vector_f32_sparse_distance_l2(v1: VectorSparse<f32>, v2: VectorSparse<f32>) -
v1_pos += 1;
v2_pos += 1;
} else if v1.idx[v1_pos] < v2.idx[v2_pos] {
sum += v1.values[v1_pos].powi(2);
v1_pos += 1;
} else {
sum += v2.values[v2_pos].powi(2);
v2_pos += 1;
}
}
sum as f64
while v1_pos < v1.idx.len() {
sum += v1.values[v1_pos].powi(2);
v1_pos += 1;
}
while v2_pos < v2.idx.len() {
sum += v2.values[v2_pos].powi(2);
v2_pos += 1;
}
(sum as f64).sqrt()
}
#[cfg(test)]
@@ -98,4 +107,58 @@ mod tests {
let query = (2..7).map(|x| x as f32).collect::<Vec<f32>>();
assert_eq!(vector_f32_distance_l2(&v, &query), 20.0_f64.sqrt());
}
#[test]
fn test_distance_l2_f32() {
assert_eq!(vector_f32_distance_l2(&[], &[]), 0.0);
assert_eq!(
vector_f32_distance_l2(&[1.0, 2.0], &[0.0, 0.0]),
(1f64 + 2f64 * 2f64).sqrt()
);
assert_eq!(vector_f32_distance_l2(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
assert_eq!(
vector_f32_distance_l2(&[1.0, 2.0], &[-1.0, -2.0]),
(2f64 * 2f64 + 4f64 * 4f64).sqrt()
);
assert_eq!(
vector_f32_distance_l2(&[1.0, 2.0], &[-2.0, 1.0]),
(3f64 * 3f64 + 1f64 * 1f64).sqrt()
);
}
#[test]
fn test_distance_l2_f64() {
assert_eq!(vector_f64_distance_l2(&[], &[]), 0.0);
assert_eq!(
vector_f64_distance_l2(&[1.0, 2.0], &[0.0, 0.0]),
(1f64 + 2f64 * 2f64).sqrt()
);
assert_eq!(vector_f64_distance_l2(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
assert_eq!(
vector_f64_distance_l2(&[1.0, 2.0], &[-1.0, -2.0]),
(2f64 * 2f64 + 4f64 * 4f64).sqrt()
);
assert_eq!(
vector_f64_distance_l2(&[1.0, 2.0], &[-2.0, 1.0]),
(3f64 * 3f64 + 1f64 * 1f64).sqrt()
);
}
#[test]
fn test_distance_l2_f32_sparse() {
assert!(
(vector_f32_sparse_distance_l2(
VectorSparse {
idx: &[0, 1],
values: &[1.0, 2.0]
},
VectorSparse {
idx: &[1, 2],
values: &[1.0, 3.0]
},
) - vector_f32_distance_l2(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0]))
.abs()
< 1e-7
);
}
}