Add several more rust tests for parameter binding

This commit is contained in:
PThorpe92
2025-05-07 22:49:07 -04:00
parent 56f5f47e86
commit 50f2621c12
7 changed files with 288 additions and 111 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"log"
"math"
"os"
"testing"
_ "github.com/tursodatabase/limbo"
@@ -16,6 +17,7 @@ var (
)
func TestMain(m *testing.M) {
log.SetOutput(os.Stdout)
conn, connErr = sql.Open("sqlite3", ":memory:")
if connErr != nil {
panic(connErr)
@@ -609,80 +611,80 @@ func TestJSONFunctions(t *testing.T) {
}
// TODO: make these pass, this is a separate issue
// func TestParameterOrdering(t *testing.T) {
// newConn, err := sql.Open("sqlite3", ":memory:")
// if err != nil {
// t.Fatalf("Error opening new connection: %v", err)
// }
// sql := "CREATE TABLE test (a,b,c);"
// newConn.Exec(sql)
//
// // Test inserting with parameters in a different order than
// // the table definition.
// sql = "INSERT INTO test (b, c ,a) VALUES (?, ?, ?);"
// expectedValues := []int{1, 2, 3}
// stmt, err := newConn.Prepare(sql)
// _, err = stmt.Exec(expectedValues[1], expectedValues[2], expectedValues[0])
// if err != nil {
// t.Fatalf("Error preparing statement: %v", err)
// }
// // check that the values are in the correct order
// query := "SELECT a,b,c FROM test;"
// rows, err := newConn.Query(query)
// if err != nil {
// t.Fatalf("Error executing query: %v", err)
// }
// for rows.Next() {
// var a, b, c int
// err := rows.Scan(&a, &b, &c)
// if err != nil {
// t.Fatal("Error scanning row: ", err)
// }
// result := []int{a, b, c}
// for i := range 3 {
// if result[i] != expectedValues[i] {
// fmt.Printf("RESULTS: %d, %d, %d\n", a, b, c)
// fmt.Printf("EXPECTED: %d, %d, %d\n", expectedValues[0], expectedValues[1], expectedValues[2])
// }
// }
// }
//
// -- part 2 --
// mixed parameters and regular values
// sql2 := "CREATE TABLE test2 (a,b,c);"
// newConn.Exec(sql2)
// expectedValues2 := []int{1, 2, 3}
//
// // Test inserting with parameters in a different order than
// // the table definition, with a mixed regular parameter included
// sql2 = "INSERT INTO test2 (a, b ,c) VALUES (1, ?, ?);"
// _, err = newConn.Exec(sql2, expectedValues2[1], expectedValues2[2])
// if err != nil {
// t.Fatalf("Error preparing statement: %v", err)
// }
// // check that the values are in the correct order
// query2 := "SELECT a,b,c FROM test2;"
// rows2, err := newConn.Query(query2)
// if err != nil {
// t.Fatalf("Error executing query: %v", err)
// }
// for rows2.Next() {
// var a, b, c int
// err := rows2.Scan(&a, &b, &c)
// if err != nil {
// t.Fatal("Error scanning row: ", err)
// }
// // result := []int{a, b, c}
//
// fmt.Printf("RESULTS: %d, %d, %d\n", a, b, c)
// fmt.Printf("EXPECTED: %d, %d, %d\n", expectedValues[0], expectedValues[1], expectedValues[2])
// for i := range 3 {
// if result[i] != expectedValues[i] {
// t.Fatalf("Expected %d, got %d", expectedValues[i], result[i])
// }
//}
//}
//}
func TestParameterOrdering(t *testing.T) {
newConn, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening new connection: %v", err)
}
sql := "CREATE TABLE test (a,b,c);"
newConn.Exec(sql)
// Test inserting with parameters in a different order than
// the table definition.
sql = "INSERT INTO test (b, c ,a) VALUES (?, ?, ?);"
expectedValues := []int{1, 2, 3}
stmt, err := newConn.Prepare(sql)
_, err = stmt.Exec(expectedValues[1], expectedValues[2], expectedValues[0])
if err != nil {
t.Fatalf("Error preparing statement: %v", err)
}
// check that the values are in the correct order
query := "SELECT a,b,c FROM test;"
rows, err := newConn.Query(query)
if err != nil {
t.Fatalf("Error executing query: %v", err)
}
for rows.Next() {
var a, b, c int
err := rows.Scan(&a, &b, &c)
if err != nil {
t.Fatal("Error scanning row: ", err)
}
result := []int{a, b, c}
for i := range 3 {
if result[i] != expectedValues[i] {
fmt.Printf("RESULTS: %d, %d, %d\n", a, b, c)
fmt.Printf("EXPECTED: %d, %d, %d\n", expectedValues[0], expectedValues[1], expectedValues[2])
}
}
}
// -- part 2 --
// mixed parameters and regular values
sql2 := "CREATE TABLE test2 (a,b,c);"
newConn.Exec(sql2)
expectedValues2 := []int{1, 2, 3}
// Test inserting with parameters in a different order than
// the table definition, with a mixed regular parameter included
sql2 = "INSERT INTO test2 (a, b ,c) VALUES (1, ?, ?);"
_, err = newConn.Exec(sql2, expectedValues2[1], expectedValues2[2])
if err != nil {
t.Fatalf("Error preparing statement: %v", err)
}
// check that the values are in the correct order
query2 := "SELECT a,b,c FROM test2;"
rows2, err := newConn.Query(query2)
if err != nil {
t.Fatalf("Error executing query: %v", err)
}
for rows2.Next() {
var a, b, c int
err := rows2.Scan(&a, &b, &c)
if err != nil {
t.Fatal("Error scanning row: ", err)
}
result := []int{a, b, c}
fmt.Printf("RESULTS: %d, %d, %d\n", a, b, c)
fmt.Printf("EXPECTED: %d, %d, %d\n", expectedValues[0], expectedValues[1], expectedValues[2])
for i := range 3 {
if result[i] != expectedValues[i] {
t.Fatalf("Expected %d, got %d", expectedValues[i], result[i])
}
}
}
}
// TODO: make this pass
// func TestUpdateParameters(t *testing.T) {

View File

@@ -636,7 +636,7 @@ impl Statement {
}
pub fn bind_at(&mut self, index: NonZero<usize>, value: OwnedValue) {
let internal = self.program.parameters.get_remapped_value(index);
let internal = self.program.parameters.get_remapped_index(index);
self.state.bind_at(internal, value);
}

View File

@@ -27,7 +27,12 @@ impl Parameter {
pub struct Parameters {
index: NonZero<usize>,
pub list: Vec<Parameter>,
remap: Vec<NonZero<usize>>,
remap: Option<Vec<NonZero<usize>>>,
// Indexes of the referenced insert values to maintain ordering of paramaters
param_positions: Option<Vec<(usize, NonZero<usize>)>>,
// For insert statements with multiple rows
current_insert_row_idx: usize,
current_col_value_idx: Option<usize>,
}
impl Default for Parameters {
@@ -41,7 +46,10 @@ impl Parameters {
Self {
index: 1.try_into().unwrap(),
list: vec![],
remap: vec![],
remap: None,
param_positions: None,
current_insert_row_idx: 0,
current_col_value_idx: None,
}
}
@@ -51,13 +59,43 @@ impl Parameters {
params.len()
}
pub fn set_parameter_remap(&mut self, remap: Vec<NonZero<usize>>) {
tracing::debug!("remap: {:?}", remap);
self.remap = remap;
pub fn set_value_index(&mut self, idx: usize) {
self.current_col_value_idx = Some(idx);
}
pub fn get_remapped_value(&self, idx: NonZero<usize>) -> NonZero<usize> {
let res = *self.remap.get(idx.get() - 1).unwrap_or(&idx);
/// Add a parameter position to the array used to build the remap.
pub fn push_parameter_position(&mut self, index: NonZero<usize>) {
if let Some(cur) = self.current_col_value_idx {
if let Some(positions) = self.param_positions.as_mut() {
positions.push((cur, index));
tracing::debug!("push parameter position: {:?}", positions);
}
}
}
/// Initialize the stored positions array at the start of an insert statement
pub fn init_parameter_remap(&mut self, cols: usize) {
self.param_positions = Some(Vec::with_capacity(cols));
}
/// Sorts the stored value indexes and builds and sets the remap array.
pub fn sort_and_build_remap(&mut self) {
self.remap = self.param_positions.as_mut().map(|positions| {
// sort by value_index
positions.sort_by_key(|(idx, _)| *idx);
tracing::debug!("param positions: {:?}", positions);
// collect the parameter indexes
positions.iter().map(|(_, idx)| *idx).collect::<Vec<_>>()
})
}
/// Returns the remapped index for a given parameter index or the original index if none is found
pub fn get_remapped_index(&self, idx: NonZero<usize>) -> NonZero<usize> {
let res = *self
.remap
.as_ref()
.map(|p| p.get(idx.get() - 1).unwrap_or(&idx))
.unwrap_or(&idx);
tracing::debug!("get_remapped_value: {idx}, value: {res}");
res
}

View File

@@ -2147,6 +2147,7 @@ pub fn translate_expr(
}
},
ast::Expr::Variable(name) => {
let index = program.parameters.push(name);
// Table t: (a,b,c)
// For 'insert' statements:
// INSERT INTO t (b,c,a) values (?,?,?)
@@ -2154,14 +2155,10 @@ pub fn translate_expr(
// the parameter was given for an insert statement. Then, we may end up with something
// like: insert into (b,c,a) values (22,?,?), in which case we will get a = 2, c = 1
// instead of previously we would have gotten a = 0, c = 1
// where it instead should be c = 0, a = 1. So all we can do is store the value index
// alongside the index into the parameters list, then during bind_at: we can translate
// where it instead should be c = 0, a = 1. So we store the value index
// alongside the index into the parameters list, allowing bind_at: we can translate
// this value into the proper order.
let index = program.parameters.push(name);
if let Some(ref mut indicies) = &mut program.param_positions {
// (value_index, parameter_index)
indicies.push((program.current_col_idx.unwrap_or(index.get()), index));
}
program.parameters.push_parameter_position(index);
program.emit_insn(Insn::Variable {
index,
dest: target_register,

View File

@@ -54,6 +54,9 @@ pub fn translate_insert(
None => crate::bail_corrupt_error!("Parse error: no such table: {}", table_name),
};
let resolver = Resolver::new(syms);
program
.parameters
.init_parameter_remap(table.columns().len());
if let Some(virtual_table) = &table.virtual_table() {
translate_virtual_table_insert(
&mut program,
@@ -589,7 +592,6 @@ fn populate_column_registers(
rowid_reg: usize,
resolver: &Resolver,
) -> Result<()> {
program.param_positions = Some(vec![]);
for (i, mapping) in column_mappings.iter().enumerate() {
let target_reg = column_registers_start + i;
@@ -606,7 +608,7 @@ fn populate_column_registers(
target_reg
};
// set the value index to make it available to the translator
program.current_col_idx = Some(value_index);
program.parameters.set_value_index(value_index);
translate_expr_no_constant_opt(
program,
None,
@@ -645,17 +647,7 @@ fn populate_column_registers(
}
}
}
// if there are any parameter positions, we sort them by the value_index position
// to ensure we are binding the parameters to the proper index later on
if let Some(ref mut params) = program.param_positions.as_mut() {
// sort the tuples by the value_index position, leaving the param_index in right order.
params.sort_by_key(|(val_pos, _)| *val_pos);
let remap = params
.iter()
.map(|(_, internal_idx)| *internal_idx)
.collect();
program.set_param_remap(remap);
}
program.parameters.sort_and_build_remap();
Ok(())
}

View File

@@ -39,9 +39,6 @@ pub struct ProgramBuilder {
pub parameters: Parameters,
pub result_columns: Vec<ResultSetColumn>,
pub table_references: Vec<TableReference>,
// Indexes of the referenced insert values to maintain ordering of paramaters
pub param_positions: Option<Vec<(usize, NonZero<usize>)>>,
pub current_col_idx: Option<usize>,
}
#[derive(Debug, Clone)]
@@ -99,8 +96,6 @@ impl ProgramBuilder {
parameters: Parameters::new(),
result_columns: Vec::new(),
table_references: Vec::new(),
param_positions: None,
current_col_idx: None,
}
}
@@ -113,10 +108,6 @@ impl ProgramBuilder {
span
}
pub fn set_param_remap(&mut self, remap: Vec<NonZero<usize>>) {
self.parameters.set_parameter_remap(remap);
}
/// End the current constant span. The last instruction that was emitted is the last
/// instruction in the span.
pub fn constant_span_end(&mut self, span_idx: usize) {

View File

@@ -246,3 +246,160 @@ fn test_insert_parameter_remap_all_params() -> anyhow::Result<()> {
assert_eq!(ins.parameters().count(), 4);
Ok(())
}
#[test]
fn test_insert_parameter_multiple_remap_backwards() -> anyhow::Result<()> {
// ─────────────────────── schema ──────────────────────────────
// Table a b c d
// INSERT lists: d , c , b , a
// VALUES list: ?1 , ?2 , ?3 , ?4
//
// Expected row on disk: a = ?1 , b = ?2 , c = ?3 , d = ?4
//
// The row should be (111, 222, 333, 444)
// ───────────────────────────────────────────────────────────────
let tmp_db = TempDatabase::new_with_rusqlite(
"create table test (a integer, b integer, c integer, d integer);",
);
let conn = tmp_db.connect_limbo();
let mut ins = conn.prepare("insert into test (d,c,b,a) values (?, ?, ?, ?);")?;
let values = [
OwnedValue::Integer(444), // ?1 → d
OwnedValue::Integer(333), // ?2 → c
OwnedValue::Integer(222), // ?3 → b
OwnedValue::Integer(111), // ?4 → a
];
for (i, value) in values.iter().enumerate() {
let idx = i + 1;
ins.bind_at(idx.try_into()?, value.clone());
}
// execute the insert (no rows returned)
loop {
match ins.step()? {
StepResult::IO => tmp_db.io.run_once()?,
StepResult::Done | StepResult::Interrupt => break,
StepResult::Busy => panic!("database busy"),
_ => {}
}
}
let mut sel = conn.prepare("select a, b, c, d from test;")?;
loop {
match sel.step()? {
StepResult::Row => {
let row = sel.row().unwrap();
// insert_index = 2
// A = 111
assert_eq!(
row.get::<&OwnedValue>(0).unwrap(),
&OwnedValue::Integer(111)
);
// insert_index = 4
// B = 444
assert_eq!(
row.get::<&OwnedValue>(1).unwrap(),
&OwnedValue::Integer(222)
);
// insert_index = 3
// C = 333
assert_eq!(
row.get::<&OwnedValue>(2).unwrap(),
&OwnedValue::Integer(333)
);
// insert_index = 1
// D = 999
assert_eq!(
row.get::<&OwnedValue>(3).unwrap(),
&OwnedValue::Integer(444)
);
}
StepResult::IO => tmp_db.io.run_once()?,
StepResult::Done | StepResult::Interrupt => break,
StepResult::Busy => panic!("database busy"),
}
}
assert_eq!(ins.parameters().count(), 4);
Ok(())
}
#[test]
fn test_insert_parameter_multiple_no_remap() -> anyhow::Result<()> {
// ─────────────────────── schema ──────────────────────────────
// Table a b c d
// INSERT lists: a , b , c , d
// VALUES list: ?1 , ?2 , ?3 , ?4
//
// Expected row on disk: a = ?1 , b = ?2 , c = ?3 , d = ?4
//
// The row should be (111, 222, 333, 444)
// ───────────────────────────────────────────────────────────────
let tmp_db = TempDatabase::new_with_rusqlite(
"create table test (a integer, b integer, c integer, d integer);",
);
let conn = tmp_db.connect_limbo();
let mut ins = conn.prepare("insert into test (a,b,c,d) values (?, ?, ?, ?);")?;
let values = [
OwnedValue::Integer(111), // ?1 → a
OwnedValue::Integer(222), // ?2 → b
OwnedValue::Integer(333), // ?3 → c
OwnedValue::Integer(444), // ?4 → d
];
for (i, value) in values.iter().enumerate() {
let idx = i + 1;
ins.bind_at(idx.try_into()?, value.clone());
}
// execute the insert (no rows returned)
loop {
match ins.step()? {
StepResult::IO => tmp_db.io.run_once()?,
StepResult::Done | StepResult::Interrupt => break,
StepResult::Busy => panic!("database busy"),
_ => {}
}
}
let mut sel = conn.prepare("select a, b, c, d from test;")?;
loop {
match sel.step()? {
StepResult::Row => {
let row = sel.row().unwrap();
// insert_index = 2
// A = 111
assert_eq!(
row.get::<&OwnedValue>(0).unwrap(),
&OwnedValue::Integer(111)
);
// insert_index = 4
// B = 444
assert_eq!(
row.get::<&OwnedValue>(1).unwrap(),
&OwnedValue::Integer(222)
);
// insert_index = 3
// C = 333
assert_eq!(
row.get::<&OwnedValue>(2).unwrap(),
&OwnedValue::Integer(333)
);
// insert_index = 1
// D = 999
assert_eq!(
row.get::<&OwnedValue>(3).unwrap(),
&OwnedValue::Integer(444)
);
}
StepResult::IO => tmp_db.io.run_once()?,
StepResult::Done | StepResult::Interrupt => break,
StepResult::Busy => panic!("database busy"),
}
}
assert_eq!(ins.parameters().count(), 4);
Ok(())
}