mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-07 02:04:21 +01:00
non group by cols are displayed in group by agg statements
This commit is contained in:
@@ -262,8 +262,8 @@ pub fn emit_query<'a>(
|
||||
init_order_by(program, t_ctx, order_by)?;
|
||||
}
|
||||
|
||||
if let Some(ref mut group_by) = plan.group_by {
|
||||
init_group_by(program, t_ctx, group_by, &plan.aggregates)?;
|
||||
if let Some(ref group_by) = plan.group_by {
|
||||
init_group_by(program, t_ctx, group_by, &plan)?;
|
||||
}
|
||||
init_loop(
|
||||
program,
|
||||
|
||||
@@ -50,15 +50,22 @@ pub fn init_group_by(
|
||||
program: &mut ProgramBuilder,
|
||||
t_ctx: &mut TranslateCtx,
|
||||
group_by: &GroupBy,
|
||||
aggregates: &[Aggregate],
|
||||
plan: &SelectPlan,
|
||||
) -> Result<()> {
|
||||
let num_aggs = aggregates.len();
|
||||
let num_aggs = plan.aggregates.len();
|
||||
|
||||
// Calculate this count only once
|
||||
let non_aggregate_count = plan
|
||||
.result_columns
|
||||
.iter()
|
||||
.filter(|rc| !rc.contains_aggregates)
|
||||
.count();
|
||||
|
||||
let sort_cursor = program.alloc_cursor_id(None, CursorType::Sorter);
|
||||
|
||||
let reg_abort_flag = program.alloc_register();
|
||||
let reg_group_exprs_cmp = program.alloc_registers(group_by.exprs.len());
|
||||
let reg_group_exprs_acc = program.alloc_registers(group_by.exprs.len());
|
||||
let reg_group_exprs_acc = program.alloc_registers(non_aggregate_count);
|
||||
let reg_agg_exprs_start = program.alloc_registers(num_aggs);
|
||||
let reg_sorter_key = program.alloc_register();
|
||||
|
||||
@@ -71,7 +78,7 @@ pub fn init_group_by(
|
||||
}
|
||||
program.emit_insn(Insn::SorterOpen {
|
||||
cursor_id: sort_cursor,
|
||||
columns: aggregates.len() + group_by.exprs.len(),
|
||||
columns: non_aggregate_count + plan.aggregates.len(),
|
||||
order: Record::new(order),
|
||||
});
|
||||
|
||||
@@ -156,14 +163,23 @@ pub fn emit_group_by<'a>(
|
||||
|
||||
let group_by = plan.group_by.as_ref().unwrap();
|
||||
|
||||
// Calculate these values once
|
||||
let non_aggregate_count = plan
|
||||
.result_columns
|
||||
.iter()
|
||||
.filter(|rc| !rc.contains_aggregates)
|
||||
.count();
|
||||
|
||||
let agg_args_count = plan
|
||||
.aggregates
|
||||
.iter()
|
||||
.map(|agg| agg.args.len())
|
||||
.sum::<usize>();
|
||||
|
||||
// all group by columns and all arguments of agg functions are in the sorter.
|
||||
// the sort keys are the group by columns (the aggregation within groups is done based on how long the sort keys remain the same)
|
||||
let sorter_column_count = group_by.exprs.len()
|
||||
+ plan
|
||||
.aggregates
|
||||
.iter()
|
||||
.map(|agg| agg.args.len())
|
||||
.sum::<usize>();
|
||||
let sorter_column_count = non_aggregate_count + agg_args_count;
|
||||
|
||||
// sorter column names do not matter
|
||||
let ty = crate::schema::Type::Null;
|
||||
let pseudo_columns = (0..sorter_column_count)
|
||||
@@ -238,11 +254,6 @@ pub fn emit_group_by<'a>(
|
||||
});
|
||||
|
||||
// New group, move current group by columns into the comparison register
|
||||
program.emit_insn(Insn::Move {
|
||||
source_reg: groups_start_reg,
|
||||
dest_reg: reg_group_exprs_cmp,
|
||||
count: group_by.exprs.len(),
|
||||
});
|
||||
|
||||
program.add_comment(
|
||||
program.offset(),
|
||||
@@ -253,6 +264,12 @@ pub fn emit_group_by<'a>(
|
||||
return_reg: reg_subrtn_acc_output_return_offset,
|
||||
});
|
||||
|
||||
program.emit_insn(Insn::Move {
|
||||
source_reg: groups_start_reg,
|
||||
dest_reg: reg_group_exprs_cmp,
|
||||
count: group_by.exprs.len(),
|
||||
});
|
||||
|
||||
program.add_comment(program.offset(), "check abort flag");
|
||||
program.emit_insn(Insn::IfPos {
|
||||
reg: reg_abort_flag,
|
||||
@@ -269,7 +286,7 @@ pub fn emit_group_by<'a>(
|
||||
// Accumulate the values into the aggregations
|
||||
program.resolve_label(agg_step_label, program.offset());
|
||||
let start_reg = t_ctx.reg_agg_start.unwrap();
|
||||
let mut cursor_index = group_by.exprs.len();
|
||||
let mut cursor_index = non_aggregate_count;
|
||||
for (i, agg) in plan.aggregates.iter().enumerate() {
|
||||
let agg_result_reg = start_reg + i;
|
||||
translate_aggregation_step_groupby(
|
||||
@@ -296,7 +313,7 @@ pub fn emit_group_by<'a>(
|
||||
});
|
||||
|
||||
// Read the group by columns for a finished group
|
||||
for i in 0..group_by.exprs.len() {
|
||||
for i in 0..non_aggregate_count {
|
||||
let key_reg = reg_group_exprs_acc + i;
|
||||
let sorter_column_index = i;
|
||||
program.emit_insn(Insn::Column {
|
||||
@@ -363,6 +380,12 @@ pub fn emit_group_by<'a>(
|
||||
});
|
||||
}
|
||||
|
||||
// Cache expressions we need multiple times
|
||||
let filtered_results = plan
|
||||
.result_columns
|
||||
.iter()
|
||||
.filter(|rc| !rc.contains_aggregates)
|
||||
.collect::<Vec<_>>();
|
||||
// we now have the group by columns in registers (group_exprs_start_register..group_exprs_start_register + group_by.len() - 1)
|
||||
// and the agg results in (agg_start_reg..agg_start_reg + aggregates.len() - 1)
|
||||
// we need to call translate_expr on each result column, but replace the expr with a register copy in case any part of the
|
||||
@@ -373,6 +396,24 @@ pub fn emit_group_by<'a>(
|
||||
.expr_to_reg_cache
|
||||
.push((expr, reg_group_exprs_acc + i));
|
||||
}
|
||||
|
||||
// Offset for the next expressions after group_by
|
||||
let mut offset = group_by.exprs.len();
|
||||
|
||||
for rc in filtered_results.iter() {
|
||||
let expr = &rc.expr;
|
||||
|
||||
// skip cols that are already in group by
|
||||
if !matches!(expr, ast::Expr::Column { .. })
|
||||
|| !is_column_in_group_by(expr, &group_by.exprs)
|
||||
{
|
||||
t_ctx
|
||||
.resolver
|
||||
.expr_to_reg_cache
|
||||
.push((expr, reg_group_exprs_acc + offset));
|
||||
offset += 1;
|
||||
}
|
||||
}
|
||||
for (i, agg) in plan.aggregates.iter().enumerate() {
|
||||
t_ctx
|
||||
.resolver
|
||||
@@ -420,7 +461,7 @@ pub fn emit_group_by<'a>(
|
||||
let start_reg = reg_group_exprs_acc;
|
||||
program.emit_insn(Insn::Null {
|
||||
dest: start_reg,
|
||||
dest_end: Some(start_reg + group_by.exprs.len() + plan.aggregates.len() - 1),
|
||||
dest_end: Some(start_reg + non_aggregate_count + plan.aggregates.len() - 1),
|
||||
});
|
||||
|
||||
program.emit_insn(Insn::Integer {
|
||||
@@ -668,3 +709,29 @@ pub fn translate_aggregation_step_groupby(
|
||||
};
|
||||
Ok(dest)
|
||||
}
|
||||
|
||||
pub fn is_column_in_group_by(expr: &ast::Expr, group_by_exprs: &[ast::Expr]) -> bool {
|
||||
if let ast::Expr::Column {
|
||||
database: _,
|
||||
table: _,
|
||||
column: col,
|
||||
is_rowid_alias: _,
|
||||
} = expr
|
||||
{
|
||||
group_by_exprs.iter().any(|ex| {
|
||||
if let ast::Expr::Column {
|
||||
database: _,
|
||||
table: _,
|
||||
column: group_col,
|
||||
is_rowid_alias: _,
|
||||
} = ex
|
||||
{
|
||||
col == group_col
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ use super::{
|
||||
aggregation::translate_aggregation_step,
|
||||
emitter::{OperationMode, TranslateCtx},
|
||||
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
|
||||
group_by::is_column_in_group_by,
|
||||
order_by::{order_by_sorter_insert, sorter_insert},
|
||||
plan::{
|
||||
IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference,
|
||||
@@ -599,7 +600,12 @@ fn emit_loop_source(
|
||||
LoopEmitTarget::GroupBySorter => {
|
||||
let group_by = plan.group_by.as_ref().unwrap();
|
||||
let aggregates = &plan.aggregates;
|
||||
let sort_keys_count = group_by.exprs.len();
|
||||
let non_aggregate_columns = plan
|
||||
.result_columns
|
||||
.iter()
|
||||
.filter(|rc| !rc.contains_aggregates)
|
||||
.collect::<Vec<_>>();
|
||||
let sort_keys_count = non_aggregate_columns.len();
|
||||
let aggregate_arguments_count = plan
|
||||
.aggregates
|
||||
.iter()
|
||||
@@ -621,6 +627,25 @@ fn emit_loop_source(
|
||||
&t_ctx.resolver,
|
||||
)?;
|
||||
}
|
||||
|
||||
if group_by.exprs.len() + aggregates.len() != plan.result_columns.len() {
|
||||
for rc in non_aggregate_columns.iter() {
|
||||
let expr = &rc.expr;
|
||||
if !is_column_in_group_by(expr, &group_by.exprs) {
|
||||
let key_reg = cur_reg;
|
||||
cur_reg += 1;
|
||||
translate_expr(
|
||||
program,
|
||||
Some(&plan.table_references),
|
||||
expr,
|
||||
key_reg,
|
||||
&t_ctx.resolver,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Process non-aggregate result columns that aren't already in group_by
|
||||
|
||||
// Then we have the aggregate arguments.
|
||||
for agg in aggregates.iter() {
|
||||
// Here we are collecting scalars for the group by sorter, which will include
|
||||
@@ -692,14 +717,14 @@ fn emit_loop_source(
|
||||
|
||||
let col_start = t_ctx.reg_result_cols_start.unwrap();
|
||||
|
||||
for (i, rc) in plan.result_columns.iter().enumerate() {
|
||||
if rc.contains_aggregates {
|
||||
// Do nothing, aggregates are computed above
|
||||
// if this result column is e.g. something like sum(x) + 1 or length(sum(x)), we do not want to translate that (+1) or length() yet,
|
||||
// it will be computed after the aggregations are finalized.
|
||||
continue;
|
||||
}
|
||||
// Process only non-aggregate columns
|
||||
let non_agg_columns = plan
|
||||
.result_columns
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, rc)| !rc.contains_aggregates);
|
||||
|
||||
for (i, rc) in non_agg_columns {
|
||||
let reg = col_start + i;
|
||||
|
||||
translate_expr(
|
||||
|
||||
Reference in New Issue
Block a user