diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 6339c2788..9c942d83e 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -17,7 +17,7 @@ use crate::translate::expr::{ use crate::translate::plan::TableReferences; use crate::translate::planner::ROWID; use crate::translate::upsert::{ - collect_set_clauses_for_upsert, emit_upsert, upsert_matches_index, upsert_matches_pk, + collect_set_clauses_for_upsert, emit_upsert, resolve_upsert_target, ResolvedUpsertTarget, }; use crate::util::normalize_ident; use crate::vdbe::builder::ProgramBuilderOpts; @@ -168,6 +168,12 @@ pub fn translate_insert( } upsert_opt = upsert.as_deref().cloned(); } + // resolve the constrained target for UPSERT if specified + let resolved_upsert = if let Some(upsert) = &upsert_opt { + Some(resolve_upsert_target(schema, &table, upsert)?) + } else { + None + }; let halt_label = program.allocate_label(); let loop_start_label = program.allocate_label(); @@ -438,8 +444,13 @@ pub fn translate_insert( // Conflict on rowid: attempt to route through UPSERT if it targets the PK, otherwise raise constraint. // emit Halt for every case *except* when upsert handles the conflict 'emit_halt: { - if let Some(ref mut upsert) = upsert_opt.as_mut() { - if upsert_matches_pk(upsert, &table) { + if let (Some(ref mut upsert), Some(ref target)) = + (upsert_opt.as_mut(), resolved_upsert.as_ref()) + { + if matches!( + target, + ResolvedUpsertTarget::CatchAll | ResolvedUpsertTarget::PrimaryKey + ) { match upsert.do_clause { UpsertDo::Nothing => { program.emit_insn(Insn::Goto { @@ -451,7 +462,6 @@ pub fn translate_insert( ref mut where_clause, } => { let mut rewritten_sets = collect_set_clauses_for_upsert(&table, sets)?; - emit_upsert( &mut program, schema, @@ -590,11 +600,16 @@ pub fn translate_insert( accum }, ); - // again, emit halt for every case *except* when upsert handles the conflict 'emit_halt: { - if let Some(ref mut upsert) = upsert_opt.as_mut() { - if upsert_matches_index(upsert, index, &table) { + if let (Some(ref mut upsert), Some(ref target)) = + (upsert_opt.as_mut(), resolved_upsert.as_ref()) + { + if match target { + ResolvedUpsertTarget::CatchAll => true, + ResolvedUpsertTarget::Index(tgt) => Arc::ptr_eq(tgt, index), + ResolvedUpsertTarget::PrimaryKey => false, + } { match upsert.do_clause { UpsertDo::Nothing => { program.emit_insn(Insn::Goto { diff --git a/core/translate/upsert.rs b/core/translate/upsert.rs index b10a24b45..6cf520ad8 100644 --- a/core/translate/upsert.rs +++ b/core/translate/upsert.rs @@ -2,7 +2,9 @@ use std::{collections::HashMap, sync::Arc}; use turso_parser::ast::{self, Upsert}; +use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; use crate::translate::expr::WalkControl; +use crate::vdbe::insn::CmpInsFlags; use crate::{ bail_parse_error, error::SQLITE_CONSTRAINT_NOTNULL, @@ -185,6 +187,42 @@ pub fn upsert_matches_index(upsert: &Upsert, index: &Index, table: &Table) -> bo need.is_empty() } +#[derive(Clone)] +pub enum ResolvedUpsertTarget { + // ON CONFLICT DO + CatchAll, + // ON CONFLICT(pk) DO + PrimaryKey, + // matched this non-partial UNIQUE index + Index(Arc), +} + +pub fn resolve_upsert_target( + schema: &Schema, + table: &Table, + upsert: &Upsert, +) -> crate::Result { + // Omitted target, catch-all + if upsert.index.is_none() { + return Ok(ResolvedUpsertTarget::CatchAll); + } + + // Targeted: must match PK or a non-partial UNIQUE index. + if upsert_matches_pk(upsert, table) { + return Ok(ResolvedUpsertTarget::PrimaryKey); + } + + for idx in schema.get_indices(table.get_name()) { + if idx.unique && idx.where_clause.is_none() && upsert_matches_index(upsert, idx, table) { + return Ok(ResolvedUpsertTarget::Index(Arc::clone(idx))); + } + } + // Match SQLite’s error text: + crate::bail_parse_error!( + "ON CONFLICT clause does not match any PRIMARY KEY or UNIQUE constraint" + ); +} + #[allow(clippy::too_many_arguments)] /// Emit the bytecode to implement the `DO UPDATE` arm of an UPSERT. /// @@ -336,6 +374,51 @@ pub fn emit_upsert( .expect("index exists"); let k = idx_meta.columns.len(); + let (before_pred_reg, new_pred_reg) = if let Some(where_clause) = &idx_meta.where_clause + { + // BEFORE image predicate + let mut before_where = where_clause.as_ref().clone(); + rewrite_partial_index_where_for_image( + &mut before_where, + table, + before_start.expect("before_start must exist for index maintenance"), + conflict_rowid_reg, + )?; + let before_reg = program.alloc_register(); + translate_expr_no_constant_opt( + program, + None, + &before_where, + before_reg, + resolver, + NoConstantOptReason::RegisterReuse, + )?; + + // NEW image predicate + let mut new_where = where_clause.as_ref().clone(); + rewrite_partial_index_where_for_image( + &mut new_where, + table, + new_start, + conflict_rowid_reg, + )?; + let new_reg = program.alloc_register(); + translate_expr(program, None, &new_where, new_reg, resolver)?; + + (Some(before_reg), Some(new_reg)) + } else { + (None, None) + }; + let maybe_skip_del = before_pred_reg.map(|r| { + let lbl = program.allocate_label(); + program.emit_insn(Insn::IfNot { + reg: r, + target_pc: lbl, + jump_if_null: true, + }); + lbl + }); + let del = program.alloc_registers(k + 1); for (i, ic) in idx_meta.columns.iter().enumerate() { let (ci, _) = table.get_column_by_name(&ic.name).unwrap(); @@ -357,6 +440,22 @@ pub fn emit_upsert( raise_error_if_no_matching_entry: false, }); + // resolve skipping the delete if it was false/NULL + if let Some(label) = maybe_skip_del { + program.resolve_label(label, program.offset()); + } + + // if NEW does not satisfy partial index, skip the insert + let maybe_skip_ins = new_pred_reg.map(|r| { + let lbl = program.allocate_label(); + program.emit_insn(Insn::IfNot { + reg: r, + target_pc: lbl, + jump_if_null: true, + }); + lbl + }); + let ins = program.alloc_registers(k + 1); for (i, ic) in idx_meta.columns.iter().enumerate() { let (ci, _) = table.get_column_by_name(&ic.name).unwrap(); @@ -380,6 +479,55 @@ pub fn emit_upsert( index_name: Some((*idx_name).clone()), affinity_str: None, }); + + // If unique, perform NoConflict + self-check before IdxInsert + if idx_meta.unique { + let ok_lbl = program.allocate_label(); + program.emit_insn(Insn::NoConflict { + cursor_id: *idx_cid, + target_pc: ok_lbl, + record_reg: ins, + num_regs: k, + }); + + // If there’s a hit, skip it if it’s self, otherwise raise constraint + let hit_rowid = program.alloc_register(); + program.emit_insn(Insn::IdxRowId { + cursor_id: *idx_cid, + dest: hit_rowid, + }); + program.emit_insn(Insn::Eq { + lhs: conflict_rowid_reg, + rhs: hit_rowid, + target_pc: ok_lbl, + flags: CmpInsFlags::default(), + collation: program.curr_collation(), + }); + let mut description = String::with_capacity( + table.get_name().len() + + idx_meta + .columns + .iter() + .map(|c| c.name.len() + 2) + .sum::(), + ); + description.push_str(table.get_name()); + description.push_str(".("); + description.push_str( + &idx_meta + .columns + .iter() + .map(|c| c.name.as_str()) + .collect::>() + .join(", "), + ); + description.push(')'); + program.emit_insn(Insn::Halt { + err_code: SQLITE_CONSTRAINT_PRIMARYKEY, + description, + }); + program.preassign_label_to_next_insn(ok_lbl); + } program.emit_insn(Insn::IdxInsert { cursor_id: *idx_cid, record_reg: rec, @@ -387,6 +535,9 @@ pub fn emit_upsert( unpacked_count: Some((k + 1) as u16), flags: IdxInsertFlags::new().nchange(true), }); + if let Some(lbl) = maybe_skip_ins { + program.resolve_label(lbl, program.offset()); + } } } @@ -569,3 +720,52 @@ fn rewrite_upsert_expr_in_place( }, ) } + +/// Rewrite partial-index WHERE to read from a contiguous row image starting at `base_start`. +/// Maps rowid (and the rowid-alias column) to `rowid_reg`... Very similar to the above method +/// but simpler because there is no EXCLUDED or table name to consider. +fn rewrite_partial_index_where_for_image( + expr: &mut ast::Expr, + table: &Table, + base_start: usize, + rowid_reg: usize, +) -> crate::Result { + walk_expr_mut( + expr, + &mut |e: &mut ast::Expr| -> crate::Result { + match e { + ast::Expr::Id(n) => { + let nm = normalize_ident(n.as_str()); + if nm.eq_ignore_ascii_case("rowid") { + *e = ast::Expr::Register(rowid_reg); + } else if let Some((col_idx, _)) = table.get_column_by_name(&nm) { + let col = &table.columns()[col_idx]; + *e = ast::Expr::Register(if col.is_rowid_alias { + rowid_reg + } else { + base_start + col_idx + }); + } + } + ast::Expr::Qualified(_, cn) | ast::Expr::DoublyQualified(_, _, cn) => { + let nm = normalize_ident(cn.as_str()); + if nm.eq_ignore_ascii_case("rowid") { + *e = ast::Expr::Register(rowid_reg); + } else if let Some((col_idx, _)) = table.get_column_by_name(&nm) { + let col = &table.columns()[col_idx]; + *e = ast::Expr::Register(if col.is_rowid_alias { + rowid_reg + } else { + base_start + col_idx + }); + } + } + ast::Expr::RowId { .. } => { + *e = ast::Expr::Register(rowid_reg); + } + _ => {} + } + Ok(WalkControl::Continue) + }, + ) +}