From fbfd2b2c384df2b38927f7fa1984cbf4ea015234 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Fri, 23 May 2025 16:09:27 +0300 Subject: [PATCH] refactor: use walk_expr_mut() in rewrite_expr() --- core/translate/optimizer/mod.rs | 181 ++++++++++---------------------- 1 file changed, 58 insertions(+), 123 deletions(-) diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index 130d595ff..9faad7888 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -12,7 +12,7 @@ use order::{compute_order_target, plan_satisfies_order_target, EliminatesSort}; use crate::{ parameters::PARAM_PREFIX, schema::{Index, IndexColumn, Schema, Table}, - translate::plan::TerminationKey, + translate::{expr::walk_expr_mut, plan::TerminationKey}, types::SeekOp, Result, }; @@ -1242,133 +1242,68 @@ fn build_seek_def( }) } -pub fn rewrite_expr(expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> { - match expr { - ast::Expr::Id(id) => { - // Convert "true" and "false" to 1 and 0 - if id.0.eq_ignore_ascii_case("true") { - *expr = ast::Expr::Literal(ast::Literal::Numeric(1.to_string())); - return Ok(()); - } - if id.0.eq_ignore_ascii_case("false") { - *expr = ast::Expr::Literal(ast::Literal::Numeric(0.to_string())); - return Ok(()); - } - Ok(()) - } - ast::Expr::Variable(var) => { - if var.is_empty() { - // rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and - // all the expressions are rewritten in the order they come in the statement - *expr = ast::Expr::Variable(format!("{}{param_idx}", PARAM_PREFIX)); - *param_idx += 1; - } - Ok(()) - } - ast::Expr::Between { - lhs, - not, - start, - end, - } => { - // Convert `y NOT BETWEEN x AND z` to `x > y OR y > z` - let (lower_op, upper_op) = if *not { - (ast::Operator::Greater, ast::Operator::Greater) - } else { - // Convert `y BETWEEN x AND z` to `x <= y AND y <= z` - (ast::Operator::LessEquals, ast::Operator::LessEquals) - }; - - rewrite_expr(start, param_idx)?; - rewrite_expr(lhs, param_idx)?; - rewrite_expr(end, param_idx)?; - - let start = start.take_ownership(); - let lhs = lhs.take_ownership(); - let end = end.take_ownership(); - - let lower_bound = ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone())); - let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end)); - - if *not { - *expr = ast::Expr::Binary( - Box::new(lower_bound), - ast::Operator::Or, - Box::new(upper_bound), - ); - } else { - *expr = ast::Expr::Binary( - Box::new(lower_bound), - ast::Operator::And, - Box::new(upper_bound), - ); - } - Ok(()) - } - ast::Expr::Parenthesized(ref mut exprs) => { - for subexpr in exprs.iter_mut() { - rewrite_expr(subexpr, param_idx)?; - } - let exprs = std::mem::take(exprs); - *expr = ast::Expr::Parenthesized(exprs); - Ok(()) - } - // Process other expressions recursively - ast::Expr::Binary(lhs, _, rhs) => { - rewrite_expr(lhs, param_idx)?; - rewrite_expr(rhs, param_idx)?; - Ok(()) - } - ast::Expr::Like { - lhs, rhs, escape, .. - } => { - rewrite_expr(lhs, param_idx)?; - rewrite_expr(rhs, param_idx)?; - if let Some(escape) = escape { - rewrite_expr(escape, param_idx)?; - } - Ok(()) - } - ast::Expr::Case { - base, - when_then_pairs, - else_expr, - } => { - if let Some(base) = base { - rewrite_expr(base, param_idx)?; - } - for (lhs, rhs) in when_then_pairs.iter_mut() { - rewrite_expr(lhs, param_idx)?; - rewrite_expr(rhs, param_idx)?; - } - if let Some(else_expr) = else_expr { - rewrite_expr(else_expr, param_idx)?; - } - Ok(()) - } - ast::Expr::InList { lhs, rhs, .. } => { - rewrite_expr(lhs, param_idx)?; - if let Some(rhs) = rhs { - for expr in rhs.iter_mut() { - rewrite_expr(expr, param_idx)?; +pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> { + walk_expr_mut(top_level_expr, &mut |expr: &mut ast::Expr| -> Result<()> { + match expr { + ast::Expr::Id(id) => { + // Convert "true" and "false" to 1 and 0 + if id.0.eq_ignore_ascii_case("true") { + *expr = ast::Expr::Literal(ast::Literal::Numeric(1.to_string())); + return Ok(()); + } + if id.0.eq_ignore_ascii_case("false") { + *expr = ast::Expr::Literal(ast::Literal::Numeric(0.to_string())); } } - Ok(()) - } - ast::Expr::FunctionCall { args, .. } => { - if let Some(args) = args { - for arg in args.iter_mut() { - rewrite_expr(arg, param_idx)?; + ast::Expr::Variable(var) => { + if var.is_empty() { + // rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and + // all the expressions are rewritten in the order they come in the statement + *expr = ast::Expr::Variable(format!("{}{param_idx}", PARAM_PREFIX)); + *param_idx += 1; } } - Ok(()) + ast::Expr::Between { + lhs, + not, + start, + end, + } => { + // Convert `y NOT BETWEEN x AND z` to `x > y OR y > z` + let (lower_op, upper_op) = if *not { + (ast::Operator::Greater, ast::Operator::Greater) + } else { + // Convert `y BETWEEN x AND z` to `x <= y AND y <= z` + (ast::Operator::LessEquals, ast::Operator::LessEquals) + }; + + let start = start.take_ownership(); + let lhs = lhs.take_ownership(); + let end = end.take_ownership(); + + let lower_bound = + ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone())); + let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end)); + + if *not { + *expr = ast::Expr::Binary( + Box::new(lower_bound), + ast::Operator::Or, + Box::new(upper_bound), + ); + } else { + *expr = ast::Expr::Binary( + Box::new(lower_bound), + ast::Operator::And, + Box::new(upper_bound), + ); + } + } + _ => {} } - ast::Expr::Unary(_, arg) => { - rewrite_expr(arg, param_idx)?; - Ok(()) - } - _ => Ok(()), - } + + Ok(()) + }) } trait TakeOwnership {