refactor: use walk_expr() in resolve_aggregates()

This commit is contained in:
Jussi Saurio
2025-05-23 16:01:35 +03:00
parent 2ab5c5f6a9
commit 3835a29f47

View File

@@ -1,4 +1,5 @@
use super::{
expr::walk_expr,
plan::{
Aggregate, ColumnUsedMask, Distinctness, EvalAt, IterationDirection, JoinInfo,
JoinOrderMember, Operation, Plan, ResultSetColumn, SelectPlan, SelectQueryType,
@@ -21,82 +22,75 @@ use limbo_sqlite3_parser::ast::{
pub const ROWID: &str = "rowid";
pub fn resolve_aggregates(expr: &Expr, aggs: &mut Vec<Aggregate>) -> Result<bool> {
if aggs
.iter()
.any(|a| exprs_are_equivalent(&a.original_expr, expr))
{
return Ok(true);
}
match expr {
Expr::FunctionCall {
name,
args,
distinctness,
..
} => {
let args_count = if let Some(args) = &args {
args.len()
} else {
0
};
match Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), args_count) {
Ok(Func::Agg(f)) => {
let distinctness = Distinctness::from_ast(distinctness.as_ref());
let num_args = args.as_ref().map_or(0, |args| args.len());
if distinctness.is_distinct() && num_args != 1 {
crate::bail_parse_error!(
"DISTINCT aggregate functions must have exactly one argument"
);
pub fn resolve_aggregates(top_level_expr: &Expr, aggs: &mut Vec<Aggregate>) -> Result<bool> {
let mut contains_aggregates = false;
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> {
if aggs
.iter()
.any(|a| exprs_are_equivalent(&a.original_expr, expr))
{
contains_aggregates = true;
return Ok(());
}
match expr {
Expr::FunctionCall {
name,
args,
distinctness,
..
} => {
let args_count = if let Some(args) = &args {
args.len()
} else {
0
};
match Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), args_count)
{
Ok(Func::Agg(f)) => {
let distinctness = Distinctness::from_ast(distinctness.as_ref());
let num_args = args.as_ref().map_or(0, |args| args.len());
if distinctness.is_distinct() && num_args != 1 {
crate::bail_parse_error!(
"DISTINCT aggregate functions must have exactly one argument"
);
}
aggs.push(Aggregate {
func: f,
args: args.clone().unwrap_or_default(),
original_expr: expr.clone(),
distinctness,
});
contains_aggregates = true;
}
aggs.push(Aggregate {
func: f,
args: args.clone().unwrap_or_default(),
original_expr: expr.clone(),
distinctness,
});
Ok(true)
}
_ => {
let mut contains_aggregates = false;
if let Some(args) = args {
for arg in args.iter() {
contains_aggregates |= resolve_aggregates(arg, aggs)?;
_ => {
if let Some(args) = args {
for arg in args.iter() {
contains_aggregates |= resolve_aggregates(arg, aggs)?;
}
}
}
Ok(contains_aggregates)
}
}
}
Expr::FunctionCallStar { name, .. } => {
if let Ok(Func::Agg(f)) =
Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), 0)
{
aggs.push(Aggregate {
func: f,
args: vec![],
original_expr: expr.clone(),
distinctness: Distinctness::NonDistinct,
});
Ok(true)
} else {
Ok(false)
Expr::FunctionCallStar { name, .. } => {
if let Ok(Func::Agg(f)) =
Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), 0)
{
aggs.push(Aggregate {
func: f,
args: vec![],
original_expr: expr.clone(),
distinctness: Distinctness::NonDistinct,
});
contains_aggregates = true;
}
}
_ => {}
}
Expr::Binary(lhs, _, rhs) => {
let mut contains_aggregates = false;
contains_aggregates |= resolve_aggregates(lhs, aggs)?;
contains_aggregates |= resolve_aggregates(rhs, aggs)?;
Ok(contains_aggregates)
}
Expr::Unary(_, expr) => {
let mut contains_aggregates = false;
contains_aggregates |= resolve_aggregates(expr, aggs)?;
Ok(contains_aggregates)
}
// TODO: handle other expressions that may contain aggregates
_ => Ok(false),
}
Ok(())
})?;
Ok(contains_aggregates)
}
pub fn bind_column_references(