Merge 'Refactor RETURNING to support arbitrary expressions' from Jussi Saurio

Note: contains 1000 lines of TCL tests generated by cursor :] runtime
changes are smaller and this actually deletes code in aggregate.
---
Main change is to support arbitrary expressions in RETURNING instead of
a specialcased subset, but also e.g. disallow returning TABLE.* which is
illegal syntax in SQLite.
Main idea is we add the columns of the target table (the table affected
by INSERT/UPDATE/DELETE) into `expr_to_reg_cache`) and then just
translate the RETURNING expressions as normal

Closes #3942
This commit is contained in:
Jussi Saurio
2025-11-14 13:34:53 +02:00
committed by GitHub
11 changed files with 1242 additions and 328 deletions

View File

@@ -38,10 +38,10 @@ pub fn emit_ungrouped_aggregation<'a>(
// we need to call translate_expr on each result column, but replace the expr with a register copy in case any part of the
// result column expression matches a) a group by column or b) an aggregation result.
for (i, agg) in plan.aggregates.iter().enumerate() {
t_ctx
.resolver
.expr_to_reg_cache
.push((&agg.original_expr, agg_start_reg + i));
t_ctx.resolver.expr_to_reg_cache.push((
std::borrow::Cow::Borrowed(&agg.original_expr),
agg_start_reg + i,
));
}
t_ctx.resolver.enable_expr_to_reg_cache();

View File

