non group by cols are displayed in group by agg statements

This commit is contained in:
Ihor Andrianov
2025-03-30 18:39:19 +03:00
parent 4fd1dcdc73
commit 2bcdd4e404
3 changed files with 120 additions and 28 deletions

View File

@@ -262,8 +262,8 @@ pub fn emit_query<'a>(
init_order_by(program, t_ctx, order_by)?;
}
if let Some(ref mut group_by) = plan.group_by {
init_group_by(program, t_ctx, group_by, &plan.aggregates)?;
if let Some(ref group_by) = plan.group_by {
init_group_by(program, t_ctx, group_by, &plan)?;
}
init_loop(
program,

View File

@@ -50,15 +50,22 @@ pub fn init_group_by(
program: &mut ProgramBuilder,
t_ctx: &mut TranslateCtx,
group_by: &GroupBy,
aggregates: &[Aggregate],
plan: &SelectPlan,
) -> Result<()> {
let num_aggs = aggregates.len();
let num_aggs = plan.aggregates.len();
// Calculate this count only once
let non_aggregate_count = plan
.result_columns
.iter()
.filter(|rc| !rc.contains_aggregates)
.count();
let sort_cursor = program.alloc_cursor_id(None, CursorType::Sorter);
let reg_abort_flag = program.alloc_register();
let reg_group_exprs_cmp = program.alloc_registers(group_by.exprs.len());
let reg_group_exprs_acc = program.alloc_registers(group_by.exprs.len());
let reg_group_exprs_acc = program.alloc_registers(non_aggregate_count);
let reg_agg_exprs_start = program.alloc_registers(num_aggs);
let reg_sorter_key = program.alloc_register();
@@ -71,7 +78,7 @@ pub fn init_group_by(
}
program.emit_insn(Insn::SorterOpen {
cursor_id: sort_cursor,
columns: aggregates.len() + group_by.exprs.len(),
columns: non_aggregate_count + plan.aggregates.len(),
order: Record::new(order),
});
@@ -156,14 +163,23 @@ pub fn emit_group_by<'a>(
let group_by = plan.group_by.as_ref().unwrap();
// Calculate these values once
let non_aggregate_count = plan
.result_columns
.iter()
.filter(|rc| !rc.contains_aggregates)
.count();
let agg_args_count = plan
.aggregates
.iter()
.map(|agg| agg.args.len())
.sum::<usize>();
// all group by columns and all arguments of agg functions are in the sorter.
// the sort keys are the group by columns (the aggregation within groups is done based on how long the sort keys remain the same)
let sorter_column_count = group_by.exprs.len()
+ plan
.aggregates
.iter()
.map(|agg| agg.args.len())
.sum::<usize>();
let sorter_column_count = non_aggregate_count + agg_args_count;
// sorter column names do not matter
let ty = crate::schema::Type::Null;
let pseudo_columns = (0..sorter_column_count)
@@ -238,11 +254,6 @@ pub fn emit_group_by<'a>(
});
// New group, move current group by columns into the comparison register
program.emit_insn(Insn::Move {
source_reg: groups_start_reg,
dest_reg: reg_group_exprs_cmp,
count: group_by.exprs.len(),
});
program.add_comment(
program.offset(),
@@ -253,6 +264,12 @@ pub fn emit_group_by<'a>(
return_reg: reg_subrtn_acc_output_return_offset,
});
program.emit_insn(Insn::Move {
source_reg: groups_start_reg,
dest_reg: reg_group_exprs_cmp,
count: group_by.exprs.len(),
});
program.add_comment(program.offset(), "check abort flag");
program.emit_insn(Insn::IfPos {
reg: reg_abort_flag,
@@ -269,7 +286,7 @@ pub fn emit_group_by<'a>(
// Accumulate the values into the aggregations
program.resolve_label(agg_step_label, program.offset());
let start_reg = t_ctx.reg_agg_start.unwrap();
let mut cursor_index = group_by.exprs.len();
let mut cursor_index = non_aggregate_count;
for (i, agg) in plan.aggregates.iter().enumerate() {
let agg_result_reg = start_reg + i;
translate_aggregation_step_groupby(
@@ -296,7 +313,7 @@ pub fn emit_group_by<'a>(
});
// Read the group by columns for a finished group
for i in 0..group_by.exprs.len() {
for i in 0..non_aggregate_count {
let key_reg = reg_group_exprs_acc + i;
let sorter_column_index = i;
program.emit_insn(Insn::Column {
@@ -363,6 +380,12 @@ pub fn emit_group_by<'a>(
});
}
// Cache expressions we need multiple times
let filtered_results = plan
.result_columns
.iter()
.filter(|rc| !rc.contains_aggregates)
.collect::<Vec<_>>();
// we now have the group by columns in registers (group_exprs_start_register..group_exprs_start_register + group_by.len() - 1)
// and the agg results in (agg_start_reg..agg_start_reg + aggregates.len() - 1)
// we need to call translate_expr on each result column, but replace the expr with a register copy in case any part of the
@@ -373,6 +396,24 @@ pub fn emit_group_by<'a>(
.expr_to_reg_cache
.push((expr, reg_group_exprs_acc + i));
}
// Offset for the next expressions after group_by
let mut offset = group_by.exprs.len();
for rc in filtered_results.iter() {
let expr = &rc.expr;
// skip cols that are already in group by
if !matches!(expr, ast::Expr::Column { .. })
|| !is_column_in_group_by(expr, &group_by.exprs)
{
t_ctx
.resolver
.expr_to_reg_cache
.push((expr, reg_group_exprs_acc + offset));
offset += 1;
}
}
for (i, agg) in plan.aggregates.iter().enumerate() {
t_ctx
.resolver
@@ -420,7 +461,7 @@ pub fn emit_group_by<'a>(
let start_reg = reg_group_exprs_acc;
program.emit_insn(Insn::Null {
dest: start_reg,
dest_end: Some(start_reg + group_by.exprs.len() + plan.aggregates.len() - 1),
dest_end: Some(start_reg + non_aggregate_count + plan.aggregates.len() - 1),
});
program.emit_insn(Insn::Integer {
@@ -668,3 +709,29 @@ pub fn translate_aggregation_step_groupby(
};
Ok(dest)
}
pub fn is_column_in_group_by(expr: &ast::Expr, group_by_exprs: &[ast::Expr]) -> bool {
if let ast::Expr::Column {
database: _,
table: _,
column: col,
is_rowid_alias: _,
} = expr
{
group_by_exprs.iter().any(|ex| {
if let ast::Expr::Column {
database: _,
table: _,
column: group_col,
is_rowid_alias: _,
} = ex
{
col == group_col
} else {
false
}
})
} else {
false
}
}

View File

@@ -15,6 +15,7 @@ use super::{
aggregation::translate_aggregation_step,
emitter::{OperationMode, TranslateCtx},
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
group_by::is_column_in_group_by,
order_by::{order_by_sorter_insert, sorter_insert},
plan::{
IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference,
@@ -599,7 +600,12 @@ fn emit_loop_source(
LoopEmitTarget::GroupBySorter => {
let group_by = plan.group_by.as_ref().unwrap();
let aggregates = &plan.aggregates;
let sort_keys_count = group_by.exprs.len();
let non_aggregate_columns = plan
.result_columns
.iter()
.filter(|rc| !rc.contains_aggregates)
.collect::<Vec<_>>();
let sort_keys_count = non_aggregate_columns.len();
let aggregate_arguments_count = plan
.aggregates
.iter()
@@ -621,6 +627,25 @@ fn emit_loop_source(
&t_ctx.resolver,
)?;
}
if group_by.exprs.len() + aggregates.len() != plan.result_columns.len() {
for rc in non_aggregate_columns.iter() {
let expr = &rc.expr;
if !is_column_in_group_by(expr, &group_by.exprs) {
let key_reg = cur_reg;
cur_reg += 1;
translate_expr(
program,
Some(&plan.table_references),
expr,
key_reg,
&t_ctx.resolver,
)?;
}
}
}
// Process non-aggregate result columns that aren't already in group_by
// Then we have the aggregate arguments.
for agg in aggregates.iter() {
// Here we are collecting scalars for the group by sorter, which will include
@@ -692,14 +717,14 @@ fn emit_loop_source(
let col_start = t_ctx.reg_result_cols_start.unwrap();
for (i, rc) in plan.result_columns.iter().enumerate() {
if rc.contains_aggregates {
// Do nothing, aggregates are computed above
// if this result column is e.g. something like sum(x) + 1 or length(sum(x)), we do not want to translate that (+1) or length() yet,
// it will be computed after the aggregations are finalized.
continue;
}
// Process only non-aggregate columns
let non_agg_columns = plan
.result_columns
.iter()
.enumerate()
.filter(|(_, rc)| !rc.contains_aggregates);
for (i, rc) in non_agg_columns {
let reg = col_start + i;
translate_expr(