From 08c8c655e91e2e87fdf40f9dc7200abf4217c313 Mon Sep 17 00:00:00 2001 From: "Levy A." Date: Tue, 14 Jan 2025 06:35:02 -0300 Subject: [PATCH 1/5] feat: initial implementation of `Statement::bind` --- core/error.rs | 4 ++ core/lib.rs | 10 +++- core/parameters.rs | 111 ++++++++++++++++++++++++++++++++++++++ core/translate/expr.rs | 9 +++- core/translate/mod.rs | 3 +- core/translate/planner.rs | 2 +- core/types.rs | 12 +++++ core/vdbe/builder.rs | 28 ++++++++++ core/vdbe/explain.rs | 9 ++++ core/vdbe/insn.rs | 10 ++++ core/vdbe/mod.rs | 22 ++++++++ test/src/lib.rs | 27 +++++++++- 12 files changed, 241 insertions(+), 6 deletions(-) create mode 100644 core/parameters.rs diff --git a/core/error.rs b/core/error.rs index e3e176b79..ca495eb99 100644 --- a/core/error.rs +++ b/core/error.rs @@ -1,3 +1,5 @@ +use std::num::NonZero; + use thiserror::Error; #[derive(Debug, Error, miette::Diagnostic)] @@ -41,6 +43,8 @@ pub enum LimboError { Constraint(String), #[error("Extension error: {0}")] ExtensionError(String), + #[error("Unbound parameter at index {0}")] + Unbound(NonZero), } #[macro_export] diff --git a/core/lib.rs b/core/lib.rs index 906a1675f..1020c495e 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -386,6 +386,7 @@ impl Connection { Rc::downgrade(self), syms, )?; + let mut state = vdbe::ProgramState::new(program.max_registers); program.step(&mut state, self.pager.clone())?; } @@ -473,7 +474,14 @@ impl Statement { Ok(Rows::new(stmt)) } - pub fn reset(&self) {} + pub fn reset(&self) { + self.state.reset(); + } + + pub fn bind(&mut self, value: Value) { + self.state.bind(value.into()); + } + } pub enum StepResult<'a> { diff --git a/core/parameters.rs b/core/parameters.rs new file mode 100644 index 000000000..9bfdf7f63 --- /dev/null +++ b/core/parameters.rs @@ -0,0 +1,111 @@ +use std::num::NonZero; + +#[derive(Clone, Debug)] +pub enum Parameter { + Anonymous(NonZero), + Indexed(NonZero), + Named(String, NonZero), +} + +impl PartialEq for Parameter { + fn eq(&self, other: &Self) -> bool { + self.index() == other.index() + } +} + +impl Parameter { + pub fn index(&self) -> NonZero { + match self { + Parameter::Anonymous(index) => *index, + Parameter::Indexed(index) => *index, + Parameter::Named(_, index) => *index, + } + } +} + +#[derive(Debug)] +pub struct Parameters { + index: NonZero, + pub list: Vec, +} + +impl Parameters { + pub fn new() -> Self { + Self { + index: 1.try_into().unwrap(), + list: vec![], + } + } + + pub fn count(&self) -> usize { + let mut params = self.list.clone(); + params.dedup(); + params.len() + } + + pub fn name(&self, index: NonZero) -> Option { + self.list.iter().find_map(|p| match p { + Parameter::Anonymous(i) if *i == index => Some("?".to_string()), + Parameter::Indexed(i) if *i == index => Some(format!("?{i}")), + Parameter::Named(name, i) if *i == index => Some(name.to_owned()), + _ => None, + }) + } + + pub fn index(&self, name: impl AsRef) -> Option> { + self.list + .iter() + .find_map(|p| match p { + Parameter::Named(n, index) if n == name.as_ref() => Some(index), + _ => None, + }) + .copied() + } + + pub fn next_index(&mut self) -> NonZero { + let index = self.index; + self.index = self.index.checked_add(1).unwrap(); + index + } + + pub fn push(&mut self, name: impl AsRef) -> NonZero { + match name.as_ref() { + "" => { + let index = self.next_index(); + self.list.push(Parameter::Anonymous(index)); + log::trace!("anonymous parameter at {index}"); + index + } + name if name.starts_with(&['$', ':', '@', '#']) => { + match self + .list + .iter() + .find(|p| matches!(p, Parameter::Named(n, _) if name == n)) + { + Some(t) => { + let index = t.index(); + self.list.push(t.clone()); + log::trace!("named parameter at {index} as {name}"); + index + } + None => { + let index = self.next_index(); + self.list.push(Parameter::Named(name.to_owned(), index)); + log::trace!("named parameter at {index} as {name}"); + index + } + } + } + index => { + // SAFETY: Garanteed from parser that the index is bigger that 0. + let index: NonZero = index.parse().unwrap(); + if index > self.index { + self.index = index.checked_add(1).unwrap(); + } + self.list.push(Parameter::Indexed(index)); + log::trace!("indexed parameter at {index}"); + index + } + } + } +} diff --git a/core/translate/expr.rs b/core/translate/expr.rs index d5cd7983e..bde068a20 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1710,7 +1710,14 @@ pub fn translate_expr( } _ => todo!(), }, - ast::Expr::Variable(_) => todo!(), + ast::Expr::Variable(name) => { + let index = program.get_parameter_index(name); + program.emit_insn(Insn::Variable { + index, + dest: target_register, + }); + Ok(target_register) + } } } diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 9a990eb36..37470cd56 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -32,8 +32,7 @@ use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; use select::translate_select; -use sqlite3_parser::ast::fmt::ToTokens; -use sqlite3_parser::ast::{self, PragmaName}; +use sqlite3_parser::ast::{self, fmt::ToTokens, PragmaName}; use std::cell::RefCell; use std::fmt::Display; use std::rc::{Rc, Weak}; diff --git a/core/translate/planner.rs b/core/translate/planner.rs index d15a2fab6..8e2c10146 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -269,7 +269,7 @@ pub fn bind_column_references( bind_column_references(expr, referenced_tables)?; Ok(()) } - ast::Expr::Variable(_) => todo!(), + ast::Expr::Variable(_) => Ok(()), } } diff --git a/core/types.rs b/core/types.rs index 81d9d5328..c2bb0be9a 100644 --- a/core/types.rs +++ b/core/types.rs @@ -336,6 +336,18 @@ impl std::ops::DivAssign for OwnedValue { } } +impl From> for OwnedValue { + fn from(value: Value<'_>) -> Self { + match value { + Value::Null => OwnedValue::Null, + Value::Integer(i) => OwnedValue::Integer(i), + Value::Float(f) => OwnedValue::Float(f), + Value::Text(s) => OwnedValue::Text(LimboText::new(Rc::new(s.to_owned()))), + Value::Blob(b) => OwnedValue::Blob(Rc::new(b.to_owned())), + } + } +} + pub fn to_value(value: &OwnedValue) -> Value<'_> { match value { OwnedValue::Null => Value::Null, diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 7acc4be6f..c796ee5f9 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -1,6 +1,7 @@ use std::{ cell::RefCell, collections::HashMap, + num::NonZero, rc::{Rc, Weak}, }; @@ -29,6 +30,8 @@ pub struct ProgramBuilder { seekrowid_emitted_bitmask: u64, // map of instruction index to manual comment (used in EXPLAIN) comments: HashMap, + named_parameters: HashMap>, + next_free_parameter_index: NonZero, } #[derive(Debug, Clone)] @@ -51,6 +54,7 @@ impl ProgramBuilder { next_free_register: 1, next_free_label: 0, next_free_cursor_id: 0, + next_free_parameter_index: 1.into(), insns: Vec::new(), next_insn_label: None, cursor_ref: Vec::new(), @@ -58,6 +62,7 @@ impl ProgramBuilder { label_to_resolved_offset: HashMap::new(), seekrowid_emitted_bitmask: 0, comments: HashMap::new(), + named_parameters: HashMap::new(), } } @@ -341,4 +346,27 @@ impl ProgramBuilder { auto_commit: true, } } + + fn next_parameter(&mut self) -> NonZero { + let index = self.next_free_parameter_index; + self.next_free_parameter_index.checked_add(1).unwrap(); + index + } + + pub fn get_parameter_index(&mut self, name: impl AsRef) -> NonZero { + let name = name.as_ref(); + + if name == "" { + return self.next_parameter(); + } + + match self.named_parameters.get(name) { + Some(index) => *index, + None => { + let index = self.next_parameter(); + self.named_parameters.insert(name.to_owned(), index); + index + } + } + } } diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 7c3de8364..22b154809 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -1062,6 +1062,15 @@ pub fn insn_to_str( 0, format!("r[{}]=r[{}] << r[{}]", dest, lhs, rhs), ), + Insn::Variable { index, dest } => ( + "Variable", + usize::from(*index) as i32, + *dest as i32, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + format!("r[{}]=parameter({})", *dest, *index), + ), }; format!( "{:<4} {:<17} {:<4} {:<4} {:<4} {:<13} {:<2} {}", diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 225a9edaf..3066a399c 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -1,3 +1,5 @@ +use std::num::NonZero; + use super::{AggFunc, BranchOffset, CursorID, FuncCtx, PageIdx}; use crate::types::{OwnedRecord, OwnedValue}; use limbo_macros::Description; @@ -487,18 +489,26 @@ pub enum Insn { db: usize, where_clause: String, }, + // Place the result of lhs >> rhs in dest register. ShiftRight { lhs: usize, rhs: usize, dest: usize, }, + // Place the result of lhs << rhs in dest register. ShiftLeft { lhs: usize, rhs: usize, dest: usize, }, + + /// Get parameter variable. + Variable { + index: NonZero, + dest: usize, + }, } fn cast_text_to_numerical(value: &str) -> OwnedValue { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 01ddd55df..80732cf9d 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -55,6 +55,7 @@ use sorter::Sorter; use std::borrow::BorrowMut; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap}; +use std::num::NonZero; use std::rc::{Rc, Weak}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -200,6 +201,7 @@ pub struct ProgramState { ended_coroutine: HashMap, // flag to indicate that a coroutine has ended (key is the yield register) regex_cache: RegexCache, interrupted: bool, + parameters: Vec, } impl ProgramState { @@ -222,6 +224,7 @@ impl ProgramState { ended_coroutine: HashMap::new(), regex_cache: RegexCache::new(), interrupted: false, + parameters: Vec::new(), } } @@ -240,6 +243,18 @@ impl ProgramState { pub fn is_interrupted(&self) -> bool { self.interrupted } + + pub fn bind(&mut self, value: OwnedValue) { + self.parameters.push(value); + } + + pub fn get_parameter(&self, index: NonZero) -> Option<&OwnedValue> { + self.parameters.get(usize::from(index) - 1) + } + + pub fn reset(&self) { + self.parameters.clear(); + } } macro_rules! must_be_btree_cursor { @@ -2182,6 +2197,13 @@ impl Program { exec_shift_left(&state.registers[*lhs], &state.registers[*rhs]); state.pc += 1; } + Insn::Variable { index, dest } => { + state.registers[*dest] = state + .get_parameter(*index) + .ok_or(LimboError::Unbound(*index))? + .clone(); + state.pc += 1; + } } } } diff --git a/test/src/lib.rs b/test/src/lib.rs index 9aa9116d5..1fa681ff9 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -40,7 +40,7 @@ impl TempDatabase { #[cfg(test)] mod tests { use super::*; - use limbo_core::{CheckpointStatus, Connection, StepResult, Value}; + use limbo_core::{CheckpointStatus, Connection, Rows, StepResult, Value}; use log::debug; #[ignore] @@ -572,4 +572,29 @@ mod tests { do_flush(&conn, &tmp_db)?; Ok(()) } + + #[test] + fn test_statement_bind() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new("CREATE TABLE test (x INTEGER PRIMARY KEY);"); + let conn = tmp_db.connect_limbo(); + let mut stmt = conn.prepare("select ?")?; + stmt.bind(Value::Text(&"hello".to_string())); + loop { + match stmt.step()? { + StepResult::Row(row) => { + if let Value::Text(s) = row.values[0] { + assert_eq!(s, "hello") + } + } + StepResult::IO => { + tmp_db.io.run_once()?; + } + StepResult::Interrupt => break, + StepResult::Done => break, + StepResult::Busy => panic!("Database is busy"), + }; + } + Ok(()) + } } From 6e0ce3dd01498ff3beac32851684e1f2c2e37c7d Mon Sep 17 00:00:00 2001 From: "Levy A." Date: Tue, 14 Jan 2025 06:39:11 -0300 Subject: [PATCH 2/5] chore: cargo fmt --- core/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/core/lib.rs b/core/lib.rs index 1020c495e..a032dea1e 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -481,7 +481,6 @@ impl Statement { pub fn bind(&mut self, value: Value) { self.state.bind(value.into()); } - } pub enum StepResult<'a> { From d3582a382f6334413b42fb0b8db6f643e5a85b1a Mon Sep 17 00:00:00 2001 From: "Levy A." Date: Tue, 14 Jan 2025 06:49:17 -0300 Subject: [PATCH 3/5] fix: small bugs --- core/lib.rs | 2 +- core/vdbe/builder.rs | 2 +- core/vdbe/mod.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index a032dea1e..a851e7281 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -474,7 +474,7 @@ impl Statement { Ok(Rows::new(stmt)) } - pub fn reset(&self) { + pub fn reset(&mut self) { self.state.reset(); } diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index c796ee5f9..796782640 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -54,7 +54,7 @@ impl ProgramBuilder { next_free_register: 1, next_free_label: 0, next_free_cursor_id: 0, - next_free_parameter_index: 1.into(), + next_free_parameter_index: 1.try_into().unwrap(), insns: Vec::new(), next_insn_label: None, cursor_ref: Vec::new(), diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 80732cf9d..50db5538a 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -252,7 +252,7 @@ impl ProgramState { self.parameters.get(usize::from(index) - 1) } - pub fn reset(&self) { + pub fn reset(&mut self) { self.parameters.clear(); } } From 5de2694834a55c5fa0f4f898e9680296cfa1c513 Mon Sep 17 00:00:00 2001 From: "Levy A." Date: Tue, 14 Jan 2025 22:15:24 -0300 Subject: [PATCH 4/5] feat: more parameter support add `Statement::{parameter_index, parameter_name, parameter_count, bind_at}`. some refactoring is still needed, this is quite a rough iteration --- core/lib.rs | 21 ++++++-- core/translate/expr.rs | 4 +- core/translate/mod.rs | 119 ++++++++++++++++++++++++++++++++++++++++- core/vdbe/builder.rs | 33 ++++-------- core/vdbe/mod.rs | 43 +++++++++++++-- test/src/lib.rs | 30 +++++++++-- 6 files changed, 212 insertions(+), 38 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index a851e7281..ab35d7711 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -28,6 +28,7 @@ use sqlite3_parser::ast; use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; use std::cell::Cell; use std::collections::HashMap; +use std::num::NonZero; use std::sync::{Arc, OnceLock, RwLock}; use std::{cell::RefCell, rc::Rc}; use storage::btree::btree_init_page; @@ -474,12 +475,24 @@ impl Statement { Ok(Rows::new(stmt)) } - pub fn reset(&mut self) { - self.state.reset(); + pub fn parameter_count(&mut self) -> usize { + self.program.parameter_count() } - pub fn bind(&mut self, value: Value) { - self.state.bind(value.into()); + pub fn parameter_name(&self, index: NonZero) -> Option { + self.program.parameter_name(index) + } + + pub fn parameter_index(&self, name: impl AsRef) -> Option> { + self.program.parameter_index(name) + } + + pub fn bind_at(&mut self, index: NonZero, value: Value) { + self.state.bind_at(index, value.into()); + } + + pub fn reset(&mut self) { + self.state.reset(); } } diff --git a/core/translate/expr.rs b/core/translate/expr.rs index bde068a20..3bafa1b1c 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1710,8 +1710,8 @@ pub fn translate_expr( } _ => todo!(), }, - ast::Expr::Variable(name) => { - let index = program.get_parameter_index(name); + ast::Expr::Variable(_) => { + let index = program.pop_index(); program.emit_insn(Insn::Variable { index, dest: target_register, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 37470cd56..1f2dc2737 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -32,12 +32,121 @@ use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; use select::translate_select; +use sqlite3_parser::ast::fmt::TokenStream; use sqlite3_parser::ast::{self, fmt::ToTokens, PragmaName}; +use sqlite3_parser::dialect::TokenType; use std::cell::RefCell; use std::fmt::Display; +use std::num::NonZero; use std::rc::{Rc, Weak}; use std::str::FromStr; +#[derive(Clone, Debug)] +pub enum Parameter { + Anonymous(NonZero), + Indexed(NonZero), + Named(String, NonZero), +} + +impl PartialEq for Parameter { + fn eq(&self, other: &Self) -> bool { + self.index() == other.index() + } +} + +impl Parameter { + pub fn index(&self) -> NonZero { + match self { + Parameter::Anonymous(index) => *index, + Parameter::Indexed(index) => *index, + Parameter::Named(_, index) => *index, + } + } +} + +/// `?` or `$` Prepared statement arg placeholder(s) +#[derive(Debug)] +pub struct Parameters { + index: NonZero, + pub list: Vec, +} + +impl Parameters { + pub fn new() -> Self { + Self { + index: 1.try_into().unwrap(), + list: vec![], + } + } + + pub fn push(&mut self, value: Parameter) { + self.list.push(value); + } + + pub fn next_index(&mut self) -> NonZero { + let index = self.index; + self.index = self.index.checked_add(1).unwrap(); + index + } + + pub fn get(&mut self, index: usize) -> Option<&Parameter> { + self.list.get(index) + } +} + +// https://sqlite.org/lang_expr.html#parameters +impl TokenStream for Parameters { + type Error = std::convert::Infallible; + + fn append( + &mut self, + ty: TokenType, + value: Option<&str>, + ) -> std::result::Result<(), Self::Error> { + if ty == TokenType::TK_VARIABLE { + if let Some(variable) = value { + match variable.split_at(1) { + ("?", "") => { + let index = self.next_index(); + self.push(Parameter::Anonymous(index.try_into().unwrap())); + log::trace!("anonymous parameter at {index}"); + } + ("?", index) => { + let index: NonZero = index.parse().unwrap(); + if index > self.index { + self.index = index.checked_add(1).unwrap(); + } + self.push(Parameter::Indexed(index.try_into().unwrap())); + log::trace!("indexed parameter at {index}"); + } + (_, _) => { + match self.list.iter().find(|p| { + let Parameter::Named(name, _) = p else { + return false; + }; + name == variable + }) { + Some(t) => { + log::trace!("named parameter at {} as {}", t.index(), variable); + self.push(t.clone()); + } + None => { + let index = self.next_index(); + self.push(Parameter::Named( + variable.to_owned(), + index.try_into().unwrap(), + )); + log::trace!("named parameter at {index} as {variable}"); + } + } + } + } + } + } + Ok(()) + } +} + /// Translate SQL statement into bytecode program. pub fn translate( schema: &Schema, @@ -47,7 +156,14 @@ pub fn translate( connection: Weak, syms: &SymbolTable, ) -> Result { - let mut program = ProgramBuilder::new(); + let mut parameters = Parameters::new(); + stmt.to_tokens(&mut parameters).unwrap(); + + // dbg!(¶meters); + // dbg!(¶meters.list.clone().dedup()); + + let mut program = ProgramBuilder::new(parameters); + match stmt { ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"), ast::Stmt::Analyze(_) => bail_parse_error!("ANALYZE not supported yet"), @@ -118,6 +234,7 @@ pub fn translate( )?; } } + Ok(program.build(database_header, connection)) } diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 796782640..711b506a0 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -32,6 +32,8 @@ pub struct ProgramBuilder { comments: HashMap, named_parameters: HashMap>, next_free_parameter_index: NonZero, + parameters: crate::translate::Parameters, + parameter_index: usize, } #[derive(Debug, Clone)] @@ -49,7 +51,7 @@ impl CursorType { } impl ProgramBuilder { - pub fn new() -> Self { + pub fn new(parameters: crate::translate::Parameters) -> Self { Self { next_free_register: 1, next_free_label: 0, @@ -63,6 +65,8 @@ impl ProgramBuilder { seekrowid_emitted_bitmask: 0, comments: HashMap::new(), named_parameters: HashMap::new(), + parameters, + parameter_index: 0, } } @@ -336,6 +340,7 @@ impl ProgramBuilder { self.constant_insns.is_empty(), "constant_insns is not empty when build() is called, did you forget to call emit_constant_insns()?" ); + self.parameters.list.dedup(); Program { max_registers: self.next_free_register, insns: self.insns, @@ -344,29 +349,13 @@ impl ProgramBuilder { comments: self.comments, connection, auto_commit: true, + parameters: self.parameters.list, } } - fn next_parameter(&mut self) -> NonZero { - let index = self.next_free_parameter_index; - self.next_free_parameter_index.checked_add(1).unwrap(); - index - } - - pub fn get_parameter_index(&mut self, name: impl AsRef) -> NonZero { - let name = name.as_ref(); - - if name == "" { - return self.next_parameter(); - } - - match self.named_parameters.get(name) { - Some(index) => *index, - None => { - let index = self.next_parameter(); - self.named_parameters.insert(name.to_owned(), index); - index - } - } + pub fn pop_index(&mut self) -> NonZero { + let parameter = self.parameters.get(self.parameter_index).unwrap(); + self.parameter_index += 1; + return parameter.index(); } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 50db5538a..5188a2329 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -201,7 +201,7 @@ pub struct ProgramState { ended_coroutine: HashMap, // flag to indicate that a coroutine has ended (key is the yield register) regex_cache: RegexCache, interrupted: bool, - parameters: Vec, + parameters: HashMap, OwnedValue>, } impl ProgramState { @@ -224,7 +224,7 @@ impl ProgramState { ended_coroutine: HashMap::new(), regex_cache: RegexCache::new(), interrupted: false, - parameters: Vec::new(), + parameters: HashMap::new(), } } @@ -244,12 +244,12 @@ impl ProgramState { self.interrupted } - pub fn bind(&mut self, value: OwnedValue) { - self.parameters.push(value); + pub fn bind_at(&mut self, index: NonZero, value: OwnedValue) { + self.parameters.insert(index, value); } pub fn get_parameter(&self, index: NonZero) -> Option<&OwnedValue> { - self.parameters.get(usize::from(index) - 1) + self.parameters.get(&index) } pub fn reset(&mut self) { @@ -277,6 +277,7 @@ pub struct Program { pub cursor_ref: Vec<(Option, CursorType)>, pub database_header: Rc>, pub comments: HashMap, + pub parameters: Vec, pub connection: Weak, pub auto_commit: bool, } @@ -300,6 +301,38 @@ impl Program { } } + pub fn parameter_count(&self) -> usize { + self.parameters.len() + } + + pub fn parameter_name(&self, index: NonZero) -> Option { + use crate::translate::Parameter; + self.parameters.iter().find_map(|p| match p { + Parameter::Anonymous(i) if *i == index => Some("?".to_string()), + Parameter::Indexed(i) if *i == index => Some(format!("?{i}")), + Parameter::Named(name, i) if *i == index => Some(name.to_owned()), + _ => None, + }) + } + + pub fn parameter_index(&self, name: impl AsRef) -> Option> { + use crate::translate::Parameter; + self.parameters + .iter() + .find_map(|p| { + let Parameter::Named(parameter_name, index) = p else { + return None; + }; + + if name.as_ref() == parameter_name { + return Some(index); + } + + None + }) + .copied() + } + pub fn step<'a>( &self, state: &'a mut ProgramState, diff --git a/test/src/lib.rs b/test/src/lib.rs index 1fa681ff9..68b4b10e5 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -40,7 +40,7 @@ impl TempDatabase { #[cfg(test)] mod tests { use super::*; - use limbo_core::{CheckpointStatus, Connection, Rows, StepResult, Value}; + use limbo_core::{CheckpointStatus, Connection, StepResult, Value}; use log::debug; #[ignore] @@ -576,16 +576,38 @@ mod tests { #[test] fn test_statement_bind() -> anyhow::Result<()> { let _ = env_logger::try_init(); - let tmp_db = TempDatabase::new("CREATE TABLE test (x INTEGER PRIMARY KEY);"); + let tmp_db = TempDatabase::new("create table test (i integer);"); let conn = tmp_db.connect_limbo(); - let mut stmt = conn.prepare("select ?")?; - stmt.bind(Value::Text(&"hello".to_string())); + + let mut stmt = conn.prepare("select ?, ?1, :named, ?4")?; + + stmt.bind_at(1.try_into().unwrap(), Value::Text(&"hello".to_string())); + + let i = stmt.parameter_index(":named").unwrap(); + stmt.bind_at(i, Value::Integer(42)); + + stmt.bind_at(4.try_into().unwrap(), Value::Float(0.5)); + + assert_eq!(stmt.parameter_count(), 3); + loop { match stmt.step()? { StepResult::Row(row) => { if let Value::Text(s) = row.values[0] { assert_eq!(s, "hello") } + + if let Value::Text(s) = row.values[1] { + assert_eq!(s, "hello") + } + + if let Value::Integer(s) = row.values[2] { + assert_eq!(s, 42) + } + + if let Value::Float(s) = row.values[3] { + assert_eq!(s, 0.5) + } } StepResult::IO => { tmp_db.io.run_once()?; From 9b8722f38e7c88d507d957f5908e8b035c706af4 Mon Sep 17 00:00:00 2001 From: "Levy A." Date: Wed, 15 Jan 2025 16:33:33 -0300 Subject: [PATCH 5/5] refactor: more well rounded implementation `?0` parameters are now handled by the parser. --- core/lib.rs | 18 +-- core/translate/expr.rs | 4 +- core/translate/mod.rs | 117 +------------------ core/vdbe/builder.rs | 23 +--- core/vdbe/mod.rs | 34 +----- test/src/lib.rs | 95 +++++++++++++-- vendored/sqlite3-parser/src/lexer/sql/mod.rs | 7 +- 7 files changed, 109 insertions(+), 189 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index ab35d7711..04caf1c71 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -4,6 +4,7 @@ mod function; mod io; #[cfg(feature = "json")] mod json; +mod parameters; mod pseudo; mod result; mod schema; @@ -44,7 +45,7 @@ use util::parse_schema_rows; pub use error::LimboError; use translate::select::prepare_select_plan; -pub type Result = std::result::Result; +pub type Result = std::result::Result; use crate::translate::optimizer::optimize_plan; pub use io::OpenFlags; @@ -475,16 +476,8 @@ impl Statement { Ok(Rows::new(stmt)) } - pub fn parameter_count(&mut self) -> usize { - self.program.parameter_count() - } - - pub fn parameter_name(&self, index: NonZero) -> Option { - self.program.parameter_name(index) - } - - pub fn parameter_index(&self, name: impl AsRef) -> Option> { - self.program.parameter_index(name) + pub fn parameters(&self) -> ¶meters::Parameters { + &self.program.parameters } pub fn bind_at(&mut self, index: NonZero, value: Value) { @@ -492,7 +485,8 @@ impl Statement { } pub fn reset(&mut self) { - self.state.reset(); + let state = vdbe::ProgramState::new(self.program.max_registers); + self.state = state } } diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 3bafa1b1c..cd9b6b28c 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1710,8 +1710,8 @@ pub fn translate_expr( } _ => todo!(), }, - ast::Expr::Variable(_) => { - let index = program.pop_index(); + ast::Expr::Variable(name) => { + let index = program.parameters.push(name); program.emit_insn(Insn::Variable { index, dest: target_register, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 1f2dc2737..20a514e5d 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -32,121 +32,12 @@ use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; use select::translate_select; -use sqlite3_parser::ast::fmt::TokenStream; use sqlite3_parser::ast::{self, fmt::ToTokens, PragmaName}; -use sqlite3_parser::dialect::TokenType; use std::cell::RefCell; use std::fmt::Display; -use std::num::NonZero; use std::rc::{Rc, Weak}; use std::str::FromStr; -#[derive(Clone, Debug)] -pub enum Parameter { - Anonymous(NonZero), - Indexed(NonZero), - Named(String, NonZero), -} - -impl PartialEq for Parameter { - fn eq(&self, other: &Self) -> bool { - self.index() == other.index() - } -} - -impl Parameter { - pub fn index(&self) -> NonZero { - match self { - Parameter::Anonymous(index) => *index, - Parameter::Indexed(index) => *index, - Parameter::Named(_, index) => *index, - } - } -} - -/// `?` or `$` Prepared statement arg placeholder(s) -#[derive(Debug)] -pub struct Parameters { - index: NonZero, - pub list: Vec, -} - -impl Parameters { - pub fn new() -> Self { - Self { - index: 1.try_into().unwrap(), - list: vec![], - } - } - - pub fn push(&mut self, value: Parameter) { - self.list.push(value); - } - - pub fn next_index(&mut self) -> NonZero { - let index = self.index; - self.index = self.index.checked_add(1).unwrap(); - index - } - - pub fn get(&mut self, index: usize) -> Option<&Parameter> { - self.list.get(index) - } -} - -// https://sqlite.org/lang_expr.html#parameters -impl TokenStream for Parameters { - type Error = std::convert::Infallible; - - fn append( - &mut self, - ty: TokenType, - value: Option<&str>, - ) -> std::result::Result<(), Self::Error> { - if ty == TokenType::TK_VARIABLE { - if let Some(variable) = value { - match variable.split_at(1) { - ("?", "") => { - let index = self.next_index(); - self.push(Parameter::Anonymous(index.try_into().unwrap())); - log::trace!("anonymous parameter at {index}"); - } - ("?", index) => { - let index: NonZero = index.parse().unwrap(); - if index > self.index { - self.index = index.checked_add(1).unwrap(); - } - self.push(Parameter::Indexed(index.try_into().unwrap())); - log::trace!("indexed parameter at {index}"); - } - (_, _) => { - match self.list.iter().find(|p| { - let Parameter::Named(name, _) = p else { - return false; - }; - name == variable - }) { - Some(t) => { - log::trace!("named parameter at {} as {}", t.index(), variable); - self.push(t.clone()); - } - None => { - let index = self.next_index(); - self.push(Parameter::Named( - variable.to_owned(), - index.try_into().unwrap(), - )); - log::trace!("named parameter at {index} as {variable}"); - } - } - } - } - } - } - Ok(()) - } -} - /// Translate SQL statement into bytecode program. pub fn translate( schema: &Schema, @@ -156,13 +47,7 @@ pub fn translate( connection: Weak, syms: &SymbolTable, ) -> Result { - let mut parameters = Parameters::new(); - stmt.to_tokens(&mut parameters).unwrap(); - - // dbg!(¶meters); - // dbg!(¶meters.list.clone().dedup()); - - let mut program = ProgramBuilder::new(parameters); + let mut program = ProgramBuilder::new(); match stmt { ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"), diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 711b506a0..ab643c105 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -1,18 +1,17 @@ use std::{ cell::RefCell, collections::HashMap, - num::NonZero, rc::{Rc, Weak}, }; use crate::{ + parameters::Parameters, schema::{BTreeTable, Index, PseudoTable}, storage::sqlite3_ondisk::DatabaseHeader, Connection, }; use super::{BranchOffset, CursorID, Insn, InsnReference, Program}; - #[allow(dead_code)] pub struct ProgramBuilder { next_free_register: usize, @@ -30,10 +29,7 @@ pub struct ProgramBuilder { seekrowid_emitted_bitmask: u64, // map of instruction index to manual comment (used in EXPLAIN) comments: HashMap, - named_parameters: HashMap>, - next_free_parameter_index: NonZero, - parameters: crate::translate::Parameters, - parameter_index: usize, + pub parameters: Parameters, } #[derive(Debug, Clone)] @@ -51,12 +47,11 @@ impl CursorType { } impl ProgramBuilder { - pub fn new(parameters: crate::translate::Parameters) -> Self { + pub fn new() -> Self { Self { next_free_register: 1, next_free_label: 0, next_free_cursor_id: 0, - next_free_parameter_index: 1.try_into().unwrap(), insns: Vec::new(), next_insn_label: None, cursor_ref: Vec::new(), @@ -64,9 +59,7 @@ impl ProgramBuilder { label_to_resolved_offset: HashMap::new(), seekrowid_emitted_bitmask: 0, comments: HashMap::new(), - named_parameters: HashMap::new(), - parameters, - parameter_index: 0, + parameters: Parameters::new(), } } @@ -349,13 +342,7 @@ impl ProgramBuilder { comments: self.comments, connection, auto_commit: true, - parameters: self.parameters.list, + parameters: self.parameters, } } - - pub fn pop_index(&mut self) -> NonZero { - let parameter = self.parameters.get(self.parameter_index).unwrap(); - self.parameter_index += 1; - return parameter.index(); - } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 5188a2329..43160eb40 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -277,7 +277,7 @@ pub struct Program { pub cursor_ref: Vec<(Option, CursorType)>, pub database_header: Rc>, pub comments: HashMap, - pub parameters: Vec, + pub parameters: crate::parameters::Parameters, pub connection: Weak, pub auto_commit: bool, } @@ -301,38 +301,6 @@ impl Program { } } - pub fn parameter_count(&self) -> usize { - self.parameters.len() - } - - pub fn parameter_name(&self, index: NonZero) -> Option { - use crate::translate::Parameter; - self.parameters.iter().find_map(|p| match p { - Parameter::Anonymous(i) if *i == index => Some("?".to_string()), - Parameter::Indexed(i) if *i == index => Some(format!("?{i}")), - Parameter::Named(name, i) if *i == index => Some(name.to_owned()), - _ => None, - }) - } - - pub fn parameter_index(&self, name: impl AsRef) -> Option> { - use crate::translate::Parameter; - self.parameters - .iter() - .find_map(|p| { - let Parameter::Named(parameter_name, index) = p else { - return None; - }; - - if name.as_ref() == parameter_name { - return Some(index); - } - - None - }) - .copied() - } - pub fn step<'a>( &self, state: &'a mut ProgramState, diff --git a/test/src/lib.rs b/test/src/lib.rs index 68b4b10e5..e9a368e07 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -573,22 +573,99 @@ mod tests { Ok(()) } + #[test] + fn test_statement_reset() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new("create table test (i integer);"); + let conn = tmp_db.connect_limbo(); + + conn.execute("insert into test values (1)")?; + conn.execute("insert into test values (2)")?; + + let mut stmt = conn.prepare("select * from test")?; + + loop { + match stmt.step()? { + StepResult::Row(row) => { + assert_eq!(row.values[0], Value::Integer(1)); + break; + } + StepResult::IO => tmp_db.io.run_once()?, + _ => break, + } + } + + stmt.reset(); + + loop { + match stmt.step()? { + StepResult::Row(row) => { + assert_eq!(row.values[0], Value::Integer(1)); + break; + } + StepResult::IO => tmp_db.io.run_once()?, + _ => break, + } + } + + Ok(()) + } + + #[test] + fn test_statement_reset_bind() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new("create table test (i integer);"); + let conn = tmp_db.connect_limbo(); + + let mut stmt = conn.prepare("select ?")?; + + stmt.bind_at(1.try_into().unwrap(), Value::Integer(1)); + + loop { + match stmt.step()? { + StepResult::Row(row) => { + assert_eq!(row.values[0], Value::Integer(1)); + } + StepResult::IO => tmp_db.io.run_once()?, + _ => break, + } + } + + stmt.reset(); + + stmt.bind_at(1.try_into().unwrap(), Value::Integer(2)); + + loop { + match stmt.step()? { + StepResult::Row(row) => { + assert_eq!(row.values[0], Value::Integer(2)); + } + StepResult::IO => tmp_db.io.run_once()?, + _ => break, + } + } + + Ok(()) + } + #[test] fn test_statement_bind() -> anyhow::Result<()> { let _ = env_logger::try_init(); let tmp_db = TempDatabase::new("create table test (i integer);"); let conn = tmp_db.connect_limbo(); - let mut stmt = conn.prepare("select ?, ?1, :named, ?4")?; + let mut stmt = conn.prepare("select ?, ?1, :named, ?3, ?4")?; stmt.bind_at(1.try_into().unwrap(), Value::Text(&"hello".to_string())); - let i = stmt.parameter_index(":named").unwrap(); + let i = stmt.parameters().index(":named").unwrap(); stmt.bind_at(i, Value::Integer(42)); + stmt.bind_at(3.try_into().unwrap(), Value::Blob(&vec![0x1, 0x2, 0x3])); + stmt.bind_at(4.try_into().unwrap(), Value::Float(0.5)); - assert_eq!(stmt.parameter_count(), 3); + assert_eq!(stmt.parameters().count(), 4); loop { match stmt.step()? { @@ -601,12 +678,16 @@ mod tests { assert_eq!(s, "hello") } - if let Value::Integer(s) = row.values[2] { - assert_eq!(s, 42) + if let Value::Integer(i) = row.values[2] { + assert_eq!(i, 42) } - if let Value::Float(s) = row.values[3] { - assert_eq!(s, 0.5) + if let Value::Blob(v) = row.values[3] { + assert_eq!(v, &vec![0x1 as u8, 0x2, 0x3]) + } + + if let Value::Float(f) = row.values[4] { + assert_eq!(f, 0.5) } } StepResult::IO => { diff --git a/vendored/sqlite3-parser/src/lexer/sql/mod.rs b/vendored/sqlite3-parser/src/lexer/sql/mod.rs index 72cca0b97..fa98282cc 100644 --- a/vendored/sqlite3-parser/src/lexer/sql/mod.rs +++ b/vendored/sqlite3-parser/src/lexer/sql/mod.rs @@ -441,7 +441,12 @@ impl Splitter for Tokenizer { // do not include the '?' in the token Ok((Some((&data[1..=i], TK_VARIABLE)), i + 1)) } - None => Ok((Some((&data[1..], TK_VARIABLE)), data.len())), + None => { + if !data[1..].is_empty() && data[1..].iter().all(|ch| *ch == b'0') { + return Err(Error::BadVariableName(None, None)); + } + Ok((Some((&data[1..], TK_VARIABLE)), data.len())) + } } } b'$' | b'@' | b'#' | b':' => {