Merge 'Fix handling of non-aggregate expressions' from Piotr Rżysko

This PR has two parts:
1. The first commit refactors how information about which registers
should be populated in the aggregation loop is calculated and
propagated. This simplification revealed a bug, which is addressed as
part of the same commit (see the included test).
2. The second commit fixes incorrect behavior for queries where complex
expressions include both aggregate and non-aggregate components. For
example, the following query previously produced incorrect results:
```sql
SELECT
  CASE WHEN c0 != 'x' THEN group_concat(c1, ',') ELSE 'x' END
FROM t0
GROUP BY c0;
```
In such cases, non-aggregate columns like `c0` were not available during
the result construction for each group, leading to incorrect evaluation.

Reviewed-by: Jussi Saurio <jussi.saurio@gmail.com>

Closes #1780
This commit is contained in:
Jussi Saurio
2025-06-20 21:56:35 +03:00
10 changed files with 470 additions and 445 deletions

View File

@@ -41,6 +41,7 @@ pub fn emit_ungrouped_aggregation<'a>(
.expr_to_reg_cache
.push((&agg.original_expr, agg_start_reg + i));
}
t_ctx.resolver.enable_expr_to_reg_cache();
// This always emits a ResultRow because currently it can only be used for a single row result
// Limit is None because we early exit on limit 0 and the max rows here is 1

View File

