diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index f71cb9728..7c29b8834 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -797,6 +797,47 @@ enum BinaryExprSide { Rhs, } +/// Recursively unwrap parentheses from an expression +/// e.g. (((t.x > 5))) -> t.x > 5 +fn unwrap_parens(expr: T) -> Result +where + T: UnwrapParens, +{ + expr.unwrap_parens() +} + +trait UnwrapParens { + fn unwrap_parens(self) -> Result + where + Self: Sized; +} + +impl UnwrapParens for &ast::Expr { + fn unwrap_parens(self) -> Result { + match self { + ast::Expr::Column { .. } => Ok(self), + ast::Expr::Parenthesized(exprs) => match exprs.len() { + 1 => unwrap_parens(exprs.first().unwrap()), + _ => crate::bail_parse_error!("expected single expression in parentheses"), + }, + _ => Ok(self), + } + } +} + +impl UnwrapParens for ast::Expr { + fn unwrap_parens(self) -> Result { + match self { + ast::Expr::Column { .. } => Ok(self), + ast::Expr::Parenthesized(mut exprs) => match exprs.len() { + 1 => unwrap_parens(exprs.pop().unwrap()), + _ => crate::bail_parse_error!("expected single expression in parentheses"), + }, + _ => Ok(self), + } + } +} + /// Get the position of a column in an index /// For example, if there is an index on table T(x,y) then y's position in the index is 1. fn get_column_position_in_index( @@ -804,20 +845,20 @@ fn get_column_position_in_index( table_index: usize, table_reference: &TableReference, index: &Arc, -) -> Option { - let ast::Expr::Column { table, column, .. } = expr else { - return None; +) -> Result> { + let ast::Expr::Column { table, column, .. } = unwrap_parens(expr)? else { + return Ok(None); }; if *table != table_index { - return None; + return Ok(None); } let Some(column) = table_reference.table.get_column_at(*column) else { - return None; + return Ok(None); }; - index + Ok(index .columns .iter() - .position(|col| Some(&col.name) == column.name.as_ref()) + .position(|col| Some(&col.name) == column.name.as_ref())) } /// Find all [IndexConstraint]s for a given WHERE clause @@ -839,7 +880,7 @@ fn find_index_constraints( continue; } // Skip terms that are not binary comparisons - let ast::Expr::Binary(lhs, operator, rhs) = &term.expr else { + let ast::Expr::Binary(lhs, operator, rhs) = unwrap_parens(&term.expr)? else { continue; }; // Only consider index scans for binary ops that are comparisons @@ -868,7 +909,7 @@ fn find_index_constraints( // Check if lhs is a column that is in the i'th position of the index if Some(position_in_index) - == get_column_position_in_index(lhs, table_index, table_reference, index) + == get_column_position_in_index(lhs, table_index, table_reference, index)? { out_constraints.push(IndexConstraint { operator: *operator, @@ -879,7 +920,7 @@ fn find_index_constraints( } // Check if rhs is a column that is in the i'th position of the index if Some(position_in_index) - == get_column_position_in_index(rhs, table_index, table_reference, index) + == get_column_position_in_index(rhs, table_index, table_reference, index)? { out_constraints.push(IndexConstraint { operator: opposite_cmp_op(*operator), // swap the operator since e.g. if condition is 5 >= x, we want to use x <= 5 @@ -931,7 +972,8 @@ pub fn build_seek_def_from_index_constraints( // Extract the other expression from the binary WhereTerm (i.e. the one being compared to the index column) let (idx, side) = constraint.position_in_where_clause; let where_term = &mut where_clause[idx]; - let ast::Expr::Binary(lhs, _, rhs) = where_term.expr.take_ownership() else { + let ast::Expr::Binary(lhs, _, rhs) = unwrap_parens(where_term.expr.take_ownership())? + else { crate::bail_parse_error!("expected binary expression"); }; let cmp_expr = if side == BinaryExprSide::Lhs {