Combine rewriting expressions in UPSERT into a single walk of the ast

This commit is contained in:
PThorpe92
2025-08-29 22:12:46 -04:00
parent 0fc603830b
commit 8531560899
2 changed files with 280 additions and 198 deletions

View File

@@ -14,8 +14,7 @@ use crate::translate::expr::{
};
use crate::translate::planner::ROWID;
use crate::translate::upsert::{
collect_set_clauses_for_upsert, emit_upsert, rewrite_excluded_in_expr, upsert_matches_index,
upsert_matches_pk,
collect_set_clauses_for_upsert, emit_upsert, upsert_matches_index, upsert_matches_pk,
};
use crate::util::normalize_ident;
use crate::vdbe::builder::ProgramBuilderOpts;
@@ -414,15 +413,13 @@ pub fn translate_insert(
ref mut sets,
ref mut where_clause,
} => {
let mut rewritten_sets =
collect_set_clauses_for_upsert(&table, sets, &insertion)?;
if let Some(expr) = where_clause.as_mut() {
rewrite_excluded_in_expr(expr, &insertion)?;
}
let mut rewritten_sets = collect_set_clauses_for_upsert(&table, sets)?;
emit_upsert(
&mut program,
schema,
&table,
&insertion,
cursor_id,
insertion.key_register(),
&mut rewritten_sets,
@@ -536,11 +533,7 @@ pub fn translate_insert(
ref mut where_clause,
} => {
let mut rewritten_sets =
collect_set_clauses_for_upsert(&table, sets, &insertion)?;
if let Some(expr) = where_clause.as_mut() {
rewrite_excluded_in_expr(expr, &insertion)?;
}
collect_set_clauses_for_upsert(&table, sets)?;
let conflict_rowid_reg = program.alloc_register();
program.emit_insn(Insn::IdxRowId {
cursor_id: idx_cursor_id,
@@ -550,6 +543,7 @@ pub fn translate_insert(
&mut program,
schema,
&table,
&insertion,
cursor_id,
conflict_rowid_reg,
&mut rewritten_sets,

View File

@@ -1,6 +1,6 @@
use std::{collections::HashMap, sync::Arc};
use turso_parser::ast::{self, Expr, Upsert};
use turso_parser::ast::{self, Upsert};
use crate::{
bail_parse_error,
@@ -87,99 +87,6 @@ fn effective_collation_for_index_col(idx_col: &IndexColumn, table: &Table) -> St
.unwrap_or_else(|| "binary".to_string())
}
/// Column names in the expressions of a DO UPDATE refer to the original unchanged value of the column, before the attempted INSERT.
/// To use the value that would have been inserted had the constraint not failed, add the special "excluded." table qualifier to the column name.
/// https://sqlite.org/lang_upsert.html
///
/// Rewrite EXCLUDED.x to Expr::Register(<reg of x from insertion>)
pub fn rewrite_excluded_in_expr(expr: &mut Expr, insertion: &Insertion) -> crate::Result<()> {
match expr {
Expr::Qualified(ns, col) if ns.as_str().eq_ignore_ascii_case("excluded") => {
let cname = match col {
ast::Name::Ident(s) => s.as_str(),
_ => return Ok(()),
};
let Some(src) = insertion.get_col_mapping_by_name(cname) else {
bail_parse_error!("no such column in EXCLUDED: {}", cname);
};
*expr = Expr::Register(src.register)
}
Expr::Collate(inner, _) => rewrite_excluded_in_expr(inner, insertion)?,
Expr::Parenthesized(v) => {
for e in v {
rewrite_excluded_in_expr(e, insertion)?
}
}
Expr::Between {
lhs, start, end, ..
} => {
rewrite_excluded_in_expr(lhs, insertion)?;
rewrite_excluded_in_expr(start, insertion)?;
rewrite_excluded_in_expr(end, insertion)?;
}
Expr::Binary(l, _, r) => {
rewrite_excluded_in_expr(l, insertion)?;
rewrite_excluded_in_expr(r, insertion)?;
}
Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(b) = base {
rewrite_excluded_in_expr(b, insertion)?
};
for (w, t) in when_then_pairs.iter_mut() {
rewrite_excluded_in_expr(w, insertion)?;
rewrite_excluded_in_expr(t, insertion)?;
}
if let Some(e) = else_expr {
rewrite_excluded_in_expr(e, insertion)?
}
}
Expr::Cast { expr: inner, .. } => rewrite_excluded_in_expr(inner, insertion)?,
Expr::FunctionCall {
args,
order_by,
filter_over,
..
} => {
for a in args {
rewrite_excluded_in_expr(a, insertion)?
}
for sc in order_by {
rewrite_excluded_in_expr(&mut sc.expr, insertion)?
}
if let Some(ex) = &mut filter_over.filter_clause {
rewrite_excluded_in_expr(ex, insertion)?
}
}
Expr::InList { lhs, rhs, .. } => {
rewrite_excluded_in_expr(lhs, insertion)?;
for e in rhs {
rewrite_excluded_in_expr(e, insertion)?
}
}
Expr::InSelect { lhs, .. } => rewrite_excluded_in_expr(lhs, insertion)?,
Expr::InTable { lhs, .. } => rewrite_excluded_in_expr(lhs, insertion)?,
Expr::IsNull(inner) => rewrite_excluded_in_expr(inner, insertion)?,
Expr::Like {
lhs, rhs, escape, ..
} => {
rewrite_excluded_in_expr(lhs, insertion)?;
rewrite_excluded_in_expr(rhs, insertion)?;
if let Some(e) = escape {
rewrite_excluded_in_expr(e, insertion)?
}
}
Expr::NotNull(inner) => rewrite_excluded_in_expr(inner, insertion)?,
Expr::Unary(_, inner) => rewrite_excluded_in_expr(inner, insertion)?,
_ => {}
}
Ok(())
}
/// Match ON CONFLICT target to the PRIMARY KEY, if any.
/// If no target is specified, it is an automatic match for PRIMARY KEY
pub fn upsert_matches_pk(upsert: &Upsert, table: &Table) -> bool {
@@ -306,6 +213,7 @@ pub fn emit_upsert(
program: &mut ProgramBuilder,
schema: &Schema,
table: &Table,
insertion: &Insertion,
tbl_cursor_id: usize,
conflict_rowid_reg: usize,
set_pairs: &mut [(usize, Box<ast::Expr>)],
@@ -362,14 +270,16 @@ pub fn emit_upsert(
extra_amount: num_cols - 1,
});
// rewrite target-table refs -> registers from current snapshot
let rewrite_target = |e: &mut ast::Expr| {
rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg);
};
// WHERE predicate on the target row. If false or NULL, skip the UPDATE.
if let Some(pred) = where_clause.as_mut() {
rewrite_target(pred);
rewrite_upsert_expr_in_place(
pred,
table,
table.get_name(),
current_start,
conflict_rowid_reg,
insertion,
)?;
let pr = program.alloc_register();
translate_expr(program, None, pred, pr, resolver)?;
program.emit_insn(Insn::IfNot {
@@ -381,7 +291,14 @@ pub fn emit_upsert(
// Evaluate each SET expression into the NEW row img
for (col_idx, expr) in set_pairs.iter_mut() {
rewrite_target(expr);
rewrite_upsert_expr_in_place(
expr,
table,
table.get_name(),
current_start,
conflict_rowid_reg,
insertion,
)?;
translate_expr_no_constant_opt(
program,
None,
@@ -549,7 +466,6 @@ pub fn emit_upsert(
pub fn collect_set_clauses_for_upsert(
table: &Table,
set_items: &mut [ast::Set],
insertion: &Insertion,
) -> crate::Result<Vec<(usize, Box<ast::Expr>)>> {
let lookup: HashMap<String, usize> = table
.columns()
@@ -572,8 +488,7 @@ pub fn collect_set_clauses_for_upsert(
values.len()
);
}
for (cn, mut e) in set.col_names.iter().zip(values.into_iter()) {
rewrite_excluded_in_expr(&mut e, insertion)?;
for (cn, e) in set.col_names.iter().zip(values.into_iter()) {
let Some(idx) = lookup.get(&normalize_ident(cn.as_str())) else {
bail_parse_error!("no such column: {}", cn);
};
@@ -587,142 +502,315 @@ pub fn collect_set_clauses_for_upsert(
Ok(out)
}
/// Rewrite references to the target table's columns in an expression tree so that
/// they read from registers containing the CURRENT (pre-update) row snapshot.
/// Rewrite an UPSERT expression so that:
/// EXCLUDED.x -> Register(insertion.x)
/// t.x / x -> Register(CURRENT.x) when t == target table or unqualified
/// rowid -> Register(conflict_rowid_reg)
///
/// This matches SQLite's rule that unqualified column refs in the DO UPDATE arm
/// refer to the original row, not the would-be inserted values, which must use
/// `EXCLUDED.x` and are handled earlier by `rewrite_excluded_in_expr`.
fn rewrite_target_cols_to_current_row(
expr: &mut ast::Expr,
/// Only rewrites names in the current expression scope, does not enter subqueries.
fn rewrite_upsert_expr_in_place(
e: &mut ast::Expr,
table: &Table,
table_name: &str,
current_start: usize,
conflict_rowid_reg: usize,
) {
// Helper: map a column name to (is_rowid, register)
insertion: &Insertion,
) -> crate::Result<()> {
use ast::Expr::*;
// helper: return the CURRENT-row register for a column (including rowid alias)
let col_reg = |name: &str| -> Option<usize> {
if name.eq_ignore_ascii_case("rowid") {
return Some(conflict_rowid_reg);
}
let (idx, col) = table.get_column_by_name(&normalize_ident(name))?;
if col.is_rowid_alias {
// You loaded alias value into current_start + idx
return Some(current_start + idx);
}
let (idx, _c) = table.get_column_by_name(&normalize_ident(name))?;
Some(current_start + idx)
};
match expr {
// tbl.col: only rewrite if it names this table
Expr::Qualified(left, col) => {
let q = left.as_str();
if !q.eq_ignore_ascii_case("excluded") && q.eq_ignore_ascii_case(table.get_name()) {
if let ast::Name::Ident(c) = col {
if let Some(reg) = col_reg(c.as_str()) {
*expr = Expr::Register(reg);
}
}
}
}
Expr::Id(ast::Name::Ident(name)) => {
if let Some(reg) = col_reg(name.as_str()) {
*expr = Expr::Register(reg);
}
}
Expr::RowId { .. } => {
*expr = Expr::Register(conflict_rowid_reg);
match e {
// EXCLUDED.x -> insertion register
Qualified(ns, ast::Name::Ident(c)) if ns.as_str().eq_ignore_ascii_case("excluded") => {
let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else {
bail_parse_error!("no such column in EXCLUDED: {}", c);
};
*e = Register(reg.register);
}
// Keep walking for composite expressions
Expr::Collate(inner, _) => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
}
Expr::Parenthesized(v) => {
for e in v {
rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg)
// t.x -> CURRENT, only if t matches the target table name (never "excluded")
Qualified(ns, ast::Name::Ident(c)) if ns.as_str().eq_ignore_ascii_case(table_name) => {
if let Some(reg) = col_reg(c.as_str()) {
*e = Register(reg);
}
}
Expr::Between {
// Unqualified column id -> CURRENT
Id(ast::Name::Ident(name)) => {
if let Some(reg) = col_reg(name.as_str()) {
*e = Register(reg);
}
}
RowId { .. } => {
*e = Register(conflict_rowid_reg);
}
Collate(inner, _) => rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?,
Parenthesized(v) => {
for ex in v {
rewrite_upsert_expr_in_place(
ex,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
}
Between {
lhs, start, end, ..
} => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg);
rewrite_target_cols_to_current_row(start, table, current_start, conflict_rowid_reg);
rewrite_target_cols_to_current_row(end, table, current_start, conflict_rowid_reg);
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
start,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
end,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Expr::Binary(l, _, r) => {
rewrite_target_cols_to_current_row(l, table, current_start, conflict_rowid_reg);
rewrite_target_cols_to_current_row(r, table, current_start, conflict_rowid_reg);
Binary(l, _, r) => {
rewrite_upsert_expr_in_place(
l,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
r,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Expr::Case {
Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(b) = base {
rewrite_target_cols_to_current_row(b, table, current_start, conflict_rowid_reg)
rewrite_upsert_expr_in_place(
b,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
for (w, t) in when_then_pairs.iter_mut() {
rewrite_target_cols_to_current_row(w, table, current_start, conflict_rowid_reg);
rewrite_target_cols_to_current_row(t, table, current_start, conflict_rowid_reg);
rewrite_upsert_expr_in_place(
w,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
t,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
if let Some(e) = else_expr {
rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg)
if let Some(e2) = else_expr {
rewrite_upsert_expr_in_place(
e2,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
}
Expr::Cast { expr: inner, .. } => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
Cast { expr: inner, .. } => {
rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Expr::FunctionCall {
FunctionCall {
args,
order_by,
filter_over,
..
} => {
for a in args {
rewrite_target_cols_to_current_row(a, table, current_start, conflict_rowid_reg)
}
for sc in order_by {
rewrite_target_cols_to_current_row(
&mut sc.expr,
rewrite_upsert_expr_in_place(
a,
table,
table_name,
current_start,
conflict_rowid_reg,
)
insertion,
)?;
}
for sc in order_by {
rewrite_upsert_expr_in_place(
&mut sc.expr,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
if let Some(ref mut f) = &mut filter_over.filter_clause {
rewrite_target_cols_to_current_row(f, table, current_start, conflict_rowid_reg)
rewrite_upsert_expr_in_place(
f,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
}
Expr::InList { lhs, rhs, .. } => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg);
for e in rhs {
rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg)
InList { lhs, rhs, .. } => {
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
for ex in rhs {
rewrite_upsert_expr_in_place(
ex,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
}
Expr::InSelect { lhs, .. } => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg)
InSelect { lhs, .. } => {
// rewrite only `lhs`, not the subselect
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Expr::InTable { lhs, .. } => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg)
InTable { lhs, .. } => {
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Expr::IsNull(inner) => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
IsNull(inner) => {
rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Expr::Like {
Like {
lhs, rhs, escape, ..
} => {
rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg);
rewrite_target_cols_to_current_row(rhs, table, current_start, conflict_rowid_reg);
if let Some(e) = escape {
rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg)
rewrite_upsert_expr_in_place(
lhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
rewrite_upsert_expr_in_place(
rhs,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
if let Some(e3) = escape {
rewrite_upsert_expr_in_place(
e3,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
}
Expr::NotNull(inner) => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
NotNull(inner) => {
rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
Expr::Unary(_, inner) => {
rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg)
Unary(_, inner) => {
rewrite_upsert_expr_in_place(
inner,
table,
table_name,
current_start,
conflict_rowid_reg,
insertion,
)?;
}
_ => {}
}
Ok(())
}