mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-18 17:14:20 +01:00
Support external aggregate functions wrapped in expressions
Handled in the same way as in `prepare_one_select_plan` for bare function calls. In `prepare_one_select_plan`, however, resolving external scalar functions is performed unnecessarily twice.
This commit is contained in:
@@ -10,6 +10,7 @@ use super::{
|
||||
select::prepare_select_plan,
|
||||
SymbolTable,
|
||||
};
|
||||
use crate::function::{AggFunc, ExtFunc};
|
||||
use crate::translate::expr::WalkControl;
|
||||
use crate::{
|
||||
function::Func,
|
||||
@@ -29,6 +30,7 @@ pub const ROWID: &str = "rowid";
|
||||
|
||||
pub fn resolve_aggregates(
|
||||
schema: &Schema,
|
||||
syms: &SymbolTable,
|
||||
top_level_expr: &Expr,
|
||||
aggs: &mut Vec<Aggregate>,
|
||||
) -> Result<bool> {
|
||||
@@ -60,22 +62,39 @@ pub fn resolve_aggregates(
|
||||
);
|
||||
}
|
||||
let args_count = args.len();
|
||||
let distinctness = Distinctness::from_ast(distinctness.as_ref());
|
||||
|
||||
if !schema.indexes_enabled() && distinctness.is_distinct() {
|
||||
crate::bail_parse_error!(
|
||||
"SELECT with DISTINCT is not allowed without indexes enabled"
|
||||
);
|
||||
}
|
||||
if distinctness.is_distinct() && args_count != 1 {
|
||||
crate::bail_parse_error!(
|
||||
"DISTINCT aggregate functions must have exactly one argument"
|
||||
);
|
||||
}
|
||||
match Func::resolve_function(name.as_str(), args_count) {
|
||||
Ok(Func::Agg(f)) => {
|
||||
let distinctness = Distinctness::from_ast(distinctness.as_ref());
|
||||
if !schema.indexes_enabled() && distinctness.is_distinct() {
|
||||
crate::bail_parse_error!(
|
||||
"SELECT with DISTINCT is not allowed without indexes enabled"
|
||||
);
|
||||
}
|
||||
if distinctness.is_distinct() && args.len() != 1 {
|
||||
crate::bail_parse_error!(
|
||||
"DISTINCT aggregate functions must have exactly one argument"
|
||||
);
|
||||
}
|
||||
aggs.push(Aggregate::new(f, args, expr, distinctness));
|
||||
contains_aggregates = true;
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(f) = syms.resolve_function(name.as_str(), args_count) {
|
||||
if let ExtFunc::Aggregate { .. } = f.as_ref().func {
|
||||
let agg = Aggregate::new(
|
||||
AggFunc::External(f.func.clone().into()),
|
||||
args,
|
||||
expr,
|
||||
distinctness,
|
||||
);
|
||||
aggs.push(agg);
|
||||
contains_aggregates = true;
|
||||
}
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -387,6 +387,7 @@ fn prepare_one_select_plan(
|
||||
Ok(_) => {
|
||||
let contains_aggregates = resolve_aggregates(
|
||||
schema,
|
||||
syms,
|
||||
expr,
|
||||
&mut aggregate_expressions,
|
||||
)?;
|
||||
@@ -408,6 +409,7 @@ fn prepare_one_select_plan(
|
||||
if let ExtFunc::Scalar(_) = f.as_ref().func {
|
||||
let contains_aggregates = resolve_aggregates(
|
||||
schema,
|
||||
syms,
|
||||
expr,
|
||||
&mut aggregate_expressions,
|
||||
)?;
|
||||
@@ -499,8 +501,12 @@ fn prepare_one_select_plan(
|
||||
}
|
||||
}
|
||||
expr => {
|
||||
let contains_aggregates =
|
||||
resolve_aggregates(schema, expr, &mut aggregate_expressions)?;
|
||||
let contains_aggregates = resolve_aggregates(
|
||||
schema,
|
||||
syms,
|
||||
expr,
|
||||
&mut aggregate_expressions,
|
||||
)?;
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| match alias {
|
||||
ast::As::Elided(alias) => alias.as_str().to_string(),
|
||||
@@ -554,7 +560,7 @@ fn prepare_one_select_plan(
|
||||
connection,
|
||||
)?;
|
||||
let contains_aggregates =
|
||||
resolve_aggregates(schema, expr, &mut aggregate_expressions)?;
|
||||
resolve_aggregates(schema, syms, expr, &mut aggregate_expressions)?;
|
||||
if !contains_aggregates {
|
||||
// TODO: sqlite allows HAVING clauses with non aggregate expressions like
|
||||
// HAVING id = 5. We should support this too eventually (I guess).
|
||||
@@ -586,7 +592,7 @@ fn prepare_one_select_plan(
|
||||
Some(&plan.result_columns),
|
||||
connection,
|
||||
)?;
|
||||
resolve_aggregates(schema, &o.expr, &mut plan.aggregates)?;
|
||||
resolve_aggregates(schema, syms, &o.expr, &mut plan.aggregates)?;
|
||||
|
||||
key.push((o.expr, o.order.unwrap_or(ast::SortOrder::Asc)));
|
||||
}
|
||||
|
||||
@@ -153,6 +153,11 @@ def test_aggregates():
|
||||
validate_median,
|
||||
"median agg function works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"select CASE WHEN median(value) > 0 THEN median(value) ELSE 0 END from numbers;",
|
||||
validate_median,
|
||||
"median agg function wrapped in expression works",
|
||||
)
|
||||
limbo.execute_dot("INSERT INTO numbers (value) VALUES (8.0);\n")
|
||||
limbo.run_test_fn(
|
||||
"select median(value) from numbers;",
|
||||
@@ -184,6 +189,11 @@ def test_grouped_aggregates():
|
||||
lambda res: "2.0\n5.5" == res,
|
||||
"median aggregate function works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"select CASE WHEN median(value) > 0 THEN median(value) ELSE 0 END from numbers GROUP BY category;",
|
||||
lambda res: "2.0\n5.5" == res,
|
||||
"median aggregate function wrapped in expression works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"SELECT percentile(value, percent) FROM test GROUP BY category;",
|
||||
lambda res: "12.5\n30.0\n45.0\n70.0" == res,
|
||||
|
||||
Reference in New Issue
Block a user