mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-26 04:24:21 +01:00
Merge 'Fix materialized views with complex expressions' from Glauber Costa
SQLite supports complex expressions in group by columns - because of course it does... So we need to make sure that a column is created for this expression if it doesn't exist already, and compute it, the same way we compute pre- projections in the filter operator. Fixes #3363 Fixes #3366 Fixes #3365 Closes #3429
This commit is contained in:
@@ -23,6 +23,7 @@ type PreprocessAggregateResult = (
|
||||
Vec<LogicalExpr>, // pre_projection_exprs
|
||||
Vec<ColumnInfo>, // pre_projection_schema
|
||||
Vec<LogicalExpr>, // modified_aggr_exprs
|
||||
Vec<LogicalExpr>, // modified_group_exprs
|
||||
);
|
||||
|
||||
/// Result type for parsing join conditions
|
||||
@@ -385,6 +386,23 @@ impl Display for Column {
|
||||
}
|
||||
}
|
||||
|
||||
/// Strip alias wrapper from an expression, returning the underlying expression.
|
||||
/// This is useful when comparing expressions where one might be aliased and the other not,
|
||||
/// such as when matching SELECT expressions with GROUP BY expressions.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```ignore
|
||||
/// let aliased = LogicalExpr::Alias { expr: Box::new(col_expr), alias: "my_alias".to_string() };
|
||||
/// let stripped = strip_alias(&aliased);
|
||||
/// assert_eq!(stripped, &col_expr);
|
||||
/// ```
|
||||
pub fn strip_alias(expr: &LogicalExpr) -> &LogicalExpr {
|
||||
match expr {
|
||||
LogicalExpr::Alias { expr, .. } => expr,
|
||||
_ => expr,
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias for binary operators
|
||||
pub type BinaryOperator = ast::Operator;
|
||||
|
||||
@@ -992,6 +1010,7 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
let mut pre_projection_exprs = Vec::new();
|
||||
let mut pre_projection_schema = Vec::new();
|
||||
let mut modified_aggr_exprs = Vec::new();
|
||||
let mut modified_group_exprs = Vec::new();
|
||||
let mut projected_col_counter = 0;
|
||||
|
||||
// First, add all group by expressions to the pre-projection
|
||||
@@ -1006,6 +1025,8 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
table: col.table.clone(),
|
||||
table_alias: None,
|
||||
});
|
||||
// Column references stay as-is in the modified group expressions
|
||||
modified_group_exprs.push(expr.clone());
|
||||
} else {
|
||||
// Complex group by expression - project it
|
||||
needs_pre_projection = true;
|
||||
@@ -1020,6 +1041,11 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
table: None,
|
||||
table_alias: None,
|
||||
});
|
||||
// Replace complex expression with reference to projected column
|
||||
modified_group_exprs.push(LogicalExpr::Column(Column {
|
||||
name: proj_col_name,
|
||||
table: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1092,6 +1118,7 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
pre_projection_exprs,
|
||||
pre_projection_schema,
|
||||
modified_aggr_exprs,
|
||||
modified_group_exprs,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1134,6 +1161,52 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
let input_schema = input.schema();
|
||||
let has_group_by = !group_exprs.is_empty();
|
||||
|
||||
// First pass: build a map of aliases to expressions from the SELECT list
|
||||
// and a vector of SELECT expressions for positional references
|
||||
// This allows GROUP BY to reference SELECT aliases (e.g., GROUP BY year)
|
||||
// or positions (e.g., GROUP BY 1)
|
||||
let mut alias_to_expr = HashMap::new();
|
||||
let mut select_exprs = Vec::new();
|
||||
for col in columns {
|
||||
if let ast::ResultColumn::Expr(expr, alias) = col {
|
||||
let logical_expr = self.build_expr(expr, input_schema)?;
|
||||
select_exprs.push(logical_expr.clone());
|
||||
|
||||
if let Some(alias) = alias {
|
||||
let alias_name = match alias {
|
||||
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
|
||||
};
|
||||
alias_to_expr.insert(alias_name, logical_expr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve GROUP BY expressions: replace column references that match SELECT aliases
|
||||
// or integer literals that represent positions
|
||||
let group_exprs = group_exprs
|
||||
.into_iter()
|
||||
.map(|expr| {
|
||||
// Check for positional reference (integer literal)
|
||||
if let LogicalExpr::Literal(crate::types::Value::Integer(pos)) = &expr {
|
||||
// SQLite uses 1-based indexing
|
||||
if *pos > 0 && (*pos as usize) <= select_exprs.len() {
|
||||
return select_exprs[(*pos as usize) - 1].clone();
|
||||
}
|
||||
}
|
||||
|
||||
// Check for alias reference (unqualified column name)
|
||||
if let LogicalExpr::Column(col) = &expr {
|
||||
if col.table.is_none() {
|
||||
// Unqualified column - check if it matches an alias
|
||||
if let Some(aliased_expr) = alias_to_expr.get(&col.name) {
|
||||
return aliased_expr.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
expr
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Build aggregate expressions and projection expressions
|
||||
let mut aggr_exprs = Vec::new();
|
||||
let mut projection_exprs = Vec::new();
|
||||
@@ -1178,8 +1251,7 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
}
|
||||
|
||||
// Track aggregates we've already seen to avoid duplicates
|
||||
let mut aggregate_map: std::collections::HashMap<String, String> =
|
||||
std::collections::HashMap::new();
|
||||
let mut aggregate_map: HashMap<String, String> = HashMap::new();
|
||||
|
||||
for col in columns {
|
||||
match col {
|
||||
@@ -1264,6 +1336,23 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
"Column '{col_name}' must appear in the GROUP BY clause or be used in an aggregate function"
|
||||
)));
|
||||
}
|
||||
|
||||
// If this expression matches a GROUP BY expression, replace it with a reference
|
||||
// to the corresponding column in the aggregate output
|
||||
let logical_expr_stripped = strip_alias(&logical_expr);
|
||||
if let Some(group_idx) = group_exprs
|
||||
.iter()
|
||||
.position(|g| logical_expr_stripped == strip_alias(g))
|
||||
{
|
||||
// Reference the GROUP BY column in the aggregate output by its name
|
||||
let group_col_name = &aggregate_schema_columns[group_idx].name;
|
||||
projection_exprs.push(LogicalExpr::Column(Column {
|
||||
name: group_col_name.clone(),
|
||||
table: None,
|
||||
}));
|
||||
} else {
|
||||
projection_exprs.push(logical_expr);
|
||||
}
|
||||
} else {
|
||||
// Without GROUP BY: only allow constant expressions
|
||||
// TODO: SQLite allows any column here and returns a value from an
|
||||
@@ -1274,8 +1363,8 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
"Column '{col_name}' must be used in an aggregate function when using aggregates without GROUP BY"
|
||||
)));
|
||||
}
|
||||
projection_exprs.push(logical_expr);
|
||||
}
|
||||
projection_exprs.push(logical_expr);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
@@ -1290,12 +1379,14 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
}
|
||||
|
||||
// Check if any aggregate functions have complex expressions as arguments
|
||||
// or if GROUP BY has complex expressions
|
||||
// If so, we need to insert a projection before the aggregate
|
||||
let (
|
||||
needs_pre_projection,
|
||||
pre_projection_exprs,
|
||||
pre_projection_schema,
|
||||
modified_aggr_exprs,
|
||||
modified_group_exprs,
|
||||
) = Self::preprocess_aggregate_expressions(&aggr_exprs, &group_exprs, input_schema)?;
|
||||
|
||||
// Build the final schema for the projection
|
||||
@@ -1338,12 +1429,17 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
Arc::new(input)
|
||||
};
|
||||
|
||||
// Use modified aggregate expressions if we inserted a pre-projection
|
||||
// Use modified aggregate and group expressions if we inserted a pre-projection
|
||||
let final_aggr_exprs = if needs_pre_projection {
|
||||
modified_aggr_exprs
|
||||
} else {
|
||||
aggr_exprs
|
||||
};
|
||||
let final_group_exprs = if needs_pre_projection {
|
||||
modified_group_exprs
|
||||
} else {
|
||||
group_exprs
|
||||
};
|
||||
|
||||
// Check if we need the outer projection
|
||||
// We need a projection if:
|
||||
@@ -1384,7 +1480,7 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
// Create the aggregate node with its natural schema
|
||||
let aggregate_plan = LogicalPlan::Aggregate(Aggregate {
|
||||
input: aggregate_input,
|
||||
group_expr: group_exprs,
|
||||
group_expr: final_group_exprs,
|
||||
aggr_expr: final_aggr_exprs,
|
||||
schema: Arc::new(LogicalSchema::new(aggregate_schema_columns)),
|
||||
});
|
||||
@@ -1957,6 +2053,14 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
// 2. An aggregate function
|
||||
// 3. A grouping column (or expression involving only grouping columns)
|
||||
fn is_valid_in_group_by(expr: &LogicalExpr, group_exprs: &[LogicalExpr]) -> bool {
|
||||
// First check if the entire expression appears in GROUP BY
|
||||
// Strip aliases before comparing since SELECT might have aliases but GROUP BY might not
|
||||
let expr_stripped = strip_alias(expr);
|
||||
if group_exprs.iter().any(|g| expr_stripped == strip_alias(g)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If not, check recursively based on expression type
|
||||
match expr {
|
||||
LogicalExpr::Literal(_) => true, // Constants are always valid
|
||||
LogicalExpr::AggregateFunction { .. } => true, // Aggregates are valid
|
||||
@@ -1987,7 +2091,7 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
// Returns the modified expression and a list of NEW (aggregate_expr, column_name) pairs
|
||||
fn extract_and_replace_aggregates_with_dedup(
|
||||
expr: LogicalExpr,
|
||||
aggregate_map: &mut std::collections::HashMap<String, String>,
|
||||
aggregate_map: &mut HashMap<String, String>,
|
||||
) -> Result<(LogicalExpr, Vec<(LogicalExpr, String)>)> {
|
||||
let mut new_aggregates = Vec::new();
|
||||
let mut counter = aggregate_map.len();
|
||||
@@ -2004,7 +2108,7 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
fn replace_aggregates_with_columns_dedup(
|
||||
expr: LogicalExpr,
|
||||
new_aggregates: &mut Vec<(LogicalExpr, String)>,
|
||||
aggregate_map: &mut std::collections::HashMap<String, String>,
|
||||
aggregate_map: &mut HashMap<String, String>,
|
||||
counter: &mut usize,
|
||||
) -> Result<LogicalExpr> {
|
||||
match expr {
|
||||
@@ -3962,4 +4066,144 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tests for strip_alias function
|
||||
#[test]
|
||||
fn test_strip_alias_with_alias() {
|
||||
let inner_expr = LogicalExpr::Column(Column::new("test"));
|
||||
let aliased = LogicalExpr::Alias {
|
||||
expr: Box::new(inner_expr.clone()),
|
||||
alias: "my_alias".to_string(),
|
||||
};
|
||||
|
||||
let stripped = strip_alias(&aliased);
|
||||
assert_eq!(stripped, &inner_expr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_alias_without_alias() {
|
||||
let expr = LogicalExpr::Column(Column::new("test"));
|
||||
let stripped = strip_alias(&expr);
|
||||
assert_eq!(stripped, &expr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_alias_literal() {
|
||||
let expr = LogicalExpr::Literal(Value::Integer(42));
|
||||
let stripped = strip_alias(&expr);
|
||||
assert_eq!(stripped, &expr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_alias_scalar_function() {
|
||||
let expr = LogicalExpr::ScalarFunction {
|
||||
fun: "substr".to_string(),
|
||||
args: vec![
|
||||
LogicalExpr::Column(Column::new("name")),
|
||||
LogicalExpr::Literal(Value::Integer(1)),
|
||||
LogicalExpr::Literal(Value::Integer(4)),
|
||||
],
|
||||
};
|
||||
let stripped = strip_alias(&expr);
|
||||
assert_eq!(stripped, &expr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_alias_nested_alias() {
|
||||
// Test that strip_alias only removes the outermost alias
|
||||
let inner_expr = LogicalExpr::Column(Column::new("test"));
|
||||
let inner_alias = LogicalExpr::Alias {
|
||||
expr: Box::new(inner_expr.clone()),
|
||||
alias: "inner_alias".to_string(),
|
||||
};
|
||||
let outer_alias = LogicalExpr::Alias {
|
||||
expr: Box::new(inner_alias.clone()),
|
||||
alias: "outer_alias".to_string(),
|
||||
};
|
||||
|
||||
let stripped = strip_alias(&outer_alias);
|
||||
assert_eq!(stripped, &inner_alias);
|
||||
|
||||
// Stripping again should give us the inner expression
|
||||
let double_stripped = strip_alias(stripped);
|
||||
assert_eq!(double_stripped, &inner_expr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_alias_comparison_with_alias() {
|
||||
// Test that two expressions match when one has an alias and one doesn't
|
||||
let base_expr = LogicalExpr::ScalarFunction {
|
||||
fun: "substr".to_string(),
|
||||
args: vec![
|
||||
LogicalExpr::Column(Column::new("orderdate")),
|
||||
LogicalExpr::Literal(Value::Integer(1)),
|
||||
LogicalExpr::Literal(Value::Integer(4)),
|
||||
],
|
||||
};
|
||||
|
||||
let aliased_expr = LogicalExpr::Alias {
|
||||
expr: Box::new(base_expr.clone()),
|
||||
alias: "year".to_string(),
|
||||
};
|
||||
|
||||
// Without strip_alias, they wouldn't match
|
||||
assert_ne!(&aliased_expr, &base_expr);
|
||||
|
||||
// With strip_alias, they should match
|
||||
assert_eq!(strip_alias(&aliased_expr), &base_expr);
|
||||
assert_eq!(strip_alias(&base_expr), &base_expr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_alias_binary_expr() {
|
||||
let expr = LogicalExpr::BinaryExpr {
|
||||
left: Box::new(LogicalExpr::Column(Column::new("a"))),
|
||||
op: BinaryOperator::Add,
|
||||
right: Box::new(LogicalExpr::Literal(Value::Integer(1))),
|
||||
};
|
||||
let stripped = strip_alias(&expr);
|
||||
assert_eq!(stripped, &expr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_alias_aggregate_function() {
|
||||
let expr = LogicalExpr::AggregateFunction {
|
||||
fun: AggFunc::Sum,
|
||||
args: vec![LogicalExpr::Column(Column::new("amount"))],
|
||||
distinct: false,
|
||||
};
|
||||
let stripped = strip_alias(&expr);
|
||||
assert_eq!(stripped, &expr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_alias_comparison_multiple_expressions() {
|
||||
// Test comparing a list of expressions with and without aliases
|
||||
let expr1 = LogicalExpr::Column(Column::new("a"));
|
||||
let expr2 = LogicalExpr::ScalarFunction {
|
||||
fun: "substr".to_string(),
|
||||
args: vec![
|
||||
LogicalExpr::Column(Column::new("b")),
|
||||
LogicalExpr::Literal(Value::Integer(1)),
|
||||
LogicalExpr::Literal(Value::Integer(4)),
|
||||
],
|
||||
};
|
||||
|
||||
let aliased1 = LogicalExpr::Alias {
|
||||
expr: Box::new(expr1.clone()),
|
||||
alias: "col_a".to_string(),
|
||||
};
|
||||
let aliased2 = LogicalExpr::Alias {
|
||||
expr: Box::new(expr2.clone()),
|
||||
alias: "year".to_string(),
|
||||
};
|
||||
|
||||
let select_exprs = [aliased1, aliased2];
|
||||
let group_exprs = [expr1.clone(), expr2.clone()];
|
||||
|
||||
// Verify that stripping aliases allows matching
|
||||
for (select_expr, group_expr) in select_exprs.iter().zip(group_exprs.iter()) {
|
||||
assert_eq!(strip_alias(select_expr), group_expr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1612,3 +1612,82 @@ do_execsql_test_on_specific_db {:memory:} matview-join-swapped-columns {
|
||||
SELECT * FROM emp_dept ORDER BY name;
|
||||
} {Alice|Engineering
|
||||
Bob|Sales}
|
||||
|
||||
do_execsql_test_on_specific_db {:memory:} matview-groupby-scalar-function {
|
||||
CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER);
|
||||
INSERT INTO orders VALUES (1, '2020-01-15', 100);
|
||||
INSERT INTO orders VALUES (2, '2020-06-10', 150);
|
||||
INSERT INTO orders VALUES (3, '2021-03-20', 200);
|
||||
|
||||
CREATE MATERIALIZED VIEW yearly_totals AS
|
||||
SELECT substr(orderdate, 1, 4), sum(amount)
|
||||
FROM orders
|
||||
GROUP BY substr(orderdate, 1, 4);
|
||||
|
||||
SELECT * FROM yearly_totals ORDER BY 1;
|
||||
} {2020|250
|
||||
2021|200}
|
||||
|
||||
do_execsql_test_on_specific_db {:memory:} matview-groupby-alias {
|
||||
CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER);
|
||||
INSERT INTO orders VALUES (1, '2020-01-15', 100);
|
||||
INSERT INTO orders VALUES (2, '2020-06-10', 150);
|
||||
INSERT INTO orders VALUES (3, '2021-03-20', 200);
|
||||
|
||||
CREATE MATERIALIZED VIEW yearly_totals AS
|
||||
SELECT substr(orderdate, 1, 4) as year, sum(amount) as total
|
||||
FROM orders
|
||||
GROUP BY year;
|
||||
|
||||
SELECT * FROM yearly_totals ORDER BY year;
|
||||
} {2020|250
|
||||
2021|200}
|
||||
|
||||
do_execsql_test_on_specific_db {:memory:} matview-groupby-position {
|
||||
CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER, nation TEXT);
|
||||
INSERT INTO orders VALUES (1, '2020-01-15', 100, 'USA');
|
||||
INSERT INTO orders VALUES (2, '2020-06-10', 150, 'USA');
|
||||
INSERT INTO orders VALUES (3, '2021-03-20', 200, 'UK');
|
||||
|
||||
CREATE MATERIALIZED VIEW national_yearly AS
|
||||
SELECT nation, substr(orderdate, 1, 4) as year, sum(amount) as total
|
||||
FROM orders
|
||||
GROUP BY 1, 2;
|
||||
|
||||
SELECT * FROM national_yearly ORDER BY nation, year;
|
||||
} {UK|2021|200
|
||||
USA|2020|250}
|
||||
|
||||
do_execsql_test_on_specific_db {:memory:} matview-groupby-scalar-incremental {
|
||||
CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER);
|
||||
INSERT INTO orders VALUES (1, '2020-01-15', 100);
|
||||
|
||||
CREATE MATERIALIZED VIEW yearly_totals AS
|
||||
SELECT substr(orderdate, 1, 4) as year, sum(amount) as total
|
||||
FROM orders
|
||||
GROUP BY year;
|
||||
|
||||
SELECT * FROM yearly_totals;
|
||||
INSERT INTO orders VALUES (2, '2020-06-10', 150);
|
||||
SELECT * FROM yearly_totals;
|
||||
INSERT INTO orders VALUES (3, '2021-03-20', 200);
|
||||
SELECT * FROM yearly_totals ORDER BY year;
|
||||
} {2020|100
|
||||
2020|250
|
||||
2020|250
|
||||
2021|200}
|
||||
|
||||
do_execsql_test_on_specific_db {:memory:} matview-groupby-join-position {
|
||||
CREATE TABLE t(a INTEGER);
|
||||
CREATE TABLE u(a INTEGER);
|
||||
INSERT INTO t VALUES (1), (2), (3);
|
||||
INSERT INTO u VALUES (1), (1), (2);
|
||||
|
||||
CREATE MATERIALIZED VIEW tujoingroup AS
|
||||
SELECT t.a, count(u.a) as cnt
|
||||
FROM t JOIN u ON t.a = u.a
|
||||
GROUP BY 1;
|
||||
|
||||
SELECT * FROM tujoingroup ORDER BY a;
|
||||
} {1|2
|
||||
2|1}
|
||||
|
||||
Reference in New Issue
Block a user