From 85315608995767c8a16b21d95c17e6eeb0b3e5e1 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Fri, 29 Aug 2025 22:12:46 -0400 Subject: [PATCH] Combine rewriting expressions in UPSERT into a single walk of the ast --- core/translate/insert.rs | 18 +- core/translate/upsert.rs | 460 +++++++++++++++++++++++---------------- 2 files changed, 280 insertions(+), 198 deletions(-) diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 5416195da..dc4a0a0eb 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -14,8 +14,7 @@ use crate::translate::expr::{ }; use crate::translate::planner::ROWID; use crate::translate::upsert::{ - collect_set_clauses_for_upsert, emit_upsert, rewrite_excluded_in_expr, upsert_matches_index, - upsert_matches_pk, + collect_set_clauses_for_upsert, emit_upsert, upsert_matches_index, upsert_matches_pk, }; use crate::util::normalize_ident; use crate::vdbe::builder::ProgramBuilderOpts; @@ -414,15 +413,13 @@ pub fn translate_insert( ref mut sets, ref mut where_clause, } => { - let mut rewritten_sets = - collect_set_clauses_for_upsert(&table, sets, &insertion)?; - if let Some(expr) = where_clause.as_mut() { - rewrite_excluded_in_expr(expr, &insertion)?; - } + let mut rewritten_sets = collect_set_clauses_for_upsert(&table, sets)?; + emit_upsert( &mut program, schema, &table, + &insertion, cursor_id, insertion.key_register(), &mut rewritten_sets, @@ -536,11 +533,7 @@ pub fn translate_insert( ref mut where_clause, } => { let mut rewritten_sets = - collect_set_clauses_for_upsert(&table, sets, &insertion)?; - if let Some(expr) = where_clause.as_mut() { - rewrite_excluded_in_expr(expr, &insertion)?; - } - + collect_set_clauses_for_upsert(&table, sets)?; let conflict_rowid_reg = program.alloc_register(); program.emit_insn(Insn::IdxRowId { cursor_id: idx_cursor_id, @@ -550,6 +543,7 @@ pub fn translate_insert( &mut program, schema, &table, + &insertion, cursor_id, conflict_rowid_reg, &mut rewritten_sets, diff --git a/core/translate/upsert.rs b/core/translate/upsert.rs index 875ae9310..7e0141ed2 100644 --- a/core/translate/upsert.rs +++ b/core/translate/upsert.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, sync::Arc}; -use turso_parser::ast::{self, Expr, Upsert}; +use turso_parser::ast::{self, Upsert}; use crate::{ bail_parse_error, @@ -87,99 +87,6 @@ fn effective_collation_for_index_col(idx_col: &IndexColumn, table: &Table) -> St .unwrap_or_else(|| "binary".to_string()) } -/// Column names in the expressions of a DO UPDATE refer to the original unchanged value of the column, before the attempted INSERT. -/// To use the value that would have been inserted had the constraint not failed, add the special "excluded." table qualifier to the column name. -/// https://sqlite.org/lang_upsert.html -/// -/// Rewrite EXCLUDED.x to Expr::Register() -pub fn rewrite_excluded_in_expr(expr: &mut Expr, insertion: &Insertion) -> crate::Result<()> { - match expr { - Expr::Qualified(ns, col) if ns.as_str().eq_ignore_ascii_case("excluded") => { - let cname = match col { - ast::Name::Ident(s) => s.as_str(), - _ => return Ok(()), - }; - let Some(src) = insertion.get_col_mapping_by_name(cname) else { - bail_parse_error!("no such column in EXCLUDED: {}", cname); - }; - *expr = Expr::Register(src.register) - } - - Expr::Collate(inner, _) => rewrite_excluded_in_expr(inner, insertion)?, - Expr::Parenthesized(v) => { - for e in v { - rewrite_excluded_in_expr(e, insertion)? - } - } - Expr::Between { - lhs, start, end, .. - } => { - rewrite_excluded_in_expr(lhs, insertion)?; - rewrite_excluded_in_expr(start, insertion)?; - rewrite_excluded_in_expr(end, insertion)?; - } - Expr::Binary(l, _, r) => { - rewrite_excluded_in_expr(l, insertion)?; - rewrite_excluded_in_expr(r, insertion)?; - } - Expr::Case { - base, - when_then_pairs, - else_expr, - } => { - if let Some(b) = base { - rewrite_excluded_in_expr(b, insertion)? - }; - for (w, t) in when_then_pairs.iter_mut() { - rewrite_excluded_in_expr(w, insertion)?; - rewrite_excluded_in_expr(t, insertion)?; - } - if let Some(e) = else_expr { - rewrite_excluded_in_expr(e, insertion)? - } - } - Expr::Cast { expr: inner, .. } => rewrite_excluded_in_expr(inner, insertion)?, - Expr::FunctionCall { - args, - order_by, - filter_over, - .. - } => { - for a in args { - rewrite_excluded_in_expr(a, insertion)? - } - for sc in order_by { - rewrite_excluded_in_expr(&mut sc.expr, insertion)? - } - if let Some(ex) = &mut filter_over.filter_clause { - rewrite_excluded_in_expr(ex, insertion)? - } - } - Expr::InList { lhs, rhs, .. } => { - rewrite_excluded_in_expr(lhs, insertion)?; - for e in rhs { - rewrite_excluded_in_expr(e, insertion)? - } - } - Expr::InSelect { lhs, .. } => rewrite_excluded_in_expr(lhs, insertion)?, - Expr::InTable { lhs, .. } => rewrite_excluded_in_expr(lhs, insertion)?, - Expr::IsNull(inner) => rewrite_excluded_in_expr(inner, insertion)?, - Expr::Like { - lhs, rhs, escape, .. - } => { - rewrite_excluded_in_expr(lhs, insertion)?; - rewrite_excluded_in_expr(rhs, insertion)?; - if let Some(e) = escape { - rewrite_excluded_in_expr(e, insertion)? - } - } - Expr::NotNull(inner) => rewrite_excluded_in_expr(inner, insertion)?, - Expr::Unary(_, inner) => rewrite_excluded_in_expr(inner, insertion)?, - _ => {} - } - Ok(()) -} - /// Match ON CONFLICT target to the PRIMARY KEY, if any. /// If no target is specified, it is an automatic match for PRIMARY KEY pub fn upsert_matches_pk(upsert: &Upsert, table: &Table) -> bool { @@ -306,6 +213,7 @@ pub fn emit_upsert( program: &mut ProgramBuilder, schema: &Schema, table: &Table, + insertion: &Insertion, tbl_cursor_id: usize, conflict_rowid_reg: usize, set_pairs: &mut [(usize, Box)], @@ -362,14 +270,16 @@ pub fn emit_upsert( extra_amount: num_cols - 1, }); - // rewrite target-table refs -> registers from current snapshot - let rewrite_target = |e: &mut ast::Expr| { - rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg); - }; - // WHERE predicate on the target row. If false or NULL, skip the UPDATE. if let Some(pred) = where_clause.as_mut() { - rewrite_target(pred); + rewrite_upsert_expr_in_place( + pred, + table, + table.get_name(), + current_start, + conflict_rowid_reg, + insertion, + )?; let pr = program.alloc_register(); translate_expr(program, None, pred, pr, resolver)?; program.emit_insn(Insn::IfNot { @@ -381,7 +291,14 @@ pub fn emit_upsert( // Evaluate each SET expression into the NEW row img for (col_idx, expr) in set_pairs.iter_mut() { - rewrite_target(expr); + rewrite_upsert_expr_in_place( + expr, + table, + table.get_name(), + current_start, + conflict_rowid_reg, + insertion, + )?; translate_expr_no_constant_opt( program, None, @@ -549,7 +466,6 @@ pub fn emit_upsert( pub fn collect_set_clauses_for_upsert( table: &Table, set_items: &mut [ast::Set], - insertion: &Insertion, ) -> crate::Result)>> { let lookup: HashMap = table .columns() @@ -572,8 +488,7 @@ pub fn collect_set_clauses_for_upsert( values.len() ); } - for (cn, mut e) in set.col_names.iter().zip(values.into_iter()) { - rewrite_excluded_in_expr(&mut e, insertion)?; + for (cn, e) in set.col_names.iter().zip(values.into_iter()) { let Some(idx) = lookup.get(&normalize_ident(cn.as_str())) else { bail_parse_error!("no such column: {}", cn); }; @@ -587,142 +502,315 @@ pub fn collect_set_clauses_for_upsert( Ok(out) } -/// Rewrite references to the target table's columns in an expression tree so that -/// they read from registers containing the CURRENT (pre-update) row snapshot. +/// Rewrite an UPSERT expression so that: +/// EXCLUDED.x -> Register(insertion.x) +/// t.x / x -> Register(CURRENT.x) when t == target table or unqualified +/// rowid -> Register(conflict_rowid_reg) /// -/// This matches SQLite's rule that unqualified column refs in the DO UPDATE arm -/// refer to the original row, not the would-be inserted values, which must use -/// `EXCLUDED.x` and are handled earlier by `rewrite_excluded_in_expr`. -fn rewrite_target_cols_to_current_row( - expr: &mut ast::Expr, +/// Only rewrites names in the current expression scope, does not enter subqueries. +fn rewrite_upsert_expr_in_place( + e: &mut ast::Expr, table: &Table, + table_name: &str, current_start: usize, conflict_rowid_reg: usize, -) { - // Helper: map a column name to (is_rowid, register) + insertion: &Insertion, +) -> crate::Result<()> { + use ast::Expr::*; + + // helper: return the CURRENT-row register for a column (including rowid alias) let col_reg = |name: &str| -> Option { if name.eq_ignore_ascii_case("rowid") { return Some(conflict_rowid_reg); } - let (idx, col) = table.get_column_by_name(&normalize_ident(name))?; - if col.is_rowid_alias { - // You loaded alias value into current_start + idx - return Some(current_start + idx); - } + let (idx, _c) = table.get_column_by_name(&normalize_ident(name))?; Some(current_start + idx) }; - match expr { - // tbl.col: only rewrite if it names this table - Expr::Qualified(left, col) => { - let q = left.as_str(); - if !q.eq_ignore_ascii_case("excluded") && q.eq_ignore_ascii_case(table.get_name()) { - if let ast::Name::Ident(c) = col { - if let Some(reg) = col_reg(c.as_str()) { - *expr = Expr::Register(reg); - } - } - } - } - Expr::Id(ast::Name::Ident(name)) => { - if let Some(reg) = col_reg(name.as_str()) { - *expr = Expr::Register(reg); - } - } - Expr::RowId { .. } => { - *expr = Expr::Register(conflict_rowid_reg); + match e { + // EXCLUDED.x -> insertion register + Qualified(ns, ast::Name::Ident(c)) if ns.as_str().eq_ignore_ascii_case("excluded") => { + let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else { + bail_parse_error!("no such column in EXCLUDED: {}", c); + }; + *e = Register(reg.register); } - // Keep walking for composite expressions - Expr::Collate(inner, _) => { - rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg) - } - Expr::Parenthesized(v) => { - for e in v { - rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg) + // t.x -> CURRENT, only if t matches the target table name (never "excluded") + Qualified(ns, ast::Name::Ident(c)) if ns.as_str().eq_ignore_ascii_case(table_name) => { + if let Some(reg) = col_reg(c.as_str()) { + *e = Register(reg); } } - Expr::Between { + // Unqualified column id -> CURRENT + Id(ast::Name::Ident(name)) => { + if let Some(reg) = col_reg(name.as_str()) { + *e = Register(reg); + } + } + RowId { .. } => { + *e = Register(conflict_rowid_reg); + } + Collate(inner, _) => rewrite_upsert_expr_in_place( + inner, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?, + Parenthesized(v) => { + for ex in v { + rewrite_upsert_expr_in_place( + ex, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; + } + } + Between { lhs, start, end, .. } => { - rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg); - rewrite_target_cols_to_current_row(start, table, current_start, conflict_rowid_reg); - rewrite_target_cols_to_current_row(end, table, current_start, conflict_rowid_reg); + rewrite_upsert_expr_in_place( + lhs, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; + rewrite_upsert_expr_in_place( + start, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; + rewrite_upsert_expr_in_place( + end, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } - Expr::Binary(l, _, r) => { - rewrite_target_cols_to_current_row(l, table, current_start, conflict_rowid_reg); - rewrite_target_cols_to_current_row(r, table, current_start, conflict_rowid_reg); + Binary(l, _, r) => { + rewrite_upsert_expr_in_place( + l, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; + rewrite_upsert_expr_in_place( + r, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } - Expr::Case { + Case { base, when_then_pairs, else_expr, } => { if let Some(b) = base { - rewrite_target_cols_to_current_row(b, table, current_start, conflict_rowid_reg) + rewrite_upsert_expr_in_place( + b, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } for (w, t) in when_then_pairs.iter_mut() { - rewrite_target_cols_to_current_row(w, table, current_start, conflict_rowid_reg); - rewrite_target_cols_to_current_row(t, table, current_start, conflict_rowid_reg); + rewrite_upsert_expr_in_place( + w, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; + rewrite_upsert_expr_in_place( + t, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } - if let Some(e) = else_expr { - rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg) + if let Some(e2) = else_expr { + rewrite_upsert_expr_in_place( + e2, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } } - Expr::Cast { expr: inner, .. } => { - rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg) + Cast { expr: inner, .. } => { + rewrite_upsert_expr_in_place( + inner, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } - Expr::FunctionCall { + FunctionCall { args, order_by, filter_over, .. } => { for a in args { - rewrite_target_cols_to_current_row(a, table, current_start, conflict_rowid_reg) - } - for sc in order_by { - rewrite_target_cols_to_current_row( - &mut sc.expr, + rewrite_upsert_expr_in_place( + a, table, + table_name, current_start, conflict_rowid_reg, - ) + insertion, + )?; + } + for sc in order_by { + rewrite_upsert_expr_in_place( + &mut sc.expr, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } if let Some(ref mut f) = &mut filter_over.filter_clause { - rewrite_target_cols_to_current_row(f, table, current_start, conflict_rowid_reg) + rewrite_upsert_expr_in_place( + f, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } } - Expr::InList { lhs, rhs, .. } => { - rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg); - for e in rhs { - rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg) + InList { lhs, rhs, .. } => { + rewrite_upsert_expr_in_place( + lhs, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; + for ex in rhs { + rewrite_upsert_expr_in_place( + ex, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } } - Expr::InSelect { lhs, .. } => { - rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg) + InSelect { lhs, .. } => { + // rewrite only `lhs`, not the subselect + rewrite_upsert_expr_in_place( + lhs, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } - Expr::InTable { lhs, .. } => { - rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg) + InTable { lhs, .. } => { + rewrite_upsert_expr_in_place( + lhs, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } - Expr::IsNull(inner) => { - rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg) + IsNull(inner) => { + rewrite_upsert_expr_in_place( + inner, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } - Expr::Like { + Like { lhs, rhs, escape, .. } => { - rewrite_target_cols_to_current_row(lhs, table, current_start, conflict_rowid_reg); - rewrite_target_cols_to_current_row(rhs, table, current_start, conflict_rowid_reg); - if let Some(e) = escape { - rewrite_target_cols_to_current_row(e, table, current_start, conflict_rowid_reg) + rewrite_upsert_expr_in_place( + lhs, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; + rewrite_upsert_expr_in_place( + rhs, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; + if let Some(e3) = escape { + rewrite_upsert_expr_in_place( + e3, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } } - Expr::NotNull(inner) => { - rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg) + NotNull(inner) => { + rewrite_upsert_expr_in_place( + inner, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } - Expr::Unary(_, inner) => { - rewrite_target_cols_to_current_row(inner, table, current_start, conflict_rowid_reg) + Unary(_, inner) => { + rewrite_upsert_expr_in_place( + inner, + table, + table_name, + current_start, + conflict_rowid_reg, + insertion, + )?; } + _ => {} } + Ok(()) }