Merge 'Refactor UPSERT to use wal_expr_mut to walk AST.' from Preston Thorpe

Working on https://github.com/tursodatabase/turso/issues/2964 I came
upon `walk_expr_mut`, I don't think it existed last time I really spent
much time in the translator. So quickly went back and cleaned this up.

Reviewed-by: Jussi Saurio <jussi.saurio@gmail.com>

Closes #3044
This commit is contained in:
Preston Thorpe
2025-09-12 06:45:13 -04:00
committed by GitHub

View File

@@ -11,7 +11,7 @@ use crate::{
emit_cdc_full_record, emit_cdc_insns, emit_cdc_patch_record, OperationMode, Resolver,
},
expr::{
emit_returning_results, translate_expr, translate_expr_no_constant_opt,
emit_returning_results, translate_expr, translate_expr_no_constant_opt, walk_expr_mut,
NoConstantOptReason, ReturningValueRegisters,
},
insert::{Insertion, ROWID_COLUMN},
@@ -525,7 +525,7 @@ fn rewrite_upsert_expr_in_place(
conflict_rowid_reg: usize,
insertion: &Insertion,
) -> crate::Result<()> {
use ast::Expr::*;
use ast::Expr;
// helper: return the CURRENT-row register for a column (including rowid alias)
let col_reg = |name: &str| -> Option<usize> {
@@ -535,291 +535,34 @@ fn rewrite_upsert_expr_in_place(
let (idx, _c) = table.get_column_by_name(&normalize_ident(name))?;
Some(current_start + idx)
};
match e {
// EXCLUDED.x -> insertion register
Qualified(ns, ast::Name::Ident(c)) if ns.as_str().eq_ignore_ascii_case("excluded") => {
let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else {
bail_parse_error!("no such column in EXCLUDED: {}", c);
};
*e = Register(reg.register);
}
// t.x -> CURRENT, only if t matches the target table name (never "excluded")
Qualified(ns, ast::Name::Ident(c)) if ns.as_str().eq_ignore_ascii_case(table_name) => {
if let Some(reg) = col_reg(c.as_str()) {
*e = Register(reg);
walk_expr_mut(e, &mut |expr: &mut ast::Expr| -> crate::Result<()> {
match expr {
// EXCLUDED.x -> insertion register
Expr::Qualified(ns, ast::Name::Ident(c)) => {
if ns.as_str().eq_ignore_ascii_case("excluded") {
let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else {
bail_parse_error!("no such column in EXCLUDED: {}", c);
};
*expr = Expr::Register(reg.register);
} else if ns.as_str().eq_ignore_ascii_case(table_name)
// t.x -> CURRENT, only if t matches the target table name (never "excluded")
{
if let Some(reg) = col_reg(c.as_str()) {
*expr = Expr::Register(reg);
}
}
}
}
// Unqualified column id -> CURRENT
Id(ast::Name::Ident(name)) => {
if let Some(reg) = col_reg(name.as_str()) {
*e = Register(reg);
// Unqualified column id -> CURRENT
Expr::Id(ast::Name::Ident(name)) => {
if let Some(reg) = col_reg(name.as_str()) {
*expr = Expr::Register(reg);
}
}
}
RowId { .. } => {
*e = Register(conflict_rowid_reg);
}
Collate(inner, _) => rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?,
Parenthesized(v) => {
for ex in v {
rewrite_upsert_expr_in_place(
ex,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
Expr::RowId { .. } => {
*expr = Expr::Register(conflict_rowid_reg);
}
_ => {}
}
Between {
lhs, start, end, ..
} => {
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
start,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
end,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Binary(l, _, r) => {
rewrite_upsert_expr_in_place(
l,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
r,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(b) = base {
rewrite_upsert_expr_in_place(
b,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
for (w, t) in when_then_pairs.iter_mut() {
rewrite_upsert_expr_in_place(
w,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
t,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
if let Some(e2) = else_expr {
rewrite_upsert_expr_in_place(
e2,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
}
Cast { expr: inner, .. } => {
rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
FunctionCall {
args,
order_by,
filter_over,
..
} => {
for a in args {
rewrite_upsert_expr_in_place(
a,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
for sc in order_by {
rewrite_upsert_expr_in_place(
&mut sc.expr,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
if let Some(ref mut f) = &mut filter_over.filter_clause {
rewrite_upsert_expr_in_place(
f,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
}
InList { lhs, rhs, .. } => {
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
for ex in rhs {
rewrite_upsert_expr_in_place(
ex,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
}
InSelect { lhs, .. } => {
// rewrite only `lhs`, not the subselect
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
InTable { lhs, .. } => {
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
IsNull(inner) => {
rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Like {
lhs, rhs, escape, ..
} => {
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
rhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
if let Some(e3) = escape {
rewrite_upsert_expr_in_place(
e3,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
}
NotNull(inner) => {
rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Unary(_, inner) => {
rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
_ => {}
}
Ok(())
Ok(())
})
}