mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-18 09:04:19 +01:00
convert vector functions to use AsValueRef
This commit is contained in:
@@ -1496,6 +1496,42 @@ impl<'a> ValueRef<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn to_text(&self) -> Option<&'a str> {
|
||||||
|
match self {
|
||||||
|
Self::Text(t) => Some(t.as_str()),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_blob(&self) -> &'a [u8] {
|
||||||
|
match self {
|
||||||
|
Self::Blob(b) => b,
|
||||||
|
_ => panic!("as_blob must be called only for Value::Blob"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_float(&self) -> f64 {
|
||||||
|
match self {
|
||||||
|
Self::Float(f) => *f,
|
||||||
|
Self::Integer(i) => *i as f64,
|
||||||
|
_ => panic!("as_float must be called only for Value::Float or Value::Integer"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_int(&self) -> Option<i64> {
|
||||||
|
match self {
|
||||||
|
Self::Integer(i) => Some(*i),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_uint(&self) -> u64 {
|
||||||
|
match self {
|
||||||
|
Self::Integer(i) => (*i).cast_unsigned(),
|
||||||
|
_ => 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_owned(&self) -> Value {
|
pub fn to_owned(&self) -> Value {
|
||||||
match self {
|
match self {
|
||||||
ValueRef::Null => Value::Null,
|
ValueRef::Null => Value::Null,
|
||||||
@@ -1508,6 +1544,16 @@ impl<'a> ValueRef<'a> {
|
|||||||
ValueRef::Blob(b) => Value::Blob(b.to_vec()),
|
ValueRef::Blob(b) => Value::Blob(b.to_vec()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn value_type(&self) -> ValueType {
|
||||||
|
match self {
|
||||||
|
Self::Null => ValueType::Null,
|
||||||
|
Self::Integer(_) => ValueType::Integer,
|
||||||
|
Self::Float(_) => ValueType::Float,
|
||||||
|
Self::Text(_) => ValueType::Text,
|
||||||
|
Self::Blob(_) => ValueType::Blob,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for ValueRef<'_> {
|
impl Display for ValueRef<'_> {
|
||||||
|
|||||||
@@ -5219,51 +5219,52 @@ pub fn op_function(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
crate::function::Func::Vector(vector_func) => match vector_func {
|
crate::function::Func::Vector(vector_func) => {
|
||||||
|
let values =
|
||||||
|
registers_to_ref_values(&state.registers[*start_reg..*start_reg + arg_count]);
|
||||||
|
match vector_func {
|
||||||
VectorFunc::Vector => {
|
VectorFunc::Vector => {
|
||||||
let result = vector32(&state.registers[*start_reg..*start_reg + arg_count])?;
|
let result = vector32(values)?;
|
||||||
state.registers[*dest] = Register::Value(result);
|
state.registers[*dest] = Register::Value(result);
|
||||||
}
|
}
|
||||||
VectorFunc::Vector32 => {
|
VectorFunc::Vector32 => {
|
||||||
let result = vector32(&state.registers[*start_reg..*start_reg + arg_count])?;
|
let result = vector32(values)?;
|
||||||
state.registers[*dest] = Register::Value(result);
|
state.registers[*dest] = Register::Value(result);
|
||||||
}
|
}
|
||||||
VectorFunc::Vector32Sparse => {
|
VectorFunc::Vector32Sparse => {
|
||||||
let result = vector32_sparse(&state.registers[*start_reg..*start_reg + arg_count])?;
|
let result = vector32_sparse(values)?;
|
||||||
state.registers[*dest] = Register::Value(result);
|
state.registers[*dest] = Register::Value(result);
|
||||||
}
|
}
|
||||||
VectorFunc::Vector64 => {
|
VectorFunc::Vector64 => {
|
||||||
let result = vector64(&state.registers[*start_reg..*start_reg + arg_count])?;
|
let result = vector64(values)?;
|
||||||
state.registers[*dest] = Register::Value(result);
|
state.registers[*dest] = Register::Value(result);
|
||||||
}
|
}
|
||||||
VectorFunc::VectorExtract => {
|
VectorFunc::VectorExtract => {
|
||||||
let result = vector_extract(&state.registers[*start_reg..*start_reg + arg_count])?;
|
let result = vector_extract(values)?;
|
||||||
state.registers[*dest] = Register::Value(result);
|
state.registers[*dest] = Register::Value(result);
|
||||||
}
|
}
|
||||||
VectorFunc::VectorDistanceCos => {
|
VectorFunc::VectorDistanceCos => {
|
||||||
let result =
|
let result = vector_distance_cos(values)?;
|
||||||
vector_distance_cos(&state.registers[*start_reg..*start_reg + arg_count])?;
|
|
||||||
state.registers[*dest] = Register::Value(result);
|
state.registers[*dest] = Register::Value(result);
|
||||||
}
|
}
|
||||||
VectorFunc::VectorDistanceL2 => {
|
VectorFunc::VectorDistanceL2 => {
|
||||||
let result =
|
let result = vector_distance_l2(values)?;
|
||||||
vector_distance_l2(&state.registers[*start_reg..*start_reg + arg_count])?;
|
|
||||||
state.registers[*dest] = Register::Value(result);
|
state.registers[*dest] = Register::Value(result);
|
||||||
}
|
}
|
||||||
VectorFunc::VectorDistanceJaccard => {
|
VectorFunc::VectorDistanceJaccard => {
|
||||||
let result =
|
let result = vector_distance_jaccard(values)?;
|
||||||
vector_distance_jaccard(&state.registers[*start_reg..*start_reg + arg_count])?;
|
|
||||||
state.registers[*dest] = Register::Value(result);
|
state.registers[*dest] = Register::Value(result);
|
||||||
}
|
}
|
||||||
VectorFunc::VectorConcat => {
|
VectorFunc::VectorConcat => {
|
||||||
let result = vector_concat(&state.registers[*start_reg..*start_reg + arg_count])?;
|
let result = vector_concat(values)?;
|
||||||
state.registers[*dest] = Register::Value(result);
|
state.registers[*dest] = Register::Value(result);
|
||||||
}
|
}
|
||||||
VectorFunc::VectorSlice => {
|
VectorFunc::VectorSlice => {
|
||||||
let result = vector_slice(&state.registers[*start_reg..*start_reg + arg_count])?;
|
let result = vector_slice(values)?;
|
||||||
state.registers[*dest] = Register::Value(result)
|
state.registers[*dest] = Register::Value(result)
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
}
|
||||||
crate::function::Func::External(f) => match f.func {
|
crate::function::Func::External(f) => match f.func {
|
||||||
ExtFunc::Scalar(f) => {
|
ExtFunc::Scalar(f) => {
|
||||||
if arg_count == 0 {
|
if arg_count == 0 {
|
||||||
|
|||||||
@@ -1,21 +1,26 @@
|
|||||||
|
use crate::types::AsValueRef;
|
||||||
use crate::types::Value;
|
use crate::types::Value;
|
||||||
use crate::types::ValueType;
|
use crate::types::ValueType;
|
||||||
use crate::vdbe::Register;
|
|
||||||
use crate::LimboError;
|
use crate::LimboError;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
|
use crate::ValueRef;
|
||||||
|
|
||||||
pub mod operations;
|
pub mod operations;
|
||||||
pub mod vector_types;
|
pub mod vector_types;
|
||||||
use vector_types::*;
|
use vector_types::*;
|
||||||
|
|
||||||
pub fn parse_vector(value: &Register, type_hint: Option<VectorType>) -> Result<Vector> {
|
pub fn parse_vector<'a>(
|
||||||
match value.get_value().value_type() {
|
value: &'a (impl AsValueRef + 'a),
|
||||||
|
type_hint: Option<VectorType>,
|
||||||
|
) -> Result<Vector<'a>> {
|
||||||
|
let value = value.as_value_ref();
|
||||||
|
match value.value_type() {
|
||||||
ValueType::Text => operations::text::vector_from_text(
|
ValueType::Text => operations::text::vector_from_text(
|
||||||
type_hint.unwrap_or(VectorType::Float32Dense),
|
type_hint.unwrap_or(VectorType::Float32Dense),
|
||||||
value.get_value().to_text().expect("value must be text"),
|
value.to_text().expect("value must be text"),
|
||||||
),
|
),
|
||||||
ValueType::Blob => {
|
ValueType::Blob => {
|
||||||
let Some(blob) = value.get_value().to_blob() else {
|
let Some(blob) = value.to_blob() else {
|
||||||
return Err(LimboError::ConversionError(
|
return Err(LimboError::ConversionError(
|
||||||
"Invalid vector value".to_string(),
|
"Invalid vector value".to_string(),
|
||||||
));
|
));
|
||||||
@@ -28,48 +33,77 @@ pub fn parse_vector(value: &Register, type_hint: Option<VectorType>) -> Result<V
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vector32(args: &[Register]) -> Result<Value> {
|
pub fn vector32<I, E, V>(args: I) -> Result<Value>
|
||||||
|
where
|
||||||
|
V: AsValueRef,
|
||||||
|
E: ExactSizeIterator<Item = V>,
|
||||||
|
I: IntoIterator<IntoIter = E, Item = V>,
|
||||||
|
{
|
||||||
|
let mut args = args.into_iter();
|
||||||
if args.len() != 1 {
|
if args.len() != 1 {
|
||||||
return Err(LimboError::ConversionError(
|
return Err(LimboError::ConversionError(
|
||||||
"vector32 requires exactly one argument".to_string(),
|
"vector32 requires exactly one argument".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
let vector = parse_vector(&args[0], Some(VectorType::Float32Dense))?;
|
let value = args.next().unwrap();
|
||||||
|
let vector = parse_vector(&value, Some(VectorType::Float32Dense))?;
|
||||||
let vector = operations::convert::vector_convert(vector, VectorType::Float32Dense)?;
|
let vector = operations::convert::vector_convert(vector, VectorType::Float32Dense)?;
|
||||||
Ok(operations::serialize::vector_serialize(vector))
|
Ok(operations::serialize::vector_serialize(vector))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vector32_sparse(args: &[Register]) -> Result<Value> {
|
pub fn vector32_sparse<I, E, V>(args: I) -> Result<Value>
|
||||||
|
where
|
||||||
|
V: AsValueRef,
|
||||||
|
E: ExactSizeIterator<Item = V>,
|
||||||
|
I: IntoIterator<IntoIter = E, Item = V>,
|
||||||
|
{
|
||||||
|
let mut args = args.into_iter();
|
||||||
if args.len() != 1 {
|
if args.len() != 1 {
|
||||||
return Err(LimboError::ConversionError(
|
return Err(LimboError::ConversionError(
|
||||||
"vector32_sparse requires exactly one argument".to_string(),
|
"vector32_sparse requires exactly one argument".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
let vector = parse_vector(&args[0], Some(VectorType::Float32Sparse))?;
|
let value = args.next().unwrap();
|
||||||
|
let vector = parse_vector(&value, Some(VectorType::Float32Sparse))?;
|
||||||
let vector = operations::convert::vector_convert(vector, VectorType::Float32Sparse)?;
|
let vector = operations::convert::vector_convert(vector, VectorType::Float32Sparse)?;
|
||||||
Ok(operations::serialize::vector_serialize(vector))
|
Ok(operations::serialize::vector_serialize(vector))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vector64(args: &[Register]) -> Result<Value> {
|
pub fn vector64<I, E, V>(args: I) -> Result<Value>
|
||||||
|
where
|
||||||
|
V: AsValueRef,
|
||||||
|
E: ExactSizeIterator<Item = V>,
|
||||||
|
I: IntoIterator<IntoIter = E, Item = V>,
|
||||||
|
{
|
||||||
|
let mut args = args.into_iter();
|
||||||
if args.len() != 1 {
|
if args.len() != 1 {
|
||||||
return Err(LimboError::ConversionError(
|
return Err(LimboError::ConversionError(
|
||||||
"vector64 requires exactly one argument".to_string(),
|
"vector64 requires exactly one argument".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
let vector = parse_vector(&args[0], Some(VectorType::Float64Dense))?;
|
let value = args.next().unwrap();
|
||||||
|
let vector = parse_vector(&value, Some(VectorType::Float64Dense))?;
|
||||||
let vector = operations::convert::vector_convert(vector, VectorType::Float64Dense)?;
|
let vector = operations::convert::vector_convert(vector, VectorType::Float64Dense)?;
|
||||||
Ok(operations::serialize::vector_serialize(vector))
|
Ok(operations::serialize::vector_serialize(vector))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vector_extract(args: &[Register]) -> Result<Value> {
|
pub fn vector_extract<I, E, V>(args: I) -> Result<Value>
|
||||||
|
where
|
||||||
|
V: AsValueRef,
|
||||||
|
E: ExactSizeIterator<Item = V>,
|
||||||
|
I: IntoIterator<IntoIter = E, Item = V>,
|
||||||
|
{
|
||||||
|
let mut args = args.into_iter();
|
||||||
if args.len() != 1 {
|
if args.len() != 1 {
|
||||||
return Err(LimboError::ConversionError(
|
return Err(LimboError::ConversionError(
|
||||||
"vector_extract requires exactly one argument".to_string(),
|
"vector_extract requires exactly one argument".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let blob = match &args[0].get_value() {
|
let value = args.next().unwrap();
|
||||||
Value::Blob(b) => b,
|
let value = value.as_value_ref();
|
||||||
|
let blob = match value {
|
||||||
|
ValueRef::Blob(b) => b,
|
||||||
_ => {
|
_ => {
|
||||||
return Err(LimboError::ConversionError(
|
return Err(LimboError::ConversionError(
|
||||||
"Expected blob value".to_string(),
|
"Expected blob value".to_string(),
|
||||||
@@ -81,78 +115,120 @@ pub fn vector_extract(args: &[Register]) -> Result<Value> {
|
|||||||
return Ok(Value::build_text("[]"));
|
return Ok(Value::build_text("[]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
let vector = Vector::from_vec(blob.to_vec())?;
|
let vector = Vector::from_slice(blob)?;
|
||||||
Ok(Value::build_text(operations::text::vector_to_text(&vector)))
|
Ok(Value::build_text(operations::text::vector_to_text(&vector)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vector_distance_cos(args: &[Register]) -> Result<Value> {
|
pub fn vector_distance_cos<I, E, V>(args: I) -> Result<Value>
|
||||||
|
where
|
||||||
|
V: AsValueRef,
|
||||||
|
E: ExactSizeIterator<Item = V>,
|
||||||
|
I: IntoIterator<IntoIter = E, Item = V>,
|
||||||
|
{
|
||||||
|
let mut args = args.into_iter();
|
||||||
if args.len() != 2 {
|
if args.len() != 2 {
|
||||||
return Err(LimboError::ConversionError(
|
return Err(LimboError::ConversionError(
|
||||||
"vector_distance_cos requires exactly two arguments".to_string(),
|
"vector_distance_cos requires exactly two arguments".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let x = parse_vector(&args[0], None)?;
|
let value_0 = args.next().unwrap();
|
||||||
let y = parse_vector(&args[1], None)?;
|
let value_1 = args.next().unwrap();
|
||||||
|
let x = parse_vector(&value_0, None)?;
|
||||||
|
let y = parse_vector(&value_1, None)?;
|
||||||
let dist = operations::distance_cos::vector_distance_cos(&x, &y)?;
|
let dist = operations::distance_cos::vector_distance_cos(&x, &y)?;
|
||||||
Ok(Value::Float(dist))
|
Ok(Value::Float(dist))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vector_distance_l2(args: &[Register]) -> Result<Value> {
|
pub fn vector_distance_l2<I, E, V>(args: I) -> Result<Value>
|
||||||
|
where
|
||||||
|
V: AsValueRef,
|
||||||
|
E: ExactSizeIterator<Item = V>,
|
||||||
|
I: IntoIterator<IntoIter = E, Item = V>,
|
||||||
|
{
|
||||||
|
let mut args = args.into_iter();
|
||||||
if args.len() != 2 {
|
if args.len() != 2 {
|
||||||
return Err(LimboError::ConversionError(
|
return Err(LimboError::ConversionError(
|
||||||
"distance_l2 requires exactly two arguments".to_string(),
|
"distance_l2 requires exactly two arguments".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let x = parse_vector(&args[0], None)?;
|
let value_0 = args.next().unwrap();
|
||||||
let y = parse_vector(&args[1], None)?;
|
let value_1 = args.next().unwrap();
|
||||||
|
let x = parse_vector(&value_0, None)?;
|
||||||
|
let y = parse_vector(&value_1, None)?;
|
||||||
let dist = operations::distance_l2::vector_distance_l2(&x, &y)?;
|
let dist = operations::distance_l2::vector_distance_l2(&x, &y)?;
|
||||||
Ok(Value::Float(dist))
|
Ok(Value::Float(dist))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vector_distance_jaccard(args: &[Register]) -> Result<Value> {
|
pub fn vector_distance_jaccard<I, E, V>(args: I) -> Result<Value>
|
||||||
|
where
|
||||||
|
V: AsValueRef,
|
||||||
|
E: ExactSizeIterator<Item = V>,
|
||||||
|
I: IntoIterator<IntoIter = E, Item = V>,
|
||||||
|
{
|
||||||
|
let mut args = args.into_iter();
|
||||||
if args.len() != 2 {
|
if args.len() != 2 {
|
||||||
return Err(LimboError::ConversionError(
|
return Err(LimboError::ConversionError(
|
||||||
"distance_jaccard requires exactly two arguments".to_string(),
|
"distance_jaccard requires exactly two arguments".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let x = parse_vector(&args[0], None)?;
|
let value_0 = args.next().unwrap();
|
||||||
let y = parse_vector(&args[1], None)?;
|
let value_1 = args.next().unwrap();
|
||||||
|
let x = parse_vector(&value_0, None)?;
|
||||||
|
let y = parse_vector(&value_1, None)?;
|
||||||
let dist = operations::jaccard::vector_distance_jaccard(&x, &y)?;
|
let dist = operations::jaccard::vector_distance_jaccard(&x, &y)?;
|
||||||
Ok(Value::Float(dist))
|
Ok(Value::Float(dist))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vector_concat(args: &[Register]) -> Result<Value> {
|
pub fn vector_concat<I, E, V>(args: I) -> Result<Value>
|
||||||
|
where
|
||||||
|
V: AsValueRef,
|
||||||
|
E: ExactSizeIterator<Item = V>,
|
||||||
|
I: IntoIterator<IntoIter = E, Item = V>,
|
||||||
|
{
|
||||||
|
let mut args = args.into_iter();
|
||||||
if args.len() != 2 {
|
if args.len() != 2 {
|
||||||
return Err(LimboError::InvalidArgument(
|
return Err(LimboError::InvalidArgument(
|
||||||
"concat requires exactly two arguments".into(),
|
"concat requires exactly two arguments".into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let x = parse_vector(&args[0], None)?;
|
let value_0 = args.next().unwrap();
|
||||||
let y = parse_vector(&args[1], None)?;
|
let value_1 = args.next().unwrap();
|
||||||
|
let x = parse_vector(&value_0, None)?;
|
||||||
|
let y = parse_vector(&value_1, None)?;
|
||||||
let vector = operations::concat::vector_concat(&x, &y)?;
|
let vector = operations::concat::vector_concat(&x, &y)?;
|
||||||
Ok(operations::serialize::vector_serialize(vector))
|
Ok(operations::serialize::vector_serialize(vector))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vector_slice(args: &[Register]) -> Result<Value> {
|
pub fn vector_slice<I, E, V>(args: I) -> Result<Value>
|
||||||
|
where
|
||||||
|
V: AsValueRef,
|
||||||
|
E: ExactSizeIterator<Item = V>,
|
||||||
|
I: IntoIterator<IntoIter = E, Item = V>,
|
||||||
|
{
|
||||||
|
let mut args = args.into_iter();
|
||||||
if args.len() != 3 {
|
if args.len() != 3 {
|
||||||
return Err(LimboError::InvalidArgument(
|
return Err(LimboError::InvalidArgument(
|
||||||
"vector_slice requires exactly three arguments".into(),
|
"vector_slice requires exactly three arguments".into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
let value_0 = args.next().unwrap();
|
||||||
|
let value_1 = args.next().unwrap();
|
||||||
|
let value_1 = value_1.as_value_ref();
|
||||||
|
|
||||||
let vector = parse_vector(&args[0], None)?;
|
let value_2 = args.next().unwrap();
|
||||||
|
let value_2 = value_2.as_value_ref();
|
||||||
|
|
||||||
let start_index = args[1]
|
let vector = parse_vector(&value_0, None)?;
|
||||||
.get_value()
|
|
||||||
|
let start_index = value_1
|
||||||
.as_int()
|
.as_int()
|
||||||
.ok_or_else(|| LimboError::InvalidArgument("start index must be an integer".into()))?;
|
.ok_or_else(|| LimboError::InvalidArgument("start index must be an integer".into()))?;
|
||||||
|
|
||||||
let end_index = args[2]
|
let end_index = value_2
|
||||||
.get_value()
|
|
||||||
.as_int()
|
.as_int()
|
||||||
.ok_or_else(|| LimboError::InvalidArgument("end_index must be an integer".into()))?;
|
.ok_or_else(|| LimboError::InvalidArgument("end_index must be an integer".into()))?;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user