mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-23 17:05:36 +01:00
add vec_concat execution flow
This commit is contained in:
@@ -158,6 +158,7 @@ pub enum VectorFunc {
|
||||
VectorExtract,
|
||||
VectorDistanceCos,
|
||||
VectorDistanceEuclidean,
|
||||
VectorConcat,
|
||||
}
|
||||
|
||||
impl VectorFunc {
|
||||
@@ -176,6 +177,7 @@ impl Display for VectorFunc {
|
||||
Self::VectorDistanceCos => "vector_distance_cos".to_string(),
|
||||
// We use `distance_l2` to reduce user input
|
||||
Self::VectorDistanceEuclidean => "vector_distance_l2".to_string(),
|
||||
Self::VectorConcat => "vector_concat".to_string(),
|
||||
};
|
||||
write!(f, "{str}")
|
||||
}
|
||||
@@ -838,6 +840,7 @@ impl Func {
|
||||
"vector_extract" => Ok(Self::Vector(VectorFunc::VectorExtract)),
|
||||
"vector_distance_cos" => Ok(Self::Vector(VectorFunc::VectorDistanceCos)),
|
||||
"vector_distance_l2" => Ok(Self::Vector(VectorFunc::VectorDistanceEuclidean)),
|
||||
"vector_concat" => Ok(Self::Vector(VectorFunc::VectorConcat)),
|
||||
_ => crate::bail_parse_error!("no such function: {}", name),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -917,6 +917,15 @@ pub fn translate_expr(
|
||||
translate_expr(program, referenced_tables, &args[0], regs, resolver)?;
|
||||
translate_expr(program, referenced_tables, &args[1], regs + 1, resolver)?;
|
||||
|
||||
emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?;
|
||||
Ok(target_register)
|
||||
}
|
||||
VectorFunc::VectorConcat => {
|
||||
let args = expect_arguments_exact!(args, 2, vector_func);
|
||||
let regs = program.alloc_registers(2);
|
||||
translate_expr(program, referenced_tables, &args[0], regs, resolver)?;
|
||||
translate_expr(program, referenced_tables, &args[1], regs + 1, resolver)?;
|
||||
|
||||
emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?;
|
||||
Ok(target_register)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use crate::types::{
|
||||
use crate::util::normalize_ident;
|
||||
use crate::vdbe::insn::InsertFlags;
|
||||
use crate::vdbe::registers_to_ref_values;
|
||||
use crate::vector::vector_concat;
|
||||
use crate::{
|
||||
error::{
|
||||
LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY,
|
||||
@@ -4557,6 +4558,10 @@ pub fn op_function(
|
||||
vector_distance_l2(&state.registers[*start_reg..*start_reg + arg_count])?;
|
||||
state.registers[*dest] = Register::Value(result);
|
||||
}
|
||||
VectorFunc::VectorConcat => {
|
||||
let result = vector_concat(&state.registers[*start_reg..*start_reg + arg_count])?;
|
||||
state.registers[*dest] = Register::Value(result);
|
||||
}
|
||||
},
|
||||
crate::function::Func::External(f) => match f.func {
|
||||
ExtFunc::Scalar(f) => {
|
||||
|
||||
@@ -104,3 +104,26 @@ pub fn vector_distance_l2(args: &[Register]) -> Result<Value> {
|
||||
let dist = Euclidean::calculate(&x, &y)?;
|
||||
Ok(Value::Float(dist))
|
||||
}
|
||||
|
||||
pub fn vector_concat(args: &[Register]) -> Result<Value> {
|
||||
if args.len() != 2 {
|
||||
return Err(LimboError::ConversionError(
|
||||
"distance_concat requires exactly two arguments".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let x = parse_vector(&args[0], None)?;
|
||||
let y = parse_vector(&args[1], None)?;
|
||||
|
||||
if x.vector_type != y.vector_type {
|
||||
return Err(LimboError::ConversionError(
|
||||
"Vectors must be of the same type".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let vector = vector_types::vector_concat(&x, &y)?;
|
||||
match vector.vector_type {
|
||||
VectorType::Float32 => Ok(vector_serialize_f32(vector)),
|
||||
VectorType::Float64 => Ok(vector_serialize_f64(vector)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -626,12 +626,12 @@ mod tests {
|
||||
fn test_vector_concat() {
|
||||
let input = "[1.0, 2.0, 3.0]";
|
||||
let value = Value::from_text(input);
|
||||
|
||||
|
||||
let vec1 = parse_string_vector(VectorType::Float32, &value).unwrap();
|
||||
let vec2 = parse_string_vector(VectorType::Float32, &value).unwrap();
|
||||
|
||||
|
||||
let result = vector_concat(&vec1, &vec2).unwrap();
|
||||
|
||||
|
||||
assert_eq!(result.dims, 6);
|
||||
assert_eq!(result.vector_type, VectorType::Float32);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user