Use consistent imports of ast::Expr in upsert

This commit is contained in:
PThorpe92
2025-08-29 21:13:03 -04:00
parent e175516319
commit 0fc603830b

View File

@@ -549,7 +549,7 @@ pub fn emit_upsert(
pub fn collect_set_clauses_for_upsert(
table: &Table,
set_items: &mut [ast::Set],
insertion: &Insertion, // for EXCLUDED.*
insertion: &Insertion,
) -> crate::Result<Vec<(usize, Box<ast::Expr>)>> {
let lookup: HashMap<String, usize> = table
.columns()
@@ -599,8 +599,6 @@ fn rewrite_target_cols_to_current_row(
current_start: usize,
conflict_rowid_reg: usize,
) {
use ast::Expr::*;
// Helper: map a column name to (is_rowid, register)
let col_reg = |name: &str| -> Option<usize> {
if name.eq_ignore_ascii_case("rowid") {
@@ -616,46 +614,46 @@ fn rewrite_target_cols_to_current_row(
match expr {
// tbl.col: only rewrite if it names this table
ast::Expr::Qualified(left, col) => {
Expr::Qualified(left, col) => {
let q = left.as_str();
if !q.eq_ignore_ascii_case("excluded") && q.eq_ignore_ascii_case(table.get_name()) {
if let ast::Name::Ident(c) = col {
if let Some(reg) = col_reg(c.as_str()) {
*expr = Register(reg);
*expr = Expr::Register(reg);
}
}
}
}
ast::Expr::Id(ast::Name::Ident(name)) => {
Expr::Id(ast::Name::Ident(name)) => {
if let Some(reg) = col_reg(name.as_str()) {
*expr = Register(reg);
*expr = Expr::Register(reg);
}
}
ast::Expr::RowId { .. } => {
*expr = Register(conflict_rowid_reg);
Expr::RowId { .. } => {
*expr = Expr::Register(conflict_rowid_reg);
}
// Keep walking for composite expressions
Collate(inner, _) => {
Expr::Collate(inner, _) => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
}
Parenthesized(v) => {
Expr::Parenthesized(v) => {
for e in v {
rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg)
}
}
Between {
Expr::Between {
lhs, start, end, ..
} => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg);
rewrite_target_cols_to_current_row(start, table, current_start, conflict_rowid_reg);
rewrite_target_cols_to_current_row(end, table, current_start, conflict_rowid_reg);
}
Binary(l, _, r) => {
Expr::Binary(l, _, r) => {
rewrite_target_cols_to_current_row(l, table, current_start, conflict_rowid_reg);
rewrite_target_cols_to_current_row(r, table, current_start, conflict_rowid_reg);
}
Case {
Expr::Case {
base,
when_then_pairs,
else_expr,
@@ -671,10 +669,10 @@ fn rewrite_target_cols_to_current_row(
rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg)
}
}
Cast { expr: inner, .. } => {
Expr::Cast { expr: inner, .. } => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
}
FunctionCall {
Expr::FunctionCall {
args,
order_by,
filter_over,
@@ -695,22 +693,22 @@ fn rewrite_target_cols_to_current_row(
rewrite_target_cols_to_current_row(f, table, current_start, conflict_rowid_reg)
}
}
InList { lhs, rhs, .. } => {
Expr::InList { lhs, rhs, .. } => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg);
for e in rhs {
rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg)
}
}
InSelect { lhs, .. } => {
Expr::InSelect { lhs, .. } => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg)
}
InTable { lhs, .. } => {
Expr::InTable { lhs, .. } => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg)
}
IsNull(inner) => {
Expr::IsNull(inner) => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
}
Like {
Expr::Like {
lhs, rhs, escape, ..
} => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg);
@@ -719,10 +717,10 @@ fn rewrite_target_cols_to_current_row(
rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg)
}
}
NotNull(inner) => {
Expr::NotNull(inner) => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
}
Unary(_, inner) => {
Expr::Unary(_, inner) => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
}
_ => {}