mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-25 03:54:21 +01:00
Merge 'Unify resolution of aggregate functions' from Piotr Rżysko
This PR unifies the logic for resolving aggregate functions. Previously,
bare aggregates (e.g. `SELECT max(a) FROM t1`) and aggregates wrapped in
expressions (e.g. `SELECT max(a) + 1 FROM t1`) were handled differently,
which led to duplicated code. Now both cases are resolved consistently.
The added benchmark shows a small improvement:
```
Prepare `SELECT first_name, last_name, state, city, age + 10, LENGTH(email), UPPER(first_name), LOWE...
time: [59.791 µs 59.898 µs 60.006 µs]
change: [-7.7090% -7.2760% -6.8242%] (p = 0.00 < 0.05)
Performance has improved.
Found 10 outliers among 100 measurements (10.00%)
8 (8.00%) high mild
2 (2.00%) high severe
```
For an existing benchmark, no change:
```
Prepare `SELECT first_name, count(1) FROM users GROUP BY first_name HAVING count(1) > 1 ORDER BY cou...
time: [11.895 µs 11.913 µs 11.931 µs]
change: [-0.2545% +0.2426% +0.6960%] (p = 0.34 > 0.05)
No change in performance detected.
Found 8 outliers among 100 measurements (8.00%)
1 (1.00%) low severe
2 (2.00%) high mild
5 (5.00%) high severe
```
Reviewed-by: Jussi Saurio <jussi.saurio@gmail.com>
Reviewed-by: Preston Thorpe <preston@turso.tech>
Closes #2884
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use pprof::criterion::{Output, PProfProfiler};
|
||||
use regex::Regex;
|
||||
use std::{sync::Arc, time::Instant};
|
||||
use turso_core::{Database, PlatformIO};
|
||||
|
||||
@@ -272,9 +273,41 @@ fn bench_prepare_query(criterion: &mut Criterion) {
|
||||
"SELECT 1",
|
||||
"SELECT * FROM users LIMIT 1",
|
||||
"SELECT first_name, count(1) FROM users GROUP BY first_name HAVING count(1) > 1 ORDER BY count(1) LIMIT 1",
|
||||
"SELECT
|
||||
first_name,
|
||||
last_name,
|
||||
state,
|
||||
city,
|
||||
age + 10,
|
||||
LENGTH(email),
|
||||
UPPER(first_name),
|
||||
LOWER(last_name),
|
||||
SUBSTR(phone_number, 1, 3),
|
||||
zipcode || '-' || state,
|
||||
AVG(age) + 5,
|
||||
MAX(age) - MIN(age),
|
||||
ROUND(AVG(age), 1),
|
||||
SUM(age) / COUNT(*),
|
||||
COUNT(*),
|
||||
COUNT(email),
|
||||
SUM(age),
|
||||
AVG(age),
|
||||
MIN(age),
|
||||
MAX(age),
|
||||
SUM(CASE WHEN age >= 18 THEN 1 ELSE 0 END),
|
||||
SUM(CASE WHEN age < 18 THEN 1 ELSE 0 END),
|
||||
AVG(CASE WHEN age >= 18 THEN age ELSE NULL END),
|
||||
MAX(CASE WHEN age >= 18 THEN age ELSE NULL END)
|
||||
FROM users
|
||||
GROUP BY state, city",
|
||||
];
|
||||
|
||||
let whitespace_re = Regex::new(r"\s+").unwrap();
|
||||
for query in queries.iter() {
|
||||
// Normalize whitespace in the query string by replacing all sequences of whitespace with a single space.
|
||||
let query = whitespace_re.replace_all(query, " ").to_string();
|
||||
let query = query.as_str();
|
||||
|
||||
let mut group = criterion.benchmark_group(format!("Prepare `{query}`"));
|
||||
|
||||
group.bench_with_input(
|
||||
|
||||
@@ -10,6 +10,7 @@ use super::{
|
||||
select::prepare_select_plan,
|
||||
SymbolTable,
|
||||
};
|
||||
use crate::function::{AggFunc, ExtFunc};
|
||||
use crate::translate::expr::WalkControl;
|
||||
use crate::{
|
||||
function::Func,
|
||||
@@ -30,18 +31,12 @@ pub const ROWID: &str = "rowid";
|
||||
|
||||
pub fn resolve_aggregates(
|
||||
schema: &Schema,
|
||||
syms: &SymbolTable,
|
||||
top_level_expr: &Expr,
|
||||
aggs: &mut Vec<Aggregate>,
|
||||
) -> Result<bool> {
|
||||
let mut contains_aggregates = false;
|
||||
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<WalkControl> {
|
||||
if aggs
|
||||
.iter()
|
||||
.any(|a| exprs_are_equivalent(&a.original_expr, expr))
|
||||
{
|
||||
contains_aggregates = true;
|
||||
return Ok(WalkControl::Continue);
|
||||
}
|
||||
match expr {
|
||||
Expr::FunctionCall {
|
||||
name,
|
||||
@@ -61,27 +56,42 @@ pub fn resolve_aggregates(
|
||||
);
|
||||
}
|
||||
let args_count = args.len();
|
||||
let distinctness = Distinctness::from_ast(distinctness.as_ref());
|
||||
|
||||
if !schema.indexes_enabled() && distinctness.is_distinct() {
|
||||
crate::bail_parse_error!(
|
||||
"SELECT with DISTINCT is not allowed without indexes enabled"
|
||||
);
|
||||
}
|
||||
if distinctness.is_distinct() && args_count != 1 {
|
||||
crate::bail_parse_error!(
|
||||
"DISTINCT aggregate functions must have exactly one argument"
|
||||
);
|
||||
}
|
||||
match Func::resolve_function(name.as_str(), args_count) {
|
||||
Ok(Func::Agg(f)) => {
|
||||
let distinctness = Distinctness::from_ast(distinctness.as_ref());
|
||||
if !schema.indexes_enabled() && distinctness.is_distinct() {
|
||||
crate::bail_parse_error!(
|
||||
"SELECT with DISTINCT is not allowed without indexes enabled"
|
||||
);
|
||||
}
|
||||
if distinctness.is_distinct() && args.len() != 1 {
|
||||
crate::bail_parse_error!(
|
||||
"DISTINCT aggregate functions must have exactly one argument"
|
||||
);
|
||||
}
|
||||
aggs.push(Aggregate::new(f, args, expr, distinctness));
|
||||
add_aggregate_if_not_exists(aggs, expr, args, distinctness, f);
|
||||
contains_aggregates = true;
|
||||
return Ok(WalkControl::SkipChildren);
|
||||
}
|
||||
_ => {
|
||||
for arg in args.iter() {
|
||||
contains_aggregates |= resolve_aggregates(schema, arg, aggs)?;
|
||||
Err(e) => {
|
||||
if let Some(f) = syms.resolve_function(name.as_str(), args_count) {
|
||||
if let ExtFunc::Aggregate { .. } = f.as_ref().func {
|
||||
add_aggregate_if_not_exists(
|
||||
aggs,
|
||||
expr,
|
||||
args,
|
||||
distinctness,
|
||||
AggFunc::External(f.func.clone().into()),
|
||||
);
|
||||
contains_aggregates = true;
|
||||
return Ok(WalkControl::SkipChildren);
|
||||
}
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Expr::FunctionCallStar { name, filter_over } => {
|
||||
@@ -90,9 +100,26 @@ pub fn resolve_aggregates(
|
||||
"FILTER clause is not supported yet in aggregate functions"
|
||||
);
|
||||
}
|
||||
if let Ok(Func::Agg(f)) = Func::resolve_function(name.as_str(), 0) {
|
||||
aggs.push(Aggregate::new(f, &[], expr, Distinctness::NonDistinct));
|
||||
contains_aggregates = true;
|
||||
match Func::resolve_function(name.as_str(), 0) {
|
||||
Ok(Func::Agg(f)) => {
|
||||
add_aggregate_if_not_exists(aggs, expr, &[], Distinctness::NonDistinct, f);
|
||||
contains_aggregates = true;
|
||||
return Ok(WalkControl::SkipChildren);
|
||||
}
|
||||
Ok(_) => {
|
||||
crate::bail_parse_error!("Invalid aggregate function: {}", name.as_str());
|
||||
}
|
||||
Err(e) => match e {
|
||||
crate::LimboError::ParseError(e) => {
|
||||
crate::bail_parse_error!("{}", e);
|
||||
}
|
||||
_ => {
|
||||
crate::bail_parse_error!(
|
||||
"Invalid aggregate function: {}",
|
||||
name.as_str()
|
||||
);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
@@ -104,6 +131,21 @@ pub fn resolve_aggregates(
|
||||
Ok(contains_aggregates)
|
||||
}
|
||||
|
||||
fn add_aggregate_if_not_exists(
|
||||
aggs: &mut Vec<Aggregate>,
|
||||
expr: &Expr,
|
||||
args: &[Box<Expr>],
|
||||
distinctness: Distinctness,
|
||||
func: AggFunc,
|
||||
) {
|
||||
if aggs
|
||||
.iter()
|
||||
.all(|a| !exprs_are_equivalent(&a.original_expr, expr))
|
||||
{
|
||||
aggs.push(Aggregate::new(func, args, expr, distinctness));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bind_column_references(
|
||||
top_level_expr: &mut Expr,
|
||||
referenced_tables: &mut TableReferences,
|
||||
|
||||
@@ -3,10 +3,9 @@ use super::plan::{
|
||||
select_star, Distinctness, JoinOrderMember, Operation, OuterQueryReference, QueryDestination,
|
||||
Search, TableReferences, WhereTerm,
|
||||
};
|
||||
use crate::function::{AggFunc, ExtFunc, Func};
|
||||
use crate::schema::Table;
|
||||
use crate::translate::optimizer::optimize_plan;
|
||||
use crate::translate::plan::{Aggregate, GroupBy, Plan, ResultSetColumn, SelectPlan};
|
||||
use crate::translate::plan::{GroupBy, Plan, ResultSetColumn, SelectPlan};
|
||||
use crate::translate::planner::{
|
||||
bind_column_references, break_predicate_at_and_boundaries, parse_from, parse_limit,
|
||||
parse_where, resolve_aggregates,
|
||||
@@ -340,177 +339,16 @@ fn prepare_one_select_plan(
|
||||
Some(&plan.result_columns),
|
||||
connection,
|
||||
)?;
|
||||
match expr.as_ref() {
|
||||
ast::Expr::FunctionCall {
|
||||
name,
|
||||
distinctness,
|
||||
args,
|
||||
filter_over,
|
||||
order_by,
|
||||
} => {
|
||||
if filter_over.filter_clause.is_some()
|
||||
|| filter_over.over_clause.is_some()
|
||||
{
|
||||
crate::bail_parse_error!(
|
||||
"FILTER clause is not supported yet in aggregate functions"
|
||||
);
|
||||
}
|
||||
if !order_by.is_empty() {
|
||||
crate::bail_parse_error!("ORDER BY clause is not supported yet in aggregate functions");
|
||||
}
|
||||
let args_count = args.len();
|
||||
let distinctness = Distinctness::from_ast(distinctness.as_ref());
|
||||
|
||||
if !schema.indexes_enabled() && distinctness.is_distinct() {
|
||||
crate::bail_parse_error!(
|
||||
"SELECT with DISTINCT is not allowed without indexes enabled"
|
||||
);
|
||||
}
|
||||
if distinctness.is_distinct() && args_count != 1 {
|
||||
crate::bail_parse_error!("DISTINCT aggregate functions must have exactly one argument");
|
||||
}
|
||||
match Func::resolve_function(name.as_str(), args_count) {
|
||||
Ok(Func::Agg(f)) => {
|
||||
let agg = Aggregate::new(f, args, expr, distinctness);
|
||||
aggregate_expressions.push(agg);
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| match alias {
|
||||
ast::As::Elided(alias) => {
|
||||
alias.as_str().to_string()
|
||||
}
|
||||
ast::As::As(alias) => alias.as_str().to_string(),
|
||||
}),
|
||||
expr: *expr.clone(),
|
||||
contains_aggregates: true,
|
||||
});
|
||||
}
|
||||
Ok(_) => {
|
||||
let contains_aggregates = resolve_aggregates(
|
||||
schema,
|
||||
expr,
|
||||
&mut aggregate_expressions,
|
||||
)?;
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| match alias {
|
||||
ast::As::Elided(alias) => {
|
||||
alias.as_str().to_string()
|
||||
}
|
||||
ast::As::As(alias) => alias.as_str().to_string(),
|
||||
}),
|
||||
expr: *expr.clone(),
|
||||
contains_aggregates,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(f) =
|
||||
syms.resolve_function(name.as_str(), args_count)
|
||||
{
|
||||
if let ExtFunc::Scalar(_) = f.as_ref().func {
|
||||
let contains_aggregates = resolve_aggregates(
|
||||
schema,
|
||||
expr,
|
||||
&mut aggregate_expressions,
|
||||
)?;
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| {
|
||||
match alias {
|
||||
ast::As::Elided(alias) => {
|
||||
alias.as_str().to_string()
|
||||
}
|
||||
ast::As::As(alias) => {
|
||||
alias.as_str().to_string()
|
||||
}
|
||||
}
|
||||
}),
|
||||
expr: *expr.clone(),
|
||||
contains_aggregates,
|
||||
});
|
||||
} else {
|
||||
let agg = Aggregate::new(
|
||||
AggFunc::External(f.func.clone().into()),
|
||||
args,
|
||||
expr,
|
||||
distinctness,
|
||||
);
|
||||
aggregate_expressions.push(agg);
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| {
|
||||
match alias {
|
||||
ast::As::Elided(alias) => {
|
||||
alias.as_str().to_string()
|
||||
}
|
||||
ast::As::As(alias) => {
|
||||
alias.as_str().to_string()
|
||||
}
|
||||
}
|
||||
}),
|
||||
expr: *expr.clone(),
|
||||
contains_aggregates: true,
|
||||
});
|
||||
}
|
||||
continue; // Continue with the normal flow instead of returning
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::Expr::FunctionCallStar { name, filter_over } => {
|
||||
if filter_over.filter_clause.is_some()
|
||||
|| filter_over.over_clause.is_some()
|
||||
{
|
||||
crate::bail_parse_error!(
|
||||
"FILTER clause is not supported yet in aggregate functions"
|
||||
);
|
||||
}
|
||||
match Func::resolve_function(name.as_str(), 0) {
|
||||
Ok(Func::Agg(f)) => {
|
||||
let agg =
|
||||
Aggregate::new(f, &[], expr, Distinctness::NonDistinct);
|
||||
aggregate_expressions.push(agg);
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| match alias {
|
||||
ast::As::Elided(alias) => {
|
||||
alias.as_str().to_string()
|
||||
}
|
||||
ast::As::As(alias) => alias.as_str().to_string(),
|
||||
}),
|
||||
expr: *expr.clone(),
|
||||
contains_aggregates: true,
|
||||
});
|
||||
}
|
||||
Ok(_) => {
|
||||
crate::bail_parse_error!(
|
||||
"Invalid aggregate function: {}",
|
||||
name.as_str()
|
||||
);
|
||||
}
|
||||
Err(e) => match e {
|
||||
crate::LimboError::ParseError(e) => {
|
||||
crate::bail_parse_error!("{}", e);
|
||||
}
|
||||
_ => {
|
||||
crate::bail_parse_error!(
|
||||
"Invalid aggregate function: {}",
|
||||
name.as_str()
|
||||
);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
expr => {
|
||||
let contains_aggregates =
|
||||
resolve_aggregates(schema, expr, &mut aggregate_expressions)?;
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| match alias {
|
||||
ast::As::Elided(alias) => alias.as_str().to_string(),
|
||||
ast::As::As(alias) => alias.as_str().to_string(),
|
||||
}),
|
||||
expr: expr.clone(),
|
||||
contains_aggregates,
|
||||
});
|
||||
}
|
||||
}
|
||||
let contains_aggregates =
|
||||
resolve_aggregates(schema, syms, expr, &mut aggregate_expressions)?;
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| match alias {
|
||||
ast::As::Elided(alias) => alias.as_str().to_string(),
|
||||
ast::As::As(alias) => alias.as_str().to_string(),
|
||||
}),
|
||||
expr: expr.as_ref().clone(),
|
||||
contains_aggregates,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -554,7 +392,7 @@ fn prepare_one_select_plan(
|
||||
connection,
|
||||
)?;
|
||||
let contains_aggregates =
|
||||
resolve_aggregates(schema, expr, &mut aggregate_expressions)?;
|
||||
resolve_aggregates(schema, syms, expr, &mut aggregate_expressions)?;
|
||||
if !contains_aggregates {
|
||||
// TODO: sqlite allows HAVING clauses with non aggregate expressions like
|
||||
// HAVING id = 5. We should support this too eventually (I guess).
|
||||
@@ -586,7 +424,7 @@ fn prepare_one_select_plan(
|
||||
Some(&plan.result_columns),
|
||||
connection,
|
||||
)?;
|
||||
resolve_aggregates(schema, &o.expr, &mut plan.aggregates)?;
|
||||
resolve_aggregates(schema, syms, &o.expr, &mut plan.aggregates)?;
|
||||
|
||||
key.push((o.expr, o.order.unwrap_or(ast::SortOrder::Asc)));
|
||||
}
|
||||
|
||||
@@ -175,3 +175,27 @@ do_execsql_test select-json-group-object-no-sorting-required {
|
||||
3|{"3437":"Amanda"}
|
||||
5|{"2378":"Amy","3227":"Amy","5605":"Amanda"}
|
||||
7|{"2454":"Amber"}}
|
||||
|
||||
do_execsql_test_error_content select-max-star {
|
||||
SELECT max(*) FROM users;
|
||||
} {"wrong number of arguments to function"}
|
||||
|
||||
do_execsql_test_error_content select-max-star-in-expression {
|
||||
SELECT CASE WHEN max(*) > 0 THEN 1 ELSE 0 END FROM users;
|
||||
} {"wrong number of arguments to function"}
|
||||
|
||||
do_execsql_test_error select-scalar-func-star {
|
||||
SELECT abs(*) FROM users;
|
||||
} {.*(Invalid aggregate function|wrong number of arguments to function).*}
|
||||
|
||||
do_execsql_test_error select-scalar-func-star-in-expression {
|
||||
SELECT CASE WHEN abs(*) > 0 THEN 1 ELSE 0 END FROM users;
|
||||
} {.*(Invalid aggregate function|wrong number of arguments to function).*}
|
||||
|
||||
do_execsql_test_error_content select-nested-agg-func {
|
||||
SELECT max(abs(sum(age))), sum(age) FROM users;
|
||||
} {"misuse of aggregate function"}
|
||||
|
||||
do_execsql_test_error_content select-nested-agg-func-in-expression {
|
||||
SELECT CASE WHEN max(abs(sum(age))) > 0 THEN 1 ELSE 0 END, sum(age) FROM users;
|
||||
} {"misuse of aggregate function"}
|
||||
|
||||
@@ -153,6 +153,11 @@ def test_aggregates():
|
||||
validate_median,
|
||||
"median agg function works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"select CASE WHEN median(value) > 0 THEN median(value) ELSE 0 END from numbers;",
|
||||
validate_median,
|
||||
"median agg function wrapped in expression works",
|
||||
)
|
||||
limbo.execute_dot("INSERT INTO numbers (value) VALUES (8.0);\n")
|
||||
limbo.run_test_fn(
|
||||
"select median(value) from numbers;",
|
||||
@@ -184,6 +189,11 @@ def test_grouped_aggregates():
|
||||
lambda res: "2.0\n5.5" == res,
|
||||
"median aggregate function works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"select CASE WHEN median(value) > 0 THEN median(value) ELSE 0 END from numbers GROUP BY category;",
|
||||
lambda res: "2.0\n5.5" == res,
|
||||
"median aggregate function wrapped in expression works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"SELECT percentile(value, percent) FROM test GROUP BY category;",
|
||||
lambda res: "12.5\n30.0\n45.0\n70.0" == res,
|
||||
|
||||
Reference in New Issue
Block a user