feat: more parameter support

add `Statement::{parameter_index, parameter_name, parameter_count,
bind_at}`. some refactoring is still needed, this is quite a rough
iteration
This commit is contained in:
Levy A.
2025-01-14 22:15:24 -03:00
parent d3582a382f
commit 5de2694834
6 changed files with 212 additions and 38 deletions

View File

@@ -28,6 +28,7 @@ use sqlite3_parser::ast;
use sqlite3_parser::{ast::Cmd, lexer::sql::Parser};
use std::cell::Cell;
use std::collections::HashMap;
use std::num::NonZero;
use std::sync::{Arc, OnceLock, RwLock};
use std::{cell::RefCell, rc::Rc};
use storage::btree::btree_init_page;
@@ -474,12 +475,24 @@ impl Statement {
Ok(Rows::new(stmt))
}
pub fn reset(&mut self) {
self.state.reset();
pub fn parameter_count(&mut self) -> usize {
self.program.parameter_count()
}
pub fn bind(&mut self, value: Value) {
self.state.bind(value.into());
pub fn parameter_name(&self, index: NonZero<usize>) -> Option<String> {
self.program.parameter_name(index)
}
pub fn parameter_index(&self, name: impl AsRef<str>) -> Option<NonZero<usize>> {
self.program.parameter_index(name)
}
pub fn bind_at(&mut self, index: NonZero<usize>, value: Value) {
self.state.bind_at(index, value.into());
}
pub fn reset(&mut self) {
self.state.reset();
}
}

View File

@@ -1710,8 +1710,8 @@ pub fn translate_expr(
}
_ => todo!(),
},
ast::Expr::Variable(name) => {
let index = program.get_parameter_index(name);
ast::Expr::Variable(_) => {
let index = program.pop_index();
program.emit_insn(Insn::Variable {
index,
dest: target_register,

View File

@@ -32,12 +32,121 @@ use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program};
use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable};
use insert::translate_insert;
use select::translate_select;
use sqlite3_parser::ast::fmt::TokenStream;
use sqlite3_parser::ast::{self, fmt::ToTokens, PragmaName};
use sqlite3_parser::dialect::TokenType;
use std::cell::RefCell;
use std::fmt::Display;
use std::num::NonZero;
use std::rc::{Rc, Weak};
use std::str::FromStr;
#[derive(Clone, Debug)]
pub enum Parameter {
Anonymous(NonZero<usize>),
Indexed(NonZero<usize>),
Named(String, NonZero<usize>),
}
impl PartialEq for Parameter {
fn eq(&self, other: &Self) -> bool {
self.index() == other.index()
}
}
impl Parameter {
pub fn index(&self) -> NonZero<usize> {
match self {
Parameter::Anonymous(index) => *index,
Parameter::Indexed(index) => *index,
Parameter::Named(_, index) => *index,
}
}
}
/// `?` or `$` Prepared statement arg placeholder(s)
#[derive(Debug)]
pub struct Parameters {
index: NonZero<usize>,
pub list: Vec<Parameter>,
}
impl Parameters {
pub fn new() -> Self {
Self {
index: 1.try_into().unwrap(),
list: vec![],
}
}
pub fn push(&mut self, value: Parameter) {
self.list.push(value);
}
pub fn next_index(&mut self) -> NonZero<usize> {
let index = self.index;
self.index = self.index.checked_add(1).unwrap();
index
}
pub fn get(&mut self, index: usize) -> Option<&Parameter> {
self.list.get(index)
}
}
// https://sqlite.org/lang_expr.html#parameters
impl TokenStream for Parameters {
type Error = std::convert::Infallible;
fn append(
&mut self,
ty: TokenType,
value: Option<&str>,
) -> std::result::Result<(), Self::Error> {
if ty == TokenType::TK_VARIABLE {
if let Some(variable) = value {
match variable.split_at(1) {
("?", "") => {
let index = self.next_index();
self.push(Parameter::Anonymous(index.try_into().unwrap()));
log::trace!("anonymous parameter at {index}");
}
("?", index) => {
let index: NonZero<usize> = index.parse().unwrap();
if index > self.index {
self.index = index.checked_add(1).unwrap();
}
self.push(Parameter::Indexed(index.try_into().unwrap()));
log::trace!("indexed parameter at {index}");
}
(_, _) => {
match self.list.iter().find(|p| {
let Parameter::Named(name, _) = p else {
return false;
};
name == variable
}) {
Some(t) => {
log::trace!("named parameter at {} as {}", t.index(), variable);
self.push(t.clone());
}
None => {
let index = self.next_index();
self.push(Parameter::Named(
variable.to_owned(),
index.try_into().unwrap(),
));
log::trace!("named parameter at {index} as {variable}");
}
}
}
}
}
}
Ok(())
}
}
/// Translate SQL statement into bytecode program.
pub fn translate(
schema: &Schema,
@@ -47,7 +156,14 @@ pub fn translate(
connection: Weak<Connection>,
syms: &SymbolTable,
) -> Result<Program> {
let mut program = ProgramBuilder::new();
let mut parameters = Parameters::new();
stmt.to_tokens(&mut parameters).unwrap();
// dbg!(&parameters);
// dbg!(&parameters.list.clone().dedup());
let mut program = ProgramBuilder::new(parameters);
match stmt {
ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"),
ast::Stmt::Analyze(_) => bail_parse_error!("ANALYZE not supported yet"),
@@ -118,6 +234,7 @@ pub fn translate(
)?;
}
}
Ok(program.build(database_header, connection))
}

View File

@@ -32,6 +32,8 @@ pub struct ProgramBuilder {
comments: HashMap<InsnReference, &'static str>,
named_parameters: HashMap<String, NonZero<usize>>,
next_free_parameter_index: NonZero<usize>,
parameters: crate::translate::Parameters,
parameter_index: usize,
}
#[derive(Debug, Clone)]
@@ -49,7 +51,7 @@ impl CursorType {
}
impl ProgramBuilder {
pub fn new() -> Self {
pub fn new(parameters: crate::translate::Parameters) -> Self {
Self {
next_free_register: 1,
next_free_label: 0,
@@ -63,6 +65,8 @@ impl ProgramBuilder {
seekrowid_emitted_bitmask: 0,
comments: HashMap::new(),
named_parameters: HashMap::new(),
parameters,
parameter_index: 0,
}
}
@@ -336,6 +340,7 @@ impl ProgramBuilder {
self.constant_insns.is_empty(),
"constant_insns is not empty when build() is called, did you forget to call emit_constant_insns()?"
);
self.parameters.list.dedup();
Program {
max_registers: self.next_free_register,
insns: self.insns,
@@ -344,29 +349,13 @@ impl ProgramBuilder {
comments: self.comments,
connection,
auto_commit: true,
parameters: self.parameters.list,
}
}
fn next_parameter(&mut self) -> NonZero<usize> {
let index = self.next_free_parameter_index;
self.next_free_parameter_index.checked_add(1).unwrap();
index
}
pub fn get_parameter_index(&mut self, name: impl AsRef<str>) -> NonZero<usize> {
let name = name.as_ref();
if name == "" {
return self.next_parameter();
}
match self.named_parameters.get(name) {
Some(index) => *index,
None => {
let index = self.next_parameter();
self.named_parameters.insert(name.to_owned(), index);
index
}
}
pub fn pop_index(&mut self) -> NonZero<usize> {
let parameter = self.parameters.get(self.parameter_index).unwrap();
self.parameter_index += 1;
return parameter.index();
}
}

View File

@@ -201,7 +201,7 @@ pub struct ProgramState {
ended_coroutine: HashMap<usize, bool>, // flag to indicate that a coroutine has ended (key is the yield register)
regex_cache: RegexCache,
interrupted: bool,
parameters: Vec<OwnedValue>,
parameters: HashMap<NonZero<usize>, OwnedValue>,
}
impl ProgramState {
@@ -224,7 +224,7 @@ impl ProgramState {
ended_coroutine: HashMap::new(),
regex_cache: RegexCache::new(),
interrupted: false,
parameters: Vec::new(),
parameters: HashMap::new(),
}
}
@@ -244,12 +244,12 @@ impl ProgramState {
self.interrupted
}
pub fn bind(&mut self, value: OwnedValue) {
self.parameters.push(value);
pub fn bind_at(&mut self, index: NonZero<usize>, value: OwnedValue) {
self.parameters.insert(index, value);
}
pub fn get_parameter(&self, index: NonZero<usize>) -> Option<&OwnedValue> {
self.parameters.get(usize::from(index) - 1)
self.parameters.get(&index)
}
pub fn reset(&mut self) {
@@ -277,6 +277,7 @@ pub struct Program {
pub cursor_ref: Vec<(Option<String>, CursorType)>,
pub database_header: Rc<RefCell<DatabaseHeader>>,
pub comments: HashMap<InsnReference, &'static str>,
pub parameters: Vec<crate::translate::Parameter>,
pub connection: Weak<Connection>,
pub auto_commit: bool,
}
@@ -300,6 +301,38 @@ impl Program {
}
}
pub fn parameter_count(&self) -> usize {
self.parameters.len()
}
pub fn parameter_name(&self, index: NonZero<usize>) -> Option<String> {
use crate::translate::Parameter;
self.parameters.iter().find_map(|p| match p {
Parameter::Anonymous(i) if *i == index => Some("?".to_string()),
Parameter::Indexed(i) if *i == index => Some(format!("?{i}")),
Parameter::Named(name, i) if *i == index => Some(name.to_owned()),
_ => None,
})
}
pub fn parameter_index(&self, name: impl AsRef<str>) -> Option<NonZero<usize>> {
use crate::translate::Parameter;
self.parameters
.iter()
.find_map(|p| {
let Parameter::Named(parameter_name, index) = p else {
return None;
};
if name.as_ref() == parameter_name {
return Some(index);
}
None
})
.copied()
}
pub fn step<'a>(
&self,
state: &'a mut ProgramState,

View File

@@ -40,7 +40,7 @@ impl TempDatabase {
#[cfg(test)]
mod tests {
use super::*;
use limbo_core::{CheckpointStatus, Connection, Rows, StepResult, Value};
use limbo_core::{CheckpointStatus, Connection, StepResult, Value};
use log::debug;
#[ignore]
@@ -576,16 +576,38 @@ mod tests {
#[test]
fn test_statement_bind() -> anyhow::Result<()> {
let _ = env_logger::try_init();
let tmp_db = TempDatabase::new("CREATE TABLE test (x INTEGER PRIMARY KEY);");
let tmp_db = TempDatabase::new("create table test (i integer);");
let conn = tmp_db.connect_limbo();
let mut stmt = conn.prepare("select ?")?;
stmt.bind(Value::Text(&"hello".to_string()));
let mut stmt = conn.prepare("select ?, ?1, :named, ?4")?;
stmt.bind_at(1.try_into().unwrap(), Value::Text(&"hello".to_string()));
let i = stmt.parameter_index(":named").unwrap();
stmt.bind_at(i, Value::Integer(42));
stmt.bind_at(4.try_into().unwrap(), Value::Float(0.5));
assert_eq!(stmt.parameter_count(), 3);
loop {
match stmt.step()? {
StepResult::Row(row) => {
if let Value::Text(s) = row.values[0] {
assert_eq!(s, "hello")
}
if let Value::Text(s) = row.values[1] {
assert_eq!(s, "hello")
}
if let Value::Integer(s) = row.values[2] {
assert_eq!(s, 42)
}
if let Value::Float(s) = row.values[3] {
assert_eq!(s, 0.5)
}
}
StepResult::IO => {
tmp_db.io.run_once()?;