diff --git a/core/benches/benchmark.rs b/core/benches/benchmark.rs index fda384712..5db976086 100644 --- a/core/benches/benchmark.rs +++ b/core/benches/benchmark.rs @@ -1,5 +1,6 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use pprof::criterion::{Output, PProfProfiler}; +use regex::Regex; use std::{sync::Arc, time::Instant}; use turso_core::{Database, PlatformIO}; @@ -272,9 +273,41 @@ fn bench_prepare_query(criterion: &mut Criterion) { "SELECT 1", "SELECT * FROM users LIMIT 1", "SELECT first_name, count(1) FROM users GROUP BY first_name HAVING count(1) > 1 ORDER BY count(1) LIMIT 1", + "SELECT + first_name, + last_name, + state, + city, + age + 10, + LENGTH(email), + UPPER(first_name), + LOWER(last_name), + SUBSTR(phone_number, 1, 3), + zipcode || '-' || state, + AVG(age) + 5, + MAX(age) - MIN(age), + ROUND(AVG(age), 1), + SUM(age) / COUNT(*), + COUNT(*), + COUNT(email), + SUM(age), + AVG(age), + MIN(age), + MAX(age), + SUM(CASE WHEN age >= 18 THEN 1 ELSE 0 END), + SUM(CASE WHEN age < 18 THEN 1 ELSE 0 END), + AVG(CASE WHEN age >= 18 THEN age ELSE NULL END), + MAX(CASE WHEN age >= 18 THEN age ELSE NULL END) + FROM users + GROUP BY state, city", ]; + let whitespace_re = Regex::new(r"\s+").unwrap(); for query in queries.iter() { + // Normalize whitespace in the query string by replacing all sequences of whitespace with a single space. + let query = whitespace_re.replace_all(query, " ").to_string(); + let query = query.as_str(); + let mut group = criterion.benchmark_group(format!("Prepare `{query}`")); group.bench_with_input( diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 9e9a6e0de..b0b696bb4 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -10,6 +10,7 @@ use super::{ select::prepare_select_plan, SymbolTable, }; +use crate::function::{AggFunc, ExtFunc}; use crate::translate::expr::WalkControl; use crate::{ function::Func, @@ -30,18 +31,12 @@ pub const ROWID: &str = "rowid"; pub fn resolve_aggregates( schema: &Schema, + syms: &SymbolTable, top_level_expr: &Expr, aggs: &mut Vec, ) -> Result { let mut contains_aggregates = false; walk_expr(top_level_expr, &mut |expr: &Expr| -> Result { - if aggs - .iter() - .any(|a| exprs_are_equivalent(&a.original_expr, expr)) - { - contains_aggregates = true; - return Ok(WalkControl::Continue); - } match expr { Expr::FunctionCall { name, @@ -61,27 +56,42 @@ pub fn resolve_aggregates( ); } let args_count = args.len(); + let distinctness = Distinctness::from_ast(distinctness.as_ref()); + + if !schema.indexes_enabled() && distinctness.is_distinct() { + crate::bail_parse_error!( + "SELECT with DISTINCT is not allowed without indexes enabled" + ); + } + if distinctness.is_distinct() && args_count != 1 { + crate::bail_parse_error!( + "DISTINCT aggregate functions must have exactly one argument" + ); + } match Func::resolve_function(name.as_str(), args_count) { Ok(Func::Agg(f)) => { - let distinctness = Distinctness::from_ast(distinctness.as_ref()); - if !schema.indexes_enabled() && distinctness.is_distinct() { - crate::bail_parse_error!( - "SELECT with DISTINCT is not allowed without indexes enabled" - ); - } - if distinctness.is_distinct() && args.len() != 1 { - crate::bail_parse_error!( - "DISTINCT aggregate functions must have exactly one argument" - ); - } - aggs.push(Aggregate::new(f, args, expr, distinctness)); + add_aggregate_if_not_exists(aggs, expr, args, distinctness, f); contains_aggregates = true; + return Ok(WalkControl::SkipChildren); } - _ => { - for arg in args.iter() { - contains_aggregates |= resolve_aggregates(schema, arg, aggs)?; + Err(e) => { + if let Some(f) = syms.resolve_function(name.as_str(), args_count) { + if let ExtFunc::Aggregate { .. } = f.as_ref().func { + add_aggregate_if_not_exists( + aggs, + expr, + args, + distinctness, + AggFunc::External(f.func.clone().into()), + ); + contains_aggregates = true; + return Ok(WalkControl::SkipChildren); + } + } else { + return Err(e); } } + _ => {} } } Expr::FunctionCallStar { name, filter_over } => { @@ -90,9 +100,26 @@ pub fn resolve_aggregates( "FILTER clause is not supported yet in aggregate functions" ); } - if let Ok(Func::Agg(f)) = Func::resolve_function(name.as_str(), 0) { - aggs.push(Aggregate::new(f, &[], expr, Distinctness::NonDistinct)); - contains_aggregates = true; + match Func::resolve_function(name.as_str(), 0) { + Ok(Func::Agg(f)) => { + add_aggregate_if_not_exists(aggs, expr, &[], Distinctness::NonDistinct, f); + contains_aggregates = true; + return Ok(WalkControl::SkipChildren); + } + Ok(_) => { + crate::bail_parse_error!("Invalid aggregate function: {}", name.as_str()); + } + Err(e) => match e { + crate::LimboError::ParseError(e) => { + crate::bail_parse_error!("{}", e); + } + _ => { + crate::bail_parse_error!( + "Invalid aggregate function: {}", + name.as_str() + ); + } + }, } } _ => {} @@ -104,6 +131,21 @@ pub fn resolve_aggregates( Ok(contains_aggregates) } +fn add_aggregate_if_not_exists( + aggs: &mut Vec, + expr: &Expr, + args: &[Box], + distinctness: Distinctness, + func: AggFunc, +) { + if aggs + .iter() + .all(|a| !exprs_are_equivalent(&a.original_expr, expr)) + { + aggs.push(Aggregate::new(func, args, expr, distinctness)); + } +} + pub fn bind_column_references( top_level_expr: &mut Expr, referenced_tables: &mut TableReferences, diff --git a/core/translate/select.rs b/core/translate/select.rs index 59239094e..cee03b87a 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -3,10 +3,9 @@ use super::plan::{ select_star, Distinctness, JoinOrderMember, Operation, OuterQueryReference, QueryDestination, Search, TableReferences, WhereTerm, }; -use crate::function::{AggFunc, ExtFunc, Func}; use crate::schema::Table; use crate::translate::optimizer::optimize_plan; -use crate::translate::plan::{Aggregate, GroupBy, Plan, ResultSetColumn, SelectPlan}; +use crate::translate::plan::{GroupBy, Plan, ResultSetColumn, SelectPlan}; use crate::translate::planner::{ bind_column_references, break_predicate_at_and_boundaries, parse_from, parse_limit, parse_where, resolve_aggregates, @@ -340,177 +339,16 @@ fn prepare_one_select_plan( Some(&plan.result_columns), connection, )?; - match expr.as_ref() { - ast::Expr::FunctionCall { - name, - distinctness, - args, - filter_over, - order_by, - } => { - if filter_over.filter_clause.is_some() - || filter_over.over_clause.is_some() - { - crate::bail_parse_error!( - "FILTER clause is not supported yet in aggregate functions" - ); - } - if !order_by.is_empty() { - crate::bail_parse_error!("ORDER BY clause is not supported yet in aggregate functions"); - } - let args_count = args.len(); - let distinctness = Distinctness::from_ast(distinctness.as_ref()); - - if !schema.indexes_enabled() && distinctness.is_distinct() { - crate::bail_parse_error!( - "SELECT with DISTINCT is not allowed without indexes enabled" - ); - } - if distinctness.is_distinct() && args_count != 1 { - crate::bail_parse_error!("DISTINCT aggregate functions must have exactly one argument"); - } - match Func::resolve_function(name.as_str(), args_count) { - Ok(Func::Agg(f)) => { - let agg = Aggregate::new(f, args, expr, distinctness); - aggregate_expressions.push(agg); - plan.result_columns.push(ResultSetColumn { - alias: maybe_alias.as_ref().map(|alias| match alias { - ast::As::Elided(alias) => { - alias.as_str().to_string() - } - ast::As::As(alias) => alias.as_str().to_string(), - }), - expr: *expr.clone(), - contains_aggregates: true, - }); - } - Ok(_) => { - let contains_aggregates = resolve_aggregates( - schema, - expr, - &mut aggregate_expressions, - )?; - plan.result_columns.push(ResultSetColumn { - alias: maybe_alias.as_ref().map(|alias| match alias { - ast::As::Elided(alias) => { - alias.as_str().to_string() - } - ast::As::As(alias) => alias.as_str().to_string(), - }), - expr: *expr.clone(), - contains_aggregates, - }); - } - Err(e) => { - if let Some(f) = - syms.resolve_function(name.as_str(), args_count) - { - if let ExtFunc::Scalar(_) = f.as_ref().func { - let contains_aggregates = resolve_aggregates( - schema, - expr, - &mut aggregate_expressions, - )?; - plan.result_columns.push(ResultSetColumn { - alias: maybe_alias.as_ref().map(|alias| { - match alias { - ast::As::Elided(alias) => { - alias.as_str().to_string() - } - ast::As::As(alias) => { - alias.as_str().to_string() - } - } - }), - expr: *expr.clone(), - contains_aggregates, - }); - } else { - let agg = Aggregate::new( - AggFunc::External(f.func.clone().into()), - args, - expr, - distinctness, - ); - aggregate_expressions.push(agg); - plan.result_columns.push(ResultSetColumn { - alias: maybe_alias.as_ref().map(|alias| { - match alias { - ast::As::Elided(alias) => { - alias.as_str().to_string() - } - ast::As::As(alias) => { - alias.as_str().to_string() - } - } - }), - expr: *expr.clone(), - contains_aggregates: true, - }); - } - continue; // Continue with the normal flow instead of returning - } else { - return Err(e); - } - } - } - } - ast::Expr::FunctionCallStar { name, filter_over } => { - if filter_over.filter_clause.is_some() - || filter_over.over_clause.is_some() - { - crate::bail_parse_error!( - "FILTER clause is not supported yet in aggregate functions" - ); - } - match Func::resolve_function(name.as_str(), 0) { - Ok(Func::Agg(f)) => { - let agg = - Aggregate::new(f, &[], expr, Distinctness::NonDistinct); - aggregate_expressions.push(agg); - plan.result_columns.push(ResultSetColumn { - alias: maybe_alias.as_ref().map(|alias| match alias { - ast::As::Elided(alias) => { - alias.as_str().to_string() - } - ast::As::As(alias) => alias.as_str().to_string(), - }), - expr: *expr.clone(), - contains_aggregates: true, - }); - } - Ok(_) => { - crate::bail_parse_error!( - "Invalid aggregate function: {}", - name.as_str() - ); - } - Err(e) => match e { - crate::LimboError::ParseError(e) => { - crate::bail_parse_error!("{}", e); - } - _ => { - crate::bail_parse_error!( - "Invalid aggregate function: {}", - name.as_str() - ); - } - }, - } - } - expr => { - let contains_aggregates = - resolve_aggregates(schema, expr, &mut aggregate_expressions)?; - plan.result_columns.push(ResultSetColumn { - alias: maybe_alias.as_ref().map(|alias| match alias { - ast::As::Elided(alias) => alias.as_str().to_string(), - ast::As::As(alias) => alias.as_str().to_string(), - }), - expr: expr.clone(), - contains_aggregates, - }); - } - } + let contains_aggregates = + resolve_aggregates(schema, syms, expr, &mut aggregate_expressions)?; + plan.result_columns.push(ResultSetColumn { + alias: maybe_alias.as_ref().map(|alias| match alias { + ast::As::Elided(alias) => alias.as_str().to_string(), + ast::As::As(alias) => alias.as_str().to_string(), + }), + expr: expr.as_ref().clone(), + contains_aggregates, + }); } } } @@ -554,7 +392,7 @@ fn prepare_one_select_plan( connection, )?; let contains_aggregates = - resolve_aggregates(schema, expr, &mut aggregate_expressions)?; + resolve_aggregates(schema, syms, expr, &mut aggregate_expressions)?; if !contains_aggregates { // TODO: sqlite allows HAVING clauses with non aggregate expressions like // HAVING id = 5. We should support this too eventually (I guess). @@ -586,7 +424,7 @@ fn prepare_one_select_plan( Some(&plan.result_columns), connection, )?; - resolve_aggregates(schema, &o.expr, &mut plan.aggregates)?; + resolve_aggregates(schema, syms, &o.expr, &mut plan.aggregates)?; key.push((o.expr, o.order.unwrap_or(ast::SortOrder::Asc))); } diff --git a/testing/agg-functions.test b/testing/agg-functions.test index 13b4600d7..52c45ce9b 100755 --- a/testing/agg-functions.test +++ b/testing/agg-functions.test @@ -175,3 +175,27 @@ do_execsql_test select-json-group-object-no-sorting-required { 3|{"3437":"Amanda"} 5|{"2378":"Amy","3227":"Amy","5605":"Amanda"} 7|{"2454":"Amber"}} + +do_execsql_test_error_content select-max-star { + SELECT max(*) FROM users; +} {"wrong number of arguments to function"} + +do_execsql_test_error_content select-max-star-in-expression { + SELECT CASE WHEN max(*) > 0 THEN 1 ELSE 0 END FROM users; +} {"wrong number of arguments to function"} + +do_execsql_test_error select-scalar-func-star { + SELECT abs(*) FROM users; +} {.*(Invalid aggregate function|wrong number of arguments to function).*} + +do_execsql_test_error select-scalar-func-star-in-expression { + SELECT CASE WHEN abs(*) > 0 THEN 1 ELSE 0 END FROM users; +} {.*(Invalid aggregate function|wrong number of arguments to function).*} + +do_execsql_test_error_content select-nested-agg-func { + SELECT max(abs(sum(age))), sum(age) FROM users; +} {"misuse of aggregate function"} + +do_execsql_test_error_content select-nested-agg-func-in-expression { + SELECT CASE WHEN max(abs(sum(age))) > 0 THEN 1 ELSE 0 END, sum(age) FROM users; +} {"misuse of aggregate function"} diff --git a/testing/cli_tests/extensions.py b/testing/cli_tests/extensions.py index 8ce7341f0..f53621fb9 100755 --- a/testing/cli_tests/extensions.py +++ b/testing/cli_tests/extensions.py @@ -153,6 +153,11 @@ def test_aggregates(): validate_median, "median agg function works", ) + limbo.run_test_fn( + "select CASE WHEN median(value) > 0 THEN median(value) ELSE 0 END from numbers;", + validate_median, + "median agg function wrapped in expression works", + ) limbo.execute_dot("INSERT INTO numbers (value) VALUES (8.0);\n") limbo.run_test_fn( "select median(value) from numbers;", @@ -184,6 +189,11 @@ def test_grouped_aggregates(): lambda res: "2.0\n5.5" == res, "median aggregate function works", ) + limbo.run_test_fn( + "select CASE WHEN median(value) > 0 THEN median(value) ELSE 0 END from numbers GROUP BY category;", + lambda res: "2.0\n5.5" == res, + "median aggregate function wrapped in expression works", + ) limbo.run_test_fn( "SELECT percentile(value, percent) FROM test GROUP BY category;", lambda res: "12.5\n30.0\n45.0\n70.0" == res,