Fix INSERT/UPSERT to properly handle and/or reject partial indexes

This commit is contained in:
PThorpe92
2025-09-20 18:32:03 -04:00
parent 51fb801d87
commit 62ee68e4dd
2 changed files with 222 additions and 7 deletions

View File

@@ -17,7 +17,7 @@ use crate::translate::expr::{
use crate::translate::plan::TableReferences;
use crate::translate::planner::ROWID;
use crate::translate::upsert::{
collect_set_clauses_for_upsert, emit_upsert, upsert_matches_index, upsert_matches_pk,
collect_set_clauses_for_upsert, emit_upsert, resolve_upsert_target, ResolvedUpsertTarget,
};
use crate::util::normalize_ident;
use crate::vdbe::builder::ProgramBuilderOpts;
@@ -168,6 +168,12 @@ pub fn translate_insert(
}
upsert_opt = upsert.as_deref().cloned();
}
// resolve the constrained target for UPSERT if specified
let resolved_upsert = if let Some(upsert) = &upsert_opt {
Some(resolve_upsert_target(schema, &table, upsert)?)
} else {
None
};
let halt_label = program.allocate_label();
let loop_start_label = program.allocate_label();
@@ -438,8 +444,13 @@ pub fn translate_insert(
// Conflict on rowid: attempt to route through UPSERT if it targets the PK, otherwise raise constraint.
// emit Halt for every case *except* when upsert handles the conflict
'emit_halt: {
if let Some(ref mut upsert) = upsert_opt.as_mut() {
if upsert_matches_pk(upsert, &table) {
if let (Some(ref mut upsert), Some(ref target)) =
(upsert_opt.as_mut(), resolved_upsert.as_ref())
{
if matches!(
target,
ResolvedUpsertTarget::CatchAll | ResolvedUpsertTarget::PrimaryKey
) {
match upsert.do_clause {
UpsertDo::Nothing => {
program.emit_insn(Insn::Goto {
@@ -451,7 +462,6 @@ pub fn translate_insert(
ref mut where_clause,
} => {
let mut rewritten_sets = collect_set_clauses_for_upsert(&table, sets)?;
emit_upsert(
&mut program,
schema,
@@ -590,11 +600,16 @@ pub fn translate_insert(
accum
},
);
// again, emit halt for every case *except* when upsert handles the conflict
'emit_halt: {
if let Some(ref mut upsert) = upsert_opt.as_mut() {
if upsert_matches_index(upsert, index, &table) {
if let (Some(ref mut upsert), Some(ref target)) =
(upsert_opt.as_mut(), resolved_upsert.as_ref())
{
if match target {
ResolvedUpsertTarget::CatchAll => true,
ResolvedUpsertTarget::Index(tgt) => Arc::ptr_eq(tgt, index),
ResolvedUpsertTarget::PrimaryKey => false,
} {
match upsert.do_clause {
UpsertDo::Nothing => {
program.emit_insn(Insn::Goto {

View File

@@ -2,7 +2,9 @@ use std::{collections::HashMap, sync::Arc};
use turso_parser::ast::{self, Upsert};
use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY;
use crate::translate::expr::WalkControl;
use crate::vdbe::insn::CmpInsFlags;
use crate::{
bail_parse_error,
error::SQLITE_CONSTRAINT_NOTNULL,
@@ -185,6 +187,42 @@ pub fn upsert_matches_index(upsert: &Upsert, index: &Index, table: &Table) -> bo
need.is_empty()
}
#[derive(Clone)]
pub enum ResolvedUpsertTarget {
// ON CONFLICT DO
CatchAll,
// ON CONFLICT(pk) DO
PrimaryKey,
// matched this non-partial UNIQUE index
Index(Arc<Index>),
}
pub fn resolve_upsert_target(
schema: &Schema,
table: &Table,
upsert: &Upsert,
) -> crate::Result<ResolvedUpsertTarget> {
// Omitted target, catch-all
if upsert.index.is_none() {
return Ok(ResolvedUpsertTarget::CatchAll);
}
// Targeted: must match PK or a non-partial UNIQUE index.
if upsert_matches_pk(upsert, table) {
return Ok(ResolvedUpsertTarget::PrimaryKey);
}
for idx in schema.get_indices(table.get_name()) {
if idx.unique && idx.where_clause.is_none() && upsert_matches_index(upsert, idx, table) {
return Ok(ResolvedUpsertTarget::Index(Arc::clone(idx)));
}
}
// Match SQLites error text:
crate::bail_parse_error!(
"ON CONFLICT clause does not match any PRIMARY KEY or UNIQUE constraint"
);
}
#[allow(clippy::too_many_arguments)]
/// Emit the bytecode to implement the `DO UPDATE` arm of an UPSERT.
///
@@ -336,6 +374,51 @@ pub fn emit_upsert(
.expect("index exists");
let k = idx_meta.columns.len();
let (before_pred_reg, new_pred_reg) = if let Some(where_clause) = &idx_meta.where_clause
{
// BEFORE image predicate
let mut before_where = where_clause.as_ref().clone();
rewrite_partial_index_where_for_image(
&mut before_where,
table,
before_start.expect("before_start must exist for index maintenance"),
conflict_rowid_reg,
)?;
let before_reg = program.alloc_register();
translate_expr_no_constant_opt(
program,
None,
&before_where,
before_reg,
resolver,
NoConstantOptReason::RegisterReuse,
)?;
// NEW image predicate
let mut new_where = where_clause.as_ref().clone();
rewrite_partial_index_where_for_image(
&mut new_where,
table,
new_start,
conflict_rowid_reg,
)?;
let new_reg = program.alloc_register();
translate_expr(program, None, &new_where, new_reg, resolver)?;
(Some(before_reg), Some(new_reg))
} else {
(None, None)
};
let maybe_skip_del = before_pred_reg.map(|r| {
let lbl = program.allocate_label();
program.emit_insn(Insn::IfNot {
reg: r,
target_pc: lbl,
jump_if_null: true,
});
lbl
});
let del = program.alloc_registers(k + 1);
for (i, ic) in idx_meta.columns.iter().enumerate() {
let (ci, _) = table.get_column_by_name(&ic.name).unwrap();
@@ -357,6 +440,22 @@ pub fn emit_upsert(
raise_error_if_no_matching_entry: false,
});
// resolve skipping the delete if it was false/NULL
if let Some(label) = maybe_skip_del {
program.resolve_label(label, program.offset());
}
// if NEW does not satisfy partial index, skip the insert
let maybe_skip_ins = new_pred_reg.map(|r| {
let lbl = program.allocate_label();
program.emit_insn(Insn::IfNot {
reg: r,
target_pc: lbl,
jump_if_null: true,
});
lbl
});
let ins = program.alloc_registers(k + 1);
for (i, ic) in idx_meta.columns.iter().enumerate() {
let (ci, _) = table.get_column_by_name(&ic.name).unwrap();
@@ -380,6 +479,55 @@ pub fn emit_upsert(
index_name: Some((*idx_name).clone()),
affinity_str: None,
});
// If unique, perform NoConflict + self-check before IdxInsert
if idx_meta.unique {
let ok_lbl = program.allocate_label();
program.emit_insn(Insn::NoConflict {
cursor_id: *idx_cid,
target_pc: ok_lbl,
record_reg: ins,
num_regs: k,
});
// If theres a hit, skip it if its self, otherwise raise constraint
let hit_rowid = program.alloc_register();
program.emit_insn(Insn::IdxRowId {
cursor_id: *idx_cid,
dest: hit_rowid,
});
program.emit_insn(Insn::Eq {
lhs: conflict_rowid_reg,
rhs: hit_rowid,
target_pc: ok_lbl,
flags: CmpInsFlags::default(),
collation: program.curr_collation(),
});
let mut description = String::with_capacity(
table.get_name().len()
+ idx_meta
.columns
.iter()
.map(|c| c.name.len() + 2)
.sum::<usize>(),
);
description.push_str(table.get_name());
description.push_str(".(");
description.push_str(
&idx_meta
.columns
.iter()
.map(|c| c.name.as_str())
.collect::<Vec<_>>()
.join(", "),
);
description.push(')');
program.emit_insn(Insn::Halt {
err_code: SQLITE_CONSTRAINT_PRIMARYKEY,
description,
});
program.preassign_label_to_next_insn(ok_lbl);
}
program.emit_insn(Insn::IdxInsert {
cursor_id: *idx_cid,
record_reg: rec,
@@ -387,6 +535,9 @@ pub fn emit_upsert(
unpacked_count: Some((k + 1) as u16),
flags: IdxInsertFlags::new().nchange(true),
});
if let Some(lbl) = maybe_skip_ins {
program.resolve_label(lbl, program.offset());
}
}
}
@@ -569,3 +720,52 @@ fn rewrite_upsert_expr_in_place(
},
)
}
/// Rewrite partial-index WHERE to read from a contiguous row image starting at `base_start`.
/// Maps rowid (and the rowid-alias column) to `rowid_reg`... Very similar to the above method
/// but simpler because there is no EXCLUDED or table name to consider.
fn rewrite_partial_index_where_for_image(
expr: &mut ast::Expr,
table: &Table,
base_start: usize,
rowid_reg: usize,
) -> crate::Result<WalkControl> {
walk_expr_mut(
expr,
&mut |e: &mut ast::Expr| -> crate::Result<WalkControl> {
match e {
ast::Expr::Id(n) => {
let nm = normalize_ident(n.as_str());
if nm.eq_ignore_ascii_case("rowid") {
*e = ast::Expr::Register(rowid_reg);
} else if let Some((col_idx, _)) = table.get_column_by_name(&nm) {
let col = &table.columns()[col_idx];
*e = ast::Expr::Register(if col.is_rowid_alias {
rowid_reg
} else {
base_start + col_idx
});
}
}
ast::Expr::Qualified(_, cn) | ast::Expr::DoublyQualified(_, _, cn) => {
let nm = normalize_ident(cn.as_str());
if nm.eq_ignore_ascii_case("rowid") {
*e = ast::Expr::Register(rowid_reg);
} else if let Some((col_idx, _)) = table.get_column_by_name(&nm) {
let col = &table.columns()[col_idx];
*e = ast::Expr::Register(if col.is_rowid_alias {
rowid_reg
} else {
base_start + col_idx
});
}
}
ast::Expr::RowId { .. } => {
*e = ast::Expr::Register(rowid_reg);
}
_ => {}
}
Ok(WalkControl::Continue)
},
)
}