From eee7fa5f950469f3d34f558e824c03ef1d356510 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Thu, 13 Nov 2025 09:35:38 +0200 Subject: [PATCH] Refactor RETURNING to support arbitrary expressions Currently RETURNING was a bit of a hack since it had a special translate_expr_for_returning() function that only supported a subset of expressions. Instead, we can store the columns of the target table of the INSERT/UPDATE/DELETE we are RETURNING from in `Resolver::expr_to_reg_cache` and make those columns point to the registers that hold the OLD/NEW column values (depending on the operation). --- core/translate/emitter.rs | 110 ++++++++------- core/translate/expr.rs | 284 ++++++++++---------------------------- core/translate/insert.rs | 85 +++++++----- core/translate/mod.rs | 6 +- core/translate/update.rs | 9 +- core/translate/upsert.rs | 21 +-- 6 files changed, 205 insertions(+), 310 deletions(-) diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 2a9ff3508..cd962562d 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -1,7 +1,6 @@ // This module contains code for emitting bytecode instructions for SQL query execution. // It handles translating high-level SQL operations into low-level bytecode that can be executed by the virtual machine. -use std::collections::HashSet; use std::num::NonZeroUsize; use std::sync::Arc; @@ -24,18 +23,20 @@ use super::select::emit_simple_count; use super::subquery::emit_from_clause_subqueries; use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; use crate::function::Func; -use crate::schema::{BTreeTable, Column, Schema, Table, ROWID_SENTINEL}; +use crate::schema::{BTreeTable, Column, Index, Schema, Table, ROWID_SENTINEL}; use crate::translate::compound_select::emit_program_for_compound_select; use crate::translate::expr::{ emit_returning_results, translate_expr_no_constant_opt, walk_expr_mut, NoConstantOptReason, - ReturningValueRegisters, WalkControl, + WalkControl, }; use crate::translate::fkeys::{ build_index_affinity_string, emit_fk_child_update_counters, emit_fk_delete_parent_existence_checks, emit_guarded_fk_decrement, emit_parent_key_change_checks, open_read_index, open_read_table, stabilize_new_row_for_fk, }; -use crate::translate::plan::{DeletePlan, EvalAt, JoinedTable, Plan, QueryDestination, Search}; +use crate::translate::plan::{ + DeletePlan, EvalAt, JoinedTable, Plan, QueryDestination, ResultSetColumn, Search, +}; use crate::translate::planner::ROWID_STRS; use crate::translate::subquery::emit_non_from_clause_subquery; use crate::translate::values::emit_values; @@ -665,12 +666,12 @@ pub fn emit_fk_child_decrement_on_delete( Ok(()) } -fn emit_delete_insns( +fn emit_delete_insns<'a>( connection: &Arc, program: &mut ProgramBuilder, - t_ctx: &mut TranslateCtx, + t_ctx: &mut TranslateCtx<'a>, table_references: &mut TableReferences, - result_columns: &[super::plan::ResultSetColumn], + result_columns: &'a [super::plan::ResultSetColumn], ) -> Result<()> { // we can either use this obviously safe raw pointer or we can clone it let table_reference: *const JoinedTable = table_references.joined_tables().first().unwrap(); @@ -873,13 +874,14 @@ fn emit_delete_insns( } // Emit RETURNING results using the values we just read - let value_registers = ReturningValueRegisters { - rowid_register: rowid_reg, - columns_start_register: columns_start_reg, - num_columns: cols_len, - }; - - emit_returning_results(program, result_columns, &value_registers)?; + emit_returning_results( + program, + table_references, + result_columns, + columns_start_reg, + rowid_reg, + &mut t_ctx.resolver, + )?; } program.emit_insn(Insn::Delete { @@ -1042,8 +1044,13 @@ fn emit_program_for_update( // Emit update instructions emit_update_insns( connection, - &mut plan, - &t_ctx, + &mut plan.table_references, + &plan.set_clauses, + plan.cdc_update_alter_statement.as_deref(), + &plan.indexes_to_update, + plan.returning.as_ref(), + plan.ephemeral_plan.as_ref(), + &mut t_ctx, program, index_cursors, iteration_cursor_id, @@ -1077,10 +1084,16 @@ fn emit_program_for_update( /// `target_table_cursor_id` is the cursor id of the table that is being updated. /// /// `target_table` is the table that is being updated. -fn emit_update_insns( +#[allow(clippy::too_many_arguments)] +fn emit_update_insns<'a>( connection: &Arc, - plan: &mut UpdatePlan, - t_ctx: &TranslateCtx, + table_references: &mut TableReferences, + set_clauses: &[(usize, Box)], + cdc_update_alter_statement: Option<&str>, + indexes_to_update: &[Arc], + returning: Option<&'a Vec>, + ephemeral_plan: Option<&SelectPlan>, + t_ctx: &mut TranslateCtx<'a>, program: &mut ProgramBuilder, index_cursors: Vec<(usize, usize)>, iteration_cursor_id: usize, @@ -1089,7 +1102,7 @@ fn emit_update_insns( ) -> crate::Result<()> { let internal_id = target_table.internal_id; let loop_labels = t_ctx.labels_main_loop.first().unwrap(); - let source_table = plan.table_references.joined_tables().first().unwrap(); + let source_table = table_references.joined_tables().first().unwrap(); let (index, is_virtual) = match &source_table.op { Operation::Scan(Scan::BTreeTable { index, .. }) => ( index.as_ref().map(|index| { @@ -1138,13 +1151,10 @@ fn emit_update_insns( .iter() .position(|c| c.is_rowid_alias()); - let has_direct_rowid_update = plan - .set_clauses - .iter() - .any(|(idx, _)| *idx == ROWID_SENTINEL); + let has_direct_rowid_update = set_clauses.iter().any(|(idx, _)| *idx == ROWID_SENTINEL); let has_user_provided_rowid = if let Some(index) = rowid_alias_index { - plan.set_clauses.iter().any(|(idx, _)| *idx == index) + set_clauses.iter().any(|(idx, _)| *idx == index) } else { has_direct_rowid_update }; @@ -1210,11 +1220,11 @@ fn emit_update_insns( let start = if is_virtual { beg + 2 } else { beg + 1 }; if has_direct_rowid_update { - if let Some((_, expr)) = plan.set_clauses.iter().find(|(i, _)| *i == ROWID_SENTINEL) { + if let Some((_, expr)) = set_clauses.iter().find(|(i, _)| *i == ROWID_SENTINEL) { let rowid_set_clause_reg = rowid_set_clause_reg.unwrap(); translate_expr( program, - Some(&plan.table_references), + Some(table_references), expr, rowid_set_clause_reg, &t_ctx.resolver, @@ -1226,7 +1236,7 @@ fn emit_update_insns( } for (idx, table_column) in target_table.table.columns().iter().enumerate() { let target_reg = start + idx; - if let Some((col_idx, expr)) = plan.set_clauses.iter().find(|(i, _)| *i == idx) { + if let Some((col_idx, expr)) = set_clauses.iter().find(|(i, _)| *i == idx) { // Skip if this is the sentinel value if *col_idx == ROWID_SENTINEL { continue; @@ -1238,7 +1248,7 @@ fn emit_update_insns( let rowid_set_clause_reg = rowid_set_clause_reg.unwrap(); translate_expr( program, - Some(&plan.table_references), + Some(table_references), expr, rowid_set_clause_reg, &t_ctx.resolver, @@ -1252,7 +1262,7 @@ fn emit_update_insns( } else { translate_expr( program, - Some(&plan.table_references), + Some(table_references), expr, target_reg, &t_ctx.resolver, @@ -1280,9 +1290,9 @@ fn emit_update_insns( program.emit_bool(true, change_reg); program.mark_last_insn_constant(); let mut updated = false; - if let Some(ddl_query_for_cdc_update) = &plan.cdc_update_alter_statement { + if let Some(ddl_query_for_cdc_update) = &cdc_update_alter_statement { if table_column.name.as_deref() == Some("sql") { - program.emit_string8(ddl_query_for_cdc_update.clone(), value_reg); + program.emit_string8(ddl_query_for_cdc_update.to_string(), value_reg); updated = true; } } @@ -1346,7 +1356,7 @@ fn emit_update_insns( stabilize_new_row_for_fk( program, &table_btree, - &plan.set_clauses, + set_clauses, target_table_cursor_id, start, rowid_new_reg, @@ -1362,11 +1372,10 @@ fn emit_update_insns( target_table_cursor_id, start, rowid_new_reg, - &plan - .set_clauses + &set_clauses .iter() .map(|(i, _)| *i) - .collect::>(), + .collect::>(), )?; } // Parent-side checks: @@ -1382,19 +1391,19 @@ fn emit_update_insns( program, &t_ctx.resolver, &table_btree, - plan.indexes_to_update.iter(), + indexes_to_update.iter(), target_table_cursor_id, beg, start, rowid_new_reg, rowid_set_clause_reg, - &plan.set_clauses, + set_clauses, )?; } } } - for (index, (idx_cursor_id, record_reg)) in plan.indexes_to_update.iter().zip(&index_cursors) { + for (index, (idx_cursor_id, record_reg)) in indexes_to_update.iter().zip(&index_cursors) { // We need to know whether or not the OLD values satisfied the predicate on the // partial index, so we can know whether or not to delete the old index entry, // as well as whether or not the NEW values satisfy the predicate, to determine whether @@ -1403,12 +1412,12 @@ fn emit_update_insns( // This means that we need to bind the column references to a copy of the index Expr, // so we can emit Insn::Column instructions and refer to the old values. let where_clause = index - .bind_where_expr(Some(&mut plan.table_references), connection) + .bind_where_expr(Some(table_references), connection) .expect("where clause to exist"); let old_satisfied_reg = program.alloc_register(); translate_expr_no_constant_opt( program, - Some(&plan.table_references), + Some(table_references), &where_clause, old_satisfied_reg, &t_ctx.resolver, @@ -1744,15 +1753,16 @@ fn emit_update_insns( }); // Emit RETURNING results if specified - if let Some(returning_columns) = &plan.returning { + if let Some(returning_columns) = &returning { if !returning_columns.is_empty() { - let value_registers = ReturningValueRegisters { - rowid_register: rowid_set_clause_reg.unwrap_or(beg), - columns_start_register: start, - num_columns: col_len, - }; - - emit_returning_results(program, returning_columns, &value_registers)?; + emit_returning_results( + program, + table_references, + returning_columns, + start, + rowid_set_clause_reg.unwrap_or(beg), + &mut t_ctx.resolver, + )?; } } @@ -1814,7 +1824,7 @@ fn emit_update_insns( emit_cdc_insns( program, &t_ctx.resolver, - OperationMode::UPDATE(if plan.ephemeral_plan.is_some() { + OperationMode::UPDATE(if ephemeral_plan.is_some() { UpdateRowSource::PrebuiltEphemeralTable { ephemeral_table_cursor_id: iteration_cursor_id, target_table: target_table.clone(), diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 2f0f9a4a9..fbda8897d 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use tracing::{instrument, Level}; -use turso_parser::ast::{self, As, Expr, SubqueryType, UnaryOperator}; +use turso_parser::ast::{self, Expr, SubqueryType, UnaryOperator}; use super::emitter::Resolver; use super::optimizer::Optimizable; @@ -22,7 +22,7 @@ use crate::vdbe::{ insn::{CmpInsFlags, Insn}, BranchOffset, }; -use crate::{Result, Value}; +use crate::{turso_assert, Result, Value}; use super::collate::CollationSeq; @@ -34,16 +34,6 @@ pub struct ConditionMetadata { pub jump_target_when_null: BranchOffset, } -/// Container for register locations of values that can be referenced in RETURNING expressions -pub struct ReturningValueRegisters { - /// Register containing the rowid/primary key - pub rowid_register: usize, - /// Starting register for column values (in column order) - pub columns_start_register: usize, - /// Number of columns available - pub num_columns: usize, -} - #[instrument(skip_all, level = Level::DEBUG)] fn emit_cond_jump(program: &mut ProgramBuilder, cond_meta: ConditionMetadata, reg: usize) { if cond_meta.jump_if_condition_is_true { @@ -4143,106 +4133,6 @@ pub fn compare_affinity( } } -/// Evaluate a RETURNING expression using register-based evaluation instead of cursor-based. -/// This is used for RETURNING clauses where we have register values instead of cursor data. -pub fn translate_expr_for_returning( - program: &mut ProgramBuilder, - expr: &Expr, - value_registers: &ReturningValueRegisters, - target_register: usize, -) -> Result { - match expr { - Expr::Column { - column, - is_rowid_alias, - .. - } => { - if *is_rowid_alias { - // For rowid references, copy from the rowid register - program.emit_insn(Insn::Copy { - src_reg: value_registers.rowid_register, - dst_reg: target_register, - extra_amount: 0, - }); - } else { - // For regular column references, copy from the appropriate column register - let column_idx = *column; - if column_idx < value_registers.num_columns { - let column_reg = value_registers.columns_start_register + column_idx; - program.emit_insn(Insn::Copy { - src_reg: column_reg, - dst_reg: target_register, - extra_amount: 0, - }); - } else { - crate::bail_parse_error!("Column index out of bounds in RETURNING clause"); - } - } - Ok(target_register) - } - Expr::RowId { .. } => { - // For ROWID expressions, copy from the rowid register - program.emit_insn(Insn::Copy { - src_reg: value_registers.rowid_register, - dst_reg: target_register, - extra_amount: 0, - }); - Ok(target_register) - } - Expr::Literal(literal) => emit_literal(program, literal, target_register), - Expr::Binary(lhs, op, rhs) => { - let lhs_reg = program.alloc_register(); - let rhs_reg = program.alloc_register(); - - // Recursively evaluate left-hand side - translate_expr_for_returning(program, lhs, value_registers, lhs_reg)?; - - // Recursively evaluate right-hand side - translate_expr_for_returning(program, rhs, value_registers, rhs_reg)?; - - // Use the shared emit_binary_insn function - emit_binary_insn( - program, - op, - lhs_reg, - rhs_reg, - target_register, - lhs, - rhs, - None, // No table references needed for RETURNING - None, // No condition metadata needed for RETURNING - )?; - - Ok(target_register) - } - Expr::FunctionCall { name, args, .. } => { - // Evaluate arguments into registers - let mut arg_regs = Vec::new(); - for arg in args.iter() { - let arg_reg = program.alloc_register(); - translate_expr_for_returning(program, arg, value_registers, arg_reg)?; - arg_regs.push(arg_reg); - } - - // Resolve and call the function using shared helper - let func = Func::resolve_function(name.as_str(), arg_regs.len())?; - let func_ctx = FuncCtx { - func, - arg_count: arg_regs.len(), - }; - - emit_function_call(program, func_ctx, &arg_regs, target_register)?; - Ok(target_register) - } - _ => { - crate::bail_parse_error!( - "Unsupported expression type in RETURNING clause: {:?}", - expr - ); - } - } -} - /// Emit literal values - shared between regular and RETURNING expression evaluation pub fn emit_literal( program: &mut ProgramBuilder, @@ -4363,56 +4253,40 @@ pub fn emit_function_call( /// with proper column binding and alias handling. pub fn process_returning_clause( returning: &mut [ast::ResultColumn], - table: &Table, - table_name: &str, - program: &mut ProgramBuilder, + table_references: &mut TableReferences, connection: &std::sync::Arc, -) -> Result<( - Vec, - super::plan::TableReferences, -)> { - use super::plan::{ColumnUsedMask, JoinedTable, Operation, ResultSetColumn, TableReferences}; +) -> Result> { + let mut result_columns = Vec::with_capacity(returning.len()); - let mut result_columns = vec![]; - - let internal_id = program.table_reference_counter.next(); - let mut table_references = TableReferences::new( - vec![JoinedTable { - table: match table { - Table::Virtual(vtab) => Table::Virtual(vtab.clone()), - Table::BTree(btree_table) => Table::BTree(btree_table.clone()), - _ => unreachable!(), - }, - identifier: table_name.to_string(), - internal_id, - op: Operation::default_scan_for(table), - join_info: None, - col_used_mask: ColumnUsedMask::default(), - database_id: 0, - }], - vec![], - ); + let alias_to_string = |alias: &ast::As| match alias { + ast::As::Elided(alias) => alias.as_str().to_string(), + ast::As::As(alias) => alias.as_str().to_string(), + }; for rc in returning.iter_mut() { match rc { ast::ResultColumn::Expr(expr, alias) => { bind_and_rewrite_expr( expr, - Some(&mut table_references), + Some(table_references), None, connection, BindingBehavior::TryResultColumnsFirst, )?; - let column_alias = determine_column_alias(expr, alias, table); - result_columns.push(ResultSetColumn { - expr: *expr.clone(), - alias: column_alias, + expr: expr.as_ref().clone(), + alias: alias.as_ref().map(alias_to_string), contains_aggregates: false, }); } ast::ResultColumn::Star => { + let table = table_references + .joined_tables() + .first() + .expect("RETURNING clause must reference at least one table"); + let internal_id = table.internal_id; + // Handle RETURNING * by expanding to all table columns // Use the shared internal_id for all columns for (column_index, column) in table.columns().iter().enumerate() { @@ -4420,7 +4294,7 @@ pub fn process_returning_clause( database: None, table: internal_id, column: column_index, - is_rowid_alias: false, + is_rowid_alias: column.is_rowid_alias(), }; result_columns.push(ResultSetColumn { @@ -4430,91 +4304,81 @@ pub fn process_returning_clause( }); } } - ast::ResultColumn::TableStar(_table_name) => { - // Handle RETURNING table.* by expanding to all table columns - // For single table RETURNING, this is equivalent to * - for (column_index, column) in table.columns().iter().enumerate() { - let column_expr = Expr::Column { - database: None, - table: internal_id, - column: column_index, - is_rowid_alias: false, - }; - - result_columns.push(ResultSetColumn { - expr: column_expr, - alias: column.name.clone(), - contains_aggregates: false, - }); - } + ast::ResultColumn::TableStar(_) => { + crate::bail_parse_error!("RETURNING may not use \"TABLE.*\" wildcards"); } } } - Ok((result_columns, table_references)) -} - -/// Determine the appropriate alias for a RETURNING column expression -fn determine_column_alias( - expr: &Expr, - explicit_alias: &Option, - table: &Table, -) -> Option { - // First check for explicit alias - if let Some(As::As(name)) = explicit_alias { - return Some(name.as_str().to_string()); - } - - // For ROWID expressions, use "rowid" as the alias - if let Expr::RowId { .. } = expr { - return Some("rowid".to_string()); - } - - // For column references, always use the column name from the table - if let Expr::Column { - column, - is_rowid_alias, - .. - } = expr - { - if let Some(name) = table - .columns() - .get(*column) - .and_then(|col| col.name.clone()) - { - return Some(name); - } else if *is_rowid_alias { - // If it's a rowid alias, return "rowid" - return Some("rowid".to_string()); - } else { - return None; - } - } - - // For other expressions, use the expression string representation - Some(expr.to_string()) + Ok(result_columns) } /// Emit bytecode to evaluate RETURNING expressions and produce result rows. -/// This function handles the actual evaluation of expressions using the values -/// from the DML operation. -pub(crate) fn emit_returning_results( +/// RETURNING result expressions are otherwise evaluated as normal, but the columns of the target table +/// are added to [Resolver::expr_to_reg_cache], meaning a reference to e.g tbl.col will effectively +/// refer to a register where the OLD/NEW value of tbl.col is stored after an INSERT/UPDATE/DELETE. +pub(crate) fn emit_returning_results<'a>( program: &mut ProgramBuilder, + table_references: &TableReferences, result_columns: &[super::plan::ResultSetColumn], - value_registers: &ReturningValueRegisters, + reg_columns_start: usize, + rowid_reg: usize, + resolver: &mut Resolver<'a>, ) -> Result<()> { if result_columns.is_empty() { return Ok(()); } + turso_assert!(table_references.joined_tables().len() == 1, "RETURNING is only used with INSERT, UPDATE, or DELETE statements, which target a single table"); + let table = table_references.joined_tables().first().unwrap(); + + resolver.enable_expr_to_reg_cache(); + let expr = Expr::RowId { + database: None, + table: table.internal_id, + }; + let cache_len = resolver.expr_to_reg_cache.len(); + resolver + .expr_to_reg_cache + .push((std::borrow::Cow::Owned(expr), rowid_reg)); + for (i, column) in table.columns().iter().enumerate() { + let reg = if column.is_rowid_alias() { + rowid_reg + } else { + reg_columns_start + i + }; + let expr = Expr::Column { + database: None, + table: table.internal_id, + column: i, + is_rowid_alias: column.is_rowid_alias(), + }; + resolver + .expr_to_reg_cache + .push((std::borrow::Cow::Owned(expr), reg)); + } + let result_start_reg = program.alloc_registers(result_columns.len()); for (i, result_column) in result_columns.iter().enumerate() { let reg = result_start_reg + i; - - translate_expr_for_returning(program, &result_column.expr, value_registers, reg)?; + translate_expr_no_constant_opt( + program, + Some(table_references), + &result_column.expr, + reg, + resolver, + NoConstantOptReason::RegisterReuse, + )?; } + // Bit of a hack: this is required in case of e.g. INSERT ... ON CONFLICT DO UPDATE ... RETURNING + // where the result column values may either be the ones that were inserted, or the ones that were updated, + // depending on the row in question. + // meaning: emit_returning_results() may be called twice during translation and the cached expression values + // must be distinct for each call. + resolver.expr_to_reg_cache.truncate(cache_len); + program.emit_insn(Insn::ResultRow { start_reg: result_start_reg, count: result_columns.len(), diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 3593415dd..a336f330c 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -13,13 +13,15 @@ use crate::translate::emitter::{ }; use crate::translate::expr::{ bind_and_rewrite_expr, emit_returning_results, process_returning_clause, walk_expr_mut, - BindingBehavior, ReturningValueRegisters, WalkControl, + BindingBehavior, WalkControl, }; use crate::translate::fkeys::{ build_index_affinity_string, emit_fk_violation, emit_guarded_fk_decrement, index_probe, open_read_index, open_read_table, }; -use crate::translate::plan::{ResultSetColumn, TableReferences}; +use crate::translate::plan::{ + ColumnUsedMask, JoinedTable, Operation, ResultSetColumn, TableReferences, +}; use crate::translate::planner::ROWID_STRS; use crate::translate::upsert::{ collect_set_clauses_for_upsert, emit_upsert, resolve_upsert_target, ResolvedUpsertTarget, @@ -134,7 +136,7 @@ pub struct InsertEmitCtx<'a> { impl<'a> InsertEmitCtx<'a> { fn new( program: &mut ProgramBuilder, - resolver: &'a Resolver, + resolver: &Resolver, table: &'a Arc, on_conflict: Option, cdc_table: Option<(usize, Arc)>, @@ -181,7 +183,7 @@ impl<'a> InsertEmitCtx<'a> { #[allow(clippy::too_many_arguments)] pub fn translate_insert( - resolver: &Resolver, + resolver: &mut Resolver, on_conflict: Option, tbl_name: QualifiedName, columns: Vec, @@ -240,14 +242,26 @@ pub fn translate_insert( let cdc_table = prepare_cdc_if_necessary(&mut program, resolver.schema, table.get_name())?; + let mut table_references = TableReferences::new( + vec![JoinedTable { + table: Table::BTree( + table + .btree() + .expect("we shouldn't have got here without a BTree table"), + ), + identifier: table_name.to_string(), + internal_id: program.table_reference_counter.next(), + op: Operation::default_scan_for(&table), + join_info: None, + col_used_mask: ColumnUsedMask::default(), + database_id: 0, + }], + vec![], + ); + // Process RETURNING clause using shared module - let (mut result_columns, _) = process_returning_clause( - &mut returning, - &table, - table_name.as_str(), - &mut program, - connection, - )?; + let mut result_columns = + process_returning_clause(&mut returning, &mut table_references, connection)?; let has_fks = fk_enabled && (resolver.schema.has_child_fks(table_name.as_str()) || resolver @@ -460,32 +474,37 @@ pub fn translate_insert( // Emit RETURNING results if specified if !result_columns.is_empty() { - let value_registers = ReturningValueRegisters { - rowid_register: insertion.key_register(), - columns_start_register: insertion.first_col_register(), - num_columns: table.columns().len(), - }; - - emit_returning_results(&mut program, &result_columns, &value_registers)?; + emit_returning_results( + &mut program, + &table_references, + &result_columns, + insertion.first_col_register(), + insertion.key_register(), + resolver, + )?; } program.emit_insn(Insn::Goto { target_pc: ctx.row_done_label, }); - - resolve_upserts( - &mut program, - resolver, - &mut upsert_actions, - &ctx, - &insertion, - &table, - &mut result_columns, - connection, - )?; + if !upsert_actions.is_empty() { + resolve_upserts( + &mut program, + resolver, + &mut upsert_actions, + &ctx, + &insertion, + &table, + &mut result_columns, + connection, + &table_references, + )?; + } emit_epilogue(&mut program, &ctx, inserting_multiple_rows); program.set_needs_stmt_subtransactions(true); + program.result_columns = result_columns; + program.table_references.extend(table_references); Ok(program) } @@ -544,7 +563,7 @@ fn emit_commit_phase( let idx_cursor_id = ctx .idx_cursors .iter() - .find(|(name, _, _)| *name == &index.name) + .find(|(name, _, _)| name == &index.name) .map(|(_, _, c_id)| *c_id) .expect("no cursor found for index"); @@ -739,13 +758,14 @@ fn emit_rowid_generation( #[allow(clippy::too_many_arguments)] fn resolve_upserts( program: &mut ProgramBuilder, - resolver: &Resolver, + resolver: &mut Resolver, upsert_actions: &mut [(ResolvedUpsertTarget, BranchOffset, Box)], ctx: &InsertEmitCtx, insertion: &Insertion, table: &Table, result_columns: &mut [ResultSetColumn], connection: &Arc, + table_references: &TableReferences, ) -> Result<()> { for (_, label, upsert) in upsert_actions { program.preassign_label_to_next_insn(*label); @@ -768,6 +788,7 @@ fn resolve_upserts( resolver, result_columns, connection, + table_references, )?; } else { // UpsertDo::Nothing case @@ -1708,7 +1729,7 @@ fn emit_preflight_constraint_checks( let idx_cursor_id = ctx .idx_cursors .iter() - .find(|(name, _, _)| *name == &index.name) + .find(|(name, _, _)| name == &index.name) .map(|(_, _, c_id)| *c_id) .expect("no cursor found for index"); diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 5d9b902ae..76e0cdddb 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -93,14 +93,14 @@ pub fn translate( ); program.prologue(); - let resolver = Resolver::new(schema, syms); + let mut resolver = Resolver::new(schema, syms); program = match stmt { // There can be no nesting with pragma, so lift it up here ast::Stmt::Pragma { name, body } => { pragma::translate_pragma(&resolver, &name, body, pager, connection.clone(), program)? } - stmt => translate_inner(stmt, &resolver, program, &connection, input)?, + stmt => translate_inner(stmt, &mut resolver, program, &connection, input)?, }; program.epilogue(schema); @@ -113,7 +113,7 @@ pub fn translate( /// Translate SQL statement into bytecode program. pub fn translate_inner( stmt: ast::Stmt, - resolver: &Resolver, + resolver: &mut Resolver, program: ProgramBuilder, connection: &Arc, input: &str, diff --git a/core/translate/update.rs b/core/translate/update.rs index 635e41051..281cb32f3 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -266,13 +266,8 @@ pub fn prepare_update_plan( } } - let (result_columns, _table_references) = process_returning_clause( - &mut body.returning, - &table, - body.tbl_name.name.as_str(), - program, - connection, - )?; + let result_columns = + process_returning_clause(&mut body.returning, &mut table_references, connection)?; let order_by = body .order_by diff --git a/core/translate/upsert.rs b/core/translate/upsert.rs index e37a03f0d..69ced5b81 100644 --- a/core/translate/upsert.rs +++ b/core/translate/upsert.rs @@ -10,6 +10,7 @@ use crate::translate::emitter::UpdateRowSource; use crate::translate::expr::{walk_expr, WalkControl}; use crate::translate::fkeys::{emit_fk_child_update_counters, emit_parent_key_change_checks}; use crate::translate::insert::{format_unique_violation_desc, InsertEmitCtx}; +use crate::translate::plan::TableReferences; use crate::translate::planner::ROWID_STRS; use crate::vdbe::insn::CmpInsFlags; use crate::Connection; @@ -23,7 +24,7 @@ use crate::{ }, expr::{ emit_returning_results, translate_expr, translate_expr_no_constant_opt, walk_expr_mut, - NoConstantOptReason, ReturningValueRegisters, + NoConstantOptReason, }, insert::Insertion, plan::ResultSetColumn, @@ -332,6 +333,7 @@ pub fn resolve_upsert_target( /// Semantics reference: https://sqlite.org/lang_upsert.html /// Column references in the DO UPDATE expressions refer to the original /// (unchanged) row. To refer to would-be inserted values, use `excluded.x`. +#[allow(clippy::too_many_arguments)] pub fn emit_upsert( program: &mut ProgramBuilder, table: &Table, @@ -339,9 +341,10 @@ pub fn emit_upsert( insertion: &Insertion, set_pairs: &mut [(usize, Box)], where_clause: &mut Option>, - resolver: &Resolver, + resolver: &mut Resolver, returning: &mut [ResultSetColumn], connection: &Arc, + table_references: &TableReferences, ) -> crate::Result<()> { // Seek & snapshot CURRENT program.emit_insn(Insn::SeekRowid { @@ -823,12 +826,14 @@ pub fn emit_upsert( // RETURNING from NEW image + final rowid if !returning.is_empty() { - let regs = ReturningValueRegisters { - rowid_register: new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg), - columns_start_register: new_start, - num_columns: num_cols, - }; - emit_returning_results(program, returning, ®s)?; + emit_returning_results( + program, + table_references, + returning, + new_start, + new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg), + resolver, + )?; } program.emit_insn(Insn::Goto {