From 3f80e41e7aad116ea79d475cd7c2430c79165872 Mon Sep 17 00:00:00 2001 From: jussisaurio Date: Thu, 28 Nov 2024 23:48:21 +0200 Subject: [PATCH 1/2] support HAVING --- core/translate/emitter.rs | 66 ++++++++++++++++++++++----------- core/translate/expr.rs | 1 - core/translate/plan.rs | 9 ++++- core/translate/planner.rs | 77 ++++++++++++++++++++++++++------------- core/types.rs | 4 +- core/vdbe/mod.rs | 39 ++++++++++++++++---- testing/groupby.test | 33 ++++++++++++++++- 7 files changed, 169 insertions(+), 60 deletions(-) diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 8548b28b1..a9bb236cd 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -20,7 +20,7 @@ use super::expr::{ ConditionMetadata, }; use super::optimizer::Optimizable; -use super::plan::{Aggregate, BTreeTableReference, Direction, Plan}; +use super::plan::{Aggregate, BTreeTableReference, Direction, GroupBy, Plan}; use super::plan::{ResultSetColumn, SourceOperator}; // Metadata for handling LEFT JOIN operations @@ -284,7 +284,7 @@ fn init_order_by( /// Initialize resources needed for GROUP BY processing fn init_group_by( program: &mut ProgramBuilder, - group_by: &Vec, + group_by: &GroupBy, aggregates: &Vec, metadata: &mut Metadata, ) -> Result<()> { @@ -296,8 +296,8 @@ fn init_group_by( let abort_flag_register = program.alloc_register(); let data_in_accumulator_indicator_register = program.alloc_register(); - let group_exprs_comparison_register = program.alloc_registers(group_by.len()); - let group_exprs_accumulator_register = program.alloc_registers(group_by.len()); + let group_exprs_comparison_register = program.alloc_registers(group_by.exprs.len()); + let group_exprs_accumulator_register = program.alloc_registers(group_by.exprs.len()); let agg_exprs_start_reg = program.alloc_registers(num_aggs); let sorter_key_register = program.alloc_register(); @@ -306,12 +306,12 @@ fn init_group_by( let mut order = Vec::new(); const ASCENDING: i64 = 0; - for _ in group_by.iter() { + for _ in group_by.exprs.iter() { order.push(OwnedValue::Integer(ASCENDING)); } program.emit_insn(Insn::SorterOpen { cursor_id: sort_cursor, - columns: aggregates.len() + group_by.len(), + columns: aggregates.len() + group_by.exprs.len(), order: OwnedRecord::new(order), }); @@ -327,8 +327,8 @@ fn init_group_by( ); program.emit_insn(Insn::Null { dest: group_exprs_comparison_register, - dest_end: if group_by.len() > 1 { - Some(group_exprs_comparison_register + group_by.len() - 1) + dest_end: if group_by.exprs.len() > 1 { + Some(group_exprs_comparison_register + group_by.exprs.len() - 1) } else { None }, @@ -787,7 +787,7 @@ fn open_loop( /// - a ResultRow (there is none of the above, so the loop emits a result row directly) pub enum InnerLoopEmitTarget<'a> { GroupBySorter { - group_by: &'a Vec, + group_by: &'a GroupBy, aggregates: &'a Vec, }, OrderBySorter { @@ -883,7 +883,7 @@ fn inner_loop_source_emit( group_by, aggregates, } => { - let sort_keys_count = group_by.len(); + let sort_keys_count = group_by.exprs.len(); let aggregate_arguments_count = aggregates.iter().map(|agg| agg.args.len()).sum::(); let column_count = sort_keys_count + aggregate_arguments_count; @@ -891,7 +891,7 @@ fn inner_loop_source_emit( let mut cur_reg = start_reg; // The group by sorter rows will contain the grouping keys first. They are also the sort keys. - for expr in group_by.iter() { + for expr in group_by.exprs.iter() { let key_reg = cur_reg; cur_reg += 1; translate_expr(program, Some(referenced_tables), expr, key_reg, None)?; @@ -1127,7 +1127,7 @@ fn close_loop( fn group_by_emit( program: &mut ProgramBuilder, result_columns: &Vec, - group_by: &Vec, + group_by: &GroupBy, order_by: Option<&Vec<(ast::Expr, Direction)>>, aggregates: &Vec, limit: Option, @@ -1156,7 +1156,7 @@ fn group_by_emit( // all group by columns and all arguments of agg functions are in the sorter. // the sort keys are the group by columns (the aggregation within groups is done based on how long the sort keys remain the same) let sorter_column_count = - group_by.len() + aggregates.iter().map(|agg| agg.args.len()).sum::(); + group_by.exprs.len() + aggregates.iter().map(|agg| agg.args.len()).sum::(); // sorter column names do not matter let pseudo_columns = (0..sorter_column_count) .map(|i| Column { @@ -1197,8 +1197,8 @@ fn group_by_emit( }); // Read the group by columns from the pseudo cursor - let groups_start_reg = program.alloc_registers(group_by.len()); - for i in 0..group_by.len() { + let groups_start_reg = program.alloc_registers(group_by.exprs.len()); + for i in 0..group_by.exprs.len() { let sorter_column_index = i; let group_reg = groups_start_reg + i; program.emit_insn(Insn::Column { @@ -1212,7 +1212,7 @@ fn group_by_emit( program.emit_insn(Insn::Compare { start_reg_a: comparison_register, start_reg_b: groups_start_reg, - count: group_by.len(), + count: group_by.exprs.len(), }); let agg_step_label = program.allocate_label(); @@ -1235,7 +1235,7 @@ fn group_by_emit( program.emit_insn(Insn::Move { source_reg: groups_start_reg, dest_reg: comparison_register, - count: group_by.len(), + count: group_by.exprs.len(), }); program.add_comment( @@ -1272,7 +1272,7 @@ fn group_by_emit( // Accumulate the values into the aggregations program.resolve_label(agg_step_label, program.offset()); let start_reg = metadata.aggregation_start_register.unwrap(); - let mut cursor_index = group_by.len(); + let mut cursor_index = group_by.exprs.len(); for (i, agg) in aggregates.iter().enumerate() { let agg_result_reg = start_reg + i; translate_aggregation_groupby( @@ -1301,7 +1301,7 @@ fn group_by_emit( ); // Read the group by columns for a finished group - for i in 0..group_by.len() { + for i in 0..group_by.exprs.len() { let key_reg = group_exprs_start_register + i; let sorter_column_index = i; program.emit_insn(Insn::Column { @@ -1369,6 +1369,11 @@ fn group_by_emit( }, termination_label, ); + let group_by_end_without_emitting_row_label = program.allocate_label(); + program.defer_label_resolution( + group_by_end_without_emitting_row_label, + program.offset() as usize, + ); program.emit_insn(Insn::Return { return_reg: group_by_metadata.subroutine_accumulator_output_return_offset_register, }); @@ -1390,14 +1395,31 @@ fn group_by_emit( // and the agg results in (agg_start_reg..agg_start_reg + aggregates.len() - 1) // we need to call translate_expr on each result column, but replace the expr with a register copy in case any part of the // result column expression matches a) a group by column or b) an aggregation result. - let mut precomputed_exprs_to_register = Vec::with_capacity(aggregates.len() + group_by.len()); - for (i, expr) in group_by.iter().enumerate() { + let mut precomputed_exprs_to_register = + Vec::with_capacity(aggregates.len() + group_by.exprs.len()); + for (i, expr) in group_by.exprs.iter().enumerate() { precomputed_exprs_to_register.push((expr, group_exprs_start_register + i)); } for (i, agg) in aggregates.iter().enumerate() { precomputed_exprs_to_register.push((&agg.original_expr, agg_start_reg + i)); } + if let Some(having) = &group_by.having { + for expr in having.iter() { + translate_condition_expr( + program, + referenced_tables, + expr, + ConditionMetadata { + jump_if_condition_is_true: false, + jump_target_when_false: group_by_end_without_emitting_row_label, + jump_target_when_true: i64::MAX, // unused + }, + Some(&precomputed_exprs_to_register), + )?; + } + } + match order_by { None => { emit_select_result( @@ -1433,7 +1455,7 @@ fn group_by_emit( let start_reg = group_by_metadata.group_exprs_accumulator_register; program.emit_insn(Insn::Null { dest: start_reg, - dest_end: Some(start_reg + group_by.len() + aggregates.len() - 1), + dest_end: Some(start_reg + group_by.exprs.len() + aggregates.len() - 1), }); program.emit_insn(Insn::Integer { diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 6c0b4437d..ac497183a 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -871,7 +871,6 @@ pub fn translate_expr( for arg in args.iter() { let reg = program.alloc_register(); start_reg = Some(start_reg.unwrap_or(reg)); - translate_expr( program, referenced_tables, diff --git a/core/translate/plan.rs b/core/translate/plan.rs index ef5d97948..8c402a6a0 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -19,6 +19,13 @@ pub struct ResultSetColumn { pub contains_aggregates: bool, } +#[derive(Debug)] +pub struct GroupBy { + pub exprs: Vec, + /// having clause split into a vec at 'AND' boundaries. + pub having: Option>, +} + #[derive(Debug)] pub struct Plan { /// A tree of sources (tables). @@ -28,7 +35,7 @@ pub struct Plan { /// where clause split into a vec at 'AND' boundaries. pub where_clause: Option>, /// group by clause - pub group_by: Option>, + pub group_by: Option, /// order by clause pub order_by: Option>, /// all the aggregates collected from the result columns, order by, and (TODO) having clauses diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 51706f108..8a0d88890 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,5 +1,5 @@ use super::plan::{ - Aggregate, BTreeTableReference, Direction, Plan, ResultSetColumn, SourceOperator, + Aggregate, BTreeTableReference, Direction, GroupBy, Plan, ResultSetColumn, SourceOperator, }; use crate::{function::Func, schema::Schema, util::normalize_ident, Result}; use sqlite3_parser::ast::{self, FromClause, JoinType, ResultColumn}; @@ -19,9 +19,9 @@ impl OperatorIdCounter { } } -fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec) { +fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec) -> bool { if aggs.iter().any(|a| a.original_expr == *expr) { - return; + return true; } match expr { ast::Expr::FunctionCall { name, args, .. } => { @@ -31,17 +31,22 @@ fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec) { 0 }; match Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), args_count) { - Ok(Func::Agg(f)) => aggs.push(Aggregate { - func: f, - args: args.clone().unwrap_or_default(), - original_expr: expr.clone(), - }), + Ok(Func::Agg(f)) => { + aggs.push(Aggregate { + func: f, + args: args.clone().unwrap_or_default(), + original_expr: expr.clone(), + }); + true + } _ => { + let mut contains_aggregates = false; if let Some(args) = args { for arg in args.iter() { - resolve_aggregates(arg, aggs); + contains_aggregates |= resolve_aggregates(arg, aggs); } } + contains_aggregates } } } @@ -53,15 +58,20 @@ fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec) { func: f, args: vec![], original_expr: expr.clone(), - }) + }); + true + } else { + false } } ast::Expr::Binary(lhs, _, rhs) => { - resolve_aggregates(lhs, aggs); - resolve_aggregates(rhs, aggs); + let mut contains_aggregates = false; + contains_aggregates |= resolve_aggregates(lhs, aggs); + contains_aggregates |= resolve_aggregates(rhs, aggs); + contains_aggregates } // TODO: handle other expressions that may contain aggregates - _ => {} + _ => false, } } @@ -340,10 +350,8 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

