mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-02 14:54:23 +01:00
add jaccard distance
This commit is contained in:
@@ -158,6 +158,7 @@ pub enum VectorFunc {
|
||||
VectorExtract,
|
||||
VectorDistanceCos,
|
||||
VectorDistanceL2,
|
||||
VectorDistanceJaccard,
|
||||
VectorConcat,
|
||||
VectorSlice,
|
||||
}
|
||||
@@ -178,6 +179,7 @@ impl Display for VectorFunc {
|
||||
Self::VectorExtract => "vector_extract".to_string(),
|
||||
Self::VectorDistanceCos => "vector_distance_cos".to_string(),
|
||||
Self::VectorDistanceL2 => "vector_distance_l2".to_string(),
|
||||
Self::VectorDistanceJaccard => "vector_distance_jaccard".to_string(),
|
||||
Self::VectorConcat => "vector_concat".to_string(),
|
||||
Self::VectorSlice => "vector_slice".to_string(),
|
||||
};
|
||||
@@ -871,6 +873,7 @@ impl Func {
|
||||
"vector_extract" => Ok(Self::Vector(VectorFunc::VectorExtract)),
|
||||
"vector_distance_cos" => Ok(Self::Vector(VectorFunc::VectorDistanceCos)),
|
||||
"vector_distance_l2" => Ok(Self::Vector(VectorFunc::VectorDistanceL2)),
|
||||
"vector_distance_jaccard" => Ok(Self::Vector(VectorFunc::VectorDistanceJaccard)),
|
||||
"vector_concat" => Ok(Self::Vector(VectorFunc::VectorConcat)),
|
||||
"vector_slice" => Ok(Self::Vector(VectorFunc::VectorSlice)),
|
||||
_ => crate::bail_parse_error!("no such function: {}", name),
|
||||
|
||||
@@ -921,6 +921,15 @@ pub fn translate_expr(
|
||||
emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?;
|
||||
Ok(target_register)
|
||||
}
|
||||
VectorFunc::VectorDistanceJaccard => {
|
||||
let args = expect_arguments_exact!(args, 2, vector_func);
|
||||
let regs = program.alloc_registers(2);
|
||||
translate_expr(program, referenced_tables, &args[0], regs, resolver)?;
|
||||
translate_expr(program, referenced_tables, &args[1], regs + 1, resolver)?;
|
||||
|
||||
emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?;
|
||||
Ok(target_register)
|
||||
}
|
||||
VectorFunc::VectorConcat => {
|
||||
let args = expect_arguments_exact!(args, 2, vector_func);
|
||||
let regs = program.alloc_registers(2);
|
||||
|
||||
@@ -20,7 +20,7 @@ use crate::types::{
|
||||
use crate::util::normalize_ident;
|
||||
use crate::vdbe::insn::InsertFlags;
|
||||
use crate::vdbe::{registers_to_ref_values, TxnCleanup};
|
||||
use crate::vector::{vector32_sparse, vector_concat, vector_slice};
|
||||
use crate::vector::{vector32_sparse, vector_concat, vector_distance_jaccard, vector_slice};
|
||||
use crate::{
|
||||
error::{
|
||||
LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY,
|
||||
@@ -5219,6 +5219,11 @@ pub fn op_function(
|
||||
vector_distance_l2(&state.registers[*start_reg..*start_reg + arg_count])?;
|
||||
state.registers[*dest] = Register::Value(result);
|
||||
}
|
||||
VectorFunc::VectorDistanceJaccard => {
|
||||
let result =
|
||||
vector_distance_jaccard(&state.registers[*start_reg..*start_reg + arg_count])?;
|
||||
state.registers[*dest] = Register::Value(result);
|
||||
}
|
||||
VectorFunc::VectorConcat => {
|
||||
let result = vector_concat(&state.registers[*start_reg..*start_reg + arg_count])?;
|
||||
state.registers[*dest] = Register::Value(result);
|
||||
|
||||
@@ -111,6 +111,19 @@ pub fn vector_distance_l2(args: &[Register]) -> Result<Value> {
|
||||
Ok(Value::Float(dist))
|
||||
}
|
||||
|
||||
pub fn vector_distance_jaccard(args: &[Register]) -> Result<Value> {
|
||||
if args.len() != 2 {
|
||||
return Err(LimboError::ConversionError(
|
||||
"distance_jaccard requires exactly two arguments".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let x = parse_vector(&args[0], None)?;
|
||||
let y = parse_vector(&args[1], None)?;
|
||||
let dist = operations::jaccard::vector_distance_jaccard(&x, &y)?;
|
||||
Ok(Value::Float(dist))
|
||||
}
|
||||
|
||||
pub fn vector_concat(args: &[Register]) -> Result<Value> {
|
||||
if args.len() != 2 {
|
||||
return Err(LimboError::InvalidArgument(
|
||||
|
||||
87
core/vector/operations/jaccard.rs
Normal file
87
core/vector/operations/jaccard.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use crate::{
|
||||
vector::vector_types::{Vector, VectorSparse, VectorType},
|
||||
LimboError, Result,
|
||||
};
|
||||
|
||||
pub fn vector_distance_jaccard(v1: &Vector, v2: &Vector) -> Result<f64> {
|
||||
if v1.dims != v2.dims {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Vectors must have the same dimensions".to_string(),
|
||||
));
|
||||
}
|
||||
if v1.vector_type != v2.vector_type {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Vectors must be of the same type".to_string(),
|
||||
));
|
||||
}
|
||||
match v1.vector_type {
|
||||
VectorType::Float32Dense => Ok(vector_f32_distance_jaccard(
|
||||
v1.as_f32_slice(),
|
||||
v2.as_f32_slice(),
|
||||
)),
|
||||
VectorType::Float64Dense => Ok(vector_f64_distance_jaccard(
|
||||
v1.as_f64_slice(),
|
||||
v2.as_f64_slice(),
|
||||
)),
|
||||
VectorType::Float32Sparse => Ok(vector_f32_sparse_distance_jaccard(
|
||||
v1.as_f32_sparse(),
|
||||
v2.as_f32_sparse(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn vector_f32_distance_jaccard(v1: &[f32], v2: &[f32]) -> f64 {
|
||||
let (mut min_sum, mut max_sum) = (0.0, 0.0);
|
||||
for (&a, &b) in v1.iter().zip(v2.iter()) {
|
||||
min_sum += a.min(b);
|
||||
max_sum += a.max(b);
|
||||
}
|
||||
if max_sum == 0.0 {
|
||||
return f64::NAN;
|
||||
}
|
||||
1. - (min_sum / min_sum) as f64
|
||||
}
|
||||
|
||||
fn vector_f64_distance_jaccard(v1: &[f64], v2: &[f64]) -> f64 {
|
||||
let (mut min_sum, mut max_sum) = (0.0, 0.0);
|
||||
for (&a, &b) in v1.iter().zip(v2.iter()) {
|
||||
min_sum += a.min(b);
|
||||
max_sum += a.max(b);
|
||||
}
|
||||
if max_sum == 0.0 {
|
||||
return f64::NAN;
|
||||
}
|
||||
1. - min_sum / min_sum
|
||||
}
|
||||
|
||||
fn vector_f32_sparse_distance_jaccard(v1: VectorSparse<f32>, v2: VectorSparse<f32>) -> f64 {
|
||||
let mut v1_pos = 0;
|
||||
let mut v2_pos = 0;
|
||||
let (mut min_sum, mut max_sum) = (0.0, 0.0);
|
||||
while v1_pos < v1.idx.len() && v2_pos < v2.idx.len() {
|
||||
if v1.idx[v1_pos] == v2.idx[v2_pos] {
|
||||
min_sum += v1.values[v1_pos].min(v2.values[v2_pos]);
|
||||
max_sum += v1.values[v1_pos].max(v2.values[v2_pos]);
|
||||
v1_pos += 1;
|
||||
v2_pos += 1;
|
||||
} else if v1.idx[v1_pos] < v2.idx[v2_pos] {
|
||||
max_sum += v1.values[v1_pos];
|
||||
v1_pos += 1;
|
||||
} else {
|
||||
max_sum += v2.values[v2_pos];
|
||||
v2_pos += 1;
|
||||
}
|
||||
}
|
||||
while v1_pos < v1.idx.len() {
|
||||
max_sum += v1.values[v1_pos];
|
||||
v1_pos += 1;
|
||||
}
|
||||
while v2_pos < v2.idx.len() {
|
||||
max_sum += v2.values[v2_pos];
|
||||
v2_pos += 1;
|
||||
}
|
||||
if max_sum == 0.0 {
|
||||
return f64::NAN;
|
||||
}
|
||||
1. - (min_sum / max_sum) as f64
|
||||
}
|
||||
@@ -2,6 +2,7 @@ pub mod concat;
|
||||
pub mod convert;
|
||||
pub mod distance_cos;
|
||||
pub mod distance_l2;
|
||||
pub mod jaccard;
|
||||
pub mod serialize;
|
||||
pub mod slice;
|
||||
pub mod text;
|
||||
|
||||
Reference in New Issue
Block a user