From e4d79a6516bd23ded1222b80776274c8b659f865 Mon Sep 17 00:00:00 2001 From: bit-aloo Date: Wed, 30 Jul 2025 06:07:03 +0530 Subject: [PATCH] add vec_concat execution flow --- core/function.rs | 3 +++ core/translate/expr.rs | 9 +++++++++ core/vdbe/execute.rs | 5 +++++ core/vector/mod.rs | 23 +++++++++++++++++++++++ core/vector/vector_types.rs | 6 +++--- 5 files changed, 43 insertions(+), 3 deletions(-) diff --git a/core/function.rs b/core/function.rs index aa24450fa..7cc6d9cc3 100644 --- a/core/function.rs +++ b/core/function.rs @@ -158,6 +158,7 @@ pub enum VectorFunc { VectorExtract, VectorDistanceCos, VectorDistanceEuclidean, + VectorConcat, } impl VectorFunc { @@ -176,6 +177,7 @@ impl Display for VectorFunc { Self::VectorDistanceCos => "vector_distance_cos".to_string(), // We use `distance_l2` to reduce user input Self::VectorDistanceEuclidean => "vector_distance_l2".to_string(), + Self::VectorConcat => "vector_concat".to_string(), }; write!(f, "{str}") } @@ -838,6 +840,7 @@ impl Func { "vector_extract" => Ok(Self::Vector(VectorFunc::VectorExtract)), "vector_distance_cos" => Ok(Self::Vector(VectorFunc::VectorDistanceCos)), "vector_distance_l2" => Ok(Self::Vector(VectorFunc::VectorDistanceEuclidean)), + "vector_concat" => Ok(Self::Vector(VectorFunc::VectorConcat)), _ => crate::bail_parse_error!("no such function: {}", name), } } diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 6fa51125b..6de758324 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -917,6 +917,15 @@ pub fn translate_expr( translate_expr(program, referenced_tables, &args[0], regs, resolver)?; translate_expr(program, referenced_tables, &args[1], regs + 1, resolver)?; + emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?; + Ok(target_register) + } + VectorFunc::VectorConcat => { + let args = expect_arguments_exact!(args, 2, vector_func); + let regs = program.alloc_registers(2); + translate_expr(program, referenced_tables, &args[0], regs, resolver)?; + translate_expr(program, referenced_tables, &args[1], regs + 1, resolver)?; + emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?; Ok(target_register) } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index a9f21431e..c67320f09 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -16,6 +16,7 @@ use crate::types::{ use crate::util::normalize_ident; use crate::vdbe::insn::InsertFlags; use crate::vdbe::registers_to_ref_values; +use crate::vector::vector_concat; use crate::{ error::{ LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, @@ -4557,6 +4558,10 @@ pub fn op_function( vector_distance_l2(&state.registers[*start_reg..*start_reg + arg_count])?; state.registers[*dest] = Register::Value(result); } + VectorFunc::VectorConcat => { + let result = vector_concat(&state.registers[*start_reg..*start_reg + arg_count])?; + state.registers[*dest] = Register::Value(result); + } }, crate::function::Func::External(f) => match f.func { ExtFunc::Scalar(f) => { diff --git a/core/vector/mod.rs b/core/vector/mod.rs index ec69c99d4..1742c0d96 100644 --- a/core/vector/mod.rs +++ b/core/vector/mod.rs @@ -104,3 +104,26 @@ pub fn vector_distance_l2(args: &[Register]) -> Result { let dist = Euclidean::calculate(&x, &y)?; Ok(Value::Float(dist)) } + +pub fn vector_concat(args: &[Register]) -> Result { + if args.len() != 2 { + return Err(LimboError::ConversionError( + "distance_concat requires exactly two arguments".to_string(), + )); + } + + let x = parse_vector(&args[0], None)?; + let y = parse_vector(&args[1], None)?; + + if x.vector_type != y.vector_type { + return Err(LimboError::ConversionError( + "Vectors must be of the same type".to_string(), + )); + } + + let vector = vector_types::vector_concat(&x, &y)?; + match vector.vector_type { + VectorType::Float32 => Ok(vector_serialize_f32(vector)), + VectorType::Float64 => Ok(vector_serialize_f64(vector)), + } +} diff --git a/core/vector/vector_types.rs b/core/vector/vector_types.rs index 10e0258ce..c7cf13d4f 100644 --- a/core/vector/vector_types.rs +++ b/core/vector/vector_types.rs @@ -626,12 +626,12 @@ mod tests { fn test_vector_concat() { let input = "[1.0, 2.0, 3.0]"; let value = Value::from_text(input); - + let vec1 = parse_string_vector(VectorType::Float32, &value).unwrap(); let vec2 = parse_string_vector(VectorType::Float32, &value).unwrap(); - + let result = vector_concat(&vec1, &vec2).unwrap(); - + assert_eq!(result.dims, 6); assert_eq!(result.vector_type, VectorType::Float32); }