@@ -1,7 +1,6 @@
// This module contains code for emitting bytecode instructions for SQL query execution.
// It handles translating high-level SQL operations into low-level bytecode that can be executed by the virtual machine.
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::sync::Arc;
@@ -24,18 +23,20 @@ use super::select::emit_simple_count;
use super::subquery::emit_from_clause_subqueries;
use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY;
use crate::function::Func;
use crate::schema::{BTreeTable, Column, Schema, Table, ROWID_SENTINEL};
use crate::schema::{BTreeTable, Column, Index, Schema, Table, ROWID_SENTINEL};
use crate::translate::compound_select::emit_program_for_compound_select;
use crate::translate::expr::{
emit_returning_results, translate_expr_no_constant_opt, walk_expr_mut, NoConstantOptReason,
ReturningValueRegisters, WalkControl,
WalkControl,
};
use crate::translate::fkeys::{
build_index_affinity_string, emit_fk_child_update_counters,
emit_fk_delete_parent_existence_checks, emit_guarded_fk_decrement,
emit_parent_key_change_checks, open_read_index, open_read_table, stabilize_new_row_for_fk,
};
use crate::translate::plan::{DeletePlan, EvalAt, JoinedTable, Plan, QueryDestination, Search};
use crate::translate::plan::{
DeletePlan, EvalAt, JoinedTable, Plan, QueryDestination, ResultSetColumn, Search,
};
use crate::translate::planner::ROWID_STRS;
use crate::translate::subquery::emit_non_from_clause_subquery;
use crate::translate::values::emit_values;
@@ -51,7 +52,7 @@ pub struct Resolver<'a> {
pub schema: &'a Schema,
pub symbol_table: &'a SymbolTable,
pub expr_to_reg_cache_enabled: bool,
pub expr_to_reg_cache: Vec<(&'a ast::Expr, usize)>,
pub expr_to_reg_cache: Vec<(std::borrow::Cow<'a, ast::Expr>, usize)>,
}
impl<'a> Resolver<'a> {
@@ -665,12 +666,12 @@ pub fn emit_fk_child_decrement_on_delete(
Ok(())
}
fn emit_delete_insns(
fn emit_delete_insns<'a>(
connection: &Arc<Connection>,
program: &mut ProgramBuilder,
t_ctx: &mut TranslateCtx,
t_ctx: &mut TranslateCtx<'a>,
table_references: &mut TableReferences,
result_columns: &[super::plan::ResultSetColumn],
result_columns: &'a [super::plan::ResultSetColumn],
) -> Result<()> {
// we can either use this obviously safe raw pointer or we can clone it
let table_reference: *const JoinedTable = table_references.joined_tables().first().unwrap();
@@ -873,13 +874,14 @@ fn emit_delete_insns(
}
// Emit RETURNING results using the values we just read
let value_registers = ReturningValueRegisters {
rowid_register: rowid_reg,
columns_start_register: columns_start_reg,
num_columns: cols_len,
};
emit_returning_results(program, result_columns, &value_registers)?;
emit_returning_results(
program,
table_references,
result_columns,
columns_start_reg,
rowid_reg,
&mut t_ctx.resolver,
)?;
}
program.emit_insn(Insn::Delete {
@@ -1042,8 +1044,13 @@ fn emit_program_for_update(
// Emit update instructions
emit_update_insns(
connection,
&mut plan,
&t_ctx,
&mut plan.table_references,
&plan.set_clauses,
plan.cdc_update_alter_statement.as_deref(),
&plan.indexes_to_update,
plan.returning.as_ref(),
plan.ephemeral_plan.as_ref(),
&mut t_ctx,
program,
index_cursors,
iteration_cursor_id,
@@ -1077,10 +1084,16 @@ fn emit_program_for_update(
/// `target_table_cursor_id` is the cursor id of the table that is being updated.
///
/// `target_table` is the table that is being updated.
fn emit_update_insns(
#[allow(clippy::too_many_arguments)]
fn emit_update_insns<'a>(
connection: &Arc<Connection>,
plan: &mut UpdatePlan,
t_ctx: &TranslateCtx,
table_references: &mut TableReferences,
set_clauses: &[(usize, Box<ast::Expr>)],
cdc_update_alter_statement: Option<&str>,
indexes_to_update: &[Arc<Index>],
returning: Option<&'a Vec<ResultSetColumn>>,
ephemeral_plan: Option<&SelectPlan>,
t_ctx: &mut TranslateCtx<'a>,
program: &mut ProgramBuilder,
index_cursors: Vec<(usize, usize)>,
iteration_cursor_id: usize,
@@ -1089,7 +1102,7 @@ fn emit_update_insns(
) -> crate::Result<()> {
let internal_id = target_table.internal_id;
let loop_labels = t_ctx.labels_main_loop.first().unwrap();
let source_table = plan.table_references.joined_tables().first().unwrap();
let source_table = table_references.joined_tables().first().unwrap();
let (index, is_virtual) = match &source_table.op {
Operation::Scan(Scan::BTreeTable { index, .. }) => (
index.as_ref().map(|index| {
@@ -1138,13 +1151,10 @@ fn emit_update_insns(
.iter()
.position(|c| c.is_rowid_alias());
let has_direct_rowid_update = plan
.set_clauses
.iter()
.any(|(idx, _)| *idx == ROWID_SENTINEL);
let has_direct_rowid_update = set_clauses.iter().any(|(idx, _)| *idx == ROWID_SENTINEL);
let has_user_provided_rowid = if let Some(index) = rowid_alias_index {
plan.set_clauses.iter().any(|(idx, _)| *idx == index)
set_clauses.iter().any(|(idx, _)| *idx == index)
} else {
has_direct_rowid_update
};
@@ -1210,11 +1220,11 @@ fn emit_update_insns(
let start = if is_virtual { beg + 2 } else { beg + 1 };
if has_direct_rowid_update {
if let Some((_, expr)) = plan.set_clauses.iter().find(|(i, _)| *i == ROWID_SENTINEL) {
if let Some((_, expr)) = set_clauses.iter().find(|(i, _)| *i == ROWID_SENTINEL) {
let rowid_set_clause_reg = rowid_set_clause_reg.unwrap();
translate_expr(
program,
Some(&plan.table_references),
Some(table_references),
expr,
rowid_set_clause_reg,
&t_ctx.resolver,
@@ -1226,7 +1236,7 @@ fn emit_update_insns(
}
for (idx, table_column) in target_table.table.columns().iter().enumerate() {
let target_reg = start + idx;
if let Some((col_idx, expr)) = plan.set_clauses.iter().find(|(i, _)| *i == idx) {
if let Some((col_idx, expr)) = set_clauses.iter().find(|(i, _)| *i == idx) {
// Skip if this is the sentinel value
if *col_idx == ROWID_SENTINEL {
continue;
@@ -1238,7 +1248,7 @@ fn emit_update_insns(
let rowid_set_clause_reg = rowid_set_clause_reg.unwrap();
translate_expr(
program,
Some(&plan.table_references),
Some(table_references),
expr,
rowid_set_clause_reg,
&t_ctx.resolver,
@@ -1252,7 +1262,7 @@ fn emit_update_insns(
} else {
translate_expr(
program,
Some(&plan.table_references),
Some(table_references),
expr,
target_reg,
&t_ctx.resolver,
@@ -1280,9 +1290,9 @@ fn emit_update_insns(
program.emit_bool(true, change_reg);
program.mark_last_insn_constant();
let mut updated = false;
if let Some(ddl_query_for_cdc_update) = &plan.cdc_update_alter_statement {
if let Some(ddl_query_for_cdc_update) = &cdc_update_alter_statement {
if table_column.name.as_deref() == Some("sql") {
program.emit_string8(ddl_query_for_cdc_update.clone(), value_reg);
program.emit_string8(ddl_query_for_cdc_update.to_string(), value_reg);
updated = true;
}
}
@@ -1346,7 +1356,7 @@ fn emit_update_insns(
stabilize_new_row_for_fk(
program,
&table_btree,
&plan.set_clauses,
set_clauses,
target_table_cursor_id,
start,
rowid_new_reg,
@@ -1362,11 +1372,10 @@ fn emit_update_insns(
target_table_cursor_id,
start,
rowid_new_reg,
&plan
.set_clauses
&set_clauses
.iter()
.map(|(i, _)| *i)
.collect::<HashSet<_>>(),
.collect::<std::collections::HashSet<_>>(),
)?;
}
// Parent-side checks:
@@ -1382,19 +1391,19 @@ fn emit_update_insns(
program,
&t_ctx.resolver,
&table_btree,
plan.indexes_to_update.iter(),
indexes_to_update.iter(),
target_table_cursor_id,
beg,
start,
rowid_new_reg,
rowid_set_clause_reg,
&plan.set_clauses,
set_clauses,
)?;
}
}
}
for (index, (idx_cursor_id, record_reg)) in plan.indexes_to_update.iter().zip(&index_cursors) {
for (index, (idx_cursor_id, record_reg)) in indexes_to_update.iter().zip(&index_cursors) {
// We need to know whether or not the OLD values satisfied the predicate on the
// partial index, so we can know whether or not to delete the old index entry,
// as well as whether or not the NEW values satisfy the predicate, to determine whether
@@ -1403,12 +1412,12 @@ fn emit_update_insns(
// This means that we need to bind the column references to a copy of the index Expr,
// so we can emit Insn::Column instructions and refer to the old values.
let where_clause = index
.bind_where_expr(Some(&mut plan.table_references), connection)
.bind_where_expr(Some(table_references), connection)
.expect("where clause to exist");
let old_satisfied_reg = program.alloc_register();
translate_expr_no_constant_opt(
program,
Some(&plan.table_references),
Some(table_references),
&where_clause,
old_satisfied_reg,
&t_ctx.resolver,
@@ -1744,15 +1753,16 @@ fn emit_update_insns(
});
// Emit RETURNING results if specified
if let Some(returning_columns) = &plan.returning {
if let Some(returning_columns) = &returning {
if !returning_columns.is_empty() {
let value_registers = ReturningValueRegisters {
rowid_register: rowid_set_clause_reg.unwrap_or(beg),
columns_start_register: start,
num_columns: col_len,
};
emit_returning_results(program, returning_columns, &value_registers)?;
emit_returning_results(
program,
table_references,
returning_columns,
start,
rowid_set_clause_reg.unwrap_or(beg),
&mut t_ctx.resolver,
)?;
}
}
@@ -1814,7 +1824,7 @@ fn emit_update_insns(
emit_cdc_insns(
program,
&t_ctx.resolver,
OperationMode::UPDATE(if plan.ephemeral_plan.is_some() {
OperationMode::UPDATE(if ephemeral_plan.is_some() {
UpdateRowSource::PrebuiltEphemeralTable {
ephemeral_table_cursor_id: iteration_cursor_id,
target_table: target_table.clone(),

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use tracing::{instrument, Level};
use turso_parser::ast::{self, As, Expr, SubqueryType, UnaryOperator};
use turso_parser::ast::{self, Expr, SubqueryType, UnaryOperator};
use super::emitter::Resolver;
use super::optimizer::Optimizable;
@@ -22,7 +22,7 @@ use crate::vdbe::{
insn::{CmpInsFlags, Insn},
BranchOffset,
};
use crate::{Result, Value};
use crate::{turso_assert, Result, Value};
use super::collate::CollationSeq;
@@ -34,16 +34,6 @@ pub struct ConditionMetadata {
pub jump_target_when_null: BranchOffset,
}
/// Container for register locations of values that can be referenced in RETURNING expressions
pub struct ReturningValueRegisters {
/// Register containing the rowid/primary key
pub rowid_register: usize,
/// Starting register for column values (in column order)
pub columns_start_register: usize,
/// Number of columns available
pub num_columns: usize,
}
#[instrument(skip_all, level = Level::DEBUG)]
fn emit_cond_jump(program: &mut ProgramBuilder, cond_meta: ConditionMetadata, reg: usize) {
if cond_meta.jump_if_condition_is_true {
@@ -4143,106 +4133,6 @@ pub fn compare_affinity(
}
}
/// Evaluate a RETURNING expression using register-based evaluation instead of cursor-based.
/// This is used for RETURNING clauses where we have register values instead of cursor data.
pub fn translate_expr_for_returning(
program: &mut ProgramBuilder,
expr: &Expr,
value_registers: &ReturningValueRegisters,
target_register: usize,
) -> Result<usize> {
match expr {
Expr::Column {
column,
is_rowid_alias,
..
} => {
if *is_rowid_alias {
// For rowid references, copy from the rowid register
program.emit_insn(Insn::Copy {
src_reg: value_registers.rowid_register,
dst_reg: target_register,
extra_amount: 0,
});
} else {
// For regular column references, copy from the appropriate column register
let column_idx = *column;
if column_idx < value_registers.num_columns {
let column_reg = value_registers.columns_start_register + column_idx;
program.emit_insn(Insn::Copy {
src_reg: column_reg,
dst_reg: target_register,
extra_amount: 0,
});
} else {
crate::bail_parse_error!("Column index out of bounds in RETURNING clause");
}
}
Ok(target_register)
}
Expr::RowId { .. } => {
// For ROWID expressions, copy from the rowid register
program.emit_insn(Insn::Copy {
src_reg: value_registers.rowid_register,
dst_reg: target_register,
extra_amount: 0,
});
Ok(target_register)
}
Expr::Literal(literal) => emit_literal(program, literal, target_register),
Expr::Binary(lhs, op, rhs) => {
let lhs_reg = program.alloc_register();
let rhs_reg = program.alloc_register();
// Recursively evaluate left-hand side
translate_expr_for_returning(program, lhs, value_registers, lhs_reg)?;
// Recursively evaluate right-hand side
translate_expr_for_returning(program, rhs, value_registers, rhs_reg)?;
// Use the shared emit_binary_insn function
emit_binary_insn(
program,
op,
lhs_reg,
rhs_reg,
target_register,
lhs,
rhs,
None, // No table references needed for RETURNING
None, // No condition metadata needed for RETURNING
)?;
Ok(target_register)
}
Expr::FunctionCall { name, args, .. } => {
// Evaluate arguments into registers
let mut arg_regs = Vec::new();
for arg in args.iter() {
let arg_reg = program.alloc_register();
translate_expr_for_returning(program, arg, value_registers, arg_reg)?;
arg_regs.push(arg_reg);
}
// Resolve and call the function using shared helper
let func = Func::resolve_function(name.as_str(), arg_regs.len())?;
let func_ctx = FuncCtx {
func,
arg_count: arg_regs.len(),
};
emit_function_call(program, func_ctx, &arg_regs, target_register)?;
Ok(target_register)
}
_ => {
crate::bail_parse_error!(
"Unsupported expression type in RETURNING clause: {:?}",
expr
);
}
}
}
/// Emit literal values - shared between regular and RETURNING expression evaluation
pub fn emit_literal(
program: &mut ProgramBuilder,
@@ -4363,56 +4253,40 @@ pub fn emit_function_call(
/// with proper column binding and alias handling.
pub fn process_returning_clause(
returning: &mut [ast::ResultColumn],
table: &Table,
table_name: &str,
program: &mut ProgramBuilder,
table_references: &mut TableReferences,
connection: &std::sync::Arc<crate::Connection>,
) -> Result<(
Vec<super::plan::ResultSetColumn>,
super::plan::TableReferences,
)> {
use super::plan::{ColumnUsedMask, JoinedTable, Operation, ResultSetColumn, TableReferences};
) -> Result<Vec<ResultSetColumn>> {
let mut result_columns = Vec::with_capacity(returning.len());
let mut result_columns = vec![];
let internal_id = program.table_reference_counter.next();
let mut table_references = TableReferences::new(
vec![JoinedTable {
table: match table {
Table::Virtual(vtab) => Table::Virtual(vtab.clone()),
Table::BTree(btree_table) => Table::BTree(btree_table.clone()),
_ => unreachable!(),
},
identifier: table_name.to_string(),
internal_id,
op: Operation::default_scan_for(table),
join_info: None,
col_used_mask: ColumnUsedMask::default(),
database_id: 0,
}],
vec![],
);
let alias_to_string = |alias: &ast::As| match alias {
ast::As::Elided(alias) => alias.as_str().to_string(),
ast::As::As(alias) => alias.as_str().to_string(),
};
for rc in returning.iter_mut() {
match rc {
ast::ResultColumn::Expr(expr, alias) => {
bind_and_rewrite_expr(
expr,
Some(&mut table_references),
Some(table_references),
None,
connection,
BindingBehavior::TryResultColumnsFirst,
)?;
let column_alias = determine_column_alias(expr, alias, table);
result_columns.push(ResultSetColumn {
expr: *expr.clone(),
alias: column_alias,
expr: expr.as_ref().clone(),
alias: alias.as_ref().map(alias_to_string),
contains_aggregates: false,
});
}
ast::ResultColumn::Star => {
let table = table_references
.joined_tables()
.first()
.expect("RETURNING clause must reference at least one table");
let internal_id = table.internal_id;
// Handle RETURNING * by expanding to all table columns
// Use the shared internal_id for all columns
for (column_index, column) in table.columns().iter().enumerate() {
@@ -4420,7 +4294,7 @@ pub fn process_returning_clause(
database: None,
table: internal_id,
column: column_index,
is_rowid_alias: false,
is_rowid_alias: column.is_rowid_alias(),
};
result_columns.push(ResultSetColumn {
@@ -4430,91 +4304,81 @@ pub fn process_returning_clause(
});
}
}
ast::ResultColumn::TableStar(_table_name) => {
// Handle RETURNING table.* by expanding to all table columns
// For single table RETURNING, this is equivalent to *
for (column_index, column) in table.columns().iter().enumerate() {
let column_expr = Expr::Column {
database: None,
table: internal_id,
column: column_index,
is_rowid_alias: false,
};
result_columns.push(ResultSetColumn {
expr: column_expr,
alias: column.name.clone(),
contains_aggregates: false,
});
}
ast::ResultColumn::TableStar(_) => {
crate::bail_parse_error!("RETURNING may not use \"TABLE.*\" wildcards");
}
}
}
Ok((result_columns, table_references))
}
/// Determine the appropriate alias for a RETURNING column expression
fn determine_column_alias(
expr: &Expr,
explicit_alias: &Option<ast::As>,
table: &Table,
) -> Option<String> {
// First check for explicit alias
if let Some(As::As(name)) = explicit_alias {
return Some(name.as_str().to_string());
}
// For ROWID expressions, use "rowid" as the alias
if let Expr::RowId { .. } = expr {
return Some("rowid".to_string());
}
// For column references, always use the column name from the table
if let Expr::Column {
column,
is_rowid_alias,
..
} = expr
{
if let Some(name) = table
.columns()
.get(*column)
.and_then(|col| col.name.clone())
{
return Some(name);
} else if *is_rowid_alias {
// If it's a rowid alias, return "rowid"
return Some("rowid".to_string());
} else {
return None;
}
}
// For other expressions, use the expression string representation
Some(expr.to_string())
Ok(result_columns)
}
/// Emit bytecode to evaluate RETURNING expressions and produce result rows.
/// This function handles the actual evaluation of expressions using the values
/// from the DML operation.
pub(crate) fn emit_returning_results(
/// RETURNING result expressions are otherwise evaluated as normal, but the columns of the target table
/// are added to [Resolver::expr_to_reg_cache], meaning a reference to e.g tbl.col will effectively
/// refer to a register where the OLD/NEW value of tbl.col is stored after an INSERT/UPDATE/DELETE.
pub(crate) fn emit_returning_results<'a>(
program: &mut ProgramBuilder,
table_references: &TableReferences,
result_columns: &[super::plan::ResultSetColumn],
value_registers: &ReturningValueRegisters,
reg_columns_start: usize,
rowid_reg: usize,
resolver: &mut Resolver<'a>,
) -> Result<()> {
if result_columns.is_empty() {
return Ok(());
}
turso_assert!(table_references.joined_tables().len() == 1, "RETURNING is only used with INSERT, UPDATE, or DELETE statements, which target a single table");
let table = table_references.joined_tables().first().unwrap();
resolver.enable_expr_to_reg_cache();
let expr = Expr::RowId {
database: None,
table: table.internal_id,
};
let cache_len = resolver.expr_to_reg_cache.len();
resolver
.expr_to_reg_cache
.push((std::borrow::Cow::Owned(expr), rowid_reg));
for (i, column) in table.columns().iter().enumerate() {
let reg = if column.is_rowid_alias() {
rowid_reg
} else {
reg_columns_start + i
};
let expr = Expr::Column {
database: None,
table: table.internal_id,
column: i,
is_rowid_alias: column.is_rowid_alias(),
};
resolver
.expr_to_reg_cache
.push((std::borrow::Cow::Owned(expr), reg));
}
let result_start_reg = program.alloc_registers(result_columns.len());
for (i, result_column) in result_columns.iter().enumerate() {
let reg = result_start_reg + i;
translate_expr_for_returning(program, &result_column.expr, value_registers, reg)?;
translate_expr_no_constant_opt(
program,
Some(table_references),
&result_column.expr,
reg,
resolver,
NoConstantOptReason::RegisterReuse,
)?;
}
// Bit of a hack: this is required in case of e.g. INSERT ... ON CONFLICT DO UPDATE ... RETURNING
// where the result column values may either be the ones that were inserted, or the ones that were updated,
// depending on the row in question.
// meaning: emit_returning_results() may be called twice during translation and the cached expression values
// must be distinct for each call.
resolver.expr_to_reg_cache.truncate(cache_len);
program.emit_insn(Insn::ResultRow {
start_reg: result_start_reg,
count: result_columns.len(),

View File

@@ -646,7 +646,10 @@ pub fn group_by_process_single_group(
{
if *in_result {
program.emit_column_or_rowid(*pseudo_cursor, sorter_column_index, next_reg);
t_ctx.resolver.expr_to_reg_cache.push((expr, next_reg));
t_ctx
.resolver
.expr_to_reg_cache
.push((std::borrow::Cow::Borrowed(expr), next_reg));
next_reg += 1;
}
}
@@ -669,7 +672,10 @@ pub fn group_by_process_single_group(
dest_reg,
&t_ctx.resolver,
)?;
t_ctx.resolver.expr_to_reg_cache.push((expr, dest_reg));
t_ctx
.resolver
.expr_to_reg_cache
.push((std::borrow::Cow::Borrowed(expr), dest_reg));
}
}
}
@@ -792,10 +798,10 @@ pub fn group_by_emit_row_phase<'a>(
register: agg_result_reg,
func: agg.func.clone(),
});
t_ctx
.resolver
.expr_to_reg_cache
.push((&agg.original_expr, agg_result_reg));
t_ctx.resolver.expr_to_reg_cache.push((
std::borrow::Cow::Borrowed(&agg.original_expr),
agg_result_reg,
));
}
t_ctx.resolver.enable_expr_to_reg_cache();

View File

@@ -13,13 +13,15 @@ use crate::translate::emitter::{
};
use crate::translate::expr::{
bind_and_rewrite_expr, emit_returning_results, process_returning_clause, walk_expr_mut,
BindingBehavior, ReturningValueRegisters, WalkControl,
BindingBehavior, WalkControl,
};
use crate::translate::fkeys::{
build_index_affinity_string, emit_fk_violation, emit_guarded_fk_decrement, index_probe,
open_read_index, open_read_table,
};
use crate::translate::plan::{ResultSetColumn, TableReferences};
use crate::translate::plan::{
ColumnUsedMask, JoinedTable, Operation, ResultSetColumn, TableReferences,
};
use crate::translate::planner::ROWID_STRS;
use crate::translate::upsert::{
collect_set_clauses_for_upsert, emit_upsert, resolve_upsert_target, ResolvedUpsertTarget,
@@ -94,7 +96,7 @@ pub struct InsertEmitCtx<'a> {
/// Index cursors we need to populate for this table
/// (idx name, root_page, idx cursor id)
pub idx_cursors: Vec<(&'a String, i64, usize)>,
pub idx_cursors: Vec<(String, i64, usize)>,
/// Context for if the insert values are materialized first
/// into a temporary table
@@ -134,7 +136,7 @@ pub struct InsertEmitCtx<'a> {
impl<'a> InsertEmitCtx<'a> {
fn new(
program: &mut ProgramBuilder,
resolver: &'a Resolver,
resolver: &Resolver,
table: &'a Arc<BTreeTable>,
on_conflict: Option<ResolveType>,
cdc_table: Option<(usize, Arc<BTreeTable>)>,
@@ -146,7 +148,7 @@ impl<'a> InsertEmitCtx<'a> {
let mut idx_cursors = Vec::new();
for idx in indices {
idx_cursors.push((
&idx.name,
idx.name.clone(),
idx.root_page,
program.alloc_cursor_index(None, idx)?,
));
@@ -181,7 +183,7 @@ impl<'a> InsertEmitCtx<'a> {
#[allow(clippy::too_many_arguments)]
pub fn translate_insert(
resolver: &Resolver,
resolver: &mut Resolver,
on_conflict: Option<ResolveType>,
tbl_name: QualifiedName,
columns: Vec<ast::Name>,
@@ -240,14 +242,26 @@ pub fn translate_insert(
let cdc_table = prepare_cdc_if_necessary(&mut program, resolver.schema, table.get_name())?;
let mut table_references = TableReferences::new(
vec![JoinedTable {
table: Table::BTree(
table
.btree()
.expect("we shouldn't have got here without a BTree table"),
),
identifier: table_name.to_string(),
internal_id: program.table_reference_counter.next(),
op: Operation::default_scan_for(&table),
join_info: None,
col_used_mask: ColumnUsedMask::default(),
database_id: 0,
}],
vec![],
);
// Process RETURNING clause using shared module
let (mut result_columns, _) = process_returning_clause(
&mut returning,
&table,
table_name.as_str(),
&mut program,
connection,
)?;
let mut result_columns =
process_returning_clause(&mut returning, &mut table_references, connection)?;
let has_fks = fk_enabled
&& (resolver.schema.has_child_fks(table_name.as_str())
|| resolver
@@ -460,32 +474,37 @@ pub fn translate_insert(
// Emit RETURNING results if specified
if !result_columns.is_empty() {
let value_registers = ReturningValueRegisters {
rowid_register: insertion.key_register(),
columns_start_register: insertion.first_col_register(),
num_columns: table.columns().len(),
};
emit_returning_results(&mut program, &result_columns, &value_registers)?;
emit_returning_results(
&mut program,
&table_references,
&result_columns,
insertion.first_col_register(),
insertion.key_register(),
resolver,
)?;
}
program.emit_insn(Insn::Goto {
target_pc: ctx.row_done_label,
});
resolve_upserts(
&mut program,
resolver,
&mut upsert_actions,
&ctx,
&insertion,
&table,
&mut result_columns,
connection,
)?;
if !upsert_actions.is_empty() {
resolve_upserts(
&mut program,
resolver,
&mut upsert_actions,
&ctx,
&insertion,
&table,
&mut result_columns,
connection,
&table_references,
)?;
}
emit_epilogue(&mut program, &ctx, inserting_multiple_rows);
program.set_needs_stmt_subtransactions(true);
program.result_columns = result_columns;
program.table_references.extend(table_references);
Ok(program)
}
@@ -544,7 +563,7 @@ fn emit_commit_phase(
let idx_cursor_id = ctx
.idx_cursors
.iter()
.find(|(name, _, _)| *name == &index.name)
.find(|(name, _, _)| name == &index.name)
.map(|(_, _, c_id)| *c_id)
.expect("no cursor found for index");
@@ -739,13 +758,14 @@ fn emit_rowid_generation(
#[allow(clippy::too_many_arguments)]
fn resolve_upserts(
program: &mut ProgramBuilder,
resolver: &Resolver,
resolver: &mut Resolver,
upsert_actions: &mut [(ResolvedUpsertTarget, BranchOffset, Box<Upsert>)],
ctx: &InsertEmitCtx,
insertion: &Insertion,
table: &Table,
result_columns: &mut [ResultSetColumn],
connection: &Arc<crate::Connection>,
table_references: &TableReferences,
) -> Result<()> {
for (_, label, upsert) in upsert_actions {
program.preassign_label_to_next_insn(*label);
@@ -768,6 +788,7 @@ fn resolve_upserts(
resolver,
result_columns,
connection,
table_references,
)?;
} else {
// UpsertDo::Nothing case
@@ -1708,7 +1729,7 @@ fn emit_preflight_constraint_checks(
let idx_cursor_id = ctx
.idx_cursors
.iter()
.find(|(name, _, _)| *name == &index.name)
.find(|(name, _, _)| name == &index.name)
.map(|(_, _, c_id)| *c_id)
.expect("no cursor found for index");

View File

@@ -93,14 +93,14 @@ pub fn translate(
);
program.prologue();
let resolver = Resolver::new(schema, syms);
let mut resolver = Resolver::new(schema, syms);
program = match stmt {
// There can be no nesting with pragma, so lift it up here
ast::Stmt::Pragma { name, body } => {
pragma::translate_pragma(&resolver, &name, body, pager, connection.clone(), program)?
}
stmt => translate_inner(stmt, &resolver, program, &connection, input)?,
stmt => translate_inner(stmt, &mut resolver, program, &connection, input)?,
};
program.epilogue(schema);
@@ -113,7 +113,7 @@ pub fn translate(
/// Translate SQL statement into bytecode program.
pub fn translate_inner(
stmt: ast::Stmt,
resolver: &Resolver,
resolver: &mut Resolver,
program: ProgramBuilder,
connection: &Arc<Connection>,
input: &str,

View File

@@ -266,13 +266,8 @@ pub fn prepare_update_plan(
}
}
let (result_columns, _table_references) = process_returning_clause(
&mut body.returning,
&table,
body.tbl_name.name.as_str(),
program,
connection,
)?;
let result_columns =
process_returning_clause(&mut body.returning, &mut table_references, connection)?;
let order_by = body
.order_by

View File

@@ -10,6 +10,7 @@ use crate::translate::emitter::UpdateRowSource;
use crate::translate::expr::{walk_expr, WalkControl};
use crate::translate::fkeys::{emit_fk_child_update_counters, emit_parent_key_change_checks};
use crate::translate::insert::{format_unique_violation_desc, InsertEmitCtx};
use crate::translate::plan::TableReferences;
use crate::translate::planner::ROWID_STRS;
use crate::vdbe::insn::CmpInsFlags;
use crate::Connection;
@@ -23,7 +24,7 @@ use crate::{
},
expr::{
emit_returning_results, translate_expr, translate_expr_no_constant_opt, walk_expr_mut,
NoConstantOptReason, ReturningValueRegisters,
NoConstantOptReason,
},
insert::Insertion,
plan::ResultSetColumn,
@@ -332,6 +333,7 @@ pub fn resolve_upsert_target(
/// Semantics reference: https://sqlite.org/lang_upsert.html
/// Column references in the DO UPDATE expressions refer to the original
/// (unchanged) row. To refer to would-be inserted values, use `excluded.x`.
#[allow(clippy::too_many_arguments)]
pub fn emit_upsert(
program: &mut ProgramBuilder,
table: &Table,
@@ -339,9 +341,10 @@ pub fn emit_upsert(
insertion: &Insertion,
set_pairs: &mut [(usize, Box<ast::Expr>)],
where_clause: &mut Option<Box<ast::Expr>>,
resolver: &Resolver,
resolver: &mut Resolver,
returning: &mut [ResultSetColumn],
connection: &Arc<Connection>,
table_references: &TableReferences,
) -> crate::Result<()> {
// Seek & snapshot CURRENT
program.emit_insn(Insn::SeekRowid {
@@ -823,12 +826,14 @@ pub fn emit_upsert(
// RETURNING from NEW image + final rowid
if !returning.is_empty() {
let regs = ReturningValueRegisters {
rowid_register: new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg),
columns_start_register: new_start,
num_columns: num_cols,
};
emit_returning_results(program, returning, &regs)?;
emit_returning_results(
program,
table_references,
returning,
new_start,
new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg),
resolver,
)?;
}
program.emit_insn(Insn::Goto {

View File

@@ -526,10 +526,10 @@ pub fn init_window<'a>(
let reg_acc_start = program.alloc_registers(window_function_count);
let reg_acc_result_start = program.alloc_registers(window_function_count);
for (i, func) in window.functions.iter().enumerate() {
t_ctx
.resolver
.expr_to_reg_cache
.push((&func.original_expr, reg_acc_result_start + i));
t_ctx.resolver.expr_to_reg_cache.push((
std::borrow::Cow::Borrowed(&func.original_expr),
reg_acc_result_start + i,
));
}
// The same approach applies to expressions referencing the subquery (columns).
@@ -543,7 +543,7 @@ pub fn init_window<'a>(
t_ctx
.resolver
.expr_to_reg_cache
.push((expr, reg_col_start + i));
.push((std::borrow::Cow::Borrowed(expr), reg_col_start + i));
}
t_ctx.meta_window = Some(WindowMetadata {

View File

@@ -48,3 +48,4 @@ source $testdir/upsert.test
source $testdir/window.test
source $testdir/partial_idx.test
source $testdir/foreign_keys.test
source $testdir/returning.test

1012
testing/returning.test Executable file

File diff suppressed because it is too large Load Diff