From 482e93bfd0b99b225a07d110aba98e2138e7e7ba Mon Sep 17 00:00:00 2001 From: Sachin Singh Date: Fri, 11 Apr 2025 05:54:23 +0530 Subject: [PATCH] feat: add likelihood scalar function --- core/function.rs | 3 +++ core/translate/expr.rs | 47 +++++++++++++++++++++++++++++++++++++++ core/vdbe/execute.rs | 50 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/core/function.rs b/core/function.rs index 4c235cca5..904ca7b93 100644 --- a/core/function.rs +++ b/core/function.rs @@ -293,6 +293,7 @@ pub enum ScalarFunc { StrfTime, Printf, Likely, + Likelihood, } impl Display for ScalarFunc { @@ -348,6 +349,7 @@ impl Display for ScalarFunc { Self::StrfTime => "strftime".to_string(), Self::Printf => "printf".to_string(), Self::Likely => "likely".to_string(), + Self::Likelihood => "likelihood".to_string(), }; write!(f, "{}", str) } @@ -599,6 +601,7 @@ impl Func { "sqlite_source_id" => Ok(Self::Scalar(ScalarFunc::SqliteSourceId)), "replace" => Ok(Self::Scalar(ScalarFunc::Replace)), "likely" => Ok(Self::Scalar(ScalarFunc::Likely)), + "likelihood" => Ok(Self::Scalar(ScalarFunc::Likelihood)), #[cfg(feature = "json")] "json" => Ok(Self::Json(JsonFunc::Json)), #[cfg(feature = "json")] diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 7bb0dc228..3827dac63 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1598,6 +1598,53 @@ pub fn translate_expr( }); Ok(target_register) } + ScalarFunc::Likelihood => { + let args = if let Some(args) = args { + if args.len() != 2 { + crate::bail_parse_error!( + "likelihood function must have exactly 2 arguments", + ); + } + args + } else { + crate::bail_parse_error!("likelihood function with no arguments",); + }; + + if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = args[1] { + if let Ok(probability) = value.parse::() { + if probability < 0.0 || probability > 1.0 { + crate::bail_parse_error!( + "likelihood second argument must be between 0.0 and 1.0", + ); + } + } else { + crate::bail_parse_error!( + "likelihood second argument must be a floating point constant", + ); + } + } else { + crate::bail_parse_error!( + "likelihood second argument must be a numeric literal", + ); + } + + let start_reg = program.alloc_register(); + translate_and_mark( + program, + referenced_tables, + &args[0], + start_reg, + resolver, + )?; + + program.emit_insn(Insn::Copy { + src_reg: start_reg, + dst_reg: target_register, + amount: 0, + }); + + Ok(target_register) + } } } Func::Math(math_func) => match math_func.arity() { diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 6827ba83b..0dafeada2 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -3505,6 +3505,14 @@ pub fn op_function( let result = exec_likely(value.get_owned_value()); state.registers[*dest] = Register::OwnedValue(result); } + ScalarFunc::Likelihood => { + assert_eq!(arg_count, 2); + let value = &state.registers[*start_reg]; + let probability = &state.registers[*start_reg + 1]; + let result = + exec_likelihood(value.get_owned_value(), probability.get_owned_value()); + state.registers[*dest] = Register::OwnedValue(result); + } }, crate::function::Func::Vector(vector_func) => match vector_func { VectorFunc::Vector => { @@ -5365,6 +5373,10 @@ fn exec_likely(reg: &OwnedValue) -> OwnedValue { reg.clone() } +fn exec_likelihood(reg: &OwnedValue, _probability: &OwnedValue) -> OwnedValue { + reg.clone() +} + pub fn exec_add(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { let result = match (lhs, rhs) { (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { @@ -6248,7 +6260,7 @@ mod tests { } use crate::vdbe::{ - execute::{exec_likely, exec_replace}, + execute::{exec_likelihood, exec_likely, exec_replace}, Bitfield, Register, }; @@ -7165,6 +7177,42 @@ mod tests { assert_eq!(exec_likely(&input), expected); } + #[test] + fn test_likelihood() { + let value = OwnedValue::build_text("limbo"); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::build_text("database"); + let prob = OwnedValue::Float(0.9375); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Integer(100); + let prob = OwnedValue::Float(0.0625); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Float(12.34); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Null; + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Blob(vec![1, 2, 3, 4]); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::Integer(1); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::build_text("0.5"); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::Null; + assert_eq!(exec_likelihood(&value, &prob), value); + } + #[test] fn test_bitfield() { let mut bitfield = Bitfield::<4>::new();