diff --git a/core/function.rs b/core/function.rs index f064970c3..a9b9750d7 100644 --- a/core/function.rs +++ b/core/function.rs @@ -158,6 +158,7 @@ pub enum VectorFunc { VectorExtract, VectorDistanceCos, VectorDistanceL2, + VectorDistanceJaccard, VectorConcat, VectorSlice, } @@ -178,6 +179,7 @@ impl Display for VectorFunc { Self::VectorExtract => "vector_extract".to_string(), Self::VectorDistanceCos => "vector_distance_cos".to_string(), Self::VectorDistanceL2 => "vector_distance_l2".to_string(), + Self::VectorDistanceJaccard => "vector_distance_jaccard".to_string(), Self::VectorConcat => "vector_concat".to_string(), Self::VectorSlice => "vector_slice".to_string(), }; @@ -871,6 +873,7 @@ impl Func { "vector_extract" => Ok(Self::Vector(VectorFunc::VectorExtract)), "vector_distance_cos" => Ok(Self::Vector(VectorFunc::VectorDistanceCos)), "vector_distance_l2" => Ok(Self::Vector(VectorFunc::VectorDistanceL2)), + "vector_distance_jaccard" => Ok(Self::Vector(VectorFunc::VectorDistanceJaccard)), "vector_concat" => Ok(Self::Vector(VectorFunc::VectorConcat)), "vector_slice" => Ok(Self::Vector(VectorFunc::VectorSlice)), _ => crate::bail_parse_error!("no such function: {}", name), diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 0a84e1548..77a464538 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -921,6 +921,15 @@ pub fn translate_expr( emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?; Ok(target_register) } + VectorFunc::VectorDistanceJaccard => { + let args = expect_arguments_exact!(args, 2, vector_func); + let regs = program.alloc_registers(2); + translate_expr(program, referenced_tables, &args[0], regs, resolver)?; + translate_expr(program, referenced_tables, &args[1], regs + 1, resolver)?; + + emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?; + Ok(target_register) + } VectorFunc::VectorConcat => { let args = expect_arguments_exact!(args, 2, vector_func); let regs = program.alloc_registers(2); diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 87103b8d4..b68bdf866 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -20,7 +20,7 @@ use crate::types::{ use crate::util::normalize_ident; use crate::vdbe::insn::InsertFlags; use crate::vdbe::{registers_to_ref_values, TxnCleanup}; -use crate::vector::{vector32_sparse, vector_concat, vector_slice}; +use crate::vector::{vector32_sparse, vector_concat, vector_distance_jaccard, vector_slice}; use crate::{ error::{ LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, @@ -5219,6 +5219,11 @@ pub fn op_function( vector_distance_l2(&state.registers[*start_reg..*start_reg + arg_count])?; state.registers[*dest] = Register::Value(result); } + VectorFunc::VectorDistanceJaccard => { + let result = + vector_distance_jaccard(&state.registers[*start_reg..*start_reg + arg_count])?; + state.registers[*dest] = Register::Value(result); + } VectorFunc::VectorConcat => { let result = vector_concat(&state.registers[*start_reg..*start_reg + arg_count])?; state.registers[*dest] = Register::Value(result); diff --git a/core/vector/mod.rs b/core/vector/mod.rs index 8c77f7e7c..514780a78 100644 --- a/core/vector/mod.rs +++ b/core/vector/mod.rs @@ -111,6 +111,19 @@ pub fn vector_distance_l2(args: &[Register]) -> Result { Ok(Value::Float(dist)) } +pub fn vector_distance_jaccard(args: &[Register]) -> Result { + if args.len() != 2 { + return Err(LimboError::ConversionError( + "distance_jaccard requires exactly two arguments".to_string(), + )); + } + + let x = parse_vector(&args[0], None)?; + let y = parse_vector(&args[1], None)?; + let dist = operations::jaccard::vector_distance_jaccard(&x, &y)?; + Ok(Value::Float(dist)) +} + pub fn vector_concat(args: &[Register]) -> Result { if args.len() != 2 { return Err(LimboError::InvalidArgument( diff --git a/core/vector/operations/jaccard.rs b/core/vector/operations/jaccard.rs new file mode 100644 index 000000000..f545d82bd --- /dev/null +++ b/core/vector/operations/jaccard.rs @@ -0,0 +1,87 @@ +use crate::{ + vector::vector_types::{Vector, VectorSparse, VectorType}, + LimboError, Result, +}; + +pub fn vector_distance_jaccard(v1: &Vector, v2: &Vector) -> Result { + if v1.dims != v2.dims { + return Err(LimboError::ConversionError( + "Vectors must have the same dimensions".to_string(), + )); + } + if v1.vector_type != v2.vector_type { + return Err(LimboError::ConversionError( + "Vectors must be of the same type".to_string(), + )); + } + match v1.vector_type { + VectorType::Float32Dense => Ok(vector_f32_distance_jaccard( + v1.as_f32_slice(), + v2.as_f32_slice(), + )), + VectorType::Float64Dense => Ok(vector_f64_distance_jaccard( + v1.as_f64_slice(), + v2.as_f64_slice(), + )), + VectorType::Float32Sparse => Ok(vector_f32_sparse_distance_jaccard( + v1.as_f32_sparse(), + v2.as_f32_sparse(), + )), + } +} + +fn vector_f32_distance_jaccard(v1: &[f32], v2: &[f32]) -> f64 { + let (mut min_sum, mut max_sum) = (0.0, 0.0); + for (&a, &b) in v1.iter().zip(v2.iter()) { + min_sum += a.min(b); + max_sum += a.max(b); + } + if max_sum == 0.0 { + return f64::NAN; + } + 1. - (min_sum / min_sum) as f64 +} + +fn vector_f64_distance_jaccard(v1: &[f64], v2: &[f64]) -> f64 { + let (mut min_sum, mut max_sum) = (0.0, 0.0); + for (&a, &b) in v1.iter().zip(v2.iter()) { + min_sum += a.min(b); + max_sum += a.max(b); + } + if max_sum == 0.0 { + return f64::NAN; + } + 1. - min_sum / min_sum +} + +fn vector_f32_sparse_distance_jaccard(v1: VectorSparse, v2: VectorSparse) -> f64 { + let mut v1_pos = 0; + let mut v2_pos = 0; + let (mut min_sum, mut max_sum) = (0.0, 0.0); + while v1_pos < v1.idx.len() && v2_pos < v2.idx.len() { + if v1.idx[v1_pos] == v2.idx[v2_pos] { + min_sum += v1.values[v1_pos].min(v2.values[v2_pos]); + max_sum += v1.values[v1_pos].max(v2.values[v2_pos]); + v1_pos += 1; + v2_pos += 1; + } else if v1.idx[v1_pos] < v2.idx[v2_pos] { + max_sum += v1.values[v1_pos]; + v1_pos += 1; + } else { + max_sum += v2.values[v2_pos]; + v2_pos += 1; + } + } + while v1_pos < v1.idx.len() { + max_sum += v1.values[v1_pos]; + v1_pos += 1; + } + while v2_pos < v2.idx.len() { + max_sum += v2.values[v2_pos]; + v2_pos += 1; + } + if max_sum == 0.0 { + return f64::NAN; + } + 1. - (min_sum / max_sum) as f64 +} diff --git a/core/vector/operations/mod.rs b/core/vector/operations/mod.rs index c0d10a0f0..9b1a20ada 100644 --- a/core/vector/operations/mod.rs +++ b/core/vector/operations/mod.rs @@ -2,6 +2,7 @@ pub mod concat; pub mod convert; pub mod distance_cos; pub mod distance_l2; +pub mod jaccard; pub mod serialize; pub mod slice; pub mod text;