diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index 943d1a1cb..c1b462ed2 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -155,14 +155,12 @@ pub fn translate_aggregation_step( target_register } AggFunc::Count | AggFunc::Count0 => { - let expr_reg = if agg.args.is_empty() { - program.alloc_register() - } else { - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - expr_reg - }; + if agg.args.len() != 1 { + crate::bail_parse_error!("count bad number of arguments"); + } + let expr = &agg.args[0]; + let expr_reg = program.alloc_register(); + let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; handle_distinct(program, agg, expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index fd24c7b87..3f45fa335 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -928,6 +928,9 @@ pub fn translate_aggregation_step_groupby( target_register } AggFunc::Count | AggFunc::Count0 => { + if num_args != 1 { + crate::bail_parse_error!("count bad number of arguments"); + } let expr_reg = agg_arg_source.translate(program, 0)?; handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { diff --git a/core/translate/plan.rs b/core/translate/plan.rs index eba50ce89..e43cdbd76 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -1048,6 +1048,24 @@ pub struct Aggregate { } impl Aggregate { + pub fn new(func: AggFunc, args: &[Box], expr: &Expr, distinctness: Distinctness) -> Self { + let agg_args = if args.is_empty() { + // The AggStep instruction requires at least one argument. For functions that accept + // zero arguments (e.g. COUNT()), we insert a dummy literal so that AggStep remains valid. + // This does not cause ambiguity: the resolver has already verified that the function + // takes zero arguments, so the dummy value will be ignored. + vec![Expr::Literal(ast::Literal::Numeric("1".to_string()))] + } else { + args.iter().map(|arg| *arg.clone()).collect() + }; + Aggregate { + func, + args: agg_args, + original_expr: expr.clone(), + distinctness, + } + } + pub fn is_distinct(&self) -> bool { self.distinctness.is_distinct() } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 1799ed42b..522256a25 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -73,12 +73,7 @@ pub fn resolve_aggregates( "DISTINCT aggregate functions must have exactly one argument" ); } - aggs.push(Aggregate { - func: f, - args: args.iter().map(|arg| *arg.clone()).collect(), - original_expr: expr.clone(), - distinctness, - }); + aggs.push(Aggregate::new(f, args, expr, distinctness)); contains_aggregates = true; } _ => { @@ -95,12 +90,7 @@ pub fn resolve_aggregates( ); } if let Ok(Func::Agg(f)) = Func::resolve_function(name.as_str(), 0) { - aggs.push(Aggregate { - func: f, - args: vec![], - original_expr: expr.clone(), - distinctness: Distinctness::NonDistinct, - }); + aggs.push(Aggregate::new(f, &[], expr, Distinctness::NonDistinct)); contains_aggregates = true; } } diff --git a/core/translate/select.rs b/core/translate/select.rs index 8641e2347..59239094e 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -371,27 +371,7 @@ fn prepare_one_select_plan( } match Func::resolve_function(name.as_str(), args_count) { Ok(Func::Agg(f)) => { - let agg_args = match (args.is_empty(), &f) { - (true, crate::function::AggFunc::Count0) => { - // COUNT() case - vec![ast::Expr::Literal(ast::Literal::Numeric( - "1".to_string(), - )) - .into()] - } - (true, _) => crate::bail_parse_error!( - "Aggregate function {} requires arguments", - name.as_str() - ), - (false, _) => args.clone(), - }; - - let agg = Aggregate { - func: f, - args: agg_args.iter().map(|arg| *arg.clone()).collect(), - original_expr: *expr.clone(), - distinctness, - }; + let agg = Aggregate::new(f, args, expr, distinctness); aggregate_expressions.push(agg); plan.result_columns.push(ResultSetColumn { alias: maybe_alias.as_ref().map(|alias| match alias { @@ -446,15 +426,12 @@ fn prepare_one_select_plan( contains_aggregates, }); } else { - let agg = Aggregate { - func: AggFunc::External(f.func.clone().into()), - args: args - .iter() - .map(|arg| *arg.clone()) - .collect(), - original_expr: *expr.clone(), + let agg = Aggregate::new( + AggFunc::External(f.func.clone().into()), + args, + expr, distinctness, - }; + ); aggregate_expressions.push(agg); plan.result_columns.push(ResultSetColumn { alias: maybe_alias.as_ref().map(|alias| { @@ -488,14 +465,8 @@ fn prepare_one_select_plan( } match Func::resolve_function(name.as_str(), 0) { Ok(Func::Agg(f)) => { - let agg = Aggregate { - func: f, - args: vec![ast::Expr::Literal(ast::Literal::Numeric( - "1".to_string(), - ))], - original_expr: *expr.clone(), - distinctness: Distinctness::NonDistinct, - }; + let agg = + Aggregate::new(f, &[], expr, Distinctness::NonDistinct); aggregate_expressions.push(agg); plan.result_columns.push(ResultSetColumn { alias: maybe_alias.as_ref().map(|alias| match alias { diff --git a/testing/groupby.test b/testing/groupby.test index c159b9892..dc5418110 100755 --- a/testing/groupby.test +++ b/testing/groupby.test @@ -145,6 +145,18 @@ do_execsql_test group_by_count_star { select u.first_name, count(*) from users u group by u.first_name limit 1; } {Aaron|41} +do_execsql_test group_by_count_star_in_expression { + select u.first_name, count(*) % 3 from users u group by u.first_name order by u.first_name limit 3; +} {Aaron|2 +Abigail|1 +Adam|0} + +do_execsql_test group_by_count_no_args_in_expression { + select u.first_name, count() % 3 from users u group by u.first_name order by u.first_name limit 3; +} {Aaron|2 +Abigail|1 +Adam|0} + do_execsql_test having { select u.first_name, round(avg(u.age)) from users u group by u.first_name having avg(u.age) > 97 order by avg(u.age) desc limit 5; } {Nina|100.0