Merge 'Initial implementation of Statement::bind_at' from Levy A.

Resolves #607.
- [x] Index parameters.
- [x] Named parameters.
- [x] Parameter count.
- [ ] More tests.
- [ ] Expose to Sqlite3 API.

Closes #675
This commit is contained in:
Pekka Enberg
2025-01-15 22:52:56 +02:00
13 changed files with 336 additions and 8 deletions

View File

@@ -1,3 +1,5 @@
use std::num::NonZero;
use thiserror::Error;
#[derive(Debug, Error, miette::Diagnostic)]
@@ -41,6 +43,8 @@ pub enum LimboError {
Constraint(String),
#[error("Extension error: {0}")]
ExtensionError(String),
#[error("Unbound parameter at index {0}")]
Unbound(NonZero<usize>),
}
#[macro_export]

View File

@@ -4,6 +4,7 @@ mod function;
mod io;
#[cfg(feature = "json")]
mod json;
mod parameters;
mod pseudo;
mod result;
mod schema;
@@ -28,6 +29,7 @@ use sqlite3_parser::ast;
use sqlite3_parser::{ast::Cmd, lexer::sql::Parser};
use std::cell::Cell;
use std::collections::HashMap;
use std::num::NonZero;
use std::sync::{Arc, OnceLock, RwLock};
use std::{cell::RefCell, rc::Rc};
use storage::btree::btree_init_page;
@@ -43,7 +45,7 @@ use util::parse_schema_rows;
pub use error::LimboError;
use translate::select::prepare_select_plan;
pub type Result<T> = std::result::Result<T, error::LimboError>;
pub type Result<T, E = error::LimboError> = std::result::Result<T, E>;
use crate::translate::optimizer::optimize_plan;
pub use io::OpenFlags;
@@ -386,6 +388,7 @@ impl Connection {
Rc::downgrade(self),
syms,
)?;
let mut state = vdbe::ProgramState::new(program.max_registers);
program.step(&mut state, self.pager.clone())?;
}
@@ -473,7 +476,18 @@ impl Statement {
Ok(Rows::new(stmt))
}
pub fn reset(&self) {}
pub fn parameters(&self) -> &parameters::Parameters {
&self.program.parameters
}
pub fn bind_at(&mut self, index: NonZero<usize>, value: Value) {
self.state.bind_at(index, value.into());
}
pub fn reset(&mut self) {
let state = vdbe::ProgramState::new(self.program.max_registers);
self.state = state
}
}
pub enum StepResult<'a> {

111
core/parameters.rs Normal file
View File

@@ -0,0 +1,111 @@
use std::num::NonZero;
#[derive(Clone, Debug)]
pub enum Parameter {
Anonymous(NonZero<usize>),
Indexed(NonZero<usize>),
Named(String, NonZero<usize>),
}
impl PartialEq for Parameter {
fn eq(&self, other: &Self) -> bool {
self.index() == other.index()
}
}
impl Parameter {
pub fn index(&self) -> NonZero<usize> {
match self {
Parameter::Anonymous(index) => *index,
Parameter::Indexed(index) => *index,
Parameter::Named(_, index) => *index,
}
}
}
#[derive(Debug)]
pub struct Parameters {
index: NonZero<usize>,
pub list: Vec<Parameter>,
}
impl Parameters {
pub fn new() -> Self {
Self {
index: 1.try_into().unwrap(),
list: vec![],
}
}
pub fn count(&self) -> usize {
let mut params = self.list.clone();
params.dedup();
params.len()
}
pub fn name(&self, index: NonZero<usize>) -> Option<String> {
self.list.iter().find_map(|p| match p {
Parameter::Anonymous(i) if *i == index => Some("?".to_string()),
Parameter::Indexed(i) if *i == index => Some(format!("?{i}")),
Parameter::Named(name, i) if *i == index => Some(name.to_owned()),
_ => None,
})
}
pub fn index(&self, name: impl AsRef<str>) -> Option<NonZero<usize>> {
self.list
.iter()
.find_map(|p| match p {
Parameter::Named(n, index) if n == name.as_ref() => Some(index),
_ => None,
})
.copied()
}
pub fn next_index(&mut self) -> NonZero<usize> {
let index = self.index;
self.index = self.index.checked_add(1).unwrap();
index
}
pub fn push(&mut self, name: impl AsRef<str>) -> NonZero<usize> {
match name.as_ref() {
"" => {
let index = self.next_index();
self.list.push(Parameter::Anonymous(index));
log::trace!("anonymous parameter at {index}");
index
}
name if name.starts_with(&['$', ':', '@', '#']) => {
match self
.list
.iter()
.find(|p| matches!(p, Parameter::Named(n, _) if name == n))
{
Some(t) => {
let index = t.index();
self.list.push(t.clone());
log::trace!("named parameter at {index} as {name}");
index
}
None => {
let index = self.next_index();
self.list.push(Parameter::Named(name.to_owned(), index));
log::trace!("named parameter at {index} as {name}");
index
}
}
}
index => {
// SAFETY: Garanteed from parser that the index is bigger that 0.
let index: NonZero<usize> = index.parse().unwrap();
if index > self.index {
self.index = index.checked_add(1).unwrap();
}
self.list.push(Parameter::Indexed(index));
log::trace!("indexed parameter at {index}");
index
}
}
}
}

