diff --git a/core/lib.rs b/core/lib.rs index a851e7281..ab35d7711 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -28,6 +28,7 @@ use sqlite3_parser::ast; use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; use std::cell::Cell; use std::collections::HashMap; +use std::num::NonZero; use std::sync::{Arc, OnceLock, RwLock}; use std::{cell::RefCell, rc::Rc}; use storage::btree::btree_init_page; @@ -474,12 +475,24 @@ impl Statement { Ok(Rows::new(stmt)) } - pub fn reset(&mut self) { - self.state.reset(); + pub fn parameter_count(&mut self) -> usize { + self.program.parameter_count() } - pub fn bind(&mut self, value: Value) { - self.state.bind(value.into()); + pub fn parameter_name(&self, index: NonZero) -> Option { + self.program.parameter_name(index) + } + + pub fn parameter_index(&self, name: impl AsRef) -> Option> { + self.program.parameter_index(name) + } + + pub fn bind_at(&mut self, index: NonZero, value: Value) { + self.state.bind_at(index, value.into()); + } + + pub fn reset(&mut self) { + self.state.reset(); } } diff --git a/core/translate/expr.rs b/core/translate/expr.rs index bde068a20..3bafa1b1c 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1710,8 +1710,8 @@ pub fn translate_expr( } _ => todo!(), }, - ast::Expr::Variable(name) => { - let index = program.get_parameter_index(name); + ast::Expr::Variable(_) => { + let index = program.pop_index(); program.emit_insn(Insn::Variable { index, dest: target_register, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 37470cd56..1f2dc2737 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -32,12 +32,121 @@ use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; use select::translate_select; +use sqlite3_parser::ast::fmt::TokenStream; use sqlite3_parser::ast::{self, fmt::ToTokens, PragmaName}; +use sqlite3_parser::dialect::TokenType; use std::cell::RefCell; use std::fmt::Display; +use std::num::NonZero; use std::rc::{Rc, Weak}; use std::str::FromStr; +#[derive(Clone, Debug)] +pub enum Parameter { + Anonymous(NonZero), + Indexed(NonZero), + Named(String, NonZero), +} + +impl PartialEq for Parameter { + fn eq(&self, other: &Self) -> bool { + self.index() == other.index() + } +} + +impl Parameter { + pub fn index(&self) -> NonZero { + match self { + Parameter::Anonymous(index) => *index, + Parameter::Indexed(index) => *index, + Parameter::Named(_, index) => *index, + } + } +} + +/// `?` or `$` Prepared statement arg placeholder(s) +#[derive(Debug)] +pub struct Parameters { + index: NonZero, + pub list: Vec, +} + +impl Parameters { + pub fn new() -> Self { + Self { + index: 1.try_into().unwrap(), + list: vec![], + } + } + + pub fn push(&mut self, value: Parameter) { + self.list.push(value); + } + + pub fn next_index(&mut self) -> NonZero { + let index = self.index; + self.index = self.index.checked_add(1).unwrap(); + index + } + + pub fn get(&mut self, index: usize) -> Option<&Parameter> { + self.list.get(index) + } +} + +// https://sqlite.org/lang_expr.html#parameters +impl TokenStream for Parameters { + type Error = std::convert::Infallible; + + fn append( + &mut self, + ty: TokenType, + value: Option<&str>, + ) -> std::result::Result<(), Self::Error> { + if ty == TokenType::TK_VARIABLE { + if let Some(variable) = value { + match variable.split_at(1) { + ("?", "") => { + let index = self.next_index(); + self.push(Parameter::Anonymous(index.try_into().unwrap())); + log::trace!("anonymous parameter at {index}"); + } + ("?", index) => { + let index: NonZero = index.parse().unwrap(); + if index > self.index { + self.index = index.checked_add(1).unwrap(); + } + self.push(Parameter::Indexed(index.try_into().unwrap())); + log::trace!("indexed parameter at {index}"); + } + (_, _) => { + match self.list.iter().find(|p| { + let Parameter::Named(name, _) = p else { + return false; + }; + name == variable + }) { + Some(t) => { + log::trace!("named parameter at {} as {}", t.index(), variable); + self.push(t.clone()); + } + None => { + let index = self.next_index(); + self.push(Parameter::Named( + variable.to_owned(), + index.try_into().unwrap(), + )); + log::trace!("named parameter at {index} as {variable}"); + } + } + } + } + } + } + Ok(()) + } +} + /// Translate SQL statement into bytecode program. pub fn translate( schema: &Schema, @@ -47,7 +156,14 @@ pub fn translate( connection: Weak, syms: &SymbolTable, ) -> Result { - let mut program = ProgramBuilder::new(); + let mut parameters = Parameters::new(); + stmt.to_tokens(&mut parameters).unwrap(); + + // dbg!(¶meters); + // dbg!(¶meters.list.clone().dedup()); + + let mut program = ProgramBuilder::new(parameters); + match stmt { ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"), ast::Stmt::Analyze(_) => bail_parse_error!("ANALYZE not supported yet"), @@ -118,6 +234,7 @@ pub fn translate( )?; } } + Ok(program.build(database_header, connection)) } diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 796782640..711b506a0 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -32,6 +32,8 @@ pub struct ProgramBuilder { comments: HashMap, named_parameters: HashMap>, next_free_parameter_index: NonZero, + parameters: crate::translate::Parameters, + parameter_index: usize, } #[derive(Debug, Clone)] @@ -49,7 +51,7 @@ impl CursorType { } impl ProgramBuilder { - pub fn new() -> Self { + pub fn new(parameters: crate::translate::Parameters) -> Self { Self { next_free_register: 1, next_free_label: 0, @@ -63,6 +65,8 @@ impl ProgramBuilder { seekrowid_emitted_bitmask: 0, comments: HashMap::new(), named_parameters: HashMap::new(), + parameters, + parameter_index: 0, } } @@ -336,6 +340,7 @@ impl ProgramBuilder { self.constant_insns.is_empty(), "constant_insns is not empty when build() is called, did you forget to call emit_constant_insns()?" ); + self.parameters.list.dedup(); Program { max_registers: self.next_free_register, insns: self.insns, @@ -344,29 +349,13 @@ impl ProgramBuilder { comments: self.comments, connection, auto_commit: true, + parameters: self.parameters.list, } } - fn next_parameter(&mut self) -> NonZero { - let index = self.next_free_parameter_index; - self.next_free_parameter_index.checked_add(1).unwrap(); - index - } - - pub fn get_parameter_index(&mut self, name: impl AsRef) -> NonZero { - let name = name.as_ref(); - - if name == "" { - return self.next_parameter(); - } - - match self.named_parameters.get(name) { - Some(index) => *index, - None => { - let index = self.next_parameter(); - self.named_parameters.insert(name.to_owned(), index); - index - } - } + pub fn pop_index(&mut self) -> NonZero { + let parameter = self.parameters.get(self.parameter_index).unwrap(); + self.parameter_index += 1; + return parameter.index(); } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 50db5538a..5188a2329 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -201,7 +201,7 @@ pub struct ProgramState { ended_coroutine: HashMap, // flag to indicate that a coroutine has ended (key is the yield register) regex_cache: RegexCache, interrupted: bool, - parameters: Vec, + parameters: HashMap, OwnedValue>, } impl ProgramState { @@ -224,7 +224,7 @@ impl ProgramState { ended_coroutine: HashMap::new(), regex_cache: RegexCache::new(), interrupted: false, - parameters: Vec::new(), + parameters: HashMap::new(), } } @@ -244,12 +244,12 @@ impl ProgramState { self.interrupted } - pub fn bind(&mut self, value: OwnedValue) { - self.parameters.push(value); + pub fn bind_at(&mut self, index: NonZero, value: OwnedValue) { + self.parameters.insert(index, value); } pub fn get_parameter(&self, index: NonZero) -> Option<&OwnedValue> { - self.parameters.get(usize::from(index) - 1) + self.parameters.get(&index) } pub fn reset(&mut self) { @@ -277,6 +277,7 @@ pub struct Program { pub cursor_ref: Vec<(Option, CursorType)>, pub database_header: Rc>, pub comments: HashMap, + pub parameters: Vec, pub connection: Weak, pub auto_commit: bool, } @@ -300,6 +301,38 @@ impl Program { } } + pub fn parameter_count(&self) -> usize { + self.parameters.len() + } + + pub fn parameter_name(&self, index: NonZero) -> Option { + use crate::translate::Parameter; + self.parameters.iter().find_map(|p| match p { + Parameter::Anonymous(i) if *i == index => Some("?".to_string()), + Parameter::Indexed(i) if *i == index => Some(format!("?{i}")), + Parameter::Named(name, i) if *i == index => Some(name.to_owned()), + _ => None, + }) + } + + pub fn parameter_index(&self, name: impl AsRef) -> Option> { + use crate::translate::Parameter; + self.parameters + .iter() + .find_map(|p| { + let Parameter::Named(parameter_name, index) = p else { + return None; + }; + + if name.as_ref() == parameter_name { + return Some(index); + } + + None + }) + .copied() + } + pub fn step<'a>( &self, state: &'a mut ProgramState, diff --git a/test/src/lib.rs b/test/src/lib.rs index 1fa681ff9..68b4b10e5 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -40,7 +40,7 @@ impl TempDatabase { #[cfg(test)] mod tests { use super::*; - use limbo_core::{CheckpointStatus, Connection, Rows, StepResult, Value}; + use limbo_core::{CheckpointStatus, Connection, StepResult, Value}; use log::debug; #[ignore] @@ -576,16 +576,38 @@ mod tests { #[test] fn test_statement_bind() -> anyhow::Result<()> { let _ = env_logger::try_init(); - let tmp_db = TempDatabase::new("CREATE TABLE test (x INTEGER PRIMARY KEY);"); + let tmp_db = TempDatabase::new("create table test (i integer);"); let conn = tmp_db.connect_limbo(); - let mut stmt = conn.prepare("select ?")?; - stmt.bind(Value::Text(&"hello".to_string())); + + let mut stmt = conn.prepare("select ?, ?1, :named, ?4")?; + + stmt.bind_at(1.try_into().unwrap(), Value::Text(&"hello".to_string())); + + let i = stmt.parameter_index(":named").unwrap(); + stmt.bind_at(i, Value::Integer(42)); + + stmt.bind_at(4.try_into().unwrap(), Value::Float(0.5)); + + assert_eq!(stmt.parameter_count(), 3); + loop { match stmt.step()? { StepResult::Row(row) => { if let Value::Text(s) = row.values[0] { assert_eq!(s, "hello") } + + if let Value::Text(s) = row.values[1] { + assert_eq!(s, "hello") + } + + if let Value::Integer(s) = row.values[2] { + assert_eq!(s, 42) + } + + if let Value::Float(s) = row.values[3] { + assert_eq!(s, 0.5) + } } StepResult::IO => { tmp_db.io.run_once()?;