From 16a97d3b9849f6b9f918a567ca8fa3ee7fcb4b1e Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Sat, 1 Feb 2025 22:31:27 +0200 Subject: [PATCH] planner.rs: refactor from/join + where parsing logic - use new TableReference and JoinAwareConditionExpr - add utilities for determining at which loop depth a WHERE condition should be evaluated, now that "operators" do not carry condition expressions inside them anymore. --- core/translate/planner.rs | 254 ++++++++++++++++++++------------------ 1 file changed, 133 insertions(+), 121 deletions(-) diff --git a/core/translate/planner.rs b/core/translate/planner.rs index f66c96ce3..043ef0e35 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,5 +1,8 @@ use super::{ - plan::{Aggregate, Plan, SelectQueryType, SourceOperator, TableReference, TableReferenceType}, + plan::{ + Aggregate, JoinAwareConditionExpr, JoinInfo, Operation, Plan, SelectQueryType, + TableReference, + }, select::prepare_select_plan, SymbolTable, }; @@ -14,21 +17,6 @@ use sqlite3_parser::ast::{self, Expr, FromClause, JoinType, Limit, UnaryOperator pub const ROWID: &str = "rowid"; -pub struct OperatorIdCounter { - id: usize, -} - -impl OperatorIdCounter { - pub fn new() -> Self { - Self { id: 1 } - } - pub fn get_next_id(&mut self) -> usize { - let id = self.id; - self.id += 1; - id - } -} - pub fn resolve_aggregates(expr: &Expr, aggs: &mut Vec) -> bool { if aggs .iter() @@ -140,10 +128,9 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen } Expr::Qualified(tbl, id) => { let normalized_table_name = normalize_ident(tbl.0.as_str()); - let matching_tbl_idx = referenced_tables.iter().position(|t| { - t.table_identifier - .eq_ignore_ascii_case(&normalized_table_name) - }); + let matching_tbl_idx = referenced_tables + .iter() + .position(|t| t.identifier.eq_ignore_ascii_case(&normalized_table_name)); if matching_tbl_idx.is_none() { crate::bail_parse_error!("Table {} not found", normalized_table_name); } @@ -273,10 +260,9 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen fn parse_from_clause_table( schema: &Schema, table: ast::SelectTable, - operator_id_counter: &mut OperatorIdCounter, cur_table_index: usize, syms: &SymbolTable, -) -> Result<(TableReference, SourceOperator)> { +) -> Result { match table { ast::SelectTable::Table(qualified_name, maybe_alias, _) => { let normalized_qualified_name = normalize_ident(qualified_name.name.0.as_str()); @@ -289,21 +275,12 @@ fn parse_from_clause_table( ast::As::Elided(id) => id, }) .map(|a| a.0); - let table_reference = TableReference { + Ok(TableReference { + op: Operation::Scan { iter_dir: None }, table: Table::BTree(table.clone()), - table_identifier: alias.unwrap_or(normalized_qualified_name), - table_index: cur_table_index, - reference_type: TableReferenceType::BTreeTable, - }; - Ok(( - table_reference.clone(), - SourceOperator::Scan { - table_reference, - predicates: None, - id: operator_id_counter.get_next_id(), - iter_dir: None, - }, - )) + identifier: alias.unwrap_or(normalized_qualified_name), + join_info: None, + }) } ast::SelectTable::Select(subselect, maybe_alias) => { let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect, syms)? else { @@ -319,17 +296,8 @@ fn parse_from_clause_table( ast::As::Elided(id) => id.0.clone(), }) .unwrap_or(format!("subquery_{}", cur_table_index)); - let table_reference = - TableReference::new_subquery(identifier.clone(), cur_table_index, &subplan); - Ok(( - table_reference.clone(), - SourceOperator::Subquery { - id: operator_id_counter.get_next_id(), - table_reference, - plan: Box::new(subplan), - predicates: None, - }, - )) + let table_reference = TableReference::new_subquery(identifier, subplan, None); + Ok(table_reference) } _ => todo!(), } @@ -338,99 +306,124 @@ fn parse_from_clause_table( pub fn parse_from( schema: &Schema, mut from: Option, - operator_id_counter: &mut OperatorIdCounter, syms: &SymbolTable, -) -> Result<(SourceOperator, Vec)> { + out_where_clause: &mut Vec, +) -> Result> { if from.as_ref().and_then(|f| f.select.as_ref()).is_none() { - return Ok(( - SourceOperator::Nothing { - id: operator_id_counter.get_next_id(), - }, - vec![], - )); + return Ok(vec![]); } - let mut table_index = 0; let mut tables = vec![]; let mut from_owned = std::mem::take(&mut from).unwrap(); let select_owned = *std::mem::take(&mut from_owned.select).unwrap(); let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default(); - let (table_reference, mut operator) = - parse_from_clause_table(schema, select_owned, operator_id_counter, table_index, syms)?; - + let table_reference = parse_from_clause_table(schema, select_owned, 0, syms)?; tables.push(table_reference); - table_index += 1; for join in joins_owned.into_iter() { - let JoinParseResult { - source_operator: right, - is_outer_join: outer, - using, - predicates, - } = parse_join( - schema, - join, - operator_id_counter, - &mut tables, - table_index, - syms, - )?; - operator = SourceOperator::Join { - left: Box::new(operator), - right: Box::new(right), - predicates, - outer, - using, - id: operator_id_counter.get_next_id(), - }; - table_index += 1; + parse_join(schema, join, syms, &mut tables, out_where_clause)?; } - Ok((operator, tables)) + Ok(tables) } pub fn parse_where( where_clause: Option, - referenced_tables: &[TableReference], -) -> Result>> { + table_references: &[TableReference], + out_where_clause: &mut Vec, +) -> Result<()> { if let Some(where_expr) = where_clause { let mut predicates = vec![]; break_predicate_at_and_boundaries(where_expr, &mut predicates); for expr in predicates.iter_mut() { - bind_column_references(expr, referenced_tables)?; + bind_column_references(expr, table_references)?; } - Ok(Some(predicates)) + for expr in predicates { + let eval_at_loop = get_rightmost_table_referenced_in_expr(&expr)?; + out_where_clause.push(JoinAwareConditionExpr { + expr, + from_outer_join: false, + eval_at_loop, + }); + } + Ok(()) } else { - Ok(None) + Ok(()) } } -struct JoinParseResult { - source_operator: SourceOperator, - is_outer_join: bool, - using: Option, - predicates: Option>, +/** + Returns the rightmost table index that is referenced in the given AST expression. + Rightmost = innermost loop. + This is used to determine where we should evaluate a given condition expression, + and it needs to be the rightmost table referenced in the expression, because otherwise + the condition would be evaluated before a row is read from that table. +*/ +fn get_rightmost_table_referenced_in_expr<'a>(predicate: &'a ast::Expr) -> Result { + let mut max_table_idx = 0; + match predicate { + ast::Expr::Binary(e1, _, e2) => { + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(e1)?); + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(e2)?); + } + ast::Expr::Column { table, .. } => { + max_table_idx = max_table_idx.max(*table); + } + ast::Expr::Id(_) => { + /* Id referring to column will already have been rewritten as an Expr::Column */ + /* we only get here with literal 'true' or 'false' etc */ + } + ast::Expr::Qualified(_, _) => { + unreachable!("Qualified should be resolved to a Column before optimizer") + } + ast::Expr::Literal(_) => {} + ast::Expr::Like { lhs, rhs, .. } => { + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(lhs)?); + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(rhs)?); + } + ast::Expr::FunctionCall { + args: Some(args), .. + } => { + for arg in args { + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(arg)?); + } + } + ast::Expr::InList { lhs, rhs, .. } => { + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(lhs)?); + if let Some(rhs_list) = rhs { + for rhs_expr in rhs_list { + max_table_idx = + max_table_idx.max(get_rightmost_table_referenced_in_expr(rhs_expr)?); + } + } + } + _ => {} + } + + Ok(max_table_idx) } fn parse_join( schema: &Schema, join: ast::JoinedSelectTable, - operator_id_counter: &mut OperatorIdCounter, - tables: &mut Vec, - table_index: usize, syms: &SymbolTable, -) -> Result { + tables: &mut Vec, + out_where_clause: &mut Vec, +) -> Result<()> { let ast::JoinedSelectTable { operator: join_operator, table, constraint, } = join; - let (table_reference, source_operator) = - parse_from_clause_table(schema, table, operator_id_counter, table_index, syms)?; - - tables.push(table_reference); + let cur_table_index = tables.len(); + tables.push(parse_from_clause_table( + schema, + table, + cur_table_index, + syms, + )?); let (outer, natural) = match join_operator { ast::JoinOperator::TypedJoin(Some(join_type)) => { @@ -442,23 +435,21 @@ fn parse_join( }; let mut using = None; - let mut predicates = None; if natural && constraint.is_some() { crate::bail_parse_error!("NATURAL JOIN cannot be combined with ON or USING clause"); } let constraint = if natural { + assert!(tables.len() >= 2); + let rightmost_table = tables.last().unwrap(); // NATURAL JOIN is first transformed into a USING join with the common columns - let left_tables = &tables[..table_index]; - assert!(!left_tables.is_empty()); - let right_table = &tables[table_index]; - let right_cols = &right_table.columns(); + let right_cols = rightmost_table.columns(); let mut distinct_names: Option = None; // TODO: O(n^2) maybe not great for large tables or big multiway joins for right_col in right_cols.iter() { let mut found_match = false; - for left_table in left_tables.iter() { + for left_table in tables.iter().take(tables.len() - 1) { for left_col in left_table.columns().iter() { if left_col.name == right_col.name { if let Some(distinct_names) = distinct_names.as_mut() { @@ -495,16 +486,28 @@ fn parse_join( for predicate in preds.iter_mut() { bind_column_references(predicate, tables)?; } - predicates = Some(preds); + for pred in preds { + let cur_table_idx = tables.len() - 1; + let eval_at_loop = if outer { + cur_table_idx + } else { + get_rightmost_table_referenced_in_expr(&pred)? + }; + out_where_clause.push(JoinAwareConditionExpr { + expr: pred, + from_outer_join: outer, + eval_at_loop, + }); + } } ast::JoinConstraint::Using(distinct_names) => { // USING join is replaced with a list of equality predicates - let mut using_predicates = vec![]; for distinct_name in distinct_names.iter() { let name_normalized = normalize_ident(distinct_name.0.as_str()); - let left_tables = &tables[..table_index]; + let cur_table_idx = tables.len() - 1; + let left_tables = &tables[..cur_table_idx]; assert!(!left_tables.is_empty()); - let right_table = &tables[table_index]; + let right_table = tables.last().unwrap(); let mut left_col = None; for (left_table_idx, left_table) in left_tables.iter().enumerate() { left_col = left_table @@ -536,7 +539,7 @@ fn parse_join( } let (left_table_idx, left_col_idx, left_col) = left_col.unwrap(); let (right_col_idx, right_col) = right_col.unwrap(); - using_predicates.push(Expr::Binary( + let expr = Expr::Binary( Box::new(Expr::Column { database: None, table: left_table_idx, @@ -546,24 +549,33 @@ fn parse_join( ast::Operator::Equals, Box::new(Expr::Column { database: None, - table: right_table.table_index, + table: cur_table_idx, column: right_col_idx, is_rowid_alias: right_col.is_rowid_alias, }), - )); + ); + let eval_at_loop = if outer { + cur_table_idx + } else { + get_rightmost_table_referenced_in_expr(&expr)? + }; + out_where_clause.push(JoinAwareConditionExpr { + expr, + from_outer_join: outer, + eval_at_loop, + }); } - predicates = Some(using_predicates); using = Some(distinct_names); } } } - Ok(JoinParseResult { - source_operator, - is_outer_join: outer, - using, - predicates, - }) + assert!(tables.len() >= 2); + let last_idx = tables.len() - 1; + let rightmost_table = tables.get_mut(last_idx).unwrap(); + rightmost_table.join_info = Some(JoinInfo { outer, using }); + + Ok(()) } pub fn parse_limit(limit: Limit) -> Result<(Option, Option)> {