diff --git a/COMPAT.md b/COMPAT.md index dbc93962c..a436e80ef 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -632,7 +632,7 @@ The `vector` extension is compatible with libSQL native vector search. | vector_distance_cos(x, y) | Yes | | | vector_distance_l2(x, y) | Yes |Euclidean distance| | vector_concat(x, y) | Yes | | -| vector_slice(x, start_index, length) | Yes | | +| vector_slice(x, start_index, end_index) | Yes | | ### Time diff --git a/core/vector/mod.rs b/core/vector/mod.rs index ebbf5a7ef..51c1f0868 100644 --- a/core/vector/mod.rs +++ b/core/vector/mod.rs @@ -142,18 +142,18 @@ pub fn vector_slice(args: &[Register]) -> Result { .as_int() .ok_or_else(|| LimboError::InvalidArgument("start index must be an integer".into()))?; - let length = args[2] + let end_index = args[2] .get_owned_value() .as_int() - .ok_or_else(|| LimboError::InvalidArgument("length must be an integer".into()))?; + .ok_or_else(|| LimboError::InvalidArgument("end_index must be an integer".into()))?; - if start_index < 0 || length < 0 { + if start_index < 0 || end_index < 0 { return Err(LimboError::InvalidArgument( - "start index and length must be non-negative".into(), + "start index and end_index must be non-negative".into(), )); } - let result = vector_types::vector_slice(&vector, start_index as usize, length as usize)?; + let result = vector_types::vector_slice(&vector, start_index as usize, end_index as usize)?; Ok(match result.vector_type { VectorType::Float32 => vector_serialize_f32(result), diff --git a/core/vector/vector_types.rs b/core/vector/vector_types.rs index 853928186..a0cf2dfe8 100644 --- a/core/vector/vector_types.rs +++ b/core/vector/vector_types.rs @@ -334,21 +334,26 @@ pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result { }) } -pub fn vector_slice(vector: &Vector, start_idx: usize, length: usize) -> Result { +pub fn vector_slice(vector: &Vector, start_idx: usize, end_idx: usize) -> Result { fn extract_bytes( slice: &[T], start: usize, - len: usize, + end: usize, to_bytes: impl Fn(&T) -> [u8; N], ) -> Result> { - if start + len > slice.len() { + if start > end { + return Err(LimboError::InvalidArgument( + "start index must not be greater than end index".into(), + )); + } + if end > slice.len() || start >= slice.len() { return Err(LimboError::ConversionError( "vector_slice range out of bounds".into(), )); } - let mut buf = Vec::with_capacity(len * N); - for item in &slice[start..start + len] { + let mut buf = Vec::with_capacity((end - start) * N); + for item in &slice[start..end] { buf.extend_from_slice(&to_bytes(item)); } Ok(buf) @@ -357,13 +362,13 @@ pub fn vector_slice(vector: &Vector, start_idx: usize, length: usize) -> Result< let (vector_type, data) = match vector.vector_type { VectorType::Float32 => ( VectorType::Float32, - extract_bytes::(vector.as_f32_slice(), start_idx, length, |v| { + extract_bytes::(vector.as_f32_slice(), start_idx, end_idx, |v| { v.to_le_bytes() })?, ), VectorType::Float64 => ( VectorType::Float64, - extract_bytes::(vector.as_f64_slice(), start_idx, length, |v| { + extract_bytes::(vector.as_f64_slice(), start_idx, end_idx, |v| { v.to_le_bytes() })?, ), @@ -371,7 +376,7 @@ pub fn vector_slice(vector: &Vector, start_idx: usize, length: usize) -> Result< Ok(Vector { vector_type, - dims: length, + dims: end_idx - start_idx, data, }) }