From 776aecbce0c8a77d7351c0b17b6653ebf2976d70 Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sun, 19 Jan 2025 03:00:18 +0530 Subject: [PATCH] 1. All aggregate functions now validate their args and return Parse error for wrong number of args. 2. Unknown functions now return `no such function: unknown_function` --- core/function.rs | 83 ++++++++++++++++++++++++++++++----- core/translate/aggregation.rs | 8 +++- core/translate/group_by.rs | 8 +++- core/translate/select.rs | 40 ++++++++++------- core/vdbe/mod.rs | 14 ++++-- 5 files changed, 118 insertions(+), 35 deletions(-) diff --git a/core/function.rs b/core/function.rs index 85a142437..c0f6609ca 100644 --- a/core/function.rs +++ b/core/function.rs @@ -4,6 +4,8 @@ use std::rc::Rc; use limbo_ext::ScalarFunction; +use crate::LimboError; + pub struct ExternalFunc { pub name: String, pub func: ScalarFunction, @@ -67,6 +69,7 @@ impl Display for JsonFunc { pub enum AggFunc { Avg, Count, + Count0, GroupConcat, Max, Min, @@ -76,9 +79,24 @@ pub enum AggFunc { } impl AggFunc { + pub fn num_args(&self) -> usize { + match self { + Self::Avg => 1, + Self::Count0 => 0, + Self::Count => 1, + Self::GroupConcat => 1, + Self::Max => 1, + Self::Min => 1, + Self::StringAgg => 2, + Self::Sum => 1, + Self::Total => 1, + } + } + pub fn to_string(&self) -> &str { match self { Self::Avg => "avg", + Self::Count0 => "count", Self::Count => "count", Self::GroupConcat => "group_concat", Self::Max => "max", @@ -336,19 +354,64 @@ pub struct FuncCtx { } impl Func { - pub fn resolve_function(name: &str, arg_count: usize) -> Result { + pub fn resolve_function(name: &str, arg_count: usize) -> Result { match name { - "avg" => Ok(Self::Agg(AggFunc::Avg)), - "count" => Ok(Self::Agg(AggFunc::Count)), - "group_concat" => Ok(Self::Agg(AggFunc::GroupConcat)), - "max" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Max)), + "avg" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Avg)) + } + "count" => { + // Handle both COUNT() and COUNT(expr) cases + if arg_count == 0 { + Ok(Self::Agg(AggFunc::Count0)) // COUNT() case + } else if arg_count == 1 { + Ok(Self::Agg(AggFunc::Count)) // COUNT(expr) case + } else { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + } + "group_concat" => { + if arg_count != 1 && arg_count != 2 { + println!("{}", arg_count); + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::GroupConcat)) + } "max" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Max)), - "min" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Min)), + "max" => { + if arg_count < 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Max)) + } "min" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Min)), + "min" => { + if arg_count < 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Min)) + } "nullif" if arg_count == 2 => Ok(Self::Scalar(ScalarFunc::Nullif)), - "string_agg" => Ok(Self::Agg(AggFunc::StringAgg)), - "sum" => Ok(Self::Agg(AggFunc::Sum)), - "total" => Ok(Self::Agg(AggFunc::Total)), + "string_agg" => { + if arg_count != 2 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::StringAgg)) + } + "sum" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Sum)) + } + "total" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Total)) + } "char" => Ok(Self::Scalar(ScalarFunc::Char)), "coalesce" => Ok(Self::Scalar(ScalarFunc::Coalesce)), "concat" => Ok(Self::Scalar(ScalarFunc::Concat)), @@ -432,7 +495,7 @@ impl Func { "trunc" => Ok(Self::Math(MathFunc::Trunc)), #[cfg(not(target_family = "wasm"))] "load_extension" => Ok(Self::Scalar(ScalarFunc::LoadExtension)), - _ => Err(()), + _ => crate::bail_parse_error!("no such function: {}", name), } } } diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index c8a7a520a..1456f0fbf 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -74,7 +74,7 @@ pub fn translate_aggregation_step( }); target_register } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { let expr_reg = if agg.args.is_empty() { program.alloc_register() } else { @@ -87,7 +87,11 @@ pub fn translate_aggregation_step( acc_reg: target_register, col: expr_reg, delimiter: 0, - func: AggFunc::Count, + func: if matches!(agg.func, AggFunc::Count0) { + AggFunc::Count0 + } else { + AggFunc::Count + }, }); target_register } diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 5855caa06..024c6cfad 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -463,14 +463,18 @@ pub fn translate_aggregation_step_groupby( }); target_register } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { let expr_reg = program.alloc_register(); emit_column(program, expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, delimiter: 0, - func: AggFunc::Count, + func: if matches!(agg.func, AggFunc::Count0) { + AggFunc::Count0 + } else { + AggFunc::Count + }, }); target_register } diff --git a/core/translate/select.rs b/core/translate/select.rs index c37466a9b..ccb3d491f 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -116,21 +116,20 @@ pub fn prepare_select_plan( args_count, ) { Ok(Func::Agg(f)) => { - let agg_args: Result, LimboError> = match args { - // if args is None and its COUNT - None if name.0.to_uppercase() == "COUNT" => { - let count_args = vec![ast::Expr::Literal( - ast::Literal::Numeric("1".to_string()), - )]; - Ok(count_args) - } - // if args is None and the function is not COUNT - None => crate::bail_parse_error!( - "Aggregate function {} requires arguments", - name.0 - ), - Some(args) => Ok(args.clone()), - }; + let agg_args: Result, LimboError> = + match (args, &f) { + (None, crate::function::AggFunc::Count0) => { + // COUNT() case + Ok(vec![ast::Expr::Literal( + ast::Literal::Numeric("1".to_string()), + )]) + } + (None, _) => crate::bail_parse_error!( + "Aggregate function {} requires arguments", + name.0 + ), + (Some(args), _) => Ok(args.clone()), + }; let agg = Aggregate { func: f, @@ -163,8 +162,12 @@ pub fn prepare_select_plan( contains_aggregates, }); } - Err(_) => { - if syms.functions.contains_key(&name.0) { + Err(e) => { + // Only handle the "no such function" case specially + // All other errors should be propagated as-is + if e.to_string().starts_with("no such function: ") + && syms.functions.contains_key(&name.0) + { let contains_aggregates = resolve_aggregates( expr, &mut aggregate_expressions, @@ -179,7 +182,10 @@ pub fn prepare_select_plan( expr: expr.clone(), contains_aggregates, }); + continue; // Continue with the normal flow instead of returning } + // Propagate the original error + return Err(e); } } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 43160eb40..992a9db6e 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -1202,7 +1202,7 @@ impl Program { // Total() never throws an integer overflow. OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.0)))) } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { OwnedValue::Agg(Box::new(AggContext::Count(OwnedValue::Integer(0)))) } AggFunc::Max => { @@ -1270,7 +1270,13 @@ impl Program { }; *acc += col; } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { + // println!("here"); + if matches!(&state.registers[*acc_reg], OwnedValue::Null) { + state.registers[*acc_reg] = OwnedValue::Agg(Box::new( + AggContext::Count(OwnedValue::Integer(0)), + )); + } let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() else { unreachable!(); @@ -1397,7 +1403,7 @@ impl Program { *acc /= count.clone(); } AggFunc::Sum | AggFunc::Total => {} - AggFunc::Count => {} + AggFunc::Count | AggFunc::Count0 => {} AggFunc::Max => {} AggFunc::Min => {} AggFunc::GroupConcat | AggFunc::StringAgg => {} @@ -1409,7 +1415,7 @@ impl Program { AggFunc::Total => { state.registers[*register] = OwnedValue::Float(0.0); } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { state.registers[*register] = OwnedValue::Integer(0); } _ => {}