Merge 'Unify handling of grouped and ungrouped aggregations' from Piotr Rżysko

The initial commits fix issues and plug gaps between ungrouped and
grouped aggregations.
The final commit consolidates the code that emits `AggStep` to prevent
future disparities between the two.

Reviewed-by: Preston Thorpe <preston@turso.tech>

Closes #2867
This commit is contained in:
Pekka Enberg
2025-09-02 09:11:40 +03:00
committed by GitHub
10 changed files with 351 additions and 484 deletions

View File

@@ -125,27 +125,161 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re
});
}
/// Emits the bytecode for processing an aggregate step.
/// E.g. in `SELECT SUM(price) FROM t`, 'price' is evaluated for every row, and the result is added to the accumulator.
/// Enum representing the source of the aggregate function arguments
///
/// This is distinct from the final step, which is called after the main loop has finished processing
/// Aggregate arguments can come from different sources, depending on how the aggregation
/// is evaluated:
/// * In the common grouped case, the aggregate function arguments are first inserted
/// into a sorter in the main loop, and in the group by aggregation phase we read
/// the data from the sorter.
/// * In grouped cases where no sorting is required, arguments are retrieved directly
/// from registers allocated in the main loop.
/// * In ungrouped cases, arguments are computed directly from the `args` expressions.
pub enum AggArgumentSource<'a> {
/// The aggregate function arguments are retrieved from a pseudo cursor
/// which reads from the GROUP BY sorter.
PseudoCursor {
cursor_id: usize,
col_start: usize,
dest_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved from a contiguous block of registers
/// allocated in the main loop for that given aggregate function.
Register {
src_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved by evaluating expressions.
Expression { aggregate: &'a Aggregate },
}
impl<'a> AggArgumentSource<'a> {
/// Create a new [AggArgumentSource] that retrieves the values from a GROUP BY sorter.
pub fn new_from_cursor(
program: &mut ProgramBuilder,
cursor_id: usize,
col_start: usize,
aggregate: &'a Aggregate,
) -> Self {
let dest_reg_start = program.alloc_registers(aggregate.args.len());
Self::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
aggregate,
}
}
/// Create a new [AggArgumentSource] that retrieves the values directly from an already
/// populated register or registers.
pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self {
Self::Register {
src_reg_start,
aggregate,
}
}
/// Create a new [AggArgumentSource] that retrieves the values by evaluating `args` expressions.
pub fn new_from_expression(aggregate: &'a Aggregate) -> Self {
Self::Expression { aggregate }
}
pub fn aggregate(&self) -> &Aggregate {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate,
AggArgumentSource::Register { aggregate, .. } => aggregate,
AggArgumentSource::Expression { aggregate } => aggregate,
}
}
pub fn agg_func(&self) -> &AggFunc {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func,
AggArgumentSource::Register { aggregate, .. } => &aggregate.func,
AggArgumentSource::Expression { aggregate } => &aggregate.func,
}
}
pub fn args(&self) -> &[ast::Expr] {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args,
AggArgumentSource::Register { aggregate, .. } => &aggregate.args,
AggArgumentSource::Expression { aggregate } => &aggregate.args,
}
}
pub fn num_args(&self) -> usize {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(),
AggArgumentSource::Register { aggregate, .. } => aggregate.args.len(),
AggArgumentSource::Expression { aggregate } => aggregate.args.len(),
}
}
/// Read the value of an aggregate function argument
pub fn translate(
&self,
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
resolver: &Resolver,
arg_idx: usize,
) -> Result<usize> {
match self {
AggArgumentSource::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
..
} => {
program.emit_column_or_rowid(
*cursor_id,
*col_start + arg_idx,
dest_reg_start + arg_idx,
);
Ok(dest_reg_start + arg_idx)
}
AggArgumentSource::Register {
src_reg_start: start_reg,
..
} => Ok(*start_reg + arg_idx),
AggArgumentSource::Expression { aggregate } => {
let dest_reg = program.alloc_register();
translate_expr(
program,
Some(referenced_tables),
&aggregate.args[arg_idx],
dest_reg,
resolver,
)
}
}
}
}
/// Emits the bytecode for processing an aggregate step.
///
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
/// and the actual result value of the aggregation is materialized.
///
/// Ungrouped aggregation is a special case of grouped aggregation that involves a single group.
///
/// Examples:
/// * In `SELECT SUM(price) FROM t`, `price` is evaluated for each row and added to the accumulator.
/// * In `SELECT product_category, SUM(price) FROM t GROUP BY product_category`, `price` is evaluated for
/// each row in the group and added to that groups accumulator.
pub fn translate_aggregation_step(
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
agg: &Aggregate,
agg_arg_source: AggArgumentSource,
target_register: usize,
resolver: &Resolver,
) -> Result<usize> {
let dest = match agg.func {
let num_args = agg_arg_source.num_args();
let func = agg_arg_source.agg_func();
let dest = match func {
AggFunc::Avg => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("avg bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -155,20 +289,16 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Count | AggFunc::Count0 => {
let expr_reg = if agg.args.is_empty() {
program.alloc_register()
} else {
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
expr_reg
};
handle_distinct(program, agg, expr_reg);
if num_args != 1 {
crate::bail_parse_error!("count bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: if matches!(agg.func, AggFunc::Count0) {
func: if matches!(func, AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
@@ -177,18 +307,16 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::GroupConcat => {
if agg.args.len() != 1 && agg.args.len() != 2 {
if num_args != 1 && num_args != 2 {
crate::bail_parse_error!("group_concat bad number of arguments");
}
let expr_reg = program.alloc_register();
let delimiter_reg = program.alloc_register();
let expr = &agg.args[0];
let delimiter_expr: ast::Expr;
if agg.args.len() == 2 {
match &agg.args[1] {
if num_args == 2 {
match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
@@ -201,8 +329,8 @@ pub fn translate_aggregation_step(
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
}
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
@@ -221,13 +349,12 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Max => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -238,13 +365,12 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Min => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -256,23 +382,12 @@ pub fn translate_aggregation_step(
}
#[cfg(feature = "json")]
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
if agg.args.len() != 2 {
if num_args != 2 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let value_expr = &agg.args[1];
let value_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let _ = translate_expr(
program,
Some(referenced_tables),
value_expr,
value_reg,
resolver,
)?;
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let value_reg = agg_arg_source.translate(program, referenced_tables, resolver, 1)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -284,13 +399,11 @@ pub fn translate_aggregation_step(
}
#[cfg(feature = "json")]
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -300,15 +413,13 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::StringAgg => {
if agg.args.len() != 2 {
if num_args != 2 {
crate::bail_parse_error!("string_agg bad number of arguments");
}
let expr_reg = program.alloc_register();
let delimiter_reg = program.alloc_register();
let expr = &agg.args[0];
let delimiter_expr = match &agg.args[1] {
let delimiter_expr = match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
@@ -316,7 +427,7 @@ pub fn translate_aggregation_step(
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
translate_expr(
program,
Some(referenced_tables),
@@ -335,13 +446,11 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Sum => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("sum bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -351,13 +460,11 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Total => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("total bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -367,31 +474,24 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::External(ref func) => {
let expr_reg = program.alloc_register();
let argc = func.agg_args().map_err(|_| {
LimboError::ExtensionError(
"External aggregate function called with wrong number of arguments".to_string(),
)
})?;
if argc != agg.args.len() {
if argc != num_args {
crate::bail_parse_error!(
"External aggregate function called with wrong number of arguments"
);
}
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
for i in 0..argc {
if i != 0 {
let _ = program.alloc_register();
let _ = agg_arg_source.translate(program, referenced_tables, resolver, i)?;
}
let _ = translate_expr(
program,
Some(referenced_tables),
&agg.args[i],
expr_reg + i,
resolver,
)?;
// invariant: distinct aggregates are only supported for single-argument functions
if argc == 1 {
handle_distinct(program, agg, expr_reg + i);
handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i);
}
}
program.emit_insn(Insn::AggStep {

View File

@@ -1,9 +1,16 @@
use turso_parser::ast;
use super::{
emitter::TranslateCtx,
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
order_by::order_by_sorter_insert,
plan::{Distinctness, GroupBy, SelectPlan},
result_row::emit_select_result,
};
use crate::translate::aggregation::{translate_aggregation_step, AggArgumentSource};
use crate::translate::expr::{walk_expr, WalkControl};
use crate::translate::plan::ResultSetColumn;
use crate::{
function::AggFunc,
schema::PseudoCursorType,
translate::collate::CollationSeq,
util::exprs_are_equivalent,
@@ -15,15 +22,6 @@ use crate::{
Result,
};
use super::{
aggregation::handle_distinct,
emitter::{Resolver, TranslateCtx},
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
order_by::order_by_sorter_insert,
plan::{Aggregate, Distinctness, GroupBy, SelectPlan, TableReferences},
result_row::emit_select_result,
};
/// Labels needed for various jumps in GROUP BY handling.
#[derive(Debug)]
pub struct GroupByLabels {
@@ -394,102 +392,6 @@ pub enum GroupByRowSource {
},
}
/// Enum representing the source of the aggregate function arguments
/// emitted for a group by aggregation.
/// In the common case, the aggregate function arguments are first inserted
/// into a sorter in the main loop, and in the group by aggregation phase
/// we read the data from the sorter.
///
/// In the alternative case, no sorting is required for group by,
/// and the aggregate function arguments are retrieved directly from
/// registers allocated in the main loop.
pub enum GroupByAggArgumentSource<'a> {
/// The aggregate function arguments are retrieved from a pseudo cursor
/// which reads from the GROUP BY sorter.
PseudoCursor {
cursor_id: usize,
col_start: usize,
dest_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved from a contiguous block of registers
/// allocated in the main loop for that given aggregate function.
Register {
src_reg_start: usize,
aggregate: &'a Aggregate,
},
}
impl<'a> GroupByAggArgumentSource<'a> {
/// Create a new [GroupByAggArgumentSource] that retrieves the values from a GROUP BY sorter.
pub fn new_from_cursor(
program: &mut ProgramBuilder,
cursor_id: usize,
col_start: usize,
aggregate: &'a Aggregate,
) -> Self {
let dest_reg_start = program.alloc_registers(aggregate.args.len());
Self::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
aggregate,
}
}
/// Create a new [GroupByAggArgumentSource] that retrieves the values directly from an already
/// populated register or registers.
pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self {
Self::Register {
src_reg_start,
aggregate,
}
}
pub fn aggregate(&self) -> &Aggregate {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate,
GroupByAggArgumentSource::Register { aggregate, .. } => aggregate,
}
}
pub fn agg_func(&self) -> &AggFunc {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func,
GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.func,
}
}
pub fn args(&self) -> &[ast::Expr] {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args,
GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.args,
}
}
pub fn num_args(&self) -> usize {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(),
GroupByAggArgumentSource::Register { aggregate, .. } => aggregate.args.len(),
}
}
/// Read the value of an aggregate function argument either from sorter data or directly from a register.
pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result<usize> {
match self {
GroupByAggArgumentSource::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
..
} => {
program.emit_column_or_rowid(*cursor_id, *col_start, dest_reg_start + arg_idx);
Ok(dest_reg_start + arg_idx)
}
GroupByAggArgumentSource::Register {
src_reg_start: start_reg,
..
} => Ok(*start_reg + arg_idx),
}
}
}
/// Emits bytecode for processing a single GROUP BY group.
pub fn group_by_process_single_group(
program: &mut ProgramBuilder,
@@ -593,21 +495,19 @@ pub fn group_by_process_single_group(
.expect("aggregate registers must be initialized");
let agg_result_reg = start_reg + i;
let agg_arg_source = match &row_source {
GroupByRowSource::Sorter { pseudo_cursor, .. } => {
GroupByAggArgumentSource::new_from_cursor(
program,
*pseudo_cursor,
cursor_index + offset,
agg,
)
}
GroupByRowSource::Sorter { pseudo_cursor, .. } => AggArgumentSource::new_from_cursor(
program,
*pseudo_cursor,
cursor_index + offset,
agg,
),
GroupByRowSource::MainLoop { start_reg_src, .. } => {
// Aggregation arguments are always placed in the registers that follow any scalars.
let start_reg_aggs = start_reg_src + t_ctx.non_aggregate_expressions.len();
GroupByAggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
AggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
}
};
translate_aggregation_step_groupby(
translate_aggregation_step(
program,
&plan.table_references,
agg_arg_source,
@@ -897,220 +797,3 @@ pub fn group_by_emit_row_phase<'a>(
program.preassign_label_to_next_insn(labels.label_group_by_end);
Ok(())
}
/// Emits the bytecode for processing an aggregate step within a GROUP BY clause.
/// Eg. in `SELECT product_category, SUM(price) FROM t GROUP BY line_item`, 'price' is evaluated for every row
/// where the 'product_category' is the same, and the result is added to the accumulator for that category.
///
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
/// and the actual result value of the aggregation is materialized.
pub fn translate_aggregation_step_groupby(
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
agg_arg_source: GroupByAggArgumentSource,
target_register: usize,
resolver: &Resolver,
) -> Result<usize> {
let num_args = agg_arg_source.num_args();
let dest = match agg_arg_source.agg_func() {
AggFunc::Avg => {
if num_args != 1 {
crate::bail_parse_error!("avg bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Avg,
});
target_register
}
AggFunc::Count | AggFunc::Count0 => {
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: if matches!(agg_arg_source.agg_func(), AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
},
});
target_register
}
AggFunc::GroupConcat => {
let num_args = agg_arg_source.num_args();
if num_args != 1 && num_args != 2 {
crate::bail_parse_error!("group_concat bad number of arguments");
}
let delimiter_reg = program.alloc_register();
let delimiter_expr: ast::Expr;
if num_args == 2 {
match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
ast::Expr::Literal(ast::Literal::String(s)) => {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string()));
}
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
} else {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
&delimiter_expr,
delimiter_reg,
resolver,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: delimiter_reg,
func: AggFunc::GroupConcat,
});
target_register
}
AggFunc::Max => {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Max,
});
target_register
}
AggFunc::Min => {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Min,
});
target_register
}
#[cfg(feature = "json")]
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::JsonGroupArray,
});
target_register
}
#[cfg(feature = "json")]
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
if num_args != 2 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let value_reg = agg_arg_source.translate(program, 1)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: value_reg,
func: AggFunc::JsonGroupObject,
});
target_register
}
AggFunc::StringAgg => {
if num_args != 2 {
crate::bail_parse_error!("string_agg bad number of arguments");
}
let delimiter_reg = program.alloc_register();
let delimiter_expr = match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
}
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
&delimiter_expr,
delimiter_reg,
resolver,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: delimiter_reg,
func: AggFunc::StringAgg,
});
target_register
}
AggFunc::Sum => {
if num_args != 1 {
crate::bail_parse_error!("sum bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Sum,
});
target_register
}
AggFunc::Total => {
if num_args != 1 {
crate::bail_parse_error!("total bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Total,
});
target_register
}
AggFunc::External(_) => {
todo!("External aggregate functions are not yet supported in GROUP BY");
}
};
Ok(dest)
}

View File

@@ -19,7 +19,7 @@ use crate::{
};
use super::{
aggregation::translate_aggregation_step,
aggregation::{translate_aggregation_step, AggArgumentSource},
emitter::{OperationMode, TranslateCtx},
expr::{
translate_condition_expr, translate_expr, translate_expr_no_constant_opt,
@@ -868,7 +868,7 @@ fn emit_loop_source(
translate_aggregation_step(
program,
&plan.table_references,
agg,
AggArgumentSource::new_from_expression(agg),
reg,
&t_ctx.resolver,
)?;

View File

@@ -1048,6 +1048,24 @@ pub struct Aggregate {
}
impl Aggregate {
pub fn new(func: AggFunc, args: &[Box<Expr>], expr: &Expr, distinctness: Distinctness) -> Self {
let agg_args = if args.is_empty() {
// The AggStep instruction requires at least one argument. For functions that accept
// zero arguments (e.g. COUNT()), we insert a dummy literal so that AggStep remains valid.
// This does not cause ambiguity: the resolver has already verified that the function
// takes zero arguments, so the dummy value will be ignored.
vec![Expr::Literal(ast::Literal::Numeric("1".to_string()))]
} else {
args.iter().map(|arg| *arg.clone()).collect()
};
Aggregate {
func,
args: agg_args,
original_expr: expr.clone(),
distinctness,
}
}
pub fn is_distinct(&self) -> bool {
self.distinctness.is_distinct()
}

View File

@@ -73,12 +73,7 @@ pub fn resolve_aggregates(
"DISTINCT aggregate functions must have exactly one argument"
);
}
aggs.push(Aggregate {
func: f,
args: args.iter().map(|arg| *arg.clone()).collect(),
original_expr: expr.clone(),
distinctness,
});
aggs.push(Aggregate::new(f, args, expr, distinctness));
contains_aggregates = true;
}
_ => {
@@ -95,12 +90,7 @@ pub fn resolve_aggregates(
);
}
if let Ok(Func::Agg(f)) = Func::resolve_function(name.as_str(), 0) {
aggs.push(Aggregate {
func: f,
args: vec![],
original_expr: expr.clone(),
distinctness: Distinctness::NonDistinct,
});
aggs.push(Aggregate::new(f, &[], expr, Distinctness::NonDistinct));
contains_aggregates = true;
}
}

View File

@@ -371,27 +371,7 @@ fn prepare_one_select_plan(
}
match Func::resolve_function(name.as_str(), args_count) {
Ok(Func::Agg(f)) => {
let agg_args = match (args.is_empty(), &f) {
(true, crate::function::AggFunc::Count0) => {
// COUNT() case
vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))
.into()]
}
(true, _) => crate::bail_parse_error!(
"Aggregate function {} requires arguments",
name.as_str()
),
(false, _) => args.clone(),
};
let agg = Aggregate {
func: f,
args: agg_args.iter().map(|arg| *arg.clone()).collect(),
original_expr: *expr.clone(),
distinctness,
};
let agg = Aggregate::new(f, args, expr, distinctness);
aggregate_expressions.push(agg);
plan.result_columns.push(ResultSetColumn {
alias: maybe_alias.as_ref().map(|alias| match alias {
@@ -446,15 +426,12 @@ fn prepare_one_select_plan(
contains_aggregates,
});
} else {
let agg = Aggregate {
func: AggFunc::External(f.func.clone().into()),
args: args
.iter()
.map(|arg| *arg.clone())
.collect(),
original_expr: *expr.clone(),
let agg = Aggregate::new(
AggFunc::External(f.func.clone().into()),
args,
expr,
distinctness,
};
);
aggregate_expressions.push(agg);
plan.result_columns.push(ResultSetColumn {
alias: maybe_alias.as_ref().map(|alias| {
@@ -488,14 +465,8 @@ fn prepare_one_select_plan(
}
match Func::resolve_function(name.as_str(), 0) {
Ok(Func::Agg(f)) => {
let agg = Aggregate {
func: f,
args: vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))],
original_expr: *expr.clone(),
distinctness: Distinctness::NonDistinct,
};
let agg =
Aggregate::new(f, &[], expr, Distinctness::NonDistinct);
aggregate_expressions.push(agg);
plan.result_columns.push(ResultSetColumn {
alias: maybe_alias.as_ref().map(|alias| match alias {

View File

@@ -143,4 +143,35 @@ do_execsql_test select-agg-json-array-object {
do_execsql_test select-distinct-agg-functions {
SELECT sum(distinct age), count(distinct age), avg(distinct age) FROM users;
} {5050|100|50.5}
} {5050|100|50.5}
do_execsql_test select-json-group-object {
select price,
json_group_object(cast (id as text), name)
from products
group by price
order by price;
} {1.0|{"9":"boots"}
18.0|{"3":"shirt"}
25.0|{"4":"sweater"}
33.0|{"10":"coat"}
70.0|{"6":"shorts"}
74.0|{"5":"sweatshirt"}
78.0|{"7":"jeans"}
79.0|{"1":"hat"}
81.0|{"11":"accessories"}
82.0|{"2":"cap","8":"sneakers"}}
do_execsql_test select-json-group-object-no-sorting-required {
select age,
json_group_object(cast (id as text), first_name)
from users
where first_name like 'Am%'
group by age
order by age
limit 5;
} {1|{"6737":"Amy"}
2|{"2297":"Amy","3580":"Amanda"}
3|{"3437":"Amanda"}
5|{"2378":"Amy","3227":"Amy","5605":"Amanda"}
7|{"2454":"Amber"}}

View File

@@ -7,22 +7,22 @@ from cli_tests.test_turso_cli import TestTursoShell
sqlite_exec = "./scripts/limbo-sqlite3"
sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ")
test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL);
INSERT INTO numbers (value) VALUES (1.0);
INSERT INTO numbers (value) VALUES (2.0);
INSERT INTO numbers (value) VALUES (3.0);
INSERT INTO numbers (value) VALUES (4.0);
INSERT INTO numbers (value) VALUES (5.0);
INSERT INTO numbers (value) VALUES (6.0);
INSERT INTO numbers (value) VALUES (7.0);
CREATE TABLE test (value REAL, percent REAL);
INSERT INTO test values (10, 25);
INSERT INTO test values (20, 25);
INSERT INTO test values (30, 25);
INSERT INTO test values (40, 25);
INSERT INTO test values (50, 25);
INSERT INTO test values (60, 25);
INSERT INTO test values (70, 25);
test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL, category TEXT DEFAULT 'A');
INSERT INTO numbers (value, category) VALUES (1.0, 'A');
INSERT INTO numbers (value, category) VALUES (2.0, 'A');
INSERT INTO numbers (value, category) VALUES (3.0, 'A');
INSERT INTO numbers (value, category) VALUES (4.0, 'B');
INSERT INTO numbers (value, category) VALUES (5.0, 'B');
INSERT INTO numbers (value, category) VALUES (6.0, 'B');
INSERT INTO numbers (value, category) VALUES (7.0, 'B');
CREATE TABLE test (value REAL, percent REAL, category TEXT);
INSERT INTO test values (10, 25, 'A');
INSERT INTO test values (20, 25, 'A');
INSERT INTO test values (30, 25, 'B');
INSERT INTO test values (40, 25, 'C');
INSERT INTO test values (50, 25, 'C');
INSERT INTO test values (60, 25, 'C');
INSERT INTO test values (70, 25, 'D');
"""
@@ -174,6 +174,39 @@ def test_aggregates():
limbo.quit()
def test_grouped_aggregates():
limbo = TestTursoShell(init_commands=test_data)
extension_path = "./target/debug/liblimbo_percentile"
limbo.execute_dot(f".load {extension_path}")
limbo.run_test_fn(
"SELECT median(value) FROM numbers GROUP BY category;",
lambda res: "2.0\n5.5" == res,
"median aggregate function works",
)
limbo.run_test_fn(
"SELECT percentile(value, percent) FROM test GROUP BY category;",
lambda res: "12.5\n30.0\n45.0\n70.0" == res,
"grouped aggregate percentile function with 2 arguments works",
)
limbo.run_test_fn(
"SELECT percentile(value, 55) FROM test GROUP BY category;",
lambda res: "15.5\n30.0\n51.0\n70.0" == res,
"grouped aggregate percentile function with 1 argument works",
)
limbo.run_test_fn(
"SELECT percentile_cont(value, 0.25) FROM test GROUP BY category;",
lambda res: "12.5\n30.0\n45.0\n70.0" == res,
"grouped aggregate percentile_cont function works",
)
limbo.run_test_fn(
"SELECT percentile_disc(value, 0.55) FROM test GROUP BY category;",
lambda res: "10.0\n30.0\n50.0\n70.0" == res,
"grouped aggregate percentile_disc function works",
)
limbo.quit()
# Encoders and decoders
def validate_url_encode(a):
return a == "%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29"
@@ -770,6 +803,7 @@ def main():
test_regexp()
test_uuid()
test_aggregates()
test_grouped_aggregates()
test_crypto()
test_series()
test_ipaddr()

View File

@@ -74,3 +74,31 @@ do_execsql_test_on_specific_db {:memory:} collate_aggregation_explicit_nocase {
insert into fruits(name) values ('Apple') ,('banana') ,('CHERRY');
select max(name collate nocase) from fruits;
} {CHERRY}
do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_default_binary {
create table fruits(name collate binary, category text);
insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B');
select max(name) from fruits group by category;
} {banana
blueberry}
do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_default_nocase {
create table fruits(name collate nocase, category text);
insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B');
select max(name) from fruits group by category;
} {banana
CHERRY}
do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_explicit_binary {
create table fruits(name collate nocase, category text);
insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B');
select max(name collate binary) from fruits group by category;
} {banana
blueberry}
do_execsql_test_on_specific_db {:memory:} collate_groupped_aggregation_explicit_nocase {
create table fruits(name collate binary, category text);
insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B');
select max(name collate nocase) from fruits group by category;
} {banana
CHERRY}

View File

@@ -145,6 +145,18 @@ do_execsql_test group_by_count_star {
select u.first_name, count(*) from users u group by u.first_name limit 1;
} {Aaron|41}
do_execsql_test group_by_count_star_in_expression {
select u.first_name, count(*) % 3 from users u group by u.first_name order by u.first_name limit 3;
} {Aaron|2
Abigail|1
Adam|0}
do_execsql_test group_by_count_no_args_in_expression {
select u.first_name, count() % 3 from users u group by u.first_name order by u.first_name limit 3;
} {Aaron|2
Abigail|1
Adam|0}
do_execsql_test having {
select u.first_name, round(avg(u.age)) from users u group by u.first_name having avg(u.age) > 97 order by avg(u.age) desc limit 5;
} {Nina|100.0