diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index ce1a11d2d..7ae7abd62 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -295,7 +295,7 @@ pub fn open_loop( t_ctx: &mut TranslateCtx, tables: &[TableReference], join_order: &[JoinOrderMember], - predicates: &mut [WhereTerm], + predicates: &[WhereTerm], ) -> Result<()> { for (join_index, join) in join_order.iter().enumerate() { let table_index = join.original_idx; @@ -406,7 +406,7 @@ pub fn open_loop( &t_ctx.resolver, )?; if cinfo.usable && usage.omit { - predicates[pred_idx].consumed = true; + predicates[pred_idx].consumed.set(true); } } } diff --git a/core/translate/optimizer/join.rs b/core/translate/optimizer/join.rs index 23889b749..b74e746ab 100644 --- a/core/translate/optimizer/join.rs +++ b/core/translate/optimizer/join.rs @@ -494,7 +494,7 @@ fn generate_join_bitmasks(table_number_max_exclusive: usize, how_many: usize) -> #[cfg(test)] mod tests { - use std::{rc::Rc, sync::Arc}; + use std::{cell::Cell, rc::Rc, sync::Arc}; use limbo_sqlite3_parser::ast::{self, Expr, Operator, SortOrder, TableInternalId}; @@ -1332,7 +1332,7 @@ mod tests { WhereTerm { expr: Expr::Binary(Box::new(lhs), op, Box::new(rhs)), from_outer_join: None, - consumed: false, + consumed: Cell::new(false), } } diff --git a/core/translate/optimizer/lift_common_subexpressions.rs b/core/translate/optimizer/lift_common_subexpressions.rs index b7302cb7e..af06f0c37 100644 --- a/core/translate/optimizer/lift_common_subexpressions.rs +++ b/core/translate/optimizer/lift_common_subexpressions.rs @@ -1,3 +1,5 @@ +use std::cell::Cell; + use limbo_sqlite3_parser::ast::{Expr, Operator}; use crate::{ @@ -112,7 +114,7 @@ pub(crate) fn lift_common_subexpressions_from_binary_or_terms( if found_non_empty_or_branches { // If we found an empty OR branch, we can remove the entire OR term. // E.g. (a AND b) OR (a) OR (a AND c) just becomes a. - where_clause[i].consumed = true; + where_clause[i].consumed.set(true); } else { assert!(new_or_operands_for_original_term.len() > 1); // Update the original WhereTerm's expression with the new OR structure (without common parts). @@ -124,7 +126,7 @@ pub(crate) fn lift_common_subexpressions_from_binary_or_terms( where_clause.push(WhereTerm { expr: common_expr_to_add, from_outer_join: term_from_outer_join, - consumed: false, + consumed: Cell::new(false), }); } @@ -255,7 +257,7 @@ mod tests { let mut where_clause = vec![WhereTerm { expr: or_expr, from_outer_join: None, - consumed: false, + consumed: Cell::new(false), }]; lift_common_subexpressions_from_binary_or_terms(&mut where_clause)?; @@ -266,7 +268,7 @@ mod tests { // 3. b = 1 let nonconsumed_terms = where_clause .iter() - .filter(|term| !term.consumed) + .filter(|term| !term.consumed.get()) .collect::>(); assert_eq!(nonconsumed_terms.len(), 3); assert_eq!( @@ -354,7 +356,7 @@ mod tests { let mut where_clause = vec![WhereTerm { expr: or_expr, from_outer_join: None, - consumed: false, + consumed: Cell::new(false), }]; lift_common_subexpressions_from_binary_or_terms(&mut where_clause)?; @@ -364,7 +366,7 @@ mod tests { // 2. a = 1 let nonconsumed_terms = where_clause .iter() - .filter(|term| !term.consumed) + .filter(|term| !term.consumed.get()) .collect::>(); assert_eq!(nonconsumed_terms.len(), 2); assert_eq!( @@ -421,7 +423,7 @@ mod tests { let mut where_clause = vec![WhereTerm { expr: or_expr.clone(), from_outer_join: None, - consumed: false, + consumed: Cell::new(false), }]; lift_common_subexpressions_from_binary_or_terms(&mut where_clause)?; @@ -429,7 +431,7 @@ mod tests { // Should remain unchanged since no common terms let nonconsumed_terms = where_clause .iter() - .filter(|term| !term.consumed) + .filter(|term| !term.consumed.get()) .collect::>(); assert_eq!(nonconsumed_terms.len(), 1); assert_eq!(nonconsumed_terms[0].expr, or_expr); @@ -488,7 +490,7 @@ mod tests { let mut where_clause = vec![WhereTerm { expr: or_expr, from_outer_join: Some(TableInternalId::default()), // Set from_outer_join - consumed: false, + consumed: Cell::new(false), }]; lift_common_subexpressions_from_binary_or_terms(&mut where_clause)?; @@ -496,7 +498,7 @@ mod tests { // Should have 2 terms, both with from_outer_join set let nonconsumed_terms = where_clause .iter() - .filter(|term| !term.consumed) + .filter(|term| !term.consumed.get()) .collect::>(); assert_eq!(nonconsumed_terms.len(), 2); assert_eq!( @@ -540,7 +542,7 @@ mod tests { let mut where_clause = vec![WhereTerm { expr: single_expr.clone(), from_outer_join: None, - consumed: false, + consumed: Cell::new(false), }]; lift_common_subexpressions_from_binary_or_terms(&mut where_clause)?; @@ -548,7 +550,7 @@ mod tests { // Should remain unchanged let nonconsumed_terms = where_clause .iter() - .filter(|term| !term.consumed) + .filter(|term| !term.consumed.get()) .collect::>(); assert_eq!(nonconsumed_terms.len(), 1); assert_eq!(nonconsumed_terms[0].expr, single_expr); @@ -593,14 +595,14 @@ mod tests { let mut where_clause = vec![WhereTerm { expr: or_expr, from_outer_join: None, - consumed: false, + consumed: Cell::new(false), }]; lift_common_subexpressions_from_binary_or_terms(&mut where_clause)?; let nonconsumed_terms = where_clause .iter() - .filter(|term| !term.consumed) + .filter(|term| !term.consumed.get()) .collect::>(); assert_eq!(nonconsumed_terms.len(), 1); assert_eq!(nonconsumed_terms[0].expr, a_expr); diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index 496d5ad38..bcb09a0a0 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -293,11 +293,13 @@ fn optimize_table_access( let constraint = &constraints_per_table[table_idx].constraints[cref.constraint_vec_pos]; assert!( - !where_clause[constraint.where_clause_pos.0].consumed, + !where_clause[constraint.where_clause_pos.0].consumed.get(), "trying to consume a where clause term twice: {:?}", where_clause[constraint.where_clause_pos.0] ); - where_clause[constraint.where_clause_pos.0].consumed = true; + where_clause[constraint.where_clause_pos.0] + .consumed + .set(true); } if let Some(index) = &access_method.index { table_references[table_idx].op = Operation::Search(Search::Seek { @@ -355,7 +357,7 @@ fn eliminate_constant_conditions( let predicate = &where_clause[i]; if predicate.expr.is_always_true()? { // true predicates can be removed since they don't affect the result - where_clause[i].consumed = true; + where_clause[i].consumed.set(true); i += 1; } else if predicate.expr.is_always_false()? { // any false predicate in a list of conjuncts (AND-ed predicates) will make the whole list false, @@ -366,7 +368,7 @@ fn eliminate_constant_conditions( } where_clause .iter_mut() - .for_each(|term| term.consumed = true); + .for_each(|term| term.consumed.set(true)); return Ok(ConstantConditionEliminationResult::ImpossibleCondition); } else { i += 1; diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 0cb071d7e..08946086a 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -2,6 +2,7 @@ use core::fmt; use limbo_ext::{ConstraintInfo, ConstraintOp}; use limbo_sqlite3_parser::ast::{self, SortOrder}; use std::{ + cell::Cell, cmp::Ordering, fmt::{Display, Formatter}, rc::Rc, @@ -97,12 +98,16 @@ pub struct WhereTerm { /// A term may have been consumed e.g. if: /// - it has been converted into a constraint in a seek key /// - it has been removed due to being trivially true or false - pub consumed: bool, + /// + /// FIXME: this can be made into a simple `bool` once we move the virtual table constraint resolution + /// code out of `init_loop()`, because that's the only place that requires a mutable reference to the where clause + /// that causes problems to other code that needs immutable references to the where clause. + pub consumed: Cell, } impl WhereTerm { pub fn should_eval_before_loop(&self, join_order: &[JoinOrderMember]) -> bool { - if self.consumed { + if self.consumed.get() { return false; } let Ok(eval_at) = self.eval_at(join_order) else { @@ -112,7 +117,7 @@ impl WhereTerm { } pub fn should_eval_at_loop(&self, loop_idx: usize, join_order: &[JoinOrderMember]) -> bool { - if self.consumed { + if self.consumed.get() { return false; } let Ok(eval_at) = self.eval_at(join_order) else { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 5aaf98f01..28c7c1b0e 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,3 +1,5 @@ +use std::cell::Cell; + use super::{ expr::walk_expr, plan::{ @@ -514,7 +516,7 @@ pub fn parse_where( out_where_clause.push(WhereTerm { expr, from_outer_join: None, - consumed: false, + consumed: Cell::new(false), }); } Ok(()) @@ -775,7 +777,7 @@ fn parse_join<'a>( } else { None }, - consumed: false, + consumed: Cell::new(false), }); } } @@ -849,7 +851,7 @@ fn parse_join<'a>( } else { None }, - consumed: false, + consumed: Cell::new(false), }); } using = Some(distinct_names);