mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-09 11:14:20 +01:00
Merge 'Anonymous params fix' from Nikita Sivukhin
This PR auto-assign ids for anonymous variables straight into parser.
Otherwise - it's pretty easy to mess up with traversal order in the core
code and assign ids incorrectly.
For example, before the fix, following code worked incorrectly because
parameter values were assigned first to conflict clause instead of
values:
```rs
let mut stmt = conn.prepare("INSERT INTO test VALUES (?, ?), (?, ?) ON CONFLICT DO UPDATE SET v = ?")?;
stmt.bind_at(1.try_into()?, Value::Integer(1));
stmt.bind_at(2.try_into()?, Value::Integer(20));
stmt.bind_at(3.try_into()?, Value::Integer(3));
stmt.bind_at(4.try_into()?, Value::Integer(40));
stmt.bind_at(5.try_into()?, Value::Integer(66));
```
Closes #3455
This commit is contained in:
@@ -1,10 +1,7 @@
|
||||
use std::num::NonZero;
|
||||
|
||||
pub const PARAM_PREFIX: &str = "__param_";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Parameter {
|
||||
Anonymous(NonZero<usize>),
|
||||
Indexed(NonZero<usize>),
|
||||
Named(String, NonZero<usize>),
|
||||
}
|
||||
@@ -18,7 +15,6 @@ impl PartialEq for Parameter {
|
||||
impl Parameter {
|
||||
pub fn index(&self) -> NonZero<usize> {
|
||||
match self {
|
||||
Parameter::Anonymous(index) => *index,
|
||||
Parameter::Indexed(index) => *index,
|
||||
Parameter::Named(_, index) => *index,
|
||||
}
|
||||
@@ -27,7 +23,7 @@ impl Parameter {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Parameters {
|
||||
index: NonZero<usize>,
|
||||
next_index: NonZero<usize>,
|
||||
pub list: Vec<Parameter>,
|
||||
}
|
||||
|
||||
@@ -40,7 +36,7 @@ impl Default for Parameters {
|
||||
impl Parameters {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
index: 1.try_into().unwrap(),
|
||||
next_index: 1.try_into().unwrap(),
|
||||
list: vec![],
|
||||
}
|
||||
}
|
||||
@@ -53,7 +49,6 @@ impl Parameters {
|
||||
|
||||
pub fn name(&self, index: NonZero<usize>) -> Option<String> {
|
||||
self.list.iter().find_map(|p| match p {
|
||||
Parameter::Anonymous(i) if *i == index => Some("?".to_string()),
|
||||
Parameter::Indexed(i) if *i == index => Some(format!("?{i}")),
|
||||
Parameter::Named(name, i) if *i == index => Some(name.to_owned()),
|
||||
_ => None,
|
||||
@@ -71,24 +66,13 @@ impl Parameters {
|
||||
}
|
||||
|
||||
pub fn next_index(&mut self) -> NonZero<usize> {
|
||||
let index = self.index;
|
||||
self.index = self.index.checked_add(1).unwrap();
|
||||
let index = self.next_index;
|
||||
self.next_index = self.next_index.checked_add(1).unwrap();
|
||||
index
|
||||
}
|
||||
|
||||
pub fn push(&mut self, name: impl AsRef<str>) -> NonZero<usize> {
|
||||
match name.as_ref() {
|
||||
param if param.is_empty() || param.starts_with(PARAM_PREFIX) => {
|
||||
let index = self.next_index();
|
||||
let use_idx = if let Some(idx) = param.strip_prefix(PARAM_PREFIX) {
|
||||
idx.parse().unwrap()
|
||||
} else {
|
||||
index
|
||||
};
|
||||
self.list.push(Parameter::Anonymous(use_idx));
|
||||
tracing::trace!("anonymous parameter at {use_idx}");
|
||||
use_idx
|
||||
}
|
||||
name if name.starts_with(['$', ':', '@', '#']) => {
|
||||
match self
|
||||
.list
|
||||
@@ -112,8 +96,8 @@ impl Parameters {
|
||||
index => {
|
||||
// SAFETY: Guaranteed from parser that the index is bigger than 0.
|
||||
let index: NonZero<usize> = index.parse().unwrap();
|
||||
if index > self.index {
|
||||
self.index = index.checked_add(1).unwrap();
|
||||
if index >= self.next_index {
|
||||
self.next_index = index.checked_add(1).unwrap();
|
||||
}
|
||||
self.list.push(Parameter::Indexed(index));
|
||||
tracing::trace!("indexed parameter at {index}");
|
||||
|
||||
@@ -10,7 +10,6 @@ use super::plan::TableReferences;
|
||||
use crate::function::JsonFunc;
|
||||
use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc, VectorFunc};
|
||||
use crate::functions::datetime;
|
||||
use crate::parameters::PARAM_PREFIX;
|
||||
use crate::schema::{affinity, Affinity, Table, Type};
|
||||
use crate::translate::optimizer::TakeOwnership;
|
||||
use crate::translate::plan::ResultSetColumn;
|
||||
@@ -3244,26 +3243,25 @@ where
|
||||
Ok(WalkControl::Continue)
|
||||
}
|
||||
|
||||
/// Context needed to walk all expressions in a INSERT|UPDATE|SELECT|DELETE body,
|
||||
/// in the order they are encountered, to ensure that the parameters are rewritten from
|
||||
/// anonymous ("?") to our internal named scheme so when the columns are re-ordered we are able
|
||||
/// to bind the proper parameter values.
|
||||
pub struct ParamState {
|
||||
/// ALWAYS starts at 1
|
||||
pub next_param_idx: usize,
|
||||
// flag which allow or forbid usage of parameters during translation of AST to the program
|
||||
//
|
||||
// for example, parameters are not allowed in the partial index definition
|
||||
// so tursodb set allowed to false when it parsed WHERE clause of partial index definition
|
||||
pub allowed: bool,
|
||||
}
|
||||
|
||||
impl Default for ParamState {
|
||||
fn default() -> Self {
|
||||
Self { next_param_idx: 1 }
|
||||
Self { allowed: true }
|
||||
}
|
||||
}
|
||||
impl ParamState {
|
||||
pub fn is_valid(&self) -> bool {
|
||||
self.next_param_idx > 0
|
||||
self.allowed
|
||||
}
|
||||
pub fn disallow() -> Self {
|
||||
Self { next_param_idx: 0 }
|
||||
Self { allowed: false }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3296,16 +3294,10 @@ pub fn bind_and_rewrite_expr<'a>(
|
||||
top_level_expr,
|
||||
&mut |expr: &mut ast::Expr| -> Result<WalkControl> {
|
||||
match expr {
|
||||
// Rewrite anonymous variables in encounter order.
|
||||
ast::Expr::Variable(var) if var.is_empty() => {
|
||||
ast::Expr::Variable(_) => {
|
||||
if !param_state.is_valid() {
|
||||
crate::bail_parse_error!("Parameters are not allowed in this context");
|
||||
}
|
||||
*expr = ast::Expr::Variable(format!(
|
||||
"{}{}",
|
||||
PARAM_PREFIX, param_state.next_param_idx
|
||||
));
|
||||
param_state.next_param_idx += 1;
|
||||
}
|
||||
ast::Expr::Between {
|
||||
lhs,
|
||||
|
||||
@@ -140,6 +140,10 @@ pub struct Parser<'a> {
|
||||
/// The current token being processed
|
||||
current_token: Token<'a>,
|
||||
peekable: bool,
|
||||
|
||||
/// Last assigned id of positional variable
|
||||
/// Parser tracks that in order to properly auto-assign variable ids in correct order for anonymous parameters '?'
|
||||
last_variable_id: u32,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for Parser<'a> {
|
||||
@@ -165,6 +169,25 @@ impl<'a> Parser<'a> {
|
||||
value: &input[..0],
|
||||
token_type: None,
|
||||
},
|
||||
last_variable_id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_variable(&mut self, token: &[u8]) -> Result<Expr> {
|
||||
if token.is_empty() {
|
||||
// Rewrite anonymous variables in encounter order
|
||||
self.last_variable_id += 1;
|
||||
Ok(Expr::Variable(format!("{}", self.last_variable_id)))
|
||||
} else if matches!(token[0], b':' | b'@' | b'$' | b'#') {
|
||||
Ok(Expr::Variable(from_bytes(token)))
|
||||
} else {
|
||||
let variable_str = std::str::from_utf8(token)
|
||||
.map_err(|e| Error::Custom(format!("non-utf8 positional variable id: {e}")))?;
|
||||
let variable_id = variable_str
|
||||
.parse::<u32>()
|
||||
.map_err(|e| Error::Custom(format!("non-integer positional variable id: {e}")))?;
|
||||
self.last_variable_id = variable_id;
|
||||
Ok(Expr::Variable(from_bytes(token)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1344,7 +1367,7 @@ impl<'a> Parser<'a> {
|
||||
}
|
||||
TK_VARIABLE => {
|
||||
let tok = eat_assert!(self, TK_VARIABLE);
|
||||
Ok(Box::new(Expr::Variable(from_bytes(tok.value))))
|
||||
Ok(Box::new(self.create_variable(tok.value)?))
|
||||
}
|
||||
TK_CAST => {
|
||||
eat_assert!(self, TK_CAST);
|
||||
@@ -4387,7 +4410,7 @@ mod tests {
|
||||
select: OneSelect::Select {
|
||||
distinctness: None,
|
||||
columns: vec![ResultColumn::Expr(
|
||||
Box::new(Expr::Variable("".to_owned())),
|
||||
Box::new(Expr::Variable("1".to_owned())),
|
||||
None,
|
||||
)],
|
||||
from: None,
|
||||
|
||||
@@ -827,3 +827,52 @@ fn test_offset_limit_bind() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upsert_parameters_order() -> anyhow::Result<()> {
|
||||
let tmp_db = TempDatabase::new_with_rusqlite(
|
||||
"CREATE TABLE test (k INTEGER PRIMARY KEY, v INTEGER);",
|
||||
false,
|
||||
);
|
||||
let conn = tmp_db.connect_limbo();
|
||||
|
||||
conn.execute("INSERT INTO test VALUES (1, 2), (3, 4)")?;
|
||||
let mut stmt =
|
||||
conn.prepare("INSERT INTO test VALUES (?, ?), (?, ?) ON CONFLICT DO UPDATE SET v = ?")?;
|
||||
stmt.bind_at(1.try_into()?, Value::Integer(1));
|
||||
stmt.bind_at(2.try_into()?, Value::Integer(20));
|
||||
stmt.bind_at(3.try_into()?, Value::Integer(3));
|
||||
stmt.bind_at(4.try_into()?, Value::Integer(40));
|
||||
stmt.bind_at(5.try_into()?, Value::Integer(66));
|
||||
while let StepResult::Row | StepResult::IO = stmt.step()? {
|
||||
stmt.run_once()?;
|
||||
}
|
||||
|
||||
let mut rows = Vec::new();
|
||||
let mut stmt = conn.prepare("SELECT * FROM test")?;
|
||||
loop {
|
||||
match stmt.step()? {
|
||||
StepResult::Row => {
|
||||
let row = stmt.row().unwrap();
|
||||
rows.push(row.get_values().cloned().collect::<Vec<_>>());
|
||||
}
|
||||
StepResult::IO => stmt.run_once()?,
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
rows,
|
||||
vec![
|
||||
vec![
|
||||
turso_core::Value::Integer(1),
|
||||
turso_core::Value::Integer(66)
|
||||
],
|
||||
vec![
|
||||
turso_core::Value::Integer(3),
|
||||
turso_core::Value::Integer(66)
|
||||
]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user