From 0593a99f0ec1e533f90668898020f6ed60702759 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 13 May 2025 12:49:16 -0400 Subject: [PATCH 1/6] Remove insertCtx from parameters and replace fix with expr rewriting --- core/parameters.rs | 83 ++++--------------------------------- core/translate/insert.rs | 33 +++++---------- core/translate/mod.rs | 4 +- core/translate/optimizer.rs | 57 +++++++++++++++---------- core/translate/plan.rs | 2 +- 5 files changed, 57 insertions(+), 122 deletions(-) diff --git a/core/parameters.rs b/core/parameters.rs index a3dab16c8..ff0f99033 100644 --- a/core/parameters.rs +++ b/core/parameters.rs @@ -1,6 +1,7 @@ -use super::ast; use std::num::NonZero; +pub const PARAM_PREFIX: &str = "__param_"; + #[derive(Clone, Debug)] pub enum Parameter { Anonymous(NonZero), @@ -24,52 +25,10 @@ impl Parameter { } } -#[derive(Debug)] -struct InsertContext { - param_positions: Vec, - current_col_value_idx: usize, -} - -impl InsertContext { - fn new(param_positions: Vec) -> Self { - Self { - param_positions, - current_col_value_idx: 0, - } - } - - /// Find the relevant parameter index needed for the current value index of insert stmt - /// Example for table t (a,b,c): - /// `insert into t (c,a,b) values (?,?,?)` - /// - /// col a -> value_index 1 - /// col b -> value_index 2 - /// col c -> value_index 0 - /// - /// however translation will always result in parameters 1, 2, 3 - /// because columns are translated in the table order so `col a` gets - /// translated first, translate_expr calls parameters.push and always gets index 1. - /// - /// Instead, we created an array representing all the value_index's that are type - /// Expr::Variable, in the case above would be [1, 2, 0], and stored it in insert_ctx. - /// That array can be used to look up the necessary parameter index by searching for the value - /// index in the array and returning the index of that value + 1. - /// value_index-> [1, 2, 0] - /// param index-> |0, 1, 2| - fn get_insert_param_index(&self) -> Option> { - self.param_positions - .iter() - .position(|param| param.eq(&self.current_col_value_idx)) - .map(|p| NonZero::new(p + 1).unwrap()) - } -} - #[derive(Debug)] pub struct Parameters { index: NonZero, pub list: Vec, - // Context for reordering parameters during insert statements - insert_ctx: Option, } impl Default for Parameters { @@ -83,7 +42,6 @@ impl Parameters { Self { index: 1.try_into().unwrap(), list: vec![], - insert_ctx: None, } } @@ -93,18 +51,6 @@ impl Parameters { params.len() } - /// Begin preparing for an Insert statement by providing the array of values from the Insert body. - pub fn init_insert_parameters(&mut self, values: &[Vec]) { - self.insert_ctx = Some(InsertContext::new(expected_param_indicies(values))); - } - - /// Set the value index for the column currently being translated for an Insert stmt. - pub fn set_insert_value_index(&mut self, idx: usize) { - if let Some(ctx) = &mut self.insert_ctx { - ctx.current_col_value_idx = idx; - } - } - pub fn name(&self, index: NonZero) -> Option { self.list.iter().find_map(|p| match p { Parameter::Anonymous(i) if *i == index => Some("?".to_string()), @@ -136,11 +82,7 @@ impl Parameters { let index = self.next_index(); self.list.push(Parameter::Anonymous(index)); tracing::trace!("anonymous parameter at {index}"); - if let Some(idx) = &self.insert_ctx { - idx.get_insert_param_index().unwrap_or(index) - } else { - index - } + index } name if name.starts_with(['$', ':', '@', '#']) => { match self @@ -163,8 +105,12 @@ impl Parameters { } } index => { - // SAFETY: Guaranteed from parser that the index is bigger than 0. - let index: NonZero = index.parse().unwrap(); + let index: NonZero = if let Some(idx) = index.strip_prefix(PARAM_PREFIX) { + idx.parse().unwrap() + } else { + // SAFETY: Guaranteed from parser that the index is bigger than 0. + index.parse().unwrap() + }; if index > self.index { self.index = index.checked_add(1).unwrap(); } @@ -175,14 +121,3 @@ impl Parameters { } } } - -/// Gather all the expected indicies of all Expr::Variable -/// in the provided array of insert values. -pub fn expected_param_indicies(cols: &[Vec]) -> Vec { - cols.iter() - .flat_map(|col| col.iter()) - .enumerate() - .filter(|(_, col)| matches!(col, ast::Expr::Variable(_))) - .map(|(i, _)| i) - .collect::>() -} diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 233d623c9..c7cff3a88 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -1,4 +1,4 @@ -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use std::rc::Rc; use limbo_sqlite3_parser::ast::{ @@ -22,6 +22,7 @@ use crate::{Result, SymbolTable, VirtualTable}; use super::emitter::Resolver; use super::expr::{translate_expr_no_constant_opt, NoConstantOptReason}; +use super::optimizer::rewrite_expr; #[allow(clippy::too_many_arguments)] pub fn translate_insert( @@ -31,7 +32,7 @@ pub fn translate_insert( on_conflict: &Option, tbl_name: &QualifiedName, columns: &Option, - body: &InsertBody, + body: &mut InsertBody, _returning: &Option>, syms: &SymbolTable, ) -> Result { @@ -100,14 +101,16 @@ pub fn translate_insert( .collect::>(); let root_page = btree_table.root_page; let values = match body { - InsertBody::Select(select, _) => match &select.body.select.deref() { - OneSelect::Values(values) => values, + InsertBody::Select(ref mut select, _) => match select.body.select.as_mut() { + OneSelect::Values(ref mut values) => values, _ => todo!(), }, - InsertBody::DefaultValues => &vec![vec![]], + InsertBody::DefaultValues => &mut vec![vec![]], }; - // prepare parameters by tracking the number of variables we will be binding to values later on - program.parameters.init_insert_parameters(values); + let mut param_idx = 1; + for mut expr in values.iter_mut().flat_map(|v| v.iter_mut()) { + rewrite_expr(&mut expr, &mut param_idx)?; + } let column_mappings = resolve_columns_for_insert(&table, columns, values)?; let index_col_mappings = resolve_indicies_for_insert(schema, table.as_ref(), &column_mappings)?; @@ -154,9 +157,8 @@ pub fn translate_insert( program.resolve_label(start_offset_label, program.offset()); - for (i, value) in values.iter().enumerate() { + for value in values.iter() { populate_column_registers( - i, &mut program, value, &column_mappings, @@ -194,7 +196,6 @@ pub fn translate_insert( }); populate_column_registers( - 0, &mut program, &values[0], &column_mappings, @@ -586,7 +587,6 @@ fn resolve_indicies_for_insert( /// Populates the column registers with values for a single row #[allow(clippy::too_many_arguments)] fn populate_column_registers( - row_idx: usize, program: &mut ProgramBuilder, value: &[Expr], column_mappings: &[ColumnMapping], @@ -610,14 +610,6 @@ fn populate_column_registers( } else { target_reg }; - // We need the 'parameters' to be aware of the value_index of the current row - // so it can map it to the correct parameter index in the Variable opcode - // but we need to make sure the value_index is not overwritten if this is a multi-row - // insert. For 'insert into t values: (?,?), (?,?);' - // value_index should be (1,2),(3,4) instead of (1,2),(1,2), so multiply by col length - program - .parameters - .set_insert_value_index(value_index + (column_mappings.len() * row_idx)); translate_expr_no_constant_opt( program, None, @@ -681,8 +673,6 @@ fn translate_virtual_table_insert( InsertBody::DefaultValues => &vec![], _ => crate::bail_parse_error!("Unsupported INSERT body for virtual tables"), }; - // initiate parameters by tracking the number of variables we will be binding to values - program.parameters.init_insert_parameters(values); let table = Table::Virtual(virtual_table.clone()); let column_mappings = resolve_columns_for_insert(&table, columns, values)?; let registers_start = program.alloc_registers(2); @@ -701,7 +691,6 @@ fn translate_virtual_table_insert( let values_reg = program.alloc_registers(column_mappings.len()); populate_column_registers( - 0, program, &values[0], &column_mappings, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 492c33b13..fa0b6d343 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -198,7 +198,7 @@ pub fn translate( or_conflict, tbl_name, columns, - body, + mut body, returning, } = *insert; change_cnt_on = true; @@ -209,7 +209,7 @@ pub fn translate( &or_conflict, &tbl_name, &columns, - &body, + &mut body, &returning, syms, )? diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 1409bd9af..75b5ee834 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -3,6 +3,7 @@ use std::{cmp::Ordering, collections::HashMap, sync::Arc}; use limbo_sqlite3_parser::ast::{self, Expr, SortOrder}; use crate::{ + parameters::PARAM_PREFIX, schema::{Index, IndexColumn, Schema}, translate::plan::TerminationKey, types::SeekOp, @@ -416,23 +417,24 @@ fn eliminate_constant_conditions( } fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> { + let mut param_count = 1; for rc in plan.result_columns.iter_mut() { - rewrite_expr(&mut rc.expr)?; + rewrite_expr(&mut rc.expr, &mut param_count)?; } for agg in plan.aggregates.iter_mut() { - rewrite_expr(&mut agg.original_expr)?; + rewrite_expr(&mut agg.original_expr, &mut param_count)?; } for cond in plan.where_clause.iter_mut() { - rewrite_expr(&mut cond.expr)?; + rewrite_expr(&mut cond.expr, &mut param_count)?; } if let Some(group_by) = &mut plan.group_by { for expr in group_by.exprs.iter_mut() { - rewrite_expr(expr)?; + rewrite_expr(expr, &mut param_count)?; } } if let Some(order_by) = &mut plan.order_by { for (expr, _) in order_by.iter_mut() { - rewrite_expr(expr)?; + rewrite_expr(expr, &mut param_count)?; } } @@ -440,27 +442,29 @@ fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> { } fn rewrite_exprs_delete(plan: &mut DeletePlan) -> Result<()> { + let mut param_idx = 1; for cond in plan.where_clause.iter_mut() { - rewrite_expr(&mut cond.expr)?; + rewrite_expr(&mut cond.expr, &mut param_idx)?; } Ok(()) } fn rewrite_exprs_update(plan: &mut UpdatePlan) -> Result<()> { - if let Some(rc) = plan.returning.as_mut() { - for rc in rc.iter_mut() { - rewrite_expr(&mut rc.expr)?; - } - } + let mut param_idx = 1; for (_, expr) in plan.set_clauses.iter_mut() { - rewrite_expr(expr)?; + rewrite_expr(expr, &mut param_idx)?; } for cond in plan.where_clause.iter_mut() { - rewrite_expr(&mut cond.expr)?; + rewrite_expr(&mut cond.expr, &mut param_idx)?; } if let Some(order_by) = &mut plan.order_by { for (expr, _) in order_by.iter_mut() { - rewrite_expr(expr)?; + rewrite_expr(expr, &mut param_idx)?; + } + } + if let Some(rc) = plan.returning.as_mut() { + for rc in rc.iter_mut() { + rewrite_expr(&mut rc.expr, &mut param_idx)?; } } Ok(()) @@ -1856,7 +1860,7 @@ pub fn try_extract_rowid_search_expression( } } -fn rewrite_expr(expr: &mut ast::Expr) -> Result<()> { +pub fn rewrite_expr(expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> { match expr { ast::Expr::Id(id) => { // Convert "true" and "false" to 1 and 0 @@ -1870,6 +1874,13 @@ fn rewrite_expr(expr: &mut ast::Expr) -> Result<()> { } Ok(()) } + ast::Expr::Variable(var) => { + if var.is_empty() { + *expr = ast::Expr::Variable(format!("{}{param_idx}", PARAM_PREFIX)); + *param_idx += 1; + } + Ok(()) + } ast::Expr::Between { lhs, not, @@ -1884,9 +1895,9 @@ fn rewrite_expr(expr: &mut ast::Expr) -> Result<()> { (ast::Operator::LessEquals, ast::Operator::LessEquals) }; - rewrite_expr(start)?; - rewrite_expr(lhs)?; - rewrite_expr(end)?; + rewrite_expr(start, param_idx)?; + rewrite_expr(lhs, param_idx)?; + rewrite_expr(end, param_idx)?; let start = start.take_ownership(); let lhs = lhs.take_ownership(); @@ -1912,7 +1923,7 @@ fn rewrite_expr(expr: &mut ast::Expr) -> Result<()> { } ast::Expr::Parenthesized(ref mut exprs) => { for subexpr in exprs.iter_mut() { - rewrite_expr(subexpr)?; + rewrite_expr(subexpr, param_idx)?; } let exprs = std::mem::take(exprs); *expr = ast::Expr::Parenthesized(exprs); @@ -1920,20 +1931,20 @@ fn rewrite_expr(expr: &mut ast::Expr) -> Result<()> { } // Process other expressions recursively ast::Expr::Binary(lhs, _, rhs) => { - rewrite_expr(lhs)?; - rewrite_expr(rhs)?; + rewrite_expr(lhs, param_idx)?; + rewrite_expr(rhs, param_idx)?; Ok(()) } ast::Expr::FunctionCall { args, .. } => { if let Some(args) = args { for arg in args.iter_mut() { - rewrite_expr(arg)?; + rewrite_expr(arg, param_idx)?; } } Ok(()) } ast::Expr::Unary(_, arg) => { - rewrite_expr(arg)?; + rewrite_expr(arg, param_idx)?; Ok(()) } _ => Ok(()), diff --git a/core/translate/plan.rs b/core/translate/plan.rs index e34efb29b..a422b6fed 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -379,7 +379,7 @@ impl SelectPlan { name: limbo_sqlite3_parser::ast::Id("count".to_string()), filter_over: None, }; - let result_col_expr = &self.result_columns.get(0).unwrap().expr; + let result_col_expr = &self.result_columns.first().unwrap().expr; if *result_col_expr != count && *result_col_expr != count_star { return false; } From e91d17f06e95b4fa6bfb69fde9764c6078d3811d Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 13 May 2025 12:50:10 -0400 Subject: [PATCH 2/6] Add tests for parameter binding for update, select and delete queries --- .../query_processing/test_read_path.rs | 280 ++++++++++++++++++ 1 file changed, 280 insertions(+) diff --git a/tests/integration/query_processing/test_read_path.rs b/tests/integration/query_processing/test_read_path.rs index f7bbe5cad..fd34cb059 100644 --- a/tests/integration/query_processing/test_read_path.rs +++ b/tests/integration/query_processing/test_read_path.rs @@ -479,3 +479,283 @@ fn test_insert_parameter_multiple_row() -> anyhow::Result<()> { assert_eq!(ins.parameters().count(), 8); Ok(()) } + +#[test] +fn test_bind_parameters_update_query() -> anyhow::Result<()> { + let tmp_db = TempDatabase::new_with_rusqlite("create table test (a integer, b text);"); + let conn = tmp_db.connect_limbo(); + let mut ins = conn.prepare("insert into test (a, b) values (3, 'test1');")?; + loop { + match ins.step()? { + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + _ => {} + } + } + let mut ins = conn.prepare("update test set a = ? where b = ?;")?; + ins.bind_at(1.try_into()?, OwnedValue::Integer(222)); + ins.bind_at(2.try_into()?, OwnedValue::build_text("test1")); + loop { + match ins.step()? { + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + _ => {} + } + } + + let mut sel = conn.prepare("select a, b from test;")?; + loop { + match sel.step()? { + StepResult::Row => { + let row = sel.row().unwrap(); + assert_eq!( + row.get::<&OwnedValue>(0).unwrap(), + &OwnedValue::Integer(222) + ); + assert_eq!( + row.get::<&OwnedValue>(1).unwrap(), + &OwnedValue::build_text("test1"), + ); + } + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + } + } + assert_eq!(ins.parameters().count(), 2); + Ok(()) +} + +#[test] +fn test_bind_parameters_update_query_multiple_where() -> anyhow::Result<()> { + let tmp_db = TempDatabase::new_with_rusqlite( + "create table test (a integer, b text, c integer, d integer);", + ); + let conn = tmp_db.connect_limbo(); + let mut ins = conn.prepare("insert into test (a, b, c, d) values (3, 'test1', 4, 5);")?; + loop { + match ins.step()? { + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + _ => {} + } + } + let mut ins = conn.prepare("update test set a = ? where b = ? and c = 4 and d = ?;")?; + ins.bind_at(1.try_into()?, OwnedValue::Integer(222)); + ins.bind_at(2.try_into()?, OwnedValue::build_text("test1")); + ins.bind_at(3.try_into()?, OwnedValue::Integer(5)); + loop { + match ins.step()? { + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + _ => {} + } + } + + let mut sel = conn.prepare("select a, b, c, d from test;")?; + loop { + match sel.step()? { + StepResult::Row => { + let row = sel.row().unwrap(); + assert_eq!( + row.get::<&OwnedValue>(0).unwrap(), + &OwnedValue::Integer(222) + ); + assert_eq!( + row.get::<&OwnedValue>(1).unwrap(), + &OwnedValue::build_text("test1"), + ); + assert_eq!(row.get::<&OwnedValue>(2).unwrap(), &OwnedValue::Integer(4)); + assert_eq!(row.get::<&OwnedValue>(3).unwrap(), &OwnedValue::Integer(5)); + } + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + } + } + assert_eq!(ins.parameters().count(), 3); + Ok(()) +} + +#[test] +fn test_bind_parameters_update_rowid_alias() -> anyhow::Result<()> { + let tmp_db = + TempDatabase::new_with_rusqlite("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT);"); + let conn = tmp_db.connect_limbo(); + let mut ins = conn.prepare("insert into test (id, name) values (1, 'test');")?; + loop { + match ins.step()? { + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + _ => {} + } + } + + let mut sel = conn.prepare("select id, name from test;")?; + loop { + match sel.step()? { + StepResult::Row => { + let row = sel.row().unwrap(); + assert_eq!(row.get::<&OwnedValue>(0).unwrap(), &OwnedValue::Integer(1)); + assert_eq!( + row.get::<&OwnedValue>(1).unwrap(), + &OwnedValue::build_text("test"), + ); + } + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + } + } + let mut ins = conn.prepare("update test set name = ? where id = ?;")?; + ins.bind_at(1.try_into()?, OwnedValue::build_text("updated")); + ins.bind_at(2.try_into()?, OwnedValue::Integer(1)); + loop { + match ins.step()? { + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + _ => {} + } + } + + let mut sel = conn.prepare("select id, name from test;")?; + loop { + match sel.step()? { + StepResult::Row => { + let row = sel.row().unwrap(); + assert_eq!(row.get::<&OwnedValue>(0).unwrap(), &OwnedValue::Integer(1)); + assert_eq!( + row.get::<&OwnedValue>(1).unwrap(), + &OwnedValue::build_text("updated"), + ); + } + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + } + } + assert_eq!(ins.parameters().count(), 2); + Ok(()) +} + +#[test] +fn test_bind_parameters_update_rowid_alias_seek_rowid() -> anyhow::Result<()> { + let tmp_db = TempDatabase::new_with_rusqlite( + "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT, age integer);", + ); + let conn = tmp_db.connect_limbo(); + conn.execute("insert into test (id, name, age) values (1, 'test', 4);")?; + conn.execute("insert into test (id, name, age) values (2, 'test', 11);")?; + + let mut sel = conn.prepare("select id, name, age from test;")?; + let mut i = 0; + loop { + match sel.step()? { + StepResult::Row => { + let row = sel.row().unwrap(); + assert_eq!( + row.get::<&OwnedValue>(0).unwrap(), + &OwnedValue::Integer(if i == 0 { 1 } else { 2 }) + ); + assert_eq!( + row.get::<&OwnedValue>(1).unwrap(), + &OwnedValue::build_text("test"), + ); + assert_eq!( + row.get::<&OwnedValue>(2).unwrap(), + &OwnedValue::Integer(if i == 0 { 4 } else { 11 }) + ); + } + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + } + i += 1; + } + let mut ins = conn.prepare("update test set name = ? where id < ? AND age between ? and ?;")?; + ins.bind_at(1.try_into()?, OwnedValue::build_text("updated")); + ins.bind_at(2.try_into()?, OwnedValue::Integer(2)); + ins.bind_at(3.try_into()?, OwnedValue::Integer(3)); + ins.bind_at(4.try_into()?, OwnedValue::Integer(5)); + loop { + match ins.step()? { + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + _ => {} + } + } + + let mut sel = conn.prepare("select name from test;")?; + let mut i = 0; + loop { + match sel.step()? { + StepResult::Row => { + let row = sel.row().unwrap(); + assert_eq!( + row.get::<&OwnedValue>(0).unwrap(), + &OwnedValue::build_text(if i == 0 { "updated" } else { "test" }), + ); + } + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + } + i += 1; + } + + assert_eq!(ins.parameters().count(), 4); + Ok(()) +} + +#[test] +fn test_bind_parameters_delete_rowid_alias_seek_out_of_order() -> anyhow::Result<()> { + let tmp_db = TempDatabase::new_with_rusqlite( + "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT, age integer);", + ); + let conn = tmp_db.connect_limbo(); + conn.execute("insert into test (id, name, age) values (1, 'correct', 4);")?; + conn.execute("insert into test (id, name, age) values (5, 'test', 11);")?; + + let mut ins = + conn.prepare("delete from test where age between ? and ? AND id > ? AND name = ?;")?; + ins.bind_at(1.try_into()?, OwnedValue::Integer(10)); + ins.bind_at(2.try_into()?, OwnedValue::Integer(12)); + ins.bind_at(3.try_into()?, OwnedValue::Integer(4)); + ins.bind_at(4.try_into()?, OwnedValue::build_text("test")); + loop { + match ins.step()? { + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + _ => {} + } + } + + let mut sel = conn.prepare("select name from test;")?; + let mut i = 0; + loop { + match sel.step()? { + StepResult::Row => { + let row = sel.row().unwrap(); + assert_eq!( + row.get::<&OwnedValue>(0).unwrap(), + &OwnedValue::build_text("correct"), + ); + } + StepResult::IO => tmp_db.io.run_once()?, + StepResult::Done | StepResult::Interrupt => break, + StepResult::Busy => panic!("database busy"), + } + i += 1; + } + assert_eq!(i, 1); + assert_eq!(ins.parameters().count(), 4); + Ok(()) +} From 16ac6ab9189b17b6f5f568217fbfb865ec2dede9 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 13 May 2025 14:33:11 -0400 Subject: [PATCH 3/6] Fix parameter push method to re-convert anonymous parameters --- core/parameters.rs | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/core/parameters.rs b/core/parameters.rs index ff0f99033..a7009f65e 100644 --- a/core/parameters.rs +++ b/core/parameters.rs @@ -105,18 +105,21 @@ impl Parameters { } } index => { - let index: NonZero = if let Some(idx) = index.strip_prefix(PARAM_PREFIX) { - idx.parse().unwrap() + if let Some(idx) = index.strip_prefix(PARAM_PREFIX) { + let idx: NonZero = idx.parse().unwrap(); + self.next_index(); + self.list.push(Parameter::Anonymous(idx)); + idx } else { // SAFETY: Guaranteed from parser that the index is bigger than 0. - index.parse().unwrap() - }; - if index > self.index { - self.index = index.checked_add(1).unwrap(); + let index: NonZero = index.parse().unwrap(); + if index > self.index { + self.index = index.checked_add(1).unwrap(); + } + self.list.push(Parameter::Indexed(index)); + tracing::trace!("indexed parameter at {index}"); + index } - self.list.push(Parameter::Indexed(index)); - tracing::trace!("indexed parameter at {index}"); - index } } } From 94aa9cd99dfc756ccc9da5b51f760245bfc62678 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 13 May 2025 14:33:45 -0400 Subject: [PATCH 4/6] Add cases to rewrite_expr in the optimizer --- core/translate/optimizer.rs | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 75b5ee834..38cdb6343 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -1876,6 +1876,8 @@ pub fn rewrite_expr(expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> { } ast::Expr::Variable(var) => { if var.is_empty() { + // rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and + // all the expressions are rewritten in the order they come in the statement *expr = ast::Expr::Variable(format!("{}{param_idx}", PARAM_PREFIX)); *param_idx += 1; } @@ -1935,6 +1937,42 @@ pub fn rewrite_expr(expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> { rewrite_expr(rhs, param_idx)?; Ok(()) } + ast::Expr::Like { + lhs, rhs, escape, .. + } => { + rewrite_expr(lhs, param_idx)?; + rewrite_expr(rhs, param_idx)?; + if let Some(escape) = escape { + rewrite_expr(escape, param_idx)?; + } + Ok(()) + } + ast::Expr::Case { + base, + when_then_pairs, + else_expr, + } => { + if let Some(base) = base { + rewrite_expr(base, param_idx)?; + } + for (lhs, rhs) in when_then_pairs.iter_mut() { + rewrite_expr(lhs, param_idx)?; + rewrite_expr(rhs, param_idx)?; + } + if let Some(else_expr) = else_expr { + rewrite_expr(else_expr, param_idx)?; + } + Ok(()) + } + ast::Expr::InList { lhs, rhs, .. } => { + rewrite_expr(lhs, param_idx)?; + if let Some(rhs) = rhs { + for expr in rhs.iter_mut() { + rewrite_expr(expr, param_idx)?; + } + } + Ok(()) + } ast::Expr::FunctionCall { args, .. } => { if let Some(args) = args { for arg in args.iter_mut() { From 2f255524bd4f436869869f713272715f8bfd4ae2 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 13 May 2025 14:34:22 -0400 Subject: [PATCH 5/6] Remove unused import and unnecessary mut annotations in insert.rs --- core/translate/insert.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/translate/insert.rs b/core/translate/insert.rs index c7cff3a88..9250cc1c7 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -1,4 +1,4 @@ -use std::ops::{Deref, DerefMut}; +use std::ops::Deref; use std::rc::Rc; use limbo_sqlite3_parser::ast::{ @@ -108,8 +108,8 @@ pub fn translate_insert( InsertBody::DefaultValues => &mut vec![vec![]], }; let mut param_idx = 1; - for mut expr in values.iter_mut().flat_map(|v| v.iter_mut()) { - rewrite_expr(&mut expr, &mut param_idx)?; + for expr in values.iter_mut().flat_map(|v| v.iter_mut()) { + rewrite_expr(expr, &mut param_idx)?; } let column_mappings = resolve_columns_for_insert(&table, columns, values)?; From a0b2b6e85d68140d67c45abf26333509e04346b8 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 13 May 2025 14:42:12 -0400 Subject: [PATCH 6/6] Consolidate match case in parameters push to handle all anonymous params in one case --- core/parameters.rs | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/core/parameters.rs b/core/parameters.rs index a7009f65e..bb29c5008 100644 --- a/core/parameters.rs +++ b/core/parameters.rs @@ -78,11 +78,16 @@ impl Parameters { pub fn push(&mut self, name: impl AsRef) -> NonZero { match name.as_ref() { - "" => { + param if param.is_empty() || param.starts_with(PARAM_PREFIX) => { let index = self.next_index(); - self.list.push(Parameter::Anonymous(index)); - tracing::trace!("anonymous parameter at {index}"); - index + let use_idx = if let Some(idx) = param.strip_prefix(PARAM_PREFIX) { + idx.parse().unwrap() + } else { + index + }; + self.list.push(Parameter::Anonymous(use_idx)); + tracing::trace!("anonymous parameter at {use_idx}"); + use_idx } name if name.starts_with(['$', ':', '@', '#']) => { match self @@ -105,21 +110,14 @@ impl Parameters { } } index => { - if let Some(idx) = index.strip_prefix(PARAM_PREFIX) { - let idx: NonZero = idx.parse().unwrap(); - self.next_index(); - self.list.push(Parameter::Anonymous(idx)); - idx - } else { - // SAFETY: Guaranteed from parser that the index is bigger than 0. - let index: NonZero = index.parse().unwrap(); - if index > self.index { - self.index = index.checked_add(1).unwrap(); - } - self.list.push(Parameter::Indexed(index)); - tracing::trace!("indexed parameter at {index}"); - index + // SAFETY: Guaranteed from parser that the index is bigger than 0. + let index: NonZero = index.parse().unwrap(); + if index > self.index { + self.index = index.checked_add(1).unwrap(); } + self.list.push(Parameter::Indexed(index)); + tracing::trace!("indexed parameter at {index}"); + index } } }