diff --git a/core/error.rs b/core/error.rs index e3e176b79..ca495eb99 100644 --- a/core/error.rs +++ b/core/error.rs @@ -1,3 +1,5 @@ +use std::num::NonZero; + use thiserror::Error; #[derive(Debug, Error, miette::Diagnostic)] @@ -41,6 +43,8 @@ pub enum LimboError { Constraint(String), #[error("Extension error: {0}")] ExtensionError(String), + #[error("Unbound parameter at index {0}")] + Unbound(NonZero), } #[macro_export] diff --git a/core/lib.rs b/core/lib.rs index 906a1675f..04caf1c71 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -4,6 +4,7 @@ mod function; mod io; #[cfg(feature = "json")] mod json; +mod parameters; mod pseudo; mod result; mod schema; @@ -28,6 +29,7 @@ use sqlite3_parser::ast; use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; use std::cell::Cell; use std::collections::HashMap; +use std::num::NonZero; use std::sync::{Arc, OnceLock, RwLock}; use std::{cell::RefCell, rc::Rc}; use storage::btree::btree_init_page; @@ -43,7 +45,7 @@ use util::parse_schema_rows; pub use error::LimboError; use translate::select::prepare_select_plan; -pub type Result = std::result::Result; +pub type Result = std::result::Result; use crate::translate::optimizer::optimize_plan; pub use io::OpenFlags; @@ -386,6 +388,7 @@ impl Connection { Rc::downgrade(self), syms, )?; + let mut state = vdbe::ProgramState::new(program.max_registers); program.step(&mut state, self.pager.clone())?; } @@ -473,7 +476,18 @@ impl Statement { Ok(Rows::new(stmt)) } - pub fn reset(&self) {} + pub fn parameters(&self) -> ¶meters::Parameters { + &self.program.parameters + } + + pub fn bind_at(&mut self, index: NonZero, value: Value) { + self.state.bind_at(index, value.into()); + } + + pub fn reset(&mut self) { + let state = vdbe::ProgramState::new(self.program.max_registers); + self.state = state + } } pub enum StepResult<'a> { diff --git a/core/parameters.rs b/core/parameters.rs new file mode 100644 index 000000000..9bfdf7f63 --- /dev/null +++ b/core/parameters.rs @@ -0,0 +1,111 @@ +use std::num::NonZero; + +#[derive(Clone, Debug)] +pub enum Parameter { + Anonymous(NonZero), + Indexed(NonZero), + Named(String, NonZero), +} + +impl PartialEq for Parameter { + fn eq(&self, other: &Self) -> bool { + self.index() == other.index() + } +} + +impl Parameter { + pub fn index(&self) -> NonZero { + match self { + Parameter::Anonymous(index) => *index, + Parameter::Indexed(index) => *index, + Parameter::Named(_, index) => *index, + } + } +} + +#[derive(Debug)] +pub struct Parameters { + index: NonZero, + pub list: Vec, +} + +impl Parameters { + pub fn new() -> Self { + Self { + index: 1.try_into().unwrap(), + list: vec![], + } + } + + pub fn count(&self) -> usize { + let mut params = self.list.clone(); + params.dedup(); + params.len() + } + + pub fn name(&self, index: NonZero) -> Option { + self.list.iter().find_map(|p| match p { + Parameter::Anonymous(i) if *i == index => Some("?".to_string()), + Parameter::Indexed(i) if *i == index => Some(format!("?{i}")), + Parameter::Named(name, i) if *i == index => Some(name.to_owned()), + _ => None, + }) + } + + pub fn index(&self, name: impl AsRef) -> Option> { + self.list + .iter() + .find_map(|p| match p { + Parameter::Named(n, index) if n == name.as_ref() => Some(index), + _ => None, + }) + .copied() + } + + pub fn next_index(&mut self) -> NonZero { + let index = self.index; + self.index = self.index.checked_add(1).unwrap(); + index + } + + pub fn push(&mut self, name: impl AsRef) -> NonZero { + match name.as_ref() { + "" => { + let index = self.next_index(); + self.list.push(Parameter::Anonymous(index)); + log::trace!("anonymous parameter at {index}"); + index + } + name if name.starts_with(&['$', ':', '@', '#']) => { + match self + .list + .iter() + .find(|p| matches!(p, Parameter::Named(n, _) if name == n)) + { + Some(t) => { + let index = t.index(); + self.list.push(t.clone()); + log::trace!("named parameter at {index} as {name}"); + index + } + None => { + let index = self.next_index(); + self.list.push(Parameter::Named(name.to_owned(), index)); + log::trace!("named parameter at {index} as {name}"); + index + } + } + } + index => { + // SAFETY: Garanteed from parser that the index is bigger that 0. + let index: NonZero = index.parse().unwrap(); + if index > self.index { + self.index = index.checked_add(1).unwrap(); + } + self.list.push(Parameter::Indexed(index)); + log::trace!("indexed parameter at {index}"); + index + } + } + } +} diff --git a/core/translate/expr.rs b/core/translate/expr.rs index d5cd7983e..cd9b6b28c 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1710,7 +1710,14 @@ pub fn translate_expr( } _ => todo!(), }, - ast::Expr::Variable(_) => todo!(), + ast::Expr::Variable(name) => { + let index = program.parameters.push(name); + program.emit_insn(Insn::Variable { + index, + dest: target_register, + }); + Ok(target_register) + } } } diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 9a990eb36..20a514e5d 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -32,8 +32,7 @@ use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; use select::translate_select; -use sqlite3_parser::ast::fmt::ToTokens; -use sqlite3_parser::ast::{self, PragmaName}; +use sqlite3_parser::ast::{self, fmt::ToTokens, PragmaName}; use std::cell::RefCell; use std::fmt::Display; use std::rc::{Rc, Weak}; @@ -49,6 +48,7 @@ pub fn translate( syms: &SymbolTable, ) -> Result { let mut program = ProgramBuilder::new(); + match stmt { ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"), ast::Stmt::Analyze(_) => bail_parse_error!("ANALYZE not supported yet"), @@ -119,6 +119,7 @@ pub fn translate( )?; } } + Ok(program.build(database_header, connection)) } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index d15a2fab6..8e2c10146 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -269,7 +269,7 @@ pub fn bind_column_references( bind_column_references(expr, referenced_tables)?; Ok(()) } - ast::Expr::Variable(_) => todo!(), + ast::Expr::Variable(_) => Ok(()), } } diff --git a/core/types.rs b/core/types.rs index 81d9d5328..c2bb0be9a 100644 --- a/core/types.rs +++ b/core/types.rs @@ -336,6 +336,18 @@ impl std::ops::DivAssign for OwnedValue { } } +impl From> for OwnedValue { + fn from(value: Value<'_>) -> Self { + match value { + Value::Null => OwnedValue::Null, + Value::Integer(i) => OwnedValue::Integer(i), + Value::Float(f) => OwnedValue::Float(f), + Value::Text(s) => OwnedValue::Text(LimboText::new(Rc::new(s.to_owned()))), + Value::Blob(b) => OwnedValue::Blob(Rc::new(b.to_owned())), + } + } +} + pub fn to_value(value: &OwnedValue) -> Value<'_> { match value { OwnedValue::Null => Value::Null, diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 7acc4be6f..ab643c105 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -5,13 +5,13 @@ use std::{ }; use crate::{ + parameters::Parameters, schema::{BTreeTable, Index, PseudoTable}, storage::sqlite3_ondisk::DatabaseHeader, Connection, }; use super::{BranchOffset, CursorID, Insn, InsnReference, Program}; - #[allow(dead_code)] pub struct ProgramBuilder { next_free_register: usize, @@ -29,6 +29,7 @@ pub struct ProgramBuilder { seekrowid_emitted_bitmask: u64, // map of instruction index to manual comment (used in EXPLAIN) comments: HashMap, + pub parameters: Parameters, } #[derive(Debug, Clone)] @@ -58,6 +59,7 @@ impl ProgramBuilder { label_to_resolved_offset: HashMap::new(), seekrowid_emitted_bitmask: 0, comments: HashMap::new(), + parameters: Parameters::new(), } } @@ -331,6 +333,7 @@ impl ProgramBuilder { self.constant_insns.is_empty(), "constant_insns is not empty when build() is called, did you forget to call emit_constant_insns()?" ); + self.parameters.list.dedup(); Program { max_registers: self.next_free_register, insns: self.insns, @@ -339,6 +342,7 @@ impl ProgramBuilder { comments: self.comments, connection, auto_commit: true, + parameters: self.parameters, } } } diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 7c3de8364..22b154809 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -1062,6 +1062,15 @@ pub fn insn_to_str( 0, format!("r[{}]=r[{}] << r[{}]", dest, lhs, rhs), ), + Insn::Variable { index, dest } => ( + "Variable", + usize::from(*index) as i32, + *dest as i32, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + format!("r[{}]=parameter({})", *dest, *index), + ), }; format!( "{:<4} {:<17} {:<4} {:<4} {:<4} {:<13} {:<2} {}", diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 225a9edaf..3066a399c 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -1,3 +1,5 @@ +use std::num::NonZero; + use super::{AggFunc, BranchOffset, CursorID, FuncCtx, PageIdx}; use crate::types::{OwnedRecord, OwnedValue}; use limbo_macros::Description; @@ -487,18 +489,26 @@ pub enum Insn { db: usize, where_clause: String, }, + // Place the result of lhs >> rhs in dest register. ShiftRight { lhs: usize, rhs: usize, dest: usize, }, + // Place the result of lhs << rhs in dest register. ShiftLeft { lhs: usize, rhs: usize, dest: usize, }, + + /// Get parameter variable. + Variable { + index: NonZero, + dest: usize, + }, } fn cast_text_to_numerical(value: &str) -> OwnedValue { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 01ddd55df..43160eb40 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -55,6 +55,7 @@ use sorter::Sorter; use std::borrow::BorrowMut; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap}; +use std::num::NonZero; use std::rc::{Rc, Weak}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -200,6 +201,7 @@ pub struct ProgramState { ended_coroutine: HashMap, // flag to indicate that a coroutine has ended (key is the yield register) regex_cache: RegexCache, interrupted: bool, + parameters: HashMap, OwnedValue>, } impl ProgramState { @@ -222,6 +224,7 @@ impl ProgramState { ended_coroutine: HashMap::new(), regex_cache: RegexCache::new(), interrupted: false, + parameters: HashMap::new(), } } @@ -240,6 +243,18 @@ impl ProgramState { pub fn is_interrupted(&self) -> bool { self.interrupted } + + pub fn bind_at(&mut self, index: NonZero, value: OwnedValue) { + self.parameters.insert(index, value); + } + + pub fn get_parameter(&self, index: NonZero) -> Option<&OwnedValue> { + self.parameters.get(&index) + } + + pub fn reset(&mut self) { + self.parameters.clear(); + } } macro_rules! must_be_btree_cursor { @@ -262,6 +277,7 @@ pub struct Program { pub cursor_ref: Vec<(Option, CursorType)>, pub database_header: Rc>, pub comments: HashMap, + pub parameters: crate::parameters::Parameters, pub connection: Weak, pub auto_commit: bool, } @@ -2182,6 +2198,13 @@ impl Program { exec_shift_left(&state.registers[*lhs], &state.registers[*rhs]); state.pc += 1; } + Insn::Variable { index, dest } => { + state.registers[*dest] = state + .get_parameter(*index) + .ok_or(LimboError::Unbound(*index))? + .clone(); + state.pc += 1; + } } } } diff --git a/test/src/lib.rs b/test/src/lib.rs index 9aa9116d5..e9a368e07 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -572,4 +572,132 @@ mod tests { do_flush(&conn, &tmp_db)?; Ok(()) } + + #[test] + fn test_statement_reset() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new("create table test (i integer);"); + let conn = tmp_db.connect_limbo(); + + conn.execute("insert into test values (1)")?; + conn.execute("insert into test values (2)")?; + + let mut stmt = conn.prepare("select * from test")?; + + loop { + match stmt.step()? { + StepResult::Row(row) => { + assert_eq!(row.values[0], Value::Integer(1)); + break; + } + StepResult::IO => tmp_db.io.run_once()?, + _ => break, + } + } + + stmt.reset(); + + loop { + match stmt.step()? { + StepResult::Row(row) => { + assert_eq!(row.values[0], Value::Integer(1)); + break; + } + StepResult::IO => tmp_db.io.run_once()?, + _ => break, + } + } + + Ok(()) + } + + #[test] + fn test_statement_reset_bind() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new("create table test (i integer);"); + let conn = tmp_db.connect_limbo(); + + let mut stmt = conn.prepare("select ?")?; + + stmt.bind_at(1.try_into().unwrap(), Value::Integer(1)); + + loop { + match stmt.step()? { + StepResult::Row(row) => { + assert_eq!(row.values[0], Value::Integer(1)); + } + StepResult::IO => tmp_db.io.run_once()?, + _ => break, + } + } + + stmt.reset(); + + stmt.bind_at(1.try_into().unwrap(), Value::Integer(2)); + + loop { + match stmt.step()? { + StepResult::Row(row) => { + assert_eq!(row.values[0], Value::Integer(2)); + } + StepResult::IO => tmp_db.io.run_once()?, + _ => break, + } + } + + Ok(()) + } + + #[test] + fn test_statement_bind() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new("create table test (i integer);"); + let conn = tmp_db.connect_limbo(); + + let mut stmt = conn.prepare("select ?, ?1, :named, ?3, ?4")?; + + stmt.bind_at(1.try_into().unwrap(), Value::Text(&"hello".to_string())); + + let i = stmt.parameters().index(":named").unwrap(); + stmt.bind_at(i, Value::Integer(42)); + + stmt.bind_at(3.try_into().unwrap(), Value::Blob(&vec![0x1, 0x2, 0x3])); + + stmt.bind_at(4.try_into().unwrap(), Value::Float(0.5)); + + assert_eq!(stmt.parameters().count(), 4); + + loop { + match stmt.step()? { + StepResult::Row(row) => { + if let Value::Text(s) = row.values[0] { + assert_eq!(s, "hello") + } + + if let Value::Text(s) = row.values[1] { + assert_eq!(s, "hello") + } + + if let Value::Integer(i) = row.values[2] { + assert_eq!(i, 42) + } + + if let Value::Blob(v) = row.values[3] { + assert_eq!(v, &vec![0x1 as u8, 0x2, 0x3]) + } + + if let Value::Float(f) = row.values[4] { + assert_eq!(f, 0.5) + } + } + StepResult::IO => { + tmp_db.io.run_once()?; + } + StepResult::Interrupt => break, + StepResult::Done => break, + StepResult::Busy => panic!("Database is busy"), + }; + } + Ok(()) + } } diff --git a/vendored/sqlite3-parser/src/lexer/sql/mod.rs b/vendored/sqlite3-parser/src/lexer/sql/mod.rs index 72cca0b97..fa98282cc 100644 --- a/vendored/sqlite3-parser/src/lexer/sql/mod.rs +++ b/vendored/sqlite3-parser/src/lexer/sql/mod.rs @@ -441,7 +441,12 @@ impl Splitter for Tokenizer { // do not include the '?' in the token Ok((Some((&data[1..=i], TK_VARIABLE)), i + 1)) } - None => Ok((Some((&data[1..], TK_VARIABLE)), data.len())), + None => { + if !data[1..].is_empty() && data[1..].iter().all(|ch| *ch == b'0') { + return Err(Error::BadVariableName(None, None)); + } + Ok((Some((&data[1..], TK_VARIABLE)), data.len())) + } } } b'$' | b'@' | b'#' | b':' => {