mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-05 00:04:23 +01:00
triggers: add translation logic for UPDATE triggers
This commit is contained in:
@@ -639,7 +639,6 @@ fn emit_program_for_delete(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
pub fn emit_fk_child_decrement_on_delete(
|
||||
program: &mut ProgramBuilder,
|
||||
resolver: &Resolver,
|
||||
@@ -1297,6 +1296,18 @@ fn emit_program_for_update(
|
||||
UpdateRowSource::Normal
|
||||
});
|
||||
|
||||
let join_order = plan
|
||||
.table_references
|
||||
.joined_tables()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| JoinOrderMember {
|
||||
table_id: t.internal_id,
|
||||
original_idx: i,
|
||||
is_outer: false,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Initialize the main loop
|
||||
init_loop(
|
||||
program,
|
||||
@@ -1306,7 +1317,7 @@ fn emit_program_for_update(
|
||||
None,
|
||||
mode.clone(),
|
||||
&plan.where_clause,
|
||||
&[JoinOrderMember::default()],
|
||||
&join_order,
|
||||
&mut [],
|
||||
)?;
|
||||
|
||||
@@ -1340,7 +1351,7 @@ fn emit_program_for_update(
|
||||
program,
|
||||
&mut t_ctx,
|
||||
&plan.table_references,
|
||||
&[JoinOrderMember::default()],
|
||||
&join_order,
|
||||
&plan.where_clause,
|
||||
temp_cursor_id,
|
||||
mode.clone(),
|
||||
@@ -1378,7 +1389,7 @@ fn emit_program_for_update(
|
||||
program,
|
||||
&mut t_ctx,
|
||||
&plan.table_references,
|
||||
&[JoinOrderMember::default()],
|
||||
&join_order,
|
||||
mode.clone(),
|
||||
)?;
|
||||
|
||||
@@ -1390,6 +1401,168 @@ fn emit_program_for_update(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper function to evaluate SET expressions and read column values for UPDATE.
|
||||
/// This is invoked once for every UPDATE, but will be invoked again if there are
|
||||
/// any BEFORE UPDATE triggers that fired, because the triggers may have modified the row,
|
||||
/// in which case the previously read values are stale.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn emit_update_column_values<'a>(
|
||||
program: &mut ProgramBuilder,
|
||||
table_references: &mut TableReferences,
|
||||
set_clauses: &[(usize, Box<ast::Expr>)],
|
||||
cdc_update_alter_statement: Option<&str>,
|
||||
target_table: &Arc<JoinedTable>,
|
||||
target_table_cursor_id: usize,
|
||||
start: usize,
|
||||
col_len: usize,
|
||||
table_name: &str,
|
||||
has_direct_rowid_update: bool,
|
||||
has_user_provided_rowid: bool,
|
||||
rowid_set_clause_reg: Option<usize>,
|
||||
is_virtual: bool,
|
||||
index: &Option<(Arc<Index>, usize)>,
|
||||
cdc_updates_register: Option<usize>,
|
||||
t_ctx: &mut TranslateCtx<'a>,
|
||||
skip_set_clauses: bool,
|
||||
) -> crate::Result<()> {
|
||||
if has_direct_rowid_update {
|
||||
if let Some((_, expr)) = set_clauses.iter().find(|(i, _)| *i == ROWID_SENTINEL) {
|
||||
if !skip_set_clauses {
|
||||
let rowid_set_clause_reg = rowid_set_clause_reg.unwrap();
|
||||
translate_expr(
|
||||
program,
|
||||
Some(table_references),
|
||||
expr,
|
||||
rowid_set_clause_reg,
|
||||
&t_ctx.resolver,
|
||||
)?;
|
||||
program.emit_insn(Insn::MustBeInt {
|
||||
reg: rowid_set_clause_reg,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
for (idx, table_column) in target_table.table.columns().iter().enumerate() {
|
||||
let target_reg = start + idx;
|
||||
if let Some((col_idx, expr)) = set_clauses.iter().find(|(i, _)| *i == idx) {
|
||||
if !skip_set_clauses {
|
||||
// Skip if this is the sentinel value
|
||||
if *col_idx == ROWID_SENTINEL {
|
||||
continue;
|
||||
}
|
||||
if has_user_provided_rowid
|
||||
&& (table_column.primary_key() || table_column.is_rowid_alias())
|
||||
&& !is_virtual
|
||||
{
|
||||
let rowid_set_clause_reg = rowid_set_clause_reg.unwrap();
|
||||
translate_expr(
|
||||
program,
|
||||
Some(table_references),
|
||||
expr,
|
||||
rowid_set_clause_reg,
|
||||
&t_ctx.resolver,
|
||||
)?;
|
||||
|
||||
program.emit_insn(Insn::MustBeInt {
|
||||
reg: rowid_set_clause_reg,
|
||||
});
|
||||
|
||||
program.emit_null(target_reg, None);
|
||||
} else {
|
||||
translate_expr(
|
||||
program,
|
||||
Some(table_references),
|
||||
expr,
|
||||
target_reg,
|
||||
&t_ctx.resolver,
|
||||
)?;
|
||||
if table_column.notnull() {
|
||||
use crate::error::SQLITE_CONSTRAINT_NOTNULL;
|
||||
program.emit_insn(Insn::HaltIfNull {
|
||||
target_reg,
|
||||
err_code: SQLITE_CONSTRAINT_NOTNULL,
|
||||
description: format!(
|
||||
"{}.{}",
|
||||
table_name,
|
||||
table_column
|
||||
.name
|
||||
.as_ref()
|
||||
.expect("Column name must be present")
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cdc_updates_register) = cdc_updates_register {
|
||||
let change_reg = cdc_updates_register + idx;
|
||||
let value_reg = cdc_updates_register + col_len + idx;
|
||||
program.emit_bool(true, change_reg);
|
||||
program.mark_last_insn_constant();
|
||||
let mut updated = false;
|
||||
if let Some(ddl_query_for_cdc_update) = cdc_update_alter_statement {
|
||||
if table_column.name.as_deref() == Some("sql") {
|
||||
program.emit_string8(ddl_query_for_cdc_update.to_string(), value_reg);
|
||||
updated = true;
|
||||
}
|
||||
}
|
||||
if !updated {
|
||||
program.emit_insn(Insn::Copy {
|
||||
src_reg: target_reg,
|
||||
dst_reg: value_reg,
|
||||
extra_amount: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Column is not being updated, read it from the table
|
||||
let column_idx_in_index = index.as_ref().and_then(|(idx, _)| {
|
||||
idx.columns
|
||||
.iter()
|
||||
.position(|c| Some(&c.name) == table_column.name.as_ref())
|
||||
});
|
||||
|
||||
// don't emit null for pkey of virtual tables. they require first two args
|
||||
// before the 'record' to be explicitly non-null
|
||||
if table_column.is_rowid_alias() && !is_virtual {
|
||||
program.emit_null(target_reg, None);
|
||||
} else if is_virtual {
|
||||
program.emit_insn(Insn::VColumn {
|
||||
cursor_id: target_table_cursor_id,
|
||||
column: idx,
|
||||
dest: target_reg,
|
||||
});
|
||||
} else {
|
||||
let cursor_id = *index
|
||||
.as_ref()
|
||||
.and_then(|(_, id)| {
|
||||
if column_idx_in_index.is_some() {
|
||||
Some(id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(&target_table_cursor_id);
|
||||
program.emit_column_or_rowid(
|
||||
cursor_id,
|
||||
column_idx_in_index.unwrap_or(idx),
|
||||
target_reg,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(cdc_updates_register) = cdc_updates_register {
|
||||
let change_bit_reg = cdc_updates_register + idx;
|
||||
let value_reg = cdc_updates_register + col_len + idx;
|
||||
program.emit_bool(false, change_bit_reg);
|
||||
program.mark_last_insn_constant();
|
||||
program.emit_null(value_reg, None);
|
||||
program.mark_last_insn_constant();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all, level = Level::DEBUG)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Emits the instructions for the UPDATE loop.
|
||||
@@ -1534,134 +1707,155 @@ fn emit_update_insns<'a>(
|
||||
|
||||
let start = if is_virtual { beg + 2 } else { beg + 1 };
|
||||
|
||||
if has_direct_rowid_update {
|
||||
if let Some((_, expr)) = set_clauses.iter().find(|(i, _)| *i == ROWID_SENTINEL) {
|
||||
let rowid_set_clause_reg = rowid_set_clause_reg.unwrap();
|
||||
translate_expr(
|
||||
program,
|
||||
Some(table_references),
|
||||
expr,
|
||||
rowid_set_clause_reg,
|
||||
&t_ctx.resolver,
|
||||
)?;
|
||||
program.emit_insn(Insn::MustBeInt {
|
||||
reg: rowid_set_clause_reg,
|
||||
});
|
||||
}
|
||||
}
|
||||
for (idx, table_column) in target_table.table.columns().iter().enumerate() {
|
||||
let target_reg = start + idx;
|
||||
if let Some((col_idx, expr)) = set_clauses.iter().find(|(i, _)| *i == idx) {
|
||||
// Skip if this is the sentinel value
|
||||
if *col_idx == ROWID_SENTINEL {
|
||||
continue;
|
||||
}
|
||||
if has_user_provided_rowid
|
||||
&& (table_column.primary_key() || table_column.is_rowid_alias())
|
||||
&& !is_virtual
|
||||
{
|
||||
let rowid_set_clause_reg = rowid_set_clause_reg.unwrap();
|
||||
translate_expr(
|
||||
program,
|
||||
Some(table_references),
|
||||
expr,
|
||||
rowid_set_clause_reg,
|
||||
&t_ctx.resolver,
|
||||
)?;
|
||||
let skip_set_clauses = false;
|
||||
|
||||
program.emit_insn(Insn::MustBeInt {
|
||||
reg: rowid_set_clause_reg,
|
||||
});
|
||||
emit_update_column_values(
|
||||
program,
|
||||
table_references,
|
||||
set_clauses,
|
||||
cdc_update_alter_statement,
|
||||
&target_table,
|
||||
target_table_cursor_id,
|
||||
start,
|
||||
col_len,
|
||||
table_name,
|
||||
has_direct_rowid_update,
|
||||
has_user_provided_rowid,
|
||||
rowid_set_clause_reg,
|
||||
is_virtual,
|
||||
&index,
|
||||
cdc_updates_register,
|
||||
t_ctx,
|
||||
skip_set_clauses,
|
||||
)?;
|
||||
|
||||
program.emit_null(target_reg, None);
|
||||
// Fire BEFORE UPDATE triggers and preserve old_registers for AFTER triggers
|
||||
let preserved_old_registers: Option<Vec<usize>> =
|
||||
if let Some(btree_table) = target_table.table.btree() {
|
||||
let updated_column_indices: std::collections::HashSet<usize> =
|
||||
set_clauses.iter().map(|(col_idx, _)| *col_idx).collect();
|
||||
let relevant_before_update_triggers = get_relevant_triggers_type_and_time(
|
||||
t_ctx.resolver.schema,
|
||||
TriggerEvent::Update,
|
||||
TriggerTime::Before,
|
||||
Some(updated_column_indices.clone()),
|
||||
&btree_table,
|
||||
);
|
||||
// Read OLD row values for trigger context
|
||||
let old_registers: Vec<usize> = (0..col_len)
|
||||
.map(|i| {
|
||||
let reg = program.alloc_register();
|
||||
program.emit_column_or_rowid(target_table_cursor_id, i, reg);
|
||||
reg
|
||||
})
|
||||
.chain(std::iter::once(beg))
|
||||
.collect();
|
||||
let has_relevant_triggers = relevant_before_update_triggers.clone().count() > 0;
|
||||
if !has_relevant_triggers {
|
||||
Some(old_registers)
|
||||
} else {
|
||||
translate_expr(
|
||||
program,
|
||||
Some(table_references),
|
||||
expr,
|
||||
target_reg,
|
||||
&t_ctx.resolver,
|
||||
)?;
|
||||
if table_column.notnull() {
|
||||
use crate::error::SQLITE_CONSTRAINT_NOTNULL;
|
||||
program.emit_insn(Insn::HaltIfNull {
|
||||
target_reg,
|
||||
err_code: SQLITE_CONSTRAINT_NOTNULL,
|
||||
description: format!(
|
||||
"{}.{}",
|
||||
table_name,
|
||||
table_column
|
||||
.name
|
||||
.as_ref()
|
||||
.expect("Column name must be present")
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
// NEW row values are already in 'start' registers
|
||||
let new_registers = (0..col_len)
|
||||
.map(|i| start + i)
|
||||
.chain(std::iter::once(beg))
|
||||
.collect();
|
||||
|
||||
if let Some(cdc_updates_register) = cdc_updates_register {
|
||||
let change_reg = cdc_updates_register + idx;
|
||||
let value_reg = cdc_updates_register + col_len + idx;
|
||||
program.emit_bool(true, change_reg);
|
||||
program.mark_last_insn_constant();
|
||||
let mut updated = false;
|
||||
if let Some(ddl_query_for_cdc_update) = &cdc_update_alter_statement {
|
||||
if table_column.name.as_deref() == Some("sql") {
|
||||
program.emit_string8(ddl_query_for_cdc_update.to_string(), value_reg);
|
||||
updated = true;
|
||||
}
|
||||
let trigger_ctx = TriggerContext::new(
|
||||
btree_table.clone(),
|
||||
Some(new_registers),
|
||||
Some(old_registers.clone()), // Clone for AFTER trigger
|
||||
);
|
||||
|
||||
// Extract updated column indices for UPDATE OF trigger filtering
|
||||
|
||||
for trigger in relevant_before_update_triggers {
|
||||
fire_trigger(
|
||||
program,
|
||||
&mut t_ctx.resolver,
|
||||
trigger,
|
||||
&trigger_ctx,
|
||||
connection,
|
||||
)?;
|
||||
}
|
||||
if !updated {
|
||||
program.emit_insn(Insn::Copy {
|
||||
src_reg: target_reg,
|
||||
dst_reg: value_reg,
|
||||
extra_amount: 0,
|
||||
});
|
||||
|
||||
// BEFORE UPDATE Triggers may have altered the btree so we need to seek again.
|
||||
program.emit_insn(Insn::NotExists {
|
||||
cursor: target_table_cursor_id,
|
||||
rowid_reg: beg,
|
||||
target_pc: check_rowid_not_exists_label.expect(
|
||||
"check_rowid_not_exists_label must be set if there are BEFORE UPDATE triggers",
|
||||
),
|
||||
});
|
||||
|
||||
let has_relevant_after_triggers = get_relevant_triggers_type_and_time(
|
||||
t_ctx.resolver.schema,
|
||||
TriggerEvent::Update,
|
||||
TriggerTime::After,
|
||||
Some(updated_column_indices),
|
||||
&btree_table,
|
||||
)
|
||||
.clone()
|
||||
.count()
|
||||
> 0;
|
||||
if has_relevant_after_triggers {
|
||||
// Preserve pseudo-row 'OLD' for AFTER triggers by copying to new registers
|
||||
// (since registers might be overwritten during trigger execution)
|
||||
let preserved: Vec<usize> = old_registers
|
||||
.iter()
|
||||
.map(|old_reg| {
|
||||
let preserved_reg = program.alloc_register();
|
||||
program.emit_insn(Insn::Copy {
|
||||
src_reg: *old_reg,
|
||||
dst_reg: preserved_reg,
|
||||
extra_amount: 0,
|
||||
});
|
||||
preserved_reg
|
||||
})
|
||||
.collect();
|
||||
Some(preserved)
|
||||
} else {
|
||||
Some(old_registers)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let column_idx_in_index = index.as_ref().and_then(|(idx, _)| {
|
||||
idx.columns
|
||||
.iter()
|
||||
.position(|c| Some(&c.name) == table_column.name.as_ref())
|
||||
});
|
||||
None
|
||||
};
|
||||
|
||||
// don't emit null for pkey of virtual tables. they require first two args
|
||||
// before the 'record' to be explicitly non-null
|
||||
if table_column.is_rowid_alias() && !is_virtual {
|
||||
program.emit_null(target_reg, None);
|
||||
} else if is_virtual {
|
||||
program.emit_insn(Insn::VColumn {
|
||||
cursor_id: target_table_cursor_id,
|
||||
column: idx,
|
||||
dest: target_reg,
|
||||
});
|
||||
} else {
|
||||
let cursor_id = *index
|
||||
.as_ref()
|
||||
.and_then(|(_, id)| {
|
||||
if column_idx_in_index.is_some() {
|
||||
Some(id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(&target_table_cursor_id);
|
||||
program.emit_column_or_rowid(
|
||||
cursor_id,
|
||||
column_idx_in_index.unwrap_or(idx),
|
||||
target_reg,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(cdc_updates_register) = cdc_updates_register {
|
||||
let change_bit_reg = cdc_updates_register + idx;
|
||||
let value_reg = cdc_updates_register + col_len + idx;
|
||||
program.emit_bool(false, change_bit_reg);
|
||||
program.mark_last_insn_constant();
|
||||
program.emit_null(value_reg, None);
|
||||
program.mark_last_insn_constant();
|
||||
}
|
||||
// If BEFORE UPDATE triggers fired, they may have modified the row being updated.
|
||||
// According to the SQLite documentation, the behavior in these cases is undefined:
|
||||
// https://sqlite.org/lang_createtrigger.html
|
||||
// However, based on fuzz testing and observations, the logic seems to be:
|
||||
// The values that are NOT referred to in SET clauses will be evaluated again,
|
||||
// and values in SET clauses are evaluated using the old values.
|
||||
// sqlite> create table t(c0,c1,c2);
|
||||
// sqlite> create trigger tu before update on t begin update t set c1=666, c2=666; end;
|
||||
// sqlite> insert into t values (1,1,1);
|
||||
// sqlite> update t set c0 = c1+1;
|
||||
// sqlite> select * from t;
|
||||
// 2|666|666
|
||||
if target_table.table.btree().is_some() {
|
||||
let before_update_triggers_fired = preserved_old_registers.is_some();
|
||||
let skip_set_clauses = true;
|
||||
if before_update_triggers_fired {
|
||||
emit_update_column_values(
|
||||
program,
|
||||
table_references,
|
||||
set_clauses,
|
||||
cdc_update_alter_statement,
|
||||
&target_table,
|
||||
target_table_cursor_id,
|
||||
start,
|
||||
col_len,
|
||||
table_name,
|
||||
has_direct_rowid_update,
|
||||
has_user_provided_rowid,
|
||||
rowid_set_clause_reg,
|
||||
is_virtual,
|
||||
&index,
|
||||
cdc_updates_register,
|
||||
t_ctx,
|
||||
skip_set_clauses,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2067,6 +2261,46 @@ fn emit_update_insns<'a>(
|
||||
table_name: target_table.identifier.clone(),
|
||||
});
|
||||
|
||||
// Fire AFTER UPDATE triggers
|
||||
if let Some(btree_table) = target_table.table.btree() {
|
||||
let updated_column_indices: std::collections::HashSet<usize> =
|
||||
set_clauses.iter().map(|(col_idx, _)| *col_idx).collect();
|
||||
let relevant_triggers = get_relevant_triggers_type_and_time(
|
||||
t_ctx.resolver.schema,
|
||||
TriggerEvent::Update,
|
||||
TriggerTime::After,
|
||||
Some(updated_column_indices),
|
||||
&btree_table,
|
||||
);
|
||||
let has_relevant_triggers = relevant_triggers.clone().count() > 0;
|
||||
if has_relevant_triggers {
|
||||
let new_rowid_reg = rowid_set_clause_reg.unwrap_or(beg);
|
||||
let new_registers_after = (0..col_len)
|
||||
.map(|i| start + i)
|
||||
.chain(std::iter::once(new_rowid_reg))
|
||||
.collect();
|
||||
|
||||
// Use preserved OLD registers from BEFORE trigger
|
||||
let old_registers_after = preserved_old_registers;
|
||||
|
||||
let trigger_ctx_after = TriggerContext::new(
|
||||
btree_table.clone(),
|
||||
Some(new_registers_after),
|
||||
old_registers_after, // OLD values preserved from BEFORE trigger
|
||||
);
|
||||
|
||||
for trigger in relevant_triggers {
|
||||
fire_trigger(
|
||||
program,
|
||||
&mut t_ctx.resolver,
|
||||
trigger,
|
||||
&trigger_ctx_after,
|
||||
connection,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit RETURNING results if specified
|
||||
if let Some(returning_columns) = &returning {
|
||||
if !returning_columns.is_empty() {
|
||||
|
||||
Reference in New Issue
Block a user