diff --git a/core/parameters.rs b/core/parameters.rs index bb29c5008..6abe5478c 100644 --- a/core/parameters.rs +++ b/core/parameters.rs @@ -1,10 +1,7 @@ use std::num::NonZero; -pub const PARAM_PREFIX: &str = "__param_"; - #[derive(Clone, Debug)] pub enum Parameter { - Anonymous(NonZero), Indexed(NonZero), Named(String, NonZero), } @@ -18,7 +15,6 @@ impl PartialEq for Parameter { impl Parameter { pub fn index(&self) -> NonZero { match self { - Parameter::Anonymous(index) => *index, Parameter::Indexed(index) => *index, Parameter::Named(_, index) => *index, } @@ -27,7 +23,7 @@ impl Parameter { #[derive(Debug)] pub struct Parameters { - index: NonZero, + next_index: NonZero, pub list: Vec, } @@ -40,7 +36,7 @@ impl Default for Parameters { impl Parameters { pub fn new() -> Self { Self { - index: 1.try_into().unwrap(), + next_index: 1.try_into().unwrap(), list: vec![], } } @@ -53,7 +49,6 @@ impl Parameters { pub fn name(&self, index: NonZero) -> Option { self.list.iter().find_map(|p| match p { - Parameter::Anonymous(i) if *i == index => Some("?".to_string()), Parameter::Indexed(i) if *i == index => Some(format!("?{i}")), Parameter::Named(name, i) if *i == index => Some(name.to_owned()), _ => None, @@ -71,24 +66,13 @@ impl Parameters { } pub fn next_index(&mut self) -> NonZero { - let index = self.index; - self.index = self.index.checked_add(1).unwrap(); + let index = self.next_index; + self.next_index = self.next_index.checked_add(1).unwrap(); index } pub fn push(&mut self, name: impl AsRef) -> NonZero { match name.as_ref() { - param if param.is_empty() || param.starts_with(PARAM_PREFIX) => { - let index = self.next_index(); - let use_idx = if let Some(idx) = param.strip_prefix(PARAM_PREFIX) { - idx.parse().unwrap() - } else { - index - }; - self.list.push(Parameter::Anonymous(use_idx)); - tracing::trace!("anonymous parameter at {use_idx}"); - use_idx - } name if name.starts_with(['$', ':', '@', '#']) => { match self .list @@ -112,8 +96,8 @@ impl Parameters { index => { // SAFETY: Guaranteed from parser that the index is bigger than 0. let index: NonZero = index.parse().unwrap(); - if index > self.index { - self.index = index.checked_add(1).unwrap(); + if index >= self.next_index { + self.next_index = index.checked_add(1).unwrap(); } self.list.push(Parameter::Indexed(index)); tracing::trace!("indexed parameter at {index}"); diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 50d416aed..c7558562d 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -10,7 +10,6 @@ use super::plan::TableReferences; use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc, VectorFunc}; use crate::functions::datetime; -use crate::parameters::PARAM_PREFIX; use crate::schema::{affinity, Affinity, Table, Type}; use crate::translate::optimizer::TakeOwnership; use crate::translate::plan::ResultSetColumn; @@ -3244,26 +3243,25 @@ where Ok(WalkControl::Continue) } -/// Context needed to walk all expressions in a INSERT|UPDATE|SELECT|DELETE body, -/// in the order they are encountered, to ensure that the parameters are rewritten from -/// anonymous ("?") to our internal named scheme so when the columns are re-ordered we are able -/// to bind the proper parameter values. pub struct ParamState { - /// ALWAYS starts at 1 - pub next_param_idx: usize, + // flag which allow or forbid usage of parameters during translation of AST to the program + // + // for example, parameters are not allowed in the partial index definition + // so tursodb set allowed to false when it parsed WHERE clause of partial index definition + pub allowed: bool, } impl Default for ParamState { fn default() -> Self { - Self { next_param_idx: 1 } + Self { allowed: true } } } impl ParamState { pub fn is_valid(&self) -> bool { - self.next_param_idx > 0 + self.allowed } pub fn disallow() -> Self { - Self { next_param_idx: 0 } + Self { allowed: false } } } @@ -3296,16 +3294,10 @@ pub fn bind_and_rewrite_expr<'a>( top_level_expr, &mut |expr: &mut ast::Expr| -> Result { match expr { - // Rewrite anonymous variables in encounter order. - ast::Expr::Variable(var) if var.is_empty() => { + ast::Expr::Variable(_) => { if !param_state.is_valid() { crate::bail_parse_error!("Parameters are not allowed in this context"); } - *expr = ast::Expr::Variable(format!( - "{}{}", - PARAM_PREFIX, param_state.next_param_idx - )); - param_state.next_param_idx += 1; } ast::Expr::Between { lhs, diff --git a/parser/src/parser.rs b/parser/src/parser.rs index 87ed5cabf..b4188df44 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -140,6 +140,10 @@ pub struct Parser<'a> { /// The current token being processed current_token: Token<'a>, peekable: bool, + + /// Last assigned id of positional variable + /// Parser tracks that in order to properly auto-assign variable ids in correct order for anonymous parameters '?' + last_variable_id: u32, } impl<'a> Iterator for Parser<'a> { @@ -165,6 +169,25 @@ impl<'a> Parser<'a> { value: &input[..0], token_type: None, }, + last_variable_id: 0, + } + } + + fn create_variable(&mut self, token: &[u8]) -> Result { + if token.is_empty() { + // Rewrite anonymous variables in encounter order + self.last_variable_id += 1; + Ok(Expr::Variable(format!("{}", self.last_variable_id))) + } else if matches!(token[0], b':' | b'@' | b'$' | b'#') { + Ok(Expr::Variable(from_bytes(token))) + } else { + let variable_str = std::str::from_utf8(token) + .map_err(|e| Error::Custom(format!("non-utf8 positional variable id: {e}")))?; + let variable_id = variable_str + .parse::() + .map_err(|e| Error::Custom(format!("non-integer positional variable id: {e}")))?; + self.last_variable_id = variable_id; + Ok(Expr::Variable(from_bytes(token))) } } @@ -1344,7 +1367,7 @@ impl<'a> Parser<'a> { } TK_VARIABLE => { let tok = eat_assert!(self, TK_VARIABLE); - Ok(Box::new(Expr::Variable(from_bytes(tok.value)))) + Ok(Box::new(self.create_variable(tok.value)?)) } TK_CAST => { eat_assert!(self, TK_CAST); @@ -4387,7 +4410,7 @@ mod tests { select: OneSelect::Select { distinctness: None, columns: vec![ResultColumn::Expr( - Box::new(Expr::Variable("".to_owned())), + Box::new(Expr::Variable("1".to_owned())), None, )], from: None, diff --git a/tests/integration/query_processing/test_read_path.rs b/tests/integration/query_processing/test_read_path.rs index cbc53cfbe..a0b72cd60 100644 --- a/tests/integration/query_processing/test_read_path.rs +++ b/tests/integration/query_processing/test_read_path.rs @@ -827,3 +827,52 @@ fn test_offset_limit_bind() -> anyhow::Result<()> { Ok(()) } + +#[test] +fn test_upsert_parameters_order() -> anyhow::Result<()> { + let tmp_db = TempDatabase::new_with_rusqlite( + "CREATE TABLE test (k INTEGER PRIMARY KEY, v INTEGER);", + false, + ); + let conn = tmp_db.connect_limbo(); + + conn.execute("INSERT INTO test VALUES (1, 2), (3, 4)")?; + let mut stmt = + conn.prepare("INSERT INTO test VALUES (?, ?), (?, ?) ON CONFLICT DO UPDATE SET v = ?")?; + stmt.bind_at(1.try_into()?, Value::Integer(1)); + stmt.bind_at(2.try_into()?, Value::Integer(20)); + stmt.bind_at(3.try_into()?, Value::Integer(3)); + stmt.bind_at(4.try_into()?, Value::Integer(40)); + stmt.bind_at(5.try_into()?, Value::Integer(66)); + while let StepResult::Row | StepResult::IO = stmt.step()? { + stmt.run_once()?; + } + + let mut rows = Vec::new(); + let mut stmt = conn.prepare("SELECT * FROM test")?; + loop { + match stmt.step()? { + StepResult::Row => { + let row = stmt.row().unwrap(); + rows.push(row.get_values().cloned().collect::>()); + } + StepResult::IO => stmt.run_once()?, + _ => break, + } + } + + assert_eq!( + rows, + vec![ + vec![ + turso_core::Value::Integer(1), + turso_core::Value::Integer(66) + ], + vec![ + turso_core::Value::Integer(3), + turso_core::Value::Integer(66) + ] + ] + ); + Ok(()) +}