@@ -3,7 +3,7 @@
use std::rc::Rc;
use limbo_sqlite3_parser::ast::{self};
use limbo_sqlite3_parser::ast::{self, Expr};
use tracing::{instrument, Level};
use super::aggregation::emit_ungrouped_aggregation;
@@ -15,7 +15,9 @@ use super::main_loop::{
close_loop, emit_loop, init_distinct, init_loop, open_loop, LeftJoinMetadata, LoopLabels,
};
use super::order_by::{emit_order_by, init_order_by, SortMetadata};
use super::plan::{JoinOrderMember, Operation, SelectPlan, TableReferences, UpdatePlan};
use super::plan::{
Distinctness, JoinOrderMember, Operation, SelectPlan, TableReferences, UpdatePlan,
};
use super::select::emit_simple_count;
use super::subquery::emit_subqueries;
use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY;
@@ -33,6 +35,7 @@ use crate::{Result, SymbolTable};
pub struct Resolver<'a> {
pub schema: &'a Schema,
pub symbol_table: &'a SymbolTable,
pub expr_to_reg_cache_enabled: bool,
pub expr_to_reg_cache: Vec<(&'a ast::Expr, usize)>,
}
@@ -41,6 +44,7 @@ impl<'a> Resolver<'a> {
Self {
schema,
symbol_table,
expr_to_reg_cache_enabled: false,
expr_to_reg_cache: Vec::new(),
}
}
@@ -55,11 +59,19 @@ impl<'a> Resolver<'a> {
}
}
pub(crate) fn enable_expr_to_reg_cache(&mut self) {
self.expr_to_reg_cache_enabled = true;
}
pub fn resolve_cached_expr_reg(&self, expr: &ast::Expr) -> Option<usize> {
self.expr_to_reg_cache
.iter()
.find(|(e, _)| exprs_are_equivalent(expr, e))
.map(|(_, reg)| *reg)
if self.expr_to_reg_cache_enabled {
self.expr_to_reg_cache
.iter()
.find(|(e, _)| exprs_are_equivalent(expr, e))
.map(|(_, reg)| *reg)
} else {
None
}
}
}
@@ -125,6 +137,17 @@ pub struct TranslateCtx<'a> {
// This vector holds the indexes of the result columns that we need to skip.
pub result_columns_to_skip_in_orderby_sorter: Option<Vec<usize>>,
pub resolver: Resolver<'a>,
/// A list of expressions that are not aggregates, along with a flag indicating
/// whether the expression should be included in the output for each group.
///
/// Each entry is a tuple:
/// - `&'ast Expr`: the expression itself
/// - `bool`: `true` if the expression should be included in the output for each group, `false` otherwise.
///
/// The order of expressions is **significant**:
/// - First: all `GROUP BY` expressions, in the order they appear in the `GROUP BY` clause.
/// - Then: remaining non-aggregate expressions that are not part of `GROUP BY`.
pub non_aggregate_expressions: Vec<(&'a Expr, bool)>,
}
impl<'a> TranslateCtx<'a> {
@@ -150,6 +173,7 @@ impl<'a> TranslateCtx<'a> {
result_column_indexes_in_orderby_sorter: (0..result_column_count).collect(),
result_columns_to_skip_in_orderby_sorter: None,
resolver: Resolver::new(schema, syms),
non_aggregate_expressions: Vec::new(),
}
}
}
@@ -280,14 +304,28 @@ pub fn emit_query<'a>(
}
if let Some(ref group_by) = plan.group_by {
init_group_by(program, t_ctx, group_by, &plan)?;
init_group_by(
program,
t_ctx,
group_by,
&plan,
&plan.result_columns,
&plan.order_by,
)?;
} else if !plan.aggregates.is_empty() {
// Aggregate registers need to be NULLed at the start because the same registers might be reused on another invocation of a subquery,
// and if they are not NULLed, the 2nd invocation of the same subquery will have values left over from the first invocation.
t_ctx.reg_agg_start = Some(program.alloc_registers_and_init_w_null(plan.aggregates.len()));
}
init_distinct(program, plan);
let distinct_ctx = if let Distinctness::Distinct { .. } = &plan.distinctness {
Some(init_distinct(program, plan))
} else {
None
};
if let Distinctness::Distinct { ctx } = &mut plan.distinctness {
*ctx = distinct_ctx
}
init_loop(
program,
t_ctx,

View File

@@ -2661,186 +2661,195 @@ pub fn unwrap_parens_owned(expr: ast::Expr) -> Result<(ast::Expr, usize)> {
}
}
/// Recursively walks an immutable expression, applying a function to each sub-expression.
pub fn walk_expr<'a, F>(expr: &'a ast::Expr, func: &mut F) -> Result<()>
where
F: FnMut(&'a ast::Expr) -> Result<()>,
{
func(expr)?;
match expr {
ast::Expr::Between {
lhs, start, end, ..
} => {
walk_expr(lhs, func)?;
walk_expr(start, func)?;
walk_expr(end, func)?;
}
ast::Expr::Binary(lhs, _, rhs) => {
walk_expr(lhs, func)?;
walk_expr(rhs, func)?;
}
ast::Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(base_expr) = base {
walk_expr(base_expr, func)?;
}
for (when_expr, then_expr) in when_then_pairs {
walk_expr(when_expr, func)?;
walk_expr(then_expr, func)?;
}
if let Some(else_expr) = else_expr {
walk_expr(else_expr, func)?;
}
}
ast::Expr::Cast { expr, .. } => {
walk_expr(expr, func)?;
}
ast::Expr::Collate(expr, _) => {
walk_expr(expr, func)?;
}
ast::Expr::Exists(_select) | ast::Expr::Subquery(_select) => {
// TODO: Walk through select statements if needed
}
ast::Expr::FunctionCall {
args,
order_by,
filter_over,
..
} => {
if let Some(args) = args {
for arg in args {
walk_expr(arg, func)?;
}
}
if let Some(order_by) = order_by {
for sort_col in order_by {
walk_expr(&sort_col.expr, func)?;
}
}
if let Some(filter_over) = filter_over {
if let Some(filter_clause) = &filter_over.filter_clause {
walk_expr(filter_clause, func)?;
}
if let Some(over_clause) = &filter_over.over_clause {
match over_clause.as_ref() {
ast::Over::Window(window) => {
if let Some(partition_by) = &window.partition_by {
for part_expr in partition_by {
walk_expr(part_expr, func)?;
}
}
if let Some(order_by_clause) = &window.order_by {
for sort_col in order_by_clause {
walk_expr(&sort_col.expr, func)?;
}
}
if let Some(frame_clause) = &window.frame_clause {
walk_expr_frame_bound(&frame_clause.start, func)?;
if let Some(end_bound) = &frame_clause.end {
walk_expr_frame_bound(end_bound, func)?;
}
}
}
ast::Over::Name(_) => {}
}
}
}
}
ast::Expr::FunctionCallStar { filter_over, .. } => {
if let Some(filter_over) = filter_over {
if let Some(filter_clause) = &filter_over.filter_clause {
walk_expr(filter_clause, func)?;
}
if let Some(over_clause) = &filter_over.over_clause {
match over_clause.as_ref() {
ast::Over::Window(window) => {
if let Some(partition_by) = &window.partition_by {
for part_expr in partition_by {
walk_expr(part_expr, func)?;
}
}
if let Some(order_by_clause) = &window.order_by {
for sort_col in order_by_clause {
walk_expr(&sort_col.expr, func)?;
}
}
if let Some(frame_clause) = &window.frame_clause {
walk_expr_frame_bound(&frame_clause.start, func)?;
if let Some(end_bound) = &frame_clause.end {
walk_expr_frame_bound(end_bound, func)?;
}
}
}
ast::Over::Name(_) => {}
}
}
}
}
ast::Expr::InList { lhs, rhs, .. } => {
walk_expr(lhs, func)?;
if let Some(rhs_exprs) = rhs {
for expr in rhs_exprs {
walk_expr(expr, func)?;
}
}
}
ast::Expr::InSelect { lhs, rhs: _, .. } => {
walk_expr(lhs, func)?;
// TODO: Walk through select statements if needed
}
ast::Expr::InTable { lhs, args, .. } => {
walk_expr(lhs, func)?;
if let Some(arg_exprs) = args {
for expr in arg_exprs {
walk_expr(expr, func)?;
}
}
}
ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => {
walk_expr(expr, func)?;
}
ast::Expr::Like {
lhs, rhs, escape, ..
} => {
walk_expr(lhs, func)?;
walk_expr(rhs, func)?;
if let Some(esc_expr) = escape {
walk_expr(esc_expr, func)?;
}
}
ast::Expr::Parenthesized(exprs) => {
for expr in exprs {
walk_expr(expr, func)?;
}
}
ast::Expr::Raise(_, expr) => {
if let Some(raise_expr) = expr {
walk_expr(raise_expr, func)?;
}
}
ast::Expr::Unary(_, expr) => {
walk_expr(expr, func)?;
}
ast::Expr::Id(_)
| ast::Expr::Column { .. }
| ast::Expr::RowId { .. }
| ast::Expr::Literal(_)
| ast::Expr::DoublyQualified(..)
| ast::Expr::Name(_)
| ast::Expr::Qualified(..)
| ast::Expr::Variable(_) => {
// No nested expressions
}
}
Ok(())
pub enum WalkControl {
Continue, // Visit children
SkipChildren, // Skip children but continue walking siblings
}
fn walk_expr_frame_bound<'a, F>(bound: &'a ast::FrameBound, func: &mut F) -> Result<()>
/// Recursively walks an immutable expression, applying a function to each sub-expression.
pub fn walk_expr<'a, F>(expr: &'a ast::Expr, func: &mut F) -> Result<WalkControl>
where
F: FnMut(&'a ast::Expr) -> Result<()>,
F: FnMut(&'a ast::Expr) -> Result<WalkControl>,
{
match func(expr)? {
WalkControl::Continue => {
match expr {
ast::Expr::Between {
lhs, start, end, ..
} => {
walk_expr(lhs, func)?;
walk_expr(start, func)?;
walk_expr(end, func)?;
}
ast::Expr::Binary(lhs, _, rhs) => {
walk_expr(lhs, func)?;
walk_expr(rhs, func)?;
}
ast::Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(base_expr) = base {
walk_expr(base_expr, func)?;
}
for (when_expr, then_expr) in when_then_pairs {
walk_expr(when_expr, func)?;
walk_expr(then_expr, func)?;
}
if let Some(else_expr) = else_expr {
walk_expr(else_expr, func)?;
}
}
ast::Expr::Cast { expr, .. } => {
walk_expr(expr, func)?;
}
ast::Expr::Collate(expr, _) => {
walk_expr(expr, func)?;
}
ast::Expr::Exists(_select) | ast::Expr::Subquery(_select) => {
// TODO: Walk through select statements if needed
}
ast::Expr::FunctionCall {
args,
order_by,
filter_over,
..
} => {
if let Some(args) = args {
for arg in args {
walk_expr(arg, func)?;
}
}
if let Some(order_by) = order_by {
for sort_col in order_by {
walk_expr(&sort_col.expr, func)?;
}
}
if let Some(filter_over) = filter_over {
if let Some(filter_clause) = &filter_over.filter_clause {
walk_expr(filter_clause, func)?;
}
if let Some(over_clause) = &filter_over.over_clause {
match over_clause.as_ref() {
ast::Over::Window(window) => {
if let Some(partition_by) = &window.partition_by {
for part_expr in partition_by {
walk_expr(part_expr, func)?;
}
}
if let Some(order_by_clause) = &window.order_by {
for sort_col in order_by_clause {
walk_expr(&sort_col.expr, func)?;
}
}
if let Some(frame_clause) = &window.frame_clause {
walk_expr_frame_bound(&frame_clause.start, func)?;
if let Some(end_bound) = &frame_clause.end {
walk_expr_frame_bound(end_bound, func)?;
}
}
}
ast::Over::Name(_) => {}
}
}
}
}
ast::Expr::FunctionCallStar { filter_over, .. } => {
if let Some(filter_over) = filter_over {
if let Some(filter_clause) = &filter_over.filter_clause {
walk_expr(filter_clause, func)?;
}
if let Some(over_clause) = &filter_over.over_clause {
match over_clause.as_ref() {
ast::Over::Window(window) => {
if let Some(partition_by) = &window.partition_by {
for part_expr in partition_by {
walk_expr(part_expr, func)?;
}
}
if let Some(order_by_clause) = &window.order_by {
for sort_col in order_by_clause {
walk_expr(&sort_col.expr, func)?;
}
}
if let Some(frame_clause) = &window.frame_clause {
walk_expr_frame_bound(&frame_clause.start, func)?;
if let Some(end_bound) = &frame_clause.end {
walk_expr_frame_bound(end_bound, func)?;
}
}
}
ast::Over::Name(_) => {}
}
}
}
}
ast::Expr::InList { lhs, rhs, .. } => {
walk_expr(lhs, func)?;
if let Some(rhs_exprs) = rhs {
for expr in rhs_exprs {
walk_expr(expr, func)?;
}
}
}
ast::Expr::InSelect { lhs, rhs: _, .. } => {
walk_expr(lhs, func)?;
// TODO: Walk through select statements if needed
}
ast::Expr::InTable { lhs, args, .. } => {
walk_expr(lhs, func)?;
if let Some(arg_exprs) = args {
for expr in arg_exprs {
walk_expr(expr, func)?;
}
}
}
ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => {
walk_expr(expr, func)?;
}
ast::Expr::Like {
lhs, rhs, escape, ..
} => {
walk_expr(lhs, func)?;
walk_expr(rhs, func)?;
if let Some(esc_expr) = escape {
walk_expr(esc_expr, func)?;
}
}
ast::Expr::Parenthesized(exprs) => {
for expr in exprs {
walk_expr(expr, func)?;
}
}
ast::Expr::Raise(_, expr) => {
if let Some(raise_expr) = expr {
walk_expr(raise_expr, func)?;
}
}
ast::Expr::Unary(_, expr) => {
walk_expr(expr, func)?;
}
ast::Expr::Id(_)
| ast::Expr::Column { .. }
| ast::Expr::RowId { .. }
| ast::Expr::Literal(_)
| ast::Expr::DoublyQualified(..)
| ast::Expr::Name(_)
| ast::Expr::Qualified(..)
| ast::Expr::Variable(_) => {
// No nested expressions
}
}
}
WalkControl::SkipChildren => return Ok(WalkControl::Continue),
};
Ok(WalkControl::Continue)
}
fn walk_expr_frame_bound<'a, F>(bound: &'a ast::FrameBound, func: &mut F) -> Result<WalkControl>
where
F: FnMut(&'a ast::Expr) -> Result<WalkControl>,
{
match bound {
ast::FrameBound::Following(expr) | ast::FrameBound::Preceding(expr) => {
@@ -2851,7 +2860,7 @@ where
| ast::FrameBound::UnboundedPreceding => {}
}
Ok(())
Ok(WalkControl::Continue)
}
/// Recursively walks a mutable expression, applying a function to each sub-expression.

View File

@@ -2,6 +2,8 @@ use std::rc::Rc;
use limbo_sqlite3_parser::ast;
use crate::translate::expr::{walk_expr, WalkControl};
use crate::translate::plan::ResultSetColumn;
use crate::{
function::AggFunc,
schema::{Column, PseudoTable},
@@ -76,23 +78,24 @@ pub struct GroupByMetadata {
pub row_source: GroupByRowSource,
pub labels: GroupByLabels,
pub registers: GroupByRegisters,
// Columns that not part of GROUP BY clause and not arguments of Aggregation function.
// Heavy calculation and needed in different functions, so it is reasonable to do it once and save.
pub non_group_by_non_agg_column_count: usize,
}
/// Initialize resources needed for GROUP BY processing
pub fn init_group_by(
pub fn init_group_by<'a>(
program: &mut ProgramBuilder,
t_ctx: &mut TranslateCtx,
group_by: &GroupBy,
t_ctx: &mut TranslateCtx<'a>,
group_by: &'a GroupBy,
plan: &SelectPlan,
result_columns: &'a Vec<ResultSetColumn>,
order_by: &'a Option<Vec<(ast::Expr, ast::SortOrder)>>,
) -> Result<()> {
let non_aggregate_count = plan
.result_columns
.iter()
.filter(|rc| !rc.contains_aggregates)
.count();
collect_non_aggregate_expressions(
&mut t_ctx.non_aggregate_expressions,
group_by,
plan,
result_columns,
order_by,
)?;
let label_subrtn_acc_output = program.allocate_label();
let label_group_by_end_without_emitting_row = program.allocate_label();
@@ -112,7 +115,8 @@ pub fn init_group_by(
// The following two blocks of registers should always be allocated contiguously,
// because they are cleared in a contiguous block in the GROUP BYs clear accumulator subroutine.
// START BLOCK
let reg_non_aggregate_exprs_acc = program.alloc_registers(non_aggregate_count);
let reg_non_aggregate_exprs_acc =
program.alloc_registers(t_ctx.non_aggregate_expressions.len());
if !plan.aggregates.is_empty() {
// Aggregate registers need to be NULLed at the start because the same registers might be reused on another invocation of a subquery,
// and if they are not NULLed, the 2nd invocation of the same subquery will have values left over from the first invocation.
@@ -121,12 +125,11 @@ pub fn init_group_by(
// END BLOCK
let reg_sorter_key = program.alloc_register();
let column_count = plan.group_by_sorter_column_count();
let column_count = plan.agg_args_count() + t_ctx.non_aggregate_expressions.len();
let reg_group_by_source_cols_start = program.alloc_registers(column_count);
let row_source = if let Some(sort_order) = group_by.sort_order.as_ref() {
let sort_cursor = program.alloc_cursor_id(CursorType::Sorter);
let sorter_column_count = plan.group_by_sorter_column_count();
// Should work the same way as Order By
/*
* Terms of the ORDER BY clause that is part of a SELECT statement may be assigned a collating sequence using the COLLATE operator,
@@ -160,21 +163,17 @@ pub fn init_group_by(
program.emit_insn(Insn::SorterOpen {
cursor_id: sort_cursor,
columns: sorter_column_count,
columns: column_count,
order: sort_order.clone(),
collations,
});
let pseudo_cursor = group_by_create_pseudo_table(program, sorter_column_count);
let pseudo_cursor = group_by_create_pseudo_table(program, column_count);
GroupByRowSource::Sorter {
pseudo_cursor,
sort_cursor,
reg_sorter_key,
column_register_mapping: group_by_create_column_register_mapping(
group_by,
reg_non_aggregate_exprs_acc,
plan,
),
sorter_column_count,
sorter_column_count: column_count,
start_reg_dest: reg_non_aggregate_exprs_acc,
}
} else {
GroupByRowSource::MainLoop {
@@ -232,11 +231,72 @@ pub fn init_group_by(
reg_subrtn_acc_clear_return_offset,
reg_group_by_source_cols_start,
},
non_group_by_non_agg_column_count: plan.non_group_by_non_agg_column_count(),
});
Ok(())
}
fn collect_non_aggregate_expressions<'a>(
non_aggregate_expressions: &mut Vec<(&'a ast::Expr, bool)>,
group_by: &'a GroupBy,
plan: &SelectPlan,
root_result_columns: &'a Vec<ResultSetColumn>,
order_by: &'a Option<Vec<(ast::Expr, ast::SortOrder)>>,
) -> Result<()> {
let mut result_columns = Vec::new();
for expr in root_result_columns
.iter()
.map(|col| &col.expr)
.chain(order_by.iter().flat_map(|o| o.iter().map(|(e, _)| e)))
.chain(group_by.having.iter().flatten())
{
collect_result_columns(expr, plan, &mut result_columns)?;
}
for group_expr in &group_by.exprs {
let in_result = result_columns
.iter()
.any(|expr| exprs_are_equivalent(expr, group_expr));
non_aggregate_expressions.push((group_expr, in_result));
}
for expr in result_columns {
let in_group_by = group_by
.exprs
.iter()
.any(|group_expr| exprs_are_equivalent(expr, group_expr));
if !in_group_by {
non_aggregate_expressions.push((expr, true));
}
}
Ok(())
}
fn collect_result_columns<'a>(
root_expr: &'a ast::Expr,
plan: &SelectPlan,
result_columns: &mut Vec<&'a ast::Expr>,
) -> Result<()> {
walk_expr(root_expr, &mut |expr: &ast::Expr| -> Result<WalkControl> {
match expr {
ast::Expr::Column { table, .. } | ast::Expr::RowId { table, .. } => {
if plan
.table_references
.find_joined_table_by_internal_id(*table)
.is_some()
{
result_columns.push(expr);
}
}
_ => {
if plan.aggregates.iter().any(|a| a.original_expr == *expr) {
return Ok(WalkControl::SkipChildren);
}
}
};
Ok(WalkControl::Continue)
})?;
Ok(())
}
/// In case sorting is needed for GROUP BY, creates a pseudo table that matches
/// the number of columns in the GROUP BY sorter. Rows are individually read
/// from the sorter into this pseudo table and processed.
@@ -338,9 +398,7 @@ pub enum GroupByRowSource {
reg_sorter_key: usize,
/// Number of columns in the GROUP BY sorter
sorter_column_count: usize,
/// In case some result columns of the SELECT query are equivalent to GROUP BY members,
/// this mapping encodes their position.
column_register_mapping: Vec<Option<usize>>,
start_reg_dest: usize,
},
MainLoop {
/// If GROUP BY rows are read directly in the main loop, start_reg is the first register
@@ -454,17 +512,16 @@ impl<'a> GroupByAggArgumentSource<'a> {
}
/// Emits bytecode for processing a single GROUP BY group.
pub fn group_by_process_single_group(
pub fn group_by_process_single_group<'a>(
program: &mut ProgramBuilder,
group_by: &GroupBy,
plan: &SelectPlan,
t_ctx: &TranslateCtx,
group_by: &'a GroupBy,
plan: &'a SelectPlan,
t_ctx: &mut TranslateCtx<'a>,
) -> Result<()> {
let GroupByMetadata {
registers,
labels,
row_source,
non_group_by_non_agg_column_count,
..
} = t_ctx
.meta_group_by
@@ -549,7 +606,7 @@ pub fn group_by_process_single_group(
// Process each aggregate function for the current row
program.resolve_label(labels.label_grouping_agg_step, program.offset());
let cursor_index = *non_group_by_non_agg_column_count + group_by.exprs.len(); // Skipping all columns in sorter that not an aggregation arguments
let cursor_index = t_ctx.non_aggregate_expressions.len(); // Skipping all columns in sorter that not an aggregation arguments
let mut offset = 0;
for (i, agg) in plan.aggregates.iter().enumerate() {
let start_reg = t_ctx
@@ -567,8 +624,7 @@ pub fn group_by_process_single_group(
}
GroupByRowSource::MainLoop { start_reg_src, .. } => {
// Aggregation arguments are always placed in the registers that follow any scalars.
let start_reg_aggs =
start_reg_src + group_by.exprs.len() + plan.non_group_by_non_agg_column_count();
let start_reg_aggs = start_reg_src + t_ctx.non_aggregate_expressions.len();
GroupByAggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
}
};
@@ -604,27 +660,32 @@ pub fn group_by_process_single_group(
match row_source {
GroupByRowSource::Sorter {
pseudo_cursor,
column_register_mapping,
start_reg_dest,
..
} => {
for (sorter_column_index, dest_reg) in column_register_mapping.iter().enumerate() {
if let Some(dest_reg) = dest_reg {
program.emit_column(*pseudo_cursor, sorter_column_index, *dest_reg);
let mut sorter_column_index = 0;
let mut next_reg = *start_reg_dest;
for (expr, in_result) in t_ctx.non_aggregate_expressions.iter() {
if *in_result {
program.emit_column(*pseudo_cursor, sorter_column_index, next_reg);
t_ctx.resolver.expr_to_reg_cache.push((expr, next_reg));
next_reg += 1;
}
sorter_column_index += 1;
}
}
GroupByRowSource::MainLoop { start_reg_dest, .. } => {
// Re-translate all the non-aggregate expressions into destination registers. We cannot use the same registers as emitted
// in the earlier part of the main loop, because they would be overwritten by the next group before the group results
// are processed.
for (i, rc) in plan
.result_columns
for (i, expr) in t_ctx
.non_aggregate_expressions
.iter()
.filter(|rc| !rc.contains_aggregates)
.filter_map(|(expr, in_result)| if *in_result { Some(expr) } else { None })
.enumerate()
{
let dest_reg = start_reg_dest + i;
let expr = &rc.expr;
translate_expr(
program,
Some(&plan.table_references),
@@ -632,6 +693,7 @@ pub fn group_by_process_single_group(
dest_reg,
&t_ctx.resolver,
)?;
t_ctx.resolver.expr_to_reg_cache.push((expr, dest_reg));
}
}
}
@@ -647,44 +709,6 @@ pub fn group_by_process_single_group(
Ok(())
}
pub fn group_by_create_column_register_mapping(
group_by: &GroupBy,
reg_non_aggregate_exprs_acc: usize,
plan: &SelectPlan,
) -> Vec<Option<usize>> {
// We have to know which group by expr present in resulting set
let group_by_expr_in_res_cols = group_by.exprs.iter().map(|expr| {
plan.result_columns
.iter()
.any(|e| exprs_are_equivalent(&e.expr, expr))
});
let group_by_count = group_by.exprs.len();
let non_group_by_non_agg_column_count = plan.non_group_by_non_agg_column_count();
// Create a map from sorter column index to result register
// This helps track where each column from the sorter should be stored
let mut column_register_mapping =
vec![None; group_by_count + non_group_by_non_agg_column_count];
let mut next_reg = reg_non_aggregate_exprs_acc;
// Map GROUP BY columns that are in the result set to registers
for (i, is_in_result) in group_by_expr_in_res_cols.clone().enumerate() {
if is_in_result {
column_register_mapping[i] = Some(next_reg);
next_reg += 1;
}
}
// Handle other non-aggregate columns that aren't part of GROUP BY and not part of Aggregation function
for i in group_by_count..group_by_count + non_group_by_non_agg_column_count {
column_register_mapping[i] = Some(next_reg);
next_reg += 1;
}
column_register_mapping
}
/// Emits the bytecode for processing the aggregation phase of a GROUP BY clause.
/// This is called either when:
/// 1. the main query execution loop has finished processing,
@@ -731,10 +755,7 @@ pub fn group_by_emit_row_phase<'a>(
) -> Result<()> {
let group_by = plan.group_by.as_ref().expect("group by not found");
let GroupByMetadata {
row_source,
labels,
registers,
..
labels, registers, ..
} = t_ctx
.meta_group_by
.as_ref()
@@ -795,82 +816,14 @@ pub fn group_by_emit_row_phase<'a>(
register: agg_result_reg,
func: agg.func.clone(),
});
}
// We have to know which group by expr present in resulting set
let group_by_expr_in_res_cols = group_by.exprs.iter().map(|expr| {
plan.result_columns
.iter()
.any(|e| exprs_are_equivalent(&e.expr, expr))
});
// Map GROUP BY expressions to their registers in the result set
for (i, (expr, is_in_result)) in group_by
.exprs
.iter()
.zip(group_by_expr_in_res_cols)
.enumerate()
{
if is_in_result {
match row_source {
GroupByRowSource::Sorter {
column_register_mapping,
..
} => {
if let Some(reg) = column_register_mapping.get(i).and_then(|opt| *opt) {
t_ctx.resolver.expr_to_reg_cache.push((expr, reg));
}
}
GroupByRowSource::MainLoop { start_reg_dest, .. } => {
t_ctx
.resolver
.expr_to_reg_cache
.push((expr, *start_reg_dest + i));
}
}
}
}
// Map non-aggregate, non-GROUP BY columns to their registers
let non_agg_cols = plan
.result_columns
.iter()
.filter(|rc| !rc.contains_aggregates && !is_column_in_group_by(&rc.expr, &group_by.exprs));
for (idx, rc) in non_agg_cols.enumerate() {
let column_relative_idx = plan.group_by_col_count() + idx;
match &row_source {
GroupByRowSource::Sorter {
column_register_mapping,
..
} => {
if let Some(reg) = column_register_mapping
.get(column_relative_idx)
.and_then(|opt| *opt)
{
t_ctx.resolver.expr_to_reg_cache.push((&rc.expr, reg));
}
}
GroupByRowSource::MainLoop { start_reg_dest, .. } => {
t_ctx
.resolver
.expr_to_reg_cache
.push((&rc.expr, start_reg_dest + column_relative_idx));
}
}
}
// Map aggregate expressions to their result registers
for (i, agg) in plan.aggregates.iter().enumerate() {
let agg_start_reg = t_ctx
.reg_agg_start
.expect("aggregate registers must be initialized");
t_ctx
.resolver
.expr_to_reg_cache
.push((&agg.original_expr, agg_start_reg + i));
.push((&agg.original_expr, agg_result_reg));
}
t_ctx.resolver.enable_expr_to_reg_cache();
if let Some(having) = &group_by.having {
for expr in having.iter() {
let if_true_target = program.allocate_label();
@@ -930,7 +883,9 @@ pub fn group_by_emit_row_phase<'a>(
// Reset all accumulator registers to NULL
program.emit_insn(Insn::Null {
dest: start_reg,
dest_end: Some(start_reg + plan.group_by_sorter_column_count() - 1),
dest_end: Some(
start_reg + t_ctx.non_aggregate_expressions.len() + plan.agg_args_count() - 1,
),
});
// Reopen ephemeral indexes for distinct aggregates (effectively clearing them).
@@ -1181,9 +1136,3 @@ pub fn translate_aggregation_step_groupby(
};
Ok(dest)
}
pub fn is_column_in_group_by(expr: &ast::Expr, group_by_exprs: &[ast::Expr]) -> bool {
group_by_exprs
.iter()
.any(|expr2| exprs_are_equivalent(expr, expr2))
}

View File

@@ -66,45 +66,41 @@ impl LoopLabels {
}
}
pub fn init_distinct(program: &mut ProgramBuilder, plan: &mut SelectPlan) {
if let Distinctness::Distinct { ctx } = &mut plan.distinctness {
assert!(
ctx.is_none(),
"distinctness context should not be allocated yet"
);
let index_name = format!("distinct_{}", program.offset().to_offset_int()); // we don't really care about the name that much, just enough that we don't get name collisions
let index = Arc::new(Index {
name: index_name.clone(),
table_name: String::new(),
ephemeral: true,
root_page: 0,
columns: plan
.result_columns
.iter()
.enumerate()
.map(|(i, col)| IndexColumn {
name: col.expr.to_string(),
order: SortOrder::Asc,
pos_in_table: i,
collation: None, // FIXME: this should be determined based on the result column expression!
default: None, // FIXME: this should be determined based on the result column expression!
})
.collect(),
unique: false,
has_rowid: false,
});
let cursor_id = program.alloc_cursor_id(CursorType::BTreeIndex(index.clone()));
*ctx = Some(DistinctCtx {
cursor_id,
ephemeral_index_name: index_name,
label_on_conflict: program.allocate_label(),
});
pub fn init_distinct(program: &mut ProgramBuilder, plan: &SelectPlan) -> DistinctCtx {
let index_name = format!("distinct_{}", program.offset().to_offset_int()); // we don't really care about the name that much, just enough that we don't get name collisions
let index = Arc::new(Index {
name: index_name.clone(),
table_name: String::new(),
ephemeral: true,
root_page: 0,
columns: plan
.result_columns
.iter()
.enumerate()
.map(|(i, col)| IndexColumn {
name: col.expr.to_string(),
order: SortOrder::Asc,
pos_in_table: i,
collation: None, // FIXME: this should be determined based on the result column expression!
default: None, // FIXME: this should be determined based on the result column expression!
})
.collect(),
unique: false,
has_rowid: false,
});
let cursor_id = program.alloc_cursor_id(CursorType::BTreeIndex(index.clone()));
let ctx = DistinctCtx {
cursor_id,
ephemeral_index_name: index_name,
label_on_conflict: program.allocate_label(),
};
program.emit_insn(Insn::OpenEphemeral {
cursor_id,
is_table: false,
});
}
program.emit_insn(Insn::OpenEphemeral {
cursor_id,
is_table: false,
});
return ctx;
}
/// Initialize resources needed for the source operators (tables, joins, etc)
@@ -765,7 +761,6 @@ fn emit_loop_source<'a>(
// 3) aggregate function arguments
// - or if the rows produced by the loop are already sorted in the order required by the GROUP BY keys,
// the group by comparisons are done directly inside the main loop.
let group_by = plan.group_by.as_ref().unwrap();
let aggregates = &plan.aggregates;
let GroupByMetadata {
@@ -777,9 +772,15 @@ fn emit_loop_source<'a>(
let start_reg = registers.reg_group_by_source_cols_start;
let mut cur_reg = start_reg;
// Step 1: Process GROUP BY columns first
// These will be the first columns in the sorter and serve as sort keys
for expr in group_by.exprs.iter() {
// Collect all non-aggregate expressions in the following order:
// 1. GROUP BY expressions. These serve as sort keys.
// 2. Remaining non-aggregate expressions that are not in GROUP BY.
//
// Example:
// SELECT col1, col2, SUM(col3) FROM table GROUP BY col1
// - col1 is added first (from GROUP BY)
// - col2 is added second (non-aggregate, in SELECT, not in GROUP BY)
for (expr, _) in t_ctx.non_aggregate_expressions.iter() {
let key_reg = cur_reg;
cur_reg += 1;
translate_expr(
@@ -791,22 +792,7 @@ fn emit_loop_source<'a>(
)?;
}
// Step 2: Process columns that aren't part of GROUP BY and don't contain aggregates
// Example: SELECT col1, col2, SUM(col3) FROM table GROUP BY col1
// Here col2 would be processed in this loop if it's in the result set
for expr in plan.non_group_by_non_agg_columns() {
let key_reg = cur_reg;
cur_reg += 1;
translate_expr(
program,
Some(&plan.table_references),
expr,
key_reg,
&t_ctx.resolver,
)?;
}
// Step 3: Process arguments for all aggregate functions
// Step 2: Process arguments for all aggregate functions
// For each aggregate, translate all its argument expressions
for agg in aggregates.iter() {
// For a query like: SELECT group_col, SUM(val1), AVG(val2) FROM table GROUP BY group_col

View File

@@ -5,7 +5,6 @@ use std::{cell::Cell, cmp::Ordering, rc::Rc, sync::Arc};
use crate::{
function::AggFunc,
schema::{BTreeTable, Column, FromClauseSubquery, Index, Table},
util::exprs_are_equivalent,
vdbe::{
builder::{CursorKey, CursorType, ProgramBuilder},
insn::{IdxInsertFlags, Insn},
@@ -454,35 +453,6 @@ impl SelectPlan {
self.aggregates.iter().map(|agg| agg.args.len()).sum()
}
pub fn group_by_col_count(&self) -> usize {
self.group_by
.as_ref()
.map_or(0, |group_by| group_by.exprs.len())
}
pub fn non_group_by_non_agg_columns(&self) -> impl Iterator<Item = &ast::Expr> {
self.result_columns
.iter()
.filter(|c| {
!c.contains_aggregates
&& !self.group_by.as_ref().map_or(false, |group_by| {
group_by
.exprs
.iter()
.any(|expr| exprs_are_equivalent(&c.expr, expr))
})
})
.map(|c| &c.expr)
}
pub fn non_group_by_non_agg_column_count(&self) -> usize {
self.non_group_by_non_agg_columns().count()
}
pub fn group_by_sorter_column_count(&self) -> usize {
self.agg_args_count() + self.group_by_col_count() + self.non_group_by_non_agg_column_count()
}
/// Reference: https://github.com/sqlite/sqlite/blob/5db695197b74580c777b37ab1b787531f15f7f9f/src/select.c#L8613
///
/// Checks to see if the query is of the format `SELECT count(*) FROM <tbl>`

View File

@@ -10,6 +10,7 @@ use super::{
select::prepare_select_plan,
SymbolTable,
};
use crate::translate::expr::WalkControl;
use crate::{
function::Func,
schema::{Schema, Table},
@@ -26,13 +27,13 @@ pub const ROWID: &str = "rowid";
pub fn resolve_aggregates(top_level_expr: &Expr, aggs: &mut Vec<Aggregate>) -> Result<bool> {
let mut contains_aggregates = false;
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> {
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<WalkControl> {
if aggs
.iter()
.any(|a| exprs_are_equivalent(&a.original_expr, expr))
{
contains_aggregates = true;
return Ok(());
return Ok(WalkControl::Continue);
}
match expr {
Expr::FunctionCall {
@@ -97,7 +98,7 @@ pub fn resolve_aggregates(top_level_expr: &Expr, aggs: &mut Vec<Aggregate>) -> R
_ => {}
}
Ok(())
Ok(WalkControl::Continue)
})?;
Ok(contains_aggregates)
@@ -639,7 +640,7 @@ pub fn table_mask_from_expr(
table_references: &TableReferences,
) -> Result<TableMask> {
let mut mask = TableMask::new();
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> {
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<WalkControl> {
match expr {
Expr::Column { table, .. } | Expr::RowId { table, .. } => {
if let Some(table_idx) = table_references
@@ -660,7 +661,7 @@ pub fn table_mask_from_expr(
}
_ => {}
}
Ok(())
Ok(WalkControl::Continue)
})?;
Ok(mask)
@@ -671,7 +672,7 @@ pub fn determine_where_to_eval_expr<'a>(
join_order: &[JoinOrderMember],
) -> Result<EvalAt> {
let mut eval_at: EvalAt = EvalAt::BeforeLoop;
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> {
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<WalkControl> {
match expr {
Expr::Column { table, .. } | Expr::RowId { table, .. } => {
let join_idx = join_order
@@ -682,7 +683,7 @@ pub fn determine_where_to_eval_expr<'a>(
}
_ => {}
}
Ok(())
Ok(WalkControl::Continue)
})?;
Ok(eval_at)

View File

@@ -81,6 +81,7 @@ pub fn emit_subquery<'a>(
reg_offset: None,
reg_limit_offset_sum: None,
resolver: Resolver::new(t_ctx.resolver.schema, t_ctx.resolver.symbol_table),
non_aggregate_expressions: Vec::new(),
};
let subquery_body_end_label = program.allocate_label();
program.emit_insn(Insn::InitCoroutine {

View File

@@ -1,3 +1,4 @@
use crate::translate::expr::WalkControl;
use crate::{
schema::{self, Column, Schema, Type},
translate::{collate::CollationSeq, expr::walk_expr, plan::JoinOrderMember},
@@ -589,7 +590,7 @@ pub fn can_pushdown_predicate(
join_order: &[JoinOrderMember],
) -> Result<bool> {
let mut can_pushdown = true;
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> {
walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<WalkControl> {
match expr {
Expr::Column { table, .. } | Expr::RowId { table, .. } => {
let join_idx = join_order
@@ -608,7 +609,7 @@ pub fn can_pushdown_predicate(
}
_ => {}
};
Ok(())
Ok(WalkControl::Continue)
})?;
Ok(can_pushdown)

69
testing/groupby.test Normal file → Executable file
View File

@@ -199,6 +199,23 @@ do_execsql_test group_by_no_sorting_required {
2|113
3|97}
if {[info exists ::env(SQLITE_EXEC)] && ($::env(SQLITE_EXEC) eq "scripts/limbo-sqlite3-index-experimental" || $::env(SQLITE_EXEC) eq "sqlite3")} {
do_execsql_test_on_specific_db {:memory:} group_by_no_sorting_required_reordered_columns {
create table t0 (a INT, b INT, c INT);
create index a_b_idx on t0 (a, b);
insert into t0 values
(1,1,1),
(1,1,2),
(2,1,3),
(2,2,3),
(2,2,5);
select c, b, a from t0 group by a, b;
} {1|1|1
3|1|2
3|2|2}
}
if {[info exists ::env(SQLITE_EXEC)] && ($::env(SQLITE_EXEC) eq "scripts/limbo-sqlite3-index-experimental" || $::env(SQLITE_EXEC) eq "sqlite3")} {
do_execsql_test distinct_agg_functions {
select first_name, sum(distinct age), count(distinct age), avg(distinct age)
@@ -224,3 +241,55 @@ do_execsql_test_on_specific_db {:memory:} having_or {
order by cnt desc
} {Michael|2|37.5
Sarah|1|65.0}
do_execsql_test complex_result_expression_containing_aggregate {
select
case when price > 70 then group_concat(name, ',') else '<undisclosed>' end names
from products
group by price
order by price;
} {<undisclosed>
<undisclosed>
<undisclosed>
<undisclosed>
<undisclosed>
sweatshirt
jeans
hat
accessories
cap,sneakers}
do_execsql_test complex_result_expression_containing_aggregate_and_rowid {
select
case when rowid >= 5 then group_concat(name, ',') else '<undisclosed>' end names
from products
group by rowid
order by rowid;
} {<undisclosed>
<undisclosed>
<undisclosed>
<undisclosed>
sweatshirt
shorts
jeans
sneakers
boots
coat
accessories}
do_execsql_test complex_having_expression_containing_aggregate {
select group_concat(name, ',') from products group by price having (group_concat(name, ',') || price) like 'ca%';
} {cap,sneakers}
do_execsql_test complex_order_by_expression_containing_aggregate {
select group_concat(name, ',') from products group by price order by (group_concat(name, ',') || price);
} {accessories
boots
cap,sneakers
coat
hat
jeans
shirt
shorts
sweater
sweatshirt}