diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 8741e1e62..57e1dec52 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -43,7 +43,7 @@ pub fn prepare_delete_plan( let mut where_predicates = vec![]; // Parse the WHERE clause - parse_where(where_clause, &table_references, &mut where_predicates)?; + parse_where(where_clause, &table_references, None, &mut where_predicates)?; // Parse the LIMIT/OFFSET clause let (resolved_limit, resolved_offset) = limit.map_or(Ok((None, None)), |l| parse_limit(*l))?; diff --git a/core/translate/planner.rs b/core/translate/planner.rs index d753e0205..1dc44b8a8 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,5 +1,8 @@ use super::{ - plan::{Aggregate, JoinInfo, Operation, Plan, SelectQueryType, TableReference, WhereTerm}, + plan::{ + Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, + WhereTerm, + }, select::prepare_select_plan, SymbolTable, }; @@ -78,7 +81,11 @@ pub fn resolve_aggregates(expr: &Expr, aggs: &mut Vec) -> bool { } } -pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReference]) -> Result<()> { +pub fn bind_column_references( + expr: &mut Expr, + referenced_tables: &[TableReference], + result_columns: Option<&[ResultSetColumn]>, +) -> Result<()> { match expr { Expr::Id(id) => { // true and false are special constants that are effectively aliases for 1 and 0 @@ -111,17 +118,25 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen match_result = Some((tbl_idx, col_idx.unwrap(), col.is_rowid_alias)); } } - if match_result.is_none() { - crate::bail_parse_error!("Column {} not found", id.0); + if let Some((tbl_idx, col_idx, is_rowid_alias)) = match_result { + *expr = Expr::Column { + database: None, // TODO: support different databases + table: tbl_idx, + column: col_idx, + is_rowid_alias, + }; + return Ok(()); } - let (tbl_idx, col_idx, is_rowid_alias) = match_result.unwrap(); - *expr = Expr::Column { - database: None, // TODO: support different databases - table: tbl_idx, - column: col_idx, - is_rowid_alias, - }; - Ok(()) + + if let Some(result_columns) = result_columns { + for result_column in result_columns.iter() { + if result_column.name == normalized_id { + *expr = result_column.expr.clone(); + return Ok(()); + } + } + } + crate::bail_parse_error!("Column {} not found", id.0); } Expr::Qualified(tbl, id) => { let normalized_table_name = normalize_ident(tbl.0.as_str()); @@ -164,14 +179,14 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen start, end, } => { - bind_column_references(lhs, referenced_tables)?; - bind_column_references(start, referenced_tables)?; - bind_column_references(end, referenced_tables)?; + bind_column_references(lhs, referenced_tables, result_columns)?; + bind_column_references(start, referenced_tables, result_columns)?; + bind_column_references(end, referenced_tables, result_columns)?; Ok(()) } Expr::Binary(expr, _operator, expr1) => { - bind_column_references(expr, referenced_tables)?; - bind_column_references(expr1, referenced_tables)?; + bind_column_references(expr, referenced_tables, result_columns)?; + bind_column_references(expr1, referenced_tables, result_columns)?; Ok(()) } Expr::Case { @@ -180,19 +195,23 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen else_expr, } => { if let Some(base) = base { - bind_column_references(base, referenced_tables)?; + bind_column_references(base, referenced_tables, result_columns)?; } for (when, then) in when_then_pairs { - bind_column_references(when, referenced_tables)?; - bind_column_references(then, referenced_tables)?; + bind_column_references(when, referenced_tables, result_columns)?; + bind_column_references(then, referenced_tables, result_columns)?; } if let Some(else_expr) = else_expr { - bind_column_references(else_expr, referenced_tables)?; + bind_column_references(else_expr, referenced_tables, result_columns)?; } Ok(()) } - Expr::Cast { expr, type_name: _ } => bind_column_references(expr, referenced_tables), - Expr::Collate(expr, _string) => bind_column_references(expr, referenced_tables), + Expr::Cast { expr, type_name: _ } => { + bind_column_references(expr, referenced_tables, result_columns) + } + Expr::Collate(expr, _string) => { + bind_column_references(expr, referenced_tables, result_columns) + } Expr::FunctionCall { name: _, distinctness: _, @@ -202,7 +221,7 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen } => { if let Some(args) = args { for arg in args { - bind_column_references(arg, referenced_tables)?; + bind_column_references(arg, referenced_tables, result_columns)?; } } Ok(()) @@ -213,10 +232,10 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen Expr::Exists(_) => todo!(), Expr::FunctionCallStar { .. } => Ok(()), Expr::InList { lhs, not: _, rhs } => { - bind_column_references(lhs, referenced_tables)?; + bind_column_references(lhs, referenced_tables, result_columns)?; if let Some(rhs) = rhs { for arg in rhs { - bind_column_references(arg, referenced_tables)?; + bind_column_references(arg, referenced_tables, result_columns)?; } } Ok(()) @@ -224,30 +243,30 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen Expr::InSelect { .. } => todo!(), Expr::InTable { .. } => todo!(), Expr::IsNull(expr) => { - bind_column_references(expr, referenced_tables)?; + bind_column_references(expr, referenced_tables, result_columns)?; Ok(()) } Expr::Like { lhs, rhs, .. } => { - bind_column_references(lhs, referenced_tables)?; - bind_column_references(rhs, referenced_tables)?; + bind_column_references(lhs, referenced_tables, result_columns)?; + bind_column_references(rhs, referenced_tables, result_columns)?; Ok(()) } Expr::Literal(_) => Ok(()), Expr::Name(_) => todo!(), Expr::NotNull(expr) => { - bind_column_references(expr, referenced_tables)?; + bind_column_references(expr, referenced_tables, result_columns)?; Ok(()) } Expr::Parenthesized(expr) => { for e in expr.iter_mut() { - bind_column_references(e, referenced_tables)?; + bind_column_references(e, referenced_tables, result_columns)?; } Ok(()) } Expr::Raise(_, _) => todo!(), Expr::Subquery(_) => todo!(), Expr::Unary(_, expr) => { - bind_column_references(expr, referenced_tables)?; + bind_column_references(expr, referenced_tables, result_columns)?; Ok(()) } Expr::Variable(_) => Ok(()), @@ -328,13 +347,14 @@ pub fn parse_from( pub fn parse_where( where_clause: Option, table_references: &[TableReference], + result_columns: Option<&[ResultSetColumn]>, out_where_clause: &mut Vec, ) -> Result<()> { if let Some(where_expr) = where_clause { let mut predicates = vec![]; break_predicate_at_and_boundaries(where_expr, &mut predicates); for expr in predicates.iter_mut() { - bind_column_references(expr, table_references)?; + bind_column_references(expr, table_references, result_columns)?; } for expr in predicates { let eval_at_loop = get_rightmost_table_referenced_in_expr(&expr)?; @@ -481,7 +501,7 @@ fn parse_join( let mut preds = vec![]; break_predicate_at_and_boundaries(expr, &mut preds); for predicate in preds.iter_mut() { - bind_column_references(predicate, tables)?; + bind_column_references(predicate, tables, None)?; } for pred in preds { let cur_table_idx = tables.len() - 1; diff --git a/core/translate/select.rs b/core/translate/select.rs index 6a9250296..81955c887 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -62,9 +62,6 @@ pub fn prepare_select_plan( query_type: SelectQueryType::TopLevel, }; - // Parse the actual WHERE clause and add its conditions to the plan WHERE clause that already contains the join conditions. - parse_where(where_clause, &plan.table_references, &mut plan.where_clause)?; - let mut aggregate_expressions = Vec::new(); for (result_column_idx, column) in columns.iter_mut().enumerate() { match column { @@ -97,7 +94,11 @@ pub fn prepare_select_plan( } } ResultColumn::Expr(ref mut expr, maybe_alias) => { - bind_column_references(expr, &plan.table_references)?; + bind_column_references( + expr, + &plan.table_references, + Some(&plan.result_columns), + )?; match expr { ast::Expr::FunctionCall { name, @@ -255,9 +256,22 @@ pub fn prepare_select_plan( } } } + + // Parse the actual WHERE clause and add its conditions to the plan WHERE clause that already contains the join conditions. + parse_where( + where_clause, + &plan.table_references, + Some(&plan.result_columns), + &mut plan.where_clause, + )?; + if let Some(mut group_by) = group_by { for expr in group_by.exprs.iter_mut() { - bind_column_references(expr, &plan.table_references)?; + bind_column_references( + expr, + &plan.table_references, + Some(&plan.result_columns), + )?; } plan.group_by = Some(GroupBy { @@ -266,7 +280,11 @@ pub fn prepare_select_plan( let mut predicates = vec![]; break_predicate_at_and_boundaries(having, &mut predicates); for expr in predicates.iter_mut() { - bind_column_references(expr, &plan.table_references)?; + bind_column_references( + expr, + &plan.table_references, + Some(&plan.result_columns), + )?; let contains_aggregates = resolve_aggregates(expr, &mut aggregate_expressions); if !contains_aggregates { @@ -312,7 +330,11 @@ pub fn prepare_select_plan( o.expr }; - bind_column_references(&mut expr, &plan.table_references)?; + bind_column_references( + &mut expr, + &plan.table_references, + Some(&plan.result_columns), + )?; resolve_aggregates(&expr, &mut plan.aggregates); key.push(( diff --git a/testing/groupby.test b/testing/groupby.test index 350405208..c370c3df2 100644 --- a/testing/groupby.test +++ b/testing/groupby.test @@ -174,4 +174,10 @@ do_execsql_test having_with_multiple_conditions { David|165|53.0 Robert|159|51.0 Jennifer|151|51.0 -John|145|50.0} \ No newline at end of file +John|145|50.0} + +# Wanda = 9, Whitney = 11, William = 111 +do_execsql_test column_alias_in_group_by_order_by_having { + select first_name as fn, count(1) as fn_count from users where fn in ('Wanda', 'Whitney', 'William') group by fn having fn_count > 10 order by fn_count; +} {Whitney|11 +William|111}