use std::{collections::HashMap, sync::Arc}; use limbo_sqlite3_parser::ast::{self, Expr, SortOrder}; use crate::{ schema::{Index, Schema}, util::exprs_are_equivalent, Result, }; use super::plan::{ DeletePlan, Direction, GroupBy, IterationDirection, Operation, Plan, Search, SelectPlan, TableReference, UpdatePlan, WhereTerm, }; pub fn optimize_plan(plan: &mut Plan, schema: &Schema) -> Result<()> { match plan { Plan::Select(plan) => optimize_select_plan(plan, schema), Plan::Delete(plan) => optimize_delete_plan(plan, schema), Plan::Update(plan) => optimize_update_plan(plan, schema), } } /** * Make a few passes over the plan to optimize it. * TODO: these could probably be done in less passes, * but having them separate makes them easier to understand */ fn optimize_select_plan(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { optimize_subqueries(plan, schema)?; rewrite_exprs_select(plan)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? { plan.contains_constant_false_condition = true; return Ok(()); } use_indexes( &mut plan.table_references, &schema.indexes, &mut plan.where_clause, &mut plan.order_by, &plan.group_by, )?; eliminate_orderby_like_groupby(plan)?; Ok(()) } fn optimize_delete_plan(plan: &mut DeletePlan, schema: &Schema) -> Result<()> { rewrite_exprs_delete(plan)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? { plan.contains_constant_false_condition = true; return Ok(()); } use_indexes( &mut plan.table_references, &schema.indexes, &mut plan.where_clause, &mut plan.order_by, &None, )?; Ok(()) } fn optimize_update_plan(plan: &mut UpdatePlan, schema: &Schema) -> Result<()> { rewrite_exprs_update(plan)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? { plan.contains_constant_false_condition = true; return Ok(()); } use_indexes( &mut plan.table_references, &schema.indexes, &mut plan.where_clause, &mut plan.order_by, &None, )?; Ok(()) } fn optimize_subqueries(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { for table in plan.table_references.iter_mut() { if let Operation::Subquery { plan, .. } = &mut table.op { optimize_select_plan(&mut *plan, schema)?; } } Ok(()) } fn eliminate_orderby_like_groupby(plan: &mut SelectPlan) -> Result<()> { if plan.order_by.is_none() | plan.group_by.is_none() { return Ok(()); } if plan.table_references.len() == 0 { return Ok(()); } let order_by_clauses = plan.order_by.as_mut().unwrap(); let group_by_clauses = plan.group_by.as_mut().unwrap(); let mut group_by_insert_position = 0; let mut order_index = 0; // This function optimizes query execution by eliminating duplicate expressions between ORDER BY and GROUP BY clauses // When the same column appears in both clauses, we can avoid redundant sorting operations // The function reorders GROUP BY expressions and removes redundant ORDER BY expressions to ensure consistent ordering while order_index < order_by_clauses.len() { let (order_expr, direction) = &order_by_clauses[order_index]; // Skip descending orders as they require separate sorting if matches!(direction, Direction::Descending) { order_index += 1; continue; } // Check if the current ORDER BY expression matches any expression in the GROUP BY clause if let Some(group_expr_position) = group_by_clauses .exprs .iter() .position(|expr| exprs_are_equivalent(expr, order_expr)) { // If we found a matching expression in GROUP BY, we need to ensure it's in the correct position // to preserve the ordering specified by ORDER BY clauses // Move the matching GROUP BY expression to the current insertion position // This effectively "bubbles up" the expression to maintain proper ordering if group_expr_position != group_by_insert_position { let mut current_position = group_expr_position; // Swap expressions to move the matching one to the correct position while current_position > group_by_insert_position { group_by_clauses .exprs .swap(current_position, current_position - 1); current_position -= 1; } } group_by_insert_position += 1; // Remove this expression from ORDER BY since it's now handled by GROUP BY order_by_clauses.remove(order_index); // Note: We don't increment order_index here because removal shifts all elements } else { // If not found in GROUP BY, move to next ORDER BY expression order_index += 1; } } if order_by_clauses.is_empty() { plan.order_by = None } Ok(()) } fn eliminate_unnecessary_orderby( table_references: &mut [TableReference], available_indexes: &HashMap>>, order_by: &mut Option>, group_by: &Option, ) -> Result<()> { let Some(order) = order_by else { return Ok(()); }; let Some(first_table_reference) = table_references.first_mut() else { return Ok(()); }; let Some(btree_table) = first_table_reference.btree() else { return Ok(()); }; // If GROUP BY clause is present, we can't rely on already ordered columns because GROUP BY reorders the data // This early return prevents the elimination of ORDER BY when GROUP BY exists, as sorting must be applied after grouping // And if ORDER BY clause duplicates GROUP BY we handle it later in fn eliminate_orderby_like_groupby if group_by.is_some() { return Ok(()); } let Operation::Scan { index, iter_dir, .. } = &mut first_table_reference.op else { return Ok(()); }; assert!( index.is_none(), "Nothing shouldve transformed the scan to use an index yet" ); // Special case: if ordering by just the rowid, we can remove the ORDER BY clause if order.len() == 1 && order[0].0.is_rowid_alias_of(0) { *iter_dir = match order[0].1 { Direction::Ascending => IterationDirection::Forwards, Direction::Descending => IterationDirection::Backwards, }; *order_by = None; return Ok(()); } // Find the best matching index for the ORDER BY columns let table_name = &btree_table.name; let mut best_index = (None, 0); for (_, indexes) in available_indexes.iter() { for index_candidate in indexes.iter().filter(|i| &i.table_name == table_name) { let matching_columns = index_candidate.columns.iter().enumerate().take_while(|(i, c)| { if let Some((Expr::Column { table, column, .. }, _)) = order.get(*i) { let col_idx_in_table = btree_table .columns .iter() .position(|tc| tc.name.as_ref() == Some(&c.name)); matches!(col_idx_in_table, Some(col_idx) if *table == 0 && *column == col_idx) } else { false } }).count(); if matching_columns > best_index.1 { best_index = (Some(index_candidate), matching_columns); } } } let Some(matching_index) = best_index.0 else { return Ok(()); }; let match_count = best_index.1; // If we found a matching index, use it for scanning *index = Some(matching_index.clone()); // If the order by direction matches the index direction, we can iterate the index in forwards order. // If they don't, we must iterate the index in backwards order. let index_direction = &matching_index.columns.first().as_ref().unwrap().order; *iter_dir = match (index_direction, order[0].1) { (SortOrder::Asc, Direction::Ascending) | (SortOrder::Desc, Direction::Descending) => { IterationDirection::Forwards } (SortOrder::Asc, Direction::Descending) | (SortOrder::Desc, Direction::Ascending) => { IterationDirection::Backwards } }; // If the index covers all ORDER BY columns, and one of the following applies: // - the ORDER BY directions exactly match the index orderings, // - the ORDER by directions are the exact opposite of the index orderings, // we can remove the ORDER BY clause. if match_count == order.len() { let full_match = { let mut all_match_forward = true; let mut all_match_reverse = true; for (i, (_, direction)) in order.iter().enumerate() { match (&matching_index.columns[i].order, direction) { (SortOrder::Asc, Direction::Ascending) | (SortOrder::Desc, Direction::Descending) => { all_match_reverse = false; } (SortOrder::Asc, Direction::Descending) | (SortOrder::Desc, Direction::Ascending) => { all_match_forward = false; } } } all_match_forward || all_match_reverse }; if full_match { *order_by = None; } } Ok(()) } /** * Use indexes where possible. * * When this function is called, condition expressions from both the actual WHERE clause and the JOIN clauses are in the where_clause vector. * If we find a condition that can be used to index scan, we pop it off from the where_clause vector and put it into a Search operation. * We put it there simply because it makes it a bit easier to track during translation. * * In this function we also try to eliminate ORDER BY clauses if there is an index that satisfies the ORDER BY clause. */ fn use_indexes( table_references: &mut [TableReference], available_indexes: &HashMap>>, where_clause: &mut Vec, order_by: &mut Option>, group_by: &Option, ) -> Result<()> { // Try to use indexes for eliminating ORDER BY clauses eliminate_unnecessary_orderby(table_references, available_indexes, order_by, group_by)?; // Try to use indexes for WHERE conditions 'outer: for (table_index, table_reference) in table_references.iter_mut().enumerate() { if let Operation::Scan { iter_dir, .. } = &table_reference.op { let mut i = 0; while i < where_clause.len() { let cond = where_clause.get_mut(i).unwrap(); if let Some(index_search) = try_extract_index_search_expression( cond, table_index, table_reference, available_indexes, *iter_dir, )? { where_clause.remove(i); table_reference.op = Operation::Search(index_search); continue 'outer; } i += 1; } } } Ok(()) } #[derive(Debug, PartialEq, Clone)] enum ConstantConditionEliminationResult { Continue, ImpossibleCondition, } /// Removes predicates that are always true. /// Returns a ConstantEliminationResult indicating whether any predicates are always false. /// This is used to determine whether the query can be aborted early. fn eliminate_constant_conditions( where_clause: &mut Vec, ) -> Result { let mut i = 0; while i < where_clause.len() { let predicate = &where_clause[i]; if predicate.expr.is_always_true()? { // true predicates can be removed since they don't affect the result where_clause.remove(i); } else if predicate.expr.is_always_false()? { // any false predicate in a list of conjuncts (AND-ed predicates) will make the whole list false, // except an outer join condition, because that just results in NULLs, not skipping the whole loop if predicate.from_outer_join { i += 1; continue; } where_clause.truncate(0); return Ok(ConstantConditionEliminationResult::ImpossibleCondition); } else { i += 1; } } Ok(ConstantConditionEliminationResult::Continue) } fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> { for rc in plan.result_columns.iter_mut() { rewrite_expr(&mut rc.expr)?; } for agg in plan.aggregates.iter_mut() { rewrite_expr(&mut agg.original_expr)?; } for cond in plan.where_clause.iter_mut() { rewrite_expr(&mut cond.expr)?; } if let Some(group_by) = &mut plan.group_by { for expr in group_by.exprs.iter_mut() { rewrite_expr(expr)?; } } if let Some(order_by) = &mut plan.order_by { for (expr, _) in order_by.iter_mut() { rewrite_expr(expr)?; } } Ok(()) } fn rewrite_exprs_delete(plan: &mut DeletePlan) -> Result<()> { for cond in plan.where_clause.iter_mut() { rewrite_expr(&mut cond.expr)?; } Ok(()) } fn rewrite_exprs_update(plan: &mut UpdatePlan) -> Result<()> { if let Some(rc) = plan.returning.as_mut() { for rc in rc.iter_mut() { rewrite_expr(&mut rc.expr)?; } } for (_, expr) in plan.set_clauses.iter_mut() { rewrite_expr(expr)?; } for cond in plan.where_clause.iter_mut() { rewrite_expr(&mut cond.expr)?; } if let Some(order_by) = &mut plan.order_by { for (expr, _) in order_by.iter_mut() { rewrite_expr(expr)?; } } Ok(()) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ConstantPredicate { AlwaysTrue, AlwaysFalse, } /** Helper trait for expressions that can be optimized Implemented for ast::Expr */ pub trait Optimizable { // if the expression is a constant expression e.g. '1', returns the constant condition fn check_constant(&self) -> Result>; fn is_always_true(&self) -> Result { Ok(self .check_constant()? .map_or(false, |c| c == ConstantPredicate::AlwaysTrue)) } fn is_always_false(&self) -> Result { Ok(self .check_constant()? .map_or(false, |c| c == ConstantPredicate::AlwaysFalse)) } fn is_rowid_alias_of(&self, table_index: usize) -> bool; fn check_index_scan( &mut self, table_index: usize, table_reference: &TableReference, available_indexes: &HashMap>>, ) -> Result>>; } impl Optimizable for ast::Expr { fn is_rowid_alias_of(&self, table_index: usize) -> bool { match self { Self::Column { table, is_rowid_alias, .. } => *is_rowid_alias && *table == table_index, _ => false, } } fn check_index_scan( &mut self, table_index: usize, table_reference: &TableReference, available_indexes: &HashMap>>, ) -> Result>> { match self { Self::Column { table, column, .. } => { if *table != table_index { return Ok(None); } let Some(available_indexes_for_table) = available_indexes.get(table_reference.table.get_name()) else { return Ok(None); }; let Some(column) = table_reference.table.get_column_at(*column) else { return Ok(None); }; for index in available_indexes_for_table.iter() { if let Some(name) = column.name.as_ref() { if &index.columns.first().unwrap().name == name { return Ok(Some(index.clone())); } } } Ok(None) } Self::Binary(lhs, op, rhs) => { // Only consider index scans for binary ops that are comparisons. // e.g. "t1.id = t2.id" is a valid index scan, but "t1.id + 1" is not. // // TODO/optimization: consider detecting index scan on e.g. table t1 in // "WHERE t1.id + 1 = t2.id" // here the Expr could be rewritten to "t1.id = t2.id - 1" // and then t1.id could be used as an index key. if !matches!( *op, ast::Operator::Equals | ast::Operator::Greater | ast::Operator::GreaterEquals | ast::Operator::Less | ast::Operator::LessEquals ) { return Ok(None); } let lhs_index = lhs.check_index_scan(table_index, &table_reference, available_indexes)?; if lhs_index.is_some() { return Ok(lhs_index); } let rhs_index = rhs.check_index_scan(table_index, &table_reference, available_indexes)?; if rhs_index.is_some() { // swap lhs and rhs let swapped_operator = match *op { ast::Operator::Equals => ast::Operator::Equals, ast::Operator::Greater => ast::Operator::Less, ast::Operator::GreaterEquals => ast::Operator::LessEquals, ast::Operator::Less => ast::Operator::Greater, ast::Operator::LessEquals => ast::Operator::GreaterEquals, _ => unreachable!(), }; let lhs_new = rhs.take_ownership(); let rhs_new = lhs.take_ownership(); *self = Self::Binary(Box::new(lhs_new), swapped_operator, Box::new(rhs_new)); return Ok(rhs_index); } Ok(None) } _ => Ok(None), } } fn check_constant(&self) -> Result> { match self { Self::Literal(lit) => match lit { ast::Literal::Numeric(b) => { if let Ok(int_value) = b.parse::() { return Ok(Some(if int_value == 0 { ConstantPredicate::AlwaysFalse } else { ConstantPredicate::AlwaysTrue })); } if let Ok(float_value) = b.parse::() { return Ok(Some(if float_value == 0.0 { ConstantPredicate::AlwaysFalse } else { ConstantPredicate::AlwaysTrue })); } Ok(None) } ast::Literal::String(s) => { let without_quotes = s.trim_matches('\''); if let Ok(int_value) = without_quotes.parse::() { return Ok(Some(if int_value == 0 { ConstantPredicate::AlwaysFalse } else { ConstantPredicate::AlwaysTrue })); } if let Ok(float_value) = without_quotes.parse::() { return Ok(Some(if float_value == 0.0 { ConstantPredicate::AlwaysFalse } else { ConstantPredicate::AlwaysTrue })); } Ok(Some(ConstantPredicate::AlwaysFalse)) } _ => Ok(None), }, Self::Unary(op, expr) => { if *op == ast::UnaryOperator::Not { let trivial = expr.check_constant()?; return Ok(trivial.map(|t| match t { ConstantPredicate::AlwaysTrue => ConstantPredicate::AlwaysFalse, ConstantPredicate::AlwaysFalse => ConstantPredicate::AlwaysTrue, })); } if *op == ast::UnaryOperator::Negative { let trivial = expr.check_constant()?; return Ok(trivial); } Ok(None) } Self::InList { lhs: _, not, rhs } => { if rhs.is_none() { return Ok(Some(if *not { ConstantPredicate::AlwaysTrue } else { ConstantPredicate::AlwaysFalse })); } let rhs = rhs.as_ref().unwrap(); if rhs.is_empty() { return Ok(Some(if *not { ConstantPredicate::AlwaysTrue } else { ConstantPredicate::AlwaysFalse })); } Ok(None) } Self::Binary(lhs, op, rhs) => { let lhs_trivial = lhs.check_constant()?; let rhs_trivial = rhs.check_constant()?; match op { ast::Operator::And => { if lhs_trivial == Some(ConstantPredicate::AlwaysFalse) || rhs_trivial == Some(ConstantPredicate::AlwaysFalse) { return Ok(Some(ConstantPredicate::AlwaysFalse)); } if lhs_trivial == Some(ConstantPredicate::AlwaysTrue) && rhs_trivial == Some(ConstantPredicate::AlwaysTrue) { return Ok(Some(ConstantPredicate::AlwaysTrue)); } Ok(None) } ast::Operator::Or => { if lhs_trivial == Some(ConstantPredicate::AlwaysTrue) || rhs_trivial == Some(ConstantPredicate::AlwaysTrue) { return Ok(Some(ConstantPredicate::AlwaysTrue)); } if lhs_trivial == Some(ConstantPredicate::AlwaysFalse) && rhs_trivial == Some(ConstantPredicate::AlwaysFalse) { return Ok(Some(ConstantPredicate::AlwaysFalse)); } Ok(None) } _ => Ok(None), } } _ => Ok(None), } } } fn opposite_cmp_op(op: ast::Operator) -> ast::Operator { match op { ast::Operator::Equals => ast::Operator::Equals, ast::Operator::Greater => ast::Operator::Less, ast::Operator::GreaterEquals => ast::Operator::LessEquals, ast::Operator::Less => ast::Operator::Greater, ast::Operator::LessEquals => ast::Operator::GreaterEquals, _ => panic!("unexpected operator: {:?}", op), } } pub fn try_extract_index_search_expression( cond: &mut WhereTerm, table_index: usize, table_reference: &TableReference, available_indexes: &HashMap>>, iter_dir: IterationDirection, ) -> Result> { if !cond.should_eval_at_loop(table_index) { return Ok(None); } match &mut cond.expr { ast::Expr::Binary(lhs, operator, rhs) => { if lhs.is_rowid_alias_of(table_index) { match operator { ast::Operator::Equals => { let rhs_owned = rhs.take_ownership(); return Ok(Some(Search::RowidEq { cmp_expr: WhereTerm { expr: rhs_owned, from_outer_join: cond.from_outer_join, eval_at: cond.eval_at, }, })); } ast::Operator::Greater | ast::Operator::GreaterEquals | ast::Operator::Less | ast::Operator::LessEquals => { let rhs_owned = rhs.take_ownership(); return Ok(Some(Search::RowidSearch { cmp_op: *operator, cmp_expr: WhereTerm { expr: rhs_owned, from_outer_join: cond.from_outer_join, eval_at: cond.eval_at, }, iter_dir, })); } _ => {} } } if rhs.is_rowid_alias_of(table_index) { match operator { ast::Operator::Equals => { let lhs_owned = lhs.take_ownership(); return Ok(Some(Search::RowidEq { cmp_expr: WhereTerm { expr: lhs_owned, from_outer_join: cond.from_outer_join, eval_at: cond.eval_at, }, })); } ast::Operator::Greater | ast::Operator::GreaterEquals | ast::Operator::Less | ast::Operator::LessEquals => { let lhs_owned = lhs.take_ownership(); return Ok(Some(Search::RowidSearch { cmp_op: opposite_cmp_op(*operator), cmp_expr: WhereTerm { expr: lhs_owned, from_outer_join: cond.from_outer_join, eval_at: cond.eval_at, }, iter_dir, })); } _ => {} } } if let Some(index_rc) = lhs.check_index_scan(table_index, &table_reference, available_indexes)? { match operator { ast::Operator::Equals | ast::Operator::Greater | ast::Operator::GreaterEquals | ast::Operator::Less | ast::Operator::LessEquals => { let rhs_owned = rhs.take_ownership(); return Ok(Some(Search::IndexSearch { index: index_rc, cmp_op: *operator, cmp_expr: WhereTerm { expr: rhs_owned, from_outer_join: cond.from_outer_join, eval_at: cond.eval_at, }, iter_dir, })); } _ => {} } } if let Some(index_rc) = rhs.check_index_scan(table_index, &table_reference, available_indexes)? { match operator { ast::Operator::Equals | ast::Operator::Greater | ast::Operator::GreaterEquals | ast::Operator::Less | ast::Operator::LessEquals => { let lhs_owned = lhs.take_ownership(); return Ok(Some(Search::IndexSearch { index: index_rc, cmp_op: opposite_cmp_op(*operator), cmp_expr: WhereTerm { expr: lhs_owned, from_outer_join: cond.from_outer_join, eval_at: cond.eval_at, }, iter_dir, })); } _ => {} } } Ok(None) } _ => Ok(None), } } fn rewrite_expr(expr: &mut ast::Expr) -> Result<()> { match expr { ast::Expr::Id(id) => { // Convert "true" and "false" to 1 and 0 if id.0.eq_ignore_ascii_case("true") { *expr = ast::Expr::Literal(ast::Literal::Numeric(1.to_string())); return Ok(()); } if id.0.eq_ignore_ascii_case("false") { *expr = ast::Expr::Literal(ast::Literal::Numeric(0.to_string())); return Ok(()); } Ok(()) } ast::Expr::Between { lhs, not, start, end, } => { // Convert `y NOT BETWEEN x AND z` to `x > y OR y > z` let (lower_op, upper_op) = if *not { (ast::Operator::Greater, ast::Operator::Greater) } else { // Convert `y BETWEEN x AND z` to `x <= y AND y <= z` (ast::Operator::LessEquals, ast::Operator::LessEquals) }; rewrite_expr(start)?; rewrite_expr(lhs)?; rewrite_expr(end)?; let start = start.take_ownership(); let lhs = lhs.take_ownership(); let end = end.take_ownership(); let lower_bound = ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone())); let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end)); if *not { *expr = ast::Expr::Binary( Box::new(lower_bound), ast::Operator::Or, Box::new(upper_bound), ); } else { *expr = ast::Expr::Binary( Box::new(lower_bound), ast::Operator::And, Box::new(upper_bound), ); } Ok(()) } ast::Expr::Parenthesized(ref mut exprs) => { for subexpr in exprs.iter_mut() { rewrite_expr(subexpr)?; } let exprs = std::mem::take(exprs); *expr = ast::Expr::Parenthesized(exprs); Ok(()) } // Process other expressions recursively ast::Expr::Binary(lhs, _, rhs) => { rewrite_expr(lhs)?; rewrite_expr(rhs)?; Ok(()) } ast::Expr::FunctionCall { args, .. } => { if let Some(args) = args { for arg in args.iter_mut() { rewrite_expr(arg)?; } } Ok(()) } ast::Expr::Unary(_, arg) => { rewrite_expr(arg)?; Ok(()) } _ => Ok(()), } } trait TakeOwnership { fn take_ownership(&mut self) -> Self; } impl TakeOwnership for ast::Expr { fn take_ownership(&mut self) -> Self { std::mem::replace(self, ast::Expr::Literal(ast::Literal::Null)) } }