From 2fde9766052ddb27f097a4d9756a2c20af4a8ce5 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Sat, 27 Sep 2025 12:03:59 -0300 Subject: [PATCH] Fix materialized views with complex expressions SQLite supports complex expressions in group by columns - because of course it does... So we need to make sure that a column is created for this expression if it doesn't exist already, and compute it, the same way we compute pre-projections in the filter operator. Fixes #3363 Fixes #3366 Fixes #3365 --- core/translate/logical.rs | 258 +++++++++++++++++++++++++++++++- testing/materialized_views.test | 79 ++++++++++ 2 files changed, 330 insertions(+), 7 deletions(-) diff --git a/core/translate/logical.rs b/core/translate/logical.rs index 9c84a47b5..d7d65da95 100644 --- a/core/translate/logical.rs +++ b/core/translate/logical.rs @@ -23,6 +23,7 @@ type PreprocessAggregateResult = ( Vec, // pre_projection_exprs Vec, // pre_projection_schema Vec, // modified_aggr_exprs + Vec, // modified_group_exprs ); /// Result type for parsing join conditions @@ -385,6 +386,23 @@ impl Display for Column { } } +/// Strip alias wrapper from an expression, returning the underlying expression. +/// This is useful when comparing expressions where one might be aliased and the other not, +/// such as when matching SELECT expressions with GROUP BY expressions. +/// +/// # Examples +/// ```ignore +/// let aliased = LogicalExpr::Alias { expr: Box::new(col_expr), alias: "my_alias".to_string() }; +/// let stripped = strip_alias(&aliased); +/// assert_eq!(stripped, &col_expr); +/// ``` +pub fn strip_alias(expr: &LogicalExpr) -> &LogicalExpr { + match expr { + LogicalExpr::Alias { expr, .. } => expr, + _ => expr, + } +} + /// Type alias for binary operators pub type BinaryOperator = ast::Operator; @@ -992,6 +1010,7 @@ impl<'a> LogicalPlanBuilder<'a> { let mut pre_projection_exprs = Vec::new(); let mut pre_projection_schema = Vec::new(); let mut modified_aggr_exprs = Vec::new(); + let mut modified_group_exprs = Vec::new(); let mut projected_col_counter = 0; // First, add all group by expressions to the pre-projection @@ -1006,6 +1025,8 @@ impl<'a> LogicalPlanBuilder<'a> { table: col.table.clone(), table_alias: None, }); + // Column references stay as-is in the modified group expressions + modified_group_exprs.push(expr.clone()); } else { // Complex group by expression - project it needs_pre_projection = true; @@ -1020,6 +1041,11 @@ impl<'a> LogicalPlanBuilder<'a> { table: None, table_alias: None, }); + // Replace complex expression with reference to projected column + modified_group_exprs.push(LogicalExpr::Column(Column { + name: proj_col_name, + table: None, + })); } } @@ -1092,6 +1118,7 @@ impl<'a> LogicalPlanBuilder<'a> { pre_projection_exprs, pre_projection_schema, modified_aggr_exprs, + modified_group_exprs, )) } @@ -1134,6 +1161,52 @@ impl<'a> LogicalPlanBuilder<'a> { let input_schema = input.schema(); let has_group_by = !group_exprs.is_empty(); + // First pass: build a map of aliases to expressions from the SELECT list + // and a vector of SELECT expressions for positional references + // This allows GROUP BY to reference SELECT aliases (e.g., GROUP BY year) + // or positions (e.g., GROUP BY 1) + let mut alias_to_expr = HashMap::new(); + let mut select_exprs = Vec::new(); + for col in columns { + if let ast::ResultColumn::Expr(expr, alias) = col { + let logical_expr = self.build_expr(expr, input_schema)?; + select_exprs.push(logical_expr.clone()); + + if let Some(alias) = alias { + let alias_name = match alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }; + alias_to_expr.insert(alias_name, logical_expr); + } + } + } + + // Resolve GROUP BY expressions: replace column references that match SELECT aliases + // or integer literals that represent positions + let group_exprs = group_exprs + .into_iter() + .map(|expr| { + // Check for positional reference (integer literal) + if let LogicalExpr::Literal(crate::types::Value::Integer(pos)) = &expr { + // SQLite uses 1-based indexing + if *pos > 0 && (*pos as usize) <= select_exprs.len() { + return select_exprs[(*pos as usize) - 1].clone(); + } + } + + // Check for alias reference (unqualified column name) + if let LogicalExpr::Column(col) = &expr { + if col.table.is_none() { + // Unqualified column - check if it matches an alias + if let Some(aliased_expr) = alias_to_expr.get(&col.name) { + return aliased_expr.clone(); + } + } + } + expr + }) + .collect::>(); + // Build aggregate expressions and projection expressions let mut aggr_exprs = Vec::new(); let mut projection_exprs = Vec::new(); @@ -1178,8 +1251,7 @@ impl<'a> LogicalPlanBuilder<'a> { } // Track aggregates we've already seen to avoid duplicates - let mut aggregate_map: std::collections::HashMap = - std::collections::HashMap::new(); + let mut aggregate_map: HashMap = HashMap::new(); for col in columns { match col { @@ -1264,6 +1336,23 @@ impl<'a> LogicalPlanBuilder<'a> { "Column '{col_name}' must appear in the GROUP BY clause or be used in an aggregate function" ))); } + + // If this expression matches a GROUP BY expression, replace it with a reference + // to the corresponding column in the aggregate output + let logical_expr_stripped = strip_alias(&logical_expr); + if let Some(group_idx) = group_exprs + .iter() + .position(|g| logical_expr_stripped == strip_alias(g)) + { + // Reference the GROUP BY column in the aggregate output by its name + let group_col_name = &aggregate_schema_columns[group_idx].name; + projection_exprs.push(LogicalExpr::Column(Column { + name: group_col_name.clone(), + table: None, + })); + } else { + projection_exprs.push(logical_expr); + } } else { // Without GROUP BY: only allow constant expressions // TODO: SQLite allows any column here and returns a value from an @@ -1274,8 +1363,8 @@ impl<'a> LogicalPlanBuilder<'a> { "Column '{col_name}' must be used in an aggregate function when using aggregates without GROUP BY" ))); } + projection_exprs.push(logical_expr); } - projection_exprs.push(logical_expr); } } _ => { @@ -1290,12 +1379,14 @@ impl<'a> LogicalPlanBuilder<'a> { } // Check if any aggregate functions have complex expressions as arguments + // or if GROUP BY has complex expressions // If so, we need to insert a projection before the aggregate let ( needs_pre_projection, pre_projection_exprs, pre_projection_schema, modified_aggr_exprs, + modified_group_exprs, ) = Self::preprocess_aggregate_expressions(&aggr_exprs, &group_exprs, input_schema)?; // Build the final schema for the projection @@ -1338,12 +1429,17 @@ impl<'a> LogicalPlanBuilder<'a> { Arc::new(input) }; - // Use modified aggregate expressions if we inserted a pre-projection + // Use modified aggregate and group expressions if we inserted a pre-projection let final_aggr_exprs = if needs_pre_projection { modified_aggr_exprs } else { aggr_exprs }; + let final_group_exprs = if needs_pre_projection { + modified_group_exprs + } else { + group_exprs + }; // Check if we need the outer projection // We need a projection if: @@ -1384,7 +1480,7 @@ impl<'a> LogicalPlanBuilder<'a> { // Create the aggregate node with its natural schema let aggregate_plan = LogicalPlan::Aggregate(Aggregate { input: aggregate_input, - group_expr: group_exprs, + group_expr: final_group_exprs, aggr_expr: final_aggr_exprs, schema: Arc::new(LogicalSchema::new(aggregate_schema_columns)), }); @@ -1957,6 +2053,14 @@ impl<'a> LogicalPlanBuilder<'a> { // 2. An aggregate function // 3. A grouping column (or expression involving only grouping columns) fn is_valid_in_group_by(expr: &LogicalExpr, group_exprs: &[LogicalExpr]) -> bool { + // First check if the entire expression appears in GROUP BY + // Strip aliases before comparing since SELECT might have aliases but GROUP BY might not + let expr_stripped = strip_alias(expr); + if group_exprs.iter().any(|g| expr_stripped == strip_alias(g)) { + return true; + } + + // If not, check recursively based on expression type match expr { LogicalExpr::Literal(_) => true, // Constants are always valid LogicalExpr::AggregateFunction { .. } => true, // Aggregates are valid @@ -1987,7 +2091,7 @@ impl<'a> LogicalPlanBuilder<'a> { // Returns the modified expression and a list of NEW (aggregate_expr, column_name) pairs fn extract_and_replace_aggregates_with_dedup( expr: LogicalExpr, - aggregate_map: &mut std::collections::HashMap, + aggregate_map: &mut HashMap, ) -> Result<(LogicalExpr, Vec<(LogicalExpr, String)>)> { let mut new_aggregates = Vec::new(); let mut counter = aggregate_map.len(); @@ -2004,7 +2108,7 @@ impl<'a> LogicalPlanBuilder<'a> { fn replace_aggregates_with_columns_dedup( expr: LogicalExpr, new_aggregates: &mut Vec<(LogicalExpr, String)>, - aggregate_map: &mut std::collections::HashMap, + aggregate_map: &mut HashMap, counter: &mut usize, ) -> Result { match expr { @@ -3962,4 +4066,144 @@ mod tests { } } } + + // Tests for strip_alias function + #[test] + fn test_strip_alias_with_alias() { + let inner_expr = LogicalExpr::Column(Column::new("test")); + let aliased = LogicalExpr::Alias { + expr: Box::new(inner_expr.clone()), + alias: "my_alias".to_string(), + }; + + let stripped = strip_alias(&aliased); + assert_eq!(stripped, &inner_expr); + } + + #[test] + fn test_strip_alias_without_alias() { + let expr = LogicalExpr::Column(Column::new("test")); + let stripped = strip_alias(&expr); + assert_eq!(stripped, &expr); + } + + #[test] + fn test_strip_alias_literal() { + let expr = LogicalExpr::Literal(Value::Integer(42)); + let stripped = strip_alias(&expr); + assert_eq!(stripped, &expr); + } + + #[test] + fn test_strip_alias_scalar_function() { + let expr = LogicalExpr::ScalarFunction { + fun: "substr".to_string(), + args: vec![ + LogicalExpr::Column(Column::new("name")), + LogicalExpr::Literal(Value::Integer(1)), + LogicalExpr::Literal(Value::Integer(4)), + ], + }; + let stripped = strip_alias(&expr); + assert_eq!(stripped, &expr); + } + + #[test] + fn test_strip_alias_nested_alias() { + // Test that strip_alias only removes the outermost alias + let inner_expr = LogicalExpr::Column(Column::new("test")); + let inner_alias = LogicalExpr::Alias { + expr: Box::new(inner_expr.clone()), + alias: "inner_alias".to_string(), + }; + let outer_alias = LogicalExpr::Alias { + expr: Box::new(inner_alias.clone()), + alias: "outer_alias".to_string(), + }; + + let stripped = strip_alias(&outer_alias); + assert_eq!(stripped, &inner_alias); + + // Stripping again should give us the inner expression + let double_stripped = strip_alias(stripped); + assert_eq!(double_stripped, &inner_expr); + } + + #[test] + fn test_strip_alias_comparison_with_alias() { + // Test that two expressions match when one has an alias and one doesn't + let base_expr = LogicalExpr::ScalarFunction { + fun: "substr".to_string(), + args: vec![ + LogicalExpr::Column(Column::new("orderdate")), + LogicalExpr::Literal(Value::Integer(1)), + LogicalExpr::Literal(Value::Integer(4)), + ], + }; + + let aliased_expr = LogicalExpr::Alias { + expr: Box::new(base_expr.clone()), + alias: "year".to_string(), + }; + + // Without strip_alias, they wouldn't match + assert_ne!(&aliased_expr, &base_expr); + + // With strip_alias, they should match + assert_eq!(strip_alias(&aliased_expr), &base_expr); + assert_eq!(strip_alias(&base_expr), &base_expr); + } + + #[test] + fn test_strip_alias_binary_expr() { + let expr = LogicalExpr::BinaryExpr { + left: Box::new(LogicalExpr::Column(Column::new("a"))), + op: BinaryOperator::Add, + right: Box::new(LogicalExpr::Literal(Value::Integer(1))), + }; + let stripped = strip_alias(&expr); + assert_eq!(stripped, &expr); + } + + #[test] + fn test_strip_alias_aggregate_function() { + let expr = LogicalExpr::AggregateFunction { + fun: AggFunc::Sum, + args: vec![LogicalExpr::Column(Column::new("amount"))], + distinct: false, + }; + let stripped = strip_alias(&expr); + assert_eq!(stripped, &expr); + } + + #[test] + fn test_strip_alias_comparison_multiple_expressions() { + // Test comparing a list of expressions with and without aliases + let expr1 = LogicalExpr::Column(Column::new("a")); + let expr2 = LogicalExpr::ScalarFunction { + fun: "substr".to_string(), + args: vec![ + LogicalExpr::Column(Column::new("b")), + LogicalExpr::Literal(Value::Integer(1)), + LogicalExpr::Literal(Value::Integer(4)), + ], + }; + + let aliased1 = LogicalExpr::Alias { + expr: Box::new(expr1.clone()), + alias: "col_a".to_string(), + }; + let aliased2 = LogicalExpr::Alias { + expr: Box::new(expr2.clone()), + alias: "year".to_string(), + }; + + let select_exprs = [aliased1, aliased2]; + let group_exprs = [expr1.clone(), expr2.clone()]; + + // Verify that stripping aliases allows matching + for (select_expr, group_expr) in select_exprs.iter().zip(group_exprs.iter()) { + assert_eq!(strip_alias(select_expr), group_expr); + } + } } diff --git a/testing/materialized_views.test b/testing/materialized_views.test index 911e50ef8..dd2652d7d 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -1612,3 +1612,82 @@ do_execsql_test_on_specific_db {:memory:} matview-join-swapped-columns { SELECT * FROM emp_dept ORDER BY name; } {Alice|Engineering Bob|Sales} + +do_execsql_test_on_specific_db {:memory:} matview-groupby-scalar-function { + CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER); + INSERT INTO orders VALUES (1, '2020-01-15', 100); + INSERT INTO orders VALUES (2, '2020-06-10', 150); + INSERT INTO orders VALUES (3, '2021-03-20', 200); + + CREATE MATERIALIZED VIEW yearly_totals AS + SELECT substr(orderdate, 1, 4), sum(amount) + FROM orders + GROUP BY substr(orderdate, 1, 4); + + SELECT * FROM yearly_totals ORDER BY 1; +} {2020|250 +2021|200} + +do_execsql_test_on_specific_db {:memory:} matview-groupby-alias { + CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER); + INSERT INTO orders VALUES (1, '2020-01-15', 100); + INSERT INTO orders VALUES (2, '2020-06-10', 150); + INSERT INTO orders VALUES (3, '2021-03-20', 200); + + CREATE MATERIALIZED VIEW yearly_totals AS + SELECT substr(orderdate, 1, 4) as year, sum(amount) as total + FROM orders + GROUP BY year; + + SELECT * FROM yearly_totals ORDER BY year; +} {2020|250 +2021|200} + +do_execsql_test_on_specific_db {:memory:} matview-groupby-position { + CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER, nation TEXT); + INSERT INTO orders VALUES (1, '2020-01-15', 100, 'USA'); + INSERT INTO orders VALUES (2, '2020-06-10', 150, 'USA'); + INSERT INTO orders VALUES (3, '2021-03-20', 200, 'UK'); + + CREATE MATERIALIZED VIEW national_yearly AS + SELECT nation, substr(orderdate, 1, 4) as year, sum(amount) as total + FROM orders + GROUP BY 1, 2; + + SELECT * FROM national_yearly ORDER BY nation, year; +} {UK|2021|200 +USA|2020|250} + +do_execsql_test_on_specific_db {:memory:} matview-groupby-scalar-incremental { + CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER); + INSERT INTO orders VALUES (1, '2020-01-15', 100); + + CREATE MATERIALIZED VIEW yearly_totals AS + SELECT substr(orderdate, 1, 4) as year, sum(amount) as total + FROM orders + GROUP BY year; + + SELECT * FROM yearly_totals; + INSERT INTO orders VALUES (2, '2020-06-10', 150); + SELECT * FROM yearly_totals; + INSERT INTO orders VALUES (3, '2021-03-20', 200); + SELECT * FROM yearly_totals ORDER BY year; +} {2020|100 +2020|250 +2020|250 +2021|200} + +do_execsql_test_on_specific_db {:memory:} matview-groupby-join-position { + CREATE TABLE t(a INTEGER); + CREATE TABLE u(a INTEGER); + INSERT INTO t VALUES (1), (2), (3); + INSERT INTO u VALUES (1), (1), (2); + + CREATE MATERIALIZED VIEW tujoingroup AS + SELECT t.a, count(u.a) as cnt + FROM t JOIN u ON t.a = u.a + GROUP BY 1; + + SELECT * FROM tujoingroup ORDER BY a; +} {1|2 +2|1}