diff --git a/COMPAT.md b/COMPAT.md index 2f9a954f7..3d07558c8 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -226,7 +226,7 @@ Feature support of [sqlite expr syntax](https://www.sqlite.org/lang_expr.html). | length(X) | Yes | | | like(X,Y) | Yes | | | like(X,Y,Z) | Yes | | -| likelihood(X,Y) | No | | +| likelihood(X,Y) | Yes | | | likely(X) | Yes | | | load_extension(X) | Yes | sqlite3 extensions not yet supported | | load_extension(X,Y) | No | | diff --git a/core/function.rs b/core/function.rs index 41613c8c8..da246e719 100644 --- a/core/function.rs +++ b/core/function.rs @@ -294,6 +294,7 @@ pub enum ScalarFunc { Printf, Likely, TimeDiff, + Likelihood, } impl Display for ScalarFunc { @@ -350,6 +351,7 @@ impl Display for ScalarFunc { Self::Printf => "printf".to_string(), Self::Likely => "likely".to_string(), Self::TimeDiff => "timediff".to_string(), + Self::Likelihood => "likelihood".to_string(), }; write!(f, "{}", str) } @@ -607,6 +609,7 @@ impl Func { "sqlite_source_id" => Ok(Self::Scalar(ScalarFunc::SqliteSourceId)), "replace" => Ok(Self::Scalar(ScalarFunc::Replace)), "likely" => Ok(Self::Scalar(ScalarFunc::Likely)), + "likelihood" => Ok(Self::Scalar(ScalarFunc::Likelihood)), #[cfg(feature = "json")] "json" => Ok(Self::Json(JsonFunc::Json)), #[cfg(feature = "json")] diff --git a/core/translate/expr.rs b/core/translate/expr.rs index fbd7680f4..958005259 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1625,6 +1625,58 @@ pub fn translate_expr( }); Ok(target_register) } + ScalarFunc::Likelihood => { + let args = if let Some(args) = args { + if args.len() != 2 { + crate::bail_parse_error!( + "likelihood() function must have exactly 2 arguments", + ); + } + args + } else { + crate::bail_parse_error!("likelihood() function with no arguments",); + }; + + if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = args[1] { + if let Ok(probability) = value.parse::() { + if !(0.0..=1.0).contains(&probability) { + crate::bail_parse_error!( + "second argument of likelihood() must be between 0.0 and 1.0", + ); + } + if !value.contains('.') { + crate::bail_parse_error!( + "second argument of likelihood() must be a floating point number with decimal point", + ); + } + } else { + crate::bail_parse_error!( + "second argument of likelihood() must be a floating point constant", + ); + } + } else { + crate::bail_parse_error!( + "second argument of likelihood() must be a numeric literal", + ); + } + + let start_reg = program.alloc_register(); + translate_and_mark( + program, + referenced_tables, + &args[0], + start_reg, + resolver, + )?; + + program.emit_insn(Insn::Copy { + src_reg: start_reg, + dst_reg: target_register, + amount: 0, + }); + + Ok(target_register) + } } } Func::Math(math_func) => match math_func.arity() { diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index e5194536d..4d2a96d10 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -3520,6 +3520,14 @@ pub fn op_function( let result = exec_likely(value.get_owned_value()); state.registers[*dest] = Register::OwnedValue(result); } + ScalarFunc::Likelihood => { + assert_eq!(arg_count, 2); + let value = &state.registers[*start_reg]; + let probability = &state.registers[*start_reg + 1]; + let result = + exec_likelihood(value.get_owned_value(), probability.get_owned_value()); + state.registers[*dest] = Register::OwnedValue(result); + } }, crate::function::Func::Vector(vector_func) => match vector_func { VectorFunc::Vector => { @@ -5380,6 +5388,10 @@ fn exec_likely(reg: &OwnedValue) -> OwnedValue { reg.clone() } +fn exec_likelihood(reg: &OwnedValue, _probability: &OwnedValue) -> OwnedValue { + reg.clone() +} + pub fn exec_add(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { let result = match (lhs, rhs) { (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { @@ -6263,7 +6275,7 @@ mod tests { } use crate::vdbe::{ - execute::{exec_likely, exec_replace}, + execute::{exec_likelihood, exec_likely, exec_replace}, Bitfield, Register, }; @@ -7180,6 +7192,39 @@ mod tests { assert_eq!(exec_likely(&input), expected); } + #[test] + fn test_likelihood() { + let value = OwnedValue::build_text("limbo"); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::build_text("database"); + let prob = OwnedValue::Float(0.9375); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Integer(100); + let prob = OwnedValue::Float(1.0); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Float(12.34); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Null; + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Blob(vec![1, 2, 3, 4]); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::build_text("0.5"); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::Null; + assert_eq!(exec_likelihood(&value, &prob), value); + } + #[test] fn test_bitfield() { let mut bitfield = Bitfield::<4>::new(); diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index 09e99a8f3..807c4971d 100755 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -211,6 +211,38 @@ do_execsql_test likely-null { select likely(NULL) } {} +do_execsql_test likelihood-string { + SELECT likelihood('limbo', 0.5); +} {limbo} + +do_execsql_test likelihood-string-high-probability { + SELECT likelihood('database', 0.9375); +} {database} + +do_execsql_test likelihood-integer { + SELECT likelihood(100, 0.0625); +} {100} + +do_execsql_test likelihood-integer-probability-1 { + SELECT likelihood(42, 1.0); +} {42} + +do_execsql_test likelihood-decimal { + SELECT likelihood(12.34, 0.5); +} {12.34} + +do_execsql_test likelihood-null { + SELECT likelihood(NULL, 0.5); +} {} + +do_execsql_test likelihood-blob { + SELECT hex(likelihood(x'01020304', 0.5)); +} {01020304} + +do_execsql_test likelihood-zero-probability { + SELECT likelihood(999, 0.0); +} {999} + do_execsql_test unhex-str-ab { SELECT unhex('6162'); } {ab}