refactor: use walk_expr_mut() in rewrite_expr()

This commit is contained in:
Jussi Saurio
2025-05-23 16:09:27 +03:00
parent 362347c474
commit fbfd2b2c38

View File

@@ -12,7 +12,7 @@ use order::{compute_order_target, plan_satisfies_order_target, EliminatesSort};
use crate::{
parameters::PARAM_PREFIX,
schema::{Index, IndexColumn, Schema, Table},
translate::plan::TerminationKey,
translate::{expr::walk_expr_mut, plan::TerminationKey},
types::SeekOp,
Result,
};
@@ -1242,133 +1242,68 @@ fn build_seek_def(
})
}
pub fn rewrite_expr(expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> {
match expr {
ast::Expr::Id(id) => {
// Convert "true" and "false" to 1 and 0
if id.0.eq_ignore_ascii_case("true") {
*expr = ast::Expr::Literal(ast::Literal::Numeric(1.to_string()));
return Ok(());
}
if id.0.eq_ignore_ascii_case("false") {
*expr = ast::Expr::Literal(ast::Literal::Numeric(0.to_string()));
return Ok(());
}
Ok(())
}
ast::Expr::Variable(var) => {
if var.is_empty() {
// rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and
// all the expressions are rewritten in the order they come in the statement
*expr = ast::Expr::Variable(format!("{}{param_idx}", PARAM_PREFIX));
*param_idx += 1;
}
Ok(())
}
ast::Expr::Between {
lhs,
not,
start,
end,
} => {
// Convert `y NOT BETWEEN x AND z` to `x > y OR y > z`
let (lower_op, upper_op) = if *not {
(ast::Operator::Greater, ast::Operator::Greater)
} else {
// Convert `y BETWEEN x AND z` to `x <= y AND y <= z`
(ast::Operator::LessEquals, ast::Operator::LessEquals)
};
rewrite_expr(start, param_idx)?;
rewrite_expr(lhs, param_idx)?;
rewrite_expr(end, param_idx)?;
let start = start.take_ownership();
let lhs = lhs.take_ownership();
let end = end.take_ownership();
let lower_bound = ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone()));
let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end));
if *not {
*expr = ast::Expr::Binary(
Box::new(lower_bound),
ast::Operator::Or,
Box::new(upper_bound),
);
} else {
*expr = ast::Expr::Binary(
Box::new(lower_bound),
ast::Operator::And,
Box::new(upper_bound),
);
}
Ok(())
}
ast::Expr::Parenthesized(ref mut exprs) => {
for subexpr in exprs.iter_mut() {
rewrite_expr(subexpr, param_idx)?;
}
let exprs = std::mem::take(exprs);
*expr = ast::Expr::Parenthesized(exprs);
Ok(())
}
// Process other expressions recursively
ast::Expr::Binary(lhs, _, rhs) => {
rewrite_expr(lhs, param_idx)?;
rewrite_expr(rhs, param_idx)?;
Ok(())
}
ast::Expr::Like {
lhs, rhs, escape, ..
} => {
rewrite_expr(lhs, param_idx)?;
rewrite_expr(rhs, param_idx)?;
if let Some(escape) = escape {
rewrite_expr(escape, param_idx)?;
}
Ok(())
}
ast::Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(base) = base {
rewrite_expr(base, param_idx)?;
}
for (lhs, rhs) in when_then_pairs.iter_mut() {
rewrite_expr(lhs, param_idx)?;
rewrite_expr(rhs, param_idx)?;
}
if let Some(else_expr) = else_expr {
rewrite_expr(else_expr, param_idx)?;
}
Ok(())
}
ast::Expr::InList { lhs, rhs, .. } => {
rewrite_expr(lhs, param_idx)?;
if let Some(rhs) = rhs {
for expr in rhs.iter_mut() {
rewrite_expr(expr, param_idx)?;
pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> {
walk_expr_mut(top_level_expr, &mut |expr: &mut ast::Expr| -> Result<()> {
match expr {
ast::Expr::Id(id) => {
// Convert "true" and "false" to 1 and 0
if id.0.eq_ignore_ascii_case("true") {
*expr = ast::Expr::Literal(ast::Literal::Numeric(1.to_string()));
return Ok(());
}
if id.0.eq_ignore_ascii_case("false") {
*expr = ast::Expr::Literal(ast::Literal::Numeric(0.to_string()));
}
}
Ok(())
}
ast::Expr::FunctionCall { args, .. } => {
if let Some(args) = args {
for arg in args.iter_mut() {
rewrite_expr(arg, param_idx)?;
ast::Expr::Variable(var) => {
if var.is_empty() {
// rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and
// all the expressions are rewritten in the order they come in the statement
*expr = ast::Expr::Variable(format!("{}{param_idx}", PARAM_PREFIX));
*param_idx += 1;
}
}
Ok(())
ast::Expr::Between {
lhs,
not,
start,
end,
} => {
// Convert `y NOT BETWEEN x AND z` to `x > y OR y > z`
let (lower_op, upper_op) = if *not {
(ast::Operator::Greater, ast::Operator::Greater)
} else {
// Convert `y BETWEEN x AND z` to `x <= y AND y <= z`
(ast::Operator::LessEquals, ast::Operator::LessEquals)
};
let start = start.take_ownership();
let lhs = lhs.take_ownership();
let end = end.take_ownership();
let lower_bound =
ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone()));
let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end));
if *not {
*expr = ast::Expr::Binary(
Box::new(lower_bound),
ast::Operator::Or,
Box::new(upper_bound),
);
} else {
*expr = ast::Expr::Binary(
Box::new(lower_bound),
ast::Operator::And,
Box::new(upper_bound),
);
}
}
_ => {}
}
ast::Expr::Unary(_, arg) => {
rewrite_expr(arg, param_idx)?;
Ok(())
}
_ => Ok(()),
}
Ok(())
})
}
trait TakeOwnership {