diff --git a/core/translate/insert.rs b/core/translate/insert.rs index cfc376712..af25ad6e1 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -19,7 +19,7 @@ use crate::{ use crate::{Result, SymbolTable, VirtualTable}; use super::emitter::Resolver; -use super::expr::{translate_expr_no_constant_opt, NoConstantOptReason}; +use super::expr::{translate_expr, translate_expr_no_constant_opt, NoConstantOptReason}; use super::optimizer::rewrite_expr; use super::plan::QueryDestination; use super::select::translate_select; @@ -32,7 +32,7 @@ pub fn translate_insert( on_conflict: Option, tbl_name: QualifiedName, columns: Option, - body: InsertBody, + mut body: InsertBody, _returning: Option>, syms: &SymbolTable, mut program: ProgramBuilder, @@ -101,10 +101,22 @@ pub fn translate_insert( .collect::>(); let root_page = btree_table.root_page; - let inserting_multiple_rows = match &body { - InsertBody::Select(select, _) => match select.body.select.as_ref() { - OneSelect::Values(values) => values.len() > 1, - OneSelect::Select(..) => true, + let mut values: Option> = None; + let inserting_multiple_rows = match &mut body { + InsertBody::Select(select, _) => match select.body.select.as_mut() { + // TODO see how to avoid clone + OneSelect::Values(values_expr) if values_expr.len() <= 1 => { + if values_expr.is_empty() { + crate::bail_parse_error!("no values to insert"); + } + let mut param_idx = 1; + for expr in values_expr.iter_mut().flat_map(|v| v.iter_mut()) { + rewrite_expr(expr, &mut param_idx)?; + } + values = values_expr.pop(); + false + } + _ => true, }, InsertBody::DefaultValues => false, }; @@ -112,25 +124,16 @@ pub fn translate_insert( let halt_label = program.allocate_label(); let loop_start_label = program.allocate_label(); - let (num_values, value) = match body { + let mut yield_reg_opt = None; + let num_values = match body { // TODO: upsert InsertBody::Select(select, _) => { // Simple Common case of INSERT INTO VALUES (...) if matches!(select.body.select.as_ref(), OneSelect::Values(values) if values.len() <= 1) { - let OneSelect::Values(mut values) = *select.body.select else { - unreachable!(); - }; - if values.is_empty() { - crate::bail_parse_error!("no values to insert"); - } - let mut param_idx = 1; - for expr in values.iter_mut().flat_map(|v| v.iter_mut()) { - rewrite_expr(expr, &mut param_idx)?; - } - - (values[0].len(), values.pop()) + values.as_ref().unwrap().len() } else { + // Multiple rows - use coroutine for value population let yield_reg = program.alloc_register(); let jump_on_definition_label = program.allocate_label(); let start_offset_label = program.allocate_label(); @@ -176,10 +179,11 @@ pub fn translate_insert( end_offset: halt_label, }); - (result.num_result_cols, None) + yield_reg_opt = Some(yield_reg); + result.num_result_cols } } - InsertBody::DefaultValues => (0, None), + InsertBody::DefaultValues => 0, }; let column_mappings = resolve_columns_for_insert(&table, &columns, num_values)?; @@ -209,8 +213,15 @@ pub fn translate_insert( let record_register = program.alloc_register(); - // Multiple rows - use coroutine for value population - if !inserting_multiple_rows { + if inserting_multiple_rows { + populate_record_multiple_rows( + &mut program, + &column_mappings, + column_registers_start, + yield_reg_opt.unwrap() + 1, + &resolver, + )?; + } else { // Single row - populate registers directly program.emit_insn(Insn::OpenWrite { cursor_id, @@ -220,10 +231,9 @@ pub fn translate_insert( populate_column_registers( &mut program, - &value.unwrap(), + &values.unwrap(), &column_mappings, column_registers_start, - false, rowid_reg, &resolver, )?; @@ -469,7 +479,7 @@ fn resolve_columns_for_insert<'a>( columns: &Option, num_values: usize, ) -> Result>> { - let table_columns = &table.columns(); + let table_columns = table.columns(); // Case 1: No columns specified - map values to columns in order if columns.is_none() { if num_values > table_columns.len() { @@ -589,6 +599,51 @@ fn resolve_indicies_for_insert( Ok(index_col_mappings) } +fn populate_record_multiple_rows( + program: &mut ProgramBuilder, + column_mappings: &[ColumnMapping], + column_registers_start: usize, + yield_reg: usize, + resolver: &Resolver, +) -> Result<()> { + let mut value_index_seen = 0; + for (i, mapping) in column_mappings.iter().enumerate() { + let target_reg = column_registers_start + i; + + if mapping.value_index.is_some() { + program.emit_insn(Insn::Copy { + src_reg: yield_reg + value_index_seen, + dst_reg: target_reg, + amount: 0, + }); + value_index_seen += 1; + continue; + } + + if mapping.column.is_rowid_alias { + program.emit_insn(Insn::SoftNull { reg: target_reg }); + } else if let Some(default_expr) = mapping.default_value { + translate_expr(program, None, default_expr, target_reg, resolver)?; + } else { + // Column was not specified as has no DEFAULT - use NULL if it is nullable, otherwise error + // Rowid alias columns can be NULL because we will autogenerate a rowid in that case. + let is_nullable = !mapping.column.primary_key || mapping.column.is_rowid_alias; + if is_nullable { + program.emit_insn(Insn::Null { + dest: target_reg, + dest_end: None, + }); + } else { + crate::bail_parse_error!( + "column {} is not nullable", + mapping.column.name.as_ref().expect("column name is None") + ); + } + } + } + Ok(()) +} + /// Populates the column registers with values for a single row #[allow(clippy::too_many_arguments)] fn populate_column_registers( @@ -596,7 +651,6 @@ fn populate_column_registers( value: &[Expr], column_mappings: &[ColumnMapping], column_registers_start: usize, - inserting_multiple_rows: bool, rowid_reg: usize, resolver: &Resolver, ) -> Result<()> { @@ -608,8 +662,7 @@ fn populate_column_registers( // When inserting a single row, SQLite writes the value provided for the rowid alias column (INTEGER PRIMARY KEY) // directly into the rowid register and writes a NULL into the rowid alias column. Not sure why this only happens // in the single row case, but let's copy it. - let write_directly_to_rowid_reg = - mapping.column.is_rowid_alias && !inserting_multiple_rows; + let write_directly_to_rowid_reg = mapping.column.is_rowid_alias; let reg = if write_directly_to_rowid_reg { rowid_reg } else { @@ -695,7 +748,6 @@ fn translate_virtual_table_insert( &value, &column_mappings, values_reg, - false, registers_start, resolver, )?;