diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index ef3c2b9ad..0eaeef58f 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -262,8 +262,8 @@ pub fn emit_query<'a>( init_order_by(program, t_ctx, order_by)?; } - if let Some(ref mut group_by) = plan.group_by { - init_group_by(program, t_ctx, group_by, &plan.aggregates)?; + if let Some(ref group_by) = plan.group_by { + init_group_by(program, t_ctx, group_by, &plan)?; } init_loop( program, diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 13d860f16..86d6087e2 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -50,15 +50,22 @@ pub fn init_group_by( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, group_by: &GroupBy, - aggregates: &[Aggregate], + plan: &SelectPlan, ) -> Result<()> { - let num_aggs = aggregates.len(); + let num_aggs = plan.aggregates.len(); + + // Calculate this count only once + let non_aggregate_count = plan + .result_columns + .iter() + .filter(|rc| !rc.contains_aggregates) + .count(); let sort_cursor = program.alloc_cursor_id(None, CursorType::Sorter); let reg_abort_flag = program.alloc_register(); let reg_group_exprs_cmp = program.alloc_registers(group_by.exprs.len()); - let reg_group_exprs_acc = program.alloc_registers(group_by.exprs.len()); + let reg_group_exprs_acc = program.alloc_registers(non_aggregate_count); let reg_agg_exprs_start = program.alloc_registers(num_aggs); let reg_sorter_key = program.alloc_register(); @@ -71,7 +78,7 @@ pub fn init_group_by( } program.emit_insn(Insn::SorterOpen { cursor_id: sort_cursor, - columns: aggregates.len() + group_by.exprs.len(), + columns: non_aggregate_count + plan.aggregates.len(), order: Record::new(order), }); @@ -156,14 +163,23 @@ pub fn emit_group_by<'a>( let group_by = plan.group_by.as_ref().unwrap(); + // Calculate these values once + let non_aggregate_count = plan + .result_columns + .iter() + .filter(|rc| !rc.contains_aggregates) + .count(); + + let agg_args_count = plan + .aggregates + .iter() + .map(|agg| agg.args.len()) + .sum::(); + // all group by columns and all arguments of agg functions are in the sorter. // the sort keys are the group by columns (the aggregation within groups is done based on how long the sort keys remain the same) - let sorter_column_count = group_by.exprs.len() - + plan - .aggregates - .iter() - .map(|agg| agg.args.len()) - .sum::(); + let sorter_column_count = non_aggregate_count + agg_args_count; + // sorter column names do not matter let ty = crate::schema::Type::Null; let pseudo_columns = (0..sorter_column_count) @@ -238,11 +254,6 @@ pub fn emit_group_by<'a>( }); // New group, move current group by columns into the comparison register - program.emit_insn(Insn::Move { - source_reg: groups_start_reg, - dest_reg: reg_group_exprs_cmp, - count: group_by.exprs.len(), - }); program.add_comment( program.offset(), @@ -253,6 +264,12 @@ pub fn emit_group_by<'a>( return_reg: reg_subrtn_acc_output_return_offset, }); + program.emit_insn(Insn::Move { + source_reg: groups_start_reg, + dest_reg: reg_group_exprs_cmp, + count: group_by.exprs.len(), + }); + program.add_comment(program.offset(), "check abort flag"); program.emit_insn(Insn::IfPos { reg: reg_abort_flag, @@ -269,7 +286,7 @@ pub fn emit_group_by<'a>( // Accumulate the values into the aggregations program.resolve_label(agg_step_label, program.offset()); let start_reg = t_ctx.reg_agg_start.unwrap(); - let mut cursor_index = group_by.exprs.len(); + let mut cursor_index = non_aggregate_count; for (i, agg) in plan.aggregates.iter().enumerate() { let agg_result_reg = start_reg + i; translate_aggregation_step_groupby( @@ -296,7 +313,7 @@ pub fn emit_group_by<'a>( }); // Read the group by columns for a finished group - for i in 0..group_by.exprs.len() { + for i in 0..non_aggregate_count { let key_reg = reg_group_exprs_acc + i; let sorter_column_index = i; program.emit_insn(Insn::Column { @@ -363,6 +380,12 @@ pub fn emit_group_by<'a>( }); } + // Cache expressions we need multiple times + let filtered_results = plan + .result_columns + .iter() + .filter(|rc| !rc.contains_aggregates) + .collect::>(); // we now have the group by columns in registers (group_exprs_start_register..group_exprs_start_register + group_by.len() - 1) // and the agg results in (agg_start_reg..agg_start_reg + aggregates.len() - 1) // we need to call translate_expr on each result column, but replace the expr with a register copy in case any part of the @@ -373,6 +396,24 @@ pub fn emit_group_by<'a>( .expr_to_reg_cache .push((expr, reg_group_exprs_acc + i)); } + + // Offset for the next expressions after group_by + let mut offset = group_by.exprs.len(); + + for rc in filtered_results.iter() { + let expr = &rc.expr; + + // skip cols that are already in group by + if !matches!(expr, ast::Expr::Column { .. }) + || !is_column_in_group_by(expr, &group_by.exprs) + { + t_ctx + .resolver + .expr_to_reg_cache + .push((expr, reg_group_exprs_acc + offset)); + offset += 1; + } + } for (i, agg) in plan.aggregates.iter().enumerate() { t_ctx .resolver @@ -420,7 +461,7 @@ pub fn emit_group_by<'a>( let start_reg = reg_group_exprs_acc; program.emit_insn(Insn::Null { dest: start_reg, - dest_end: Some(start_reg + group_by.exprs.len() + plan.aggregates.len() - 1), + dest_end: Some(start_reg + non_aggregate_count + plan.aggregates.len() - 1), }); program.emit_insn(Insn::Integer { @@ -668,3 +709,29 @@ pub fn translate_aggregation_step_groupby( }; Ok(dest) } + +pub fn is_column_in_group_by(expr: &ast::Expr, group_by_exprs: &[ast::Expr]) -> bool { + if let ast::Expr::Column { + database: _, + table: _, + column: col, + is_rowid_alias: _, + } = expr + { + group_by_exprs.iter().any(|ex| { + if let ast::Expr::Column { + database: _, + table: _, + column: group_col, + is_rowid_alias: _, + } = ex + { + col == group_col + } else { + false + } + }) + } else { + false + } +} diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index ea7b2f4ef..9c58e193d 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -15,6 +15,7 @@ use super::{ aggregation::translate_aggregation_step, emitter::{OperationMode, TranslateCtx}, expr::{translate_condition_expr, translate_expr, ConditionMetadata}, + group_by::is_column_in_group_by, order_by::{order_by_sorter_insert, sorter_insert}, plan::{ IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference, @@ -599,7 +600,12 @@ fn emit_loop_source( LoopEmitTarget::GroupBySorter => { let group_by = plan.group_by.as_ref().unwrap(); let aggregates = &plan.aggregates; - let sort_keys_count = group_by.exprs.len(); + let non_aggregate_columns = plan + .result_columns + .iter() + .filter(|rc| !rc.contains_aggregates) + .collect::>(); + let sort_keys_count = non_aggregate_columns.len(); let aggregate_arguments_count = plan .aggregates .iter() @@ -621,6 +627,25 @@ fn emit_loop_source( &t_ctx.resolver, )?; } + + if group_by.exprs.len() + aggregates.len() != plan.result_columns.len() { + for rc in non_aggregate_columns.iter() { + let expr = &rc.expr; + if !is_column_in_group_by(expr, &group_by.exprs) { + let key_reg = cur_reg; + cur_reg += 1; + translate_expr( + program, + Some(&plan.table_references), + expr, + key_reg, + &t_ctx.resolver, + )?; + } + } + } + // Process non-aggregate result columns that aren't already in group_by + // Then we have the aggregate arguments. for agg in aggregates.iter() { // Here we are collecting scalars for the group by sorter, which will include @@ -692,14 +717,14 @@ fn emit_loop_source( let col_start = t_ctx.reg_result_cols_start.unwrap(); - for (i, rc) in plan.result_columns.iter().enumerate() { - if rc.contains_aggregates { - // Do nothing, aggregates are computed above - // if this result column is e.g. something like sum(x) + 1 or length(sum(x)), we do not want to translate that (+1) or length() yet, - // it will be computed after the aggregations are finalized. - continue; - } + // Process only non-aggregate columns + let non_agg_columns = plan + .result_columns + .iter() + .enumerate() + .filter(|(_, rc)| !rc.contains_aggregates); + for (i, rc) in non_agg_columns { let reg = col_start + i; translate_expr(