diff --git a/Makefile b/Makefile index 06afa0e5d..db3c3acdb 100644 --- a/Makefile +++ b/Makefile @@ -66,7 +66,7 @@ uv-sync: uv sync --all-packages .PHONE: uv-sync -test: limbo uv-sync test-compat test-vector test-sqlite3 test-shell test-extensions test-memory test-write test-update +test: limbo uv-sync test-compat test-vector test-sqlite3 test-shell test-extensions test-memory test-write test-update test-constraint .PHONY: test test-extensions: limbo uv-sync @@ -109,6 +109,10 @@ test-update: limbo uv-sync SQLITE_EXEC=$(SQLITE_EXEC) uv run --project limbo_test test-update .PHONY: test-update +test-constraint: limbo uv-sync + SQLITE_EXEC=$(SQLITE_EXEC) uv run --project limbo_test test-constraint +.PHONY: test-constraint + bench-vfs: uv-sync cargo build --release uv run --project limbo_test bench-vfs "$(SQL)" "$(N)" diff --git a/core/schema.rs b/core/schema.rs index dd09671ab..42c619693 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -81,6 +81,14 @@ impl Schema { .map_or_else(|| &[] as &[Arc], |v| v.as_slice()) } + pub fn get_index(&self, table_name: &str, index_name: &str) -> Option<&Arc> { + let name = normalize_ident(table_name); + self.indexes + .get(&name)? + .iter() + .find(|index| index.name == index_name) + } + pub fn remove_indices_for_table(&mut self, table_name: &str) { let name = normalize_ident(table_name); self.indexes.remove(&name); diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 4ca7e6fca..b17d19110 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -283,19 +283,7 @@ pub fn translate_insert( } _ => (), } - // Create and insert the record - program.emit_insn(Insn::MakeRecord { - start_reg: column_registers_start, - count: num_cols, - dest_reg: record_register, - }); - program.emit_insn(Insn::Insert { - cursor: cursor_id, - key_reg: rowid_reg, - record_reg: record_register, - flag: 0, - }); for index_col_mapping in index_col_mappings.iter() { // find which cursor we opened earlier for this index let idx_cursor_id = idx_cursors @@ -332,6 +320,49 @@ pub fn translate_insert( dest_reg: record_reg, }); + let index = schema + .get_index(&table_name.0, &index_col_mapping.idx_name) + .expect("index should be present"); + + if index.unique { + let label_idx_insert = program.allocate_label(); + program.emit_insn(Insn::NoConflict { + cursor_id: idx_cursor_id, + target_pc: label_idx_insert, + record_reg: idx_start_reg, + num_regs: num_cols, + }); + let column_names = index_col_mapping.columns.iter().enumerate().fold( + String::with_capacity(50), + |mut accum, (idx, (index, _))| { + if idx > 0 { + accum.push_str(", "); + } + + accum.push_str(&btree_table.name); + accum.push('.'); + + let name = btree_table + .columns + .get(*index) + .unwrap() + .name + .as_ref() + .expect("column name is None"); + accum.push_str(name); + + accum + }, + ); + + program.emit_insn(Insn::Halt { + err_code: SQLITE_CONSTRAINT_PRIMARYKEY, + description: column_names, + }); + + program.resolve_label(label_idx_insert, program.offset()); + } + // now do the actual index insertion using the unpacked registers program.emit_insn(Insn::IdxInsert { cursor_id: idx_cursor_id, @@ -342,6 +373,21 @@ pub fn translate_insert( flags: IdxInsertFlags::new(), }); } + + // Create and insert the record + program.emit_insn(Insn::MakeRecord { + start_reg: column_registers_start, + count: num_cols, + dest_reg: record_register, + }); + + program.emit_insn(Insn::Insert { + cursor: cursor_id, + key_reg: rowid_reg, + record_reg: record_register, + flag: 0, + }); + if inserting_multiple_rows { // For multiple rows, loop back program.emit_insn(Insn::Goto { @@ -472,7 +518,7 @@ fn resolve_columns_for_insert<'a>( /// Represents how a column in an index should be populated during an INSERT. /// Similar to ColumnMapping above but includes the index name, as well as multiple /// possible value indices for each. -#[derive(Default)] +#[derive(Debug, Default)] struct IndexColMapping { idx_name: String, columns: Vec<(usize, IndexColumn)>, diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 05fdc4938..66b2143bb 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -440,6 +440,9 @@ impl ProgramBuilder { Insn::VFilter { pc_if_empty, .. } => { resolve(pc_if_empty, "VFilter"); } + Insn::NoConflict { target_pc, .. } => { + resolve(target_pc, "NoConflict"); + } _ => {} } } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index fc7bb1833..928f7f94a 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -3894,6 +3894,60 @@ pub fn op_soft_null( Ok(InsnFunctionStepResult::Step) } +pub fn op_no_conflict( + program: &Program, + state: &mut ProgramState, + insn: &Insn, + pager: &Rc, + mv_store: Option<&Rc>, +) -> Result { + let Insn::NoConflict { + cursor_id, + target_pc, + record_reg, + num_regs, + } = insn + else { + unreachable!("unexpected Insn {:?}", insn) + }; + let mut cursor_ref = state.get_cursor(*cursor_id); + let cursor = cursor_ref.as_btree_mut(); + + let record = if *num_regs == 0 { + let record = match &state.registers[*record_reg] { + Register::Record(r) => r, + _ => { + return Err(LimboError::InternalError( + "NoConflict: exepected a record in the register".into(), + )); + } + }; + record + } else { + &make_record(&state.registers, record_reg, num_regs) + }; + // If there is at least one NULL in the index record, there cannot be a conflict so we can immediately jump. + let contains_nulls = record + .get_values() + .iter() + .any(|val| matches!(val, RefValue::Null)); + + if contains_nulls { + drop(cursor_ref); + state.pc = target_pc.to_offset_int(); + return Ok(InsnFunctionStepResult::Step); + } + + let conflict = return_if_io!(cursor.seek(SeekKey::IndexKey(record), SeekOp::EQ)); + drop(cursor_ref); + if !conflict { + state.pc = target_pc.to_offset_int(); + } else { + state.pc += 1; + } + Ok(InsnFunctionStepResult::Step) +} + pub fn op_not_exists( program: &Program, state: &mut ProgramState, diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 96afc5d17..eadb5a0d9 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -569,13 +569,13 @@ pub fn insn_to_str( ), Insn::Halt { err_code, - description: _, + description, } => ( "Halt", *err_code as i32, 0, 0, - OwnedValue::build_text(""), + OwnedValue::build_text(&description), 0, "".to_string(), ), @@ -1068,6 +1068,20 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::NoConflict { + cursor_id, + target_pc, + record_reg, + num_regs, + } => ( + "NoConflict", + *cursor_id as i32, + target_pc.to_debug_int(), + *record_reg as i32, + OwnedValue::build_text(&format!("{num_regs}")), + 0, + format!("key=r[{}]", record_reg), + ), Insn::NotExists { cursor, rowid_reg, diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 56f44bd2b..6f310f746 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -664,6 +664,18 @@ pub enum Insn { reg: usize, }, + /// If P4==0 then register P3 holds a blob constructed by [MakeRecord](https://sqlite.org/opcode.html#MakeRecord). If P4>0 then register P3 is the first of P4 registers that form an unpacked record.\ + /// + /// Cursor P1 is on an index btree. If the record identified by P3 and P4 contains any NULL value, jump immediately to P2. If all terms of the record are not-NULL then a check is done to determine if any row in the P1 index btree has a matching key prefix. If there are no matches, jump immediately to P2. If there is a match, fall through and leave the P1 cursor pointing to the matching row.\ + /// + /// This opcode is similar to [NotFound](https://sqlite.org/opcode.html#NotFound) with the exceptions that the branch is always taken if any part of the search key input is NULL. + NoConflict { + cursor_id: CursorID, // P1 index cursor + target_pc: BranchOffset, // P2 jump target + record_reg: usize, + num_regs: usize, + }, + NotExists { cursor: CursorID, rowid_reg: usize, @@ -922,6 +934,7 @@ impl Insn { Insn::NewRowid { .. } => execute::op_new_rowid, Insn::MustBeInt { .. } => execute::op_must_be_int, Insn::SoftNull { .. } => execute::op_soft_null, + Insn::NoConflict { .. } => execute::op_no_conflict, Insn::NotExists { .. } => execute::op_not_exists, Insn::OffsetLimit { .. } => execute::op_offset_limit, Insn::OpenWrite { .. } => execute::op_open_write, diff --git a/testing/cli_tests/constraint.py b/testing/cli_tests/constraint.py new file mode 100644 index 000000000..65758745b --- /dev/null +++ b/testing/cli_tests/constraint.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 + +# Eventually extract these tests to be in the fuzzing integration tests +import os +from faker import Faker +from faker.providers.lorem.en_US import Provider as P +from cli_tests.test_limbo_cli import TestLimboShell +from pydantic import BaseModel +from cli_tests import console +from enum import Enum +import random +import sqlite3 + +sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ") + + +keywords = [ + "ABORT", + "ACTION", + "ADD", + "AFTER", + "ALL", + "ALTER", + "ALWAYS", + "ANALYZE", + "AND", + "AS", + "ASC", + "ATTACH", + "AUTOINCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BY", + "CASCADE", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "COMMIT", + "CONFLICT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATABASE", + "DEFAULT", + "DEFERRABLE", + "DEFERRED", + "DELETE", + "DESC", + "DETACH", + "DISTINCT", + "DO", + "DROP", + "EACH", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXCLUSIVE", + "EXISTS", + "EXPLAIN", + "FAIL", + "FILTER", + "FIRST", + "FOLLOWING", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GENERATED", + "GLOB", + "GROUP", + "GROUPS", + "HAVING", + "IF", + "IGNORE", + "IMMEDIATE", + "IN", + "INDEX", + "INDEXED", + "INITIALLY", + "INNER", + "INSERT", + "INSTEAD", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "KEY", + "LAST", + "LEFT", + "LIKE", + "LIMIT", + "MATCH", + "MATERIALIZED", + "NATURAL", + "NO", + "NOT", + "NOTHING", + "NOTNULL", + "NULL", + "NULLS", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER", + "OTHERS", + "OUTER", + "OVER", + "PARTITION", + "PLAN", + "PRAGMA", + "PRECEDING", + "PRIMARY", + "QUERY", + "RAISE", + "RANGE", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "REINDEX", + "RELEASE", + "RENAME", + "REPLACE", + "RESTRICT", + "RETURNING", + "RIGHT", + "ROLLBACK", + "ROW", + "ROWS", + "SAVEPOINT", + "SELECT", + "SET", + "TABLE", + "TEMP", + "TEMPORARY", + "THEN", + "TIES", + "TO", + "TRANSACTION", + "TRIGGER", + "UNBOUNDED", + "UNION", + "UNIQUE", + "UPDATE", + "USING", + "VACUUM", + "VALUES", + "VIEW", + "VIRTUAL", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHOUT", +] +P.word_list = tuple(word for word in P.word_list if word.upper() not in keywords) +del P +fake: Faker = Faker(locale="en_US").unique +Faker.seed(0) + + +class ColumnType(Enum): + blob = "blob" + integer = "integer" + real = "real" + text = "text" + + def generate(self, faker: Faker) -> str: + match self.value: + case "blob": + blob = sqlite3.Binary(faker.binary(length=4)).hex() + return f"x'{blob}'" + case "integer": + return str(faker.pyint()) + case "real": + return str(faker.pyfloat()) + case "text": + return f"'{faker.text(max_nb_chars=20)}'" + + def __str__(self) -> str: + return self.value.upper() + + +class Column(BaseModel): + name: str + col_type: ColumnType + primary_key: bool + + def generate(faker: Faker) -> "Column": + name = faker.word().replace(" ", "_") + return Column( + name=name, + col_type=Faker().enum(ColumnType), + primary_key=False, + ) + + def __str__(self) -> str: + return f"{self.name} {str(self.col_type)}" + + +class Table(BaseModel): + columns: list[Column] + name: str + + def create_table(self) -> str: + accum = f"CREATE TABLE {self.name} " + col_strings = [str(col) for col in self.columns] + + pk_columns = [col.name for col in self.columns if col.primary_key] + primary_key_stmt = "PRIMARY KEY (" + ", ".join(pk_columns) + ")" + col_strings.append(primary_key_stmt) + + accum = accum + "(" + ", ".join(col_strings) + ");" + + return accum + + def generate_insert(self) -> str: + vals = [col.col_type.generate(fake) for col in self.columns] + vals = ", ".join(vals) + + return f"INSERT INTO {self.name} VALUES ({vals});" + + +class ConstraintTest(BaseModel): + table: Table + db_path: str = "testing/constraint.db" + insert_stmts: list[str] + insert_errors: list[str] + + def run( + self, + limbo: TestLimboShell, + ): + big_stmt = [self.table.create_table()] + for insert_stmt in self.insert_stmts: + big_stmt.append(insert_stmt) + + limbo.run_test("Inserting values into table", "\n".join(big_stmt), "") + + for insert_stmt in self.insert_errors: + limbo.run_test_fn( + insert_stmt, + lambda val: "Runtime error: UNIQUE constraint failed" in val, + ) + limbo.run_test( + "Nothing was inserted after error", + f"SELECT count(*) from {self.table.name};", + str(len(self.insert_stmts)), + ) + + +def validate_with_expected(result: str, expected: str): + return (expected in result, expected) + + +def generate_test(col_amount: int, primary_keys: int) -> ConstraintTest: + assert col_amount >= primary_keys, "Cannot have more primary keys than columns" + cols: list[Column] = [] + for _ in range(col_amount): + cols.append(Column.generate(fake)) + + pk_cols = random.sample( + population=cols, + k=primary_keys, + ) + + for col in pk_cols: + for c in cols: + if col.name == c.name: + c.primary_key = True + + table = Table(columns=cols, name=fake.word()) + insert_stmts = [table.generate_insert() for _ in range(col_amount)] + return ConstraintTest( + table=table, insert_stmts=insert_stmts, insert_errors=insert_stmts + ) + + +def custom_test_1() -> ConstraintTest: + cols = [ + Column(name="id", col_type="integer", primary_key=True), + Column(name="username", col_type="text", primary_key=True), + ] + table = Table(columns=cols, name="users") + insert_stmts = [ + "INSERT INTO users VALUES (1, 'alice');", + "INSERT INTO users VALUES (2, 'bob');", + ] + return ConstraintTest( + table=table, insert_stmts=insert_stmts, insert_errors=insert_stmts + ) + + +def custom_test_2(limbo: TestLimboShell): + create = "CREATE TABLE users (id INT PRIMARY KEY, username TEXT);" + first_insert = "INSERT INTO users VALUES (1, 'alice');" + limbo.run_test("Create unique INT index", create + first_insert, "") + fail_insert = "INSERT INTO users VALUES (1, 'bob');" + limbo.run_test_fn( + fail_insert, + lambda val: "Runtime error: UNIQUE constraint failed" in val, + ) + + +def all_tests() -> list[ConstraintTest]: + tests: list[ConstraintTest] = [] + max_cols = 10 + + curr_fake = Faker() + for _ in range(25): + num_cols = curr_fake.pyint(1, max_cols) + test = generate_test(num_cols, curr_fake.pyint(1, num_cols)) + tests.append(test) + + tests.append(custom_test_1()) + return tests + + +def cleanup(db_fullpath: str): + wal_path = f"{db_fullpath}-wal" + shm_path = f"{db_fullpath}-shm" + paths = [db_fullpath, wal_path, shm_path] + for path in paths: + if os.path.exists(path): + os.remove(path) + + +def main(): + tests = all_tests() + for test in tests: + console.info(test.table) + db_path = test.db_path + try: + # Use with syntax to automatically close shell on error + with TestLimboShell("") as limbo: + limbo.execute_dot(f".open {db_path}") + test.run(limbo) + + except Exception as e: + console.error(f"Test FAILED: {e}") + console.debug(test.table.create_table(), test.insert_stmts) + cleanup(db_path) + exit(1) + # delete db after every compat test so we we have fresh db for next test + cleanup(db_path) + + db_path = "testing/constraint.db" + try: + with TestLimboShell("") as limbo: + limbo.execute_dot(f".open {db_path}") + custom_test_2(limbo) + except Exception as e: + console.error(f"Test FAILED: {e}") + cleanup(db_path) + exit(1) + cleanup(db_path) + console.info("All tests passed successfully.") + + +if __name__ == "__main__": + main() diff --git a/testing/pyproject.toml b/testing/pyproject.toml index cdd30ec54..0aed7b99b 100644 --- a/testing/pyproject.toml +++ b/testing/pyproject.toml @@ -16,6 +16,7 @@ test-extensions = "cli_tests.extensions:main" test-update = "cli_tests.update:main" test-memory = "cli_tests.memory:main" bench-vfs = "cli_tests.vfs_bench:main" +test-constraint = "cli_tests.constraint:main" [tool.uv] package = true