diff --git a/COMPAT.md b/COMPAT.md index ff73f6790..6890c7ba5 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -77,7 +77,7 @@ Turso aims to be fully compatible with SQLite, with opt-in features not supporte | REINDEX | No | | | RELEASE SAVEPOINT | No | | | REPLACE | No | | -| RETURNING clause | No | | +| RETURNING clause | Partial | DELETE is missing | | ROLLBACK TRANSACTION | Yes | | | SAVEPOINT | No | | | SELECT | Yes | | diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 9c05dd736..02bb345b8 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -6,15 +6,17 @@ use crate::translate::planner::{parse_limit, parse_where}; use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts, TableRefIdCounter}; use crate::{schema::Schema, Result, SymbolTable}; use std::sync::Arc; -use turso_sqlite3_parser::ast::{Expr, Limit, QualifiedName}; +use turso_sqlite3_parser::ast::{Expr, Limit, QualifiedName, ResultColumn}; use super::plan::{ColumnUsedMask, IterationDirection, JoinedTable, TableReferences}; +#[allow(clippy::too_many_arguments)] pub fn translate_delete( schema: &Schema, tbl_name: &QualifiedName, where_clause: Option>, limit: Option>, + returning: Option>, syms: &SymbolTable, mut program: ProgramBuilder, connection: &Arc, @@ -26,11 +28,22 @@ pub fn translate_delete( "DELETE for table with indexes is disabled by default. Run with `--experimental-indexes` to enable this feature." ); } + + // FIXME: SQLite's delete using Returning is complex. It scans the table in read mode first, building + // the result set, and only after that it opens the table for writing and deletes the rows. It + // also uses a couple of instructions that we don't implement yet (i.e.: RowSetAdd, RowSetRead, + // RowSetTest). So for now I'll just defer it altogether. + if returning.is_some() { + crate::bail_parse_error!("RETURNING currently not implemented for DELETE statements."); + } + let result_columns = vec![]; + let mut delete_plan = prepare_delete_plan( schema, tbl_name, where_clause, limit, + result_columns, &mut program.table_reference_counter, connection, )?; @@ -53,6 +66,7 @@ pub fn prepare_delete_plan( tbl_name: &QualifiedName, where_clause: Option>, limit: Option>, + result_columns: Vec, table_ref_counter: &mut TableRefIdCounter, connection: &Arc, ) -> Result { @@ -99,7 +113,7 @@ pub fn prepare_delete_plan( let plan = DeletePlan { table_references, - result_columns: vec![], + result_columns, where_clause: where_predicates, order_by: None, limit: resolved_limit, diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 3ab881813..552317067 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -24,6 +24,7 @@ use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; use crate::function::Func; use crate::schema::{Schema, Table}; use crate::translate::compound_select::emit_program_for_compound_select; +use crate::translate::expr::{emit_returning_results, ReturningValueRegisters}; use crate::translate::plan::{DeletePlan, Plan, QueryDestination, Search}; use crate::translate::values::emit_values; use crate::util::exprs_are_equivalent; @@ -453,7 +454,12 @@ fn emit_program_for_delete( None, )?; - emit_delete_insns(program, &mut t_ctx, &plan.table_references)?; + emit_delete_insns( + program, + &mut t_ctx, + &plan.table_references, + &plan.result_columns, + )?; // Clean up and close the main execution loop close_loop( @@ -476,6 +482,7 @@ fn emit_delete_insns( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, table_references: &TableReferences, + result_columns: &[super::plan::ResultSetColumn], ) -> Result<()> { let table_reference = table_references.joined_tables().first().unwrap(); if table_reference @@ -607,6 +614,33 @@ fn emit_delete_insns( )?; } + // Emit RETURNING results if specified (must be before DELETE) + if !result_columns.is_empty() { + // Get rowid for RETURNING + let rowid_reg = program.alloc_register(); + program.emit_insn(Insn::RowId { + cursor_id: main_table_cursor_id, + dest: rowid_reg, + }); + + // Allocate registers for column values + let columns_start_reg = program.alloc_registers(table_reference.columns().len()); + + // Read all column values from the row to be deleted + for (i, _column) in table_reference.columns().iter().enumerate() { + program.emit_column(main_table_cursor_id, i, columns_start_reg + i); + } + + // Emit RETURNING results using the values we just read + let value_registers = ReturningValueRegisters { + rowid_register: rowid_reg, + columns_start_register: columns_start_reg, + num_columns: table_reference.columns().len(), + }; + + emit_returning_results(program, result_columns, &value_registers)?; + } + program.emit_insn(Insn::Delete { cursor_id: main_table_cursor_id, }); @@ -1170,6 +1204,19 @@ fn emit_update_insns( table_name: table_ref.identifier.clone(), }); + // Emit RETURNING results if specified + if let Some(returning_columns) = &plan.returning { + if !returning_columns.is_empty() { + let value_registers = ReturningValueRegisters { + rowid_register: rowid_set_clause_reg.unwrap_or(beg), + columns_start_register: start, + num_columns: table_ref.columns().len(), + }; + + emit_returning_results(program, returning_columns, &value_registers)?; + } + } + // create full CDC record after update if necessary let cdc_after_reg = if program.capture_data_changes_mode().has_after() { Some(emit_cdc_patch_record( diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 649fa529f..ec6664861 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1,5 +1,5 @@ use tracing::{instrument, Level}; -use turso_sqlite3_parser::ast::{self, Expr, UnaryOperator}; +use turso_sqlite3_parser::ast::{self, As, Expr, UnaryOperator}; use super::emitter::Resolver; use super::optimizer::Optimizable; @@ -27,6 +27,16 @@ pub struct ConditionMetadata { pub jump_target_when_false: BranchOffset, } +/// Container for register locations of values that can be referenced in RETURNING expressions +pub struct ReturningValueRegisters { + /// Register containing the rowid/primary key + pub rowid_register: usize, + /// Starting register for column values (in column order) + pub columns_start_register: usize, + /// Number of columns available + pub num_columns: usize, +} + #[instrument(skip_all, level = Level::DEBUG)] fn emit_cond_jump(program: &mut ProgramBuilder, cond_meta: ConditionMetadata, reg: usize) { if cond_meta.jump_if_condition_is_true { @@ -708,12 +718,10 @@ pub fn translate_expr( )?; } } - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: regs, - dest: target_register, - func: func_ctx, - }); + + // Use shared function call helper + let arg_registers: Vec = (regs..regs + args_count).collect(); + emit_function_call(program, func_ctx, &arg_registers, target_register)?; Ok(target_register) } @@ -874,36 +882,24 @@ pub fn translate_expr( let args = expect_arguments_exact!(args, 1, vector_func); let start_reg = program.alloc_register(); translate_expr(program, referenced_tables, &args[0], start_reg, resolver)?; - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg, - dest: target_register, - func: func_ctx, - }); + + emit_function_call(program, func_ctx, &[start_reg], target_register)?; Ok(target_register) } VectorFunc::Vector64 => { let args = expect_arguments_exact!(args, 1, vector_func); let start_reg = program.alloc_register(); translate_expr(program, referenced_tables, &args[0], start_reg, resolver)?; - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg, - dest: target_register, - func: func_ctx, - }); + + emit_function_call(program, func_ctx, &[start_reg], target_register)?; Ok(target_register) } VectorFunc::VectorExtract => { let args = expect_arguments_exact!(args, 1, vector_func); let start_reg = program.alloc_register(); translate_expr(program, referenced_tables, &args[0], start_reg, resolver)?; - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg, - dest: target_register, - func: func_ctx, - }); + + emit_function_call(program, func_ctx, &[start_reg], target_register)?; Ok(target_register) } VectorFunc::VectorDistanceCos => { @@ -911,12 +907,8 @@ pub fn translate_expr( let regs = program.alloc_registers(2); translate_expr(program, referenced_tables, &args[0], regs, resolver)?; translate_expr(program, referenced_tables, &args[1], regs + 1, resolver)?; - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: regs, - dest: target_register, - func: func_ctx, - }); + + emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?; Ok(target_register) } VectorFunc::VectorDistanceEuclidean => { @@ -924,12 +916,8 @@ pub fn translate_expr( let regs = program.alloc_registers(2); translate_expr(program, referenced_tables, &args[0], regs, resolver)?; translate_expr(program, referenced_tables, &args[1], regs + 1, resolver)?; - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: regs, - dest: target_register, - func: func_ctx, - }); + + emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?; Ok(target_register) } }, @@ -2089,79 +2077,7 @@ pub fn translate_expr( } Ok(target_register) } - ast::Expr::Literal(lit) => match lit { - ast::Literal::Numeric(val) => { - match parse_numeric_literal(val)? { - Value::Integer(int_value) => { - program.emit_insn(Insn::Integer { - value: int_value, - dest: target_register, - }); - } - Value::Float(real_value) => { - program.emit_insn(Insn::Real { - value: real_value, - dest: target_register, - }); - } - _ => unreachable!(), - } - Ok(target_register) - } - ast::Literal::String(s) => { - program.emit_insn(Insn::String8 { - value: sanitize_string(s), - dest: target_register, - }); - Ok(target_register) - } - ast::Literal::Blob(s) => { - let bytes = s - .as_bytes() - .chunks_exact(2) - .map(|pair| { - // We assume that sqlite3-parser has already validated that - // the input is valid hex string, thus unwrap is safe. - let hex_byte = std::str::from_utf8(pair).unwrap(); - u8::from_str_radix(hex_byte, 16).unwrap() - }) - .collect(); - program.emit_insn(Insn::Blob { - value: bytes, - dest: target_register, - }); - Ok(target_register) - } - ast::Literal::Keyword(_) => todo!(), - ast::Literal::Null => { - program.emit_insn(Insn::Null { - dest: target_register, - dest_end: None, - }); - Ok(target_register) - } - ast::Literal::CurrentDate => { - program.emit_insn(Insn::String8 { - value: datetime::exec_date(&[]).to_string(), - dest: target_register, - }); - Ok(target_register) - } - ast::Literal::CurrentTime => { - program.emit_insn(Insn::String8 { - value: datetime::exec_time(&[]).to_string(), - dest: target_register, - }); - Ok(target_register) - } - ast::Literal::CurrentTimestamp => { - program.emit_insn(Insn::String8 { - value: datetime::exec_datetime_full(&[]).to_string(), - dest: target_register, - }); - Ok(target_register) - } - }, + ast::Expr::Literal(lit) => emit_literal(program, lit, target_register), ast::Expr::Name(_) => todo!(), ast::Expr::NotNull(expr) => { let reg = program.alloc_register(); @@ -3237,3 +3153,371 @@ pub fn compare_affinity( } } } + +/// Evaluate a RETURNING expression using register-based evaluation instead of cursor-based. +/// This is used for RETURNING clauses where we have register values instead of cursor data. +pub fn translate_expr_for_returning( + program: &mut ProgramBuilder, + expr: &Expr, + value_registers: &ReturningValueRegisters, + target_register: usize, +) -> Result { + match expr { + Expr::Column { + column, + is_rowid_alias, + .. + } => { + if *is_rowid_alias { + // For rowid references, copy from the rowid register + program.emit_insn(Insn::Copy { + src_reg: value_registers.rowid_register, + dst_reg: target_register, + extra_amount: 0, + }); + } else { + // For regular column references, copy from the appropriate column register + let column_idx = *column; + if column_idx < value_registers.num_columns { + let column_reg = value_registers.columns_start_register + column_idx; + program.emit_insn(Insn::Copy { + src_reg: column_reg, + dst_reg: target_register, + extra_amount: 0, + }); + } else { + crate::bail_parse_error!("Column index out of bounds in RETURNING clause"); + } + } + Ok(target_register) + } + Expr::RowId { .. } => { + // For ROWID expressions, copy from the rowid register + program.emit_insn(Insn::Copy { + src_reg: value_registers.rowid_register, + dst_reg: target_register, + extra_amount: 0, + }); + Ok(target_register) + } + Expr::Literal(literal) => emit_literal(program, literal, target_register), + Expr::Binary(lhs, op, rhs) => { + let lhs_reg = program.alloc_register(); + let rhs_reg = program.alloc_register(); + + // Recursively evaluate left-hand side + translate_expr_for_returning(program, lhs, value_registers, lhs_reg)?; + + // Recursively evaluate right-hand side + translate_expr_for_returning(program, rhs, value_registers, rhs_reg)?; + + // Use the shared emit_binary_insn function + emit_binary_insn( + program, + op, + lhs_reg, + rhs_reg, + target_register, + lhs, + rhs, + None, // No table references needed for RETURNING + )?; + + Ok(target_register) + } + Expr::FunctionCall { name, args, .. } => { + // Evaluate arguments into registers + let mut arg_regs = Vec::new(); + if let Some(args) = args { + for arg in args.iter() { + let arg_reg = program.alloc_register(); + translate_expr_for_returning(program, arg, value_registers, arg_reg)?; + arg_regs.push(arg_reg); + } + } + + // Resolve and call the function using shared helper + let func = Func::resolve_function(name.as_str(), arg_regs.len())?; + let func_ctx = FuncCtx { + func, + arg_count: arg_regs.len(), + }; + + emit_function_call(program, func_ctx, &arg_regs, target_register)?; + Ok(target_register) + } + _ => { + crate::bail_parse_error!( + "Unsupported expression type in RETURNING clause: {:?}", + expr + ); + } + } +} + +/// Emit literal values - shared between regular and RETURNING expression evaluation +pub fn emit_literal( + program: &mut ProgramBuilder, + literal: &ast::Literal, + target_register: usize, +) -> Result { + match literal { + ast::Literal::Numeric(val) => { + match parse_numeric_literal(val)? { + Value::Integer(int_value) => { + program.emit_insn(Insn::Integer { + value: int_value, + dest: target_register, + }); + } + Value::Float(real_value) => { + program.emit_insn(Insn::Real { + value: real_value, + dest: target_register, + }); + } + _ => unreachable!(), + } + Ok(target_register) + } + ast::Literal::String(s) => { + program.emit_insn(Insn::String8 { + value: sanitize_string(s), + dest: target_register, + }); + Ok(target_register) + } + ast::Literal::Blob(s) => { + let bytes = s + .as_bytes() + .chunks_exact(2) + .map(|pair| { + // We assume that sqlite3-parser has already validated that + // the input is valid hex string, thus unwrap is safe. + let hex_byte = std::str::from_utf8(pair).unwrap(); + u8::from_str_radix(hex_byte, 16).unwrap() + }) + .collect(); + program.emit_insn(Insn::Blob { + value: bytes, + dest: target_register, + }); + Ok(target_register) + } + ast::Literal::Keyword(_) => todo!(), + ast::Literal::Null => { + program.emit_insn(Insn::Null { + dest: target_register, + dest_end: None, + }); + Ok(target_register) + } + ast::Literal::CurrentDate => { + program.emit_insn(Insn::String8 { + value: datetime::exec_date(&[]).to_string(), + dest: target_register, + }); + Ok(target_register) + } + ast::Literal::CurrentTime => { + program.emit_insn(Insn::String8 { + value: datetime::exec_time(&[]).to_string(), + dest: target_register, + }); + Ok(target_register) + } + ast::Literal::CurrentTimestamp => { + program.emit_insn(Insn::String8 { + value: datetime::exec_datetime_full(&[]).to_string(), + dest: target_register, + }); + Ok(target_register) + } + } +} + +/// Emit a function call instruction with pre-allocated argument registers +/// This is shared between different function call contexts +pub fn emit_function_call( + program: &mut ProgramBuilder, + func_ctx: FuncCtx, + arg_registers: &[usize], + target_register: usize, +) -> Result<()> { + let start_reg = if arg_registers.is_empty() { + target_register // If no arguments, use target register as start + } else { + arg_registers[0] // Use first argument register as start + }; + + program.emit_insn(Insn::Function { + constant_mask: 0, + start_reg, + dest: target_register, + func: func_ctx, + }); + + Ok(()) +} + +/// Process a RETURNING clause, converting ResultColumn expressions into ResultSetColumn structures +/// with proper column binding and alias handling. +pub fn process_returning_clause( + returning: &mut [ast::ResultColumn], + table: &Table, + table_name: &str, + program: &mut ProgramBuilder, + connection: &std::sync::Arc, +) -> Result<( + Vec, + super::plan::TableReferences, +)> { + use super::plan::{ + ColumnUsedMask, IterationDirection, JoinedTable, Operation, ResultSetColumn, + TableReferences, + }; + use super::planner::bind_column_references; + + let mut result_columns = vec![]; + + let internal_id = program.table_reference_counter.next(); + let mut table_references = TableReferences::new( + vec![JoinedTable { + table: match table { + Table::Virtual(vtab) => Table::Virtual(vtab.clone()), + Table::BTree(btree_table) => Table::BTree(btree_table.clone()), + _ => unreachable!(), + }, + identifier: table_name.to_string(), + internal_id, + op: Operation::Scan { + iter_dir: IterationDirection::Forwards, + index: None, + }, + join_info: None, + col_used_mask: ColumnUsedMask::default(), + database_id: 0, + }], + vec![], + ); + + for rc in returning.iter_mut() { + match rc { + ast::ResultColumn::Expr(expr, alias) => { + let column_alias = determine_column_alias(expr, alias, table); + + bind_column_references(expr, &mut table_references, None, connection)?; + + result_columns.push(ResultSetColumn { + expr: expr.clone(), + alias: column_alias, + contains_aggregates: false, + }); + } + ast::ResultColumn::Star => { + // Handle RETURNING * by expanding to all table columns + // Use the shared internal_id for all columns + for (column_index, column) in table.columns().iter().enumerate() { + let column_expr = Expr::Column { + database: None, + table: internal_id, + column: column_index, + is_rowid_alias: false, + }; + + result_columns.push(ResultSetColumn { + expr: column_expr, + alias: column.name.clone(), + contains_aggregates: false, + }); + } + } + ast::ResultColumn::TableStar(_table_name) => { + // Handle RETURNING table.* by expanding to all table columns + // For single table RETURNING, this is equivalent to * + for (column_index, column) in table.columns().iter().enumerate() { + let column_expr = Expr::Column { + database: None, + table: internal_id, + column: column_index, + is_rowid_alias: false, + }; + + result_columns.push(ResultSetColumn { + expr: column_expr, + alias: column.name.clone(), + contains_aggregates: false, + }); + } + } + } + } + + Ok((result_columns, table_references)) +} + +/// Determine the appropriate alias for a RETURNING column expression +fn determine_column_alias( + expr: &Expr, + explicit_alias: &Option, + table: &Table, +) -> Option { + // First check for explicit alias + if let Some(As::As(name)) = explicit_alias { + return Some(name.to_string()); + } + + // For ROWID expressions, use "rowid" as the alias + if let Expr::RowId { .. } = expr { + return Some("rowid".to_string()); + } + + // For column references, use special handling + if let Expr::Column { + column, + is_rowid_alias, + .. + } = expr + { + if *is_rowid_alias { + return Some("rowid".to_string()); + } else { + // Get the column name from the table + return table + .columns() + .get(*column) + .and_then(|col| col.name.clone()); + } + } + + // For other expressions, use the expression string representation + Some(expr.to_string()) +} + +/// Emit bytecode to evaluate RETURNING expressions and produce result rows. +/// This function handles the actual evaluation of expressions using the values +/// from the DML operation. +pub(crate) fn emit_returning_results( + program: &mut ProgramBuilder, + result_columns: &[super::plan::ResultSetColumn], + value_registers: &ReturningValueRegisters, +) -> Result<()> { + if result_columns.is_empty() { + return Ok(()); + } + + let result_start_reg = program.alloc_registers(result_columns.len()); + + for (i, result_column) in result_columns.iter().enumerate() { + let reg = result_start_reg + i; + + translate_expr_for_returning(program, &result_column.expr, value_registers, reg)?; + } + + program.emit_insn(Insn::ResultRow { + start_reg: result_start_reg, + count: result_columns.len(), + }); + + Ok(()) +} diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 3671c90bb..4b1cd9aef 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -8,6 +8,10 @@ use turso_sqlite3_parser::ast::{ use crate::error::{SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY}; use crate::schema::{self, IndexColumn, Table}; use crate::translate::emitter::{emit_cdc_insns, emit_cdc_patch_record, OperationMode}; +use crate::translate::expr::{ + emit_returning_results, process_returning_clause, ReturningValueRegisters, +}; +use crate::translate::plan::TableReferences; use crate::translate::planner::ROWID; use crate::util::normalize_ident; use crate::vdbe::builder::ProgramBuilderOpts; @@ -42,7 +46,7 @@ pub fn translate_insert( tbl_name: QualifiedName, columns: Option, mut body: InsertBody, - _returning: Option>, + mut returning: Option>, syms: &SymbolTable, mut program: ProgramBuilder, connection: &Arc, @@ -140,6 +144,24 @@ pub fn translate_insert( None }; + // Process RETURNING clause using shared module + let (result_columns, _) = if let Some(returning) = &mut returning { + process_returning_clause( + returning, + &table, + table_name.as_str(), + &mut program, + connection, + )? + } else { + (vec![], TableReferences::new(vec![], vec![])) + }; + + // Set up the program to return result columns if RETURNING is specified + if !result_columns.is_empty() { + program.result_columns = result_columns.clone(); + } + let mut yield_reg_opt = None; let mut temp_table_ctx = None; let (num_values, cursor_id) = match body { @@ -579,6 +601,17 @@ pub fn translate_insert( )?; } + // Emit RETURNING results if specified + if !result_columns.is_empty() { + let value_registers = ReturningValueRegisters { + rowid_register: rowid_and_columns_start_register, + columns_start_register, + num_columns: table.columns().len(), + }; + + emit_returning_results(&mut program, &result_columns, &value_registers)?; + } + if inserting_multiple_rows { if let Some(temp_table_ctx) = temp_table_ctx { program.emit_insn(Insn::Next { diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 14cc8dd32..3281ce41b 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -156,6 +156,7 @@ pub fn translate_inner( tbl_name, where_clause, limit, + returning, .. } = *delete; translate_delete( @@ -163,6 +164,7 @@ pub fn translate_inner( &tbl_name, where_clause, limit, + returning, syms, program, connection, diff --git a/core/translate/update.rs b/core/translate/update.rs index f3636cadc..135b08868 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -13,9 +13,10 @@ use crate::{ vdbe::builder::{ProgramBuilder, ProgramBuilderOpts}, SymbolTable, }; -use turso_sqlite3_parser::ast::{self, Expr, ResultColumn, SortOrder, Update}; +use turso_sqlite3_parser::ast::{Expr, SortOrder, Update}; use super::emitter::emit_program; +use super::expr::process_returning_clause; use super::optimizer::optimize_plan; use super::plan::{ ColumnUsedMask, IterationDirection, JoinedTable, Plan, ResultSetColumn, TableReferences, @@ -171,27 +172,21 @@ pub fn prepare_update_plan( } } - let mut result_columns = vec![]; - if let Some(returning) = &mut body.returning { - for rc in returning.iter_mut() { - if let ResultColumn::Expr(expr, alias) = rc { - bind_column_references(expr, &mut table_references, None, connection)?; - result_columns.push(ResultSetColumn { - expr: expr.clone(), - alias: alias.as_ref().and_then(|a| { - if let ast::As::As(name) = a { - Some(name.to_string()) - } else { - None - } - }), - contains_aggregates: false, - }); - } else { - bail_parse_error!("Only expressions are allowed in RETURNING clause"); - } - } - } + let (result_columns, _table_references) = if let Some(returning) = &mut body.returning { + process_returning_clause( + returning, + &table, + body.tbl_name.name.as_str(), + program, + connection, + )? + } else { + ( + vec![], + crate::translate::plan::TableReferences::new(vec![], vec![]), + ) + }; + let order_by = body.order_by.as_ref().map(|order| { order .iter() @@ -342,7 +337,11 @@ pub fn prepare_update_plan( table_references, set_clauses, where_clause, - returning: Some(result_columns), + returning: if result_columns.is_empty() { + None + } else { + Some(result_columns) + }, order_by, limit, offset, diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index c79042aff..0e15079fb 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -6521,7 +6521,12 @@ impl Value { pub fn exec_length(&self) -> Self { match self { - Value::Text(_) | Value::Integer(_) | Value::Float(_) => { + Value::Text(t) => { + // Count Unicode scalar values (characters) + Value::Integer(t.as_str().chars().count() as i64) + } + Value::Integer(_) | Value::Float(_) => { + // For numbers, SQLite returns the length of the string representation Value::Integer(self.to_string().chars().count() as i64) } Value::Blob(blob) => Value::Integer(blob.len() as i64), diff --git a/testing/insert.test b/testing/insert.test index e2d1b327b..11fd9326f 100755 --- a/testing/insert.test +++ b/testing/insert.test @@ -517,3 +517,65 @@ do_execsql_test_in_memory_error_content insert-explicit-rowid-conflict { insert into t(rowid, x) values (1, 1); insert into t(rowid, x) values (1, 2); } {UNIQUE constraint failed: t.rowid (19)} + +# RETURNING clause tests +do_execsql_test_on_specific_db {:memory:} returning-basic-column { + CREATE TABLE test (id INTEGER, name TEXT, value REAL); + INSERT INTO test (id, name, value) VALUES (1, 'test', 10.5) RETURNING id; +} {1} + +do_execsql_test_on_specific_db {:memory:} returning-multiple-columns { + CREATE TABLE test (id INTEGER, name TEXT, value REAL); + INSERT INTO test (id, name, value) VALUES (1, 'test', 10.5) RETURNING id, name; +} {1|test} + +do_execsql_test_on_specific_db {:memory:} returning-all-columns { + CREATE TABLE test (id INTEGER, name TEXT, value REAL); + INSERT INTO test (id, name, value) VALUES (1, 'test', 10.5) RETURNING *; +} {1|test|10.5} + +do_execsql_test_on_specific_db {:memory:} returning-literal { + CREATE TABLE test (id INTEGER); + INSERT INTO test (id) VALUES (1) RETURNING 42; +} {42} + +do_execsql_test_on_specific_db {:memory:} returning-arithmetic { + CREATE TABLE test (id INTEGER, value INTEGER); + INSERT INTO test (id, value) VALUES (1, 10) RETURNING 2 * value; +} {20} + +do_execsql_test_on_specific_db {:memory:} returning-complex-expression { + CREATE TABLE test (id INTEGER, x INTEGER, y INTEGER); + INSERT INTO test (id, x, y) VALUES (1, 5, 3) RETURNING x + y * 2; +} {11} + +do_execsql_test_on_specific_db {:memory:} returning-function-call { + CREATE TABLE test (id INTEGER, name TEXT); + INSERT INTO test (id, name) VALUES (1, 'hello') RETURNING upper(name); +} {HELLO} + +do_execsql_test_on_specific_db {:memory:} returning-mixed-expressions { + CREATE TABLE test (id INTEGER, name TEXT, value INTEGER); + INSERT INTO test (id, name, value) VALUES (1, 'test', 10) RETURNING id, upper(name), value * 3; +} {1|TEST|30} + +do_execsql_test_on_specific_db {:memory:} returning-multiple-rows { + CREATE TABLE test (id INTEGER, name TEXT); + INSERT INTO test (id, name) VALUES (1, 'first'), (2, 'second') RETURNING id, name; +} {1|first +2|second} + +do_execsql_test_on_specific_db {:memory:} returning-with-autoincrement { + CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT); + INSERT INTO test (name) VALUES ('test') RETURNING id, name; +} {1|test} + +do_execsql_test_on_specific_db {:memory:} returning-rowid { + CREATE TABLE test (name TEXT); + INSERT INTO test (name) VALUES ('test') RETURNING rowid, name; +} {1|test} + +do_execsql_test_on_specific_db {:memory:} returning-null-values { + CREATE TABLE test (id INTEGER, name TEXT, value INTEGER); + INSERT INTO test (id, name, value) VALUES (1, NULL, NULL) RETURNING id, name, value; +} {1||} diff --git a/testing/update.test b/testing/update.test index 8c225aba1..5183fc514 100755 --- a/testing/update.test +++ b/testing/update.test @@ -269,4 +269,79 @@ do_execsql_test_on_specific_db {:memory:} update-single-rowid { INSERT INTO t VALUES (1); UPDATE t SET x = 2 WHERE x = 1; SELECT * FROM t; -} {2} \ No newline at end of file +} {2} + +# RETURNING clause tests for UPDATE +do_execsql_test_on_specific_db {:memory:} update-returning-basic-column { + CREATE TABLE test (id INTEGER, name TEXT, value REAL); + INSERT INTO test (id, name, value) VALUES (1, 'test', 10.5); + UPDATE test SET value = 20.5 WHERE id = 1 RETURNING id; +} {1} + +do_execsql_test_on_specific_db {:memory:} update-returning-multiple-columns { + CREATE TABLE test (id INTEGER, name TEXT, value REAL); + INSERT INTO test (id, name, value) VALUES (1, 'test', 10.5); + UPDATE test SET value = 20.5 WHERE id = 1 RETURNING id, name, value; +} {1|test|20.5} + +do_execsql_test_on_specific_db {:memory:} update-returning-all-columns { + CREATE TABLE test (id INTEGER, name TEXT, value REAL); + INSERT INTO test (id, name, value) VALUES (1, 'test', 10.5); + UPDATE test SET value = 20.5 WHERE id = 1 RETURNING *; +} {1|test|20.5} + +do_execsql_test_on_specific_db {:memory:} update-returning-literal { + CREATE TABLE test (id INTEGER, value INTEGER); + INSERT INTO test (id, value) VALUES (1, 10); + UPDATE test SET value = 20 WHERE id = 1 RETURNING 42; +} {42} + +do_execsql_test_on_specific_db {:memory:} update-returning-arithmetic { + CREATE TABLE test (id INTEGER, value INTEGER); + INSERT INTO test (id, value) VALUES (1, 10); + UPDATE test SET value = 20 WHERE id = 1 RETURNING 2 * value; +} {40} + +do_execsql_test_on_specific_db {:memory:} update-returning-complex-expression { + CREATE TABLE test (id INTEGER, x INTEGER, y INTEGER); + INSERT INTO test (id, x, y) VALUES (1, 5, 3); + UPDATE test SET x = 8 WHERE id = 1 RETURNING x + y * 2; +} {14} + +do_execsql_test_on_specific_db {:memory:} update-returning-function-call { + CREATE TABLE test (id INTEGER, name TEXT); + INSERT INTO test (id, name) VALUES (1, 'hello'); + UPDATE test SET name = 'world' WHERE id = 1 RETURNING upper(name); +} {WORLD} + +do_execsql_test_on_specific_db {:memory:} update-returning-mixed-expressions { + CREATE TABLE test (id INTEGER, name TEXT, value INTEGER); + INSERT INTO test (id, name, value) VALUES (1, 'test', 10); + UPDATE test SET name = 'updated', value = 30 WHERE id = 1 RETURNING id, upper(name), value * 2; +} {1|UPDATED|60} + +do_execsql_test_on_specific_db {:memory:} update-returning-multiple-rows { + CREATE TABLE test (id INTEGER, name TEXT); + INSERT INTO test (id, name) VALUES (1, 'first'), (2, 'second'); + UPDATE test SET name = 'updated' RETURNING id, name; +} {1|updated +2|updated} + +do_execsql_test_on_specific_db {:memory:} update-returning-with-where { + CREATE TABLE test (id INTEGER, name TEXT, active INTEGER); + INSERT INTO test (id, name, active) VALUES (1, 'first', 1), (2, 'second', 0), (3, 'third', 1); + UPDATE test SET name = 'updated' WHERE active = 1 RETURNING id, name; +} {1|updated +3|updated} + +do_execsql_test_on_specific_db {:memory:} update-returning-old-vs-new-values { + CREATE TABLE test (id INTEGER, counter INTEGER); + INSERT INTO test (id, counter) VALUES (1, 5); + UPDATE test SET counter = counter + 10 WHERE id = 1 RETURNING id, counter; +} {1|15} + +do_execsql_test_on_specific_db {:memory:} update-returning-null-values { + CREATE TABLE test (id INTEGER, name TEXT, value INTEGER); + INSERT INTO test (id, name, value) VALUES (1, 'test', 10); + UPDATE test SET name = NULL, value = NULL WHERE id = 1 RETURNING id, name, value; +} {1||} \ No newline at end of file