modified tests as we do not have rollback yet. Also correctly raise a contraint error on primary keys only

This commit is contained in:
pedrocarlo
2025-04-30 14:26:50 -03:00
parent 3aaf4206b7
commit 758dfff2fe
2 changed files with 113 additions and 21 deletions

View File

@@ -742,11 +742,45 @@ fn emit_update_insns(
cursor_id,
dest: beg,
});
// Check if rowid was provided (through INTEGER PRIMARY KEY as a rowid alias)
let rowid_alias_index = table_ref.columns().iter().position(|c| c.is_rowid_alias);
let rowid_set_clause_reg = if rowid_alias_index.is_some() {
Some(program.alloc_register())
} else {
None
};
let has_user_provided_rowid = {
if let Some(index) = rowid_alias_index {
plan.set_clauses
.iter()
.position(|(idx, _)| *idx == index)
.is_some()
} else {
false
}
};
let check_rowid_not_exists_label = if has_user_provided_rowid {
Some(program.allocate_label())
} else {
None
};
if has_user_provided_rowid {
program.emit_insn(Insn::NotExists {
cursor: cursor_id,
rowid_reg: beg,
target_pc: check_rowid_not_exists_label.unwrap(),
});
} else {
// if no rowid, we're done
program.emit_insn(Insn::IsNull {
reg: beg,
target_pc: t_ctx.label_main_loop_end.unwrap(),
});
}
if is_virtual {
program.emit_insn(Insn::Copy {
src_reg: beg,
@@ -775,7 +809,6 @@ fn emit_update_insns(
let rowid_reg = beg;
let idx_cols_start_reg = beg + 1;
// copy each index column from the table's column registers into these scratch regs
for (i, col) in index.columns.iter().enumerate() {
// copy from the table's column register over to the index's scratch register
@@ -906,9 +939,28 @@ fn emit_update_insns(
// we scan a column at a time, loading either the column's values, or the new value
// from the Set expression, into registers so we can emit a MakeRecord and update the row.
let start = if is_virtual { beg + 2 } else { beg + 1 };
for idx in 0..table_ref.columns().len() {
for (idx, table_column) in table_ref.columns().iter().enumerate() {
let target_reg = start + idx;
if let Some((_, expr)) = plan.set_clauses.iter().find(|(i, _)| *i == idx) {
if has_user_provided_rowid
&& (table_column.primary_key || table_column.is_rowid_alias)
&& !is_virtual
{
let rowid_set_clause_reg = rowid_set_clause_reg.unwrap();
translate_expr(
program,
Some(&plan.table_references),
expr,
rowid_set_clause_reg,
&t_ctx.resolver,
)?;
program.emit_insn(Insn::MustBeInt {
reg: rowid_set_clause_reg,
});
program.emit_null(target_reg, None);
} else {
translate_expr(
program,
Some(&plan.table_references),
@@ -916,8 +968,9 @@ fn emit_update_insns(
target_reg,
&t_ctx.resolver,
)?;
}
// if let Some(rowid_reg) = rowid_set_clause_reg {}
} else {
let table_column = table_ref.table.columns().get(idx).unwrap();
let column_idx_in_index = index.as_ref().and_then(|(idx, _)| {
idx.columns
.iter()
@@ -961,6 +1014,42 @@ fn emit_update_insns(
table_reference: Rc::clone(&btree_table),
});
}
if has_user_provided_rowid {
let record_label = program.allocate_label();
let idx = rowid_alias_index.unwrap();
let target_reg = rowid_set_clause_reg.unwrap();
program.emit_insn(Insn::Eq {
lhs: target_reg,
rhs: beg,
target_pc: record_label,
flags: CmpInsFlags::default(),
});
program.emit_insn(Insn::NotExists {
cursor: cursor_id,
rowid_reg: target_reg,
target_pc: record_label,
});
program.emit_insn(Insn::Halt {
err_code: SQLITE_CONSTRAINT_PRIMARYKEY,
description: format!(
"{}.{}",
table_ref.table.get_name(),
&table_ref
.columns()
.get(idx)
.unwrap()
.name
.as_ref()
.map_or("", |v| v)
),
});
program.preassign_label_to_next_insn(record_label);
}
let record_reg = program.alloc_register();
program.emit_insn(Insn::MakeRecord {
start_reg: start,
@@ -1012,7 +1101,7 @@ fn emit_update_insns(
program.emit_insn(Insn::Insert {
cursor: cursor_id,
key_reg: beg,
key_reg: rowid_set_clause_reg.unwrap_or(beg),
record_reg,
flag: 0,
table_name: table_ref.identifier.clone(),
@@ -1035,5 +1124,10 @@ fn emit_update_insns(
})
}
// TODO(pthorpe): handle RETURNING clause
if let Some(label) = check_rowid_not_exists_label {
program.preassign_label_to_next_insn(label);
}
Ok(())
}

View File

@@ -304,9 +304,8 @@ def generate_test(col_amount: int, primary_keys: int) -> ConstraintTest:
update_errors = []
if len(insert_stmts) > 1:
update_errors = [
table.generate_update() for _ in table.columns if col.primary_key
]
# TODO: As we have no rollback we just generate one update statement
update_errors = [table.generate_update()]
return ConstraintTest(
table=table,
@@ -327,7 +326,6 @@ def custom_test_1() -> ConstraintTest:
"INSERT INTO users VALUES (2, 'bob');",
]
update_stmts = [
"UPDATE users SET id = 3;",
"UPDATE users SET id = 2, username = 'bob' WHERE id == 1;",
]
return ConstraintTest(