From 482e93bfd0b99b225a07d110aba98e2138e7e7ba Mon Sep 17 00:00:00 2001 From: Sachin Singh Date: Fri, 11 Apr 2025 05:54:23 +0530 Subject: [PATCH 1/3] feat: add likelihood scalar function --- core/function.rs | 3 +++ core/translate/expr.rs | 47 +++++++++++++++++++++++++++++++++++++++ core/vdbe/execute.rs | 50 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/core/function.rs b/core/function.rs index 4c235cca5..904ca7b93 100644 --- a/core/function.rs +++ b/core/function.rs @@ -293,6 +293,7 @@ pub enum ScalarFunc { StrfTime, Printf, Likely, + Likelihood, } impl Display for ScalarFunc { @@ -348,6 +349,7 @@ impl Display for ScalarFunc { Self::StrfTime => "strftime".to_string(), Self::Printf => "printf".to_string(), Self::Likely => "likely".to_string(), + Self::Likelihood => "likelihood".to_string(), }; write!(f, "{}", str) } @@ -599,6 +601,7 @@ impl Func { "sqlite_source_id" => Ok(Self::Scalar(ScalarFunc::SqliteSourceId)), "replace" => Ok(Self::Scalar(ScalarFunc::Replace)), "likely" => Ok(Self::Scalar(ScalarFunc::Likely)), + "likelihood" => Ok(Self::Scalar(ScalarFunc::Likelihood)), #[cfg(feature = "json")] "json" => Ok(Self::Json(JsonFunc::Json)), #[cfg(feature = "json")] diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 7bb0dc228..3827dac63 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1598,6 +1598,53 @@ pub fn translate_expr( }); Ok(target_register) } + ScalarFunc::Likelihood => { + let args = if let Some(args) = args { + if args.len() != 2 { + crate::bail_parse_error!( + "likelihood function must have exactly 2 arguments", + ); + } + args + } else { + crate::bail_parse_error!("likelihood function with no arguments",); + }; + + if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = args[1] { + if let Ok(probability) = value.parse::() { + if probability < 0.0 || probability > 1.0 { + crate::bail_parse_error!( + "likelihood second argument must be between 0.0 and 1.0", + ); + } + } else { + crate::bail_parse_error!( + "likelihood second argument must be a floating point constant", + ); + } + } else { + crate::bail_parse_error!( + "likelihood second argument must be a numeric literal", + ); + } + + let start_reg = program.alloc_register(); + translate_and_mark( + program, + referenced_tables, + &args[0], + start_reg, + resolver, + )?; + + program.emit_insn(Insn::Copy { + src_reg: start_reg, + dst_reg: target_register, + amount: 0, + }); + + Ok(target_register) + } } } Func::Math(math_func) => match math_func.arity() { diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 6827ba83b..0dafeada2 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -3505,6 +3505,14 @@ pub fn op_function( let result = exec_likely(value.get_owned_value()); state.registers[*dest] = Register::OwnedValue(result); } + ScalarFunc::Likelihood => { + assert_eq!(arg_count, 2); + let value = &state.registers[*start_reg]; + let probability = &state.registers[*start_reg + 1]; + let result = + exec_likelihood(value.get_owned_value(), probability.get_owned_value()); + state.registers[*dest] = Register::OwnedValue(result); + } }, crate::function::Func::Vector(vector_func) => match vector_func { VectorFunc::Vector => { @@ -5365,6 +5373,10 @@ fn exec_likely(reg: &OwnedValue) -> OwnedValue { reg.clone() } +fn exec_likelihood(reg: &OwnedValue, _probability: &OwnedValue) -> OwnedValue { + reg.clone() +} + pub fn exec_add(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { let result = match (lhs, rhs) { (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { @@ -6248,7 +6260,7 @@ mod tests { } use crate::vdbe::{ - execute::{exec_likely, exec_replace}, + execute::{exec_likelihood, exec_likely, exec_replace}, Bitfield, Register, }; @@ -7165,6 +7177,42 @@ mod tests { assert_eq!(exec_likely(&input), expected); } + #[test] + fn test_likelihood() { + let value = OwnedValue::build_text("limbo"); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::build_text("database"); + let prob = OwnedValue::Float(0.9375); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Integer(100); + let prob = OwnedValue::Float(0.0625); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Float(12.34); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Null; + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Blob(vec![1, 2, 3, 4]); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::Integer(1); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::build_text("0.5"); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::Null; + assert_eq!(exec_likelihood(&value, &prob), value); + } + #[test] fn test_bitfield() { let mut bitfield = Bitfield::<4>::new(); From 5ffdd42f12a22353e45c91100b2ced8269a9373a Mon Sep 17 00:00:00 2001 From: Sachin Singh Date: Fri, 11 Apr 2025 06:02:07 +0530 Subject: [PATCH 2/3] Additional tests --- COMPAT.md | 2 +- testing/scalar-functions.test | 36 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/COMPAT.md b/COMPAT.md index e85a47725..d1ce96b96 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -226,7 +226,7 @@ Feature support of [sqlite expr syntax](https://www.sqlite.org/lang_expr.html). | length(X) | Yes | | | like(X,Y) | Yes | | | like(X,Y,Z) | Yes | | -| likelihood(X,Y) | No | | +| likelihood(X,Y) | Yes | | | likely(X) | Yes | | | load_extension(X) | Yes | sqlite3 extensions not yet supported | | load_extension(X,Y) | No | | diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index 09e99a8f3..a63e80467 100755 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -211,6 +211,42 @@ do_execsql_test likely-null { select likely(NULL) } {} +do_execsql_test likelihood-string { + SELECT likelihood('limbo', 0.5); +} {limbo} + +do_execsql_test likelihood-string-high-probability { + SELECT likelihood('database', 0.9375); +} {database} + +do_execsql_test likelihood-integer { + SELECT likelihood(100, 0.0625); +} {100} + +do_execsql_test likelihood-integer-probability-1 { + SELECT likelihood(42, 1); +} {42} + +do_execsql_test likelihood-decimal { + SELECT likelihood(12.34, 0.5); +} {12.34} + +do_execsql_test likelihood-null { + SELECT likelihood(NULL, 0.5); +} {} + +do_execsql_test likelihood-blob { + SELECT hex(likelihood(x'01020304', 0.5)); +} {01020304} + +do_execsql_test likelihood-zero-probability { + SELECT likelihood(999, 0); +} {999} + +do_execsql_test likelihood-extreme-probability { + SELECT likelihood(999, 1); +} {999} + do_execsql_test unhex-str-ab { SELECT unhex('6162'); } {ab} From 01fa02364d19b558fc83fdb6695de0fb0d8a5f63 Mon Sep 17 00:00:00 2001 From: Sachin Singh Date: Fri, 11 Apr 2025 08:34:29 +0530 Subject: [PATCH 3/3] correctly handle edge cases --- core/translate/expr.rs | 17 +++++++++++------ core/vdbe/execute.rs | 5 +---- testing/scalar-functions.test | 8 ++------ 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 3827dac63..02be1db8d 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1602,29 +1602,34 @@ pub fn translate_expr( let args = if let Some(args) = args { if args.len() != 2 { crate::bail_parse_error!( - "likelihood function must have exactly 2 arguments", + "likelihood() function must have exactly 2 arguments", ); } args } else { - crate::bail_parse_error!("likelihood function with no arguments",); + crate::bail_parse_error!("likelihood() function with no arguments",); }; if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = args[1] { if let Ok(probability) = value.parse::() { - if probability < 0.0 || probability > 1.0 { + if !(0.0..=1.0).contains(&probability) { crate::bail_parse_error!( - "likelihood second argument must be between 0.0 and 1.0", + "second argument of likelihood() must be between 0.0 and 1.0", + ); + } + if !value.contains('.') { + crate::bail_parse_error!( + "second argument of likelihood() must be a floating point number with decimal point", ); } } else { crate::bail_parse_error!( - "likelihood second argument must be a floating point constant", + "second argument of likelihood() must be a floating point constant", ); } } else { crate::bail_parse_error!( - "likelihood second argument must be a numeric literal", + "second argument of likelihood() must be a numeric literal", ); } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 0dafeada2..8017e8c96 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -7188,7 +7188,7 @@ mod tests { assert_eq!(exec_likelihood(&value, &prob), value); let value = OwnedValue::Integer(100); - let prob = OwnedValue::Float(0.0625); + let prob = OwnedValue::Float(1.0); assert_eq!(exec_likelihood(&value, &prob), value); let value = OwnedValue::Float(12.34); @@ -7203,9 +7203,6 @@ mod tests { let prob = OwnedValue::Float(0.5); assert_eq!(exec_likelihood(&value, &prob), value); - let prob = OwnedValue::Integer(1); - assert_eq!(exec_likelihood(&value, &prob), value); - let prob = OwnedValue::build_text("0.5"); assert_eq!(exec_likelihood(&value, &prob), value); diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index a63e80467..807c4971d 100755 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -224,7 +224,7 @@ do_execsql_test likelihood-integer { } {100} do_execsql_test likelihood-integer-probability-1 { - SELECT likelihood(42, 1); + SELECT likelihood(42, 1.0); } {42} do_execsql_test likelihood-decimal { @@ -240,11 +240,7 @@ do_execsql_test likelihood-blob { } {01020304} do_execsql_test likelihood-zero-probability { - SELECT likelihood(999, 0); -} {999} - -do_execsql_test likelihood-extreme-probability { - SELECT likelihood(999, 1); + SELECT likelihood(999, 0.0); } {999} do_execsql_test unhex-str-ab {