diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index ad902097e..7a1de9776 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -125,27 +125,161 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re }); } -/// Emits the bytecode for processing an aggregate step. -/// E.g. in `SELECT SUM(price) FROM t`, 'price' is evaluated for every row, and the result is added to the accumulator. +/// Enum representing the source of the aggregate function arguments /// -/// This is distinct from the final step, which is called after the main loop has finished processing +/// Aggregate arguments can come from different sources, depending on how the aggregation +/// is evaluated: +/// * In the common grouped case, the aggregate function arguments are first inserted +/// into a sorter in the main loop, and in the group by aggregation phase we read +/// the data from the sorter. +/// * In grouped cases where no sorting is required, arguments are retrieved directly +/// from registers allocated in the main loop. +/// * In ungrouped cases, arguments are computed directly from the `args` expressions. +pub enum AggArgumentSource<'a> { + /// The aggregate function arguments are retrieved from a pseudo cursor + /// which reads from the GROUP BY sorter. + PseudoCursor { + cursor_id: usize, + col_start: usize, + dest_reg_start: usize, + aggregate: &'a Aggregate, + }, + /// The aggregate function arguments are retrieved from a contiguous block of registers + /// allocated in the main loop for that given aggregate function. + Register { + src_reg_start: usize, + aggregate: &'a Aggregate, + }, + /// The aggregate function arguments are retrieved by evaluating expressions. + Expression { aggregate: &'a Aggregate }, +} + +impl<'a> AggArgumentSource<'a> { + /// Create a new [AggArgumentSource] that retrieves the values from a GROUP BY sorter. + pub fn new_from_cursor( + program: &mut ProgramBuilder, + cursor_id: usize, + col_start: usize, + aggregate: &'a Aggregate, + ) -> Self { + let dest_reg_start = program.alloc_registers(aggregate.args.len()); + Self::PseudoCursor { + cursor_id, + col_start, + dest_reg_start, + aggregate, + } + } + /// Create a new [AggArgumentSource] that retrieves the values directly from an already + /// populated register or registers. + pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self { + Self::Register { + src_reg_start, + aggregate, + } + } + + /// Create a new [AggArgumentSource] that retrieves the values by evaluating `args` expressions. + pub fn new_from_expression(aggregate: &'a Aggregate) -> Self { + Self::Expression { aggregate } + } + + pub fn aggregate(&self) -> &Aggregate { + match self { + AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate, + AggArgumentSource::Register { aggregate, .. } => aggregate, + AggArgumentSource::Expression { aggregate } => aggregate, + } + } + + pub fn agg_func(&self) -> &AggFunc { + match self { + AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func, + AggArgumentSource::Register { aggregate, .. } => &aggregate.func, + AggArgumentSource::Expression { aggregate } => &aggregate.func, + } + } + pub fn args(&self) -> &[ast::Expr] { + match self { + AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args, + AggArgumentSource::Register { aggregate, .. } => &aggregate.args, + AggArgumentSource::Expression { aggregate } => &aggregate.args, + } + } + pub fn num_args(&self) -> usize { + match self { + AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(), + AggArgumentSource::Register { aggregate, .. } => aggregate.args.len(), + AggArgumentSource::Expression { aggregate } => aggregate.args.len(), + } + } + /// Read the value of an aggregate function argument + pub fn translate( + &self, + program: &mut ProgramBuilder, + referenced_tables: &TableReferences, + resolver: &Resolver, + arg_idx: usize, + ) -> Result { + match self { + AggArgumentSource::PseudoCursor { + cursor_id, + col_start, + dest_reg_start, + .. + } => { + program.emit_column_or_rowid( + *cursor_id, + *col_start + arg_idx, + dest_reg_start + arg_idx, + ); + Ok(dest_reg_start + arg_idx) + } + AggArgumentSource::Register { + src_reg_start: start_reg, + .. + } => Ok(*start_reg + arg_idx), + AggArgumentSource::Expression { aggregate } => { + let dest_reg = program.alloc_register(); + translate_expr( + program, + Some(referenced_tables), + &aggregate.args[arg_idx], + dest_reg, + resolver, + ) + } + } + } +} + +/// Emits the bytecode for processing an aggregate step. +/// +/// This is distinct from the final step, which is called after a single group has been entirely accumulated, /// and the actual result value of the aggregation is materialized. +/// +/// Ungrouped aggregation is a special case of grouped aggregation that involves a single group. +/// +/// Examples: +/// * In `SELECT SUM(price) FROM t`, `price` is evaluated for each row and added to the accumulator. +/// * In `SELECT product_category, SUM(price) FROM t GROUP BY product_category`, `price` is evaluated for +/// each row in the group and added to that group’s accumulator. pub fn translate_aggregation_step( program: &mut ProgramBuilder, referenced_tables: &TableReferences, - agg: &Aggregate, + agg_arg_source: AggArgumentSource, target_register: usize, resolver: &Resolver, ) -> Result { - let dest = match agg.func { + let num_args = agg_arg_source.num_args(); + let func = agg_arg_source.agg_func(); + let dest = match func { AggFunc::Avg => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("avg bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -155,20 +289,16 @@ pub fn translate_aggregation_step( target_register } AggFunc::Count | AggFunc::Count0 => { - let expr_reg = if agg.args.is_empty() { - program.alloc_register() - } else { - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - expr_reg - }; - handle_distinct(program, agg, expr_reg); + if num_args != 1 { + crate::bail_parse_error!("count bad number of arguments"); + } + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, delimiter: 0, - func: if matches!(agg.func, AggFunc::Count0) { + func: if matches!(func, AggFunc::Count0) { AggFunc::Count0 } else { AggFunc::Count @@ -177,18 +307,16 @@ pub fn translate_aggregation_step( target_register } AggFunc::GroupConcat => { - if agg.args.len() != 1 && agg.args.len() != 2 { + if num_args != 1 && num_args != 2 { crate::bail_parse_error!("group_concat bad number of arguments"); } - let expr_reg = program.alloc_register(); let delimiter_reg = program.alloc_register(); - let expr = &agg.args[0]; let delimiter_expr: ast::Expr; - if agg.args.len() == 2 { - match &agg.args[1] { + if num_args == 2 { + match &agg_arg_source.args()[1] { arg @ ast::Expr::Column { .. } => { delimiter_expr = arg.clone(); } @@ -201,8 +329,8 @@ pub fn translate_aggregation_step( delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\""))); } - translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); translate_expr( program, Some(referenced_tables), @@ -221,13 +349,12 @@ pub fn translate_aggregation_step( target_register } AggFunc::Max => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("max bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let expr = &agg_arg_source.args()[0]; emit_collseq_if_needed(program, referenced_tables, expr); program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -238,13 +365,12 @@ pub fn translate_aggregation_step( target_register } AggFunc::Min => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("min bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let expr = &agg_arg_source.args()[0]; emit_collseq_if_needed(program, referenced_tables, expr); program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -256,23 +382,12 @@ pub fn translate_aggregation_step( } #[cfg(feature = "json")] AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => { - if agg.args.len() != 2 { + if num_args != 2 { crate::bail_parse_error!("max bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let value_expr = &agg.args[1]; - let value_reg = program.alloc_register(); - - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); - let _ = translate_expr( - program, - Some(referenced_tables), - value_expr, - value_reg, - resolver, - )?; + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); + let value_reg = agg_arg_source.translate(program, referenced_tables, resolver, 1)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -284,13 +399,11 @@ pub fn translate_aggregation_step( } #[cfg(feature = "json")] AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("max bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -300,15 +413,13 @@ pub fn translate_aggregation_step( target_register } AggFunc::StringAgg => { - if agg.args.len() != 2 { + if num_args != 2 { crate::bail_parse_error!("string_agg bad number of arguments"); } - let expr_reg = program.alloc_register(); let delimiter_reg = program.alloc_register(); - let expr = &agg.args[0]; - let delimiter_expr = match &agg.args[1] { + let delimiter_expr = match &agg_arg_source.args()[1] { arg @ ast::Expr::Column { .. } => arg.clone(), ast::Expr::Literal(ast::Literal::String(s)) => { ast::Expr::Literal(ast::Literal::String(s.to_string())) @@ -316,7 +427,7 @@ pub fn translate_aggregation_step( _ => crate::bail_parse_error!("Incorrect delimiter parameter"), }; - translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; translate_expr( program, Some(referenced_tables), @@ -335,13 +446,11 @@ pub fn translate_aggregation_step( target_register } AggFunc::Sum => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("sum bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -351,13 +460,11 @@ pub fn translate_aggregation_step( target_register } AggFunc::Total => { - if agg.args.len() != 1 { + if num_args != 1 { crate::bail_parse_error!("total bad number of arguments"); } - let expr = &agg.args[0]; - let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?; - handle_distinct(program, agg, expr_reg); + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; + handle_distinct(program, agg_arg_source.aggregate(), expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -367,31 +474,24 @@ pub fn translate_aggregation_step( target_register } AggFunc::External(ref func) => { - let expr_reg = program.alloc_register(); let argc = func.agg_args().map_err(|_| { LimboError::ExtensionError( "External aggregate function called with wrong number of arguments".to_string(), ) })?; - if argc != agg.args.len() { + if argc != num_args { crate::bail_parse_error!( "External aggregate function called with wrong number of arguments" ); } + let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?; for i in 0..argc { if i != 0 { - let _ = program.alloc_register(); + let _ = agg_arg_source.translate(program, referenced_tables, resolver, i)?; } - let _ = translate_expr( - program, - Some(referenced_tables), - &agg.args[i], - expr_reg + i, - resolver, - )?; // invariant: distinct aggregates are only supported for single-argument functions if argc == 1 { - handle_distinct(program, agg, expr_reg + i); + handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i); } } program.emit_insn(Insn::AggStep { diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 0a524f348..37c05cc45 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -1,9 +1,16 @@ use turso_parser::ast; +use super::{ + emitter::TranslateCtx, + expr::{translate_condition_expr, translate_expr, ConditionMetadata}, + order_by::order_by_sorter_insert, + plan::{Distinctness, GroupBy, SelectPlan}, + result_row::emit_select_result, +}; +use crate::translate::aggregation::{translate_aggregation_step, AggArgumentSource}; use crate::translate::expr::{walk_expr, WalkControl}; use crate::translate::plan::ResultSetColumn; use crate::{ - function::AggFunc, schema::PseudoCursorType, translate::collate::CollationSeq, util::exprs_are_equivalent, @@ -15,15 +22,6 @@ use crate::{ Result, }; -use super::{ - aggregation::handle_distinct, - emitter::{Resolver, TranslateCtx}, - expr::{translate_condition_expr, translate_expr, ConditionMetadata}, - order_by::order_by_sorter_insert, - plan::{Aggregate, Distinctness, GroupBy, SelectPlan, TableReferences}, - result_row::emit_select_result, -}; - /// Labels needed for various jumps in GROUP BY handling. #[derive(Debug)] pub struct GroupByLabels { @@ -394,102 +392,6 @@ pub enum GroupByRowSource { }, } -/// Enum representing the source of the aggregate function arguments -/// emitted for a group by aggregation. -/// In the common case, the aggregate function arguments are first inserted -/// into a sorter in the main loop, and in the group by aggregation phase -/// we read the data from the sorter. -/// -/// In the alternative case, no sorting is required for group by, -/// and the aggregate function arguments are retrieved directly from -/// registers allocated in the main loop. -pub enum GroupByAggArgumentSource<'a> { - /// The aggregate function arguments are retrieved from a pseudo cursor - /// which reads from the GROUP BY sorter. - PseudoCursor { - cursor_id: usize, - col_start: usize, - dest_reg_start: usize, - aggregate: &'a Aggregate, - }, - /// The aggregate function arguments are retrieved from a contiguous block of registers - /// allocated in the main loop for that given aggregate function. - Register { - src_reg_start: usize, - aggregate: &'a Aggregate, - }, -} - -impl<'a> GroupByAggArgumentSource<'a> { - /// Create a new [GroupByAggArgumentSource] that retrieves the values from a GROUP BY sorter. - pub fn new_from_cursor( - program: &mut ProgramBuilder, - cursor_id: usize, - col_start: usize, - aggregate: &'a Aggregate, - ) -> Self { - let dest_reg_start = program.alloc_registers(aggregate.args.len()); - Self::PseudoCursor { - cursor_id, - col_start, - dest_reg_start, - aggregate, - } - } - /// Create a new [GroupByAggArgumentSource] that retrieves the values directly from an already - /// populated register or registers. - pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self { - Self::Register { - src_reg_start, - aggregate, - } - } - - pub fn aggregate(&self) -> &Aggregate { - match self { - GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate, - GroupByAggArgumentSource::Register { aggregate, .. } => aggregate, - } - } - - pub fn agg_func(&self) -> &AggFunc { - match self { - GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func, - GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.func, - } - } - pub fn args(&self) -> &[ast::Expr] { - match self { - GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args, - GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.args, - } - } - pub fn num_args(&self) -> usize { - match self { - GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(), - GroupByAggArgumentSource::Register { aggregate, .. } => aggregate.args.len(), - } - } - /// Read the value of an aggregate function argument either from sorter data or directly from a register. - pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result { - match self { - GroupByAggArgumentSource::PseudoCursor { - cursor_id, - col_start, - dest_reg_start, - .. - } => { - program.emit_column_or_rowid(*cursor_id, *col_start, dest_reg_start + arg_idx); - Ok(dest_reg_start + arg_idx) - } - GroupByAggArgumentSource::Register { - src_reg_start: start_reg, - .. - } => Ok(*start_reg + arg_idx), - } - } -} - /// Emits bytecode for processing a single GROUP BY group. pub fn group_by_process_single_group( program: &mut ProgramBuilder, @@ -593,21 +495,19 @@ pub fn group_by_process_single_group( .expect("aggregate registers must be initialized"); let agg_result_reg = start_reg + i; let agg_arg_source = match &row_source { - GroupByRowSource::Sorter { pseudo_cursor, .. } => { - GroupByAggArgumentSource::new_from_cursor( - program, - *pseudo_cursor, - cursor_index + offset, - agg, - ) - } + GroupByRowSource::Sorter { pseudo_cursor, .. } => AggArgumentSource::new_from_cursor( + program, + *pseudo_cursor, + cursor_index + offset, + agg, + ), GroupByRowSource::MainLoop { start_reg_src, .. } => { // Aggregation arguments are always placed in the registers that follow any scalars. let start_reg_aggs = start_reg_src + t_ctx.non_aggregate_expressions.len(); - GroupByAggArgumentSource::new_from_registers(start_reg_aggs + offset, agg) + AggArgumentSource::new_from_registers(start_reg_aggs + offset, agg) } }; - translate_aggregation_step_groupby( + translate_aggregation_step( program, &plan.table_references, agg_arg_source, @@ -897,220 +797,3 @@ pub fn group_by_emit_row_phase<'a>( program.preassign_label_to_next_insn(labels.label_group_by_end); Ok(()) } - -/// Emits the bytecode for processing an aggregate step within a GROUP BY clause. -/// Eg. in `SELECT product_category, SUM(price) FROM t GROUP BY line_item`, 'price' is evaluated for every row -/// where the 'product_category' is the same, and the result is added to the accumulator for that category. -/// -/// This is distinct from the final step, which is called after a single group has been entirely accumulated, -/// and the actual result value of the aggregation is materialized. -pub fn translate_aggregation_step_groupby( - program: &mut ProgramBuilder, - referenced_tables: &TableReferences, - agg_arg_source: GroupByAggArgumentSource, - target_register: usize, - resolver: &Resolver, -) -> Result { - let num_args = agg_arg_source.num_args(); - let dest = match agg_arg_source.agg_func() { - AggFunc::Avg => { - if num_args != 1 { - crate::bail_parse_error!("avg bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Avg, - }); - target_register - } - AggFunc::Count | AggFunc::Count0 => { - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: if matches!(agg_arg_source.agg_func(), AggFunc::Count0) { - AggFunc::Count0 - } else { - AggFunc::Count - }, - }); - target_register - } - AggFunc::GroupConcat => { - let num_args = agg_arg_source.num_args(); - if num_args != 1 && num_args != 2 { - crate::bail_parse_error!("group_concat bad number of arguments"); - } - - let delimiter_reg = program.alloc_register(); - - let delimiter_expr: ast::Expr; - - if num_args == 2 { - match &agg_arg_source.args()[1] { - arg @ ast::Expr::Column { .. } => { - delimiter_expr = arg.clone(); - } - ast::Expr::Literal(ast::Literal::String(s)) => { - delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string())); - } - _ => crate::bail_parse_error!("Incorrect delimiter parameter"), - }; - } else { - delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\""))); - } - - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - translate_expr( - program, - Some(referenced_tables), - &delimiter_expr, - delimiter_reg, - resolver, - )?; - - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: delimiter_reg, - func: AggFunc::GroupConcat, - }); - - target_register - } - AggFunc::Max => { - if num_args != 1 { - crate::bail_parse_error!("max bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Max, - }); - target_register - } - AggFunc::Min => { - if num_args != 1 { - crate::bail_parse_error!("min bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Min, - }); - target_register - } - #[cfg(feature = "json")] - AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => { - if num_args != 1 { - crate::bail_parse_error!("min bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::JsonGroupArray, - }); - target_register - } - #[cfg(feature = "json")] - AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => { - if num_args != 2 { - crate::bail_parse_error!("max bad number of arguments"); - } - - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - let value_reg = agg_arg_source.translate(program, 1)?; - - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: value_reg, - func: AggFunc::JsonGroupObject, - }); - target_register - } - AggFunc::StringAgg => { - if num_args != 2 { - crate::bail_parse_error!("string_agg bad number of arguments"); - } - - let delimiter_reg = program.alloc_register(); - - let delimiter_expr = match &agg_arg_source.args()[1] { - arg @ ast::Expr::Column { .. } => arg.clone(), - ast::Expr::Literal(ast::Literal::String(s)) => { - ast::Expr::Literal(ast::Literal::String(s.to_string())) - } - _ => crate::bail_parse_error!("Incorrect delimiter parameter"), - }; - - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - translate_expr( - program, - Some(referenced_tables), - &delimiter_expr, - delimiter_reg, - resolver, - )?; - - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: delimiter_reg, - func: AggFunc::StringAgg, - }); - - target_register - } - AggFunc::Sum => { - if num_args != 1 { - crate::bail_parse_error!("sum bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Sum, - }); - target_register - } - AggFunc::Total => { - if num_args != 1 { - crate::bail_parse_error!("total bad number of arguments"); - } - let expr_reg = agg_arg_source.translate(program, 0)?; - handle_distinct(program, agg_arg_source.aggregate(), expr_reg); - program.emit_insn(Insn::AggStep { - acc_reg: target_register, - col: expr_reg, - delimiter: 0, - func: AggFunc::Total, - }); - target_register - } - AggFunc::External(_) => { - todo!("External aggregate functions are not yet supported in GROUP BY"); - } - }; - Ok(dest) -} diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 2617c86d8..06d801705 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -19,7 +19,7 @@ use crate::{ }; use super::{ - aggregation::translate_aggregation_step, + aggregation::{translate_aggregation_step, AggArgumentSource}, emitter::{OperationMode, TranslateCtx}, expr::{ translate_condition_expr, translate_expr, translate_expr_no_constant_opt, @@ -868,7 +868,7 @@ fn emit_loop_source( translate_aggregation_step( program, &plan.table_references, - agg, + AggArgumentSource::new_from_expression(agg), reg, &t_ctx.resolver, )?; diff --git a/core/translate/plan.rs b/core/translate/plan.rs index eba50ce89..e43cdbd76 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -1048,6 +1048,24 @@ pub struct Aggregate { } impl Aggregate { + pub fn new(func: AggFunc, args: &[Box], expr: &Expr, distinctness: Distinctness) -> Self { + let agg_args = if args.is_empty() { + // The AggStep instruction requires at least one argument. For functions that accept + // zero arguments (e.g. COUNT()), we insert a dummy literal so that AggStep remains valid. + // This does not cause ambiguity: the resolver has already verified that the function + // takes zero arguments, so the dummy value will be ignored. + vec![Expr::Literal(ast::Literal::Numeric("1".to_string()))] + } else { + args.iter().map(|arg| *arg.clone()).collect() + }; + Aggregate { + func, + args: agg_args, + original_expr: expr.clone(), + distinctness, + } + } + pub fn is_distinct(&self) -> bool { self.distinctness.is_distinct() } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 1799ed42b..522256a25 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -73,12 +73,7 @@ pub fn resolve_aggregates( "DISTINCT aggregate functions must have exactly one argument" ); } - aggs.push(Aggregate { - func: f, - args: args.iter().map(|arg| *arg.clone()).collect(), - original_expr: expr.clone(), - distinctness, - }); + aggs.push(Aggregate::new(f, args, expr, distinctness)); contains_aggregates = true; } _ => { @@ -95,12 +90,7 @@ pub fn resolve_aggregates( ); } if let Ok(Func::Agg(f)) = Func::resolve_function(name.as_str(), 0) { - aggs.push(Aggregate { - func: f, - args: vec![], - original_expr: expr.clone(), - distinctness: Distinctness::NonDistinct, - }); + aggs.push(Aggregate::new(f, &[], expr, Distinctness::NonDistinct)); contains_aggregates = true; } } diff --git a/core/translate/select.rs b/core/translate/select.rs index 8641e2347..59239094e 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -371,27 +371,7 @@ fn prepare_one_select_plan( } match Func::resolve_function(name.as_str(), args_count) { Ok(Func::Agg(f)) => { - let agg_args = match (args.is_empty(), &f) { - (true, crate::function::AggFunc::Count0) => { - // COUNT() case - vec![ast::Expr::Literal(ast::Literal::Numeric( - "1".to_string(), - )) - .into()] - } - (true, _) => crate::bail_parse_error!( - "Aggregate function {} requires arguments", - name.as_str() - ), - (false, _) => args.clone(), - }; - - let agg = Aggregate { - func: f, - args: agg_args.iter().map(|arg| *arg.clone()).collect(), - original_expr: *expr.clone(), - distinctness, - }; + let agg = Aggregate::new(f, args, expr, distinctness); aggregate_expressions.push(agg); plan.result_columns.push(ResultSetColumn { alias: maybe_alias.as_ref().map(|alias| match alias { @@ -446,15 +426,12 @@ fn prepare_one_select_plan( contains_aggregates, }); } else { - let agg = Aggregate { - func: AggFunc::External(f.func.clone().into()), - args: args - .iter() - .map(|arg| *arg.clone()) - .collect(), - original_expr: *expr.clone(), + let agg = Aggregate::new( + AggFunc::External(f.func.clone().into()), + args, + expr, distinctness, - }; + ); aggregate_expressions.push(agg); plan.result_columns.push(ResultSetColumn { alias: maybe_alias.as_ref().map(|alias| { @@ -488,14 +465,8 @@ fn prepare_one_select_plan( } match Func::resolve_function(name.as_str(), 0) { Ok(Func::Agg(f)) => { - let agg = Aggregate { - func: f, - args: vec![ast::Expr::Literal(ast::Literal::Numeric( - "1".to_string(), - ))], - original_expr: *expr.clone(), - distinctness: Distinctness::NonDistinct, - }; + let agg = + Aggregate::new(f, &[], expr, Distinctness::NonDistinct); aggregate_expressions.push(agg); plan.result_columns.push(ResultSetColumn { alias: maybe_alias.as_ref().map(|alias| match alias { diff --git a/testing/agg-functions.test b/testing/agg-functions.test index 9becf56a4..13b4600d7 100755 --- a/testing/agg-functions.test +++ b/testing/agg-functions.test @@ -143,4 +143,35 @@ do_execsql_test select-agg-json-array-object { do_execsql_test select-distinct-agg-functions { SELECT sum(distinct age), count(distinct age), avg(distinct age) FROM users; -} {5050|100|50.5} \ No newline at end of file +} {5050|100|50.5} + +do_execsql_test select-json-group-object { + select price, + json_group_object(cast (id as text), name) + from products + group by price + order by price; +} {1.0|{"9":"boots"} +18.0|{"3":"shirt"} +25.0|{"4":"sweater"} +33.0|{"10":"coat"} +70.0|{"6":"shorts"} +74.0|{"5":"sweatshirt"} +78.0|{"7":"jeans"} +79.0|{"1":"hat"} +81.0|{"11":"accessories"} +82.0|{"2":"cap","8":"sneakers"}} + +do_execsql_test select-json-group-object-no-sorting-required { + select age, + json_group_object(cast (id as text), first_name) + from users + where first_name like 'Am%' + group by age + order by age + limit 5; +} {1|{"6737":"Amy"} +2|{"2297":"Amy","3580":"Amanda"} +3|{"3437":"Amanda"} +5|{"2378":"Amy","3227":"Amy","5605":"Amanda"} +7|{"2454":"Amber"}} diff --git a/testing/cli_tests/extensions.py b/testing/cli_tests/extensions.py index 8b0e3c3a5..8ce7341f0 100755 --- a/testing/cli_tests/extensions.py +++ b/testing/cli_tests/extensions.py @@ -7,22 +7,22 @@ from cli_tests.test_turso_cli import TestTursoShell sqlite_exec = "./scripts/limbo-sqlite3" sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ") -test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL); -INSERT INTO numbers (value) VALUES (1.0); -INSERT INTO numbers (value) VALUES (2.0); -INSERT INTO numbers (value) VALUES (3.0); -INSERT INTO numbers (value) VALUES (4.0); -INSERT INTO numbers (value) VALUES (5.0); -INSERT INTO numbers (value) VALUES (6.0); -INSERT INTO numbers (value) VALUES (7.0); -CREATE TABLE test (value REAL, percent REAL); -INSERT INTO test values (10, 25); -INSERT INTO test values (20, 25); -INSERT INTO test values (30, 25); -INSERT INTO test values (40, 25); -INSERT INTO test values (50, 25); -INSERT INTO test values (60, 25); -INSERT INTO test values (70, 25); +test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL, category TEXT DEFAULT 'A'); +INSERT INTO numbers (value, category) VALUES (1.0, 'A'); +INSERT INTO numbers (value, category) VALUES (2.0, 'A'); +INSERT INTO numbers (value, category) VALUES (3.0, 'A'); +INSERT INTO numbers (value, category) VALUES (4.0, 'B'); +INSERT INTO numbers (value, category) VALUES (5.0, 'B'); +INSERT INTO numbers (value, category) VALUES (6.0, 'B'); +INSERT INTO numbers (value, category) VALUES (7.0, 'B'); +CREATE TABLE test (value REAL, percent REAL, category TEXT); +INSERT INTO test values (10, 25, 'A'); +INSERT INTO test values (20, 25, 'A'); +INSERT INTO test values (30, 25, 'B'); +INSERT INTO test values (40, 25, 'C'); +INSERT INTO test values (50, 25, 'C'); +INSERT INTO test values (60, 25, 'C'); +INSERT INTO test values (70, 25, 'D'); """ @@ -174,6 +174,39 @@ def test_aggregates(): limbo.quit() +def test_grouped_aggregates(): + limbo = TestTursoShell(init_commands=test_data) + extension_path = "./target/debug/liblimbo_percentile" + limbo.execute_dot(f".load {extension_path}") + + limbo.run_test_fn( + "SELECT median(value) FROM numbers GROUP BY category;", + lambda res: "2.0\n5.5" == res, + "median aggregate function works", + ) + limbo.run_test_fn( + "SELECT percentile(value, percent) FROM test GROUP BY category;", + lambda res: "12.5\n30.0\n45.0\n70.0" == res, + "grouped aggregate percentile function with 2 arguments works", + ) + limbo.run_test_fn( + "SELECT percentile(value, 55) FROM test GROUP BY category;", + lambda res: "15.5\n30.0\n51.0\n70.0" == res, + "grouped aggregate percentile function with 1 argument works", + ) + limbo.run_test_fn( + "SELECT percentile_cont(value, 0.25) FROM test GROUP BY category;", + lambda res: "12.5\n30.0\n45.0\n70.0" == res, + "grouped aggregate percentile_cont function works", + ) + limbo.run_test_fn( + "SELECT percentile_disc(value, 0.55) FROM test GROUP BY category;", + lambda res: "10.0\n30.0\n50.0\n70.0" == res, + "grouped aggregate percentile_disc function works", + ) + limbo.quit() + + # Encoders and decoders def validate_url_encode(a): return a == "%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29" @@ -770,6 +803,7 @@ def main(): test_regexp() test_uuid() test_aggregates() + test_grouped_aggregates() test_crypto() test_series() test_ipaddr() diff --git a/testing/collate.test b/testing/collate.test index 2b9aa3ff3..d985286b9 100755 --- a/testing/collate.test +++ b/testing/collate.test @@ -74,3 +74,31 @@ do_execsql_test_on_specific_db {:memory:} collate_aggregation_explicit_nocase { insert into fruits(name) values ('Apple') ,('banana') ,('CHERRY'); select max(name collate nocase) from fruits; } {CHERRY} + +do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_default_binary { + create table fruits(name collate binary, category text); + insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B'); + select max(name) from fruits group by category; +} {banana +blueberry} + +do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_default_nocase { + create table fruits(name collate nocase, category text); + insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B'); + select max(name) from fruits group by category; +} {banana +CHERRY} + +do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_explicit_binary { + create table fruits(name collate nocase, category text); + insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B'); + select max(name collate binary) from fruits group by category; +} {banana +blueberry} + +do_execsql_test_on_specific_db {:memory:} collate_groupped_aggregation_explicit_nocase { + create table fruits(name collate binary, category text); + insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B'); + select max(name collate nocase) from fruits group by category; +} {banana +CHERRY} diff --git a/testing/groupby.test b/testing/groupby.test index c159b9892..dc5418110 100755 --- a/testing/groupby.test +++ b/testing/groupby.test @@ -145,6 +145,18 @@ do_execsql_test group_by_count_star { select u.first_name, count(*) from users u group by u.first_name limit 1; } {Aaron|41} +do_execsql_test group_by_count_star_in_expression { + select u.first_name, count(*) % 3 from users u group by u.first_name order by u.first_name limit 3; +} {Aaron|2 +Abigail|1 +Adam|0} + +do_execsql_test group_by_count_no_args_in_expression { + select u.first_name, count() % 3 from users u group by u.first_name order by u.first_name limit 3; +} {Aaron|2 +Abigail|1 +Adam|0} + do_execsql_test having { select u.first_name, round(avg(u.age)) from users u group by u.first_name having avg(u.age) > 97 order by avg(u.age) desc limit 5; } {Nina|100.0