{ - let cur_agg_count = aggregate_expressions.len(); - resolve_aggregates(&expr, &mut aggregate_expressions); let contains_aggregates = - cur_agg_count != aggregate_expressions.len(); + resolve_aggregates(&expr, &mut aggregate_expressions); plan.result_columns.push(ResultSetColumn { expr: expr.clone(), contains_aggregates, @@ -380,10 +388,8 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

{ - let cur_agg_count = aggregate_expressions.len(); - resolve_aggregates(expr, &mut aggregate_expressions); let contains_aggregates = - cur_agg_count != aggregate_expressions.len(); + resolve_aggregates(expr, &mut aggregate_expressions); plan.result_columns.push(ResultSetColumn { expr: expr.clone(), contains_aggregates, @@ -393,18 +399,37 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

for OwnedValue { (OwnedValue::Null, _) => Some(std::cmp::Ordering::Less), (_, OwnedValue::Null) => Some(std::cmp::Ordering::Greater), (OwnedValue::Agg(a), OwnedValue::Agg(b)) => a.partial_cmp(b), - _ => None, + (OwnedValue::Agg(a), other) => a.final_value().partial_cmp(other), + (other, OwnedValue::Agg(b)) => other.partial_cmp(b.final_value()), + other => todo!("{:?}", other), } } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 906f43f73..a42b449cd 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -2189,7 +2189,12 @@ impl Program { } ScalarFunc::Round => { let reg_value = state.registers[*start_reg].clone(); - let precision_value = state.registers.get(*start_reg + 1).cloned(); + assert!(arg_count == 1 || arg_count == 2); + let precision_value = if arg_count > 1 { + Some(state.registers[*start_reg + 1].clone()) + } else { + None + }; let result = exec_round(®_value, precision_value); state.registers[*dest] = result; } @@ -2554,7 +2559,10 @@ fn exec_concat(registers: &[OwnedValue]) -> OwnedValue { OwnedValue::Text(text) => result.push_str(text), OwnedValue::Integer(i) => result.push_str(&i.to_string()), OwnedValue::Float(f) => result.push_str(&f.to_string()), - _ => continue, + OwnedValue::Agg(aggctx) => result.push_str(&aggctx.final_value().to_string()), + OwnedValue::Null => continue, + OwnedValue::Blob(_) => todo!("TODO concat blob"), + OwnedValue::Record(_) => unreachable!(), } } OwnedValue::Text(Rc::new(result)) @@ -2909,20 +2917,27 @@ fn exec_unicode(reg: &OwnedValue) -> OwnedValue { } } +fn _to_float(reg: &OwnedValue) -> f64 { + match reg { + OwnedValue::Text(x) => x.parse().unwrap_or(0.0), + OwnedValue::Integer(x) => *x as f64, + OwnedValue::Float(x) => *x, + _ => 0.0, + } +} + fn exec_round(reg: &OwnedValue, precision: Option) -> OwnedValue { let precision = match precision { Some(OwnedValue::Text(x)) => x.parse().unwrap_or(0.0), Some(OwnedValue::Integer(x)) => x as f64, Some(OwnedValue::Float(x)) => x, - None => 0.0, - _ => return OwnedValue::Null, + Some(OwnedValue::Null) => return OwnedValue::Null, + _ => 0.0, }; let reg = match reg { - OwnedValue::Text(x) => x.parse().unwrap_or(0.0), - OwnedValue::Integer(x) => *x as f64, - OwnedValue::Float(x) => *x, - _ => return reg.to_owned(), + OwnedValue::Agg(ctx) => _to_float(ctx.final_value()), + _ => _to_float(reg), }; let precision = if precision < 1.0 { 0.0 } else { precision }; @@ -3763,6 +3778,14 @@ mod tests { let precision_val = OwnedValue::Integer(1); let expected_val = OwnedValue::Float(123.0); assert_eq!(exec_round(&input_val, Some(precision_val)), expected_val); + + let input_val = OwnedValue::Float(100.123); + let expected_val = OwnedValue::Float(100.0); + assert_eq!(exec_round(&input_val, None), expected_val); + + let input_val = OwnedValue::Float(100.123); + let expected_val = OwnedValue::Null; + assert_eq!(exec_round(&input_val, Some(OwnedValue::Null)), expected_val); } #[test] diff --git a/testing/groupby.test b/testing/groupby.test index b3d873110..34cf802af 100644 --- a/testing/groupby.test +++ b/testing/groupby.test @@ -130,4 +130,35 @@ do_execsql_test group_by_function_expression_ridiculous { do_execsql_test group_by_count_star { select u.first_name, count(*) from users u group by u.first_name limit 1; -} {Aaron|41} \ No newline at end of file +} {Aaron|41} + +do_execsql_test having { + select u.first_name, round(avg(u.age)) from users u group by u.first_name having avg(u.age) > 97 order by avg(u.age) desc limit 5; +} {Nina|100.0 +Kurt|99.0 +Selena|98.0} + +do_execsql_test having_with_binary_cond { + select u.first_name, sum(u.age) from users u group by u.first_name having sum(u.age) + 1000 = 9109; +} {Robert|8109} + +do_execsql_test having_with_scalar_fn_over_aggregate { + select u.first_name, concat(count(1), ' people with this name') from users u group by u.first_name having count(1) > 50 order by count(1) asc limit 5; +} {"Angela|51 people with this name +Justin|51 people with this name +Rachel|52 people with this name +Susan|52 people with this name +Jeffrey|54 people with this name"} + +do_execsql_test having_with_multiple_conditions { + select u.first_name, count(*), round(avg(u.age)) as avg_age + from users u + group by u.first_name + having count(*) > 40 and avg(u.age) > 40 + order by count(*) desc, avg(u.age) desc + limit 5; +} {Michael|228|49.0 +David|165|53.0 +Robert|159|51.0 +Jennifer|151|51.0 +John|145|50.0} \ No newline at end of file From 3e9883bfbdb0408feeb35562a2c7cd8cff2dc4d1 Mon Sep 17 00:00:00 2001 From: jussisaurio Date: Sat, 30 Nov 2024 10:06:37 +0200 Subject: [PATCH 2/2] update COMPAT --- COMPAT.md | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/COMPAT.md b/COMPAT.md index e909c794e..b106b9be9 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -2,11 +2,16 @@ This document describes the SQLite compatibility status of Limbo: -* [Limitations](#limitations) -* [SQL statements](#sql-statements) -* [SQL functions](#sql-functions) -* [SQLite API](#sqlite-api) -* [SQLite VDBE opcodes](#sqlite-vdbe-opcodes) +- [SQLite Compatibility](#sqlite-compatibility) + - [Limitations](#limitations) + - [SQL statements](#sql-statements) + - [SQL functions](#sql-functions) + - [Scalar functions](#scalar-functions) + - [Aggregate functions](#aggregate-functions) + - [Date and time functions](#date-and-time-functions) + - [JSON functions](#json-functions) + - [SQLite API](#sqlite-api) + - [SQLite VDBE opcodes](#sqlite-vdbe-opcodes) ## Limitations @@ -51,6 +56,7 @@ This document describes the SQLite compatibility status of Limbo: | SELECT ... LIMIT | Yes | | | SELECT ... ORDER BY | Partial | | | SELECT ... GROUP BY | Partial | | +| SELECT ... HAVING | Partial | | | SELECT ... JOIN | Partial | | | SELECT ... CROSS JOIN | Partial | | | SELECT ... INNER JOIN | Partial | |