diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index f5243c9c5..c378885f7 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -742,11 +742,45 @@ fn emit_update_insns( cursor_id, dest: beg, }); - // if no rowid, we're done - program.emit_insn(Insn::IsNull { - reg: beg, - target_pc: t_ctx.label_main_loop_end.unwrap(), - }); + + // Check if rowid was provided (through INTEGER PRIMARY KEY as a rowid alias) + let rowid_alias_index = table_ref.columns().iter().position(|c| c.is_rowid_alias); + let rowid_set_clause_reg = if rowid_alias_index.is_some() { + Some(program.alloc_register()) + } else { + None + }; + let has_user_provided_rowid = { + if let Some(index) = rowid_alias_index { + plan.set_clauses + .iter() + .position(|(idx, _)| *idx == index) + .is_some() + } else { + false + } + }; + + let check_rowid_not_exists_label = if has_user_provided_rowid { + Some(program.allocate_label()) + } else { + None + }; + + if has_user_provided_rowid { + program.emit_insn(Insn::NotExists { + cursor: cursor_id, + rowid_reg: beg, + target_pc: check_rowid_not_exists_label.unwrap(), + }); + } else { + // if no rowid, we're done + program.emit_insn(Insn::IsNull { + reg: beg, + target_pc: t_ctx.label_main_loop_end.unwrap(), + }); + } + if is_virtual { program.emit_insn(Insn::Copy { src_reg: beg, @@ -775,9 +809,8 @@ fn emit_update_insns( let rowid_reg = beg; let idx_cols_start_reg = beg + 1; - // copy each index column from the table's column registers into these scratch regs - for (i, col) in index.columns.iter().enumerate(){ + for (i, col) in index.columns.iter().enumerate() { // copy from the table's column register over to the index's scratch register program.emit_insn(Insn::Copy { @@ -906,18 +939,38 @@ fn emit_update_insns( // we scan a column at a time, loading either the column's values, or the new value // from the Set expression, into registers so we can emit a MakeRecord and update the row. let start = if is_virtual { beg + 2 } else { beg + 1 }; - for idx in 0..table_ref.columns().len() { + for (idx, table_column) in table_ref.columns().iter().enumerate() { let target_reg = start + idx; if let Some((_, expr)) = plan.set_clauses.iter().find(|(i, _)| *i == idx) { - translate_expr( - program, - Some(&plan.table_references), - expr, - target_reg, - &t_ctx.resolver, - )?; + if has_user_provided_rowid + && (table_column.primary_key || table_column.is_rowid_alias) + && !is_virtual + { + let rowid_set_clause_reg = rowid_set_clause_reg.unwrap(); + translate_expr( + program, + Some(&plan.table_references), + expr, + rowid_set_clause_reg, + &t_ctx.resolver, + )?; + + program.emit_insn(Insn::MustBeInt { + reg: rowid_set_clause_reg, + }); + + program.emit_null(target_reg, None); + } else { + translate_expr( + program, + Some(&plan.table_references), + expr, + target_reg, + &t_ctx.resolver, + )?; + } + // if let Some(rowid_reg) = rowid_set_clause_reg {} } else { - let table_column = table_ref.table.columns().get(idx).unwrap(); let column_idx_in_index = index.as_ref().and_then(|(idx, _)| { idx.columns .iter() @@ -961,6 +1014,42 @@ fn emit_update_insns( table_reference: Rc::clone(&btree_table), }); } + + if has_user_provided_rowid { + let record_label = program.allocate_label(); + let idx = rowid_alias_index.unwrap(); + let target_reg = rowid_set_clause_reg.unwrap(); + program.emit_insn(Insn::Eq { + lhs: target_reg, + rhs: beg, + target_pc: record_label, + flags: CmpInsFlags::default(), + }); + + program.emit_insn(Insn::NotExists { + cursor: cursor_id, + rowid_reg: target_reg, + target_pc: record_label, + }); + + program.emit_insn(Insn::Halt { + err_code: SQLITE_CONSTRAINT_PRIMARYKEY, + description: format!( + "{}.{}", + table_ref.table.get_name(), + &table_ref + .columns() + .get(idx) + .unwrap() + .name + .as_ref() + .map_or("", |v| v) + ), + }); + + program.preassign_label_to_next_insn(record_label); + } + let record_reg = program.alloc_register(); program.emit_insn(Insn::MakeRecord { start_reg: start, @@ -1012,7 +1101,7 @@ fn emit_update_insns( program.emit_insn(Insn::Insert { cursor: cursor_id, - key_reg: beg, + key_reg: rowid_set_clause_reg.unwrap_or(beg), record_reg, flag: 0, table_name: table_ref.identifier.clone(), @@ -1035,5 +1124,10 @@ fn emit_update_insns( }) } // TODO(pthorpe): handle RETURNING clause + + if let Some(label) = check_rowid_not_exists_label { + program.preassign_label_to_next_insn(label); + } + Ok(()) } diff --git a/testing/cli_tests/constraint.py b/testing/cli_tests/constraint.py index 9d5003840..36181fa67 100644 --- a/testing/cli_tests/constraint.py +++ b/testing/cli_tests/constraint.py @@ -304,9 +304,8 @@ def generate_test(col_amount: int, primary_keys: int) -> ConstraintTest: update_errors = [] if len(insert_stmts) > 1: - update_errors = [ - table.generate_update() for _ in table.columns if col.primary_key - ] + # TODO: As we have no rollback we just generate one update statement + update_errors = [table.generate_update()] return ConstraintTest( table=table, @@ -327,7 +326,6 @@ def custom_test_1() -> ConstraintTest: "INSERT INTO users VALUES (2, 'bob');", ] update_stmts = [ - "UPDATE users SET id = 3;", "UPDATE users SET id = 2, username = 'bob' WHERE id == 1;", ] return ConstraintTest(