From dde00c3bc5bc193358592007ebd9498559e63d46 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Sun, 21 Jul 2024 15:13:22 -0400 Subject: [PATCH 1/3] implementation of scalar functions min and max --- core/function.rs | 6 +++ core/translate/expr.rs | 54 ++++++++++++++++++++++++++ core/vdbe.rs | 72 ++++++++++++++++++++++++++++++++++- testing/scalar-functions.test | 25 +++++++++++- 4 files changed, 155 insertions(+), 2 deletions(-) diff --git a/core/function.rs b/core/function.rs index 0ec15ce28..44ab386a1 100644 --- a/core/function.rs +++ b/core/function.rs @@ -38,6 +38,8 @@ pub enum SingleRowFunc { Trim, Round, Length, + Min, + Max, } impl ToString for SingleRowFunc { @@ -52,6 +54,8 @@ impl ToString for SingleRowFunc { SingleRowFunc::Trim => "trim".to_string(), SingleRowFunc::Round => "round".to_string(), SingleRowFunc::Length => "length".to_string(), + SingleRowFunc::Min => "min_arr".to_string(), + SingleRowFunc::Max => "max_arr".to_string(), } } } @@ -84,6 +88,8 @@ impl FromStr for Func { "trim" => Ok(Func::SingleRow(SingleRowFunc::Trim)), "round" => Ok(Func::SingleRow(SingleRowFunc::Round)), "length" => Ok(Func::SingleRow(SingleRowFunc::Length)), + "min_arr" => Ok(Func::SingleRow(SingleRowFunc::Min)), + "max_arr" => Ok(Func::SingleRow(SingleRowFunc::Max)), _ => Err(()), } } diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 53816b24c..0506359a2 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -377,6 +377,60 @@ pub fn translate_expr( }); Ok(target_register) } + SingleRowFunc::Min => { + let args = if let Some(args) = args { + if args.len() < 1 { + anyhow::bail!( + "Parse error: min function with less than one argument" + ); + } + args + } else { + anyhow::bail!("Parse error: min function with no arguments"); + }; + for arg in args { + let reg = program.alloc_register(); + let _ = translate_expr(program, select, arg, reg)?; + match arg { + ast::Expr::Literal(_) => program.mark_last_insn_constant(), + _ => {} + } + } + + program.emit_insn(Insn::Function { + start_reg: target_register + 1, + dest: target_register, + func: SingleRowFunc::Min, + }); + Ok(target_register) + } + SingleRowFunc::Max => { + let args = if let Some(args) = args { + if args.len() < 1 { + anyhow::bail!( + "Parse error: max function with less than one argument" + ); + } + args + } else { + anyhow::bail!("Parse error: max function with no arguments"); + }; + for arg in args { + let reg = program.alloc_register(); + let _ = translate_expr(program, select, arg, reg)?; + match arg { + ast::Expr::Literal(_) => program.mark_last_insn_constant(), + _ => {} + } + } + + program.emit_insn(Insn::Function { + start_reg: target_register + 1, + dest: target_register, + func: SingleRowFunc::Max, + }); + Ok(target_register) + } } } None => { diff --git a/core/vdbe.rs b/core/vdbe.rs index 9f1b2d97e..68ddb5fff 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -1362,6 +1362,32 @@ impl Program { state.registers[*dest] = result; state.pc += 1; } + SingleRowFunc::Min => { + let start_reg = *start_reg; + let reg_values = state.registers[start_reg..state.registers.len()] + .iter() + .collect(); + let min_fn = |a, b| if a < b { a } else { b }; + if let Some(value) = exec_minmax(reg_values, min_fn) { + state.registers[*dest] = value; + } else { + state.registers[*dest] = OwnedValue::Null; + } + state.pc += 1; + } + SingleRowFunc::Max => { + let start_reg = *start_reg; + let reg_values = state.registers[start_reg..state.registers.len()] + .iter() + .collect(); + let max_fn = |a, b| if a > b { a } else { b }; + if let Some(value) = exec_minmax(reg_values, max_fn) { + state.registers[*dest] = value; + } else { + state.registers[*dest] = OwnedValue::Null; + } + state.pc += 1; + } }, } } @@ -1990,6 +2016,13 @@ fn exec_like(pattern: &str, text: &str) -> bool { re.is_match(text) } +fn exec_minmax<'a>( + regs: Vec<&'a OwnedValue>, + op: fn(&'a OwnedValue, &'a OwnedValue) -> &'a OwnedValue, +) -> Option { + regs.into_iter().reduce(|a, b| op(a, b)).cloned() +} + fn exec_round(reg: &OwnedValue, precision: Option) -> OwnedValue { let precision = match precision { Some(OwnedValue::Text(x)) => x.parse().unwrap_or(0.0), @@ -2045,7 +2078,7 @@ fn exec_if(reg: &OwnedValue, null_reg: &OwnedValue, not: bool) -> bool { #[cfg(test)] mod tests { use super::{ - exec_abs, exec_if, exec_length, exec_like, exec_lower, exec_random, exec_round, exec_trim, + exec_abs, exec_if, exec_length, exec_like, exec_lower, exec_minmax, exec_random, exec_round, exec_trim, exec_upper, OwnedValue, }; use std::rc::Rc; @@ -2070,6 +2103,43 @@ mod tests { assert_eq!(exec_length(&expected_blob), expected_len); } + #[test] + fn test_minmax() { + let min_fn = |a, b| if a < b { a } else { b }; + let max_fn = |a, b| if a > b { a } else { b }; + let input_int_vec = vec![&OwnedValue::Integer(-1), &OwnedValue::Integer(10)]; + assert_eq!( + exec_minmax(input_int_vec.clone(), min_fn), + Some(OwnedValue::Integer(-1)) + ); + assert_eq!( + exec_minmax(input_int_vec.clone(), max_fn), + Some(OwnedValue::Integer(10)) + ); + + let str1 = OwnedValue::Text(Rc::new(String::from("A"))); + let str2 = OwnedValue::Text(Rc::new(String::from("z"))); + let input_str_vec = vec![&str2, &str1]; + assert_eq!( + exec_minmax(input_str_vec.clone(), min_fn), + Some(OwnedValue::Text(Rc::new(String::from("A")))) + ); + assert_eq!( + exec_minmax(input_str_vec.clone(), max_fn), + Some(OwnedValue::Text(Rc::new(String::from("z")))) + ); + + let input_null_vec = vec![&OwnedValue::Null, &OwnedValue::Null]; + assert_eq!( + exec_minmax(input_null_vec.clone(), min_fn), + Some(OwnedValue::Null) + ); + assert_eq!( + exec_minmax(input_null_vec.clone(), max_fn), + Some(OwnedValue::Null) + ); + } + #[test] fn test_trim() { let input_str = OwnedValue::Text(Rc::new(String::from(" Bob and Alice "))); diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index a3b3d2114..2640c10ce 100644 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -141,4 +141,27 @@ do_execsql_test length-null { do_execsql_test length-empty-text { SELECT length(''); -} {0} \ No newline at end of file +} {0} +do_execsql_test min-number { + select min_arr(-10,2,3) +} {-10} + +do_execsql_test min-str { + select min_arr('b','a','z') +} {a} + +do_execsql_test min-null { + select min_arr(null,null) +} {} + +do_execsql_test max-number { + select max_arr(-10,2,3) +} {3} + +do_execsql_test max-str { + select max_arr('b','a','z') +} {z} + +do_execsql_test max-null { + select max_arr(null,null) +} {} \ No newline at end of file From b81f7d9acd681009069f4e2297a0206b23b4d87a Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Sun, 21 Jul 2024 15:38:23 -0400 Subject: [PATCH 2/3] add cursor_hint to min and max scalar functions --- core/translate/expr.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 0506359a2..a5f380e6e 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -390,7 +390,7 @@ pub fn translate_expr( }; for arg in args { let reg = program.alloc_register(); - let _ = translate_expr(program, select, arg, reg)?; + let _ = translate_expr(program, select, arg, reg, cursor_hint)?; match arg { ast::Expr::Literal(_) => program.mark_last_insn_constant(), _ => {} @@ -417,7 +417,7 @@ pub fn translate_expr( }; for arg in args { let reg = program.alloc_register(); - let _ = translate_expr(program, select, arg, reg)?; + let _ = translate_expr(program, select, arg, reg, cursor_hint)?; match arg { ast::Expr::Literal(_) => program.mark_last_insn_constant(), _ => {} From c2270017371ac982e585ef23acc2c3689dabd6ab Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 22 Jul 2024 17:00:34 -0400 Subject: [PATCH 3/3] support handling functions with the same name but different parameters number --- core/function.rs | 22 ++++++++++------------ core/translate/expr.rs | 6 ++++-- testing/scalar-functions.test | 12 ++++++------ 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/core/function.rs b/core/function.rs index 44ab386a1..f3d7650b8 100644 --- a/core/function.rs +++ b/core/function.rs @@ -1,4 +1,3 @@ -use std::str::FromStr; #[derive(Debug, Clone, PartialEq)] pub enum AggFunc { @@ -54,8 +53,8 @@ impl ToString for SingleRowFunc { SingleRowFunc::Trim => "trim".to_string(), SingleRowFunc::Round => "round".to_string(), SingleRowFunc::Length => "length".to_string(), - SingleRowFunc::Min => "min_arr".to_string(), - SingleRowFunc::Max => "max_arr".to_string(), + SingleRowFunc::Min => "min".to_string(), + SingleRowFunc::Max => "max".to_string(), } } } @@ -66,16 +65,17 @@ pub enum Func { SingleRow(SingleRowFunc), } -impl FromStr for Func { - type Err = (); +impl Func{ - fn from_str(s: &str) -> Result { - match s { + pub fn resolve_function(name: &str, arg_count:usize) -> Result{ + match name { "avg" => Ok(Func::Agg(AggFunc::Avg)), "count" => Ok(Func::Agg(AggFunc::Count)), "group_concat" => Ok(Func::Agg(AggFunc::GroupConcat)), - "max" => Ok(Func::Agg(AggFunc::Max)), - "min" => Ok(Func::Agg(AggFunc::Min)), + "max" if arg_count == 0 || arg_count == 1 => Ok(Func::Agg(AggFunc::Max)), + "max" if arg_count > 1 => Ok(Func::SingleRow(SingleRowFunc::Max)), + "min" if arg_count == 0 || arg_count == 1 => Ok(Func::Agg(AggFunc::Min)), + "min" if arg_count > 1 => Ok(Func::SingleRow(SingleRowFunc::Min)), "string_agg" => Ok(Func::Agg(AggFunc::StringAgg)), "sum" => Ok(Func::Agg(AggFunc::Sum)), "total" => Ok(Func::Agg(AggFunc::Total)), @@ -88,9 +88,7 @@ impl FromStr for Func { "trim" => Ok(Func::SingleRow(SingleRowFunc::Trim)), "round" => Ok(Func::SingleRow(SingleRowFunc::Round)), "length" => Ok(Func::SingleRow(SingleRowFunc::Length)), - "min_arr" => Ok(Func::SingleRow(SingleRowFunc::Min)), - "max_arr" => Ok(Func::SingleRow(SingleRowFunc::Max)), _ => Err(()), } } -} +} \ No newline at end of file diff --git a/core/translate/expr.rs b/core/translate/expr.rs index a5f380e6e..41d70e9f7 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -225,7 +225,8 @@ pub fn translate_expr( args, filter_over: _, } => { - let func_type: Option = normalize_ident(name.0.as_str()).as_str().parse().ok(); + let args_count = if let Some(args) = args { args.len() } else { 0 }; + let func_type: Option = Func::resolve_function(normalize_ident(name.0.as_str()).as_str(),args_count).ok(); match func_type { Some(Func::Agg(_)) => { @@ -585,7 +586,8 @@ pub fn analyze_expr<'a>(expr: &'a Expr, column_info_out: &mut ColumnInfo<'a>) { args, filter_over: _, } => { - let func_type = match normalize_ident(name.0.as_str()).as_str().parse() { + let args_count = if let Some(args) = args { args.len() } else { 0 }; + let func_type = match Func::resolve_function(normalize_ident(name.0.as_str()).as_str(),args_count) { Ok(func) => Some(func), Err(_) => None, }; diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index 2640c10ce..3e245f9f1 100644 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -143,25 +143,25 @@ do_execsql_test length-empty-text { SELECT length(''); } {0} do_execsql_test min-number { - select min_arr(-10,2,3) + select min(-10,2,3) } {-10} do_execsql_test min-str { - select min_arr('b','a','z') + select min('b','a','z') } {a} do_execsql_test min-null { - select min_arr(null,null) + select min(null,null) } {} do_execsql_test max-number { - select max_arr(-10,2,3) + select max(-10,2,3) } {3} do_execsql_test max-str { - select max_arr('b','a','z') + select max('b','a','z') } {z} do_execsql_test max-null { - select max_arr(null,null) + select max(null,null) } {} \ No newline at end of file