planner.rs: refactor from/join + where parsing logic

- use new TableReference and JoinAwareConditionExpr
- add utilities for determining at which loop depth a
  WHERE condition should be evaluated, now that "operators"
  do not carry condition expressions inside them anymore.
This commit is contained in:
Jussi Saurio
2025-02-01 22:31:27 +02:00
parent e63256f657
commit 16a97d3b98

View File

@@ -1,5 +1,8 @@
use super::{
plan::{Aggregate, Plan, SelectQueryType, SourceOperator, TableReference, TableReferenceType},
plan::{
Aggregate, JoinAwareConditionExpr, JoinInfo, Operation, Plan, SelectQueryType,
TableReference,
},
select::prepare_select_plan,
SymbolTable,
};
@@ -14,21 +17,6 @@ use sqlite3_parser::ast::{self, Expr, FromClause, JoinType, Limit, UnaryOperator
pub const ROWID: &str = "rowid";
pub struct OperatorIdCounter {
id: usize,
}
impl OperatorIdCounter {
pub fn new() -> Self {
Self { id: 1 }
}
pub fn get_next_id(&mut self) -> usize {
let id = self.id;
self.id += 1;
id
}
}
pub fn resolve_aggregates(expr: &Expr, aggs: &mut Vec<Aggregate>) -> bool {
if aggs
.iter()
@@ -140,10 +128,9 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen
}
Expr::Qualified(tbl, id) => {
let normalized_table_name = normalize_ident(tbl.0.as_str());
let matching_tbl_idx = referenced_tables.iter().position(|t| {
t.table_identifier
.eq_ignore_ascii_case(&normalized_table_name)
});
let matching_tbl_idx = referenced_tables
.iter()
.position(|t| t.identifier.eq_ignore_ascii_case(&normalized_table_name));
if matching_tbl_idx.is_none() {
crate::bail_parse_error!("Table {} not found", normalized_table_name);
}
@@ -273,10 +260,9 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen
fn parse_from_clause_table(
schema: &Schema,
table: ast::SelectTable,
operator_id_counter: &mut OperatorIdCounter,
cur_table_index: usize,
syms: &SymbolTable,
) -> Result<(TableReference, SourceOperator)> {
) -> Result<TableReference> {
match table {
ast::SelectTable::Table(qualified_name, maybe_alias, _) => {
let normalized_qualified_name = normalize_ident(qualified_name.name.0.as_str());
@@ -289,21 +275,12 @@ fn parse_from_clause_table(
ast::As::Elided(id) => id,
})
.map(|a| a.0);
let table_reference = TableReference {
Ok(TableReference {
op: Operation::Scan { iter_dir: None },
table: Table::BTree(table.clone()),
table_identifier: alias.unwrap_or(normalized_qualified_name),
table_index: cur_table_index,
reference_type: TableReferenceType::BTreeTable,
};
Ok((
table_reference.clone(),
SourceOperator::Scan {
table_reference,
predicates: None,
id: operator_id_counter.get_next_id(),
iter_dir: None,
},
))
identifier: alias.unwrap_or(normalized_qualified_name),
join_info: None,
})
}
ast::SelectTable::Select(subselect, maybe_alias) => {
let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect, syms)? else {
@@ -319,17 +296,8 @@ fn parse_from_clause_table(
ast::As::Elided(id) => id.0.clone(),
})
.unwrap_or(format!("subquery_{}", cur_table_index));
let table_reference =
TableReference::new_subquery(identifier.clone(), cur_table_index, &subplan);
Ok((
table_reference.clone(),
SourceOperator::Subquery {
id: operator_id_counter.get_next_id(),
table_reference,
plan: Box::new(subplan),
predicates: None,
},
))
let table_reference = TableReference::new_subquery(identifier, subplan, None);
Ok(table_reference)
}
_ => todo!(),
}
@@ -338,99 +306,124 @@ fn parse_from_clause_table(
pub fn parse_from(
schema: &Schema,
mut from: Option<FromClause>,
operator_id_counter: &mut OperatorIdCounter,
syms: &SymbolTable,
) -> Result<(SourceOperator, Vec<TableReference>)> {
out_where_clause: &mut Vec<JoinAwareConditionExpr>,
) -> Result<Vec<TableReference>> {
if from.as_ref().and_then(|f| f.select.as_ref()).is_none() {
return Ok((
SourceOperator::Nothing {
id: operator_id_counter.get_next_id(),
},
vec![],
));
return Ok(vec![]);
}
let mut table_index = 0;
let mut tables = vec![];
let mut from_owned = std::mem::take(&mut from).unwrap();
let select_owned = *std::mem::take(&mut from_owned.select).unwrap();
let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default();
let (table_reference, mut operator) =
parse_from_clause_table(schema, select_owned, operator_id_counter, table_index, syms)?;
let table_reference = parse_from_clause_table(schema, select_owned, 0, syms)?;
tables.push(table_reference);
table_index += 1;
for join in joins_owned.into_iter() {
let JoinParseResult {
source_operator: right,
is_outer_join: outer,
using,
predicates,
} = parse_join(
schema,
join,
operator_id_counter,
&mut tables,
table_index,
syms,
)?;
operator = SourceOperator::Join {
left: Box::new(operator),
right: Box::new(right),
predicates,
outer,
using,
id: operator_id_counter.get_next_id(),
};
table_index += 1;
parse_join(schema, join, syms, &mut tables, out_where_clause)?;
}
Ok((operator, tables))
Ok(tables)
}
pub fn parse_where(
where_clause: Option<Expr>,
referenced_tables: &[TableReference],
) -> Result<Option<Vec<Expr>>> {
table_references: &[TableReference],
out_where_clause: &mut Vec<JoinAwareConditionExpr>,
) -> Result<()> {
if let Some(where_expr) = where_clause {
let mut predicates = vec![];
break_predicate_at_and_boundaries(where_expr, &mut predicates);
for expr in predicates.iter_mut() {
bind_column_references(expr, referenced_tables)?;
bind_column_references(expr, table_references)?;
}
Ok(Some(predicates))
for expr in predicates {
let eval_at_loop = get_rightmost_table_referenced_in_expr(&expr)?;
out_where_clause.push(JoinAwareConditionExpr {
expr,
from_outer_join: false,
eval_at_loop,
});
}
Ok(())
} else {
Ok(None)
Ok(())
}
}
struct JoinParseResult {
source_operator: SourceOperator,
is_outer_join: bool,
using: Option<ast::DistinctNames>,
predicates: Option<Vec<Expr>>,
/**
Returns the rightmost table index that is referenced in the given AST expression.
Rightmost = innermost loop.
This is used to determine where we should evaluate a given condition expression,
and it needs to be the rightmost table referenced in the expression, because otherwise
the condition would be evaluated before a row is read from that table.
*/
fn get_rightmost_table_referenced_in_expr<'a>(predicate: &'a ast::Expr) -> Result<usize> {
let mut max_table_idx = 0;
match predicate {
ast::Expr::Binary(e1, _, e2) => {
max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(e1)?);
max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(e2)?);
}
ast::Expr::Column { table, .. } => {
max_table_idx = max_table_idx.max(*table);
}
ast::Expr::Id(_) => {
/* Id referring to column will already have been rewritten as an Expr::Column */
/* we only get here with literal 'true' or 'false' etc */
}
ast::Expr::Qualified(_, _) => {
unreachable!("Qualified should be resolved to a Column before optimizer")
}
ast::Expr::Literal(_) => {}
ast::Expr::Like { lhs, rhs, .. } => {
max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(lhs)?);
max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(rhs)?);
}
ast::Expr::FunctionCall {
args: Some(args), ..
} => {
for arg in args {
max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(arg)?);
}
}
ast::Expr::InList { lhs, rhs, .. } => {
max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(lhs)?);
if let Some(rhs_list) = rhs {
for rhs_expr in rhs_list {
max_table_idx =
max_table_idx.max(get_rightmost_table_referenced_in_expr(rhs_expr)?);
}
}
}
_ => {}
}
Ok(max_table_idx)
}
fn parse_join(
schema: &Schema,
join: ast::JoinedSelectTable,
operator_id_counter: &mut OperatorIdCounter,
tables: &mut Vec<TableReference>,
table_index: usize,
syms: &SymbolTable,
) -> Result<JoinParseResult> {
tables: &mut Vec<TableReference>,
out_where_clause: &mut Vec<JoinAwareConditionExpr>,
) -> Result<()> {
let ast::JoinedSelectTable {
operator: join_operator,
table,
constraint,
} = join;
let (table_reference, source_operator) =
parse_from_clause_table(schema, table, operator_id_counter, table_index, syms)?;
tables.push(table_reference);
let cur_table_index = tables.len();
tables.push(parse_from_clause_table(
schema,
table,
cur_table_index,
syms,
)?);
let (outer, natural) = match join_operator {
ast::JoinOperator::TypedJoin(Some(join_type)) => {
@@ -442,23 +435,21 @@ fn parse_join(
};
let mut using = None;
let mut predicates = None;
if natural && constraint.is_some() {
crate::bail_parse_error!("NATURAL JOIN cannot be combined with ON or USING clause");
}
let constraint = if natural {
assert!(tables.len() >= 2);
let rightmost_table = tables.last().unwrap();
// NATURAL JOIN is first transformed into a USING join with the common columns
let left_tables = &tables[..table_index];
assert!(!left_tables.is_empty());
let right_table = &tables[table_index];
let right_cols = &right_table.columns();
let right_cols = rightmost_table.columns();
let mut distinct_names: Option<ast::DistinctNames> = None;
// TODO: O(n^2) maybe not great for large tables or big multiway joins
for right_col in right_cols.iter() {
let mut found_match = false;
for left_table in left_tables.iter() {
for left_table in tables.iter().take(tables.len() - 1) {
for left_col in left_table.columns().iter() {
if left_col.name == right_col.name {
if let Some(distinct_names) = distinct_names.as_mut() {
@@ -495,16 +486,28 @@ fn parse_join(
for predicate in preds.iter_mut() {
bind_column_references(predicate, tables)?;
}
predicates = Some(preds);
for pred in preds {
let cur_table_idx = tables.len() - 1;
let eval_at_loop = if outer {
cur_table_idx
} else {
get_rightmost_table_referenced_in_expr(&pred)?
};
out_where_clause.push(JoinAwareConditionExpr {
expr: pred,
from_outer_join: outer,
eval_at_loop,
});
}
}
ast::JoinConstraint::Using(distinct_names) => {
// USING join is replaced with a list of equality predicates
let mut using_predicates = vec![];
for distinct_name in distinct_names.iter() {
let name_normalized = normalize_ident(distinct_name.0.as_str());
let left_tables = &tables[..table_index];
let cur_table_idx = tables.len() - 1;
let left_tables = &tables[..cur_table_idx];
assert!(!left_tables.is_empty());
let right_table = &tables[table_index];
let right_table = tables.last().unwrap();
let mut left_col = None;
for (left_table_idx, left_table) in left_tables.iter().enumerate() {
left_col = left_table
@@ -536,7 +539,7 @@ fn parse_join(
}
let (left_table_idx, left_col_idx, left_col) = left_col.unwrap();
let (right_col_idx, right_col) = right_col.unwrap();
using_predicates.push(Expr::Binary(
let expr = Expr::Binary(
Box::new(Expr::Column {
database: None,
table: left_table_idx,
@@ -546,24 +549,33 @@ fn parse_join(
ast::Operator::Equals,
Box::new(Expr::Column {
database: None,
table: right_table.table_index,
table: cur_table_idx,
column: right_col_idx,
is_rowid_alias: right_col.is_rowid_alias,
}),
));
);
let eval_at_loop = if outer {
cur_table_idx
} else {
get_rightmost_table_referenced_in_expr(&expr)?
};
out_where_clause.push(JoinAwareConditionExpr {
expr,
from_outer_join: outer,
eval_at_loop,
});
}
predicates = Some(using_predicates);
using = Some(distinct_names);
}
}
}
Ok(JoinParseResult {
source_operator,
is_outer_join: outer,
using,
predicates,
})
assert!(tables.len() >= 2);
let last_idx = tables.len() - 1;
let rightmost_table = tables.get_mut(last_idx).unwrap();
rightmost_table.join_info = Some(JoinInfo { outer, using });
Ok(())
}
pub fn parse_limit(limit: Limit) -> Result<(Option<isize>, Option<isize>)> {