Consolidate methods emitting AggStep

This commit is contained in:
Piotr Rzysko
2025-08-31 11:07:54 +02:00
parent cdba1f1b87
commit 6f1cd17fcf
3 changed files with 89 additions and 339 deletions

View File

@@ -61,7 +61,7 @@ pub fn emit_ungrouped_aggregation<'a>(
Ok(())
}
pub fn emit_collseq_if_needed(
fn emit_collseq_if_needed(
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
expr: &ast::Expr,
@@ -134,6 +134,7 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re
/// the data from the sorter.
/// * In grouped cases where no sorting is required, arguments are retrieved directly
/// from registers allocated in the main loop.
/// * In ungrouped cases, arguments are computed directly from the `args` expressions.
pub enum AggArgumentSource<'a> {
/// The aggregate function arguments are retrieved from a pseudo cursor
/// which reads from the GROUP BY sorter.
@@ -149,6 +150,8 @@ pub enum AggArgumentSource<'a> {
src_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved by evaluating expressions.
Expression { aggregate: &'a Aggregate },
}
impl<'a> AggArgumentSource<'a> {
@@ -176,10 +179,16 @@ impl<'a> AggArgumentSource<'a> {
}
}
/// Create a new [AggArgumentSource] that retrieves the values by evaluating `args` expressions.
pub fn new_from_expression(aggregate: &'a Aggregate) -> Self {
Self::Expression { aggregate }
}
pub fn aggregate(&self) -> &Aggregate {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate,
AggArgumentSource::Register { aggregate, .. } => aggregate,
AggArgumentSource::Expression { aggregate } => aggregate,
}
}
@@ -187,22 +196,31 @@ impl<'a> AggArgumentSource<'a> {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func,
AggArgumentSource::Register { aggregate, .. } => &aggregate.func,
AggArgumentSource::Expression { aggregate } => &aggregate.func,
}
}
pub fn args(&self) -> &[ast::Expr] {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args,
AggArgumentSource::Register { aggregate, .. } => &aggregate.args,
AggArgumentSource::Expression { aggregate } => &aggregate.args,
}
}
pub fn num_args(&self) -> usize {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(),
AggArgumentSource::Register { aggregate, .. } => aggregate.args.len(),
AggArgumentSource::Expression { aggregate } => aggregate.args.len(),
}
}
/// Read the value of an aggregate function argument
pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result<usize> {
pub fn translate(
&self,
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
resolver: &Resolver,
arg_idx: usize,
) -> Result<usize> {
match self {
AggArgumentSource::PseudoCursor {
cursor_id,
@@ -221,31 +239,47 @@ impl<'a> AggArgumentSource<'a> {
src_reg_start: start_reg,
..
} => Ok(*start_reg + arg_idx),
AggArgumentSource::Expression { aggregate } => {
let dest_reg = program.alloc_register();
translate_expr(
program,
Some(referenced_tables),
&aggregate.args[arg_idx],
dest_reg,
resolver,
)
}
}
}
}
/// Emits the bytecode for processing an aggregate step.
/// E.g. in `SELECT SUM(price) FROM t`, 'price' is evaluated for every row, and the result is added to the accumulator.
///
/// This is distinct from the final step, which is called after the main loop has finished processing
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
/// and the actual result value of the aggregation is materialized.
///
/// Ungrouped aggregation is a special case of grouped aggregation that involves a single group.
///
/// Examples:
/// * In `SELECT SUM(price) FROM t`, `price` is evaluated for each row and added to the accumulator.
/// * In `SELECT product_category, SUM(price) FROM t GROUP BY product_category`, `price` is evaluated for
/// each row in the group and added to that groups accumulator.
pub fn translate_aggregation_step(
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
agg: &Aggregate,
agg_arg_source: AggArgumentSource,
target_register: usize,
resolver: &Resolver,
) -> Result<usize> {
let dest = match agg.func {
let num_args = agg_arg_source.num_args();
let func = agg_arg_source.agg_func();
let dest = match func {
AggFunc::Avg => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("avg bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -255,18 +289,16 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Count | AggFunc::Count0 => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("count bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: if matches!(agg.func, AggFunc::Count0) {
func: if matches!(func, AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
@@ -275,18 +307,16 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::GroupConcat => {
if agg.args.len() != 1 && agg.args.len() != 2 {
if num_args != 1 && num_args != 2 {
crate::bail_parse_error!("group_concat bad number of arguments");
}
let expr_reg = program.alloc_register();
let delimiter_reg = program.alloc_register();
let expr = &agg.args[0];
let delimiter_expr: ast::Expr;
if agg.args.len() == 2 {
match &agg.args[1] {
if num_args == 2 {
match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
@@ -299,8 +329,8 @@ pub fn translate_aggregation_step(
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
}
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
@@ -319,13 +349,12 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Max => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -336,13 +365,12 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Min => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -354,23 +382,12 @@ pub fn translate_aggregation_step(
}
#[cfg(feature = "json")]
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
if agg.args.len() != 2 {
if num_args != 2 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let value_expr = &agg.args[1];
let value_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let _ = translate_expr(
program,
Some(referenced_tables),
value_expr,
value_reg,
resolver,
)?;
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let value_reg = agg_arg_source.translate(program, referenced_tables, resolver, 1)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -382,13 +399,11 @@ pub fn translate_aggregation_step(
}
#[cfg(feature = "json")]
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -398,15 +413,13 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::StringAgg => {
if agg.args.len() != 2 {
if num_args != 2 {
crate::bail_parse_error!("string_agg bad number of arguments");
}
let expr_reg = program.alloc_register();
let delimiter_reg = program.alloc_register();
let expr = &agg.args[0];
let delimiter_expr = match &agg.args[1] {
let delimiter_expr = match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
@@ -414,7 +427,7 @@ pub fn translate_aggregation_step(
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
translate_expr(
program,
Some(referenced_tables),
@@ -433,13 +446,11 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Sum => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("sum bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -449,13 +460,11 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Total => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("total bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -465,31 +474,24 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::External(ref func) => {
let expr_reg = program.alloc_register();
let argc = func.agg_args().map_err(|_| {
LimboError::ExtensionError(
"External aggregate function called with wrong number of arguments".to_string(),
)
})?;
if argc != agg.args.len() {
if argc != num_args {
crate::bail_parse_error!(
"External aggregate function called with wrong number of arguments"
);
}
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
for i in 0..argc {
if i != 0 {
let _ = program.alloc_register();
let _ = agg_arg_source.translate(program, referenced_tables, resolver, i)?;
}
let _ = translate_expr(
program,
Some(referenced_tables),
&agg.args[i],
expr_reg + i,
resolver,
)?;
// invariant: distinct aggregates are only supported for single-argument functions
if argc == 1 {
handle_distinct(program, agg, expr_reg + i);
handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i);
}
}
program.emit_insn(Insn::AggStep {

View File

@@ -1,18 +1,16 @@
use turso_parser::ast;
use super::{
aggregation::handle_distinct,
emitter::{Resolver, TranslateCtx},
emitter::TranslateCtx,
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
order_by::order_by_sorter_insert,
plan::{Distinctness, GroupBy, SelectPlan, TableReferences},
plan::{Distinctness, GroupBy, SelectPlan},
result_row::emit_select_result,
};
use crate::translate::aggregation::{emit_collseq_if_needed, AggArgumentSource};
use crate::translate::aggregation::{translate_aggregation_step, AggArgumentSource};
use crate::translate::expr::{walk_expr, WalkControl};
use crate::translate::plan::ResultSetColumn;
use crate::{
function::AggFunc,
schema::PseudoCursorType,
translate::collate::CollationSeq,
util::exprs_are_equivalent,
@@ -21,7 +19,7 @@ use crate::{
insn::Insn,
BranchOffset,
},
LimboError, Result,
Result,
};
/// Labels needed for various jumps in GROUP BY handling.
@@ -509,7 +507,7 @@ pub fn group_by_process_single_group(
AggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
}
};
translate_aggregation_step_groupby(
translate_aggregation_step(
program,
&plan.table_references,
agg_arg_source,
@@ -799,253 +797,3 @@ pub fn group_by_emit_row_phase<'a>(
program.preassign_label_to_next_insn(labels.label_group_by_end);
Ok(())
}
/// Emits the bytecode for processing an aggregate step within a GROUP BY clause.
/// Eg. in `SELECT product_category, SUM(price) FROM t GROUP BY line_item`, 'price' is evaluated for every row
/// where the 'product_category' is the same, and the result is added to the accumulator for that category.
///
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
/// and the actual result value of the aggregation is materialized.
pub fn translate_aggregation_step_groupby(
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
agg_arg_source: AggArgumentSource,
target_register: usize,
resolver: &Resolver,
) -> Result<usize> {
let num_args = agg_arg_source.num_args();
let dest = match agg_arg_source.agg_func() {
AggFunc::Avg => {
if num_args != 1 {
crate::bail_parse_error!("avg bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Avg,
});
target_register
}
AggFunc::Count | AggFunc::Count0 => {
if num_args != 1 {
crate::bail_parse_error!("count bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: if matches!(agg_arg_source.agg_func(), AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
},
});
target_register
}
AggFunc::GroupConcat => {
let num_args = agg_arg_source.num_args();
if num_args != 1 && num_args != 2 {
crate::bail_parse_error!("group_concat bad number of arguments");
}
let delimiter_reg = program.alloc_register();
let delimiter_expr: ast::Expr;
if num_args == 2 {
match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
ast::Expr::Literal(ast::Literal::String(s)) => {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string()));
}
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
} else {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
&delimiter_expr,
delimiter_reg,
resolver,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: delimiter_reg,
func: AggFunc::GroupConcat,
});
target_register
}
AggFunc::Max => {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Max,
});
target_register
}
AggFunc::Min => {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Min,
});
target_register
}
#[cfg(feature = "json")]
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::JsonGroupArray,
});
target_register
}
#[cfg(feature = "json")]
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
if num_args != 2 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let value_reg = agg_arg_source.translate(program, 1)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: value_reg,
func: AggFunc::JsonGroupObject,
});
target_register
}
AggFunc::StringAgg => {
if num_args != 2 {
crate::bail_parse_error!("string_agg bad number of arguments");
}
let delimiter_reg = program.alloc_register();
let delimiter_expr = match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
}
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
&delimiter_expr,
delimiter_reg,
resolver,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: delimiter_reg,
func: AggFunc::StringAgg,
});
target_register
}
AggFunc::Sum => {
if num_args != 1 {
crate::bail_parse_error!("sum bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Sum,
});
target_register
}
AggFunc::Total => {
if num_args != 1 {
crate::bail_parse_error!("total bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Total,
});
target_register
}
AggFunc::External(ref func) => {
let argc = func.agg_args().map_err(|_| {
LimboError::ExtensionError(
"External aggregate function called with wrong number of arguments".to_string(),
)
})?;
if argc != num_args {
crate::bail_parse_error!(
"External aggregate function called with wrong number of arguments"
);
}
let expr_reg = agg_arg_source.translate(program, 0)?;
for i in 0..argc {
if i != 0 {
let _ = agg_arg_source.translate(program, i)?;
}
// invariant: distinct aggregates are only supported for single-argument functions
if argc == 1 {
handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i);
}
}
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::External(func.clone()),
});
target_register
}
};
Ok(dest)
}

View File

@@ -19,7 +19,7 @@ use crate::{
};
use super::{
aggregation::translate_aggregation_step,
aggregation::{translate_aggregation_step, AggArgumentSource},
emitter::{OperationMode, TranslateCtx},
expr::{
translate_condition_expr, translate_expr, translate_expr_no_constant_opt,
@@ -868,7 +868,7 @@ fn emit_loop_source(
translate_aggregation_step(
program,
&plan.table_references,
agg,
AggArgumentSource::new_from_expression(agg),
reg,
&t_ctx.resolver,
)?;