View File

@@ -1710,7 +1710,14 @@ pub fn translate_expr(
}
_ => todo!(),
},
ast::Expr::Variable(_) => todo!(),
ast::Expr::Variable(name) => {
let index = program.parameters.push(name);
program.emit_insn(Insn::Variable {
index,
dest: target_register,
});
Ok(target_register)
}
}
}

View File

@@ -32,8 +32,7 @@ use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program};
use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable};
use insert::translate_insert;
use select::translate_select;
use sqlite3_parser::ast::fmt::ToTokens;
use sqlite3_parser::ast::{self, PragmaName};
use sqlite3_parser::ast::{self, fmt::ToTokens, PragmaName};
use std::cell::RefCell;
use std::fmt::Display;
use std::rc::{Rc, Weak};
@@ -49,6 +48,7 @@ pub fn translate(
syms: &SymbolTable,
) -> Result<Program> {
let mut program = ProgramBuilder::new();
match stmt {
ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"),
ast::Stmt::Analyze(_) => bail_parse_error!("ANALYZE not supported yet"),
@@ -119,6 +119,7 @@ pub fn translate(
)?;
}
}
Ok(program.build(database_header, connection))
}

View File

@@ -269,7 +269,7 @@ pub fn bind_column_references(
bind_column_references(expr, referenced_tables)?;
Ok(())
}
ast::Expr::Variable(_) => todo!(),
ast::Expr::Variable(_) => Ok(()),
}
}

View File

@@ -336,6 +336,18 @@ impl std::ops::DivAssign<OwnedValue> for OwnedValue {
}
}
impl From<Value<'_>> for OwnedValue {
fn from(value: Value<'_>) -> Self {
match value {
Value::Null => OwnedValue::Null,
Value::Integer(i) => OwnedValue::Integer(i),
Value::Float(f) => OwnedValue::Float(f),
Value::Text(s) => OwnedValue::Text(LimboText::new(Rc::new(s.to_owned()))),
Value::Blob(b) => OwnedValue::Blob(Rc::new(b.to_owned())),
}
}
}
pub fn to_value(value: &OwnedValue) -> Value<'_> {
match value {
OwnedValue::Null => Value::Null,

View File

@@ -5,13 +5,13 @@ use std::{
};
use crate::{
parameters::Parameters,
schema::{BTreeTable, Index, PseudoTable},
storage::sqlite3_ondisk::DatabaseHeader,
Connection,
};
use super::{BranchOffset, CursorID, Insn, InsnReference, Program};
#[allow(dead_code)]
pub struct ProgramBuilder {
next_free_register: usize,
@@ -29,6 +29,7 @@ pub struct ProgramBuilder {
seekrowid_emitted_bitmask: u64,
// map of instruction index to manual comment (used in EXPLAIN)
comments: HashMap<InsnReference, &'static str>,
pub parameters: Parameters,
}
#[derive(Debug, Clone)]
@@ -58,6 +59,7 @@ impl ProgramBuilder {
label_to_resolved_offset: HashMap::new(),
seekrowid_emitted_bitmask: 0,
comments: HashMap::new(),
parameters: Parameters::new(),
}
}
@@ -331,6 +333,7 @@ impl ProgramBuilder {
self.constant_insns.is_empty(),
"constant_insns is not empty when build() is called, did you forget to call emit_constant_insns()?"
);
self.parameters.list.dedup();
Program {
max_registers: self.next_free_register,
insns: self.insns,
@@ -339,6 +342,7 @@ impl ProgramBuilder {
comments: self.comments,
connection,
auto_commit: true,
parameters: self.parameters,
}
}
}

