From 978a78b79a51b9d5f7cf1db398d85863a40f4b89 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Sun, 31 Aug 2025 06:51:26 +0200 Subject: [PATCH 1/6] Handle COLLATE clause in grouped aggregations Previously, it was only applied to ungrouped aggregations. --- core/translate/aggregation.rs | 2 +- core/translate/group_by.rs | 6 +++++- testing/collate.test | 28 ++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index ad902097e..943d1a1cb 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -61,7 +61,7 @@ pub fn emit_ungrouped_aggregation<'a>( Ok(()) } -fn emit_collseq_if_needed( +pub fn emit_collseq_if_needed( program: &mut ProgramBuilder, referenced_tables: &TableReferences, expr: &ast::Expr, diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 0a524f348..fd24c7b87 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -14,7 +14,7 @@ use crate::{ }, Result, }; - +use crate::translate::aggregation::emit_collseq_if_needed; use super::{ aggregation::handle_distinct, emitter::{Resolver, TranslateCtx}, @@ -991,6 +991,8 @@ pub fn translate_aggregation_step_groupby( } let expr_reg = agg_arg_source.translate(program, 0)?; handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let expr = &agg_arg_source.args()[0]; + emit_collseq_if_needed(program, referenced_tables, expr); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -1005,6 +1007,8 @@ pub fn translate_aggregation_step_groupby( } let expr_reg = agg_arg_source.translate(program, 0)?; handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let expr = &agg_arg_source.args()[0]; + emit_collseq_if_needed(program, referenced_tables, expr); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, diff --git a/testing/collate.test b/testing/collate.test index 2b9aa3ff3..d985286b9 100755 --- a/testing/collate.test +++ b/testing/collate.test @@ -74,3 +74,31 @@ do_execsql_test_on_specific_db {:memory:} collate_aggregation_explicit_nocase { insert into fruits(name) values ('Apple') ,('banana') ,('CHERRY'); select max(name collate nocase) from fruits; } {CHERRY} + +do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_default_binary { + create table fruits(name collate binary, category text); + insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B'); + select max(name) from fruits group by category; +} {banana +blueberry} + +do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_default_nocase { + create table fruits(name collate nocase, category text); + insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B'); + select max(name) from fruits group by category; +} {banana +CHERRY} + +do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_explicit_binary { + create table fruits(name collate nocase, category text); + insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B'); + select max(name collate binary) from fruits group by category; +} {banana +blueberry} + +do_execsql_test_on_specific_db {:memory:} collate_groupped_aggregation_explicit_nocase { + create table fruits(name collate binary, category text); + insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B'); + select max(name collate nocase) from fruits group by category; +} {banana +CHERRY} From 3ad4016080642df89a15329cfa8b16f677e4e34c Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Sun, 31 Aug 2025 08:03:34 +0200 Subject: [PATCH 2/6] Fix handling of zero-argument grouped aggregations This commit consolidates the creation of the Aggregate struct, which was previously handled differently in `prepare_one_select_plan` and `resolve_aggregates`. That discrepancy caused inconsistent handling of zero-argument aggregates. The queries added in the new tests would previously trigger a panic. --- core/translate/aggregation.rs | 14 +++++------ core/translate/group_by.rs | 3 +++ core/translate/plan.rs | 18 ++++++++++++++ core/translate/planner.rs | 14 ++--------- core/translate/select.rs | 45 +++++++---------------------------- testing/groupby.test | 12 ++++++++++ 6 files changed, 49 insertions(+), 57 deletions(-) diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index 943d1a1cb..c1b462ed2 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -155,14 +155,12 @@ pub fn translate_aggregation_step( target_register } AggFunc::Count | AggFunc::Count0 => { - let expr_reg = if agg.args.is_empty() { - program.alloc_register() - } else { - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - expr_reg - }; + if agg.args.len() != 1 { + crate::bail_parse_error!("count bad number of arguments"); + } + let expr = &agg.args[0]; + let expr_reg = program.alloc_register(); + let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; handle_distinct(program, agg, expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index fd24c7b87..3f45fa335 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -928,6 +928,9 @@ pub fn translate_aggregation_step_groupby( target_register } AggFunc::Count | AggFunc::Count0 => { + if num_args != 1 { + crate::bail_parse_error!("count bad number of arguments"); + } let expr_reg = agg_arg_source.translate(program, 0)?; handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { diff --git a/core/translate/plan.rs b/core/translate/plan.rs index eba50ce89..e43cdbd76 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -1048,6 +1048,24 @@ pub struct Aggregate { } impl Aggregate { + pub fn new(func: AggFunc, args: &[Box], expr: &Expr, distinctness: Distinctness) -> Self { + let agg_args = if args.is_empty() { + // The AggStep instruction requires at least one argument. For functions that accept + // zero arguments (e.g. COUNT()), we insert a dummy literal so that AggStep remains valid. + // This does not cause ambiguity: the resolver has already verified that the function + // takes zero arguments, so the dummy value will be ignored. + vec![Expr::Literal(ast::Literal::Numeric("1".to_string()))] + } else { + args.iter().map(|arg| *arg.clone()).collect() + }; + Aggregate { + func, + args: agg_args, + original_expr: expr.clone(), + distinctness, + } + } + pub fn is_distinct(&self) -> bool { self.distinctness.is_distinct() } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 1799ed42b..522256a25 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -73,12 +73,7 @@ pub fn resolve_aggregates( "DISTINCT aggregate functions must have exactly one argument" ); } - aggs.push(Aggregate { - func: f, - args: args.iter().map(|arg| *arg.clone()).collect(), - original_expr: expr.clone(), - distinctness, - }); + aggs.push(Aggregate::new(f, args, expr, distinctness)); contains_aggregates = true; } _ => { @@ -95,12 +90,7 @@ pub fn resolve_aggregates( ); } if let Ok(Func::Agg(f)) = Func::resolve_function(name.as_str(), 0) { - aggs.push(Aggregate { - func: f, - args: vec![], - original_expr: expr.clone(), - distinctness: Distinctness::NonDistinct, - }); + aggs.push(Aggregate::new(f, &[], expr, Distinctness::NonDistinct)); contains_aggregates = true; } } diff --git a/core/translate/select.rs b/core/translate/select.rs index 8641e2347..59239094e 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -371,27 +371,7 @@ fn prepare_one_select_plan( } match Func::resolve_function(name.as_str(), args_count) { Ok(Func::Agg(f)) => { - let agg_args = match (args.is_empty(), &f) { - (true, crate::function::AggFunc::Count0) => { - // COUNT() case - vec![ast::Expr::Literal(ast::Literal::Numeric( - "1".to_string(), - )) - .into()] - } - (true, _) => crate::bail_parse_error!( - "Aggregate function {} requires arguments", - name.as_str() - ), - (false, _) => args.clone(), - }; - - let agg = Aggregate { - func: f, - args: agg_args.iter().map(|arg| *arg.clone()).collect(), - original_expr: *expr.clone(), - distinctness, - }; + let agg = Aggregate::new(f, args, expr, distinctness); aggregate_expressions.push(agg); plan.result_columns.push(ResultSetColumn { alias: maybe_alias.as_ref().map(|alias| match alias { @@ -446,15 +426,12 @@ fn prepare_one_select_plan( contains_aggregates, }); } else { - let agg = Aggregate { - func: AggFunc::External(f.func.clone().into()), - args: args - .iter() - .map(|arg| *arg.clone()) - .collect(), - original_expr: *expr.clone(), + let agg = Aggregate::new( + AggFunc::External(f.func.clone().into()), + args, + expr, distinctness, - }; + ); aggregate_expressions.push(agg); plan.result_columns.push(ResultSetColumn { alias: maybe_alias.as_ref().map(|alias| { @@ -488,14 +465,8 @@ fn prepare_one_select_plan( } match Func::resolve_function(name.as_str(), 0) { Ok(Func::Agg(f)) => { - let agg = Aggregate { - func: f, - args: vec![ast::Expr::Literal(ast::Literal::Numeric( - "1".to_string(), - ))], - original_expr: *expr.clone(), - distinctness: Distinctness::NonDistinct, - }; + let agg = + Aggregate::new(f, &[], expr, Distinctness::NonDistinct); aggregate_expressions.push(agg); plan.result_columns.push(ResultSetColumn { alias: maybe_alias.as_ref().map(|alias| match alias { diff --git a/testing/groupby.test b/testing/groupby.test index c159b9892..dc5418110 100755 --- a/testing/groupby.test +++ b/testing/groupby.test @@ -145,6 +145,18 @@ do_execsql_test group_by_count_star { select u.first_name, count(*) from users u group by u.first_name limit 1; } {Aaron|41} +do_execsql_test group_by_count_star_in_expression { + select u.first_name, count(*) % 3 from users u group by u.first_name order by u.first_name limit 3; +} {Aaron|2 +Abigail|1 +Adam|0} + +do_execsql_test group_by_count_no_args_in_expression { + select u.first_name, count() % 3 from users u group by u.first_name order by u.first_name limit 3; +} {Aaron|2 +Abigail|1 +Adam|0} + do_execsql_test having { select u.first_name, round(avg(u.age)) from users u group by u.first_name having avg(u.age) > 97 order by avg(u.age) desc limit 5; } {Nina|100.0 From 7d179bd9fee86121f0d9725b5d3459b0ab50b98b Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Sun, 31 Aug 2025 10:44:36 +0200 Subject: [PATCH 3/6] Fix handling of multiple arguments in aggregate functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This bug occurred when arguments were read for the GROUP BY sorter — all arguments were incorrectly resolved to the first column. Added tests confirm that aggregates now work correctly both with and without the sorter. --- core/translate/group_by.rs | 6 +++++- testing/agg-functions.test | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 3f45fa335..dbeabd417 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -479,7 +479,11 @@ impl<'a> GroupByAggArgumentSource<'a> { dest_reg_start, .. } => { - program.emit_column_or_rowid(*cursor_id, *col_start, dest_reg_start + arg_idx); + program.emit_column_or_rowid( + *cursor_id, + *col_start + arg_idx, + dest_reg_start + arg_idx, + ); Ok(dest_reg_start + arg_idx) } GroupByAggArgumentSource::Register { diff --git a/testing/agg-functions.test b/testing/agg-functions.test index 9becf56a4..13b4600d7 100755 --- a/testing/agg-functions.test +++ b/testing/agg-functions.test @@ -143,4 +143,35 @@ do_execsql_test select-agg-json-array-object { do_execsql_test select-distinct-agg-functions { SELECT sum(distinct age), count(distinct age), avg(distinct age) FROM users; -} {5050|100|50.5} \ No newline at end of file +} {5050|100|50.5} + +do_execsql_test select-json-group-object { + select price, + json_group_object(cast (id as text), name) + from products + group by price + order by price; +} {1.0|{"9":"boots"} +18.0|{"3":"shirt"} +25.0|{"4":"sweater"} +33.0|{"10":"coat"} +70.0|{"6":"shorts"} +74.0|{"5":"sweatshirt"} +78.0|{"7":"jeans"} +79.0|{"1":"hat"} +81.0|{"11":"accessories"} +82.0|{"2":"cap","8":"sneakers"}} + +do_execsql_test select-json-group-object-no-sorting-required { + select age, + json_group_object(cast (id as text), first_name) + from users + where first_name like 'Am%' + group by age + order by age + limit 5; +} {1|{"6737":"Amy"} +2|{"2297":"Amy","3580":"Amanda"} +3|{"3437":"Amanda"} +5|{"2378":"Amy","3227":"Amy","5605":"Amanda"} +7|{"2454":"Amber"}} From 0a85883ee2860e123e3d5006fd647b98d21cf292 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Sun, 31 Aug 2025 10:57:33 +0200 Subject: [PATCH 4/6] Support external aggregate functions in GROUP BY --- core/translate/group_by.rs | 32 ++++++++++++++-- testing/cli_tests/extensions.py | 66 +++++++++++++++++++++++++-------- 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index dbeabd417..006495eb1 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -12,7 +12,7 @@ use crate::{ insn::Insn, BranchOffset, }, - Result, + LimboError, Result, }; use crate::translate::aggregation::emit_collseq_if_needed; use super::{ @@ -1119,8 +1119,34 @@ pub fn translate_aggregation_step_groupby( }); target_register } - AggFunc::External(_) => { - todo!("External aggregate functions are not yet supported in GROUP BY"); + AggFunc::External(ref func) => { + let argc = func.agg_args().map_err(|_| { + LimboError::ExtensionError( + "External aggregate function called with wrong number of arguments".to_string(), + ) + })?; + if argc != num_args { + crate::bail_parse_error!( + "External aggregate function called with wrong number of arguments" + ); + } + let expr_reg = agg_arg_source.translate(program, 0)?; + for i in 0..argc { + if i != 0 { + let _ = agg_arg_source.translate(program, i)?; + } + // invariant: distinct aggregates are only supported for single-argument functions + if argc == 1 { + handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i); + } + } + program.emit_insn(Insn::AggStep { + acc_reg: target_register, + col: expr_reg, + delimiter: 0, + func: AggFunc::External(func.clone()), + }); + target_register } }; Ok(dest) diff --git a/testing/cli_tests/extensions.py b/testing/cli_tests/extensions.py index 8b0e3c3a5..8ce7341f0 100755 --- a/testing/cli_tests/extensions.py +++ b/testing/cli_tests/extensions.py @@ -7,22 +7,22 @@ from cli_tests.test_turso_cli import TestTursoShell sqlite_exec = "./scripts/limbo-sqlite3" sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ") -test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL); -INSERT INTO numbers (value) VALUES (1.0); -INSERT INTO numbers (value) VALUES (2.0); -INSERT INTO numbers (value) VALUES (3.0); -INSERT INTO numbers (value) VALUES (4.0); -INSERT INTO numbers (value) VALUES (5.0); -INSERT INTO numbers (value) VALUES (6.0); -INSERT INTO numbers (value) VALUES (7.0); -CREATE TABLE test (value REAL, percent REAL); -INSERT INTO test values (10, 25); -INSERT INTO test values (20, 25); -INSERT INTO test values (30, 25); -INSERT INTO test values (40, 25); -INSERT INTO test values (50, 25); -INSERT INTO test values (60, 25); -INSERT INTO test values (70, 25); +test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL, category TEXT DEFAULT 'A'); +INSERT INTO numbers (value, category) VALUES (1.0, 'A'); +INSERT INTO numbers (value, category) VALUES (2.0, 'A'); +INSERT INTO numbers (value, category) VALUES (3.0, 'A'); +INSERT INTO numbers (value, category) VALUES (4.0, 'B'); +INSERT INTO numbers (value, category) VALUES (5.0, 'B'); +INSERT INTO numbers (value, category) VALUES (6.0, 'B'); +INSERT INTO numbers (value, category) VALUES (7.0, 'B'); +CREATE TABLE test (value REAL, percent REAL, category TEXT); +INSERT INTO test values (10, 25, 'A'); +INSERT INTO test values (20, 25, 'A'); +INSERT INTO test values (30, 25, 'B'); +INSERT INTO test values (40, 25, 'C'); +INSERT INTO test values (50, 25, 'C'); +INSERT INTO test values (60, 25, 'C'); +INSERT INTO test values (70, 25, 'D'); """ @@ -174,6 +174,39 @@ def test_aggregates(): limbo.quit() +def test_grouped_aggregates(): + limbo = TestTursoShell(init_commands=test_data) + extension_path = "./target/debug/liblimbo_percentile" + limbo.execute_dot(f".load {extension_path}") + + limbo.run_test_fn( + "SELECT median(value) FROM numbers GROUP BY category;", + lambda res: "2.0\n5.5" == res, + "median aggregate function works", + ) + limbo.run_test_fn( + "SELECT percentile(value, percent) FROM test GROUP BY category;", + lambda res: "12.5\n30.0\n45.0\n70.0" == res, + "grouped aggregate percentile function with 2 arguments works", + ) + limbo.run_test_fn( + "SELECT percentile(value, 55) FROM test GROUP BY category;", + lambda res: "15.5\n30.0\n51.0\n70.0" == res, + "grouped aggregate percentile function with 1 argument works", + ) + limbo.run_test_fn( + "SELECT percentile_cont(value, 0.25) FROM test GROUP BY category;", + lambda res: "12.5\n30.0\n45.0\n70.0" == res, + "grouped aggregate percentile_cont function works", + ) + limbo.run_test_fn( + "SELECT percentile_disc(value, 0.55) FROM test GROUP BY category;", + lambda res: "10.0\n30.0\n50.0\n70.0" == res, + "grouped aggregate percentile_disc function works", + ) + limbo.quit() + + # Encoders and decoders def validate_url_encode(a): return a == "%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29" @@ -770,6 +803,7 @@ def main(): test_regexp() test_uuid() test_aggregates() + test_grouped_aggregates() test_crypto() test_series() test_ipaddr() From cdba1f1b87796ac50c941b7fc60304a9e18d6133 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Sun, 31 Aug 2025 11:39:50 +0200 Subject: [PATCH 5/6] Generalize GroupByAggArgumentSource This is primarily a mechanical change: the enum was moved between files, renamed, and its comments updated so it is no longer strictly tied to GROUP BY aggregations. This prepares the enum for reuse with ungrouped aggregations. --- core/translate/aggregation.rs | 100 +++++++++++++++++++++++++ core/translate/group_by.rs | 136 +++++----------------------------- 2 files changed, 117 insertions(+), 119 deletions(-) diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index c1b462ed2..2e43c3ea9 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -125,6 +125,106 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re }); } +/// Enum representing the source of the aggregate function arguments +/// +/// Aggregate arguments can come from different sources, depending on how the aggregation +/// is evaluated: +/// * In the common grouped case, the aggregate function arguments are first inserted +/// into a sorter in the main loop, and in the group by aggregation phase we read +/// the data from the sorter. +/// * In grouped cases where no sorting is required, arguments are retrieved directly +/// from registers allocated in the main loop. +pub enum AggArgumentSource<'a> { + /// The aggregate function arguments are retrieved from a pseudo cursor + /// which reads from the GROUP BY sorter. + PseudoCursor { + cursor_id: usize, + col_start: usize, + dest_reg_start: usize, + aggregate: &'a Aggregate, + }, + /// The aggregate function arguments are retrieved from a contiguous block of registers + /// allocated in the main loop for that given aggregate function. + Register { + src_reg_start: usize, + aggregate: &'a Aggregate, + }, +} + +impl<'a> AggArgumentSource<'a> { + /// Create a new [AggArgumentSource] that retrieves the values from a GROUP BY sorter. + pub fn new_from_cursor( + program: &mut ProgramBuilder, + cursor_id: usize, + col_start: usize, + aggregate: &'a Aggregate, + ) -> Self { + let dest_reg_start = program.alloc_registers(aggregate.args.len()); + Self::PseudoCursor { + cursor_id, + col_start, + dest_reg_start, + aggregate, + } + } + /// Create a new [AggArgumentSource] that retrieves the values directly from an already + /// populated register or registers. + pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self { + Self::Register { + src_reg_start, + aggregate, + } + } + + pub fn aggregate(&self) -> &Aggregate { + match self { + AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate, + AggArgumentSource::Register { aggregate, .. } => aggregate, + } + } + + pub fn agg_func(&self) -> &AggFunc { + match self { + AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func, + AggArgumentSource::Register { aggregate, .. } => &aggregate.func, + } + } + pub fn args(&self) -> &[ast::Expr] { + match self { + AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args, + AggArgumentSource::Register { aggregate, .. } => &aggregate.args, + } + } + pub fn num_args(&self) -> usize { + match self { + AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(), + AggArgumentSource::Register { aggregate, .. } => aggregate.args.len(), + } + } + /// Read the value of an aggregate function argument + pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result { + match self { + AggArgumentSource::PseudoCursor { + cursor_id, + col_start, + dest_reg_start, + .. + } => { + program.emit_column_or_rowid( + *cursor_id, + *col_start + arg_idx, + dest_reg_start + arg_idx, + ); + Ok(dest_reg_start + arg_idx) + } + AggArgumentSource::Register { + src_reg_start: start_reg, + .. + } => Ok(*start_reg + arg_idx), + } + } +} + /// Emits the bytecode for processing an aggregate step. /// E.g. in `SELECT SUM(price) FROM t`, 'price' is evaluated for every row, and the result is added to the accumulator. /// diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 006495eb1..715bb5c93 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -1,5 +1,14 @@ use turso_parser::ast; +use super::{ + aggregation::handle_distinct, + emitter::{Resolver, TranslateCtx}, + expr::{translate_condition_expr, translate_expr, ConditionMetadata}, + order_by::order_by_sorter_insert, + plan::{Distinctness, GroupBy, SelectPlan, TableReferences}, + result_row::emit_select_result, +}; +use crate::translate::aggregation::{emit_collseq_if_needed, AggArgumentSource}; use crate::translate::expr::{walk_expr, WalkControl}; use crate::translate::plan::ResultSetColumn; use crate::{ @@ -14,15 +23,6 @@ use crate::{ }, LimboError, Result, }; -use crate::translate::aggregation::emit_collseq_if_needed; -use super::{ - aggregation::handle_distinct, - emitter::{Resolver, TranslateCtx}, - expr::{translate_condition_expr, translate_expr, ConditionMetadata}, - order_by::order_by_sorter_insert, - plan::{Aggregate, Distinctness, GroupBy, SelectPlan, TableReferences}, - result_row::emit_select_result, -}; /// Labels needed for various jumps in GROUP BY handling. #[derive(Debug)] @@ -394,106 +394,6 @@ pub enum GroupByRowSource { }, } -/// Enum representing the source of the aggregate function arguments -/// emitted for a group by aggregation. -/// In the common case, the aggregate function arguments are first inserted -/// into a sorter in the main loop, and in the group by aggregation phase -/// we read the data from the sorter. -/// -/// In the alternative case, no sorting is required for group by, -/// and the aggregate function arguments are retrieved directly from -/// registers allocated in the main loop. -pub enum GroupByAggArgumentSource<'a> { - /// The aggregate function arguments are retrieved from a pseudo cursor - /// which reads from the GROUP BY sorter. - PseudoCursor { - cursor_id: usize, - col_start: usize, - dest_reg_start: usize, - aggregate: &'a Aggregate, - }, - /// The aggregate function arguments are retrieved from a contiguous block of registers - /// allocated in the main loop for that given aggregate function. - Register { - src_reg_start: usize, - aggregate: &'a Aggregate, - }, -} - -impl<'a> GroupByAggArgumentSource<'a> { - /// Create a new [GroupByAggArgumentSource] that retrieves the values from a GROUP BY sorter. - pub fn new_from_cursor( - program: &mut ProgramBuilder, - cursor_id: usize, - col_start: usize, - aggregate: &'a Aggregate, - ) -> Self { - let dest_reg_start = program.alloc_registers(aggregate.args.len()); - Self::PseudoCursor { - cursor_id, - col_start, - dest_reg_start, - aggregate, - } - } - /// Create a new [GroupByAggArgumentSource] that retrieves the values directly from an already - /// populated register or registers. - pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self { - Self::Register { - src_reg_start, - aggregate, - } - } - - pub fn aggregate(&self) -> &Aggregate { - match self { - GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate, - GroupByAggArgumentSource::Register { aggregate, .. } => aggregate, - } - } - - pub fn agg_func(&self) -> &AggFunc { - match self { - GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func, - GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.func, - } - } - pub fn args(&self) -> &[ast::Expr] { - match self { - GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args, - GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.args, - } - } - pub fn num_args(&self) -> usize { - match self { - GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(), - GroupByAggArgumentSource::Register { aggregate, .. } => aggregate.args.len(), - } - } - /// Read the value of an aggregate function argument either from sorter data or directly from a register. - pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result { - match self { - GroupByAggArgumentSource::PseudoCursor { - cursor_id, - col_start, - dest_reg_start, - .. - } => { - program.emit_column_or_rowid( - *cursor_id, - *col_start + arg_idx, - dest_reg_start + arg_idx, - ); - Ok(dest_reg_start + arg_idx) - } - GroupByAggArgumentSource::Register { - src_reg_start: start_reg, - .. - } => Ok(*start_reg + arg_idx), - } - } -} - /// Emits bytecode for processing a single GROUP BY group. pub fn group_by_process_single_group( program: &mut ProgramBuilder, @@ -597,18 +497,16 @@ pub fn group_by_process_single_group( .expect("aggregate registers must be initialized"); let agg_result_reg = start_reg + i; let agg_arg_source = match &row_source { - GroupByRowSource::Sorter { pseudo_cursor, .. } => { - GroupByAggArgumentSource::new_from_cursor( - program, - *pseudo_cursor, - cursor_index + offset, - agg, - ) - } + GroupByRowSource::Sorter { pseudo_cursor, .. } => AggArgumentSource::new_from_cursor( + program, + *pseudo_cursor, + cursor_index + offset, + agg, + ), GroupByRowSource::MainLoop { start_reg_src, .. } => { // Aggregation arguments are always placed in the registers that follow any scalars. let start_reg_aggs = start_reg_src + t_ctx.non_aggregate_expressions.len(); - GroupByAggArgumentSource::new_from_registers(start_reg_aggs + offset, agg) + AggArgumentSource::new_from_registers(start_reg_aggs + offset, agg) } }; translate_aggregation_step_groupby( @@ -911,7 +809,7 @@ pub fn group_by_emit_row_phase<'a>( pub fn translate_aggregation_step_groupby( program: &mut ProgramBuilder, referenced_tables: &TableReferences, - agg_arg_source: GroupByAggArgumentSource, + agg_arg_source: AggArgumentSource, target_register: usize, resolver: &Resolver, ) -> Result { From 6f1cd17fcf7c99d2e1d2424f448bc63eab1cf750 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Sun, 31 Aug 2025 11:07:54 +0200 Subject: [PATCH 6/6] Consolidate methods emitting AggStep --- core/translate/aggregation.rs | 162 ++++++++++----------- core/translate/group_by.rs | 262 +--------------------------------- core/translate/main_loop.rs | 4 +- 3 files changed, 89 insertions(+), 339 deletions(-) diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index 2e43c3ea9..7a1de9776 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -61,7 +61,7 @@ pub fn emit_ungrouped_aggregation<'a>( Ok(()) } -pub fn emit_collseq_if_needed( +fn emit_collseq_if_needed( program: &mut ProgramBuilder, referenced_tables: &TableReferences, expr: &ast::Expr, @@ -134,6 +134,7 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re /// the data from the sorter. /// * In grouped cases where no sorting is required, arguments are retrieved directly /// from registers allocated in the main loop. +/// * In ungrouped cases, arguments are computed directly from the `args` expressions. pub enum AggArgumentSource<'a> { /// The aggregate function arguments are retrieved from a pseudo cursor /// which reads from the GROUP BY sorter. @@ -149,6 +150,8 @@ pub enum AggArgumentSource<'a> { src_reg_start: usize, aggregate: &'a Aggregate, }, + /// The aggregate function arguments are retrieved by evaluating expressions. + Expression { aggregate: &'a Aggregate }, } impl<'a> AggArgumentSource<'a> { @@ -176,10 +179,16 @@ impl<'a> AggArgumentSource<'a> { } } + /// Create a new [AggArgumentSource] that retrieves the values by evaluating `args` expressions. + pub fn new_from_expression(aggregate: &'a Aggregate) -> Self { + Self::Expression { aggregate } + } + pub fn aggregate(&self) -> &Aggregate { match self { AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate, AggArgumentSource::Register { aggregate, .. } => aggregate, + AggArgumentSource::Expression { aggregate } => aggregate, } } @@ -187,22 +196,31 @@ impl<'a> AggArgumentSource<'a> { match self { AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func, AggArgumentSource::Register { aggregate, .. } => &aggregate.func, + AggArgumentSource::Expression { aggregate } => &aggregate.func, } } pub fn args(&self) -> &[ast::Expr] { match self { AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args, AggArgumentSource::Register { aggregate, .. } => &aggregate.args, + AggArgumentSource::Expression { aggregate } => &aggregate.args, } } pub fn num_args(&self) -> usize { match self { AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(), AggArgumentSource::Register { aggregate, .. } => aggregate.args.len(), + AggArgumentSource::Expression { aggregate } => aggregate.args.len(), } } /// Read the value of an aggregate function argument - pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result { + pub fn translate( + &self, + program: &mut ProgramBuilder, + referenced_tables: &TableReferences, + resolver: &Resolver, + arg_idx: usize, + ) -> Result { match self { AggArgumentSource::PseudoCursor { cursor_id, @@ -221,31 +239,47 @@ impl<'a> AggArgumentSource<'a> { src_reg_start: start_reg, .. } => Ok(*start_reg + arg_idx), + AggArgumentSource::Expression { aggregate } => { + let dest_reg = program.alloc_register(); + translate_expr( + program, + Some(referenced_tables), + &aggregate.args[arg_idx], + dest_reg, + resolver, + ) + } } } } /// Emits the bytecode for processing an aggregate step. -/// E.g. in `SELECT SUM(price) FROM t`, 'price' is evaluated for every row, and the result is added to the accumulator. /// -/// This is distinct from the final step, which is called after the main loop has finished processing +/// This is distinct from the final step, which is called after a single group has been entirely accumulated, /// and the actual result value of the aggregation is materialized. +/// +/// Ungrouped aggregation is a special case of grouped aggregation that involves a single group. +/// +/// Examples: +/// * In `SELECT SUM(price) FROM t`, `price` is evaluated for each row and added to the accumulator. +/// * In `SELECT product_category, SUM(price) FROM t GROUP BY product_category`, `price` is evaluated for +/// each row in the group and added to that group’s accumulator. pub fn translate_aggregation_step( program: &mut ProgramBuilder, referenced_tables: &TableReferences, - agg: &Aggregate, + agg_arg_source: AggArgumentSource, target_register: usize, resolver: &Resolver, ) -> Result { - let dest = match agg.func { + let num_args = agg_arg_source.num_args(); + let func = agg_arg_source.agg_func(); + let dest = match func { AggFunc::Avg => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("avg bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -255,18 +289,16 @@ pub fn translate_aggregation_step( target_register } AggFunc::Count | AggFunc::Count0 => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("count bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, delimiter: 0, - func: if matches!(agg.func, AggFunc::Count0) { + func: if matches!(func, AggFunc::Count0) { AggFunc::Count0 } else { AggFunc::Count @@ -275,18 +307,16 @@ pub fn translate_aggregation_step( target_register } AggFunc::GroupConcat => { - if agg.args.len() != 1 && agg.args.len() != 2 { + if num_args != 1 && num_args != 2 { crate::bail_parse_error!("group_concat bad number of arguments"); } - let expr_reg = program.alloc_register(); let delimiter_reg = program.alloc_register(); - let expr = &agg.args[0]; let delimiter_expr: ast::Expr; - if agg.args.len() == 2 { - match &agg.args[1] { + if num_args == 2 { + match &agg_arg_source.args()[1] { arg @ ast::Expr::Column { .. } => { delimiter_expr = arg.clone(); } @@ -299,8 +329,8 @@ pub fn translate_aggregation_step( delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\""))); } - translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); translate_expr( program, Some(referenced_tables), @@ -319,13 +349,12 @@ pub fn translate_aggregation_step( target_register } AggFunc::Max => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("max bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let expr = &agg_arg_source.args()[0]; emit_collseq_if_needed(program, referenced_tables, expr); program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -336,13 +365,12 @@ pub fn translate_aggregation_step( target_register } AggFunc::Min => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("min bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let expr = &agg_arg_source.args()[0]; emit_collseq_if_needed(program, referenced_tables, expr); program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -354,23 +382,12 @@ pub fn translate_aggregation_step( } #[cfg(feature = "json")] AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => { - if agg.args.len() != 2 { + if num_args != 2 { crate::bail_parse_error!("max bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let value_expr = &agg.args[1]; - let value_reg = program.alloc_register(); - - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); - let _ = translate_expr( - program, - Some(referenced_tables), - value_expr, - value_reg, - resolver, - )?; + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let value_reg = agg_arg_source.translate(program, referenced_tables, resolver, 1)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -382,13 +399,11 @@ pub fn translate_aggregation_step( } #[cfg(feature = "json")] AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("max bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -398,15 +413,13 @@ pub fn translate_aggregation_step( target_register } AggFunc::StringAgg => { - if agg.args.len() != 2 { + if num_args != 2 { crate::bail_parse_error!("string_agg bad number of arguments"); } - let expr_reg = program.alloc_register(); let delimiter_reg = program.alloc_register(); - let expr = &agg.args[0]; - let delimiter_expr = match &agg.args[1] { + let delimiter_expr = match &agg_arg_source.args()[1] { arg @ ast::Expr::Column { .. } => arg.clone(), ast::Expr::Literal(ast::Literal::String(s)) => { ast::Expr::Literal(ast::Literal::String(s.to_string())) @@ -414,7 +427,7 @@ pub fn translate_aggregation_step( _ => crate::bail_parse_error!("Incorrect delimiter parameter"), }; - translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; translate_expr( program, Some(referenced_tables), @@ -433,13 +446,11 @@ pub fn translate_aggregation_step( target_register } AggFunc::Sum => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("sum bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -449,13 +460,11 @@ pub fn translate_aggregation_step( target_register } AggFunc::Total => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("total bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -465,31 +474,24 @@ pub fn translate_aggregation_step( target_register } AggFunc::External(ref func) => { - let expr_reg = program.alloc_register(); let argc = func.agg_args().map_err(|_| { LimboError::ExtensionError( "External aggregate function called with wrong number of arguments".to_string(), ) })?; - if argc != agg.args.len() { + if argc != num_args { crate::bail_parse_error!( "External aggregate function called with wrong number of arguments" ); } + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; for i in 0..argc { if i != 0 { - let _ = program.alloc_register(); + let _ = agg_arg_source.translate(program, referenced_tables, resolver, i)?; } - let _ = translate_expr( - program, - Some(referenced_tables), - &agg.args[i], - expr_reg + i, - resolver, - )?; // invariant: distinct aggregates are only supported for single-argument functions if argc == 1 { - handle_distinct(program, agg, expr_reg + i); + handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i); } } program.emit_insn(Insn::AggStep { diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 715bb5c93..37c05cc45 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -1,18 +1,16 @@ use turso_parser::ast; use super::{ - aggregation::handle_distinct, - emitter::{Resolver, TranslateCtx}, + emitter::TranslateCtx, expr::{translate_condition_expr, translate_expr, ConditionMetadata}, order_by::order_by_sorter_insert, - plan::{Distinctness, GroupBy, SelectPlan, TableReferences}, + plan::{Distinctness, GroupBy, SelectPlan}, result_row::emit_select_result, }; -use crate::translate::aggregation::{emit_collseq_if_needed, AggArgumentSource}; +use crate::translate::aggregation::{translate_aggregation_step, AggArgumentSource}; use crate::translate::expr::{walk_expr, WalkControl}; use crate::translate::plan::ResultSetColumn; use crate::{ - function::AggFunc, schema::PseudoCursorType, translate::collate::CollationSeq, util::exprs_are_equivalent, @@ -21,7 +19,7 @@ use crate::{ insn::Insn, BranchOffset, }, - LimboError, Result, + Result, }; /// Labels needed for various jumps in GROUP BY handling. @@ -509,7 +507,7 @@ pub fn group_by_process_single_group( AggArgumentSource::new_from_registers(start_reg_aggs + offset, agg) } }; - translate_aggregation_step_groupby( + translate_aggregation_step( program, &plan.table_references, agg_arg_source, @@ -799,253 +797,3 @@ pub fn group_by_emit_row_phase<'a>( program.preassign_label_to_next_insn(labels.label_group_by_end); Ok(()) } - -/// Emits the bytecode for processing an aggregate step within a GROUP BY clause. -/// Eg. in `SELECT product_category, SUM(price) FROM t GROUP BY line_item`, 'price' is evaluated for every row -/// where the 'product_category' is the same, and the result is added to the accumulator for that category. -/// -/// This is distinct from the final step, which is called after a single group has been entirely accumulated, -/// and the actual result value of the aggregation is materialized. -pub fn translate_aggregation_step_groupby( - program: &mut ProgramBuilder, - referenced_tables: &TableReferences, - agg_arg_source: AggArgumentSource, - target_register: usize, - resolver: &Resolver, -) -> Result { - let num_args = agg_arg_source.num_args(); - let dest = match agg_arg_source.agg_func() { - AggFunc::Avg => { - if num_args != 1 { - crate::bail_parse_error!("avg bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Avg, - }); - target_register - } - AggFunc::Count | AggFunc::Count0 => { - if num_args != 1 { - crate::bail_parse_error!("count bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: if matches!(agg_arg_source.agg_func(), AggFunc::Count0) { - AggFunc::Count0 - } else { - AggFunc::Count - }, - }); - target_register - } - AggFunc::GroupConcat => { - let num_args = agg_arg_source.num_args(); - if num_args != 1 && num_args != 2 { - crate::bail_parse_error!("group_concat bad number of arguments"); - } - - let delimiter_reg = program.alloc_register(); - - let delimiter_expr: ast::Expr; - - if num_args == 2 { - match &agg_arg_source.args()[1] { - arg @ ast::Expr::Column { .. } => { - delimiter_expr = arg.clone(); - } - ast::Expr::Literal(ast::Literal::String(s)) => { - delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string())); - } - _ => crate::bail_parse_error!("Incorrect delimiter parameter"), - }; - } else { - delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\""))); - } - - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - translate_expr( - program, - Some(referenced_tables), - &delimiter_expr, - delimiter_reg, - resolver, - )?; - - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: delimiter_reg, - func: AggFunc::GroupConcat, - }); - - target_register - } - AggFunc::Max => { - if num_args != 1 { - crate::bail_parse_error!("max bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - let expr = &agg_arg_source.args()[0]; - emit_collseq_if_needed(program, referenced_tables, expr); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Max, - }); - target_register - } - AggFunc::Min => { - if num_args != 1 { - crate::bail_parse_error!("min bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - let expr = &agg_arg_source.args()[0]; - emit_collseq_if_needed(program, referenced_tables, expr); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Min, - }); - target_register - } - #[cfg(feature = "json")] - AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => { - if num_args != 1 { - crate::bail_parse_error!("min bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::JsonGroupArray, - }); - target_register - } - #[cfg(feature = "json")] - AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => { - if num_args != 2 { - crate::bail_parse_error!("max bad number of arguments"); - } - - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - let value_reg = agg_arg_source.translate(program, 1)?; - - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: value_reg, - func: AggFunc::JsonGroupObject, - }); - target_register - } - AggFunc::StringAgg => { - if num_args != 2 { - crate::bail_parse_error!("string_agg bad number of arguments"); - } - - let delimiter_reg = program.alloc_register(); - - let delimiter_expr = match &agg_arg_source.args()[1] { - arg @ ast::Expr::Column { .. } => arg.clone(), - ast::Expr::Literal(ast::Literal::String(s)) => { - ast::Expr::Literal(ast::Literal::String(s.to_string())) - } - _ => crate::bail_parse_error!("Incorrect delimiter parameter"), - }; - - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - translate_expr( - program, - Some(referenced_tables), - &delimiter_expr, - delimiter_reg, - resolver, - )?; - - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: delimiter_reg, - func: AggFunc::StringAgg, - }); - - target_register - } - AggFunc::Sum => { - if num_args != 1 { - crate::bail_parse_error!("sum bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Sum, - }); - target_register - } - AggFunc::Total => { - if num_args != 1 { - crate::bail_parse_error!("total bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Total, - }); - target_register - } - AggFunc::External(ref func) => { - let argc = func.agg_args().map_err(|_| { - LimboError::ExtensionError( - "External aggregate function called with wrong number of arguments".to_string(), - ) - })?; - if argc != num_args { - crate::bail_parse_error!( - "External aggregate function called with wrong number of arguments" - ); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - for i in 0..argc { - if i != 0 { - let _ = agg_arg_source.translate(program, i)?; - } - // invariant: distinct aggregates are only supported for single-argument functions - if argc == 1 { - handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i); - } - } - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::External(func.clone()), - }); - target_register - } - }; - Ok(dest) -} diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 2617c86d8..06d801705 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -19,7 +19,7 @@ use crate::{ }; use super::{ - aggregation::translate_aggregation_step, + aggregation::{translate_aggregation_step, AggArgumentSource}, emitter::{OperationMode, TranslateCtx}, expr::{ translate_condition_expr, translate_expr, translate_expr_no_constant_opt, @@ -868,7 +868,7 @@ fn emit_loop_source( translate_aggregation_step( program, &plan.table_references, - agg, + AggArgumentSource::new_from_expression(agg), reg, &t_ctx.resolver, )?;