diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index ec8ac0274..3fb1ff83f 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -41,6 +41,7 @@ pub fn emit_ungrouped_aggregation<'a>( .expr_to_reg_cache .push((&agg.original_expr, agg_start_reg + i)); } + t_ctx.resolver.enable_expr_to_reg_cache(); // This always emits a ResultRow because currently it can only be used for a single row result // Limit is None because we early exit on limit 0 and the max rows here is 1 diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 00dd9c0bc..14c3b2c5e 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -3,7 +3,7 @@ use std::rc::Rc; -use limbo_sqlite3_parser::ast::{self}; +use limbo_sqlite3_parser::ast::{self, Expr}; use tracing::{instrument, Level}; use super::aggregation::emit_ungrouped_aggregation; @@ -15,7 +15,9 @@ use super::main_loop::{ close_loop, emit_loop, init_distinct, init_loop, open_loop, LeftJoinMetadata, LoopLabels, }; use super::order_by::{emit_order_by, init_order_by, SortMetadata}; -use super::plan::{JoinOrderMember, Operation, SelectPlan, TableReferences, UpdatePlan}; +use super::plan::{ + Distinctness, JoinOrderMember, Operation, SelectPlan, TableReferences, UpdatePlan, +}; use super::select::emit_simple_count; use super::subquery::emit_subqueries; use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; @@ -33,6 +35,7 @@ use crate::{Result, SymbolTable}; pub struct Resolver<'a> { pub schema: &'a Schema, pub symbol_table: &'a SymbolTable, + pub expr_to_reg_cache_enabled: bool, pub expr_to_reg_cache: Vec<(&'a ast::Expr, usize)>, } @@ -41,6 +44,7 @@ impl<'a> Resolver<'a> { Self { schema, symbol_table, + expr_to_reg_cache_enabled: false, expr_to_reg_cache: Vec::new(), } } @@ -55,11 +59,19 @@ impl<'a> Resolver<'a> { } } + pub(crate) fn enable_expr_to_reg_cache(&mut self) { + self.expr_to_reg_cache_enabled = true; + } + pub fn resolve_cached_expr_reg(&self, expr: &ast::Expr) -> Option { - self.expr_to_reg_cache - .iter() - .find(|(e, _)| exprs_are_equivalent(expr, e)) - .map(|(_, reg)| *reg) + if self.expr_to_reg_cache_enabled { + self.expr_to_reg_cache + .iter() + .find(|(e, _)| exprs_are_equivalent(expr, e)) + .map(|(_, reg)| *reg) + } else { + None + } } } @@ -125,6 +137,17 @@ pub struct TranslateCtx<'a> { // This vector holds the indexes of the result columns that we need to skip. pub result_columns_to_skip_in_orderby_sorter: Option>, pub resolver: Resolver<'a>, + /// A list of expressions that are not aggregates, along with a flag indicating + /// whether the expression should be included in the output for each group. + /// + /// Each entry is a tuple: + /// - `&'ast Expr`: the expression itself + /// - `bool`: `true` if the expression should be included in the output for each group, `false` otherwise. + /// + /// The order of expressions is **significant**: + /// - First: all `GROUP BY` expressions, in the order they appear in the `GROUP BY` clause. + /// - Then: remaining non-aggregate expressions that are not part of `GROUP BY`. + pub non_aggregate_expressions: Vec<(&'a Expr, bool)>, } impl<'a> TranslateCtx<'a> { @@ -150,6 +173,7 @@ impl<'a> TranslateCtx<'a> { result_column_indexes_in_orderby_sorter: (0..result_column_count).collect(), result_columns_to_skip_in_orderby_sorter: None, resolver: Resolver::new(schema, syms), + non_aggregate_expressions: Vec::new(), } } } @@ -280,14 +304,28 @@ pub fn emit_query<'a>( } if let Some(ref group_by) = plan.group_by { - init_group_by(program, t_ctx, group_by, &plan)?; + init_group_by( + program, + t_ctx, + group_by, + &plan, + &plan.result_columns, + &plan.order_by, + )?; } else if !plan.aggregates.is_empty() { // Aggregate registers need to be NULLed at the start because the same registers might be reused on another invocation of a subquery, // and if they are not NULLed, the 2nd invocation of the same subquery will have values left over from the first invocation. t_ctx.reg_agg_start = Some(program.alloc_registers_and_init_w_null(plan.aggregates.len())); } - init_distinct(program, plan); + let distinct_ctx = if let Distinctness::Distinct { .. } = &plan.distinctness { + Some(init_distinct(program, plan)) + } else { + None + }; + if let Distinctness::Distinct { ctx } = &mut plan.distinctness { + *ctx = distinct_ctx + } init_loop( program, t_ctx, diff --git a/core/translate/expr.rs b/core/translate/expr.rs index dbcdb47ff..851a26b5f 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -2661,186 +2661,195 @@ pub fn unwrap_parens_owned(expr: ast::Expr) -> Result<(ast::Expr, usize)> { } } -/// Recursively walks an immutable expression, applying a function to each sub-expression. -pub fn walk_expr<'a, F>(expr: &'a ast::Expr, func: &mut F) -> Result<()> -where - F: FnMut(&'a ast::Expr) -> Result<()>, -{ - func(expr)?; - match expr { - ast::Expr::Between { - lhs, start, end, .. - } => { - walk_expr(lhs, func)?; - walk_expr(start, func)?; - walk_expr(end, func)?; - } - ast::Expr::Binary(lhs, _, rhs) => { - walk_expr(lhs, func)?; - walk_expr(rhs, func)?; - } - ast::Expr::Case { - base, - when_then_pairs, - else_expr, - } => { - if let Some(base_expr) = base { - walk_expr(base_expr, func)?; - } - for (when_expr, then_expr) in when_then_pairs { - walk_expr(when_expr, func)?; - walk_expr(then_expr, func)?; - } - if let Some(else_expr) = else_expr { - walk_expr(else_expr, func)?; - } - } - ast::Expr::Cast { expr, .. } => { - walk_expr(expr, func)?; - } - ast::Expr::Collate(expr, _) => { - walk_expr(expr, func)?; - } - ast::Expr::Exists(_select) | ast::Expr::Subquery(_select) => { - // TODO: Walk through select statements if needed - } - ast::Expr::FunctionCall { - args, - order_by, - filter_over, - .. - } => { - if let Some(args) = args { - for arg in args { - walk_expr(arg, func)?; - } - } - if let Some(order_by) = order_by { - for sort_col in order_by { - walk_expr(&sort_col.expr, func)?; - } - } - if let Some(filter_over) = filter_over { - if let Some(filter_clause) = &filter_over.filter_clause { - walk_expr(filter_clause, func)?; - } - if let Some(over_clause) = &filter_over.over_clause { - match over_clause.as_ref() { - ast::Over::Window(window) => { - if let Some(partition_by) = &window.partition_by { - for part_expr in partition_by { - walk_expr(part_expr, func)?; - } - } - if let Some(order_by_clause) = &window.order_by { - for sort_col in order_by_clause { - walk_expr(&sort_col.expr, func)?; - } - } - if let Some(frame_clause) = &window.frame_clause { - walk_expr_frame_bound(&frame_clause.start, func)?; - if let Some(end_bound) = &frame_clause.end { - walk_expr_frame_bound(end_bound, func)?; - } - } - } - ast::Over::Name(_) => {} - } - } - } - } - ast::Expr::FunctionCallStar { filter_over, .. } => { - if let Some(filter_over) = filter_over { - if let Some(filter_clause) = &filter_over.filter_clause { - walk_expr(filter_clause, func)?; - } - if let Some(over_clause) = &filter_over.over_clause { - match over_clause.as_ref() { - ast::Over::Window(window) => { - if let Some(partition_by) = &window.partition_by { - for part_expr in partition_by { - walk_expr(part_expr, func)?; - } - } - if let Some(order_by_clause) = &window.order_by { - for sort_col in order_by_clause { - walk_expr(&sort_col.expr, func)?; - } - } - if let Some(frame_clause) = &window.frame_clause { - walk_expr_frame_bound(&frame_clause.start, func)?; - if let Some(end_bound) = &frame_clause.end { - walk_expr_frame_bound(end_bound, func)?; - } - } - } - ast::Over::Name(_) => {} - } - } - } - } - ast::Expr::InList { lhs, rhs, .. } => { - walk_expr(lhs, func)?; - if let Some(rhs_exprs) = rhs { - for expr in rhs_exprs { - walk_expr(expr, func)?; - } - } - } - ast::Expr::InSelect { lhs, rhs: _, .. } => { - walk_expr(lhs, func)?; - // TODO: Walk through select statements if needed - } - ast::Expr::InTable { lhs, args, .. } => { - walk_expr(lhs, func)?; - if let Some(arg_exprs) = args { - for expr in arg_exprs { - walk_expr(expr, func)?; - } - } - } - ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => { - walk_expr(expr, func)?; - } - ast::Expr::Like { - lhs, rhs, escape, .. - } => { - walk_expr(lhs, func)?; - walk_expr(rhs, func)?; - if let Some(esc_expr) = escape { - walk_expr(esc_expr, func)?; - } - } - ast::Expr::Parenthesized(exprs) => { - for expr in exprs { - walk_expr(expr, func)?; - } - } - ast::Expr::Raise(_, expr) => { - if let Some(raise_expr) = expr { - walk_expr(raise_expr, func)?; - } - } - ast::Expr::Unary(_, expr) => { - walk_expr(expr, func)?; - } - ast::Expr::Id(_) - | ast::Expr::Column { .. } - | ast::Expr::RowId { .. } - | ast::Expr::Literal(_) - | ast::Expr::DoublyQualified(..) - | ast::Expr::Name(_) - | ast::Expr::Qualified(..) - | ast::Expr::Variable(_) => { - // No nested expressions - } - } - Ok(()) +pub enum WalkControl { + Continue, // Visit children + SkipChildren, // Skip children but continue walking siblings } -fn walk_expr_frame_bound<'a, F>(bound: &'a ast::FrameBound, func: &mut F) -> Result<()> +/// Recursively walks an immutable expression, applying a function to each sub-expression. +pub fn walk_expr<'a, F>(expr: &'a ast::Expr, func: &mut F) -> Result where - F: FnMut(&'a ast::Expr) -> Result<()>, + F: FnMut(&'a ast::Expr) -> Result, +{ + match func(expr)? { + WalkControl::Continue => { + match expr { + ast::Expr::Between { + lhs, start, end, .. + } => { + walk_expr(lhs, func)?; + walk_expr(start, func)?; + walk_expr(end, func)?; + } + ast::Expr::Binary(lhs, _, rhs) => { + walk_expr(lhs, func)?; + walk_expr(rhs, func)?; + } + ast::Expr::Case { + base, + when_then_pairs, + else_expr, + } => { + if let Some(base_expr) = base { + walk_expr(base_expr, func)?; + } + for (when_expr, then_expr) in when_then_pairs { + walk_expr(when_expr, func)?; + walk_expr(then_expr, func)?; + } + if let Some(else_expr) = else_expr { + walk_expr(else_expr, func)?; + } + } + ast::Expr::Cast { expr, .. } => { + walk_expr(expr, func)?; + } + ast::Expr::Collate(expr, _) => { + walk_expr(expr, func)?; + } + ast::Expr::Exists(_select) | ast::Expr::Subquery(_select) => { + // TODO: Walk through select statements if needed + } + ast::Expr::FunctionCall { + args, + order_by, + filter_over, + .. + } => { + if let Some(args) = args { + for arg in args { + walk_expr(arg, func)?; + } + } + if let Some(order_by) = order_by { + for sort_col in order_by { + walk_expr(&sort_col.expr, func)?; + } + } + if let Some(filter_over) = filter_over { + if let Some(filter_clause) = &filter_over.filter_clause { + walk_expr(filter_clause, func)?; + } + if let Some(over_clause) = &filter_over.over_clause { + match over_clause.as_ref() { + ast::Over::Window(window) => { + if let Some(partition_by) = &window.partition_by { + for part_expr in partition_by { + walk_expr(part_expr, func)?; + } + } + if let Some(order_by_clause) = &window.order_by { + for sort_col in order_by_clause { + walk_expr(&sort_col.expr, func)?; + } + } + if let Some(frame_clause) = &window.frame_clause { + walk_expr_frame_bound(&frame_clause.start, func)?; + if let Some(end_bound) = &frame_clause.end { + walk_expr_frame_bound(end_bound, func)?; + } + } + } + ast::Over::Name(_) => {} + } + } + } + } + ast::Expr::FunctionCallStar { filter_over, .. } => { + if let Some(filter_over) = filter_over { + if let Some(filter_clause) = &filter_over.filter_clause { + walk_expr(filter_clause, func)?; + } + if let Some(over_clause) = &filter_over.over_clause { + match over_clause.as_ref() { + ast::Over::Window(window) => { + if let Some(partition_by) = &window.partition_by { + for part_expr in partition_by { + walk_expr(part_expr, func)?; + } + } + if let Some(order_by_clause) = &window.order_by { + for sort_col in order_by_clause { + walk_expr(&sort_col.expr, func)?; + } + } + if let Some(frame_clause) = &window.frame_clause { + walk_expr_frame_bound(&frame_clause.start, func)?; + if let Some(end_bound) = &frame_clause.end { + walk_expr_frame_bound(end_bound, func)?; + } + } + } + ast::Over::Name(_) => {} + } + } + } + } + ast::Expr::InList { lhs, rhs, .. } => { + walk_expr(lhs, func)?; + if let Some(rhs_exprs) = rhs { + for expr in rhs_exprs { + walk_expr(expr, func)?; + } + } + } + ast::Expr::InSelect { lhs, rhs: _, .. } => { + walk_expr(lhs, func)?; + // TODO: Walk through select statements if needed + } + ast::Expr::InTable { lhs, args, .. } => { + walk_expr(lhs, func)?; + if let Some(arg_exprs) = args { + for expr in arg_exprs { + walk_expr(expr, func)?; + } + } + } + ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => { + walk_expr(expr, func)?; + } + ast::Expr::Like { + lhs, rhs, escape, .. + } => { + walk_expr(lhs, func)?; + walk_expr(rhs, func)?; + if let Some(esc_expr) = escape { + walk_expr(esc_expr, func)?; + } + } + ast::Expr::Parenthesized(exprs) => { + for expr in exprs { + walk_expr(expr, func)?; + } + } + ast::Expr::Raise(_, expr) => { + if let Some(raise_expr) = expr { + walk_expr(raise_expr, func)?; + } + } + ast::Expr::Unary(_, expr) => { + walk_expr(expr, func)?; + } + ast::Expr::Id(_) + | ast::Expr::Column { .. } + | ast::Expr::RowId { .. } + | ast::Expr::Literal(_) + | ast::Expr::DoublyQualified(..) + | ast::Expr::Name(_) + | ast::Expr::Qualified(..) + | ast::Expr::Variable(_) => { + // No nested expressions + } + } + } + WalkControl::SkipChildren => return Ok(WalkControl::Continue), + }; + Ok(WalkControl::Continue) +} + +fn walk_expr_frame_bound<'a, F>(bound: &'a ast::FrameBound, func: &mut F) -> Result +where + F: FnMut(&'a ast::Expr) -> Result, { match bound { ast::FrameBound::Following(expr) | ast::FrameBound::Preceding(expr) => { @@ -2851,7 +2860,7 @@ where | ast::FrameBound::UnboundedPreceding => {} } - Ok(()) + Ok(WalkControl::Continue) } /// Recursively walks a mutable expression, applying a function to each sub-expression. diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index b3875b264..cb0b56076 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -2,6 +2,8 @@ use std::rc::Rc; use limbo_sqlite3_parser::ast; +use crate::translate::expr::{walk_expr, WalkControl}; +use crate::translate::plan::ResultSetColumn; use crate::{ function::AggFunc, schema::{Column, PseudoTable}, @@ -76,23 +78,24 @@ pub struct GroupByMetadata { pub row_source: GroupByRowSource, pub labels: GroupByLabels, pub registers: GroupByRegisters, - // Columns that not part of GROUP BY clause and not arguments of Aggregation function. - // Heavy calculation and needed in different functions, so it is reasonable to do it once and save. - pub non_group_by_non_agg_column_count: usize, } /// Initialize resources needed for GROUP BY processing -pub fn init_group_by( +pub fn init_group_by<'a>( program: &mut ProgramBuilder, - t_ctx: &mut TranslateCtx, - group_by: &GroupBy, + t_ctx: &mut TranslateCtx<'a>, + group_by: &'a GroupBy, plan: &SelectPlan, + result_columns: &'a Vec, + order_by: &'a Option>, ) -> Result<()> { - let non_aggregate_count = plan - .result_columns - .iter() - .filter(|rc| !rc.contains_aggregates) - .count(); + collect_non_aggregate_expressions( + &mut t_ctx.non_aggregate_expressions, + group_by, + plan, + result_columns, + order_by, + )?; let label_subrtn_acc_output = program.allocate_label(); let label_group_by_end_without_emitting_row = program.allocate_label(); @@ -112,7 +115,8 @@ pub fn init_group_by( // The following two blocks of registers should always be allocated contiguously, // because they are cleared in a contiguous block in the GROUP BYs clear accumulator subroutine. // START BLOCK - let reg_non_aggregate_exprs_acc = program.alloc_registers(non_aggregate_count); + let reg_non_aggregate_exprs_acc = + program.alloc_registers(t_ctx.non_aggregate_expressions.len()); if !plan.aggregates.is_empty() { // Aggregate registers need to be NULLed at the start because the same registers might be reused on another invocation of a subquery, // and if they are not NULLed, the 2nd invocation of the same subquery will have values left over from the first invocation. @@ -121,12 +125,11 @@ pub fn init_group_by( // END BLOCK let reg_sorter_key = program.alloc_register(); - let column_count = plan.group_by_sorter_column_count(); + let column_count = plan.agg_args_count() + t_ctx.non_aggregate_expressions.len(); let reg_group_by_source_cols_start = program.alloc_registers(column_count); let row_source = if let Some(sort_order) = group_by.sort_order.as_ref() { let sort_cursor = program.alloc_cursor_id(CursorType::Sorter); - let sorter_column_count = plan.group_by_sorter_column_count(); // Should work the same way as Order By /* * Terms of the ORDER BY clause that is part of a SELECT statement may be assigned a collating sequence using the COLLATE operator, @@ -160,21 +163,17 @@ pub fn init_group_by( program.emit_insn(Insn::SorterOpen { cursor_id: sort_cursor, - columns: sorter_column_count, + columns: column_count, order: sort_order.clone(), collations, }); - let pseudo_cursor = group_by_create_pseudo_table(program, sorter_column_count); + let pseudo_cursor = group_by_create_pseudo_table(program, column_count); GroupByRowSource::Sorter { pseudo_cursor, sort_cursor, reg_sorter_key, - column_register_mapping: group_by_create_column_register_mapping( - group_by, - reg_non_aggregate_exprs_acc, - plan, - ), - sorter_column_count, + sorter_column_count: column_count, + start_reg_dest: reg_non_aggregate_exprs_acc, } } else { GroupByRowSource::MainLoop { @@ -232,11 +231,72 @@ pub fn init_group_by( reg_subrtn_acc_clear_return_offset, reg_group_by_source_cols_start, }, - non_group_by_non_agg_column_count: plan.non_group_by_non_agg_column_count(), }); Ok(()) } +fn collect_non_aggregate_expressions<'a>( + non_aggregate_expressions: &mut Vec<(&'a ast::Expr, bool)>, + group_by: &'a GroupBy, + plan: &SelectPlan, + root_result_columns: &'a Vec, + order_by: &'a Option>, +) -> Result<()> { + let mut result_columns = Vec::new(); + for expr in root_result_columns + .iter() + .map(|col| &col.expr) + .chain(order_by.iter().flat_map(|o| o.iter().map(|(e, _)| e))) + .chain(group_by.having.iter().flatten()) + { + collect_result_columns(expr, plan, &mut result_columns)?; + } + + for group_expr in &group_by.exprs { + let in_result = result_columns + .iter() + .any(|expr| exprs_are_equivalent(expr, group_expr)); + non_aggregate_expressions.push((group_expr, in_result)); + } + for expr in result_columns { + let in_group_by = group_by + .exprs + .iter() + .any(|group_expr| exprs_are_equivalent(expr, group_expr)); + if !in_group_by { + non_aggregate_expressions.push((expr, true)); + } + } + Ok(()) +} + +fn collect_result_columns<'a>( + root_expr: &'a ast::Expr, + plan: &SelectPlan, + result_columns: &mut Vec<&'a ast::Expr>, +) -> Result<()> { + walk_expr(root_expr, &mut |expr: &ast::Expr| -> Result { + match expr { + ast::Expr::Column { table, .. } | ast::Expr::RowId { table, .. } => { + if plan + .table_references + .find_joined_table_by_internal_id(*table) + .is_some() + { + result_columns.push(expr); + } + } + _ => { + if plan.aggregates.iter().any(|a| a.original_expr == *expr) { + return Ok(WalkControl::SkipChildren); + } + } + }; + Ok(WalkControl::Continue) + })?; + Ok(()) +} + /// In case sorting is needed for GROUP BY, creates a pseudo table that matches /// the number of columns in the GROUP BY sorter. Rows are individually read /// from the sorter into this pseudo table and processed. @@ -338,9 +398,7 @@ pub enum GroupByRowSource { reg_sorter_key: usize, /// Number of columns in the GROUP BY sorter sorter_column_count: usize, - /// In case some result columns of the SELECT query are equivalent to GROUP BY members, - /// this mapping encodes their position. - column_register_mapping: Vec>, + start_reg_dest: usize, }, MainLoop { /// If GROUP BY rows are read directly in the main loop, start_reg is the first register @@ -454,17 +512,16 @@ impl<'a> GroupByAggArgumentSource<'a> { } /// Emits bytecode for processing a single GROUP BY group. -pub fn group_by_process_single_group( +pub fn group_by_process_single_group<'a>( program: &mut ProgramBuilder, - group_by: &GroupBy, - plan: &SelectPlan, - t_ctx: &TranslateCtx, + group_by: &'a GroupBy, + plan: &'a SelectPlan, + t_ctx: &mut TranslateCtx<'a>, ) -> Result<()> { let GroupByMetadata { registers, labels, row_source, - non_group_by_non_agg_column_count, .. } = t_ctx .meta_group_by @@ -549,7 +606,7 @@ pub fn group_by_process_single_group( // Process each aggregate function for the current row program.resolve_label(labels.label_grouping_agg_step, program.offset()); - let cursor_index = *non_group_by_non_agg_column_count + group_by.exprs.len(); // Skipping all columns in sorter that not an aggregation arguments + let cursor_index = t_ctx.non_aggregate_expressions.len(); // Skipping all columns in sorter that not an aggregation arguments let mut offset = 0; for (i, agg) in plan.aggregates.iter().enumerate() { let start_reg = t_ctx @@ -567,8 +624,7 @@ pub fn group_by_process_single_group( } GroupByRowSource::MainLoop { start_reg_src, .. } => { // Aggregation arguments are always placed in the registers that follow any scalars. - let start_reg_aggs = - start_reg_src + group_by.exprs.len() + plan.non_group_by_non_agg_column_count(); + let start_reg_aggs = start_reg_src + t_ctx.non_aggregate_expressions.len(); GroupByAggArgumentSource::new_from_registers(start_reg_aggs + offset, agg) } }; @@ -604,27 +660,32 @@ pub fn group_by_process_single_group( match row_source { GroupByRowSource::Sorter { pseudo_cursor, - column_register_mapping, + start_reg_dest, .. } => { - for (sorter_column_index, dest_reg) in column_register_mapping.iter().enumerate() { - if let Some(dest_reg) = dest_reg { - program.emit_column(*pseudo_cursor, sorter_column_index, *dest_reg); + let mut sorter_column_index = 0; + let mut next_reg = *start_reg_dest; + + for (expr, in_result) in t_ctx.non_aggregate_expressions.iter() { + if *in_result { + program.emit_column(*pseudo_cursor, sorter_column_index, next_reg); + t_ctx.resolver.expr_to_reg_cache.push((expr, next_reg)); + next_reg += 1; } + sorter_column_index += 1; } } GroupByRowSource::MainLoop { start_reg_dest, .. } => { // Re-translate all the non-aggregate expressions into destination registers. We cannot use the same registers as emitted // in the earlier part of the main loop, because they would be overwritten by the next group before the group results // are processed. - for (i, rc) in plan - .result_columns + for (i, expr) in t_ctx + .non_aggregate_expressions .iter() - .filter(|rc| !rc.contains_aggregates) + .filter_map(|(expr, in_result)| if *in_result { Some(expr) } else { None }) .enumerate() { let dest_reg = start_reg_dest + i; - let expr = &rc.expr; translate_expr( program, Some(&plan.table_references), @@ -632,6 +693,7 @@ pub fn group_by_process_single_group( dest_reg, &t_ctx.resolver, )?; + t_ctx.resolver.expr_to_reg_cache.push((expr, dest_reg)); } } } @@ -647,44 +709,6 @@ pub fn group_by_process_single_group( Ok(()) } -pub fn group_by_create_column_register_mapping( - group_by: &GroupBy, - reg_non_aggregate_exprs_acc: usize, - plan: &SelectPlan, -) -> Vec> { - // We have to know which group by expr present in resulting set - let group_by_expr_in_res_cols = group_by.exprs.iter().map(|expr| { - plan.result_columns - .iter() - .any(|e| exprs_are_equivalent(&e.expr, expr)) - }); - - let group_by_count = group_by.exprs.len(); - let non_group_by_non_agg_column_count = plan.non_group_by_non_agg_column_count(); - - // Create a map from sorter column index to result register - // This helps track where each column from the sorter should be stored - let mut column_register_mapping = - vec![None; group_by_count + non_group_by_non_agg_column_count]; - let mut next_reg = reg_non_aggregate_exprs_acc; - - // Map GROUP BY columns that are in the result set to registers - for (i, is_in_result) in group_by_expr_in_res_cols.clone().enumerate() { - if is_in_result { - column_register_mapping[i] = Some(next_reg); - next_reg += 1; - } - } - - // Handle other non-aggregate columns that aren't part of GROUP BY and not part of Aggregation function - for i in group_by_count..group_by_count + non_group_by_non_agg_column_count { - column_register_mapping[i] = Some(next_reg); - next_reg += 1; - } - - column_register_mapping -} - /// Emits the bytecode for processing the aggregation phase of a GROUP BY clause. /// This is called either when: /// 1. the main query execution loop has finished processing, @@ -731,10 +755,7 @@ pub fn group_by_emit_row_phase<'a>( ) -> Result<()> { let group_by = plan.group_by.as_ref().expect("group by not found"); let GroupByMetadata { - row_source, - labels, - registers, - .. + labels, registers, .. } = t_ctx .meta_group_by .as_ref() @@ -795,82 +816,14 @@ pub fn group_by_emit_row_phase<'a>( register: agg_result_reg, func: agg.func.clone(), }); - } - - // We have to know which group by expr present in resulting set - let group_by_expr_in_res_cols = group_by.exprs.iter().map(|expr| { - plan.result_columns - .iter() - .any(|e| exprs_are_equivalent(&e.expr, expr)) - }); - - // Map GROUP BY expressions to their registers in the result set - for (i, (expr, is_in_result)) in group_by - .exprs - .iter() - .zip(group_by_expr_in_res_cols) - .enumerate() - { - if is_in_result { - match row_source { - GroupByRowSource::Sorter { - column_register_mapping, - .. - } => { - if let Some(reg) = column_register_mapping.get(i).and_then(|opt| *opt) { - t_ctx.resolver.expr_to_reg_cache.push((expr, reg)); - } - } - GroupByRowSource::MainLoop { start_reg_dest, .. } => { - t_ctx - .resolver - .expr_to_reg_cache - .push((expr, *start_reg_dest + i)); - } - } - } - } - - // Map non-aggregate, non-GROUP BY columns to their registers - let non_agg_cols = plan - .result_columns - .iter() - .filter(|rc| !rc.contains_aggregates && !is_column_in_group_by(&rc.expr, &group_by.exprs)); - - for (idx, rc) in non_agg_cols.enumerate() { - let column_relative_idx = plan.group_by_col_count() + idx; - match &row_source { - GroupByRowSource::Sorter { - column_register_mapping, - .. - } => { - if let Some(reg) = column_register_mapping - .get(column_relative_idx) - .and_then(|opt| *opt) - { - t_ctx.resolver.expr_to_reg_cache.push((&rc.expr, reg)); - } - } - GroupByRowSource::MainLoop { start_reg_dest, .. } => { - t_ctx - .resolver - .expr_to_reg_cache - .push((&rc.expr, start_reg_dest + column_relative_idx)); - } - } - } - - // Map aggregate expressions to their result registers - for (i, agg) in plan.aggregates.iter().enumerate() { - let agg_start_reg = t_ctx - .reg_agg_start - .expect("aggregate registers must be initialized"); t_ctx .resolver .expr_to_reg_cache - .push((&agg.original_expr, agg_start_reg + i)); + .push((&agg.original_expr, agg_result_reg)); } + t_ctx.resolver.enable_expr_to_reg_cache(); + if let Some(having) = &group_by.having { for expr in having.iter() { let if_true_target = program.allocate_label(); @@ -930,7 +883,9 @@ pub fn group_by_emit_row_phase<'a>( // Reset all accumulator registers to NULL program.emit_insn(Insn::Null { dest: start_reg, - dest_end: Some(start_reg + plan.group_by_sorter_column_count() - 1), + dest_end: Some( + start_reg + t_ctx.non_aggregate_expressions.len() + plan.agg_args_count() - 1, + ), }); // Reopen ephemeral indexes for distinct aggregates (effectively clearing them). @@ -1181,9 +1136,3 @@ pub fn translate_aggregation_step_groupby( }; Ok(dest) } - -pub fn is_column_in_group_by(expr: &ast::Expr, group_by_exprs: &[ast::Expr]) -> bool { - group_by_exprs - .iter() - .any(|expr2| exprs_are_equivalent(expr, expr2)) -} diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index fd8ec644f..1968ac343 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -66,45 +66,41 @@ impl LoopLabels { } } -pub fn init_distinct(program: &mut ProgramBuilder, plan: &mut SelectPlan) { - if let Distinctness::Distinct { ctx } = &mut plan.distinctness { - assert!( - ctx.is_none(), - "distinctness context should not be allocated yet" - ); - let index_name = format!("distinct_{}", program.offset().to_offset_int()); // we don't really care about the name that much, just enough that we don't get name collisions - let index = Arc::new(Index { - name: index_name.clone(), - table_name: String::new(), - ephemeral: true, - root_page: 0, - columns: plan - .result_columns - .iter() - .enumerate() - .map(|(i, col)| IndexColumn { - name: col.expr.to_string(), - order: SortOrder::Asc, - pos_in_table: i, - collation: None, // FIXME: this should be determined based on the result column expression! - default: None, // FIXME: this should be determined based on the result column expression! - }) - .collect(), - unique: false, - has_rowid: false, - }); - let cursor_id = program.alloc_cursor_id(CursorType::BTreeIndex(index.clone())); - *ctx = Some(DistinctCtx { - cursor_id, - ephemeral_index_name: index_name, - label_on_conflict: program.allocate_label(), - }); +pub fn init_distinct(program: &mut ProgramBuilder, plan: &SelectPlan) -> DistinctCtx { + let index_name = format!("distinct_{}", program.offset().to_offset_int()); // we don't really care about the name that much, just enough that we don't get name collisions + let index = Arc::new(Index { + name: index_name.clone(), + table_name: String::new(), + ephemeral: true, + root_page: 0, + columns: plan + .result_columns + .iter() + .enumerate() + .map(|(i, col)| IndexColumn { + name: col.expr.to_string(), + order: SortOrder::Asc, + pos_in_table: i, + collation: None, // FIXME: this should be determined based on the result column expression! + default: None, // FIXME: this should be determined based on the result column expression! + }) + .collect(), + unique: false, + has_rowid: false, + }); + let cursor_id = program.alloc_cursor_id(CursorType::BTreeIndex(index.clone())); + let ctx = DistinctCtx { + cursor_id, + ephemeral_index_name: index_name, + label_on_conflict: program.allocate_label(), + }; - program.emit_insn(Insn::OpenEphemeral { - cursor_id, - is_table: false, - }); - } + program.emit_insn(Insn::OpenEphemeral { + cursor_id, + is_table: false, + }); + + return ctx; } /// Initialize resources needed for the source operators (tables, joins, etc) @@ -765,7 +761,6 @@ fn emit_loop_source<'a>( // 3) aggregate function arguments // - or if the rows produced by the loop are already sorted in the order required by the GROUP BY keys, // the group by comparisons are done directly inside the main loop. - let group_by = plan.group_by.as_ref().unwrap(); let aggregates = &plan.aggregates; let GroupByMetadata { @@ -777,9 +772,15 @@ fn emit_loop_source<'a>( let start_reg = registers.reg_group_by_source_cols_start; let mut cur_reg = start_reg; - // Step 1: Process GROUP BY columns first - // These will be the first columns in the sorter and serve as sort keys - for expr in group_by.exprs.iter() { + // Collect all non-aggregate expressions in the following order: + // 1. GROUP BY expressions. These serve as sort keys. + // 2. Remaining non-aggregate expressions that are not in GROUP BY. + // + // Example: + // SELECT col1, col2, SUM(col3) FROM table GROUP BY col1 + // - col1 is added first (from GROUP BY) + // - col2 is added second (non-aggregate, in SELECT, not in GROUP BY) + for (expr, _) in t_ctx.non_aggregate_expressions.iter() { let key_reg = cur_reg; cur_reg += 1; translate_expr( @@ -791,22 +792,7 @@ fn emit_loop_source<'a>( )?; } - // Step 2: Process columns that aren't part of GROUP BY and don't contain aggregates - // Example: SELECT col1, col2, SUM(col3) FROM table GROUP BY col1 - // Here col2 would be processed in this loop if it's in the result set - for expr in plan.non_group_by_non_agg_columns() { - let key_reg = cur_reg; - cur_reg += 1; - translate_expr( - program, - Some(&plan.table_references), - expr, - key_reg, - &t_ctx.resolver, - )?; - } - - // Step 3: Process arguments for all aggregate functions + // Step 2: Process arguments for all aggregate functions // For each aggregate, translate all its argument expressions for agg in aggregates.iter() { // For a query like: SELECT group_col, SUM(val1), AVG(val2) FROM table GROUP BY group_col diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 9c1c9fea7..e5f8af284 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -5,7 +5,6 @@ use std::{cell::Cell, cmp::Ordering, rc::Rc, sync::Arc}; use crate::{ function::AggFunc, schema::{BTreeTable, Column, FromClauseSubquery, Index, Table}, - util::exprs_are_equivalent, vdbe::{ builder::{CursorKey, CursorType, ProgramBuilder}, insn::{IdxInsertFlags, Insn}, @@ -454,35 +453,6 @@ impl SelectPlan { self.aggregates.iter().map(|agg| agg.args.len()).sum() } - pub fn group_by_col_count(&self) -> usize { - self.group_by - .as_ref() - .map_or(0, |group_by| group_by.exprs.len()) - } - - pub fn non_group_by_non_agg_columns(&self) -> impl Iterator { - self.result_columns - .iter() - .filter(|c| { - !c.contains_aggregates - && !self.group_by.as_ref().map_or(false, |group_by| { - group_by - .exprs - .iter() - .any(|expr| exprs_are_equivalent(&c.expr, expr)) - }) - }) - .map(|c| &c.expr) - } - - pub fn non_group_by_non_agg_column_count(&self) -> usize { - self.non_group_by_non_agg_columns().count() - } - - pub fn group_by_sorter_column_count(&self) -> usize { - self.agg_args_count() + self.group_by_col_count() + self.non_group_by_non_agg_column_count() - } - /// Reference: https://github.com/sqlite/sqlite/blob/5db695197b74580c777b37ab1b787531f15f7f9f/src/select.c#L8613 /// /// Checks to see if the query is of the format `SELECT count(*) FROM ` diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 77af53126..1bf6a7452 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -10,6 +10,7 @@ use super::{ select::prepare_select_plan, SymbolTable, }; +use crate::translate::expr::WalkControl; use crate::{ function::Func, schema::{Schema, Table}, @@ -26,13 +27,13 @@ pub const ROWID: &str = "rowid"; pub fn resolve_aggregates(top_level_expr: &Expr, aggs: &mut Vec) -> Result { let mut contains_aggregates = false; - walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> { + walk_expr(top_level_expr, &mut |expr: &Expr| -> Result { if aggs .iter() .any(|a| exprs_are_equivalent(&a.original_expr, expr)) { contains_aggregates = true; - return Ok(()); + return Ok(WalkControl::Continue); } match expr { Expr::FunctionCall { @@ -97,7 +98,7 @@ pub fn resolve_aggregates(top_level_expr: &Expr, aggs: &mut Vec) -> R _ => {} } - Ok(()) + Ok(WalkControl::Continue) })?; Ok(contains_aggregates) @@ -639,7 +640,7 @@ pub fn table_mask_from_expr( table_references: &TableReferences, ) -> Result { let mut mask = TableMask::new(); - walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> { + walk_expr(top_level_expr, &mut |expr: &Expr| -> Result { match expr { Expr::Column { table, .. } | Expr::RowId { table, .. } => { if let Some(table_idx) = table_references @@ -660,7 +661,7 @@ pub fn table_mask_from_expr( } _ => {} } - Ok(()) + Ok(WalkControl::Continue) })?; Ok(mask) @@ -671,7 +672,7 @@ pub fn determine_where_to_eval_expr<'a>( join_order: &[JoinOrderMember], ) -> Result { let mut eval_at: EvalAt = EvalAt::BeforeLoop; - walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> { + walk_expr(top_level_expr, &mut |expr: &Expr| -> Result { match expr { Expr::Column { table, .. } | Expr::RowId { table, .. } => { let join_idx = join_order @@ -682,7 +683,7 @@ pub fn determine_where_to_eval_expr<'a>( } _ => {} } - Ok(()) + Ok(WalkControl::Continue) })?; Ok(eval_at) diff --git a/core/translate/subquery.rs b/core/translate/subquery.rs index 26c9df8bc..f6b526922 100644 --- a/core/translate/subquery.rs +++ b/core/translate/subquery.rs @@ -81,6 +81,7 @@ pub fn emit_subquery<'a>( reg_offset: None, reg_limit_offset_sum: None, resolver: Resolver::new(t_ctx.resolver.schema, t_ctx.resolver.symbol_table), + non_aggregate_expressions: Vec::new(), }; let subquery_body_end_label = program.allocate_label(); program.emit_insn(Insn::InitCoroutine { diff --git a/core/util.rs b/core/util.rs index b1489333e..49c13a9f4 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1,3 +1,4 @@ +use crate::translate::expr::WalkControl; use crate::{ schema::{self, Column, Schema, Type}, translate::{collate::CollationSeq, expr::walk_expr, plan::JoinOrderMember}, @@ -589,7 +590,7 @@ pub fn can_pushdown_predicate( join_order: &[JoinOrderMember], ) -> Result { let mut can_pushdown = true; - walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> { + walk_expr(top_level_expr, &mut |expr: &Expr| -> Result { match expr { Expr::Column { table, .. } | Expr::RowId { table, .. } => { let join_idx = join_order @@ -608,7 +609,7 @@ pub fn can_pushdown_predicate( } _ => {} }; - Ok(()) + Ok(WalkControl::Continue) })?; Ok(can_pushdown) diff --git a/testing/groupby.test b/testing/groupby.test old mode 100644 new mode 100755 index 5d093068f..1c7124ecd --- a/testing/groupby.test +++ b/testing/groupby.test @@ -199,6 +199,23 @@ do_execsql_test group_by_no_sorting_required { 2|113 3|97} +if {[info exists ::env(SQLITE_EXEC)] && ($::env(SQLITE_EXEC) eq "scripts/limbo-sqlite3-index-experimental" || $::env(SQLITE_EXEC) eq "sqlite3")} { + do_execsql_test_on_specific_db {:memory:} group_by_no_sorting_required_reordered_columns { + create table t0 (a INT, b INT, c INT); + create index a_b_idx on t0 (a, b); + insert into t0 values + (1,1,1), + (1,1,2), + (2,1,3), + (2,2,3), + (2,2,5); + + select c, b, a from t0 group by a, b; + } {1|1|1 + 3|1|2 + 3|2|2} +} + if {[info exists ::env(SQLITE_EXEC)] && ($::env(SQLITE_EXEC) eq "scripts/limbo-sqlite3-index-experimental" || $::env(SQLITE_EXEC) eq "sqlite3")} { do_execsql_test distinct_agg_functions { select first_name, sum(distinct age), count(distinct age), avg(distinct age) @@ -224,3 +241,55 @@ do_execsql_test_on_specific_db {:memory:} having_or { order by cnt desc } {Michael|2|37.5 Sarah|1|65.0} + +do_execsql_test complex_result_expression_containing_aggregate { + select + case when price > 70 then group_concat(name, ',') else '' end names + from products + group by price + order by price; +} { + + + + +sweatshirt +jeans +hat +accessories +cap,sneakers} + +do_execsql_test complex_result_expression_containing_aggregate_and_rowid { + select + case when rowid >= 5 then group_concat(name, ',') else '' end names + from products + group by rowid + order by rowid; +} { + + + +sweatshirt +shorts +jeans +sneakers +boots +coat +accessories} + +do_execsql_test complex_having_expression_containing_aggregate { + select group_concat(name, ',') from products group by price having (group_concat(name, ',') || price) like 'ca%'; +} {cap,sneakers} + +do_execsql_test complex_order_by_expression_containing_aggregate { + select group_concat(name, ',') from products group by price order by (group_concat(name, ',') || price); +} {accessories +boots +cap,sneakers +coat +hat +jeans +shirt +shorts +sweater +sweatshirt}