From 3835a29f4704ea6751bd6cea5591c868d0ecbc64 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Fri, 23 May 2025 16:01:35 +0300 Subject: [PATCH] refactor: use walk_expr() in resolve_aggregates() --- core/translate/planner.rs | 132 ++++++++++++++++++-------------------- 1 file changed, 63 insertions(+), 69 deletions(-) diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 25e3a2da3..226768d05 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,4 +1,5 @@ use super::{ + expr::walk_expr, plan::{ Aggregate, ColumnUsedMask, Distinctness, EvalAt, IterationDirection, JoinInfo, JoinOrderMember, Operation, Plan, ResultSetColumn, SelectPlan, SelectQueryType, @@ -21,82 +22,75 @@ use limbo_sqlite3_parser::ast::{ pub const ROWID: &str = "rowid"; -pub fn resolve_aggregates(expr: &Expr, aggs: &mut Vec) -> Result { - if aggs - .iter() - .any(|a| exprs_are_equivalent(&a.original_expr, expr)) - { - return Ok(true); - } - match expr { - Expr::FunctionCall { - name, - args, - distinctness, - .. - } => { - let args_count = if let Some(args) = &args { - args.len() - } else { - 0 - }; - match Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), args_count) { - Ok(Func::Agg(f)) => { - let distinctness = Distinctness::from_ast(distinctness.as_ref()); - let num_args = args.as_ref().map_or(0, |args| args.len()); - if distinctness.is_distinct() && num_args != 1 { - crate::bail_parse_error!( - "DISTINCT aggregate functions must have exactly one argument" - ); +pub fn resolve_aggregates(top_level_expr: &Expr, aggs: &mut Vec) -> Result { + let mut contains_aggregates = false; + walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> { + if aggs + .iter() + .any(|a| exprs_are_equivalent(&a.original_expr, expr)) + { + contains_aggregates = true; + return Ok(()); + } + match expr { + Expr::FunctionCall { + name, + args, + distinctness, + .. + } => { + let args_count = if let Some(args) = &args { + args.len() + } else { + 0 + }; + match Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), args_count) + { + Ok(Func::Agg(f)) => { + let distinctness = Distinctness::from_ast(distinctness.as_ref()); + let num_args = args.as_ref().map_or(0, |args| args.len()); + if distinctness.is_distinct() && num_args != 1 { + crate::bail_parse_error!( + "DISTINCT aggregate functions must have exactly one argument" + ); + } + aggs.push(Aggregate { + func: f, + args: args.clone().unwrap_or_default(), + original_expr: expr.clone(), + distinctness, + }); + contains_aggregates = true; } - aggs.push(Aggregate { - func: f, - args: args.clone().unwrap_or_default(), - original_expr: expr.clone(), - distinctness, - }); - Ok(true) - } - _ => { - let mut contains_aggregates = false; - if let Some(args) = args { - for arg in args.iter() { - contains_aggregates |= resolve_aggregates(arg, aggs)?; + _ => { + if let Some(args) = args { + for arg in args.iter() { + contains_aggregates |= resolve_aggregates(arg, aggs)?; + } } } - Ok(contains_aggregates) } } - } - Expr::FunctionCallStar { name, .. } => { - if let Ok(Func::Agg(f)) = - Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), 0) - { - aggs.push(Aggregate { - func: f, - args: vec![], - original_expr: expr.clone(), - distinctness: Distinctness::NonDistinct, - }); - Ok(true) - } else { - Ok(false) + Expr::FunctionCallStar { name, .. } => { + if let Ok(Func::Agg(f)) = + Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), 0) + { + aggs.push(Aggregate { + func: f, + args: vec![], + original_expr: expr.clone(), + distinctness: Distinctness::NonDistinct, + }); + contains_aggregates = true; + } } + _ => {} } - Expr::Binary(lhs, _, rhs) => { - let mut contains_aggregates = false; - contains_aggregates |= resolve_aggregates(lhs, aggs)?; - contains_aggregates |= resolve_aggregates(rhs, aggs)?; - Ok(contains_aggregates) - } - Expr::Unary(_, expr) => { - let mut contains_aggregates = false; - contains_aggregates |= resolve_aggregates(expr, aggs)?; - Ok(contains_aggregates) - } - // TODO: handle other expressions that may contain aggregates - _ => Ok(false), - } + + Ok(()) + })?; + + Ok(contains_aggregates) } pub fn bind_column_references(