From 6f1cd17fcf7c99d2e1d2424f448bc63eab1cf750 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Sun, 31 Aug 2025 11:07:54 +0200 Subject: [PATCH] Consolidate methods emitting AggStep --- core/translate/aggregation.rs | 162 ++++++++++----------- core/translate/group_by.rs | 262 +--------------------------------- core/translate/main_loop.rs | 4 +- 3 files changed, 89 insertions(+), 339 deletions(-) diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index 2e43c3ea9..7a1de9776 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -61,7 +61,7 @@ pub fn emit_ungrouped_aggregation<'a>( Ok(()) } -pub fn emit_collseq_if_needed( +fn emit_collseq_if_needed( program: &mut ProgramBuilder, referenced_tables: &TableReferences, expr: &ast::Expr, @@ -134,6 +134,7 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re /// the data from the sorter. /// * In grouped cases where no sorting is required, arguments are retrieved directly /// from registers allocated in the main loop. +/// * In ungrouped cases, arguments are computed directly from the `args` expressions. pub enum AggArgumentSource<'a> { /// The aggregate function arguments are retrieved from a pseudo cursor /// which reads from the GROUP BY sorter. @@ -149,6 +150,8 @@ pub enum AggArgumentSource<'a> { src_reg_start: usize, aggregate: &'a Aggregate, }, + /// The aggregate function arguments are retrieved by evaluating expressions. + Expression { aggregate: &'a Aggregate }, } impl<'a> AggArgumentSource<'a> { @@ -176,10 +179,16 @@ impl<'a> AggArgumentSource<'a> { } } + /// Create a new [AggArgumentSource] that retrieves the values by evaluating `args` expressions. + pub fn new_from_expression(aggregate: &'a Aggregate) -> Self { + Self::Expression { aggregate } + } + pub fn aggregate(&self) -> &Aggregate { match self { AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate, AggArgumentSource::Register { aggregate, .. } => aggregate, + AggArgumentSource::Expression { aggregate } => aggregate, } } @@ -187,22 +196,31 @@ impl<'a> AggArgumentSource<'a> { match self { AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func, AggArgumentSource::Register { aggregate, .. } => &aggregate.func, + AggArgumentSource::Expression { aggregate } => &aggregate.func, } } pub fn args(&self) -> &[ast::Expr] { match self { AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args, AggArgumentSource::Register { aggregate, .. } => &aggregate.args, + AggArgumentSource::Expression { aggregate } => &aggregate.args, } } pub fn num_args(&self) -> usize { match self { AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(), AggArgumentSource::Register { aggregate, .. } => aggregate.args.len(), + AggArgumentSource::Expression { aggregate } => aggregate.args.len(), } } /// Read the value of an aggregate function argument - pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result { + pub fn translate( + &self, + program: &mut ProgramBuilder, + referenced_tables: &TableReferences, + resolver: &Resolver, + arg_idx: usize, + ) -> Result { match self { AggArgumentSource::PseudoCursor { cursor_id, @@ -221,31 +239,47 @@ impl<'a> AggArgumentSource<'a> { src_reg_start: start_reg, .. } => Ok(*start_reg + arg_idx), + AggArgumentSource::Expression { aggregate } => { + let dest_reg = program.alloc_register(); + translate_expr( + program, + Some(referenced_tables), + &aggregate.args[arg_idx], + dest_reg, + resolver, + ) + } } } } /// Emits the bytecode for processing an aggregate step. -/// E.g. in `SELECT SUM(price) FROM t`, 'price' is evaluated for every row, and the result is added to the accumulator. /// -/// This is distinct from the final step, which is called after the main loop has finished processing +/// This is distinct from the final step, which is called after a single group has been entirely accumulated, /// and the actual result value of the aggregation is materialized. +/// +/// Ungrouped aggregation is a special case of grouped aggregation that involves a single group. +/// +/// Examples: +/// * In `SELECT SUM(price) FROM t`, `price` is evaluated for each row and added to the accumulator. +/// * In `SELECT product_category, SUM(price) FROM t GROUP BY product_category`, `price` is evaluated for +/// each row in the group and added to that group’s accumulator. pub fn translate_aggregation_step( program: &mut ProgramBuilder, referenced_tables: &TableReferences, - agg: &Aggregate, + agg_arg_source: AggArgumentSource, target_register: usize, resolver: &Resolver, ) -> Result { - let dest = match agg.func { + let num_args = agg_arg_source.num_args(); + let func = agg_arg_source.agg_func(); + let dest = match func { AggFunc::Avg => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("avg bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -255,18 +289,16 @@ pub fn translate_aggregation_step( target_register } AggFunc::Count | AggFunc::Count0 => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("count bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, delimiter: 0, - func: if matches!(agg.func, AggFunc::Count0) { + func: if matches!(func, AggFunc::Count0) { AggFunc::Count0 } else { AggFunc::Count @@ -275,18 +307,16 @@ pub fn translate_aggregation_step( target_register } AggFunc::GroupConcat => { - if agg.args.len() != 1 && agg.args.len() != 2 { + if num_args != 1 && num_args != 2 { crate::bail_parse_error!("group_concat bad number of arguments"); } - let expr_reg = program.alloc_register(); let delimiter_reg = program.alloc_register(); - let expr = &agg.args[0]; let delimiter_expr: ast::Expr; - if agg.args.len() == 2 { - match &agg.args[1] { + if num_args == 2 { + match &agg_arg_source.args()[1] { arg @ ast::Expr::Column { .. } => { delimiter_expr = arg.clone(); } @@ -299,8 +329,8 @@ pub fn translate_aggregation_step( delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\""))); } - translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); translate_expr( program, Some(referenced_tables), @@ -319,13 +349,12 @@ pub fn translate_aggregation_step( target_register } AggFunc::Max => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("max bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let expr = &agg_arg_source.args()[0]; emit_collseq_if_needed(program, referenced_tables, expr); program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -336,13 +365,12 @@ pub fn translate_aggregation_step( target_register } AggFunc::Min => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("min bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let expr = &agg_arg_source.args()[0]; emit_collseq_if_needed(program, referenced_tables, expr); program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -354,23 +382,12 @@ pub fn translate_aggregation_step( } #[cfg(feature = "json")] AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => { - if agg.args.len() != 2 { + if num_args != 2 { crate::bail_parse_error!("max bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let value_expr = &agg.args[1]; - let value_reg = program.alloc_register(); - - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); - let _ = translate_expr( - program, - Some(referenced_tables), - value_expr, - value_reg, - resolver, - )?; + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let value_reg = agg_arg_source.translate(program, referenced_tables, resolver, 1)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -382,13 +399,11 @@ pub fn translate_aggregation_step( } #[cfg(feature = "json")] AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("max bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -398,15 +413,13 @@ pub fn translate_aggregation_step( target_register } AggFunc::StringAgg => { - if agg.args.len() != 2 { + if num_args != 2 { crate::bail_parse_error!("string_agg bad number of arguments"); } - let expr_reg = program.alloc_register(); let delimiter_reg = program.alloc_register(); - let expr = &agg.args[0]; - let delimiter_expr = match &agg.args[1] { + let delimiter_expr = match &agg_arg_source.args()[1] { arg @ ast::Expr::Column { .. } => arg.clone(), ast::Expr::Literal(ast::Literal::String(s)) => { ast::Expr::Literal(ast::Literal::String(s.to_string())) @@ -414,7 +427,7 @@ pub fn translate_aggregation_step( _ => crate::bail_parse_error!("Incorrect delimiter parameter"), }; - translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; translate_expr( program, Some(referenced_tables), @@ -433,13 +446,11 @@ pub fn translate_aggregation_step( target_register } AggFunc::Sum => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("sum bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -449,13 +460,11 @@ pub fn translate_aggregation_step( target_register } AggFunc::Total => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("total bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -465,31 +474,24 @@ pub fn translate_aggregation_step( target_register } AggFunc::External(ref func) => { - let expr_reg = program.alloc_register(); let argc = func.agg_args().map_err(|_| { LimboError::ExtensionError( "External aggregate function called with wrong number of arguments".to_string(), ) })?; - if argc != agg.args.len() { + if argc != num_args { crate::bail_parse_error!( "External aggregate function called with wrong number of arguments" ); } + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; for i in 0..argc { if i != 0 { - let _ = program.alloc_register(); + let _ = agg_arg_source.translate(program, referenced_tables, resolver, i)?; } - let _ = translate_expr( - program, - Some(referenced_tables), - &agg.args[i], - expr_reg + i, - resolver, - )?; // invariant: distinct aggregates are only supported for single-argument functions if argc == 1 { - handle_distinct(program, agg, expr_reg + i); + handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i); } } program.emit_insn(Insn::AggStep { diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 715bb5c93..37c05cc45 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -1,18 +1,16 @@ use turso_parser::ast; use super::{ - aggregation::handle_distinct, - emitter::{Resolver, TranslateCtx}, + emitter::TranslateCtx, expr::{translate_condition_expr, translate_expr, ConditionMetadata}, order_by::order_by_sorter_insert, - plan::{Distinctness, GroupBy, SelectPlan, TableReferences}, + plan::{Distinctness, GroupBy, SelectPlan}, result_row::emit_select_result, }; -use crate::translate::aggregation::{emit_collseq_if_needed, AggArgumentSource}; +use crate::translate::aggregation::{translate_aggregation_step, AggArgumentSource}; use crate::translate::expr::{walk_expr, WalkControl}; use crate::translate::plan::ResultSetColumn; use crate::{ - function::AggFunc, schema::PseudoCursorType, translate::collate::CollationSeq, util::exprs_are_equivalent, @@ -21,7 +19,7 @@ use crate::{ insn::Insn, BranchOffset, }, - LimboError, Result, + Result, }; /// Labels needed for various jumps in GROUP BY handling. @@ -509,7 +507,7 @@ pub fn group_by_process_single_group( AggArgumentSource::new_from_registers(start_reg_aggs + offset, agg) } }; - translate_aggregation_step_groupby( + translate_aggregation_step( program, &plan.table_references, agg_arg_source, @@ -799,253 +797,3 @@ pub fn group_by_emit_row_phase<'a>( program.preassign_label_to_next_insn(labels.label_group_by_end); Ok(()) } - -/// Emits the bytecode for processing an aggregate step within a GROUP BY clause. -/// Eg. in `SELECT product_category, SUM(price) FROM t GROUP BY line_item`, 'price' is evaluated for every row -/// where the 'product_category' is the same, and the result is added to the accumulator for that category. -/// -/// This is distinct from the final step, which is called after a single group has been entirely accumulated, -/// and the actual result value of the aggregation is materialized. -pub fn translate_aggregation_step_groupby( - program: &mut ProgramBuilder, - referenced_tables: &TableReferences, - agg_arg_source: AggArgumentSource, - target_register: usize, - resolver: &Resolver, -) -> Result { - let num_args = agg_arg_source.num_args(); - let dest = match agg_arg_source.agg_func() { - AggFunc::Avg => { - if num_args != 1 { - crate::bail_parse_error!("avg bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Avg, - }); - target_register - } - AggFunc::Count | AggFunc::Count0 => { - if num_args != 1 { - crate::bail_parse_error!("count bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: if matches!(agg_arg_source.agg_func(), AggFunc::Count0) { - AggFunc::Count0 - } else { - AggFunc::Count - }, - }); - target_register - } - AggFunc::GroupConcat => { - let num_args = agg_arg_source.num_args(); - if num_args != 1 && num_args != 2 { - crate::bail_parse_error!("group_concat bad number of arguments"); - } - - let delimiter_reg = program.alloc_register(); - - let delimiter_expr: ast::Expr; - - if num_args == 2 { - match &agg_arg_source.args()[1] { - arg @ ast::Expr::Column { .. } => { - delimiter_expr = arg.clone(); - } - ast::Expr::Literal(ast::Literal::String(s)) => { - delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string())); - } - _ => crate::bail_parse_error!("Incorrect delimiter parameter"), - }; - } else { - delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\""))); - } - - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - translate_expr( - program, - Some(referenced_tables), - &delimiter_expr, - delimiter_reg, - resolver, - )?; - - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: delimiter_reg, - func: AggFunc::GroupConcat, - }); - - target_register - } - AggFunc::Max => { - if num_args != 1 { - crate::bail_parse_error!("max bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - let expr = &agg_arg_source.args()[0]; - emit_collseq_if_needed(program, referenced_tables, expr); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Max, - }); - target_register - } - AggFunc::Min => { - if num_args != 1 { - crate::bail_parse_error!("min bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - let expr = &agg_arg_source.args()[0]; - emit_collseq_if_needed(program, referenced_tables, expr); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Min, - }); - target_register - } - #[cfg(feature = "json")] - AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => { - if num_args != 1 { - crate::bail_parse_error!("min bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::JsonGroupArray, - }); - target_register - } - #[cfg(feature = "json")] - AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => { - if num_args != 2 { - crate::bail_parse_error!("max bad number of arguments"); - } - - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - let value_reg = agg_arg_source.translate(program, 1)?; - - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: value_reg, - func: AggFunc::JsonGroupObject, - }); - target_register - } - AggFunc::StringAgg => { - if num_args != 2 { - crate::bail_parse_error!("string_agg bad number of arguments"); - } - - let delimiter_reg = program.alloc_register(); - - let delimiter_expr = match &agg_arg_source.args()[1] { - arg @ ast::Expr::Column { .. } => arg.clone(), - ast::Expr::Literal(ast::Literal::String(s)) => { - ast::Expr::Literal(ast::Literal::String(s.to_string())) - } - _ => crate::bail_parse_error!("Incorrect delimiter parameter"), - }; - - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - translate_expr( - program, - Some(referenced_tables), - &delimiter_expr, - delimiter_reg, - resolver, - )?; - - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: delimiter_reg, - func: AggFunc::StringAgg, - }); - - target_register - } - AggFunc::Sum => { - if num_args != 1 { - crate::bail_parse_error!("sum bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Sum, - }); - target_register - } - AggFunc::Total => { - if num_args != 1 { - crate::bail_parse_error!("total bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Total, - }); - target_register - } - AggFunc::External(ref func) => { - let argc = func.agg_args().map_err(|_| { - LimboError::ExtensionError( - "External aggregate function called with wrong number of arguments".to_string(), - ) - })?; - if argc != num_args { - crate::bail_parse_error!( - "External aggregate function called with wrong number of arguments" - ); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - for i in 0..argc { - if i != 0 { - let _ = agg_arg_source.translate(program, i)?; - } - // invariant: distinct aggregates are only supported for single-argument functions - if argc == 1 { - handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i); - } - } - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::External(func.clone()), - }); - target_register - } - }; - Ok(dest) -} diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 2617c86d8..06d801705 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -19,7 +19,7 @@ use crate::{ }; use super::{ - aggregation::translate_aggregation_step, + aggregation::{translate_aggregation_step, AggArgumentSource}, emitter::{OperationMode, TranslateCtx}, expr::{ translate_condition_expr, translate_expr, translate_expr_no_constant_opt, @@ -868,7 +868,7 @@ fn emit_loop_source( translate_aggregation_step( program, &plan.table_references, - agg, + AggArgumentSource::new_from_expression(agg), reg, &t_ctx.resolver, )?;