View File

@@ -1062,6 +1062,15 @@ pub fn insn_to_str(
0,
format!("r[{}]=r[{}] << r[{}]", dest, lhs, rhs),
),
Insn::Variable { index, dest } => (
"Variable",
usize::from(*index) as i32,
*dest as i32,
0,
OwnedValue::build_text(Rc::new("".to_string())),
0,
format!("r[{}]=parameter({})", *dest, *index),
),
};
format!(
"{:<4} {:<17} {:<4} {:<4} {:<4} {:<13} {:<2} {}",

View File

@@ -1,3 +1,5 @@
use std::num::NonZero;
use super::{AggFunc, BranchOffset, CursorID, FuncCtx, PageIdx};
use crate::types::{OwnedRecord, OwnedValue};
use limbo_macros::Description;
@@ -487,18 +489,26 @@ pub enum Insn {
db: usize,
where_clause: String,
},
// Place the result of lhs >> rhs in dest register.
ShiftRight {
lhs: usize,
rhs: usize,
dest: usize,
},
// Place the result of lhs << rhs in dest register.
ShiftLeft {
lhs: usize,
rhs: usize,
dest: usize,
},
/// Get parameter variable.
Variable {
index: NonZero<usize>,
dest: usize,
},
}
fn cast_text_to_numerical(value: &str) -> OwnedValue {

View File

@@ -55,6 +55,7 @@ use sorter::Sorter;
use std::borrow::BorrowMut;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap};
use std::num::NonZero;
use std::rc::{Rc, Weak};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@@ -200,6 +201,7 @@ pub struct ProgramState {
ended_coroutine: HashMap<usize, bool>, // flag to indicate that a coroutine has ended (key is the yield register)
regex_cache: RegexCache,
interrupted: bool,
parameters: HashMap<NonZero<usize>, OwnedValue>,
}
impl ProgramState {
@@ -222,6 +224,7 @@ impl ProgramState {
ended_coroutine: HashMap::new(),
regex_cache: RegexCache::new(),
interrupted: false,
parameters: HashMap::new(),
}
}
@@ -240,6 +243,18 @@ impl ProgramState {
pub fn is_interrupted(&self) -> bool {
self.interrupted
}
pub fn bind_at(&mut self, index: NonZero<usize>, value: OwnedValue) {
self.parameters.insert(index, value);
}
pub fn get_parameter(&self, index: NonZero<usize>) -> Option<&OwnedValue> {
self.parameters.get(&index)
}
pub fn reset(&mut self) {
self.parameters.clear();
}
}
macro_rules! must_be_btree_cursor {
@@ -262,6 +277,7 @@ pub struct Program {
pub cursor_ref: Vec<(Option<String>, CursorType)>,
pub database_header: Rc<RefCell<DatabaseHeader>>,
pub comments: HashMap<InsnReference, &'static str>,
pub parameters: crate::parameters::Parameters,
pub connection: Weak<Connection>,
pub auto_commit: bool,
}
@@ -2182,6 +2198,13 @@ impl Program {
exec_shift_left(&state.registers[*lhs], &state.registers[*rhs]);
state.pc += 1;
}
Insn::Variable { index, dest } => {
state.registers[*dest] = state
.get_parameter(*index)
.ok_or(LimboError::Unbound(*index))?
.clone();
state.pc += 1;
}
}
}
}

