From 51c75c6014dd2ffa7f02e47f789cc06e37c1bd18 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Sat, 17 May 2025 15:00:31 +0300 Subject: [PATCH] Support distinct aggregates in GROUP BY --- core/translate/aggregation.rs | 2 +- core/translate/emitter.rs | 3 +++ core/translate/group_by.rs | 47 ++++++++++++++++++++++++++++++++++- core/translate/main_loop.rs | 15 +++++++---- testing/groupby.test | 11 +++++++- 5 files changed, 70 insertions(+), 8 deletions(-) diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index dee70ebca..3e5579d1d 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -63,7 +63,7 @@ pub fn emit_ungrouped_aggregation<'a>( /// Emits the bytecode for handling duplicates in a distinct aggregate. /// This is used in both GROUP BY and non-GROUP BY aggregations to jump over /// the AggStep that would otherwise accumulate the same value multiple times. -fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_reg: usize) { +pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_reg: usize) { let AggDistinctness::Distinct { ctx } = &agg.distinctness else { return; }; diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index dd7025544..027865ffa 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -288,6 +288,7 @@ pub fn emit_query<'a>( t_ctx, &plan.table_references, &mut plan.aggregates, + plan.group_by.as_ref(), OperationMode::SELECT, )?; @@ -398,6 +399,7 @@ fn emit_program_for_delete( &mut t_ctx, &plan.table_references, &mut [], + None, OperationMode::DELETE, )?; @@ -591,6 +593,7 @@ fn emit_program_for_update( &mut t_ctx, &plan.table_references, &mut [], + None, OperationMode::UPDATE, )?; // Open indexes for update. diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 9bb10a57e..ebfddaf9f 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -15,10 +15,11 @@ use crate::{ }; use super::{ + aggregation::handle_distinct, emitter::{Resolver, TranslateCtx}, expr::{translate_condition_expr, translate_expr, ConditionMetadata}, order_by::order_by_sorter_insert, - plan::{Aggregate, GroupBy, SelectPlan, TableReference}, + plan::{AggDistinctness, Aggregate, GroupBy, SelectPlan, TableReference}, result_row::emit_select_result, }; @@ -366,6 +367,14 @@ impl<'a> GroupByAggArgumentSource<'a> { aggregate, } } + + pub fn aggregate(&self) -> &Aggregate { + match self { + GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate, + GroupByAggArgumentSource::Register { aggregate, .. } => aggregate, + } + } + pub fn agg_func(&self) -> &AggFunc { match self { GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func, @@ -535,6 +544,12 @@ pub fn group_by_process_single_group( agg_result_reg, &t_ctx.resolver, )?; + if let AggDistinctness::Distinct { ctx } = &agg.distinctness { + let ctx = ctx + .as_ref() + .expect("distinct aggregate context not populated"); + program.preassign_label_to_next_insn(ctx.label_on_conflict); + } offset += agg.args.len(); } @@ -873,6 +888,26 @@ pub fn group_by_emit_row_phase<'a>( dest_end: Some(start_reg + plan.group_by_sorter_column_count() - 1), }); + // Reopen ephemeral indexes for distinct aggregates (effectively clearing them). + plan.aggregates + .iter() + .filter_map(|agg| { + if let AggDistinctness::Distinct { ctx } = &agg.distinctness { + Some(ctx) + } else { + None + } + }) + .for_each(|ctx| { + let ctx = ctx + .as_ref() + .expect("distinct aggregate context not populated"); + program.emit_insn(Insn::OpenEphemeral { + cursor_id: ctx.cursor_id, + is_table: false, + }); + }); + program.emit_insn(Insn::Integer { value: 0, dest: registers.reg_data_in_acc_flag, @@ -904,6 +939,7 @@ pub fn translate_aggregation_step_groupby( crate::bail_parse_error!("avg bad number of arguments"); } let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -914,6 +950,7 @@ pub fn translate_aggregation_step_groupby( } AggFunc::Count | AggFunc::Count0 => { let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -951,6 +988,7 @@ pub fn translate_aggregation_step_groupby( } let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); translate_expr( program, Some(referenced_tables), @@ -973,6 +1011,7 @@ pub fn translate_aggregation_step_groupby( crate::bail_parse_error!("max bad number of arguments"); } let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -986,6 +1025,7 @@ pub fn translate_aggregation_step_groupby( crate::bail_parse_error!("min bad number of arguments"); } let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -1000,6 +1040,7 @@ pub fn translate_aggregation_step_groupby( crate::bail_parse_error!("min bad number of arguments"); } let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -1015,6 +1056,7 @@ pub fn translate_aggregation_step_groupby( } let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); let value_reg = agg_arg_source.translate(program, 1)?; program.emit_insn(Insn::AggStep { @@ -1041,6 +1083,7 @@ pub fn translate_aggregation_step_groupby( }; let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); translate_expr( program, Some(referenced_tables), @@ -1063,6 +1106,7 @@ pub fn translate_aggregation_step_groupby( crate::bail_parse_error!("sum bad number of arguments"); } let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -1076,6 +1120,7 @@ pub fn translate_aggregation_step_groupby( crate::bail_parse_error!("total bad number of arguments"); } let expr_reg = agg_arg_source.translate(program, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index bfce19cf2..2cefd9254 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -29,7 +29,7 @@ use super::{ optimizer::Optimizable, order_by::{order_by_sorter_insert, sorter_insert}, plan::{ - convert_where_to_vtab_constraint, Aggregate, IterationDirection, JoinOrderMember, + convert_where_to_vtab_constraint, Aggregate, GroupBy, IterationDirection, JoinOrderMember, Operation, Search, SeekDef, SelectPlan, SelectQueryType, TableReference, WhereTerm, }, }; @@ -72,6 +72,7 @@ pub fn init_loop( t_ctx: &mut TranslateCtx, tables: &[TableReference], aggregates: &mut [Aggregate], + group_by: Option<&GroupBy>, mode: OperationMode, ) -> Result<()> { assert!( @@ -105,10 +106,14 @@ pub fn init_loop( Some(index_name.clone()), CursorType::BTreeIndex(index.clone()), ); - program.emit_insn(Insn::OpenEphemeral { - cursor_id, - is_table: false, - }); + if group_by.is_none() { + // In GROUP BY, the ephemeral index is reinitialized for every group + // in the clear accumulator subroutine, so we only do it here if there is no GROUP BY. + program.emit_insn(Insn::OpenEphemeral { + cursor_id, + is_table: false, + }); + } agg.distinctness = AggDistinctness::Distinct { ctx: Some(DistinctAggCtx { cursor_id, diff --git a/testing/groupby.test b/testing/groupby.test index 70141be0a..1012ed658 100644 --- a/testing/groupby.test +++ b/testing/groupby.test @@ -197,4 +197,13 @@ do_execsql_test group_by_no_sorting_required { select age, count(1) from users group by age limit 3; } {1|112 2|113 -3|97} \ No newline at end of file +3|97} + +do_execsql_test distinct_agg_functions { + select first_name, sum(distinct age), count(distinct age), avg(distinct age) + from users + group by 1 + limit 3; +} {Aaron|1769|33|53.6060606060606 +Abigail|833|15|55.5333333333333 +Adam|1517|30|50.5666666666667} \ No newline at end of file