Merge 'Fix materialized views with complex expressions' from Glauber Costa

SQLite supports complex expressions in group by columns - because of
course it does...
So we need to make sure that a column is created for this expression if
it doesn't exist already, and compute it, the same way we compute pre-
projections in the filter operator.
Fixes #3363
Fixes #3366
Fixes #3365

Closes #3429
This commit is contained in:
Pekka Enberg
2025-09-30 07:34:51 +03:00
committed by GitHub
2 changed files with 330 additions and 7 deletions

View File

@@ -23,6 +23,7 @@ type PreprocessAggregateResult = (
Vec<LogicalExpr>, // pre_projection_exprs
Vec<ColumnInfo>, // pre_projection_schema
Vec<LogicalExpr>, // modified_aggr_exprs
Vec<LogicalExpr>, // modified_group_exprs
);
/// Result type for parsing join conditions
@@ -385,6 +386,23 @@ impl Display for Column {
}
}
/// Strip alias wrapper from an expression, returning the underlying expression.
/// This is useful when comparing expressions where one might be aliased and the other not,
/// such as when matching SELECT expressions with GROUP BY expressions.
///
/// # Examples
/// ```ignore
/// let aliased = LogicalExpr::Alias { expr: Box::new(col_expr), alias: "my_alias".to_string() };
/// let stripped = strip_alias(&aliased);
/// assert_eq!(stripped, &col_expr);
/// ```
pub fn strip_alias(expr: &LogicalExpr) -> &LogicalExpr {
match expr {
LogicalExpr::Alias { expr, .. } => expr,
_ => expr,
}
}
/// Type alias for binary operators
pub type BinaryOperator = ast::Operator;
@@ -992,6 +1010,7 @@ impl<'a> LogicalPlanBuilder<'a> {
let mut pre_projection_exprs = Vec::new();
let mut pre_projection_schema = Vec::new();
let mut modified_aggr_exprs = Vec::new();
let mut modified_group_exprs = Vec::new();
let mut projected_col_counter = 0;
// First, add all group by expressions to the pre-projection
@@ -1006,6 +1025,8 @@ impl<'a> LogicalPlanBuilder<'a> {
table: col.table.clone(),
table_alias: None,
});
// Column references stay as-is in the modified group expressions
modified_group_exprs.push(expr.clone());
} else {
// Complex group by expression - project it
needs_pre_projection = true;
@@ -1020,6 +1041,11 @@ impl<'a> LogicalPlanBuilder<'a> {
table: None,
table_alias: None,
});
// Replace complex expression with reference to projected column
modified_group_exprs.push(LogicalExpr::Column(Column {
name: proj_col_name,
table: None,
}));
}
}
@@ -1092,6 +1118,7 @@ impl<'a> LogicalPlanBuilder<'a> {
pre_projection_exprs,
pre_projection_schema,
modified_aggr_exprs,
modified_group_exprs,
))
}
@@ -1134,6 +1161,52 @@ impl<'a> LogicalPlanBuilder<'a> {
let input_schema = input.schema();
let has_group_by = !group_exprs.is_empty();
// First pass: build a map of aliases to expressions from the SELECT list
// and a vector of SELECT expressions for positional references
// This allows GROUP BY to reference SELECT aliases (e.g., GROUP BY year)
// or positions (e.g., GROUP BY 1)
let mut alias_to_expr = HashMap::new();
let mut select_exprs = Vec::new();
for col in columns {
if let ast::ResultColumn::Expr(expr, alias) = col {
let logical_expr = self.build_expr(expr, input_schema)?;
select_exprs.push(logical_expr.clone());
if let Some(alias) = alias {
let alias_name = match alias {
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
};
alias_to_expr.insert(alias_name, logical_expr);
}
}
}
// Resolve GROUP BY expressions: replace column references that match SELECT aliases
// or integer literals that represent positions
let group_exprs = group_exprs
.into_iter()
.map(|expr| {
// Check for positional reference (integer literal)
if let LogicalExpr::Literal(crate::types::Value::Integer(pos)) = &expr {
// SQLite uses 1-based indexing
if *pos > 0 && (*pos as usize) <= select_exprs.len() {
return select_exprs[(*pos as usize) - 1].clone();
}
}
// Check for alias reference (unqualified column name)
if let LogicalExpr::Column(col) = &expr {
if col.table.is_none() {
// Unqualified column - check if it matches an alias
if let Some(aliased_expr) = alias_to_expr.get(&col.name) {
return aliased_expr.clone();
}
}
}
expr
})
.collect::<Vec<_>>();
// Build aggregate expressions and projection expressions
let mut aggr_exprs = Vec::new();
let mut projection_exprs = Vec::new();
@@ -1178,8 +1251,7 @@ impl<'a> LogicalPlanBuilder<'a> {
}
// Track aggregates we've already seen to avoid duplicates
let mut aggregate_map: std::collections::HashMap<String, String> =
std::collections::HashMap::new();
let mut aggregate_map: HashMap<String, String> = HashMap::new();
for col in columns {
match col {
@@ -1264,6 +1336,23 @@ impl<'a> LogicalPlanBuilder<'a> {
"Column '{col_name}' must appear in the GROUP BY clause or be used in an aggregate function"
)));
}
// If this expression matches a GROUP BY expression, replace it with a reference
// to the corresponding column in the aggregate output
let logical_expr_stripped = strip_alias(&logical_expr);
if let Some(group_idx) = group_exprs
.iter()
.position(|g| logical_expr_stripped == strip_alias(g))
{
// Reference the GROUP BY column in the aggregate output by its name
let group_col_name = &aggregate_schema_columns[group_idx].name;
projection_exprs.push(LogicalExpr::Column(Column {
name: group_col_name.clone(),
table: None,
}));
} else {
projection_exprs.push(logical_expr);
}
} else {
// Without GROUP BY: only allow constant expressions
// TODO: SQLite allows any column here and returns a value from an
@@ -1274,8 +1363,8 @@ impl<'a> LogicalPlanBuilder<'a> {
"Column '{col_name}' must be used in an aggregate function when using aggregates without GROUP BY"
)));
}
projection_exprs.push(logical_expr);
}
projection_exprs.push(logical_expr);
}
}
_ => {
@@ -1290,12 +1379,14 @@ impl<'a> LogicalPlanBuilder<'a> {
}
// Check if any aggregate functions have complex expressions as arguments
// or if GROUP BY has complex expressions
// If so, we need to insert a projection before the aggregate
let (
needs_pre_projection,
pre_projection_exprs,
pre_projection_schema,
modified_aggr_exprs,
modified_group_exprs,
) = Self::preprocess_aggregate_expressions(&aggr_exprs, &group_exprs, input_schema)?;
// Build the final schema for the projection
@@ -1338,12 +1429,17 @@ impl<'a> LogicalPlanBuilder<'a> {
Arc::new(input)
};
// Use modified aggregate expressions if we inserted a pre-projection
// Use modified aggregate and group expressions if we inserted a pre-projection
let final_aggr_exprs = if needs_pre_projection {
modified_aggr_exprs
} else {
aggr_exprs
};
let final_group_exprs = if needs_pre_projection {
modified_group_exprs
} else {
group_exprs
};
// Check if we need the outer projection
// We need a projection if:
@@ -1384,7 +1480,7 @@ impl<'a> LogicalPlanBuilder<'a> {
// Create the aggregate node with its natural schema
let aggregate_plan = LogicalPlan::Aggregate(Aggregate {
input: aggregate_input,
group_expr: group_exprs,
group_expr: final_group_exprs,
aggr_expr: final_aggr_exprs,
schema: Arc::new(LogicalSchema::new(aggregate_schema_columns)),
});
@@ -1957,6 +2053,14 @@ impl<'a> LogicalPlanBuilder<'a> {
// 2. An aggregate function
// 3. A grouping column (or expression involving only grouping columns)
fn is_valid_in_group_by(expr: &LogicalExpr, group_exprs: &[LogicalExpr]) -> bool {
// First check if the entire expression appears in GROUP BY
// Strip aliases before comparing since SELECT might have aliases but GROUP BY might not
let expr_stripped = strip_alias(expr);
if group_exprs.iter().any(|g| expr_stripped == strip_alias(g)) {
return true;
}
// If not, check recursively based on expression type
match expr {
LogicalExpr::Literal(_) => true, // Constants are always valid
LogicalExpr::AggregateFunction { .. } => true, // Aggregates are valid
@@ -1987,7 +2091,7 @@ impl<'a> LogicalPlanBuilder<'a> {
// Returns the modified expression and a list of NEW (aggregate_expr, column_name) pairs
fn extract_and_replace_aggregates_with_dedup(
expr: LogicalExpr,
aggregate_map: &mut std::collections::HashMap<String, String>,
aggregate_map: &mut HashMap<String, String>,
) -> Result<(LogicalExpr, Vec<(LogicalExpr, String)>)> {
let mut new_aggregates = Vec::new();
let mut counter = aggregate_map.len();
@@ -2004,7 +2108,7 @@ impl<'a> LogicalPlanBuilder<'a> {
fn replace_aggregates_with_columns_dedup(
expr: LogicalExpr,
new_aggregates: &mut Vec<(LogicalExpr, String)>,
aggregate_map: &mut std::collections::HashMap<String, String>,
aggregate_map: &mut HashMap<String, String>,
counter: &mut usize,
) -> Result<LogicalExpr> {
match expr {
@@ -3962,4 +4066,144 @@ mod tests {
}
}
}
// Tests for strip_alias function
#[test]
fn test_strip_alias_with_alias() {
let inner_expr = LogicalExpr::Column(Column::new("test"));
let aliased = LogicalExpr::Alias {
expr: Box::new(inner_expr.clone()),
alias: "my_alias".to_string(),
};
let stripped = strip_alias(&aliased);
assert_eq!(stripped, &inner_expr);
}
#[test]
fn test_strip_alias_without_alias() {
let expr = LogicalExpr::Column(Column::new("test"));
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_literal() {
let expr = LogicalExpr::Literal(Value::Integer(42));
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_scalar_function() {
let expr = LogicalExpr::ScalarFunction {
fun: "substr".to_string(),
args: vec![
LogicalExpr::Column(Column::new("name")),
LogicalExpr::Literal(Value::Integer(1)),
LogicalExpr::Literal(Value::Integer(4)),
],
};
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_nested_alias() {
// Test that strip_alias only removes the outermost alias
let inner_expr = LogicalExpr::Column(Column::new("test"));
let inner_alias = LogicalExpr::Alias {
expr: Box::new(inner_expr.clone()),
alias: "inner_alias".to_string(),
};
let outer_alias = LogicalExpr::Alias {
expr: Box::new(inner_alias.clone()),
alias: "outer_alias".to_string(),
};
let stripped = strip_alias(&outer_alias);
assert_eq!(stripped, &inner_alias);
// Stripping again should give us the inner expression
let double_stripped = strip_alias(stripped);
assert_eq!(double_stripped, &inner_expr);
}
#[test]
fn test_strip_alias_comparison_with_alias() {
// Test that two expressions match when one has an alias and one doesn't
let base_expr = LogicalExpr::ScalarFunction {
fun: "substr".to_string(),
args: vec![
LogicalExpr::Column(Column::new("orderdate")),
LogicalExpr::Literal(Value::Integer(1)),
LogicalExpr::Literal(Value::Integer(4)),
],
};
let aliased_expr = LogicalExpr::Alias {
expr: Box::new(base_expr.clone()),
alias: "year".to_string(),
};
// Without strip_alias, they wouldn't match
assert_ne!(&aliased_expr, &base_expr);
// With strip_alias, they should match
assert_eq!(strip_alias(&aliased_expr), &base_expr);
assert_eq!(strip_alias(&base_expr), &base_expr);
}
#[test]
fn test_strip_alias_binary_expr() {
let expr = LogicalExpr::BinaryExpr {
left: Box::new(LogicalExpr::Column(Column::new("a"))),
op: BinaryOperator::Add,
right: Box::new(LogicalExpr::Literal(Value::Integer(1))),
};
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_aggregate_function() {
let expr = LogicalExpr::AggregateFunction {
fun: AggFunc::Sum,
args: vec![LogicalExpr::Column(Column::new("amount"))],
distinct: false,
};
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_comparison_multiple_expressions() {
// Test comparing a list of expressions with and without aliases
let expr1 = LogicalExpr::Column(Column::new("a"));
let expr2 = LogicalExpr::ScalarFunction {
fun: "substr".to_string(),
args: vec![
LogicalExpr::Column(Column::new("b")),
LogicalExpr::Literal(Value::Integer(1)),
LogicalExpr::Literal(Value::Integer(4)),
],
};
let aliased1 = LogicalExpr::Alias {
expr: Box::new(expr1.clone()),
alias: "col_a".to_string(),
};
let aliased2 = LogicalExpr::Alias {
expr: Box::new(expr2.clone()),
alias: "year".to_string(),
};
let select_exprs = [aliased1, aliased2];
let group_exprs = [expr1.clone(), expr2.clone()];
// Verify that stripping aliases allows matching
for (select_expr, group_expr) in select_exprs.iter().zip(group_exprs.iter()) {
assert_eq!(strip_alias(select_expr), group_expr);
}
}
}

View File

@@ -1612,3 +1612,82 @@ do_execsql_test_on_specific_db {:memory:} matview-join-swapped-columns {
SELECT * FROM emp_dept ORDER BY name;
} {Alice|Engineering
Bob|Sales}
do_execsql_test_on_specific_db {:memory:} matview-groupby-scalar-function {
CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER);
INSERT INTO orders VALUES (1, '2020-01-15', 100);
INSERT INTO orders VALUES (2, '2020-06-10', 150);
INSERT INTO orders VALUES (3, '2021-03-20', 200);
CREATE MATERIALIZED VIEW yearly_totals AS
SELECT substr(orderdate, 1, 4), sum(amount)
FROM orders
GROUP BY substr(orderdate, 1, 4);
SELECT * FROM yearly_totals ORDER BY 1;
} {2020|250
2021|200}
do_execsql_test_on_specific_db {:memory:} matview-groupby-alias {
CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER);
INSERT INTO orders VALUES (1, '2020-01-15', 100);
INSERT INTO orders VALUES (2, '2020-06-10', 150);
INSERT INTO orders VALUES (3, '2021-03-20', 200);
CREATE MATERIALIZED VIEW yearly_totals AS
SELECT substr(orderdate, 1, 4) as year, sum(amount) as total
FROM orders
GROUP BY year;
SELECT * FROM yearly_totals ORDER BY year;
} {2020|250
2021|200}
do_execsql_test_on_specific_db {:memory:} matview-groupby-position {
CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER, nation TEXT);
INSERT INTO orders VALUES (1, '2020-01-15', 100, 'USA');
INSERT INTO orders VALUES (2, '2020-06-10', 150, 'USA');
INSERT INTO orders VALUES (3, '2021-03-20', 200, 'UK');
CREATE MATERIALIZED VIEW national_yearly AS
SELECT nation, substr(orderdate, 1, 4) as year, sum(amount) as total
FROM orders
GROUP BY 1, 2;
SELECT * FROM national_yearly ORDER BY nation, year;
} {UK|2021|200
USA|2020|250}
do_execsql_test_on_specific_db {:memory:} matview-groupby-scalar-incremental {
CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER);
INSERT INTO orders VALUES (1, '2020-01-15', 100);
CREATE MATERIALIZED VIEW yearly_totals AS
SELECT substr(orderdate, 1, 4) as year, sum(amount) as total
FROM orders
GROUP BY year;
SELECT * FROM yearly_totals;
INSERT INTO orders VALUES (2, '2020-06-10', 150);
SELECT * FROM yearly_totals;
INSERT INTO orders VALUES (3, '2021-03-20', 200);
SELECT * FROM yearly_totals ORDER BY year;
} {2020|100
2020|250
2020|250
2021|200}
do_execsql_test_on_specific_db {:memory:} matview-groupby-join-position {
CREATE TABLE t(a INTEGER);
CREATE TABLE u(a INTEGER);
INSERT INTO t VALUES (1), (2), (3);
INSERT INTO u VALUES (1), (1), (2);
CREATE MATERIALIZED VIEW tujoingroup AS
SELECT t.a, count(u.a) as cnt
FROM t JOIN u ON t.a = u.a
GROUP BY 1;
SELECT * FROM tujoingroup ORDER BY a;
} {1|2
2|1}