View File

@@ -572,4 +572,132 @@ mod tests {
do_flush(&conn, &tmp_db)?;
Ok(())
}
#[test]
fn test_statement_reset() -> anyhow::Result<()> {
let _ = env_logger::try_init();
let tmp_db = TempDatabase::new("create table test (i integer);");
let conn = tmp_db.connect_limbo();
conn.execute("insert into test values (1)")?;
conn.execute("insert into test values (2)")?;
let mut stmt = conn.prepare("select * from test")?;
loop {
match stmt.step()? {
StepResult::Row(row) => {
assert_eq!(row.values[0], Value::Integer(1));
break;
}
StepResult::IO => tmp_db.io.run_once()?,
_ => break,
}
}
stmt.reset();
loop {
match stmt.step()? {
StepResult::Row(row) => {
assert_eq!(row.values[0], Value::Integer(1));
break;
}
StepResult::IO => tmp_db.io.run_once()?,
_ => break,
}
}
Ok(())
}
#[test]
fn test_statement_reset_bind() -> anyhow::Result<()> {
let _ = env_logger::try_init();
let tmp_db = TempDatabase::new("create table test (i integer);");
let conn = tmp_db.connect_limbo();
let mut stmt = conn.prepare("select ?")?;
stmt.bind_at(1.try_into().unwrap(), Value::Integer(1));
loop {
match stmt.step()? {
StepResult::Row(row) => {
assert_eq!(row.values[0], Value::Integer(1));
}
StepResult::IO => tmp_db.io.run_once()?,
_ => break,
}
}
stmt.reset();
stmt.bind_at(1.try_into().unwrap(), Value::Integer(2));
loop {
match stmt.step()? {
StepResult::Row(row) => {
assert_eq!(row.values[0], Value::Integer(2));
}
StepResult::IO => tmp_db.io.run_once()?,
_ => break,
}
}
Ok(())
}
#[test]
fn test_statement_bind() -> anyhow::Result<()> {
let _ = env_logger::try_init();
let tmp_db = TempDatabase::new("create table test (i integer);");
let conn = tmp_db.connect_limbo();
let mut stmt = conn.prepare("select ?, ?1, :named, ?3, ?4")?;
stmt.bind_at(1.try_into().unwrap(), Value::Text(&"hello".to_string()));
let i = stmt.parameters().index(":named").unwrap();
stmt.bind_at(i, Value::Integer(42));
stmt.bind_at(3.try_into().unwrap(), Value::Blob(&vec![0x1, 0x2, 0x3]));
stmt.bind_at(4.try_into().unwrap(), Value::Float(0.5));
assert_eq!(stmt.parameters().count(), 4);
loop {
match stmt.step()? {
StepResult::Row(row) => {
if let Value::Text(s) = row.values[0] {
assert_eq!(s, "hello")
}
if let Value::Text(s) = row.values[1] {
assert_eq!(s, "hello")
}
if let Value::Integer(i) = row.values[2] {
assert_eq!(i, 42)
}
if let Value::Blob(v) = row.values[3] {
assert_eq!(v, &vec![0x1 as u8, 0x2, 0x3])
}
if let Value::Float(f) = row.values[4] {
assert_eq!(f, 0.5)
}
}
StepResult::IO => {
tmp_db.io.run_once()?;
}
StepResult::Interrupt => break,
StepResult::Done => break,
StepResult::Busy => panic!("Database is busy"),
};
}
Ok(())
}
}

View File

@@ -441,7 +441,12 @@ impl Splitter for Tokenizer {
// do not include the '?' in the token
Ok((Some((&data[1..=i], TK_VARIABLE)), i + 1))
}
None => Ok((Some((&data[1..], TK_VARIABLE)), data.len())),
None => {
if !data[1..].is_empty() && data[1..].iter().all(|ch| *ch == b'0') {
return Err(Error::BadVariableName(None, None));
}
Ok((Some((&data[1..], TK_VARIABLE)), data.len()))
}
}
}
b'$' | b'@' | b'#' | b':' => {