diff --git a/.github/workflows/long_fuzz_tests_btree.yml b/.github/workflows/fuzz.yml similarity index 66% rename from .github/workflows/long_fuzz_tests_btree.yml rename to .github/workflows/fuzz.yml index 982ac8604..4573d658f 100644 --- a/.github/workflows/long_fuzz_tests_btree.yml +++ b/.github/workflows/fuzz.yml @@ -12,7 +12,27 @@ on: - main jobs: - run-long-tests: + run-fuzz-tests: + runs-on: blacksmith-4vcpu-ubuntu-2404 + timeout-minutes: 30 + + steps: + - uses: actions/checkout@v3 + - uses: useblacksmith/rust-cache@v3 + with: + prefix-key: "v1-rust" # can be updated if we need to reset caches due to non-trivial change in the dependencies (for example, custom env var were set for single workspace project) + - name: Set up Python 3.10 + uses: useblacksmith/setup-python@v6 + with: + python-version: "3.10" + - name: Build + run: cargo build --verbose + - name: Run ignored long tests + run: cargo test --test fuzz_tests + env: + RUST_BACKTRACE: 1 + + run-long-fuzz-tests: runs-on: blacksmith-4vcpu-ubuntu-2404 timeout-minutes: 30 diff --git a/.github/workflows/java-publish.yml b/.github/workflows/java-publish.yml new file mode 100644 index 000000000..ec42744da --- /dev/null +++ b/.github/workflows/java-publish.yml @@ -0,0 +1,143 @@ +name: Publish Java Bindings to Maven Central + +on: + # Manually trigger the workflow + workflow_dispatch: + +env: + working-directory: bindings/java + +jobs: + # Build native libraries for each platform + build-natives: + strategy: + matrix: + include: + - os: ubuntu-latest + target: x86_64-unknown-linux-gnu + make-target: linux_x86 + artifact-name: linux-x86_64 + - os: macos-latest + target: x86_64-apple-darwin + make-target: macos_x86 + artifact-name: macos-x86_64 + - os: macos-latest + target: aarch64-apple-darwin + make-target: macos_arm64 + artifact-name: macos-arm64 + - os: ubuntu-latest + target: x86_64-pc-windows-gnu + make-target: windows + artifact-name: windows-x86_64 + + runs-on: ${{ matrix.os }} + timeout-minutes: 30 + + defaults: + run: + working-directory: ${{ env.working-directory }} + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + targets: ${{ matrix.target }} + + - name: Verify and install Rust target + run: | + echo "Installing target: ${{ matrix.target }}" + rustup target add ${{ matrix.target }} + echo "Installed targets:" + rustup target list --installed + echo "Rust version:" + rustc --version + + - name: Install cross-compilation tools (Windows on Linux) + if: matrix.target == 'x86_64-pc-windows-gnu' + run: | + sudo apt-get update + sudo apt-get install -y mingw-w64 + + - name: Build native library + run: make ${{ matrix.make-target }} + + - name: Verify build output + run: | + echo "Build completed for ${{ matrix.target }}" + ls -lah libs/ + find libs/ -type f + + - name: Upload native library + uses: actions/upload-artifact@v4 + with: + name: native-${{ matrix.artifact-name }} + path: ${{ env.working-directory }}/libs/ + retention-days: 1 + + # Publish to Maven Central with all native libraries + publish: + needs: build-natives + runs-on: ubuntu-latest + timeout-minutes: 30 + + defaults: + run: + working-directory: ${{ env.working-directory }} + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up JDK + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '8' + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v3 + + - name: Install Rust (for test builds) + uses: dtolnay/rust-toolchain@stable + + - name: Download all native libraries + uses: actions/download-artifact@v4 + with: + pattern: native-* + path: ${{ env.working-directory }}/libs-temp + merge-multiple: true + + - name: Organize native libraries + run: | + # Move downloaded artifacts to libs directory + rm -rf libs + mv libs-temp libs + echo "Native libraries collected:" + ls -R libs/ + + - name: Build test natives + run: make build_test + + - name: Run tests + run: ./gradlew test + + - name: Publish to Maven Central + env: + MAVEN_UPLOAD_USERNAME: ${{ secrets.MAVEN_UPLOAD_USERNAME }} + MAVEN_UPLOAD_PASSWORD: ${{ secrets.MAVEN_UPLOAD_PASSWORD }} + MAVEN_SIGNING_KEY: ${{ secrets.MAVEN_SIGNING_KEY }} + MAVEN_SIGNING_PASSPHRASE: ${{ secrets.MAVEN_SIGNING_PASSPHRASE }} + run: | + echo "Building, signing, and publishing to Maven Central..." + ./gradlew clean publishToMavenCentral --no-daemon --stacktrace + + - name: Upload bundle artifact + if: always() + uses: actions/upload-artifact@v4 + with: + name: maven-central-bundle + path: ${{ env.working-directory }}/build/maven-central/*.zip + retention-days: 7 \ No newline at end of file diff --git a/.github/workflows/perf_nightly.yml b/.github/workflows/perf_nightly.yml new file mode 100644 index 000000000..690c37595 --- /dev/null +++ b/.github/workflows/perf_nightly.yml @@ -0,0 +1,165 @@ +name: Nightly Benchmarks on Nyrkiö Runners (stability) + +on: + workflow_dispatch: + branches: ["main", "notmain", "master"] + schedule: + - cron: '24 4 * * *' + push: + # branches: ["main", "notmain", "master"] + branches: ["notmain"] + pull_request: + # branches: ["main", "notmain", "master"] + branches: ["notmain"] + +env: + CARGO_TERM_COLOR: never + +jobs: + bench: + runs-on: nyrkio_perf_server_4cpu_ubuntu2404 + timeout-minutes: 30 + steps: + - uses: actions/checkout@v3 + - uses: useblacksmith/setup-node@v5 + with: + node-version: 20 + # cache: 'npm' + # - name: Install dependencies + # run: npm install && npm run build + + - name: Bench + run: make bench-exclude-tpc-h 2>&1 | tee output.txt + - name: Analyze benchmark result with Nyrkiö + uses: nyrkio/change-detection@HEAD + with: + name: nightly/turso + tool: criterion + output-file-path: output.txt + + # What to do if a change is immediately detected by Nyrkiö. + # Note that smaller changes are only detected with delay, usually after a change + # persisted over 2-7 commits. Go to nyrkiö.com to view those or configure alerts. + # Note that Nyrkiö will find all changes, also improvements. This means fail-on-alert + # on pull events isn't compatible with this workflow being required to pass branch protection. + fail-on-alert: false + comment-on-alert: true + comment-always: true + # Nyrkiö configuration + # Get yours from https://nyrkio.com/docs/getting-started + nyrkio-token: ${{ secrets.NYRKIO_JWT_TOKEN }} + # HTTP requests will fail for all non-core contributors that don't have their own token. + # Don't want that to spoil the build, so: + never-fail: true + # Make results and change points public, so that any oss contributor can see them + nyrkio-public: true + + # parameters of the algorithm. Note: These are global, so we only set them once and for all. + # Smaller p-value = less change points found. Larger p-value = more, but also more false positives. + nyrkio-settings-pvalue: 0.0001 + # Ignore changes smaller than this. + nyrkio-settings-threshold: 0% + + clickbench: + runs-on: nyrkio_perf_server_4cpu_ubuntu2404 + timeout-minutes: 30 + steps: + - uses: actions/checkout@v3 + - uses: useblacksmith/setup-node@v5 + with: + node-version: 20 + + - name: Clickbench + run: make clickbench + + - name: Analyze TURSO result with Nyrkiö + uses: nyrkio/change-detection@HEAD + with: + name: nightly/clickbench/turso + tool: time + output-file-path: clickbench-tursodb.txt + # What to do if a change is immediately detected by Nyrkiö. + # Note that smaller changes are only detected with delay, usually after a change + # persisted over 2-7 commits. Go to nyrkiö.com to view those or configure alerts. + # Note that Nyrkiö will find all changes, also improvements. This means fail-on-alert + # on pull events isn't compatible with this workflow being required to pass branch protection. + fail-on-alert: false + comment-on-alert: true + comment-always: true + # Nyrkiö configuration + # Get yours from https://nyrkio.com/docs/getting-started + nyrkio-token: ${{ secrets.NYRKIO_JWT_TOKEN }} + # HTTP requests will fail for all non-core contributors that don't have their own token. + # Don't want that to spoil the build, so: + never-fail: true + # Make results and change points public, so that any oss contributor can see them + nyrkio-public: true + + - name: Analyze SQLITE3 result with Nyrkiö + uses: nyrkio/change-detection@HEAD + with: + name: nightly/clickbench/sqlite3 + tool: time + output-file-path: clickbench-sqlite3.txt + fail-on-alert: false + comment-on-alert: true + comment-always: false + nyrkio-token: ${{ secrets.NYRKIO_JWT_TOKEN }} + never-fail: true + nyrkio-public: true + + tpc-h-criterion: + runs-on: nyrkio_perf_server_4cpu_ubuntu2404 + timeout-minutes: 60 + env: + DB_FILE: "perf/tpc-h/TPC-H.db" + steps: + - uses: actions/checkout@v3 + - uses: useblacksmith/rust-cache@v3 + with: + prefix-key: "v1-rust" # can be updated if we need to reset caches due to non-trivial change in the dependencies (for example, custom env var were set for single workspace project) + + - name: Cache TPC-H + id: cache-primes + uses: useblacksmith/cache@v5 + with: + path: ${{ env.DB_FILE }} + key: tpc-h + - name: Download TPC-H + if: steps.cache-primes.outputs.cache-hit != 'true' + env: + DB_URL: "https://github.com/lovasoa/TPCH-sqlite/releases/download/v1.0/TPC-H.db" + run: wget -O $DB_FILE --no-verbose $DB_URL + + - name: Bench + run: cargo bench --bench tpc_h_benchmark 2>&1 | tee output.txt + - name: Analyze benchmark result with Nyrkiö + uses: nyrkio/change-detection@HEAD + with: + name: nightly/tpc-h + tool: criterion + output-file-path: output.txt + + # What to do if a change is immediately detected by Nyrkiö. + # Note that smaller changes are only detected with delay, usually after a change + # persisted over 2-7 commits. Go to nyrkiö.com to view those or configure alerts. + # Note that Nyrkiö will find all changes, also improvements. This means fail-on-alert + # on pull events isn't compatible with this workflow being required to pass branch protection. + fail-on-alert: false + comment-on-alert: true + comment-always: true + # Nyrkiö configuration + # Get yours from https://nyrkio.com/docs/getting-started + nyrkio-token: ${{ secrets.NYRKIO_JWT_TOKEN }} + # HTTP requests will fail for all non-core contributors that don't have their own token. + # Don't want that to spoil the build, so: + never-fail: true + # Make results and change points public, so that any oss contributor can see them + nyrkio-public: true + + # parameters of the algorithm. Note: These are global, so we only set them once and for all. + # Smaller p-value = less change points found. Larger p-value = more, but also more false positives. + nyrkio-settings-pvalue: 0.0001 + # Ignore changes smaller than this. + nyrkio-settings-threshold: 0% + diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 49140d2b8..e059236c1 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -48,7 +48,7 @@ jobs: - name: Test env: RUST_LOG: ${{ runner.debug && 'turso_core::storage=trace' || '' }} - run: cargo test --verbose --features checksum + run: cargo test --verbose --features checksum --test integration_tests timeout-minutes: 20 clippy: diff --git a/.gitignore b/.gitignore index 2e6cf78f8..666b560b0 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,7 @@ simulator.log **/*.txt profile.json.gz simulator-output/ +tests/*.sql &1 bisected.sql diff --git a/COMPAT.md b/COMPAT.md index d3f651453..e14c04005 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -95,6 +95,7 @@ Turso aims to be fully compatible with SQLite, with opt-in features not supporte | UPDATE | Yes | | | VACUUM | No | | | WITH clause | Partial | No RECURSIVE, no MATERIALIZED, only SELECT supported in CTEs | +| WINDOW functions | Partial | only default frame definition, no window-specific functions (rank() etc) | #### [PRAGMA](https://www.sqlite.org/pragma.html) @@ -448,8 +449,8 @@ Modifiers: | Eq | Yes | | | Expire | No | | | Explain | No | | -| FkCounter | No | | -| FkIfZero | No | | +| FkCounter | Yes | | +| FkIfZero | Yes | | | Found | Yes | | | Function | Yes | | | Ge | Yes | | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f0e1b7aa5..1aaa08f32 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -128,6 +128,44 @@ echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid cargo bench --bench benchmark -- --profile-time=5 ``` +## Debugging bugs + +### Query execution debugging + +Turso aims towards SQLite compatibility. If you find a query that has different behavior than SQLite, the first step is to check what the generated bytecode looks like. + +To do that, first run the `EXPLAIN` command in `sqlite3` shell: + +``` +sqlite> EXPLAIN SELECT first_name FROM users; +addr opcode p1 p2 p3 p4 p5 comment +---- ------------- ---- ---- ---- ------------- -- ------------- +0 Init 0 7 0 0 Start at 7 +1 OpenRead 0 2 0 2 0 root=2 iDb=0; users +2 Rewind 0 6 0 0 +3 Column 0 1 1 0 r[1]= cursor 0 column 1 +4 ResultRow 1 1 0 0 output=r[1] +5 Next 0 3 0 1 +6 Halt 0 0 0 0 +7 Transaction 0 0 1 0 1 usesStmtJournal=0 +8 Goto 0 1 0 0 +``` + +and then run the same command in Turso's shell. + +If the bytecode is different, that's the bug -- work towards fixing code generation. +If the bytecode is the same, but query results are different, then the bug is somewhere in the virtual machine interpreter or storage layer. + +### Stress testing with sanitizers + +If you suspect a multi-threading issue, you can run the stress test with ThreadSanitizer enabled as follows: + +```console +rustup toolchain install nightly +rustup override set nightly +cargo run -Zbuild-std --target x86_64-unknown-linux-gnu -p turso_stress -- --vfs syscall --nr-threads 4 --nr-iterations 1000 +``` + ## Finding things to work on The issue tracker has issues tagged with [good first issue](https://github.com/tursodatabase/limbo/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22), @@ -155,33 +193,6 @@ To produce pull requests like this, you should learn how to use Git's interactiv For a longer discussion on good commits, see Al Tenhundfeld's [What makes a good git commit](https://www.simplethread.com/what-makes-a-good-git-commit/), for example. - -## Debugging query execution - -Turso aims towards SQLite compatibility. If you find a query that has different behavior than SQLite, the first step is to check what the generated bytecode looks like. - -To do that, first run the `EXPLAIN` command in `sqlite3` shell: - -``` -sqlite> EXPLAIN SELECT first_name FROM users; -addr opcode p1 p2 p3 p4 p5 comment ----- ------------- ---- ---- ---- ------------- -- ------------- -0 Init 0 7 0 0 Start at 7 -1 OpenRead 0 2 0 2 0 root=2 iDb=0; users -2 Rewind 0 6 0 0 -3 Column 0 1 1 0 r[1]= cursor 0 column 1 -4 ResultRow 1 1 0 0 output=r[1] -5 Next 0 3 0 1 -6 Halt 0 0 0 0 -7 Transaction 0 0 1 0 1 usesStmtJournal=0 -8 Goto 0 1 0 0 -``` - -and then run the same command in Turso's shell. - -If the bytecode is different, that's the bug -- work towards fixing code generation. -If the bytecode is the same, but query results are different, then the bug is somewhere in the virtual machine interpreter or storage layer. - ## Compatibility tests The `testing/test.all` is a starting point for adding functional tests using a similar syntax to SQLite. diff --git a/Cargo.lock b/Cargo.lock index ed065f587..4583e4553 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -227,6 +227,12 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayref" version = "0.3.9" @@ -255,18 +261,104 @@ dependencies = [ "wait-timeout", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "atomic" version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba" +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower 0.5.2", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -291,6 +383,12 @@ dependencies = [ "backtrace", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -431,10 +529,11 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.17" +version = "1.2.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a" +checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", @@ -639,6 +738,45 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "console-api" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8030735ecb0d128428b64cd379809817e620a40e5001c54465b99ec5feec2857" +dependencies = [ + "futures-core", + "prost 0.13.5", + "prost-types", + "tonic", + "tracing-core", +] + +[[package]] +name = "console-subscriber" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6539aa9c6a4cd31f4b1c040f860a1eac9aa80e7df6b05d506a6e7179936d6a01" +dependencies = [ + "console-api", + "crossbeam-channel", + "crossbeam-utils", + "futures-task", + "hdrhistogram", + "humantime", + "hyper-util", + "prost 0.13.5", + "prost-types", + "serde", + "serde_json", + "thread_local", + "tokio", + "tokio-stream", + "tonic", + "tracing", + "tracing-core", + "tracing-subscriber", +] + [[package]] name = "console_error_panic_hook" version = "0.1.7" @@ -690,7 +828,7 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "core_tester" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "anyhow", "assert_cmd", @@ -700,6 +838,7 @@ dependencies = [ "rand 0.9.2", "rand_chacha 0.9.0", "rusqlite", + "sql_generation", "tempfile", "test-log", "tokio", @@ -707,6 +846,7 @@ dependencies = [ "tracing-subscriber", "turso", "turso_core", + "turso_parser", "twox-hash", "zerocopy 0.8.26", ] @@ -1371,6 +1511,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" + [[package]] name = "findshlibs" version = "0.10.2" @@ -1666,6 +1812,25 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +[[package]] +name = "h2" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap 2.11.1", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "2.5.0" @@ -1700,6 +1865,19 @@ dependencies = [ "hashbrown 0.15.2", ] +[[package]] +name = "hdrhistogram" +version = "7.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" +dependencies = [ + "base64 0.21.7", + "byteorder", + "flate2", + "nom", + "num-traits", +] + [[package]] name = "heck" version = "0.5.0" @@ -1744,6 +1922,104 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + +[[package]] +name = "hyper" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "hyper", + "libc", + "pin-project-lite", + "socket2 0.6.0", + "tokio", + "tower-service", + "tracing", +] + [[package]] name = "iana-time-zone" version = "0.1.62" @@ -2140,15 +2416,6 @@ dependencies = [ "serde", ] -[[package]] -name = "julian_day_converter" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2987f71b89b85c812c8484cbf0c5d7912589e77bfdc66fd3e52f760e7859f16" -dependencies = [ - "chrono", -] - [[package]] name = "kqueue" version = "1.0.8" @@ -2278,7 +2545,7 @@ dependencies = [ [[package]] name = "limbo_completion" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "mimalloc", "turso_ext", @@ -2286,7 +2553,7 @@ dependencies = [ [[package]] name = "limbo_crypto" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "blake3", "data-encoding", @@ -2299,7 +2566,7 @@ dependencies = [ [[package]] name = "limbo_csv" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "csv", "mimalloc", @@ -2309,7 +2576,7 @@ dependencies = [ [[package]] name = "limbo_fuzzy" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "mimalloc", "turso_ext", @@ -2317,7 +2584,7 @@ dependencies = [ [[package]] name = "limbo_ipaddr" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "ipnetwork", "mimalloc", @@ -2326,7 +2593,7 @@ dependencies = [ [[package]] name = "limbo_percentile" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "mimalloc", "turso_ext", @@ -2334,7 +2601,7 @@ dependencies = [ [[package]] name = "limbo_regexp" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "mimalloc", "regex", @@ -2343,9 +2610,10 @@ dependencies = [ [[package]] name = "limbo_sim" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "anyhow", + "bitflags 2.9.4", "bitmaps", "chrono", "clap", @@ -2378,7 +2646,7 @@ dependencies = [ [[package]] name = "limbo_sqlite_test_ext" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "cc", ] @@ -2458,6 +2726,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "md-5" version = "0.10.6" @@ -2547,6 +2821,12 @@ dependencies = [ "libmimalloc-sys", ] +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "minimad" version = "0.13.1" @@ -2556,6 +2836,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.5" @@ -2671,6 +2957,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "notify" version = "8.0.0" @@ -2901,6 +3197,26 @@ dependencies = [ "sha2", ] +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2925,7 +3241,7 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac26e981c03a6e53e0aee43c113e3202f5581d5360dae7bd2c70e800dd0451d" dependencies = [ - "base64", + "base64 0.22.1", "indexmap 2.11.1", "quick-xml 0.32.0", "serde", @@ -3089,6 +3405,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +dependencies = [ + "bytes", + "prost-derive 0.13.5", +] + [[package]] name = "prost" version = "0.14.1" @@ -3096,7 +3422,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.14.1", +] + +[[package]] +name = "prost-derive" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.100", ] [[package]] @@ -3112,9 +3451,18 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "prost-types" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +dependencies = [ + "prost 0.13.5", +] + [[package]] name = "py-turso" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "anyhow", "pyo3", @@ -3799,6 +4147,15 @@ dependencies = [ "similar", ] +[[package]] +name = "simsimd" +version = "6.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e3f209c5a8155b8458b1a0d3a6fc9fa09d201e6086fdaae18e9e283b9274f8f" +dependencies = [ + "cc", +] + [[package]] name = "slab" version = "0.4.9" @@ -3817,6 +4174,16 @@ dependencies = [ "serde", ] +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "socket2" version = "0.6.0" @@ -3841,7 +4208,7 @@ checksum = "d372029cb5195f9ab4e4b9aef550787dce78b124fcaee8d82519925defcd6f0d" [[package]] name = "sql_generation" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "anarchist-readable-name-generator-lib 0.2.0", "anyhow", @@ -3853,6 +4220,7 @@ dependencies = [ "rand_chacha 0.9.0", "schemars 1.0.4", "serde", + "strum", "tracing", "turso_core", "turso_parser", @@ -3982,6 +4350,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "synstructure" version = "0.13.1" @@ -4222,8 +4596,9 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "slab", - "socket2", + "socket2 0.6.0", "tokio-macros", + "tracing", "windows-sys 0.59.0", ] @@ -4238,6 +4613,30 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.8.22" @@ -4280,6 +4679,82 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076" +[[package]] +name = "tonic" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.22.1", + "bytes", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "prost 0.13.5", + "socket2 0.5.10", + "tokio", + "tokio-stream", + "tower 0.4.13", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "indexmap 1.9.3", + "pin-project", + "pin-project-lite", + "rand 0.8.5", + "slab", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.41" @@ -4353,21 +4828,29 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "turso" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "rand 0.9.2", "rand_chacha 0.9.0", "tempfile", "thiserror 2.0.16", "tokio", + "tracing", + "tracing-subscriber", "turso_core", ] [[package]] name = "turso-java" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "jni", "thiserror 2.0.16", @@ -4376,7 +4859,7 @@ dependencies = [ [[package]] name = "turso_cli" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "anyhow", "cfg-if", @@ -4412,12 +4895,13 @@ dependencies = [ [[package]] name = "turso_core" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "aegis", "aes", "aes-gcm", "antithesis_sdk", + "arc-swap", "bitflags 2.9.4", "built", "bytemuck", @@ -4427,11 +4911,9 @@ dependencies = [ "crossbeam-skiplist", "env_logger 0.11.7", "fallible-iterator", - "getrandom 0.2.15", "hex", "intrusive-collections", "io-uring", - "julian_day_converter", "libc", "libloading", "libm", @@ -4445,15 +4927,17 @@ dependencies = [ "pprof", "quickcheck", "quickcheck_macros", - "rand 0.8.5", + "rand 0.9.2", "rand_chacha 0.9.0", "regex", "regex-syntax", + "roaring", "rstest", "rusqlite", "rustix 1.0.7", "ryu", "serde", + "simsimd", "sorted-vec", "strum", "strum_macros", @@ -4471,7 +4955,7 @@ dependencies = [ [[package]] name = "turso_dart" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "flutter_rust_bridge", "turso_core", @@ -4479,7 +4963,7 @@ dependencies = [ [[package]] name = "turso_ext" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "chrono", "getrandom 0.3.2", @@ -4488,7 +4972,7 @@ dependencies = [ [[package]] name = "turso_ext_tests" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "env_logger 0.11.7", "lazy_static", @@ -4499,7 +4983,7 @@ dependencies = [ [[package]] name = "turso_macros" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "proc-macro2", "quote", @@ -4508,7 +4992,7 @@ dependencies = [ [[package]] name = "turso_node" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "chrono", "napi", @@ -4521,7 +5005,7 @@ dependencies = [ [[package]] name = "turso_parser" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "bitflags 2.9.4", "criterion", @@ -4537,7 +5021,7 @@ dependencies = [ [[package]] name = "turso_sqlite3" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "env_logger 0.11.7", "libc", @@ -4550,12 +5034,13 @@ dependencies = [ [[package]] name = "turso_stress" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "anarchist-readable-name-generator-lib 0.1.2", "antithesis_sdk", "clap", "hex", + "rusqlite", "tempfile", "tokio", "tracing", @@ -4566,15 +5051,15 @@ dependencies = [ [[package]] name = "turso_sync_engine" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "ctor 0.4.2", "futures", "genawaiter", "http", - "prost", + "prost 0.14.1", "rand 0.9.2", "rand_chacha 0.9.0", "roaring", @@ -4593,7 +5078,7 @@ dependencies = [ [[package]] name = "turso_sync_js" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "genawaiter", "napi", @@ -4608,7 +5093,7 @@ dependencies = [ [[package]] name = "turso_whopper" -version = "0.2.0" +version = "0.3.0-pre.4" dependencies = [ "anyhow", "clap", @@ -4819,6 +5304,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -5198,6 +5692,7 @@ name = "write-throughput" version = "0.1.0" dependencies = [ "clap", + "console-subscriber", "futures", "tokio", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index 51998d10a..4460ca602 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,29 +39,29 @@ exclude = [ ] [workspace.package] -version = "0.2.0" +version = "0.3.0-pre.4" authors = ["the Limbo authors"] edition = "2021" license = "MIT" repository = "https://github.com/tursodatabase/turso" [workspace.dependencies] -turso = { path = "bindings/rust", version = "0.2.0" } -turso_node = { path = "bindings/javascript", version = "0.2.0" } -limbo_completion = { path = "extensions/completion", version = "0.2.0" } -turso_core = { path = "core", version = "0.2.0" } -turso_sync_engine = { path = "sync/engine", version = "0.2.0" } -limbo_crypto = { path = "extensions/crypto", version = "0.2.0" } -limbo_csv = { path = "extensions/csv", version = "0.2.0" } -turso_ext = { path = "extensions/core", version = "0.2.0" } -turso_ext_tests = { path = "extensions/tests", version = "0.2.0" } -limbo_ipaddr = { path = "extensions/ipaddr", version = "0.2.0" } -turso_macros = { path = "macros", version = "0.2.0" } -limbo_percentile = { path = "extensions/percentile", version = "0.2.0" } -limbo_regexp = { path = "extensions/regexp", version = "0.2.0" } -limbo_uuid = { path = "extensions/uuid", version = "0.2.0" } -turso_parser = { path = "parser", version = "0.2.0" } -limbo_fuzzy = { path = "extensions/fuzzy", version = "0.2.0" } +turso = { path = "bindings/rust", version = "0.3.0-pre.4" } +turso_node = { path = "bindings/javascript", version = "0.3.0-pre.4" } +limbo_completion = { path = "extensions/completion", version = "0.3.0-pre.4" } +turso_core = { path = "core", version = "0.3.0-pre.4" } +turso_sync_engine = { path = "sync/engine", version = "0.3.0-pre.4" } +limbo_crypto = { path = "extensions/crypto", version = "0.3.0-pre.4" } +limbo_csv = { path = "extensions/csv", version = "0.3.0-pre.4" } +turso_ext = { path = "extensions/core", version = "0.3.0-pre.4" } +turso_ext_tests = { path = "extensions/tests", version = "0.3.0-pre.4" } +limbo_ipaddr = { path = "extensions/ipaddr", version = "0.3.0-pre.4" } +turso_macros = { path = "macros", version = "0.3.0-pre.4" } +limbo_percentile = { path = "extensions/percentile", version = "0.3.0-pre.4" } +limbo_regexp = { path = "extensions/regexp", version = "0.3.0-pre.4" } +limbo_uuid = { path = "extensions/uuid", version = "0.3.0-pre.4" } +turso_parser = { path = "parser", version = "0.3.0-pre.4" } +limbo_fuzzy = { path = "extensions/fuzzy", version = "0.3.0-pre.4" } sql_generation = { path = "sql_generation" } strum = { version = "0.26", features = ["derive"] } strum_macros = "0.26" @@ -90,7 +90,7 @@ fallible-iterator = "0.3.0" criterion = "0.5" chrono = { version = "0.4.42", default-features = false } hex = "0.4" -antithesis_sdk = "0.2" +antithesis_sdk = { version = "0.2", default-features = false } cfg-if = "1.0.0" tracing-appender = "0.2.3" env_logger = { version = "0.11.6", default-features = false } @@ -99,6 +99,7 @@ regex-syntax = { version = "0.8.5", default-features = false } similar = { version = "2.7.0" } similar-asserts = { version = "1.7.0" } bitmaps = { version = "3.2.1", default-features = false } +console-subscriber = { version = "0.4.1" } [profile.dev.package.similar] opt-level = 3 diff --git a/README.md b/README.md index 00207893f..dffd2421e 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Turso Database is an in-process SQL database written in Rust, compatible with SQ * **SQLite compatibility** for SQL dialect, file formats, and the C API [see [document](COMPAT.md) for details] * **Change data capture (CDC)** for real-time tracking of database changes. -* **Language support** for +* **Multi-language support** for * [Go](https://github.com/tursodatabase/turso-go) * [JavaScript](bindings/javascript) * [Java](bindings/java) diff --git a/antithesis-tests/stress-composer/parallel_driver_alter_table.py b/antithesis-tests/stress-composer/parallel_driver_alter_table.py index be0551e41..0d151dbbb 100755 --- a/antithesis-tests/stress-composer/parallel_driver_alter_table.py +++ b/antithesis-tests/stress-composer/parallel_driver_alter_table.py @@ -126,6 +126,10 @@ try: con.commit() con_init.commit() +except turso.ProgrammingError as e: + print(f"Table/column might have been dropped in parallel: {e}") + con.rollback() + con_init.rollback() except turso.OperationalError as e: print(f"Failed to alter table: {e}") con.rollback() diff --git a/antithesis-tests/stress-composer/parallel_driver_create_index.py b/antithesis-tests/stress-composer/parallel_driver_create_index.py index e384dcb2d..2391dac4f 100755 --- a/antithesis-tests/stress-composer/parallel_driver_create_index.py +++ b/antithesis-tests/stress-composer/parallel_driver_create_index.py @@ -90,6 +90,9 @@ if create_composite: """) con_init.commit() print(f"Successfully created composite index: {index_name}") + except turso.ProgrammingError as e: + print(f"Table/column might have been dropped in parallel: {e}") + con.rollback() except turso.OperationalError as e: print(f"Failed to create composite index: {e}") con.rollback() @@ -137,6 +140,9 @@ else: """) con_init.commit() print(f"Successfully created {idx_type} index: {index_name}") + except turso.ProgrammingError as e: + print(f"Table/column might have been dropped in parallel: {e}") + con.rollback() except turso.OperationalError as e: print(f"Failed to create index: {e}") con.rollback() diff --git a/antithesis-tests/stress-composer/parallel_driver_create_table.py b/antithesis-tests/stress-composer/parallel_driver_create_table.py index 5d786d2da..7f3d2ba3a 100755 --- a/antithesis-tests/stress-composer/parallel_driver_create_table.py +++ b/antithesis-tests/stress-composer/parallel_driver_create_table.py @@ -49,7 +49,7 @@ print(f"Creating new table: tbl_{next_table_num}") # Define possible data types and constraints data_types = ["INTEGER", "REAL", "TEXT", "BLOB", "NUMERIC"] -constraints = ["", "NOT NULL", "DEFAULT 0", "DEFAULT ''", "UNIQUE", "CHECK (col_0 > 0)"] +constraints = ["", "NOT NULL", "DEFAULT 0", "DEFAULT ''", "UNIQUE"] # Generate random number of columns (2-10) col_count = 2 + (get_random() % 9) diff --git a/antithesis-tests/stress-composer/parallel_driver_delete.py b/antithesis-tests/stress-composer/parallel_driver_delete.py index 951cffd62..7f32e816f 100755 --- a/antithesis-tests/stress-composer/parallel_driver_delete.py +++ b/antithesis-tests/stress-composer/parallel_driver_delete.py @@ -48,6 +48,10 @@ for i in range(deletions): cur.execute(f""" DELETE FROM tbl_{selected_tbl} WHERE {where_clause} """) + except turso.ProgrammingError: + # Table/column might have been dropped in parallel - this is expected + con.rollback() + break except turso.OperationalError: con.rollback() # Re-raise other operational errors diff --git a/antithesis-tests/stress-composer/parallel_driver_drop_index.py b/antithesis-tests/stress-composer/parallel_driver_drop_index.py index 2033b4e5e..031aa0509 100755 --- a/antithesis-tests/stress-composer/parallel_driver_drop_index.py +++ b/antithesis-tests/stress-composer/parallel_driver_drop_index.py @@ -55,11 +55,13 @@ try: con_init.commit() print(f"Successfully dropped index: {index_name}") +except turso.ProgrammingError as e: + print(f"Index {index_name} already dropped in parallel: {e}") + con.rollback() except turso.OperationalError as e: print(f"Failed to drop index: {e}") con.rollback() except Exception as e: - # Handle case where index might not exist in indexes table print(f"Warning: Could not remove index from metadata: {e}") con.commit() diff --git a/antithesis-tests/stress-composer/parallel_driver_drop_table.py b/antithesis-tests/stress-composer/parallel_driver_drop_table.py index d065d7442..dea33ee42 100755 --- a/antithesis-tests/stress-composer/parallel_driver_drop_table.py +++ b/antithesis-tests/stress-composer/parallel_driver_drop_table.py @@ -31,9 +31,14 @@ except Exception as e: cur = con.cursor() -cur.execute(f"DROP TABLE tbl_{selected_tbl}") - -con.commit() +try: + cur.execute(f"DROP TABLE tbl_{selected_tbl}") + con.commit() + print(f"Successfully dropped table tbl_{selected_tbl}") +except turso.ProgrammingError as e: + # Table might have been dropped in parallel - this is expected + print(f"Table tbl_{selected_tbl} already dropped in parallel: {e}") + con.rollback() con.close() diff --git a/antithesis-tests/stress-composer/parallel_driver_insert.py b/antithesis-tests/stress-composer/parallel_driver_insert.py index 707ec58a6..50079094b 100755 --- a/antithesis-tests/stress-composer/parallel_driver_insert.py +++ b/antithesis-tests/stress-composer/parallel_driver_insert.py @@ -46,6 +46,10 @@ for i in range(insertions): INSERT INTO tbl_{selected_tbl} ({cols}) VALUES ({", ".join(values)}) """) + except turso.ProgrammingError: + # Table/column might have been dropped in parallel - this is expected + con.rollback() + break except turso.OperationalError as e: if "UNIQUE constraint failed" in str(e): # Ignore UNIQUE constraint violations diff --git a/antithesis-tests/stress-composer/parallel_driver_rollback.py b/antithesis-tests/stress-composer/parallel_driver_rollback.py index 862da43e1..30435b55b 100755 --- a/antithesis-tests/stress-composer/parallel_driver_rollback.py +++ b/antithesis-tests/stress-composer/parallel_driver_rollback.py @@ -46,6 +46,10 @@ for i in range(insertions): INSERT INTO tbl_{selected_tbl} ({cols}) VALUES ({", ".join(values)}) """) + except turso.ProgrammingError: + # Table/column might have been dropped in parallel - this is expected + con.rollback() + break except turso.OperationalError as e: if "UNIQUE constraint failed" in str(e): # Ignore UNIQUE constraint violations diff --git a/antithesis-tests/stress-composer/parallel_driver_schema_rollback.py b/antithesis-tests/stress-composer/parallel_driver_schema_rollback.py index 62ad68e96..5243ce5b8 100755 --- a/antithesis-tests/stress-composer/parallel_driver_schema_rollback.py +++ b/antithesis-tests/stress-composer/parallel_driver_schema_rollback.py @@ -35,25 +35,36 @@ except Exception as e: exit(0) cur = con.cursor() -cur.execute("SELECT sql FROM sqlite_schema WHERE type = 'table' AND name = '" + tbl_name + "'") -result = cur.fetchone() +try: + cur.execute("SELECT sql FROM sqlite_schema WHERE type = 'table' AND name = '" + tbl_name + "'") -if result is None: - print(f"Table {tbl_name} not found") + result = cur.fetchone() + + if result is None: + print(f"Table {tbl_name} not found") + exit(0) + else: + schema_before = result[0] + + cur.execute("BEGIN TRANSACTION") + + cur.execute("ALTER TABLE " + tbl_name + " RENAME TO " + tbl_name + "_old") + + con.rollback() + + cur = con.cursor() + cur.execute("SELECT sql FROM sqlite_schema WHERE type = 'table' AND name = '" + tbl_name + "'") + + result_after = cur.fetchone() + if result_after is None: + print(f"Table {tbl_name} dropped in parallel after rollback") + exit(0) + + schema_after = result_after[0] + + always(schema_before == schema_after, "schema should be the same after rollback", {}) +except turso.ProgrammingError as e: + print(f"Table {tbl_name} dropped in parallel: {e}") + con.rollback() exit(0) -else: - schema_before = result[0] - -cur.execute("BEGIN TRANSACTION") - -cur.execute("ALTER TABLE " + tbl_name + " RENAME TO " + tbl_name + "_old") - -con.rollback() - -cur = con.cursor() -cur.execute("SELECT sql FROM sqlite_schema WHERE type = 'table' AND name = '" + tbl_name + "'") - -schema_after = cur.fetchone()[0] - -always(schema_before == schema_after, "schema should be the same after rollback", {}) diff --git a/antithesis-tests/stress-composer/parallel_driver_update.py b/antithesis-tests/stress-composer/parallel_driver_update.py index 9288b7287..d4c916bdb 100755 --- a/antithesis-tests/stress-composer/parallel_driver_update.py +++ b/antithesis-tests/stress-composer/parallel_driver_update.py @@ -60,6 +60,10 @@ for i in range(updates): cur.execute(f""" UPDATE tbl_{selected_tbl} SET {set_clause} WHERE {where_clause} """) + except turso.ProgrammingError: + # Table/column might have been dropped in parallel - this is expected + con.rollback() + break except turso.OperationalError as e: if "UNIQUE constraint failed" in str(e): # Ignore UNIQUE constraint violations diff --git a/bindings/java/.gitignore b/bindings/java/.gitignore index f4f8fc542..ea2eb6d0a 100644 --- a/bindings/java/.gitignore +++ b/bindings/java/.gitignore @@ -41,3 +41,4 @@ bin/ ### turso builds ### libs +temp \ No newline at end of file diff --git a/bindings/java/Makefile b/bindings/java/Makefile index 572521784..9251afe83 100644 --- a/bindings/java/Makefile +++ b/bindings/java/Makefile @@ -10,7 +10,7 @@ LINUX_X86_DIR := $(RELEASE_DIR)/linux_x86 .PHONY: libs macos_x86 macos_arm64 windows lint lint_apply test build_test -libs: macos_x86 macos_arm64 windows +libs: macos_x86 macos_arm64 windows linux_x86 macos_x86: @echo "Building release version for macOS x86_64..." diff --git a/bindings/java/build.gradle.kts b/bindings/java/build.gradle.kts index 1df88faf9..4fb7a704e 100644 --- a/bindings/java/build.gradle.kts +++ b/bindings/java/build.gradle.kts @@ -8,27 +8,35 @@ plugins { application `java-library` `maven-publish` + signing id("net.ltgt.errorprone") version "3.1.0" // If you're stuck on JRE 8, use id 'com.diffplug.spotless' version '6.13.0' or older. id("com.diffplug.spotless") version "6.13.0" } -group = properties["projectGroup"]!! -version = properties["projectVersion"]!! +// Apply publishing configuration +apply(from = "gradle/publish.gradle.kts") + +// Helper function to read properties with defaults +fun prop(key: String, default: String? = null): String? = + findProperty(key)?.toString() ?: default + +group = prop("projectGroup") ?: error("projectGroup must be set in gradle.properties") +version = prop("projectVersion") ?: error("projectVersion must be set in gradle.properties") java { sourceCompatibility = JavaVersion.VERSION_1_8 targetCompatibility = JavaVersion.VERSION_1_8 + withJavadocJar() + withSourcesJar() } -publishing { - publications { - create("mavenJava") { - from(components["java"]) - groupId = "tech.turso" - artifactId = "turso" - version = "0.0.1-SNAPSHOT" +// TODO: Add javadoc to required class and methods. After that, let's remove this settings +tasks.withType { + options { + (this as StandardJavadocDocletOptions).apply { + addStringOption("Xdoclint:none", "-quiet") } } } diff --git a/bindings/java/gradle.properties b/bindings/java/gradle.properties index 4b6e7d55e..c2f38979b 100644 --- a/bindings/java/gradle.properties +++ b/bindings/java/gradle.properties @@ -1,2 +1,16 @@ -projectGroup="tech.turso" -projectVersion=0.0.1-SNAPSHOT +projectGroup=tech.turso +projectVersion=0.0.1 +projectArtifactId=turso + +# POM metadata +pomName=Turso JDBC Driver +pomDescription=Turso JDBC driver for Java applications +pomUrl=https://github.com/tursodatabase/turso +pomLicenseName=MIT License +pomLicenseUrl=https://opensource.org/licenses/MIT +pomDeveloperId=turso +pomDeveloperName=Turso +pomDeveloperEmail=penberg@iki.fi +pomScmConnection=scm:git:git://github.com/tursodatabase/turso.git +pomScmDeveloperConnection=scm:git:ssh://github.com:tursodatabase/turso.git +pomScmUrl=https://github.com/tursodatabase/turso \ No newline at end of file diff --git a/bindings/java/gradle/publish.gradle.kts b/bindings/java/gradle/publish.gradle.kts new file mode 100644 index 000000000..97fd1dc41 --- /dev/null +++ b/bindings/java/gradle/publish.gradle.kts @@ -0,0 +1,250 @@ +import java.security.MessageDigest +import org.gradle.api.publish.PublishingExtension +import org.gradle.api.publish.maven.MavenPublication +import org.gradle.plugins.signing.SigningExtension + +// Helper function to read properties with defaults +fun prop(key: String, default: String? = null): String? = + project.findProperty(key)?.toString() ?: default + +// Maven Publishing Configuration +configure { + publications { + create("mavenJava") { + from(components["java"]) + groupId = prop("projectGroup")!! + artifactId = prop("projectArtifactId")!! + version = prop("projectVersion")!! + + pom { + name.set(prop("pomName")) + description.set(prop("pomDescription")) + url.set(prop("pomUrl")) + + licenses { + license { + name.set(prop("pomLicenseName")) + url.set(prop("pomLicenseUrl")) + } + } + + developers { + developer { + id.set(prop("pomDeveloperId")) + name.set(prop("pomDeveloperName")) + email.set(prop("pomDeveloperEmail")) + } + } + + scm { + connection.set(prop("pomScmConnection")) + developerConnection.set(prop("pomScmDeveloperConnection")) + url.set(prop("pomScmUrl")) + } + } + } + } +} + +// GPG Signing Configuration +configure { + // Make signing required for publishing + setRequired(true) + + // For CI/GitHub Actions: use in-memory keys + val signingKey = providers.environmentVariable("MAVEN_SIGNING_KEY").orNull + val signingPassword = providers.environmentVariable("MAVEN_SIGNING_PASSPHRASE").orNull + + if (signingKey != null && signingPassword != null) { + // CI mode: use in-memory keys + useInMemoryPgpKeys(signingKey, signingPassword) + } else { + // Local mode: use GPG command from system + useGpgCmd() + } + + sign(the().publications["mavenJava"]) +} + +// Helper task to generate checksums +val generateChecksums by tasks.registering { + dependsOn("jar", "sourcesJar", "javadocJar", "generatePomFileForMavenJavaPublication") + + val checksumDir = layout.buildDirectory.dir("checksums") + + doLast { + val files = listOf( + tasks.named("jar").get().outputs.files.singleFile, + tasks.named("sourcesJar").get().outputs.files.singleFile, + tasks.named("javadocJar").get().outputs.files.singleFile, + layout.buildDirectory.file("publications/mavenJava/pom-default.xml").get().asFile + ) + + checksumDir.get().asFile.mkdirs() + + files.forEach { file -> + if (file.exists()) { + // MD5 + val md5 = MessageDigest.getInstance("MD5") + .digest(file.readBytes()) + .joinToString("") { "%02x".format(it) } + file("${file.absolutePath}.md5").writeText(md5) + + // SHA1 + val sha1 = MessageDigest.getInstance("SHA-1") + .digest(file.readBytes()) + .joinToString("") { "%02x".format(it) } + file("${file.absolutePath}.sha1").writeText(sha1) + } + } + } +} + +// Task to create a bundle zip for Maven Central Portal +val createMavenCentralBundle by tasks.registering(Zip::class) { + group = "publishing" + description = "Creates a bundle zip for Maven Central Portal upload" + + dependsOn("generatePomFileForMavenJavaPublication", "jar", "sourcesJar", "javadocJar", "signMavenJavaPublication", generateChecksums) + + // Ensure signing happens before bundle creation + mustRunAfter("signMavenJavaPublication") + + val groupId = prop("projectGroup")!!.replace(".", "/") + val artifactId = prop("projectArtifactId")!! + val projectVer = project.version.toString() + + // Validate version is not SNAPSHOT for Maven Central + doFirst { + if (projectVer.contains("SNAPSHOT")) { + throw GradleException( + "Cannot publish SNAPSHOT version to Maven Central. " + + "Please change projectVersion in gradle.properties to a release version (e.g., 0.0.1)" + ) + } + } + + archiveFileName.set("$artifactId-$projectVer-bundle.zip") + destinationDirectory.set(layout.buildDirectory.dir("maven-central")) + + // Maven Central expects files in groupId/artifactId/version/ structure + val basePath = "$groupId/$artifactId/$projectVer" + + // Main JAR + checksums + signature + from(tasks.named("jar").get().outputs.files) { + into(basePath) + rename { "$artifactId-$projectVer.jar" } + } + from(tasks.named("jar").get().outputs.files.singleFile.absolutePath + ".md5") { + into(basePath) + rename { "$artifactId-$projectVer.jar.md5" } + } + from(tasks.named("jar").get().outputs.files.singleFile.absolutePath + ".sha1") { + into(basePath) + rename { "$artifactId-$projectVer.jar.sha1" } + } + + // Sources JAR + checksums + signature + from(tasks.named("sourcesJar").get().outputs.files) { + into(basePath) + rename { "$artifactId-$projectVer-sources.jar" } + } + from(tasks.named("sourcesJar").get().outputs.files.singleFile.absolutePath + ".md5") { + into(basePath) + rename { "$artifactId-$projectVer-sources.jar.md5" } + } + from(tasks.named("sourcesJar").get().outputs.files.singleFile.absolutePath + ".sha1") { + into(basePath) + rename { "$artifactId-$projectVer-sources.jar.sha1" } + } + + // Javadoc JAR + checksums + signature + from(tasks.named("javadocJar").get().outputs.files) { + into(basePath) + rename { "$artifactId-$projectVer-javadoc.jar" } + } + from(tasks.named("javadocJar").get().outputs.files.singleFile.absolutePath + ".md5") { + into(basePath) + rename { "$artifactId-$projectVer-javadoc.jar.md5" } + } + from(tasks.named("javadocJar").get().outputs.files.singleFile.absolutePath + ".sha1") { + into(basePath) + rename { "$artifactId-$projectVer-javadoc.jar.sha1" } + } + + // POM + checksums + signature + from(layout.buildDirectory.file("publications/mavenJava/pom-default.xml")) { + into(basePath) + rename { "$artifactId-$projectVer.pom" } + } + from(layout.buildDirectory.file("publications/mavenJava/pom-default.xml").get().asFile.absolutePath + ".md5") { + into(basePath) + rename { "$artifactId-$projectVer.pom.md5" } + } + from(layout.buildDirectory.file("publications/mavenJava/pom-default.xml").get().asFile.absolutePath + ".sha1") { + into(basePath) + rename { "$artifactId-$projectVer.pom.sha1" } + } + + // Signature files - get them from the signing task outputs + doFirst { + val signingTask = tasks.named("signMavenJavaPublication").get() + logger.lifecycle("Signing task outputs: ${signingTask.outputs.files.files}") + } + + // Include signature files generated by the signing plugin + from(tasks.named("signMavenJavaPublication").get().outputs.files) { + into(basePath) + include("*.jar.asc", "pom-default.xml.asc") + exclude("module.json.asc") // Exclude gradle module metadata signature + rename { name -> + // Only rename the POM signature file + // JAR signatures are already correctly named by the signing plugin + if (name == "pom-default.xml.asc") { + "$artifactId-$projectVer.pom.asc" + } else { + name // Keep original name (already correct) + } + } + } +} + +// Task to upload bundle to Maven Central Portal +tasks.register("publishToMavenCentral") { + group = "publishing" + description = "Publishes artifacts to Maven Central Portal" + + // Run publish first to generate signatures, then create bundle + dependsOn("publish") + dependsOn(createMavenCentralBundle) + + // Make sure bundle creation happens after publish + createMavenCentralBundle.get().mustRunAfter("publish") + + doLast { + val username = providers.environmentVariable("MAVEN_UPLOAD_USERNAME").orNull + val password = providers.environmentVariable("MAVEN_UPLOAD_PASSWORD").orNull + val bundleFile = createMavenCentralBundle.get().archiveFile.get().asFile + + require(username != null) { "MAVEN_UPLOAD_USERNAME environment variable must be set" } + require(password != null) { "MAVEN_UPLOAD_PASSWORD environment variable must be set" } + require(bundleFile.exists()) { "Bundle file does not exist: ${bundleFile.absolutePath}" } + + logger.lifecycle("Uploading bundle to Maven Central Portal...") + logger.lifecycle("Bundle: ${bundleFile.absolutePath}") + logger.lifecycle("Size: ${bundleFile.length() / 1024} KB") + + // Use curl for uploading (simple and available on most systems) + exec { + commandLine( + "curl", + "-X", "POST", + "-u", "$username:$password", + "-F", "bundle=@${bundleFile.absolutePath}", + "https://central.sonatype.com/api/v1/publisher/upload?name=${bundleFile.name}&publishingType=AUTOMATIC" + ) + } + + logger.lifecycle("Upload completed. Check https://central.sonatype.com/publishing for status.") + } +} diff --git a/bindings/java/src/main/java/tech/turso/JDBC.java b/bindings/java/src/main/java/tech/turso/JDBC.java index 9611398d9..904ece6c3 100644 --- a/bindings/java/src/main/java/tech/turso/JDBC.java +++ b/bindings/java/src/main/java/tech/turso/JDBC.java @@ -10,6 +10,7 @@ import tech.turso.jdbc4.JDBC4Connection; import tech.turso.utils.Logger; import tech.turso.utils.LoggerFactory; +/** Turso JDBC driver implementation. */ public final class JDBC implements Driver { private static final Logger logger = LoggerFactory.getLogger(JDBC.class); @@ -24,6 +25,14 @@ public final class JDBC implements Driver { } } + /** + * Creates a new Turso JDBC connection. + * + * @param url the database URL + * @param properties connection properties + * @return a new connection instance, or null if the URL is not valid + * @throws SQLException if a database access error occurs + */ @Nullable public static JDBC4Connection createConnection(String url, Properties properties) throws SQLException { diff --git a/bindings/java/src/main/java/tech/turso/core/TursoDBFactory.java b/bindings/java/src/main/java/tech/turso/core/TursoDBFactory.java index 0076011e4..900c0f681 100644 --- a/bindings/java/src/main/java/tech/turso/core/TursoDBFactory.java +++ b/bindings/java/src/main/java/tech/turso/core/TursoDBFactory.java @@ -23,7 +23,7 @@ public final class TursoDBFactory { * @param url the URL of the database * @param filePath the path to the database file * @param properties additional properties for the database connection - * @return an instance of {@link tursoDB} + * @return an instance of {@link TursoDB} * @throws SQLException if there is an error opening the connection * @throws IllegalArgumentException if the fileName is empty */ diff --git a/bindings/java/src/main/java/tech/turso/core/TursoResultSet.java b/bindings/java/src/main/java/tech/turso/core/TursoResultSet.java index 00d2d6aa0..0c7bcf4c5 100644 --- a/bindings/java/src/main/java/tech/turso/core/TursoResultSet.java +++ b/bindings/java/src/main/java/tech/turso/core/TursoResultSet.java @@ -57,7 +57,7 @@ public final class TursoResultSet { } /** - * Moves the cursor forward one row from its current position. A {@link tursoResultSet} cursor is + * Moves the cursor forward one row from its current position. A {@link TursoResultSet} cursor is * initially positioned before the first fow; the first call to the method next makes * the first row the current row; the second call makes the second row the current row, and so on. * When a call to the next method returns false, the cursor is @@ -65,6 +65,9 @@ public final class TursoResultSet { * *

Note that turso only supports ResultSet.TYPE_FORWARD_ONLY, which means that the * cursor can only move forward. + * + * @return true if the new current row is valid; false if there are no more rows + * @throws SQLException if a database access error occurs */ public boolean next() throws SQLException { if (!open) { diff --git a/bindings/java/src/main/java/tech/turso/core/TursoStatement.java b/bindings/java/src/main/java/tech/turso/core/TursoStatement.java index de4d86e7a..0c15f6586 100644 --- a/bindings/java/src/main/java/tech/turso/core/TursoStatement.java +++ b/bindings/java/src/main/java/tech/turso/core/TursoStatement.java @@ -91,8 +91,8 @@ public final class TursoStatement { private native void _close(long statementPointer); /** - * Initializes the column metadata, such as the names of the columns. Since {@link tursoStatement} - * can only have a single {@link tursoResultSet}, it is appropriate to place the initialization of + * Initializes the column metadata, such as the names of the columns. Since {@link TursoStatement} + * can only have a single {@link TursoResultSet}, it is appropriate to place the initialization of * column metadata here. * * @throws SQLException if a database access error occurs while retrieving column names diff --git a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4Connection.java b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4Connection.java index 6841a5cbc..1629a52ce 100644 --- a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4Connection.java +++ b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4Connection.java @@ -9,20 +9,43 @@ import tech.turso.annotations.SkipNullableCheck; import tech.turso.core.TursoConnection; import tech.turso.core.TursoStatement; +/** JDBC 4 Connection implementation for Turso databases. */ public final class JDBC4Connection implements Connection { private final TursoConnection connection; private Map> typeMap = new HashMap<>(); + /** + * Creates a new JDBC4 connection. + * + * @param url the database URL + * @param filePath the database file path + * @throws SQLException if a database access error occurs + */ public JDBC4Connection(String url, String filePath) throws SQLException { this.connection = new TursoConnection(url, filePath); } + /** + * Creates a new JDBC4 connection with properties. + * + * @param url the database URL + * @param filePath the database file path + * @param properties connection properties + * @throws SQLException if a database access error occurs + */ public JDBC4Connection(String url, String filePath, Properties properties) throws SQLException { this.connection = new TursoConnection(url, filePath, properties); } + /** + * Prepares a SQL statement for execution. + * + * @param sql the SQL statement to prepare + * @return the prepared statement + * @throws SQLException if a database access error occurs + */ public TursoStatement prepare(String sql) throws SQLException { final TursoStatement statement = connection.prepare(sql); statement.initializeColumnMetadata(); @@ -357,6 +380,11 @@ public final class JDBC4Connection implements Connection { return false; } + /** + * Sets the busy timeout for the connection. + * + * @param busyTimeout the timeout in milliseconds + */ public void setBusyTimeout(int busyTimeout) { // TODO: add support for busy timeout } @@ -367,10 +395,20 @@ public final class JDBC4Connection implements Connection { return 0; } + /** + * Gets the database URL. + * + * @return the database URL + */ public String getUrl() { return this.connection.getUrl(); } + /** + * Checks if the connection is open. + * + * @throws SQLException if the connection is closed + */ public void checkOpen() throws SQLException { connection.checkOpen(); } diff --git a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4DatabaseMetaData.java b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4DatabaseMetaData.java index c0137c96d..97f455b0a 100644 --- a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4DatabaseMetaData.java +++ b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4DatabaseMetaData.java @@ -13,6 +13,7 @@ import tech.turso.core.TursoPropertiesHolder; import tech.turso.utils.Logger; import tech.turso.utils.LoggerFactory; +/** JDBC 4 DatabaseMetaData implementation for Turso databases. */ public final class JDBC4DatabaseMetaData implements DatabaseMetaData { private static final Logger logger = LoggerFactory.getLogger(JDBC4DatabaseMetaData.class); @@ -51,6 +52,11 @@ public final class JDBC4DatabaseMetaData implements DatabaseMetaData { @Nullable private PreparedStatement getColumnPrivileges = null; + /** + * Creates a new JDBC4DatabaseMetaData instance. + * + * @param connection the database connection + */ public JDBC4DatabaseMetaData(JDBC4Connection connection) { this.connection = connection; } diff --git a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4PreparedStatement.java b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4PreparedStatement.java index d9508e0dd..6eba787b5 100644 --- a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4PreparedStatement.java +++ b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4PreparedStatement.java @@ -26,11 +26,19 @@ import java.util.Calendar; import tech.turso.annotations.SkipNullableCheck; import tech.turso.core.TursoResultSet; +/** JDBC 4 PreparedStatement implementation for Turso databases. */ public final class JDBC4PreparedStatement extends JDBC4Statement implements PreparedStatement { private final String sql; private final JDBC4ResultSet resultSet; + /** + * Creates a new JDBC4PreparedStatement. + * + * @param connection the database connection + * @param sql the SQL statement to prepare + * @throws SQLException if a database access error occurs + */ public JDBC4PreparedStatement(JDBC4Connection connection, String sql) throws SQLException { super(connection); this.sql = sql; diff --git a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4ResultSet.java b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4ResultSet.java index 777a84ffa..8df20ff2c 100644 --- a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4ResultSet.java +++ b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4ResultSet.java @@ -29,11 +29,17 @@ import tech.turso.annotations.Nullable; import tech.turso.annotations.SkipNullableCheck; import tech.turso.core.TursoResultSet; +/** JDBC 4 ResultSet implementation for Turso databases. */ public final class JDBC4ResultSet implements ResultSet, ResultSetMetaData { private final TursoResultSet resultSet; private boolean wasNull = false; + /** + * Creates a new JDBC4ResultSet. + * + * @param resultSet the underlying Turso result set + */ public JDBC4ResultSet(TursoResultSet resultSet) { this.resultSet = resultSet; } @@ -1361,8 +1367,19 @@ public final class JDBC4ResultSet implements ResultSet, ResultSetMetaData { - localCal.getTimeZone().getOffset(timeMillis); } + /** + * Functional interface for result set value suppliers. + * + * @param the type of value to supply + */ @FunctionalInterface public interface ResultSetSupplier { + /** + * Gets a result from the result set. + * + * @return the result value + * @throws Exception if an error occurs + */ T get() throws Exception; } diff --git a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4Statement.java b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4Statement.java index eb31c8d0b..8f1f50da3 100644 --- a/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4Statement.java +++ b/bindings/java/src/main/java/tech/turso/jdbc4/JDBC4Statement.java @@ -17,6 +17,7 @@ import tech.turso.annotations.SkipNullableCheck; import tech.turso.core.TursoResultSet; import tech.turso.core.TursoStatement; +/** JDBC 4 Statement implementation for Turso databases. */ public class JDBC4Statement implements Statement { private static final Pattern BATCH_COMPATIBLE_PATTERN = @@ -35,7 +36,10 @@ public class JDBC4Statement implements Statement { private final JDBC4Connection connection; + /** The underlying Turso statement. */ @Nullable protected TursoStatement statement = null; + + /** The number of rows affected by the last update operation. */ protected long updateCount; // Because JDBC4Statement has different life cycle in compared to tursoStatement, let's use this @@ -475,8 +479,19 @@ public class JDBC4Statement implements Statement { } } + /** + * Functional interface for SQL callable operations. + * + * @param the return type + */ @FunctionalInterface protected interface SQLCallable { + /** + * Executes the SQL operation. + * + * @return the result of the operation + * @throws SQLException if a database access error occurs + */ T call() throws SQLException; } diff --git a/bindings/java/src/main/java/tech/turso/utils/ByteArrayUtils.java b/bindings/java/src/main/java/tech/turso/utils/ByteArrayUtils.java index 4922984f0..bd366b5ee 100644 --- a/bindings/java/src/main/java/tech/turso/utils/ByteArrayUtils.java +++ b/bindings/java/src/main/java/tech/turso/utils/ByteArrayUtils.java @@ -3,7 +3,14 @@ package tech.turso.utils; import java.nio.charset.StandardCharsets; import tech.turso.annotations.Nullable; +/** Utility class for converting between byte arrays and strings using UTF-8 encoding. */ public final class ByteArrayUtils { + /** + * Converts a UTF-8 encoded byte array to a string. + * + * @param buffer the byte array to convert, may be null + * @return the string representation, or null if the input is null + */ @Nullable public static String utf8ByteBufferToString(@Nullable byte[] buffer) { if (buffer == null) { @@ -13,6 +20,12 @@ public final class ByteArrayUtils { return new String(buffer, StandardCharsets.UTF_8); } + /** + * Converts a string to a UTF-8 encoded byte array. + * + * @param str the string to convert, may be null + * @return the byte array representation, or null if the input is null + */ @Nullable public static byte[] stringToUtf8ByteArray(@Nullable String str) { if (str == null) { diff --git a/bindings/javascript/package-lock.json b/bindings/javascript/package-lock.json index 34da4b8a3..56e6574bd 100644 --- a/bindings/javascript/package-lock.json +++ b/bindings/javascript/package-lock.json @@ -1,11 +1,11 @@ { "name": "javascript", - "version": "0.2.0", + "version": "0.3.0-pre.4", "lockfileVersion": 3, "requires": true, "packages": { "": { - "version": "0.2.0", + "version": "0.3.0-pre.4", "workspaces": [ "packages/common", "packages/wasm-common", @@ -3542,7 +3542,7 @@ }, "packages/common": { "name": "@tursodatabase/database-common", - "version": "0.2.0", + "version": "0.3.0-pre.4", "license": "MIT", "devDependencies": { "typescript": "^5.9.2", @@ -3551,10 +3551,10 @@ }, "packages/native": { "name": "@tursodatabase/database", - "version": "0.2.0", + "version": "0.3.0-pre.4", "license": "MIT", "dependencies": { - "@tursodatabase/database-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4" }, "devDependencies": { "@napi-rs/cli": "^3.1.5", @@ -3568,11 +3568,11 @@ }, "packages/wasm": { "name": "@tursodatabase/database-wasm", - "version": "0.2.0", + "version": "0.3.0-pre.4", "license": "MIT", "dependencies": { - "@tursodatabase/database-common": "^0.2.0", - "@tursodatabase/database-wasm-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4", + "@tursodatabase/database-wasm-common": "^0.3.0-pre.4" }, "devDependencies": { "@napi-rs/cli": "^3.1.5", @@ -3585,7 +3585,7 @@ }, "packages/wasm-common": { "name": "@tursodatabase/database-wasm-common", - "version": "0.2.0", + "version": "0.3.0-pre.4", "license": "MIT", "dependencies": { "@napi-rs/wasm-runtime": "^1.0.5" @@ -3596,10 +3596,10 @@ }, "sync/packages/common": { "name": "@tursodatabase/sync-common", - "version": "0.2.0", + "version": "0.3.0-pre.4", "license": "MIT", "dependencies": { - "@tursodatabase/database-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4" }, "devDependencies": { "typescript": "^5.9.2" @@ -3607,11 +3607,11 @@ }, "sync/packages/native": { "name": "@tursodatabase/sync", - "version": "0.2.0", + "version": "0.3.0-pre.4", "license": "MIT", "dependencies": { - "@tursodatabase/database-common": "^0.2.0", - "@tursodatabase/sync-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4", + "@tursodatabase/sync-common": "^0.3.0-pre.4" }, "devDependencies": { "@napi-rs/cli": "^3.1.5", @@ -3622,12 +3622,12 @@ }, "sync/packages/wasm": { "name": "@tursodatabase/sync-wasm", - "version": "0.2.0", + "version": "0.3.0-pre.4", "license": "MIT", "dependencies": { - "@tursodatabase/database-common": "^0.2.0", - "@tursodatabase/database-wasm-common": "^0.2.0", - "@tursodatabase/sync-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4", + "@tursodatabase/database-wasm-common": "^0.3.0-pre.4", + "@tursodatabase/sync-common": "^0.3.0-pre.4" }, "devDependencies": { "@napi-rs/cli": "^3.1.5", diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 7625d4ca2..12f5d43f4 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -14,5 +14,5 @@ "sync/packages/native", "sync/packages/wasm" ], - "version": "0.2.0" + "version": "0.3.0-pre.4" } diff --git a/bindings/javascript/packages/common/package.json b/bindings/javascript/packages/common/package.json index 9412bdda0..72acbc0dd 100644 --- a/bindings/javascript/packages/common/package.json +++ b/bindings/javascript/packages/common/package.json @@ -1,6 +1,6 @@ { "name": "@tursodatabase/database-common", - "version": "0.2.0", + "version": "0.3.0-pre.4", "repository": { "type": "git", "url": "https://github.com/tursodatabase/turso" diff --git a/bindings/javascript/packages/native/package.json b/bindings/javascript/packages/native/package.json index f04125972..052d1c1ac 100644 --- a/bindings/javascript/packages/native/package.json +++ b/bindings/javascript/packages/native/package.json @@ -1,6 +1,6 @@ { "name": "@tursodatabase/database", - "version": "0.2.0", + "version": "0.3.0-pre.4", "repository": { "type": "git", "url": "https://github.com/tursodatabase/turso" @@ -47,7 +47,7 @@ ] }, "dependencies": { - "@tursodatabase/database-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4" }, "imports": { "#index": "./index.js" diff --git a/bindings/javascript/packages/wasm-common/package.json b/bindings/javascript/packages/wasm-common/package.json index c4f9d3bf4..23b0801ed 100644 --- a/bindings/javascript/packages/wasm-common/package.json +++ b/bindings/javascript/packages/wasm-common/package.json @@ -1,6 +1,6 @@ { "name": "@tursodatabase/database-wasm-common", - "version": "0.2.0", + "version": "0.3.0-pre.4", "repository": { "type": "git", "url": "https://github.com/tursodatabase/turso" diff --git a/bindings/javascript/packages/wasm/package.json b/bindings/javascript/packages/wasm/package.json index eb7e3a542..a474952e4 100644 --- a/bindings/javascript/packages/wasm/package.json +++ b/bindings/javascript/packages/wasm/package.json @@ -1,6 +1,6 @@ { "name": "@tursodatabase/database-wasm", - "version": "0.2.0", + "version": "0.3.0-pre.4", "repository": { "type": "git", "url": "https://github.com/tursodatabase/turso" @@ -51,7 +51,7 @@ ] }, "dependencies": { - "@tursodatabase/database-common": "^0.2.0", - "@tursodatabase/database-wasm-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4", + "@tursodatabase/database-wasm-common": "^0.3.0-pre.4" } } diff --git a/bindings/javascript/packages/wasm/promise.test.ts b/bindings/javascript/packages/wasm/promise.test.ts index 77176d9f9..7cdd8bc1b 100644 --- a/bindings/javascript/packages/wasm/promise.test.ts +++ b/bindings/javascript/packages/wasm/promise.test.ts @@ -1,6 +1,110 @@ import { expect, test } from 'vitest' import { connect, Database } from './promise-default.js' +test('vector-test', async () => { + const db = await connect(":memory:"); + const v1 = new Array(1024).fill(0).map((_, i) => i); + const v2 = new Array(1024).fill(0).map((_, i) => 1024 - i); + const result = await db.prepare(`SELECT + vector_distance_cos(vector32('${JSON.stringify(v1)}'), vector32('${JSON.stringify(v2)}')) as cosf32, + vector_distance_cos(vector64('${JSON.stringify(v1)}'), vector64('${JSON.stringify(v2)}')) as cosf64, + vector_distance_l2(vector32('${JSON.stringify(v1)}'), vector32('${JSON.stringify(v2)}')) as l2f32, + vector_distance_l2(vector64('${JSON.stringify(v1)}'), vector64('${JSON.stringify(v2)}')) as l2f64 + `).all(); + console.info(result); +}) + +test('explain', async () => { + const db = await connect(":memory:"); + const stmt = db.prepare("EXPLAIN SELECT 1"); + expect(stmt.columns()).toEqual([ + { + "name": "addr", + "type": "INTEGER", + }, + { + "name": "opcode", + "type": "TEXT", + }, + { + "name": "p1", + "type": "INTEGER", + }, + { + "name": "p2", + "type": "INTEGER", + }, + { + "name": "p3", + "type": "INTEGER", + }, + { + "name": "p4", + "type": "INTEGER", + }, + { + "name": "p5", + "type": "INTEGER", + }, + { + "name": "comment", + "type": "TEXT", + }, + ].map(x => ({ ...x, column: null, database: null, table: null }))); + expect(await stmt.all()).toEqual([ + { + "addr": 0, + "comment": "Start at 3", + "opcode": "Init", + "p1": 0, + "p2": 3, + "p3": 0, + "p4": "", + "p5": 0, + }, + { + "addr": 1, + "comment": "output=r[1]", + "opcode": "ResultRow", + "p1": 1, + "p2": 1, + "p3": 0, + "p4": "", + "p5": 0, + }, + { + "addr": 2, + "comment": "", + "opcode": "Halt", + "p1": 0, + "p2": 0, + "p3": 0, + "p4": "", + "p5": 0, + }, + { + "addr": 3, + "comment": "r[1]=1", + "opcode": "Integer", + "p1": 1, + "p2": 1, + "p3": 0, + "p4": "", + "p5": 0, + }, + { + "addr": 4, + "comment": "", + "opcode": "Goto", + "p1": 0, + "p2": 1, + "p3": 0, + "p4": "", + "p5": 0, + }, + ]); +}) + test('in-memory db', async () => { const db = await connect(":memory:"); await db.exec("CREATE TABLE t(x)"); @@ -10,6 +114,7 @@ test('in-memory db', async () => { expect(rows).toEqual([{ x: 1 }, { x: 3 }]); }) + test('implicit connect', async () => { const db = new Database(':memory:'); const defer = db.prepare("SELECT * FROM t"); diff --git a/bindings/javascript/src/lib.rs b/bindings/javascript/src/lib.rs index dcaaefc5b..d92661269 100644 --- a/bindings/javascript/src/lib.rs +++ b/bindings/javascript/src/lib.rs @@ -178,7 +178,7 @@ fn connect_sync(db: &DatabaseInner) -> napi::Result<()> { let io = &db.io; let file = io .open_file(&db.path, flags, false) - .map_err(|e| to_generic_error("failed to open file", e))?; + .map_err(|e| to_generic_error(&format!("failed to open file {}", db.path), e))?; let db_file = DatabaseFile::new(file); let db_core = turso_core::Database::open_with_flags( @@ -191,7 +191,7 @@ fn connect_sync(db: &DatabaseInner) -> napi::Result<()> { .with_indexes(true), None, ) - .map_err(|e| to_generic_error("failed to open database", e))?; + .map_err(|e| to_generic_error(&format!("failed to open database {}", db.path), e))?; let conn = db_core .connect() diff --git a/bindings/javascript/sync/packages/common/package.json b/bindings/javascript/sync/packages/common/package.json index 72962d9f8..0f83707d6 100644 --- a/bindings/javascript/sync/packages/common/package.json +++ b/bindings/javascript/sync/packages/common/package.json @@ -1,6 +1,6 @@ { "name": "@tursodatabase/sync-common", - "version": "0.2.0", + "version": "0.3.0-pre.4", "repository": { "type": "git", "url": "https://github.com/tursodatabase/turso" @@ -23,6 +23,6 @@ "test": "echo 'no tests'" }, "dependencies": { - "@tursodatabase/database-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4" } } diff --git a/bindings/javascript/sync/packages/native/package.json b/bindings/javascript/sync/packages/native/package.json index be78f1452..d8b20c163 100644 --- a/bindings/javascript/sync/packages/native/package.json +++ b/bindings/javascript/sync/packages/native/package.json @@ -1,6 +1,6 @@ { "name": "@tursodatabase/sync", - "version": "0.2.0", + "version": "0.3.0-pre.4", "repository": { "type": "git", "url": "https://github.com/tursodatabase/turso" @@ -44,8 +44,8 @@ ] }, "dependencies": { - "@tursodatabase/database-common": "^0.2.0", - "@tursodatabase/sync-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4", + "@tursodatabase/sync-common": "^0.3.0-pre.4" }, "imports": { "#index": "./index.js" diff --git a/bindings/javascript/sync/packages/wasm/package.json b/bindings/javascript/sync/packages/wasm/package.json index 88d8ffd04..c4bea4c94 100644 --- a/bindings/javascript/sync/packages/wasm/package.json +++ b/bindings/javascript/sync/packages/wasm/package.json @@ -1,6 +1,6 @@ { "name": "@tursodatabase/sync-wasm", - "version": "0.2.0", + "version": "0.3.0-pre.4", "repository": { "type": "git", "url": "https://github.com/tursodatabase/turso" @@ -54,8 +54,8 @@ "#index": "./index.js" }, "dependencies": { - "@tursodatabase/database-common": "^0.2.0", - "@tursodatabase/database-wasm-common": "^0.2.0", - "@tursodatabase/sync-common": "^0.2.0" + "@tursodatabase/database-common": "^0.3.0-pre.4", + "@tursodatabase/database-wasm-common": "^0.3.0-pre.4", + "@tursodatabase/sync-common": "^0.3.0-pre.4" } } diff --git a/bindings/javascript/yarn.lock b/bindings/javascript/yarn.lock index 7dcf99767..7b33e6081 100644 --- a/bindings/javascript/yarn.lock +++ b/bindings/javascript/yarn.lock @@ -1092,13 +1092,13 @@ __metadata: linkType: hard "@napi-rs/wasm-runtime@npm:^1.0.1": - version: 1.0.6 - resolution: "@napi-rs/wasm-runtime@npm:1.0.6" + version: 1.0.7 + resolution: "@napi-rs/wasm-runtime@npm:1.0.7" dependencies: "@emnapi/core": "npm:^1.5.0" "@emnapi/runtime": "npm:^1.5.0" "@tybys/wasm-util": "npm:^0.10.1" - checksum: 10c0/af48168c6e13c970498fda3ce7238234a906bc69dd474dc9abd560cdf8a7dea6410147afec8f0191a1d19767c8347d8ec0125a8a93225312f7ac37e06e8c15ad + checksum: 10c0/2d8635498136abb49d6dbf7395b78c63422292240963bf055f307b77aeafbde57ae2c0ceaaef215601531b36d6eb92a2cdd6f5ba90ed2aa8127c27aff9c4ae55 languageName: node linkType: hard @@ -1586,7 +1586,7 @@ __metadata: languageName: node linkType: hard -"@tursodatabase/database-common@npm:^0.2.0, @tursodatabase/database-common@workspace:packages/common": +"@tursodatabase/database-common@npm:^0.3.0-pre.4, @tursodatabase/database-common@workspace:packages/common": version: 0.0.0-use.local resolution: "@tursodatabase/database-common@workspace:packages/common" dependencies: @@ -1595,7 +1595,7 @@ __metadata: languageName: unknown linkType: soft -"@tursodatabase/database-wasm-common@npm:^0.2.0, @tursodatabase/database-wasm-common@workspace:packages/wasm-common": +"@tursodatabase/database-wasm-common@npm:^0.3.0-pre.4, @tursodatabase/database-wasm-common@workspace:packages/wasm-common": version: 0.0.0-use.local resolution: "@tursodatabase/database-wasm-common@workspace:packages/wasm-common" dependencies: @@ -1609,8 +1609,8 @@ __metadata: resolution: "@tursodatabase/database-wasm@workspace:packages/wasm" dependencies: "@napi-rs/cli": "npm:^3.1.5" - "@tursodatabase/database-common": "npm:^0.2.0" - "@tursodatabase/database-wasm-common": "npm:^0.2.0" + "@tursodatabase/database-common": "npm:^0.3.0-pre.4" + "@tursodatabase/database-wasm-common": "npm:^0.3.0-pre.4" "@vitest/browser": "npm:^3.2.4" playwright: "npm:^1.55.0" typescript: "npm:^5.9.2" @@ -1624,7 +1624,7 @@ __metadata: resolution: "@tursodatabase/database@workspace:packages/native" dependencies: "@napi-rs/cli": "npm:^3.1.5" - "@tursodatabase/database-common": "npm:^0.2.0" + "@tursodatabase/database-common": "npm:^0.3.0-pre.4" "@types/node": "npm:^24.3.1" better-sqlite3: "npm:^12.2.0" drizzle-kit: "npm:^0.31.4" @@ -1634,11 +1634,11 @@ __metadata: languageName: unknown linkType: soft -"@tursodatabase/sync-common@npm:^0.2.0, @tursodatabase/sync-common@workspace:sync/packages/common": +"@tursodatabase/sync-common@npm:^0.3.0-pre.4, @tursodatabase/sync-common@workspace:sync/packages/common": version: 0.0.0-use.local resolution: "@tursodatabase/sync-common@workspace:sync/packages/common" dependencies: - "@tursodatabase/database-common": "npm:^0.2.0" + "@tursodatabase/database-common": "npm:^0.3.0-pre.4" typescript: "npm:^5.9.2" languageName: unknown linkType: soft @@ -1648,9 +1648,9 @@ __metadata: resolution: "@tursodatabase/sync-wasm@workspace:sync/packages/wasm" dependencies: "@napi-rs/cli": "npm:^3.1.5" - "@tursodatabase/database-common": "npm:^0.2.0" - "@tursodatabase/database-wasm-common": "npm:^0.2.0" - "@tursodatabase/sync-common": "npm:^0.2.0" + "@tursodatabase/database-common": "npm:^0.3.0-pre.4" + "@tursodatabase/database-wasm-common": "npm:^0.3.0-pre.4" + "@tursodatabase/sync-common": "npm:^0.3.0-pre.4" "@vitest/browser": "npm:^3.2.4" playwright: "npm:^1.55.0" typescript: "npm:^5.9.2" @@ -1664,8 +1664,8 @@ __metadata: resolution: "@tursodatabase/sync@workspace:sync/packages/native" dependencies: "@napi-rs/cli": "npm:^3.1.5" - "@tursodatabase/database-common": "npm:^0.2.0" - "@tursodatabase/sync-common": "npm:^0.2.0" + "@tursodatabase/database-common": "npm:^0.3.0-pre.4" + "@tursodatabase/sync-common": "npm:^0.3.0-pre.4" "@types/node": "npm:^24.3.1" typescript: "npm:^5.9.2" vitest: "npm:^3.2.4" @@ -2573,9 +2573,9 @@ __metadata: linkType: hard "exponential-backoff@npm:^3.1.1": - version: 3.1.2 - resolution: "exponential-backoff@npm:3.1.2" - checksum: 10c0/d9d3e1eafa21b78464297df91f1776f7fbaa3d5e3f7f0995648ca5b89c069d17055033817348d9f4a43d1c20b0eab84f75af6991751e839df53e4dfd6f22e844 + version: 3.1.3 + resolution: "exponential-backoff@npm:3.1.3" + checksum: 10c0/77e3ae682b7b1f4972f563c6dbcd2b0d54ac679e62d5d32f3e5085feba20483cf28bd505543f520e287a56d4d55a28d7874299941faf637e779a1aa5994d1267 languageName: node linkType: hard @@ -3089,8 +3089,8 @@ __metadata: linkType: hard "node-gyp@npm:latest": - version: 11.4.2 - resolution: "node-gyp@npm:11.4.2" + version: 11.5.0 + resolution: "node-gyp@npm:11.5.0" dependencies: env-paths: "npm:^2.2.0" exponential-backoff: "npm:^3.1.1" @@ -3104,7 +3104,7 @@ __metadata: which: "npm:^5.0.0" bin: node-gyp: bin/node-gyp.js - checksum: 10c0/0bfd3e96770ed70f07798d881dd37b4267708966d868a0e585986baac487d9cf5831285579fd629a83dc4e434f53e6416ce301097f2ee464cb74d377e4d8bdbe + checksum: 10c0/31ff49586991b38287bb15c3d529dd689cfc32f992eed9e6997b9d712d5d21fe818a8b1bbfe3b76a7e33765c20210c5713212f4aa329306a615b87d8a786da3a languageName: node linkType: hard diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 2ffc62f8a..0a9f99e84 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,7 +18,7 @@ tracing_release = ["turso_core/tracing_release"] [dependencies] anyhow = "1.0" -turso_core = { path = "../../core", features = ["io_uring"] } +turso_core = { workspace = true, features = ["io_uring"] } pyo3 = { version = "0.24.1", features = ["anyhow"] } [build-dependencies] diff --git a/bindings/rust/Cargo.toml b/bindings/rust/Cargo.toml index d799b5320..42bce00cd 100644 --- a/bindings/rust/Cargo.toml +++ b/bindings/rust/Cargo.toml @@ -19,6 +19,8 @@ tracing_release = ["turso_core/tracing_release"] [dependencies] turso_core = { workspace = true, features = ["io_uring"] } thiserror = { workspace = true } +tracing-subscriber.workspace = true +tracing.workspace = true [dev-dependencies] tempfile = { workspace = true } diff --git a/bindings/rust/README.md b/bindings/rust/README.md index a8c1ce4c4..3d25c0556 100644 --- a/bindings/rust/README.md +++ b/bindings/rust/README.md @@ -4,7 +4,7 @@ The next evolution of SQLite: A high-performance, SQLite-compatible database lib ## Features -- **SQLite Compatible**: Drop-in replacement for rusqlite with familiar API +- **SQLite Compatible**: Similar interface to rusqlite with familiar API apart from using async Rust - **High Performance**: Built with Rust for maximum speed and efficiency - **Async/Await Support**: Native async operations with tokio support - **In-Process**: No network overhead, runs directly in your application @@ -18,7 +18,7 @@ Add this to your `Cargo.toml`: ```toml [dependencies] -turso = "0.1" +turso = "0.2" tokio = { version = "1.0", features = ["full"] } ``` diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index be95b4333..8cd47d599 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -46,12 +46,18 @@ pub use params::params_from_iter; pub use params::IntoParams; use std::fmt::Debug; +use std::future::Future; use std::num::NonZero; +use std::sync::atomic::AtomicU8; +use std::sync::atomic::Ordering; use std::sync::{Arc, Mutex}; +use std::task::Poll; pub use turso_core::EncryptionOpts; use turso_core::OpenFlags; + // Re-exports rows pub use crate::rows::{Row, Rows}; +use crate::transaction::DropBehavior; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -147,14 +153,14 @@ impl Builder { match vfs_choice { "memory" => Ok(Arc::new(turso_core::MemoryIO::new())), "syscall" => { - #[cfg(target_family = "unix")] + #[cfg(all(target_family = "unix", not(miri)))] { Ok(Arc::new( turso_core::UnixIO::new() .map_err(|e| Error::SqlExecutionFailure(e.to_string()))?, )) } - #[cfg(not(target_family = "unix"))] + #[cfg(any(not(target_family = "unix"), miri))] { Ok(Arc::new( turso_core::PlatformIO::new() @@ -162,12 +168,12 @@ impl Builder { )) } } - #[cfg(target_os = "linux")] + #[cfg(all(target_os = "linux", not(miri)))] "io_uring" => Ok(Arc::new( turso_core::UringIO::new() .map_err(|e| Error::SqlExecutionFailure(e.to_string()))?, )), - #[cfg(not(target_os = "linux"))] + #[cfg(any(not(target_os = "linux"), miri))] "io_uring" => Err(Error::SqlExecutionFailure( "io_uring is only available on Linux targets".to_string(), )), @@ -215,10 +221,39 @@ impl Database { } } +/// Atomic wrapper for [DropBehavior] +struct AtomicDropBehavior { + inner: AtomicU8, +} + +impl AtomicDropBehavior { + fn new(behavior: DropBehavior) -> Self { + Self { + inner: AtomicU8::new(behavior.into()), + } + } + + fn load(&self, ordering: Ordering) -> DropBehavior { + self.inner.load(ordering).into() + } + + fn store(&self, behavior: DropBehavior, ordering: Ordering) { + self.inner.store(behavior.into(), ordering); + } +} + /// A database connection. pub struct Connection { inner: Arc>>, transaction_behavior: TransactionBehavior, + /// If there is a dangling transaction after it was dropped without being finished, + /// [Connection::dangling_tx] will be set to the [DropBehavior] of the dangling transaction, + /// and the corresponding action will be taken when a new transaction is requested + /// or the connection queries/executes. + /// We cannot do this eagerly on Drop because drop is not async. + /// + /// By default, the value is [DropBehavior::Ignore] which effectively does nothing. + dangling_tx: AtomicDropBehavior, } impl Clone for Connection { @@ -226,6 +261,7 @@ impl Clone for Connection { Self { inner: Arc::clone(&self.inner), transaction_behavior: self.transaction_behavior, + dangling_tx: AtomicDropBehavior::new(self.dangling_tx.load(Ordering::SeqCst)), } } } @@ -239,17 +275,43 @@ impl Connection { let connection = Connection { inner: Arc::new(Mutex::new(conn)), transaction_behavior: TransactionBehavior::Deferred, + dangling_tx: AtomicDropBehavior::new(DropBehavior::Ignore), }; connection } + + async fn maybe_handle_dangling_tx(&self) -> Result<()> { + match self.dangling_tx.load(Ordering::SeqCst) { + DropBehavior::Rollback => { + let mut stmt = self.prepare("ROLLBACK").await?; + stmt.execute(()).await?; + self.dangling_tx + .store(DropBehavior::Ignore, Ordering::SeqCst); + } + DropBehavior::Commit => { + let mut stmt = self.prepare("COMMIT").await?; + stmt.execute(()).await?; + self.dangling_tx + .store(DropBehavior::Ignore, Ordering::SeqCst); + } + DropBehavior::Ignore => {} + DropBehavior::Panic => { + panic!("Transaction dropped unexpectedly."); + } + } + Ok(()) + } + /// Query the database with SQL. pub async fn query(&self, sql: &str, params: impl IntoParams) -> Result { + self.maybe_handle_dangling_tx().await?; let mut stmt = self.prepare(sql).await?; stmt.query(params).await } /// Execute SQL statement on the database. pub async fn execute(&self, sql: &str, params: impl IntoParams) -> Result { + self.maybe_handle_dangling_tx().await?; let mut stmt = self.prepare(sql).await?; stmt.execute(params).await } @@ -334,6 +396,7 @@ impl Connection { /// Execute a batch of SQL statements on the database. pub async fn execute_batch(&self, sql: &str) -> Result<()> { + self.maybe_handle_dangling_tx().await?; self.prepare_execute_batch(sql).await?; Ok(()) } @@ -355,6 +418,7 @@ impl Connection { } async fn prepare_execute_batch(&self, sql: impl AsRef) -> Result<()> { + self.maybe_handle_dangling_tx().await?; let conn = self .inner .lock() @@ -464,6 +528,45 @@ impl Clone for Statement { unsafe impl Send for Statement {} unsafe impl Sync for Statement {} +struct Execute { + stmt: Arc>, +} + +unsafe impl Send for Execute {} +unsafe impl Sync for Execute {} + +impl Future for Execute { + type Output = Result; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let mut stmt = self.stmt.lock().unwrap(); + match stmt.step_with_waker(cx.waker()) { + Ok(turso_core::StepResult::Row) => Poll::Ready(Err(Error::SqlExecutionFailure( + "unexpected row during execution".to_string(), + ))), + Ok(turso_core::StepResult::Done) => { + let changes = stmt.n_change(); + assert!(changes >= 0); + Poll::Ready(Ok(changes as u64)) + } + Ok(turso_core::StepResult::IO) => { + stmt.run_once()?; + Poll::Pending + } + Ok(turso_core::StepResult::Busy) => Poll::Ready(Err(Error::SqlExecutionFailure( + "database is locked".to_string(), + ))), + Ok(turso_core::StepResult::Interrupt) => { + Poll::Ready(Err(Error::SqlExecutionFailure("interrupted".to_string()))) + } + Err(err) => Poll::Ready(Err(err.into())), + } + } +} + impl Statement { /// Query the database with this prepared statement. pub async fn query(&mut self, params: impl IntoParams) -> Result { @@ -514,33 +617,11 @@ impl Statement { } } } - loop { - let mut stmt = self.inner.lock().unwrap(); - match stmt.step() { - Ok(turso_core::StepResult::Row) => { - return Err(Error::SqlExecutionFailure( - "unexpected row during execution".to_string(), - )); - } - Ok(turso_core::StepResult::Done) => { - let changes = stmt.n_change(); - assert!(changes >= 0); - return Ok(changes as u64); - } - Ok(turso_core::StepResult::IO) => { - stmt.run_once()?; - } - Ok(turso_core::StepResult::Busy) => { - return Err(Error::SqlExecutionFailure("database is locked".to_string())); - } - Ok(turso_core::StepResult::Interrupt) => { - return Err(Error::SqlExecutionFailure("interrupted".to_string())); - } - Err(err) => { - return Err(err.into()); - } - } - } + + let execute = Execute { + stmt: self.inner.clone(), + }; + execute.await } /// Returns columns of the result of this prepared statement. @@ -576,7 +657,11 @@ impl Statement { pub async fn query_row(&mut self, params: impl IntoParams) -> Result { let mut rows = self.query(params).await?; - rows.next().await?.ok_or(Error::QueryReturnedNoRows) + let first_row = rows.next().await?.ok_or(Error::QueryReturnedNoRows)?; + // Discard remaining rows so that the statement is executed to completion + // Otherwise Drop of the statement will cause transaction rollback + while rows.next().await?.is_some() {} + Ok(first_row) } } diff --git a/bindings/rust/src/rows.rs b/bindings/rust/src/rows.rs index 9102baaf3..d13193edc 100644 --- a/bindings/rust/src/rows.rs +++ b/bindings/rust/src/rows.rs @@ -2,7 +2,9 @@ use turso_core::types::FromValue; use crate::{Error, Result, Value}; use std::fmt::Debug; +use std::future::Future; use std::sync::{Arc, Mutex}; +use std::task::Poll; /// Results of a prepared statement query. pub struct Rows { @@ -28,33 +30,50 @@ impl Rows { } /// Fetch the next row of this result set. pub async fn next(&mut self) -> Result> { - loop { - let mut stmt = self - .inner - .lock() - .map_err(|e| Error::MutexError(e.to_string()))?; - match stmt.step()? { - turso_core::StepResult::Row => { - let row = stmt.row().unwrap(); - return Ok(Some(Row { - values: row.get_values().map(|v| v.to_owned()).collect(), - })); - } - turso_core::StepResult::Done => return Ok(None), - turso_core::StepResult::IO => { - if let Err(e) = stmt.run_once() { - return Err(e.into()); + struct Next { + stmt: Arc>, + } + + impl Future for Next { + type Output = Result>; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let mut stmt = self + .stmt + .lock() + .map_err(|e| Error::MutexError(e.to_string()))?; + match stmt.step_with_waker(cx.waker())? { + turso_core::StepResult::Row => { + let row = stmt.row().unwrap(); + Poll::Ready(Ok(Some(Row { + values: row.get_values().map(|v| v.to_owned()).collect(), + }))) + } + turso_core::StepResult::Done => Poll::Ready(Ok(None)), + turso_core::StepResult::IO => { + stmt.run_once()?; + Poll::Pending + } + turso_core::StepResult::Busy => Poll::Ready(Err(Error::SqlExecutionFailure( + "database is locked".to_string(), + ))), + turso_core::StepResult::Interrupt => { + Poll::Ready(Err(Error::SqlExecutionFailure("interrupted".to_string()))) } - continue; - } - turso_core::StepResult::Busy => { - return Err(Error::SqlExecutionFailure("database is locked".to_string())) - } - turso_core::StepResult::Interrupt => { - return Err(Error::SqlExecutionFailure("interrupted".to_string())) } } } + + unsafe impl Send for Next {} + + let next = Next { + stmt: self.inner.clone(), + }; + + next.await } } diff --git a/bindings/rust/src/transaction.rs b/bindings/rust/src/transaction.rs index b68cc1fab..dd77ed0d8 100644 --- a/bindings/rust/src/transaction.rs +++ b/bindings/rust/src/transaction.rs @@ -1,4 +1,4 @@ -use std::ops::Deref; +use std::{ops::Deref, sync::atomic::Ordering}; use crate::{Connection, Result}; @@ -36,13 +36,36 @@ pub enum DropBehavior { Panic, } +impl From for u8 { + fn from(behavior: DropBehavior) -> Self { + match behavior { + DropBehavior::Rollback => 0, + DropBehavior::Commit => 1, + DropBehavior::Ignore => 2, + DropBehavior::Panic => 3, + } + } +} + +impl From for DropBehavior { + fn from(value: u8) -> Self { + match value { + 0 => DropBehavior::Rollback, + 1 => DropBehavior::Commit, + 2 => DropBehavior::Ignore, + 3 => DropBehavior::Panic, + _ => panic!("Invalid drop behavior: {value}"), + } + } +} + /// Represents a transaction on a database connection. /// /// ## Note /// /// Transactions will roll back by default. Use `commit` method to explicitly /// commit the transaction, or use `set_drop_behavior` to change what happens -/// when the transaction is dropped. +/// on the next access to the connection after the transaction is dropped. /// /// ## Example /// @@ -63,7 +86,7 @@ pub enum DropBehavior { pub struct Transaction<'conn> { conn: &'conn Connection, drop_behavior: DropBehavior, - must_finish: bool, + in_progress: bool, } impl Transaction<'_> { @@ -99,7 +122,7 @@ impl Transaction<'_> { conn.execute(query, ()).await.map(move |_| Transaction { conn, drop_behavior: DropBehavior::Rollback, - must_finish: true, + in_progress: true, }) } @@ -126,8 +149,8 @@ impl Transaction<'_> { #[inline] async fn _commit(&mut self) -> Result<()> { - self.must_finish = false; self.conn.execute("COMMIT", ()).await?; + self.in_progress = false; Ok(()) } @@ -139,8 +162,8 @@ impl Transaction<'_> { #[inline] async fn _rollback(&mut self) -> Result<()> { - self.must_finish = false; self.conn.execute("ROLLBACK", ()).await?; + self.in_progress = false; Ok(()) } @@ -186,8 +209,14 @@ impl Deref for Transaction<'_> { impl Drop for Transaction<'_> { #[inline] fn drop(&mut self) { - if self.must_finish { - panic!("Transaction dropped without finish()") + if self.in_progress { + self.conn + .dangling_tx + .store(self.drop_behavior(), Ordering::SeqCst); + } else { + self.conn + .dangling_tx + .store(DropBehavior::Ignore, Ordering::SeqCst); } } } @@ -195,7 +224,8 @@ impl Drop for Transaction<'_> { impl Connection { /// Begin a new transaction with the default behavior (DEFERRED). /// - /// The transaction defaults to rolling back when it is dropped. If you + /// The transaction defaults to rolling back on the next access to the connection + /// if it is not finished when the transaction is dropped. If you /// want the transaction to commit, you must call /// [`commit`](Transaction::commit) or /// [`set_drop_behavior(DropBehavior::Commit)`](Transaction::set_drop_behavior). @@ -221,7 +251,8 @@ impl Connection { /// Will return `Err` if the call fails. #[inline] pub async fn transaction(&mut self) -> Result> { - Transaction::new(self, self.transaction_behavior).await + self.transaction_with_behavior(self.transaction_behavior) + .await } /// Begin a new transaction with a specified behavior. @@ -236,6 +267,7 @@ impl Connection { &mut self, behavior: TransactionBehavior, ) -> Result> { + self.maybe_handle_dangling_tx().await?; Transaction::new(self, behavior).await } @@ -318,29 +350,81 @@ mod test { } #[tokio::test] - #[should_panic(expected = "Transaction dropped without finish()")] - async fn test_drop_panic() { + async fn test_drop_rollback_on_new_transaction() { let mut conn = checked_memory_handle().await.unwrap(); { let tx = conn.transaction().await.unwrap(); tx.execute("INSERT INTO foo VALUES(?)", &[1]).await.unwrap(); + // Drop without finish - should be rolled back when next transaction starts } + + // Start a new transaction - this should rollback the dangling one + let tx = conn.transaction().await.unwrap(); + tx.execute("INSERT INTO foo VALUES(?)", &[2]).await.unwrap(); + let result = tx + .prepare("SELECT SUM(x) FROM foo") + .await + .unwrap() + .query_row(()) + .await + .unwrap(); + + // The insert from the dropped transaction should have been rolled back + assert_eq!(2, result.get::(0).unwrap()); + tx.finish().await.unwrap(); + } + + #[tokio::test] + async fn test_drop_rollback_on_query() { + let mut conn = checked_memory_handle().await.unwrap(); + { + let tx = conn.transaction().await.unwrap(); + tx.execute("INSERT INTO foo VALUES(?)", &[1]).await.unwrap(); + // Drop without finish - should be rolled back when conn.query is called + } + + // Using conn.query should rollback the dangling transaction + let mut rows = conn.query("SELECT count(*) FROM foo", ()).await.unwrap(); + let result = rows.next().await.unwrap().unwrap(); + + // The insert from the dropped transaction should have been rolled back + assert_eq!(0, result.get::(0).unwrap()); + } + + #[tokio::test] + async fn test_drop_rollback_on_execute() { + let mut conn = checked_memory_handle().await.unwrap(); + { + let tx = conn.transaction().await.unwrap(); + tx.execute("INSERT INTO foo VALUES(?)", &[1]).await.unwrap(); + // Drop without finish - should be rolled back when conn.execute is called + } + + // Using conn.execute should rollback the dangling transaction + conn.execute("INSERT INTO foo VALUES(?)", &[2]) + .await + .unwrap(); + + let mut rows = conn.query("SELECT count(*) FROM foo", ()).await.unwrap(); + let result = rows.next().await.unwrap().unwrap(); + + // The insert from the dropped transaction should have been rolled back + assert_eq!(1, result.get::(0).unwrap()); } #[tokio::test] async fn test_drop() -> Result<()> { + let _ = tracing_subscriber::fmt::try_init(); let mut conn = checked_memory_handle().await?; { let tx = conn.transaction().await?; tx.execute("INSERT INTO foo VALUES(?)", &[1]).await?; - tx.finish().await?; // default: rollback } { let mut tx = conn.transaction().await?; tx.execute("INSERT INTO foo VALUES(?)", &[2]).await?; tx.set_drop_behavior(DropBehavior::Commit); - tx.finish().await?; } { let tx = conn.transaction().await?; @@ -351,7 +435,6 @@ mod test { .await?; assert_eq!(2, result.get::(0)?); - tx.finish().await?; } Ok(()) } diff --git a/bindings/rust/tests/integration_tests.rs b/bindings/rust/tests/integration_tests.rs index 19514532d..c0b25244d 100644 --- a/bindings/rust/tests/integration_tests.rs +++ b/bindings/rust/tests/integration_tests.rs @@ -402,7 +402,8 @@ async fn test_concurrent_unique_constraint_regression() { match result { Ok(_) => (), Err(Error::SqlExecutionFailure(e)) - if e.contains("UNIQUE constraint failed") => {} + if e.contains("UNIQUE constraint failed") + | e.contains("database is locked") => {} Err(e) => { panic!("Error executing statement: {e:?}"); } diff --git a/cli/Cargo.toml b/cli/Cargo.toml index c1c60f928..35f691042 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -28,8 +28,8 @@ ctrlc = "3.4.4" dirs = "5.0.1" env_logger = { workspace = true } libc = "0.2.172" -turso_core = { path = "../core", default-features = true, features = ["cli_only"] } -limbo_completion = { path = "../extensions/completion", features = ["static"] } +turso_core = { workspace = true , default-features = true, features = ["cli_only"] } +limbo_completion = { workspace = true, features = ["static"] } miette = { workspace = true, features = ["fancy"] } nu-ansi-term = {version = "0.50.1", features = ["serde", "derive_serde_style"]} rustyline = { version = "15.0.0", default-features = true, features = [ diff --git a/cli/app.rs b/cli/app.rs index efc2f312f..923d280a3 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -106,44 +106,65 @@ macro_rules! row_step_result_query { return Ok(()); } - let start = Instant::now(); + let start = if $stats.is_some() { + Some(Instant::now()) + } else { + None + }; match $rows.step() { Ok(StepResult::Row) => { if let Some(ref mut stats) = $stats { - stats.execute_time_elapsed_samples.push(start.elapsed()); + stats + .execute_time_elapsed_samples + .push(start.unwrap().elapsed()); } $row_handle } Ok(StepResult::IO) => { - let start = Instant::now(); + if let Some(ref mut stats) = $stats { + stats.io_time_elapsed_samples.push(start.unwrap().elapsed()); + } + let start = if $stats.is_some() { + Some(Instant::now()) + } else { + None + }; $rows.run_once()?; if let Some(ref mut stats) = $stats { - stats.io_time_elapsed_samples.push(start.elapsed()); + stats.io_time_elapsed_samples.push(start.unwrap().elapsed()); } } Ok(StepResult::Interrupt) => { if let Some(ref mut stats) = $stats { - stats.execute_time_elapsed_samples.push(start.elapsed()); + stats + .execute_time_elapsed_samples + .push(start.unwrap().elapsed()); } break; } Ok(StepResult::Done) => { if let Some(ref mut stats) = $stats { - stats.execute_time_elapsed_samples.push(start.elapsed()); + stats + .execute_time_elapsed_samples + .push(start.unwrap().elapsed()); } break; } Ok(StepResult::Busy) => { if let Some(ref mut stats) = $stats { - stats.execute_time_elapsed_samples.push(start.elapsed()); + stats + .execute_time_elapsed_samples + .push(start.unwrap().elapsed()); } let _ = $app.writeln("database is busy"); break; } Err(err) => { if let Some(ref mut stats) = $stats { - stats.execute_time_elapsed_samples.push(start.elapsed()); + stats + .execute_time_elapsed_samples + .push(start.unwrap().elapsed()); } let report = miette::Error::from(err).with_source_code($sql.to_owned()); let _ = $app.writeln_fmt(format_args!("{report:?}")); @@ -1239,26 +1260,33 @@ impl Limbo { } fn display_indexes(&mut self, maybe_table: Option) -> anyhow::Result<()> { - let sql = match maybe_table { - Some(ref tbl_name) => format!( - "SELECT name FROM sqlite_schema WHERE type='index' AND tbl_name = '{tbl_name}' ORDER BY 1" - ), - None => String::from("SELECT name FROM sqlite_schema WHERE type='index' ORDER BY 1"), - }; - let mut indexes = String::new(); - let handler = |row: &turso_core::Row| -> anyhow::Result<()> { - if let Ok(Value::Text(idx)) = row.get::<&Value>(0) { - indexes.push_str(idx.as_str()); - indexes.push(' '); - } - Ok(()) - }; - if let Err(err) = self.handle_row(&sql, handler) { - if err.to_string().contains("no such table: sqlite_schema") { - return Err(anyhow::anyhow!("Unable to access database schema. The database may be using an older SQLite version or may not be properly initialized.")); - } else { - return Err(anyhow::anyhow!("Error querying schema: {}", err)); + + for name in self.database_names()? { + let prefix = (name != "main").then_some(&name); + let sql = match maybe_table { + Some(ref tbl_name) => format!( + "SELECT name FROM {name}.sqlite_schema WHERE type='index' AND tbl_name = '{tbl_name}' ORDER BY 1" + ), + None => format!("SELECT name FROM {name}.sqlite_schema WHERE type='index' ORDER BY 1"), + }; + let handler = |row: &turso_core::Row| -> anyhow::Result<()> { + if let Ok(Value::Text(idx)) = row.get::<&Value>(0) { + if let Some(prefix) = prefix { + indexes.push_str(prefix); + indexes.push('.'); + } + indexes.push_str(idx.as_str()); + indexes.push(' '); + } + Ok(()) + }; + if let Err(err) = self.handle_row(&sql, handler) { + if err.to_string().contains("no such table: sqlite_schema") { + return Err(anyhow::anyhow!("Unable to access database schema. The database may be using an older SQLite version or may not be properly initialized.")); + } else { + return Err(anyhow::anyhow!("Error querying schema: {}", err)); + } } } if !indexes.is_empty() { @@ -1268,28 +1296,35 @@ impl Limbo { } fn display_tables(&mut self, pattern: Option<&str>) -> anyhow::Result<()> { - let sql = match pattern { - Some(pattern) => format!( - "SELECT name FROM sqlite_schema WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name LIKE '{pattern}' ORDER BY 1" - ), - None => String::from( - "SELECT name FROM sqlite_schema WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY 1" - ), - }; - let mut tables = String::new(); - let handler = |row: &turso_core::Row| -> anyhow::Result<()> { - if let Ok(Value::Text(table)) = row.get::<&Value>(0) { - tables.push_str(table.as_str()); - tables.push(' '); - } - Ok(()) - }; - if let Err(e) = self.handle_row(&sql, handler) { - if e.to_string().contains("no such table: sqlite_schema") { - return Err(anyhow::anyhow!("Unable to access database schema. The database may be using an older SQLite version or may not be properly initialized.")); - } else { - return Err(anyhow::anyhow!("Error querying schema: {}", e)); + + for name in self.database_names()? { + let prefix = (name != "main").then_some(&name); + let sql = match pattern { + Some(pattern) => format!( + "SELECT name FROM {name}.sqlite_schema WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name LIKE '{pattern}' ORDER BY 1" + ), + None => format!( + "SELECT name FROM {name}.sqlite_schema WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY 1" + ), + }; + let handler = |row: &turso_core::Row| -> anyhow::Result<()> { + if let Ok(Value::Text(table)) = row.get::<&Value>(0) { + if let Some(prefix) = prefix { + tables.push_str(prefix); + tables.push('.'); + } + tables.push_str(table.as_str()); + tables.push(' '); + } + Ok(()) + }; + if let Err(e) = self.handle_row(&sql, handler) { + if e.to_string().contains("no such table: sqlite_schema") { + return Err(anyhow::anyhow!("Unable to access database schema. The database may be using an older SQLite version or may not be properly initialized.")); + } else { + return Err(anyhow::anyhow!("Error querying schema: {}", e)); + } } } if !tables.is_empty() { @@ -1304,6 +1339,21 @@ impl Limbo { Ok(()) } + fn database_names(&mut self) -> anyhow::Result> { + let sql = "PRAGMA database_list"; + let mut db_names: Vec = Vec::new(); + let handler = |row: &turso_core::Row| -> anyhow::Result<()> { + if let Ok(Value::Text(name)) = row.get::<&Value>(1) { + db_names.push(name.to_string()); + } + Ok(()) + }; + match self.handle_row(sql, handler) { + Ok(_) => Ok(db_names), + Err(e) => Err(anyhow::anyhow!("Error in database list: {}", e)), + } + } + fn handle_row(&mut self, sql: &str, mut handler: F) -> anyhow::Result<()> where F: FnMut(&turso_core::Row) -> anyhow::Result<()>, @@ -1453,6 +1503,10 @@ impl Limbo { StepResult::Row => { let row = rows.row().unwrap(); let name: &str = row.get::<&str>(0)?; + // Skip sqlite_sequence table + if name == "sqlite_sequence" { + continue; + } let ddl: &str = row.get::<&str>(1)?; writeln!(out, "{ddl};")?; Self::dump_table_from_conn(&conn, out, name, &mut progress)?; @@ -1567,7 +1621,6 @@ impl Limbo { if !has_seq { return Ok(()); } - writeln!(out, "DELETE FROM sqlite_sequence;")?; if let Some(mut rows) = conn.query("SELECT name, seq FROM sqlite_sequence")? { loop { diff --git a/cli/main.rs b/cli/main.rs index 5b5f98e2f..a9a4eeb89 100644 --- a/cli/main.rs +++ b/cli/main.rs @@ -16,7 +16,7 @@ use std::{ sync::{atomic::Ordering, LazyLock}, }; -#[cfg(not(target_family = "wasm"))] +#[cfg(all(not(target_family = "wasm"), not(miri)))] #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; diff --git a/cli/manual.rs b/cli/manual.rs index b6a459696..b1266a38e 100644 --- a/cli/manual.rs +++ b/cli/manual.rs @@ -1,7 +1,15 @@ use include_dir::{include_dir, Dir}; use rand::seq::SliceRandom; -use std::io::{IsTerminal, Write}; -use termimad::MadSkin; +use std::io::{stdout, IsTerminal, Write}; + +use termimad::{ + crossterm::{ + event::{read, Event, KeyCode}, + queue, + terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, + }, + Area, MadSkin, MadView, +}; static MANUAL_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/manuals"); @@ -63,34 +71,89 @@ fn strip_frontmatter(content: &str) -> &str { } } +// not ideal but enough for our usecase , probably overkill maybe. +fn levenshtein(a: &str, b: &str) -> usize { + let a_chars: Vec<_> = a.chars().collect(); + let b_chars: Vec<_> = b.chars().collect(); + let (a_len, b_len) = (a_chars.len(), b_chars.len()); + if a_len == 0 { + return b_len; + } + if b_len == 0 { + return a_len; + } + let mut prev_row: Vec = (0..=b_len).collect(); + let mut current_row = vec![0; b_len + 1]; + for i in 1..=a_len { + current_row[0] = i; + for j in 1..=b_len { + let substitution_cost = if a_chars[i - 1] == b_chars[j - 1] { + 0 + } else { + 1 + }; + current_row[j] = (prev_row[j] + 1) + .min(current_row[j - 1] + 1) + .min(prev_row[j - 1] + substitution_cost); + } + prev_row.clone_from_slice(¤t_row); + } + prev_row[b_len] +} + +fn find_closest_manual_page<'a>( + page_name: &str, + available_pages: impl Iterator, +) -> Option<&'a str> { + const RELATIVE_SIMILARITY_THRESHOLD: f64 = 0.4; + + available_pages + .filter_map(|candidate| { + let distance = levenshtein(page_name, candidate); + let longer_len = std::cmp::max(page_name.chars().count(), candidate.chars().count()); + if longer_len == 0 { + return None; + } + let relative_distance = distance as f64 / longer_len as f64; + + if relative_distance < RELATIVE_SIMILARITY_THRESHOLD { + Some((candidate, distance)) + } else { + None + } + }) + .min_by_key(|&(_, score)| score) + .map(|(name, _)| name) +} + pub fn display_manual(page: Option<&str>, writer: &mut dyn Write) -> anyhow::Result<()> { let page_name = page.unwrap_or("index"); let file_name = format!("{page_name}.md"); - // Try to find the manual page - let content = if let Some(file) = MANUAL_DIR.get_file(&file_name) { - file.contents_utf8() - .ok_or_else(|| anyhow::anyhow!("Failed to read manual page: {}", page_name))? + if let Some(file) = MANUAL_DIR.get_file(&file_name) { + let content = file + .contents_utf8() + .ok_or_else(|| anyhow::anyhow!("Failed to read manual page: {}", page_name))?; + let content = strip_frontmatter(content); + if IsTerminal::is_terminal(&std::io::stdout()) { + render_in_terminal(content)?; + } else { + writeln!(writer, "{content}")?; + } + Ok(()) } else if page.is_none() { // If no page specified, list available pages return list_available_manuals(writer); } else { - return Err(anyhow::anyhow!("Manual page not found: {}", page_name)); - }; - - // Strip frontmatter before displaying - let content = strip_frontmatter(content); - - // Check if we're in a terminal or piped output - if IsTerminal::is_terminal(&std::io::stdout()) { - // Use termimad for nice terminal rendering - render_in_terminal(content)?; - } else { - // Plain output for pipes/redirects - writeln!(writer, "{content}")?; + let available_pages = MANUAL_DIR + .files() + .filter_map(|file| file.path().file_stem().and_then(|stem| stem.to_str())); + let mut error_message = format!("Manual page not found: {page_name}"); + if let Some(suggestion) = find_closest_manual_page(page_name, available_pages) { + error_message.push_str(&format!("\n\nDid you mean '.manual {suggestion}'?")); + } + Err(anyhow::anyhow!(error_message)) } - - Ok(()) } fn render_in_terminal(content: &str) -> anyhow::Result<()> { @@ -107,8 +170,41 @@ fn render_in_terminal(content: &str) -> anyhow::Result<()> { skin.code_block .set_fg(termimad::crossterm::style::Color::Green); - // Just print the formatted content - skin.print_text(content); + let mut w = stdout(); + queue!(w, EnterAlternateScreen)?; + enable_raw_mode()?; + + let area = Area::full_screen(); + let mut view = MadView::from(content.to_string(), area, skin); + + loop { + view.write_on(&mut w)?; + w.flush()?; + + match read()? { + Event::Key(key) => match key.code { + KeyCode::Up | KeyCode::Char('k') => view.try_scroll_lines(-1), + KeyCode::Down | KeyCode::Char('j') => view.try_scroll_lines(1), + KeyCode::PageUp => view.try_scroll_pages(-1), + KeyCode::PageDown => view.try_scroll_pages(1), + KeyCode::Char('g') => view.scroll = 0, + KeyCode::Char('G') => view.try_scroll_lines(i32::MAX), + + KeyCode::Esc | KeyCode::Char('q') | KeyCode::Enter => break, + + _ => {} + }, + Event::Resize(width, height) => { + let new_area = Area::new(0, 0, width, height); + view.resize(&new_area); + } + _ => {} + } + } + + disable_raw_mode()?; + queue!(w, LeaveAlternateScreen)?; + w.flush()?; Ok(()) } @@ -116,27 +212,19 @@ fn render_in_terminal(content: &str) -> anyhow::Result<()> { fn list_available_manuals(writer: &mut dyn Write) -> anyhow::Result<()> { writeln!(writer, "Available manual pages:")?; writeln!(writer)?; - - let mut pages: Vec = Vec::new(); - - for file in MANUAL_DIR.files() { - if let Some(name) = file.path().file_stem() { - if let Some(name_str) = name.to_str() { - pages.push(name_str.to_string()); - } - } - } - + let mut pages: Vec = MANUAL_DIR + .files() + .filter_map(|file| file.path().file_stem()?.to_str().map(String::from)) + .collect(); pages.sort(); - for page in pages { + for page in &pages { writeln!(writer, " .manual {page} # or .man {page}")?; } - if MANUAL_DIR.files().count() == 0 { + if pages.is_empty() { writeln!(writer, " (No manual pages found)")?; } - writeln!(writer)?; writeln!(writer, "Usage: .manual or .man ")?; diff --git a/core/Cargo.toml b/core/Cargo.toml index 7fa6afb78..b913b62f2 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -15,7 +15,7 @@ path = "lib.rs" [features] default = ["fs", "uuid", "time", "json", "series", "encryption"] -antithesis = ["dep:antithesis_sdk"] +antithesis = ["dep:antithesis_sdk", "antithesis_sdk?/full"] tracing_release = ["tracing/release_max_level_info"] conn_raw_api = [] fs = ["turso_ext/vfs"] @@ -52,14 +52,12 @@ cfg_block = "0.1.1" fallible-iterator = { workspace = true } hex = { workspace = true } thiserror = { workspace = true } -getrandom = { version = "0.2.15" } regex = { workspace = true } regex-syntax = { workspace = true, default-features = false, features = [ "unicode", ] } chrono = { workspace = true, default-features = false, features = ["clock"] } -julian_day_converter = "0.4.5" -rand = "0.8.5" +rand = { workspace = true } libm = "0.2" turso_macros = { workspace = true } miette = { workspace = true } @@ -83,6 +81,9 @@ turso_parser = { workspace = true } aegis = "0.9.0" twox-hash = "2.1.1" intrusive-collections = "0.9.7" +roaring = "0.11.2" +simsimd = "6.5.3" +arc-swap = "1.7" [build-dependencies] chrono = { workspace = true, default-features = false } @@ -99,10 +100,9 @@ criterion = { workspace = true, features = [ "async_futures", ] } rstest = "0.18.2" -rusqlite.workspace = true +rusqlite = { workspace = true, features = ["series"] } quickcheck = { version = "1.0", default-features = false } quickcheck_macros = { version = "1.0", default-features = false } -rand = "0.8.5" # Required for quickcheck rand_chacha = { workspace = true } env_logger = { workspace = true } test-log = { version = "0.2.17", features = ["trace"] } diff --git a/core/benches/benchmark.rs b/core/benches/benchmark.rs index 51b844aad..0ef1c6875 100644 --- a/core/benches/benchmark.rs +++ b/core/benches/benchmark.rs @@ -2,6 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri use pprof::criterion::{Output, PProfProfiler}; use regex::Regex; use std::{sync::Arc, time::Instant}; +use tempfile::TempDir; use turso_core::{Database, LimboError, PlatformIO, StepResult}; #[cfg(not(target_family = "wasm"))] @@ -16,6 +17,36 @@ fn rusqlite_open() -> rusqlite::Connection { sqlite_conn } +fn setup_rusqlite(temp_dir: &TempDir, query: &str) -> rusqlite::Connection { + let db_path = temp_dir.path().join("bench.db"); + let sqlite_conn = rusqlite::Connection::open(db_path).unwrap(); + sqlite_conn + .pragma_update(None, "synchronous", "FULL") + .unwrap(); + sqlite_conn + .pragma_update(None, "journal_mode", "WAL") + .unwrap(); + sqlite_conn + .pragma_update(None, "locking_mode", "EXCLUSIVE") + .unwrap(); + let journal_mode = sqlite_conn + .pragma_query_value(None, "journal_mode", |row| row.get::<_, String>(0)) + .unwrap(); + assert_eq!(journal_mode.to_lowercase(), "wal"); + let synchronous = sqlite_conn + .pragma_query_value(None, "synchronous", |row| row.get::<_, usize>(0)) + .unwrap(); + const FULL: usize = 2; + assert_eq!(synchronous, FULL); + + // load the generate_series extension + rusqlite::vtab::series::load_module(&sqlite_conn).unwrap(); + + // Create test table + sqlite_conn.execute(query, []).unwrap(); + sqlite_conn +} + fn bench_open(criterion: &mut Criterion) { // https://github.com/tursodatabase/turso/issues/174 // The rusqlite benchmark crashes on Mac M1 when using the flamegraph features @@ -896,9 +927,90 @@ fn bench_concurrent_writes(criterion: &mut Criterion) { }); } +fn bench_insert_randomblob(criterion: &mut Criterion) { + // The rusqlite benchmark crashes on Mac M1 when using the flamegraph features + let enable_rusqlite = std::env::var("DISABLE_RUSQLITE_BENCHMARK").is_err(); + + let mut group = criterion.benchmark_group("Insert rows in batches"); + + // Test different batch sizes + for batch_size in [1, 10, 100] { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = temp_dir.path().join("bench.db"); + + #[allow(clippy::arc_with_non_send_sync)] + let io = Arc::new(PlatformIO::new().unwrap()); + let db = Database::open_file(io.clone(), db_path.to_str().unwrap(), false, false).unwrap(); + let limbo_conn = db.connect().unwrap(); + + let mut stmt = limbo_conn.query("CREATE TABLE test(x)").unwrap().unwrap(); + + loop { + match stmt.step().unwrap() { + turso_core::StepResult::IO => { + stmt.run_once().unwrap(); + } + turso_core::StepResult::Done => { + break; + } + turso_core::StepResult::Row => { + unreachable!(); + } + turso_core::StepResult::Interrupt | turso_core::StepResult::Busy => { + unreachable!(); + } + } + } + + let random_blob = format!( + "INSERT INTO test select randomblob(1024 * 100) from generate_series(1, {batch_size});" + ); + + group.bench_function(format!("limbo_insert_{batch_size}_randomblob"), |b| { + let mut stmt = limbo_conn.prepare(&random_blob).unwrap(); + b.iter(|| { + loop { + match stmt.step().unwrap() { + turso_core::StepResult::IO => { + stmt.run_once().unwrap(); + } + turso_core::StepResult::Done => { + break; + } + turso_core::StepResult::Row => { + unreachable!(); + } + turso_core::StepResult::Interrupt | turso_core::StepResult::Busy => { + unreachable!(); + } + } + } + stmt.reset(); + }); + }); + + if enable_rusqlite { + let temp_dir = tempfile::tempdir().unwrap(); + let sqlite_conn = setup_rusqlite(&temp_dir, "CREATE TABLE test(x)"); + + group.bench_function(format!("sqlite_insert_{batch_size}_randomblob"), |b| { + let mut stmt = sqlite_conn.prepare(&random_blob).unwrap(); + b.iter(|| { + let mut rows = stmt.raw_query(); + while let Some(row) = rows.next().unwrap() { + black_box(row); + } + }); + }); + } + } + + group.finish(); +} + criterion_group! { name = benches; config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = bench_open, bench_alter, bench_prepare_query, bench_execute_select_1, bench_execute_select_rows, bench_execute_select_count, bench_insert_rows, bench_concurrent_writes + targets = bench_open, bench_alter, bench_prepare_query, bench_execute_select_1, bench_execute_select_rows, bench_execute_select_count, bench_insert_rows, bench_concurrent_writes, bench_insert_randomblob } criterion_main!(benches); diff --git a/core/benches/mvcc_benchmark.rs b/core/benches/mvcc_benchmark.rs index 0ebd33fa5..7d316707d 100644 --- a/core/benches/mvcc_benchmark.rs +++ b/core/benches/mvcc_benchmark.rs @@ -36,8 +36,7 @@ fn bench(c: &mut Criterion) { let conn = db.conn.clone(); let tx_id = db.mvcc_store.begin_tx(conn.get_pager().clone()).unwrap(); db.mvcc_store - .rollback_tx(tx_id, conn.get_pager().clone(), &conn) - .unwrap(); + .rollback_tx(tx_id, conn.get_pager().clone(), &conn); }) }); diff --git a/core/build.rs b/core/build.rs index 50afee6bf..270fae925 100644 --- a/core/build.rs +++ b/core/build.rs @@ -1,7 +1,13 @@ -use std::fs; use std::path::PathBuf; +use std::{env, fs}; fn main() { + let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string()); + + if profile == "debug" { + println!("cargo::rerun-if-changed=build.rs"); + } + let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); let built_file = out_dir.join("built.rs"); diff --git a/core/error.rs b/core/error.rs index 3dd4841ad..c5bb811db 100644 --- a/core/error.rs +++ b/core/error.rs @@ -49,8 +49,6 @@ pub enum LimboError { ExtensionError(String), #[error("Runtime error: integer overflow")] IntegerOverflow, - #[error("Schema is locked for write")] - SchemaLocked, #[error("Runtime error: database table is locked")] TableLocked, #[error("Error: Resource is read-only")] @@ -165,6 +163,8 @@ impl From for LimboError { pub const SQLITE_CONSTRAINT: usize = 19; pub const SQLITE_CONSTRAINT_PRIMARYKEY: usize = SQLITE_CONSTRAINT | (6 << 8); +#[allow(dead_code)] +pub const SQLITE_CONSTRAINT_FOREIGNKEY: usize = SQLITE_CONSTRAINT | (7 << 8); pub const SQLITE_CONSTRAINT_NOTNULL: usize = SQLITE_CONSTRAINT | (5 << 8); pub const SQLITE_FULL: usize = 13; // we want this in autoincrement - incase if user inserts max allowed int pub const SQLITE_CONSTRAINT_UNIQUE: usize = 2067; diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 1d73c3ba2..d58c86909 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -2,7 +2,7 @@ mod dynamic; mod vtab_xconnect; use crate::schema::{Schema, Table}; -#[cfg(all(target_os = "linux", feature = "io_uring"))] +#[cfg(all(target_os = "linux", feature = "io_uring", not(miri)))] use crate::UringIO; use crate::{function::ExternalFunc, Connection, Database}; use crate::{vtab::VirtualTable, SymbolTable}; @@ -146,7 +146,7 @@ impl Database { let io: Arc = match vfs { "memory" => Arc::new(MemoryIO::new()), "syscall" => Arc::new(SyscallIO::new()?), - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(all(target_os = "linux", feature = "io_uring", not(miri)))] "io_uring" => Arc::new(UringIO::new()?), other => match get_vfs_modules().iter().find(|v| v.0 == vfs) { Some((_, vfs)) => vfs.clone(), diff --git a/core/function.rs b/core/function.rs index b2858af4a..0a3236856 100644 --- a/core/function.rs +++ b/core/function.rs @@ -153,10 +153,12 @@ impl Display for JsonFunc { pub enum VectorFunc { Vector, Vector32, + Vector32Sparse, Vector64, VectorExtract, VectorDistanceCos, - VectorDistanceEuclidean, + VectorDistanceL2, + VectorDistanceJaccard, VectorConcat, VectorSlice, } @@ -172,11 +174,12 @@ impl Display for VectorFunc { let str = match self { Self::Vector => "vector".to_string(), Self::Vector32 => "vector32".to_string(), + Self::Vector32Sparse => "vector32_sparse".to_string(), Self::Vector64 => "vector64".to_string(), Self::VectorExtract => "vector_extract".to_string(), Self::VectorDistanceCos => "vector_distance_cos".to_string(), - // We use `distance_l2` to reduce user input - Self::VectorDistanceEuclidean => "vector_distance_l2".to_string(), + Self::VectorDistanceL2 => "vector_distance_l2".to_string(), + Self::VectorDistanceJaccard => "vector_distance_jaccard".to_string(), Self::VectorConcat => "vector_concat".to_string(), Self::VectorSlice => "vector_slice".to_string(), }; @@ -310,6 +313,7 @@ pub enum ScalarFunc { Unicode, Quote, SqliteVersion, + TursoVersion, SqliteSourceId, UnixEpoch, JulianDay, @@ -373,6 +377,7 @@ impl ScalarFunc { ScalarFunc::Unicode => true, ScalarFunc::Quote => true, ScalarFunc::SqliteVersion => true, + ScalarFunc::TursoVersion => true, ScalarFunc::SqliteSourceId => true, ScalarFunc::UnixEpoch => false, ScalarFunc::JulianDay => false, @@ -437,6 +442,7 @@ impl Display for ScalarFunc { Self::Unicode => "unicode".to_string(), Self::Quote => "quote".to_string(), Self::SqliteVersion => "sqlite_version".to_string(), + Self::TursoVersion => "turso_version".to_string(), Self::SqliteSourceId => "sqlite_source_id".to_string(), Self::JulianDay => "julianday".to_string(), Self::UnixEpoch => "unixepoch".to_string(), @@ -642,6 +648,29 @@ impl Func { Self::AlterTable(_) => true, } } + + pub fn supports_star_syntax(&self) -> bool { + match self { + Self::Scalar(scalar_func) => { + matches!( + scalar_func, + ScalarFunc::Changes + | ScalarFunc::Random + | ScalarFunc::TotalChanges + | ScalarFunc::SqliteVersion + | ScalarFunc::TursoVersion + | ScalarFunc::SqliteSourceId + | ScalarFunc::LastInsertRowid + ) + } + Self::Math(math_func) => { + matches!(math_func.arity(), MathFuncArity::Nullary) + } + // Aggregate functions with (*) syntax are handled separately in the planner + Self::Agg(_) => false, + _ => false, + } + } pub fn resolve_function(name: &str, arg_count: usize) -> Result { let normalized_name = crate::util::normalize_ident(name); match normalized_name.as_str() { @@ -723,7 +752,7 @@ impl Func { "total_changes" => Ok(Self::Scalar(ScalarFunc::TotalChanges)), "glob" => Ok(Self::Scalar(ScalarFunc::Glob)), "ifnull" => Ok(Self::Scalar(ScalarFunc::IfNull)), - "iif" => Ok(Self::Scalar(ScalarFunc::Iif)), + "if" | "iif" => Ok(Self::Scalar(ScalarFunc::Iif)), "instr" => Ok(Self::Scalar(ScalarFunc::Instr)), "like" => Ok(Self::Scalar(ScalarFunc::Like)), "abs" => Ok(Self::Scalar(ScalarFunc::Abs)), @@ -748,6 +777,7 @@ impl Func { "unicode" => Ok(Self::Scalar(ScalarFunc::Unicode)), "quote" => Ok(Self::Scalar(ScalarFunc::Quote)), "sqlite_version" => Ok(Self::Scalar(ScalarFunc::SqliteVersion)), + "turso_version" => Ok(Self::Scalar(ScalarFunc::TursoVersion)), "sqlite_source_id" => Ok(Self::Scalar(ScalarFunc::SqliteSourceId)), "replace" => Ok(Self::Scalar(ScalarFunc::Replace)), "likely" => Ok(Self::Scalar(ScalarFunc::Likely)), @@ -843,10 +873,12 @@ impl Func { "printf" => Ok(Self::Scalar(ScalarFunc::Printf)), "vector" => Ok(Self::Vector(VectorFunc::Vector)), "vector32" => Ok(Self::Vector(VectorFunc::Vector32)), + "vector32_sparse" => Ok(Self::Vector(VectorFunc::Vector32Sparse)), "vector64" => Ok(Self::Vector(VectorFunc::Vector64)), "vector_extract" => Ok(Self::Vector(VectorFunc::VectorExtract)), "vector_distance_cos" => Ok(Self::Vector(VectorFunc::VectorDistanceCos)), - "vector_distance_l2" => Ok(Self::Vector(VectorFunc::VectorDistanceEuclidean)), + "vector_distance_l2" => Ok(Self::Vector(VectorFunc::VectorDistanceL2)), + "vector_distance_jaccard" => Ok(Self::Vector(VectorFunc::VectorDistanceJaccard)), "vector_concat" => Ok(Self::Vector(VectorFunc::VectorConcat)), "vector_slice" => Ok(Self::Vector(VectorFunc::VectorSlice)), _ => crate::bail_parse_error!("no such function: {}", name), diff --git a/core/functions/datetime.rs b/core/functions/datetime.rs index b5969855f..3ccbef5a2 100644 --- a/core/functions/datetime.rs +++ b/core/functions/datetime.rs @@ -348,30 +348,35 @@ pub fn exec_julianday(values: &[Register]) -> Value { } fn to_julian_day_exact(dt: &NaiveDateTime) -> f64 { - let year = dt.year(); - let month = dt.month() as i32; - let day = dt.day() as i32; - let (adjusted_year, adjusted_month) = if month <= 2 { - (year - 1, month + 12) - } else { - (year, month) - }; + // SQLite's computeJD algorithm + let mut y = dt.year(); + let mut m = dt.month() as i32; + let d = dt.day() as i32; - let a = adjusted_year / 100; - let b = 2 - a + a / 4; - let jd_days = (365.25 * ((adjusted_year + 4716) as f64)).floor() - + (30.6001 * ((adjusted_month + 1) as f64)).floor() - + (day as f64) - + (b as f64) - - 1524.5; + if m <= 2 { + y -= 1; + m += 12; + } - let seconds = dt.hour() as f64 * 3600.0 - + dt.minute() as f64 * 60.0 - + dt.second() as f64 - + (dt.nanosecond() as f64) / 1_000_000_000.0; + let a = (y + 4800) / 100; + let b = 38 - a + (a / 4); + let x1 = 36525 * (y + 4716) / 100; + let x2 = 306001 * (m + 1) / 10000; - let jd_fraction = seconds / 86400.0; - jd_days + jd_fraction + // iJD = (sqlite3_int64)((X1 + X2 + D + B - 1524.5) * 86400000) + let jd_days = (x1 + x2 + d + b) as f64 - 1524.5; + let mut i_jd = (jd_days * 86400000.0) as i64; + + // Add time component in milliseconds + // iJD += h*3600000 + m*60000 + (sqlite3_int64)(s*1000 + 0.5) + let h_ms = dt.hour() as i64 * 3600000; + let m_ms = dt.minute() as i64 * 60000; + let s_ms = (dt.second() as f64 * 1000.0 + dt.nanosecond() as f64 / 1_000_000.0 + 0.5) as i64; + + i_jd += h_ms + m_ms + s_ms; + + // Convert back to floating point JD + i_jd as f64 / 86400000.0 } pub fn exec_unixepoch(time_value: &Value) -> Result { @@ -490,7 +495,58 @@ fn get_date_time_from_time_value_float(value: f64) -> Option { if value.is_infinite() || value.is_nan() || !is_julian_day_value(value) { return None; } - julian_day_converter::julian_day_to_datetime(value).ok() + julian_day_to_datetime(value).ok() +} + +/// Convert a Julian Day number (as f64) to a NaiveDateTime +/// This uses SQLite's algorithm which converts to integer milliseconds first +/// to preserve precision, then converts back to date/time components. +fn julian_day_to_datetime(jd: f64) -> Result { + // SQLite approach: Convert JD to integer milliseconds + // iJD = (sqlite3_int64)(jd * 86400000.0 + 0.5) + let i_jd = (jd * 86400000.0 + 0.5) as i64; + + // Compute the date (Year, Month, Day) from iJD + // Z = (int)((iJD + 43200000)/86400000) + let z = ((i_jd + 43200000) / 86400000) as i32; + + // SQLite's algorithm from computeYMD + let alpha = ((z as f64 + 32044.75) / 36524.25) as i32 - 52; + let a = z + 1 + alpha - ((alpha + 100) / 4) + 25; + let b = a + 1524; + let c = ((b as f64 - 122.1) / 365.25) as i32; + let d = (36525 * (c & 32767)) / 100; + let e = ((b - d) as f64 / 30.6001) as i32; + let x1 = (30.6001 * e as f64) as i32; + + let day = (b - d - x1) as u32; + let month = if e < 14 { e - 1 } else { e - 13 } as u32; + let year = if month > 2 { c - 4716 } else { c - 4715 }; + + // Compute the time (Hour, Minute, Second) from iJD + // day_ms = (int)((iJD + 43200000) % 86400000) + let day_ms = ((i_jd + 43200000) % 86400000) as i32; + + // s = (day_ms % 60000) / 1000.0 + let s_millis = day_ms % 60000; + let seconds = (s_millis / 1000) as u32; + let millis = (s_millis % 1000) as u32; + + // day_min = day_ms / 60000 + let day_min = day_ms / 60000; + let minutes = (day_min % 60) as u32; + let hours = (day_min / 60) as u32; + + // Create the date + let date = NaiveDate::from_ymd_opt(year, month, day) + .ok_or_else(|| crate::LimboError::InternalError("Invalid date".to_string()))?; + + // Create time with millisecond precision converted to nanoseconds + let nanos = millis * 1_000_000; + let time = NaiveTime::from_hms_nano_opt(hours, minutes, seconds, nanos) + .ok_or_else(|| crate::LimboError::InternalError("Invalid time".to_string()))?; + + Ok(NaiveDateTime::new(date, time)) } fn is_leap_second(dt: &NaiveDateTime) -> bool { @@ -1584,17 +1640,15 @@ mod tests { assert_eq!(weekday_sunday_based(&dt), 5); } - #[allow(deprecated)] #[test] fn test_apply_modifier_julianday() { - use julian_day_converter::*; - let dt = create_datetime(2000, 1, 1, 12, 0, 0); - let julian_day = &dt.to_jd(); - let mut dt_result = NaiveDateTime::default(); - if let Some(ndt) = JulianDay::from_jd(*julian_day) { - dt_result = ndt; - } + + // Convert datetime to julian day using our implementation + let julian_day_value = to_julian_day_exact(&dt); + + // Convert back + let dt_result = julian_day_to_datetime(julian_day_value).unwrap(); assert_eq!(dt_result, dt); } diff --git a/core/incremental/aggregate_operator.rs b/core/incremental/aggregate_operator.rs index e24b930a9..a6d0fb5b8 100644 --- a/core/incremental/aggregate_operator.rs +++ b/core/incremental/aggregate_operator.rs @@ -7,15 +7,78 @@ use crate::incremental::operator::{ generate_storage_id, ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, }; use crate::incremental::persistence::{ReadRecord, WriteRow}; -use crate::types::{IOResult, ImmutableRecord, RefValue, SeekKey, SeekOp, SeekResult}; +use crate::storage::btree::CursorTrait; +use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult, ValueRef}; use crate::{return_and_restore_if_io, return_if_io, LimboError, Result, Value}; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{self, Display}; use std::sync::{Arc, Mutex}; +// Architecture of the Aggregate Operator +// ======================================== +// +// This operator implements SQL aggregations (GROUP BY, DISTINCT, COUNT, SUM, AVG, MIN, MAX) +// using DBSP-style incremental computation. The key insight is that all these operations +// can be expressed as operations on weighted sets (Z-sets) stored in persistent BTrees. +// +// ## Storage Strategy +// +// We use three different storage encodings (identified by 2-bit type codes in storage IDs): +// - **Regular aggregates** (COUNT/SUM/AVG): Store accumulated state as a blob +// - **MIN/MAX aggregates**: Store individual values; BTree ordering gives us min/max efficiently +// - **DISTINCT tracking**: Store distinct values with weights (positive = present, zero = deleted) +// +// ## MIN/MAX Handling +// +// MIN/MAX are special because they're not fully incrementalizable: +// - **Inserts**: Can be computed incrementally (new_min = min(old_min, new_value)) +// - **Deletes**: Must recompute from the BTree when the current min/max is deleted +// +// Our approach: +// 1. Store each value with its weight in a BTree (leveraging natural ordering) +// 2. On insert: Simply compare with current min/max (incremental) +// 3. On delete of current min/max: Scan the BTree to find the next min/max +// - For MIN: scan forward from the beginning to find first value with positive weight +// - For MAX: scan backward from the end to find last value with positive weight +// +// ## DISTINCT Handling +// +// DISTINCT operations (COUNT(DISTINCT), SUM(DISTINCT), etc.) are implemented using the +// weighted set pattern: +// - Each distinct value is stored with a weight (occurrence count) +// - Weight > 0 means the value exists in the current dataset +// - Weight = 0 means the value has been deleted (we may clean these up) +// - We track transitions: when a value's weight crosses zero (appears/disappears) +// +// ## Plain DISTINCT (SELECT DISTINCT) +// +// A clever reuse of infrastructure: SELECT DISTINCT x, y, z is compiled to: +// - GROUP BY x, y, z (making each unique row combination a group) +// - Empty aggregates vector (no actual aggregations to compute) +// - The groups themselves become the distinct rows +// +// This allows us to reuse all the incremental machinery for DISTINCT without special casing. +// The `is_distinct_only` flag indicates this pattern, where the groups ARE the output rows. +// +// ## State Machines +// +// The operator uses async-ready state machines to handle I/O operations: +// - **Eval state machine**: Fetches existing state, applies deltas, recomputes MIN/MAX +// - **Commit state machine**: Persists updated state back to storage +// - Each state represents a resumption point for when I/O operations yield + /// Constants for aggregate type encoding in storage IDs (2 bits) pub const AGG_TYPE_REGULAR: u8 = 0b00; // COUNT/SUM/AVG pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) +pub const AGG_TYPE_DISTINCT: u8 = 0b10; // DISTINCT values tracking + +/// Hash a Value to generate an element_id for DISTINCT storage +/// Uses HashableRow with column_idx as rowid for consistent hashing +fn hash_value(value: &Value, column_idx: usize) -> Hash128 { + // Use column_idx as rowid to ensure different columns with same value get different hashes + let row = HashableRow::new(column_idx as i64, vec![value.clone()]); + row.cached_hash() +} // Serialization type codes for aggregate functions const AGG_FUNC_COUNT: i64 = 0; @@ -23,22 +86,31 @@ const AGG_FUNC_SUM: i64 = 1; const AGG_FUNC_AVG: i64 = 2; const AGG_FUNC_MIN: i64 = 3; const AGG_FUNC_MAX: i64 = 4; +const AGG_FUNC_COUNT_DISTINCT: i64 = 5; +const AGG_FUNC_SUM_DISTINCT: i64 = 6; +const AGG_FUNC_AVG_DISTINCT: i64 = 7; #[derive(Debug, Clone, PartialEq)] pub enum AggregateFunction { Count, - Sum(usize), // Column index - Avg(usize), // Column index - Min(usize), // Column index - Max(usize), // Column index + CountDistinct(usize), // COUNT(DISTINCT column_index) + Sum(usize), // Column index + SumDistinct(usize), // SUM(DISTINCT column_index) + Avg(usize), // Column index + AvgDistinct(usize), // AVG(DISTINCT column_index) + Min(usize), // Column index + Max(usize), // Column index } impl Display for AggregateFunction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { AggregateFunction::Count => write!(f, "COUNT(*)"), + AggregateFunction::CountDistinct(idx) => write!(f, "COUNT(DISTINCT col{idx})"), AggregateFunction::Sum(idx) => write!(f, "SUM(col{idx})"), + AggregateFunction::SumDistinct(idx) => write!(f, "SUM(DISTINCT col{idx})"), AggregateFunction::Avg(idx) => write!(f, "AVG(col{idx})"), + AggregateFunction::AvgDistinct(idx) => write!(f, "AVG(DISTINCT col{idx})"), AggregateFunction::Min(idx) => write!(f, "MIN(col{idx})"), AggregateFunction::Max(idx) => write!(f, "MAX(col{idx})"), } @@ -57,12 +129,30 @@ impl AggregateFunction { pub fn to_values(&self) -> Vec { match self { AggregateFunction::Count => vec![Value::Integer(AGG_FUNC_COUNT)], + AggregateFunction::CountDistinct(idx) => { + vec![ + Value::Integer(AGG_FUNC_COUNT_DISTINCT), + Value::Integer(*idx as i64), + ] + } AggregateFunction::Sum(idx) => { vec![Value::Integer(AGG_FUNC_SUM), Value::Integer(*idx as i64)] } + AggregateFunction::SumDistinct(idx) => { + vec![ + Value::Integer(AGG_FUNC_SUM_DISTINCT), + Value::Integer(*idx as i64), + ] + } AggregateFunction::Avg(idx) => { vec![Value::Integer(AGG_FUNC_AVG), Value::Integer(*idx as i64)] } + AggregateFunction::AvgDistinct(idx) => { + vec![ + Value::Integer(AGG_FUNC_AVG_DISTINCT), + Value::Integer(*idx as i64), + ] + } AggregateFunction::Min(idx) => { vec![Value::Integer(AGG_FUNC_MIN), Value::Integer(*idx as i64)] } @@ -84,6 +174,20 @@ impl AggregateFunction { *cursor += 1; AggregateFunction::Count } + Value::Integer(AGG_FUNC_COUNT_DISTINCT) => { + *cursor += 1; + let idx = values.get(*cursor).ok_or_else(|| { + LimboError::InternalError("Missing COUNT(DISTINCT) column index".into()) + })?; + if let Value::Integer(idx) = idx { + *cursor += 1; + AggregateFunction::CountDistinct(*idx as usize) + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for COUNT(DISTINCT) column index, got {idx:?}" + ))); + } + } Value::Integer(AGG_FUNC_SUM) => { *cursor += 1; let idx = values @@ -98,6 +202,20 @@ impl AggregateFunction { ))); } } + Value::Integer(AGG_FUNC_SUM_DISTINCT) => { + *cursor += 1; + let idx = values.get(*cursor).ok_or_else(|| { + LimboError::InternalError("Missing SUM(DISTINCT) column index".into()) + })?; + if let Value::Integer(idx) = idx { + *cursor += 1; + AggregateFunction::SumDistinct(*idx as usize) + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for SUM(DISTINCT) column index, got {idx:?}" + ))); + } + } Value::Integer(AGG_FUNC_AVG) => { *cursor += 1; let idx = values @@ -112,6 +230,20 @@ impl AggregateFunction { ))); } } + Value::Integer(AGG_FUNC_AVG_DISTINCT) => { + *cursor += 1; + let idx = values.get(*cursor).ok_or_else(|| { + LimboError::InternalError("Missing AVG(DISTINCT) column index".into()) + })?; + if let Value::Integer(idx) = idx { + *cursor += 1; + AggregateFunction::AvgDistinct(*idx as usize) + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for AVG(DISTINCT) column index, got {idx:?}" + ))); + } + } Value::Integer(AGG_FUNC_MIN) => { *cursor += 1; let idx = values @@ -188,6 +320,27 @@ type ComputedStates = HashMap, AggregateState)>; // group_key_str -> (column_index, value_as_hashable_row) -> accumulated_weight pub type MinMaxDeltas = HashMap>; +/// Type for tracking distinct values within a batch +/// Maps: group_key_str -> (column_idx, HashableRow) -> accumulated_weight +/// HashableRow contains the value with column_idx as rowid for proper hashing +type DistinctDeltas = HashMap>; + +/// Return type for merge_delta_with_existing function +type MergeResult = (Delta, HashMap, AggregateState)>); + +/// Information about distinct value transitions for a single column +#[derive(Debug, Clone)] +pub struct DistinctTransition { + pub transition_type: TransitionType, + pub transitioned_value: Value, // The value that was added/removed +} + +#[derive(Debug, Clone, PartialEq)] +pub enum TransitionType { + Added, // Value added to distinct set + Removed, // Value removed from distinct set +} + #[derive(Debug)] enum AggregateCommitState { Idle, @@ -197,13 +350,21 @@ enum AggregateCommitState { PersistDelta { delta: Delta, computed_states: ComputedStates, + old_states: HashMap, // Track old counts for plain DISTINCT current_idx: usize, write_row: WriteRow, min_max_deltas: MinMaxDeltas, + distinct_deltas: DistinctDeltas, + input_delta: Delta, // Keep original input delta for distinct processing }, PersistMinMax { delta: Delta, min_max_persist_state: MinMaxPersistState, + distinct_deltas: DistinctDeltas, + }, + PersistDistinctValues { + delta: Delta, + distinct_persist_state: DistinctPersistState, }, Done { delta: Delta, @@ -220,8 +381,9 @@ pub enum AggregateEvalState { groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access existing_groups: HashMap, old_values: HashMap>, + pre_existing_groups: HashSet, // Track groups that existed before this delta }, - FetchData { + FetchAggregateState { delta: Delta, // Keep original delta for merge operation current_idx: usize, groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access @@ -229,12 +391,23 @@ pub enum AggregateEvalState { old_values: HashMap>, rowid: Option, // Rowid found by FetchKey (None if not found) read_record_state: Box, + pre_existing_groups: HashSet, // Track groups that existed before this delta + }, + FetchDistinctValues { + delta: Delta, // Keep original delta for merge operation + current_idx: usize, + groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access + existing_groups: HashMap, + old_values: HashMap>, + fetch_distinct_state: Box, + pre_existing_groups: HashSet, // Track groups that existed before this delta }, RecomputeMinMax { delta: Delta, existing_groups: HashMap, old_values: HashMap>, recompute_state: Box, + pre_existing_groups: HashSet, // Track groups that existed before this delta }, Done { output: (Delta, ComputedStates), @@ -256,10 +429,15 @@ pub struct AggregateOperator { pub input_column_names: Vec, // Map from column index to aggregate info for quick lookup pub column_min_max: HashMap, + // Set of column indices that have distinct aggregates + pub distinct_columns: HashSet, tracker: Option>>, // State machine for commit operation commit_state: AggregateCommitState, + + // SELECT DISTINCT x,y,z.... with no aggregations. + is_distinct_only: bool, } /// State for a single group's aggregates @@ -275,9 +453,34 @@ pub struct AggregateState { pub mins: HashMap, // For MAX: column_index -> maximum value pub maxs: HashMap, + // For DISTINCT aggregates: column_index -> computed value + // These are populated during eval when we scan the BTree (or in-memory map) + pub distinct_counts: HashMap, + pub distinct_sums: HashMap, + + // Weights of specific distinct values needed for current delta processing + // (column_index, value) -> weight + // Populated during FetchKey for values mentioned in the delta + pub(crate) distinct_value_weights: HashMap<(usize, HashableRow), i64>, } impl AggregateEvalState { + /// Process a delta through the aggregate state machine. + /// + /// Control flow is strictly linear for maintainability: + /// 1. FetchKey → FetchAggregateState (always) + /// 2. FetchAggregateState → FetchKey (always, loops until all groups processed) + /// 3. FetchKey (when done) → FetchDistinctValues (always) + /// 4. FetchDistinctValues → RecomputeMinMax (always) + /// 5. RecomputeMinMax → Done (always) + /// + /// Some states may be no-ops depending on the operator configuration: + /// - FetchAggregateState: For plain DISTINCT, skips reading aggregate blob (no aggregates to fetch) + /// - FetchDistinctValues: No-op if no distinct columns exist (distinct_columns is empty) + /// - RecomputeMinMax: No-op if no MIN/MAX aggregates exist (has_min_max() returns false) + /// + /// This deterministic flow ensures each state always transitions to the same next state, + /// making the state machine easier to understand and debug. fn process_delta( &mut self, operator: &mut AggregateOperator, @@ -291,40 +494,47 @@ impl AggregateEvalState { groups_to_read, existing_groups, old_values, + pre_existing_groups, } => { if *current_idx >= groups_to_read.len() { - // All groups have been fetched, move to RecomputeMinMax - // Extract MIN/MAX deltas from the input delta - let min_max_deltas = operator.extract_min_max_deltas(delta); - - let recompute_state = Box::new(RecomputeMinMax::new( - min_max_deltas, + // All groups have been fetched, move to FetchDistinctValues + // Create FetchDistinctState based on the delta and existing groups + let fetch_distinct_state = FetchDistinctState::new( + delta, + &operator.distinct_columns, + |values| operator.extract_group_key(values), + AggregateOperator::group_key_to_string, existing_groups, - operator, - )); + operator.is_distinct_only, + ); - *self = AggregateEvalState::RecomputeMinMax { + *self = AggregateEvalState::FetchDistinctValues { delta: std::mem::take(delta), + current_idx: 0, + groups_to_read: std::mem::take(groups_to_read), existing_groups: std::mem::take(existing_groups), old_values: std::mem::take(old_values), - recompute_state, + fetch_distinct_state: Box::new(fetch_distinct_state), + pre_existing_groups: std::mem::take(pre_existing_groups), }; } else { // Get the current group to read let (group_key_str, _group_key) = &groups_to_read[*current_idx]; - // Build the key for the index: (operator_id, zset_hash, element_id) - // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR + // For plain DISTINCT, we still need to transition to FetchAggregateState + // to add the group to existing_groups, but we won't read any aggregate blob + + // Build the key for regular aggregate state: (operator_id, zset_hash, element_id=0) let operator_storage_id = generate_storage_id(operator.operator_id, 0, AGG_TYPE_REGULAR); let zset_hash = operator.generate_group_hash(group_key_str); - let element_id = 0i64; // Always 0 for aggregators + let element_id = Hash128::new(0, 0); // Always zeros for aggregate state // Create index key values let index_key_values = vec![ Value::Integer(operator_storage_id), zset_hash.to_value(), - Value::Integer(element_id), + element_id.to_value(), ]; // Create an immutable record for the index key @@ -346,10 +556,10 @@ impl AggregateEvalState { None }; - // Always transition to FetchData + // Always transition to FetchAggregateState let taken_existing = std::mem::take(existing_groups); let taken_old_values = std::mem::take(old_values); - let next_state = AggregateEvalState::FetchData { + let next_state = AggregateEvalState::FetchAggregateState { delta: std::mem::take(delta), current_idx: *current_idx, groups_to_read: std::mem::take(groups_to_read), @@ -357,11 +567,12 @@ impl AggregateEvalState { old_values: taken_old_values, rowid, read_record_state: Box::new(ReadRecord::new()), + pre_existing_groups: std::mem::take(pre_existing_groups), // Pass through existing }; *self = next_state; } } - AggregateEvalState::FetchData { + AggregateEvalState::FetchAggregateState { delta, current_idx, groups_to_read, @@ -369,13 +580,20 @@ impl AggregateEvalState { old_values, rowid, read_record_state, + pre_existing_groups, } => { // Get the current group to read let (group_key_str, group_key) = &groups_to_read[*current_idx]; - // Only try to read if we have a rowid - if let Some(rowid) = rowid { + // For plain DISTINCT, skip aggregate state fetch entirely + // The distinct values are handled separately in FetchDistinctValues + if operator.is_distinct_only { + // Always insert the group key so FetchDistinctState will process it + // The count will be set properly when we fetch distinct values + existing_groups.insert(group_key_str.clone(), AggregateState::default()); + } else if let Some(rowid) = rowid { let key = SeekKey::TableRowId(*rowid); + // Regular aggregates - read the blob let state = return_if_io!( read_record_state.read_record(key, &mut cursors.table_cursor) ); @@ -384,23 +602,75 @@ impl AggregateEvalState { let mut old_row = group_key.clone(); old_row.extend(state.to_values(&operator.aggregates)); old_values.insert(group_key_str.clone(), old_row); - existing_groups.insert(group_key_str.clone(), state.clone()); + existing_groups.insert(group_key_str.clone(), state); + // Track that this group exists in storage + pre_existing_groups.insert(group_key_str.clone()); } - } else { - // No rowid for this group, skipping read } // If no rowid, there's no existing state for this group - // Move to next group + // Always move to next group via FetchKey let next_idx = *current_idx + 1; + let taken_existing = std::mem::take(existing_groups); let taken_old_values = std::mem::take(old_values); + let taken_pre_existing_groups = std::mem::take(pre_existing_groups); let next_state = AggregateEvalState::FetchKey { delta: std::mem::take(delta), current_idx: next_idx, groups_to_read: std::mem::take(groups_to_read), existing_groups: taken_existing, old_values: taken_old_values, + pre_existing_groups: taken_pre_existing_groups, + }; + *self = next_state; + } + AggregateEvalState::FetchDistinctValues { + delta, + current_idx: _, + groups_to_read: _, + existing_groups, + old_values, + fetch_distinct_state, + pre_existing_groups, + } => { + // Use FetchDistinctState to read distinct values from BTree storage + return_if_io!(fetch_distinct_state.fetch_distinct_values( + operator.operator_id, + existing_groups, + cursors, + |group_key| operator.generate_group_hash(group_key), + operator.is_distinct_only + )); + + // For plain DISTINCT, mark groups as "from storage" if they have distinct values + if operator.is_distinct_only { + for (group_key_str, state) in existing_groups.iter() { + // Check if this group has any distinct values with positive weight + let has_values = state.distinct_value_weights.values().any(|&w| w > 0); + if has_values { + pre_existing_groups.insert(group_key_str.clone()); + } + } + } + + // Extract MIN/MAX deltas for recomputation + let min_max_deltas = operator.extract_min_max_deltas(delta); + + // Create RecomputeMinMax before moving existing_groups + let recompute_state = Box::new(RecomputeMinMax::new( + min_max_deltas, + existing_groups, + operator, + )); + + // Transition to RecomputeMinMax + let next_state = AggregateEvalState::RecomputeMinMax { + delta: std::mem::take(delta), + existing_groups: std::mem::take(existing_groups), + old_values: std::mem::take(old_values), + recompute_state, + pre_existing_groups: std::mem::take(pre_existing_groups), }; *self = next_state; } @@ -409,6 +679,7 @@ impl AggregateEvalState { existing_groups, old_values, recompute_state, + pre_existing_groups, } => { if operator.has_min_max() { // Process MIN/MAX recomputation - this will update existing_groups with correct MIN/MAX @@ -416,15 +687,20 @@ impl AggregateEvalState { } // Now compute final output with updated MIN/MAX values - let (output_delta, computed_states) = - operator.merge_delta_with_existing(delta, existing_groups, old_values); + let (output_delta, computed_states) = operator.merge_delta_with_existing( + delta, + existing_groups, + old_values, + pre_existing_groups, + ); *self = AggregateEvalState::Done { output: (output_delta, computed_states), }; } AggregateEvalState::Done { output } => { - return Ok(IOResult::Done(output.clone())); + let (delta, computed_states) = output.clone(); + return Ok(IOResult::Done((delta, computed_states))); } } } @@ -458,15 +734,34 @@ impl AggregateState { AggregateFunction::Count => { // Count state is already stored at the beginning } + AggregateFunction::CountDistinct(col_idx) => { + // Store the distinct count for this column + let count = self.distinct_counts.get(col_idx).copied().unwrap_or(0); + values.push(Value::Integer(count)); + } AggregateFunction::Sum(col_idx) => { let sum = self.sums.get(col_idx).copied().unwrap_or(0.0); values.push(Value::Float(sum)); } + AggregateFunction::SumDistinct(col_idx) => { + // Store both the distinct count and sum for this column + let count = self.distinct_counts.get(col_idx).copied().unwrap_or(0); + let sum = self.distinct_sums.get(col_idx).copied().unwrap_or(0.0); + values.push(Value::Integer(count)); + values.push(Value::Float(sum)); + } AggregateFunction::Avg(col_idx) => { let (sum, count) = self.avgs.get(col_idx).copied().unwrap_or((0.0, 0)); values.push(Value::Float(sum)); values.push(Value::Integer(count)); } + AggregateFunction::AvgDistinct(col_idx) => { + // Store both the distinct count and sum for this column + let count = self.distinct_counts.get(col_idx).copied().unwrap_or(0); + let sum = self.distinct_sums.get(col_idx).copied().unwrap_or(0.0); + values.push(Value::Integer(count)); + values.push(Value::Float(sum)); + } AggregateFunction::Min(col_idx) => { if let Some(min_val) = self.mins.get(col_idx) { values.push(Value::Integer(1)); // Has value @@ -531,6 +826,69 @@ impl AggregateState { AggregateFunction::Count => { // Count state is already stored at the beginning } + AggregateFunction::CountDistinct(col_idx) => { + let count = values.get(cursor).ok_or_else(|| { + LimboError::InternalError("Missing COUNT(DISTINCT) value".into()) + })?; + if let Value::Integer(count) = count { + state.distinct_counts.insert(col_idx, *count); + cursor += 1; + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for COUNT(DISTINCT) value, got {count:?}" + ))); + } + } + AggregateFunction::SumDistinct(col_idx) => { + let count = values.get(cursor).ok_or_else(|| { + LimboError::InternalError("Missing SUM(DISTINCT) count".into()) + })?; + if let Value::Integer(count) = count { + state.distinct_counts.insert(col_idx, *count); + cursor += 1; + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for SUM(DISTINCT) count, got {count:?}" + ))); + } + + let sum = values.get(cursor).ok_or_else(|| { + LimboError::InternalError("Missing SUM(DISTINCT) sum".into()) + })?; + if let Value::Float(sum) = sum { + state.distinct_sums.insert(col_idx, *sum); + cursor += 1; + } else { + return Err(LimboError::InternalError(format!( + "Expected Float for SUM(DISTINCT) sum, got {sum:?}" + ))); + } + } + AggregateFunction::AvgDistinct(col_idx) => { + let count = values.get(cursor).ok_or_else(|| { + LimboError::InternalError("Missing AVG(DISTINCT) count".into()) + })?; + if let Value::Integer(count) = count { + state.distinct_counts.insert(col_idx, *count); + cursor += 1; + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for AVG(DISTINCT) count, got {count:?}" + ))); + } + + let sum = values.get(cursor).ok_or_else(|| { + LimboError::InternalError("Missing AVG(DISTINCT) sum".into()) + })?; + if let Value::Float(sum) = sum { + state.distinct_sums.insert(col_idx, *sum); + cursor += 1; + } else { + return Err(LimboError::InternalError(format!( + "Expected Float for AVG(DISTINCT) sum, got {sum:?}" + ))); + } + } AggregateFunction::Sum(col_idx) => { let sum = values .get(cursor) @@ -688,16 +1046,72 @@ impl AggregateState { weight: isize, aggregates: &[AggregateFunction], _column_names: &[String], // No longer needed + distinct_transitions: &HashMap, ) { // Update COUNT self.count += weight as i64; - // Update other aggregates + // Track which columns have had their distinct counts/sums updated + // This prevents double-counting when multiple distinct aggregates + // operate on the same column (e.g., COUNT(DISTINCT col), SUM(DISTINCT col), AVG(DISTINCT col)) + let mut processed_counts: HashSet = HashSet::new(); + let mut processed_sums: HashSet = HashSet::new(); + + // Update distinct aggregate state for agg in aggregates { match agg { AggregateFunction::Count => { // Already handled above } + AggregateFunction::CountDistinct(col_idx) => { + // Only update count if we haven't processed this column yet + if !processed_counts.contains(col_idx) { + if let Some(transition) = distinct_transitions.get(col_idx) { + let current_count = + self.distinct_counts.get(col_idx).copied().unwrap_or(0); + let new_count = match transition.transition_type { + TransitionType::Added => current_count + 1, + TransitionType::Removed => current_count - 1, + }; + self.distinct_counts.insert(*col_idx, new_count); + processed_counts.insert(*col_idx); + } + } + } + AggregateFunction::SumDistinct(col_idx) + | AggregateFunction::AvgDistinct(col_idx) => { + if let Some(transition) = distinct_transitions.get(col_idx) { + // Update count if not already processed (needed for AVG) + if !processed_counts.contains(col_idx) { + let current_count = + self.distinct_counts.get(col_idx).copied().unwrap_or(0); + let new_count = match transition.transition_type { + TransitionType::Added => current_count + 1, + TransitionType::Removed => current_count - 1, + }; + self.distinct_counts.insert(*col_idx, new_count); + processed_counts.insert(*col_idx); + } + + // Update sum if not already processed + if !processed_sums.contains(col_idx) { + let current_sum = + self.distinct_sums.get(col_idx).copied().unwrap_or(0.0); + let value_as_float = match &transition.transitioned_value { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + + let new_sum = match transition.transition_type { + TransitionType::Added => current_sum + value_as_float, + TransitionType::Removed => current_sum - value_as_float, + }; + self.distinct_sums.insert(*col_idx, new_sum); + processed_sums.insert(*col_idx); + } + } + } AggregateFunction::Sum(col_idx) => { if let Some(val) = values.get(*col_idx) { let num_val = match val { @@ -751,6 +1165,14 @@ impl AggregateState { } /// Convert aggregate state to output values + /// + /// Note: SQLite returns INTEGER for SUM when all inputs are integers, and REAL when any input is REAL. + /// However, in an incremental system like DBSP, we cannot track whether all current values are integers + /// after deletions. For example: + /// - Initial: SUM(10, 20, 30.5) = 60.5 (REAL) + /// - After DELETE 30.5: SUM(10, 20) = 30 (SQLite returns INTEGER, but we only know the sum is 30.0) + /// + /// Therefore, we always return REAL for SUM operations. pub fn to_values(&self, aggregates: &[AggregateFunction]) -> Vec { let mut result = Vec::new(); @@ -759,14 +1181,19 @@ impl AggregateState { AggregateFunction::Count => { result.push(Value::Integer(self.count)); } + AggregateFunction::CountDistinct(col_idx) => { + // Return the computed DISTINCT count + let count = self.distinct_counts.get(col_idx).copied().unwrap_or(0); + result.push(Value::Integer(count)); + } AggregateFunction::Sum(col_idx) => { let sum = self.sums.get(col_idx).copied().unwrap_or(0.0); - // Return as integer if it's a whole number, otherwise as float - if sum.fract() == 0.0 { - result.push(Value::Integer(sum as i64)); - } else { - result.push(Value::Float(sum)); - } + result.push(Value::Float(sum)); + } + AggregateFunction::SumDistinct(col_idx) => { + // Return the computed SUM(DISTINCT) + let sum = self.distinct_sums.get(col_idx).copied().unwrap_or(0.0); + result.push(Value::Float(sum)); } AggregateFunction::Avg(col_idx) => { if let Some((sum, count)) = self.avgs.get(col_idx) { @@ -779,6 +1206,18 @@ impl AggregateState { result.push(Value::Null); } } + AggregateFunction::AvgDistinct(col_idx) => { + // Compute AVG from SUM(DISTINCT) / COUNT(DISTINCT) + let count = self.distinct_counts.get(col_idx).copied().unwrap_or(0); + if count > 0 { + let sum = self.distinct_sums.get(col_idx).copied().unwrap_or(0.0); + let avg = sum / count as f64; + // AVG always returns a float value for consistency with SQLite + result.push(Value::Float(avg)); + } else { + result.push(Value::Null); + } + } AggregateFunction::Min(col_idx) => { // Return the MIN value from our state result.push(self.mins.get(col_idx).cloned().unwrap_or(Value::Null)); @@ -795,12 +1234,109 @@ impl AggregateState { } impl AggregateOperator { + /// Detect if a distinct value crosses the zero boundary (using pre-fetched weights and batch-accumulated weights) + fn detect_distinct_value_transition( + col_idx: usize, + val: &Value, + weight: isize, + existing_state: &AggregateState, + group_distinct_deltas: Option<&HashMap<(usize, HashableRow), isize>>, + ) -> Option { + let hashable_row = HashableRow::new(col_idx as i64, vec![val.clone()]); + + // Get the weight from storage (pre-fetched in AggregateState) + let storage_count = existing_state + .distinct_value_weights + .get(&(col_idx, hashable_row.clone())) + .copied() + .unwrap_or(0); + + // Get the accumulated weight from the current batch (before this row) + let batch_accumulated = if let Some(deltas) = group_distinct_deltas { + deltas + .get(&(col_idx, hashable_row.clone())) + .copied() + .unwrap_or(0) + } else { + 0 + }; + + // The old count is storage + batch accumulated so far (before this row) + let old_count = storage_count + batch_accumulated as i64; + // The new count includes the current weight + let new_count = old_count + weight as i64; + + // Detect transitions + if old_count <= 0 && new_count > 0 { + // Value added to distinct set + Some(DistinctTransition { + transition_type: TransitionType::Added, + transitioned_value: val.clone(), + }) + } else if old_count > 0 && new_count <= 0 { + // Value removed from distinct set + Some(DistinctTransition { + transition_type: TransitionType::Removed, + transitioned_value: val.clone(), + }) + } else { + // No transition + None + } + } + + /// Detect distinct value transitions for a single row + fn detect_distinct_transitions( + &self, + row_values: &[Value], + weight: isize, + existing_state: &AggregateState, + group_distinct_deltas: Option<&HashMap<(usize, HashableRow), isize>>, + ) -> HashMap { + let mut transitions = HashMap::new(); + + // Plain Distinct doesn't track individual values, so no transitions needed + if self.is_distinct_only { + // Distinct is handled by the count alone in apply_delta + return transitions; + } + + // Process each distinct column + for &col_idx in &self.distinct_columns { + let val = match row_values.get(col_idx) { + Some(v) => v, + None => continue, + }; + + // Skip null values + if val == &Value::Null { + continue; + } + + if let Some(transition) = Self::detect_distinct_value_transition( + col_idx, + val, + weight, + existing_state, + group_distinct_deltas, + ) { + transitions.insert(col_idx, transition); + } + } + + transitions + } + pub fn new( operator_id: i64, group_by: Vec, aggregates: Vec, input_column_names: Vec, ) -> Self { + // Precompute flags for runtime efficiency + // Plain DISTINCT is indicated by empty aggregates vector + let is_distinct_only = aggregates.is_empty(); + // Build map of column indices to their MIN/MAX info let mut column_min_max = HashMap::new(); let mut storage_indices = HashMap::new(); @@ -820,7 +1356,7 @@ impl AggregateOperator { } } - // Second pass: build the column info map + // Second pass: build the column info map for MIN/MAX for agg in &aggregates { match agg { AggregateFunction::Min(col_idx) => { @@ -845,14 +1381,29 @@ impl AggregateOperator { } } + // Build the distinct columns set + let mut distinct_columns = HashSet::new(); + for agg in &aggregates { + match agg { + AggregateFunction::CountDistinct(col_idx) + | AggregateFunction::SumDistinct(col_idx) + | AggregateFunction::AvgDistinct(col_idx) => { + distinct_columns.insert(*col_idx); + } + _ => {} + } + } + Self { operator_id, group_by, aggregates, input_column_names, column_min_max, + distinct_columns, tracker: None, commit_state: AggregateCommitState::Idle, + is_distinct_only, } } @@ -860,6 +1411,11 @@ impl AggregateOperator { !self.column_min_max.is_empty() } + /// Check if this operator has any DISTINCT aggregates or plain DISTINCT + pub fn has_distinct(&self) -> bool { + !self.distinct_columns.is_empty() || self.is_distinct_only + } + fn eval_internal( &mut self, state: &mut EvalState, @@ -895,6 +1451,7 @@ impl AggregateOperator { groups_to_read: groups_to_read.into_iter().collect(), existing_groups: HashMap::new(), old_values: HashMap::new(), + pre_existing_groups: HashSet::new(), // Initialize empty })); } EvalState::Aggregate(_agg_state) => { @@ -923,12 +1480,17 @@ impl AggregateOperator { delta: &Delta, existing_groups: &mut HashMap, old_values: &mut HashMap>, - ) -> (Delta, HashMap, AggregateState)>) { + pre_existing_groups: &HashSet, + ) -> MergeResult { let mut output_delta = Delta::new(); let mut temp_keys: HashMap> = HashMap::new(); + // Track distinct value weights as we process the batch + let mut batch_distinct_weights: HashMap> = + HashMap::new(); + // Process each change in the delta - for (row, weight) in &delta.changes { + for (row, weight) in delta.changes.iter() { if let Some(tracker) = &self.tracker { tracker.lock().unwrap().record_aggregation(); } @@ -937,50 +1499,159 @@ impl AggregateOperator { let group_key = self.extract_group_key(&row.values); let group_key_str = Self::group_key_to_string(&group_key); + // Get or create the state for this group let state = existing_groups.entry(group_key_str.clone()).or_default(); + // Get batch weights for this group + let group_batch_weights = batch_distinct_weights.get(&group_key_str); + + // Detect distinct transitions using the existing state and batch-accumulated weights + let distinct_transitions = if self.has_distinct() { + self.detect_distinct_transitions(&row.values, *weight, state, group_batch_weights) + } else { + HashMap::new() + }; + + // Update batch weights after detecting transitions + if self.has_distinct() { + for &col_idx in &self.distinct_columns { + if let Some(val) = row.values.get(col_idx) { + if val != &Value::Null { + let hashable_row = HashableRow::new(col_idx as i64, vec![val.clone()]); + let group_entry = batch_distinct_weights + .entry(group_key_str.clone()) + .or_default(); + let weight_entry = + group_entry.entry((col_idx, hashable_row)).or_insert(0); + *weight_entry += weight; + } + } + } + } + temp_keys.insert(group_key_str.clone(), group_key.clone()); - // Apply the delta to the temporary state + // Apply the delta to the state with pre-computed transitions state.apply_delta( &row.values, *weight, &self.aggregates, &self.input_column_names, + &distinct_transitions, ); } // Generate output delta from temporary states and collect final states let mut final_states = HashMap::new(); - for (group_key_str, state) in existing_groups { - let group_key = temp_keys.get(group_key_str).cloned().unwrap_or_default(); + for (group_key_str, state) in existing_groups.iter() { + let group_key = if let Some(key) = temp_keys.get(group_key_str) { + key.clone() + } else if let Some(old_row) = old_values.get(group_key_str) { + // Extract group key from old row (first N columns where N = group_by.len()) + old_row[0..self.group_by.len()].to_vec() + } else { + vec![] + }; // Generate synthetic rowid for this group let result_key = self.generate_group_rowid(group_key_str); - if let Some(old_row_values) = old_values.get(group_key_str) { - let old_row = HashableRow::new(result_key, old_row_values.clone()); - output_delta.changes.push((old_row, -1)); - } - // Always store the state for persistence (even if count=0, we need to delete it) final_states.insert(group_key_str.clone(), (group_key.clone(), state.clone())); - // Only include groups with count > 0 in the output delta - if state.count > 0 { - // Build output row: group_by columns + aggregate values - let mut output_values = group_key.clone(); - let aggregate_values = state.to_values(&self.aggregates); - output_values.extend(aggregate_values); + // Check if we only have DISTINCT (no other aggregates) + if self.is_distinct_only { + // For plain DISTINCT, we output each distinct VALUE (not group) + // state.count tells us how many distinct values have positive weight - let output_row = HashableRow::new(result_key, output_values.clone()); - output_delta.changes.push((output_row, 1)); + // Check if this group had any values before + let old_existed = pre_existing_groups.contains(group_key_str); + let new_exists = state.count > 0; + + if old_existed && !new_exists { + // All distinct values removed: output deletion + if let Some(old_row_values) = old_values.get(group_key_str) { + let old_row = HashableRow::new(result_key, old_row_values.clone()); + output_delta.changes.push((old_row, -1)); + } else { + // For plain DISTINCT, the old row is just the group key itself + let old_row = HashableRow::new(result_key, group_key.clone()); + output_delta.changes.push((old_row, -1)); + } + } else if !old_existed && new_exists { + // First distinct value added: output insertion + let output_values = group_key.clone(); + // DISTINCT doesn't add aggregate values - just the group key + let output_row = HashableRow::new(result_key, output_values.clone()); + output_delta.changes.push((output_row, 1)); + } + // No output if staying positive or staying at zero + } else { + // Normal aggregates: output deletions and insertions as before + if let Some(old_row_values) = old_values.get(group_key_str) { + let old_row = HashableRow::new(result_key, old_row_values.clone()); + output_delta.changes.push((old_row, -1)); + } + + // Only include groups with count > 0 in the output delta + if state.count > 0 { + // Build output row: group_by columns + aggregate values + let mut output_values = group_key.clone(); + let aggregate_values = state.to_values(&self.aggregates); + output_values.extend(aggregate_values); + + let output_row = HashableRow::new(result_key, output_values.clone()); + output_delta.changes.push((output_row, 1)); + } } } + (output_delta, final_states) } + /// Extract distinct values from delta changes for batch tracking + fn extract_distinct_deltas(&self, delta: &Delta) -> DistinctDeltas { + let mut distinct_deltas: DistinctDeltas = HashMap::new(); + + for (row, weight) in &delta.changes { + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + + // Get or create entry for this group + let group_entry = distinct_deltas.entry(group_key_str.clone()).or_default(); + + if self.is_distinct_only { + // For plain DISTINCT, the group itself is what we're tracking + // We store a single entry that represents "this group exists N times" + // Use column index 0 with the group_key_str as the value + // For group key, use 0 as column index + let key = ( + 0, + HashableRow::new(0, vec![Value::Text(group_key_str.clone().into())]), + ); + let value_entry = group_entry.entry(key).or_insert(0); + *value_entry += weight; + } else { + // For DISTINCT aggregates, track individual column values + for &col_idx in &self.distinct_columns { + if let Some(val) = row.values.get(col_idx) { + // Skip NULL values + if val == &Value::Null { + continue; + } + + let key = (col_idx, HashableRow::new(col_idx as i64, vec![val.clone()])); + let value_entry = group_entry.entry(key).or_insert(0); + *value_entry += weight; + } + } + } + } + + distinct_deltas + } + /// Extract MIN/MAX values from delta changes for persistence to index fn extract_min_max_deltas(&self, delta: &Delta) -> MinMaxDeltas { let mut min_max_deltas: MinMaxDeltas = HashMap::new(); @@ -1103,14 +1774,25 @@ impl IncrementalOperator for AggregateOperator { self.commit_state = AggregateCommitState::Eval { eval_state }; } AggregateCommitState::Eval { ref mut eval_state } => { - // Extract input delta before eval for MIN/MAX processing - let input_delta = eval_state.extract_delta(); + // Clone the delta for MIN/MAX processing before eval consumes it + // We need to get the delta from the eval_state if it's still in Init + let input_delta = match eval_state { + EvalState::Init { deltas } => deltas.left.clone(), + _ => Delta::new(), // Empty delta if already processed + }; - // Extract MIN/MAX deltas before any I/O operations + // Extract MIN/MAX and DISTINCT deltas before any I/O operations let min_max_deltas = self.extract_min_max_deltas(&input_delta); + // For plain DISTINCT, we need to extract deltas too + let distinct_deltas = if self.has_distinct() || self.is_distinct_only { + self.extract_distinct_deltas(&input_delta) + } else { + HashMap::new() + }; - // Create a new eval state with the same delta - *eval_state = EvalState::from_delta(input_delta.clone()); + // Get old counts before eval modifies the states + // We need to extract this from the eval_state before it's consumed + let old_states = HashMap::new(); // TODO: Extract from eval_state let (output_delta, computed_states) = return_and_restore_if_io!( &mut self.commit_state, @@ -1121,17 +1803,23 @@ impl IncrementalOperator for AggregateOperator { self.commit_state = AggregateCommitState::PersistDelta { delta: output_delta, computed_states, + old_states, current_idx: 0, write_row: WriteRow::new(), - min_max_deltas, // Store for later use + min_max_deltas, // Store for later use + distinct_deltas, // Store for distinct processing + input_delta, // Store original input }; } AggregateCommitState::PersistDelta { delta, computed_states, + old_states, current_idx, write_row, min_max_deltas, + distinct_deltas, + input_delta, } => { let states_vec: Vec<_> = computed_states.iter().collect(); @@ -1140,28 +1828,59 @@ impl IncrementalOperator for AggregateOperator { self.commit_state = AggregateCommitState::PersistMinMax { delta: delta.clone(), min_max_persist_state: MinMaxPersistState::new(min_max_deltas.clone()), + distinct_deltas: distinct_deltas.clone(), }; } else { let (group_key_str, (group_key, agg_state)) = states_vec[*current_idx]; - // Build the key components for the new table structure - // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR + // Skip aggregate state persistence for plain DISTINCT + // Plain DISTINCT only uses the distinct value weights, not aggregate state + if self.is_distinct_only { + // Skip to next - distinct values are handled in PersistDistinctValues + // We still need to transition states properly + let next_idx = *current_idx + 1; + if next_idx >= states_vec.len() { + // Done with all groups, move to PersistMinMax + self.commit_state = AggregateCommitState::PersistMinMax { + delta: std::mem::take(delta), + min_max_persist_state: MinMaxPersistState::new(std::mem::take( + min_max_deltas, + )), + distinct_deltas: std::mem::take(distinct_deltas), + }; + } else { + // Move to next group + self.commit_state = AggregateCommitState::PersistDelta { + delta: std::mem::take(delta), + computed_states: std::mem::take(computed_states), + old_states: std::mem::take(old_states), + current_idx: next_idx, + write_row: WriteRow::new(), + min_max_deltas: std::mem::take(min_max_deltas), + distinct_deltas: std::mem::take(distinct_deltas), + input_delta: std::mem::take(input_delta), + }; + } + continue; + } + + // Build the key components for regular aggregates let operator_storage_id = generate_storage_id(self.operator_id, 0, AGG_TYPE_REGULAR); let zset_hash = self.generate_group_hash(group_key_str); - let element_id = 0i64; + let element_id = Hash128::new(0, 0); // Always zeros for regular aggregates - // Determine weight: -1 to delete (cancels existing weight=1), 1 to insert/update + // Determine weight: 1 if exists, -1 if deleted let weight = if agg_state.count == 0 { -1 } else { 1 }; - // Serialize the aggregate state with group key (even for deletion, we need a row) + // Serialize the aggregate state (only for regular aggregates, not plain DISTINCT) let state_blob = agg_state.to_blob(&self.aggregates, group_key); let blob_value = Value::Blob(state_blob); // Build the aggregate storage format: [operator_id, zset_hash, element_id, value, weight] let operator_id_val = Value::Integer(operator_storage_id); let zset_hash_val = zset_hash.to_value(); - let element_id_val = Value::Integer(element_id); + let element_id_val = element_id.to_value(); let blob_val = blob_value.clone(); // Create index key - the first 3 columns of our primary key @@ -1184,24 +1903,27 @@ impl IncrementalOperator for AggregateOperator { let delta = std::mem::take(delta); let computed_states = std::mem::take(computed_states); let min_max_deltas = std::mem::take(min_max_deltas); + let distinct_deltas = std::mem::take(distinct_deltas); + let input_delta = std::mem::take(input_delta); self.commit_state = AggregateCommitState::PersistDelta { delta, computed_states, + old_states: std::mem::take(old_states), current_idx: *current_idx + 1, write_row: WriteRow::new(), // Reset for next write min_max_deltas, + distinct_deltas, + input_delta, }; } } AggregateCommitState::PersistMinMax { delta, min_max_persist_state, + distinct_deltas, } => { - if !self.has_min_max() { - let delta = std::mem::take(delta); - self.commit_state = AggregateCommitState::Done { delta }; - } else { + if self.has_min_max() { return_and_restore_if_io!( &mut self.commit_state, state, @@ -1212,10 +1934,37 @@ impl IncrementalOperator for AggregateOperator { |group_key_str| self.generate_group_hash(group_key_str) ) ); - - let delta = std::mem::take(delta); - self.commit_state = AggregateCommitState::Done { delta }; } + + // Transition to PersistDistinctValues + let delta = std::mem::take(delta); + let distinct_deltas = std::mem::take(distinct_deltas); + let distinct_persist_state = DistinctPersistState::new(distinct_deltas); + self.commit_state = AggregateCommitState::PersistDistinctValues { + delta, + distinct_persist_state, + }; + } + AggregateCommitState::PersistDistinctValues { + delta, + distinct_persist_state, + } => { + if self.has_distinct() { + // Use the state machine to persist distinct values to BTree + return_and_restore_if_io!( + &mut self.commit_state, + state, + distinct_persist_state.persist_distinct_values( + self.operator_id, + cursors, + |group_key_str| self.generate_group_hash(group_key_str) + ) + ); + } + + // Transition to Done + let delta = std::mem::take(delta); + self.commit_state = AggregateCommitState::Done { delta }; } AggregateCommitState::Done { delta } => { self.commit_state = AggregateCommitState::Idle; @@ -1546,7 +2295,7 @@ impl ScanState { }; // Check if we're still in the same group - if let RefValue::Integer(rec_sid) = rec_storage_id { + if let ValueRef::Integer(rec_sid) = rec_storage_id { if *rec_sid != storage_id { return Ok(IOResult::Done(None)); } @@ -1555,8 +2304,8 @@ impl ScanState { } // Compare zset_hash as blob - if let RefValue::Blob(rec_zset_blob) = rec_zset_hash { - if let Some(rec_hash) = Hash128::from_blob(rec_zset_blob.to_slice()) { + if let ValueRef::Blob(rec_zset_blob) = rec_zset_hash { + if let Some(rec_hash) = Hash128::from_blob(rec_zset_blob) { if rec_hash != zset_hash { return Ok(IOResult::Done(None)); } @@ -1754,6 +2503,484 @@ pub enum MinMaxPersistState { Done, } +/// State machine for fetching distinct values from BTree storage +#[derive(Debug)] +pub enum FetchDistinctState { + Init { + groups_to_fetch: Vec<(String, HashMap>)>, + }, + FetchGroup { + groups_to_fetch: Vec<(String, HashMap>)>, + group_idx: usize, + value_idx: usize, + values_to_fetch: Vec<(usize, Value)>, + }, + ReadValue { + groups_to_fetch: Vec<(String, HashMap>)>, + group_idx: usize, + value_idx: usize, + values_to_fetch: Vec<(usize, Value)>, + group_key: String, + column_idx: usize, + value: Value, + }, + Done, +} + +impl FetchDistinctState { + /// Add fetch entry for plain DISTINCT - the group itself is the distinct value + fn add_plain_distinct_fetch( + group_entry: &mut HashMap>, + group_key_str: &str, + ) { + let group_value = Value::Text(group_key_str.to_string().into()); + group_entry + .entry(0) + .or_default() + .insert(HashableRow::new(0, vec![group_value])); + } + + /// Add fetch entries for DISTINCT aggregates - individual column values + fn add_aggregate_distinct_fetch( + group_entry: &mut HashMap>, + row_values: &[Value], + distinct_columns: &HashSet, + ) { + for &col_idx in distinct_columns { + if let Some(val) = row_values.get(col_idx) { + if val != &Value::Null { + group_entry + .entry(col_idx) + .or_default() + .insert(HashableRow::new(col_idx as i64, vec![val.clone()])); + } + } + } + } + + pub fn new( + delta: &Delta, + distinct_columns: &HashSet, + extract_group_key: impl Fn(&[Value]) -> Vec, + group_key_to_string: impl Fn(&[Value]) -> String, + existing_groups: &HashMap, + is_plain_distinct: bool, + ) -> Self { + let mut groups_to_fetch: HashMap>> = + HashMap::new(); + + for (row, _weight) in &delta.changes { + let group_key = extract_group_key(&row.values); + let group_key_str = group_key_to_string(&group_key); + + // Skip groups we don't need to fetch + // For DISTINCT aggregates, only fetch for existing groups + if !is_plain_distinct && !existing_groups.contains_key(&group_key_str) { + continue; + } + + let group_entry = groups_to_fetch.entry(group_key_str.clone()).or_default(); + + if is_plain_distinct { + Self::add_plain_distinct_fetch(group_entry, &group_key_str); + } else { + Self::add_aggregate_distinct_fetch(group_entry, &row.values, distinct_columns); + } + } + + let groups_to_fetch: Vec<_> = groups_to_fetch.into_iter().collect(); + + if groups_to_fetch.is_empty() { + Self::Done + } else { + Self::Init { groups_to_fetch } + } + } + + pub fn fetch_distinct_values( + &mut self, + operator_id: i64, + existing_groups: &mut HashMap, + cursors: &mut DbspStateCursors, + generate_group_hash: impl Fn(&str) -> Hash128, + is_plain_distinct: bool, + ) -> Result> { + loop { + match self { + FetchDistinctState::Init { groups_to_fetch } => { + if groups_to_fetch.is_empty() { + *self = FetchDistinctState::Done; + continue; + } + + let groups = std::mem::take(groups_to_fetch); + *self = FetchDistinctState::FetchGroup { + groups_to_fetch: groups, + group_idx: 0, + value_idx: 0, + values_to_fetch: Vec::new(), + }; + } + FetchDistinctState::FetchGroup { + groups_to_fetch, + group_idx, + value_idx, + values_to_fetch, + } => { + if *group_idx >= groups_to_fetch.len() { + *self = FetchDistinctState::Done; + continue; + } + + // Build list of values to fetch for current group if not done + if values_to_fetch.is_empty() && *group_idx < groups_to_fetch.len() { + let (_group_key, cols_values) = &groups_to_fetch[*group_idx]; + for (col_idx, values) in cols_values { + for hashable_row in values { + // Extract the value from HashableRow + values_to_fetch + .push((*col_idx, hashable_row.values.first().unwrap().clone())); + } + } + } + + if *value_idx >= values_to_fetch.len() { + // Move to next group + *group_idx += 1; + *value_idx = 0; + values_to_fetch.clear(); + continue; + } + + // Fetch current value + let (group_key, _) = groups_to_fetch[*group_idx].clone(); + let (column_idx, value) = values_to_fetch[*value_idx].clone(); + + let groups = std::mem::take(groups_to_fetch); + let values = std::mem::take(values_to_fetch); + *self = FetchDistinctState::ReadValue { + groups_to_fetch: groups, + group_idx: *group_idx, + value_idx: *value_idx, + values_to_fetch: values, + group_key, + column_idx, + value, + }; + } + FetchDistinctState::ReadValue { + groups_to_fetch, + group_idx, + value_idx, + values_to_fetch, + group_key, + column_idx, + value, + } => { + // Read the record from BTree using the same pattern as WriteRow: + // 1. Seek in index to find the entry + // 2. Get rowid from index cursor + // 3. Use rowid to read from table cursor + let storage_id = + generate_storage_id(operator_id, *column_idx, AGG_TYPE_DISTINCT); + let zset_hash = generate_group_hash(group_key); + let element_id = hash_value(value, *column_idx); + + // First, seek in the index cursor + let index_key = vec![ + Value::Integer(storage_id), + zset_hash.to_value(), + element_id.to_value(), + ]; + let index_record = ImmutableRecord::from_values(&index_key, index_key.len()); + + let seek_result = return_if_io!(cursors.index_cursor.seek( + SeekKey::IndexKey(&index_record), + SeekOp::GE { eq_only: true } + )); + + // Early exit if not found in index + if !matches!(seek_result, SeekResult::Found) { + let groups = std::mem::take(groups_to_fetch); + let values = std::mem::take(values_to_fetch); + *self = FetchDistinctState::FetchGroup { + groups_to_fetch: groups, + group_idx: *group_idx, + value_idx: *value_idx + 1, + values_to_fetch: values, + }; + continue; + } + + // Get the rowid from the index cursor + let rowid = return_if_io!(cursors.index_cursor.rowid()); + + // Early exit if no rowid + let rowid = match rowid { + Some(id) => id, + None => { + let groups = std::mem::take(groups_to_fetch); + let values = std::mem::take(values_to_fetch); + *self = FetchDistinctState::FetchGroup { + groups_to_fetch: groups, + group_idx: *group_idx, + value_idx: *value_idx + 1, + values_to_fetch: values, + }; + continue; + } + }; + + // Now seek in the table cursor using the rowid + let table_result = return_if_io!(cursors + .table_cursor + .seek(SeekKey::TableRowId(rowid), SeekOp::GE { eq_only: true })); + + // Early exit if not found in table + if !matches!(table_result, SeekResult::Found) { + let groups = std::mem::take(groups_to_fetch); + let values = std::mem::take(values_to_fetch); + *self = FetchDistinctState::FetchGroup { + groups_to_fetch: groups, + group_idx: *group_idx, + value_idx: *value_idx + 1, + values_to_fetch: values, + }; + continue; + } + + // Read the actual record from the table cursor + let record = return_if_io!(cursors.table_cursor.record()); + + if let Some(r) = record { + let values = r.get_values(); + + // The table has 5 columns: storage_id, zset_hash, element_id, blob, weight + // The weight is at index 4 + if values.len() >= 5 { + // Get the weight directly from column 4 + let weight = match values[4].to_owned() { + Value::Integer(w) => w, + _ => 0, + }; + + // Store the weight in the existing group's state + let state = existing_groups.entry(group_key.clone()).or_default(); + state.distinct_value_weights.insert( + ( + *column_idx, + HashableRow::new(*column_idx as i64, vec![value.clone()]), + ), + weight, + ); + } + } + + // Move to next value + let groups = std::mem::take(groups_to_fetch); + let values = std::mem::take(values_to_fetch); + *self = FetchDistinctState::FetchGroup { + groups_to_fetch: groups, + group_idx: *group_idx, + value_idx: *value_idx + 1, + values_to_fetch: values, + }; + } + FetchDistinctState::Done => { + // For plain DISTINCT, construct AggregateState from the weights we fetched + if is_plain_distinct { + for (_group_key_str, state) in existing_groups.iter_mut() { + // For plain DISTINCT, sum all the weights to get total count + // Each weight represents how many times the distinct value appears + let total_weight: i64 = state.distinct_value_weights.values().sum(); + + // Set the count based on total weight + state.count = total_weight; + } + } + return Ok(IOResult::Done(())); + } + } + } + } +} + +/// State machine for persisting distinct values to BTree storage +#[derive(Debug)] +pub enum DistinctPersistState { + Init { + distinct_deltas: DistinctDeltas, + group_keys: Vec, + }, + ProcessGroup { + distinct_deltas: DistinctDeltas, + group_keys: Vec, + group_idx: usize, + value_keys: Vec<(usize, HashableRow)>, // (col_idx, value) pairs for current group + value_idx: usize, + }, + WriteValue { + distinct_deltas: DistinctDeltas, + group_keys: Vec, + group_idx: usize, + value_keys: Vec<(usize, HashableRow)>, + value_idx: usize, + group_key: String, + col_idx: usize, + value: Value, + weight: isize, + write_row: WriteRow, + }, + Done, +} + +impl DistinctPersistState { + pub fn new(distinct_deltas: DistinctDeltas) -> Self { + let group_keys: Vec = distinct_deltas.keys().cloned().collect(); + Self::Init { + distinct_deltas, + group_keys, + } + } + + pub fn persist_distinct_values( + &mut self, + operator_id: i64, + cursors: &mut DbspStateCursors, + generate_group_hash: impl Fn(&str) -> Hash128, + ) -> Result> { + loop { + match self { + DistinctPersistState::Init { + distinct_deltas, + group_keys, + } => { + let distinct_deltas = std::mem::take(distinct_deltas); + let group_keys = std::mem::take(group_keys); + *self = DistinctPersistState::ProcessGroup { + distinct_deltas, + group_keys, + group_idx: 0, + value_keys: Vec::new(), + value_idx: 0, + }; + } + DistinctPersistState::ProcessGroup { + distinct_deltas, + group_keys, + group_idx, + value_keys, + value_idx, + } => { + // Check if we're past all groups + if *group_idx >= group_keys.len() { + *self = DistinctPersistState::Done; + continue; + } + + // Check if we need to get value_keys for current group + if value_keys.is_empty() && *group_idx < group_keys.len() { + let group_key_str = &group_keys[*group_idx]; + if let Some(group_values) = distinct_deltas.get(group_key_str) { + *value_keys = group_values.keys().cloned().collect(); + } + } + + // Check if we have more values in current group + if *value_idx >= value_keys.len() { + *group_idx += 1; + *value_idx = 0; + value_keys.clear(); + continue; + } + + // Process current value + let group_key = group_keys[*group_idx].clone(); + let (col_idx, hashable_row) = value_keys[*value_idx].clone(); + let weight = distinct_deltas[&group_key][&(col_idx, hashable_row.clone())]; + // Extract the value from HashableRow (it's the first element in values vector) + let value = hashable_row.values.first().unwrap().clone(); + + let distinct_deltas = std::mem::take(distinct_deltas); + let group_keys = std::mem::take(group_keys); + let value_keys = std::mem::take(value_keys); + *self = DistinctPersistState::WriteValue { + distinct_deltas, + group_keys, + group_idx: *group_idx, + value_keys, + value_idx: *value_idx, + group_key, + col_idx, + value, + weight, + write_row: WriteRow::new(), + }; + } + DistinctPersistState::WriteValue { + distinct_deltas, + group_keys, + group_idx, + value_keys, + value_idx, + group_key, + col_idx, + value, + weight, + write_row, + } => { + // Build the key components for DISTINCT storage + let storage_id = generate_storage_id(operator_id, *col_idx, AGG_TYPE_DISTINCT); + let zset_hash = generate_group_hash(group_key); + + // For DISTINCT, element_id is a hash of the value + let element_id = hash_value(value, *col_idx); + + // Create index key + let index_key = vec![ + Value::Integer(storage_id), + zset_hash.to_value(), + element_id.to_value(), + ]; + + // Record values (operator_id, zset_hash, element_id, weight_blob) + // Store weight as a minimal AggregateState blob so ReadRecord can parse it + let weight_state = AggregateState { + count: *weight as i64, + ..Default::default() + }; + let weight_blob = weight_state.to_blob(&[], &[]); + + let record_values = vec![ + Value::Integer(storage_id), + zset_hash.to_value(), + element_id.to_value(), + Value::Blob(weight_blob), + ]; + + // Write to BTree + return_if_io!(write_row.write_row(cursors, index_key, record_values, *weight)); + + // Move to next value + let distinct_deltas = std::mem::take(distinct_deltas); + let group_keys = std::mem::take(group_keys); + let value_keys = std::mem::take(value_keys); + *self = DistinctPersistState::ProcessGroup { + distinct_deltas, + group_keys, + group_idx: *group_idx, + value_keys, + value_idx: *value_idx + 1, + }; + } + DistinctPersistState::Done => { + return Ok(IOResult::Done(())); + } + } + } + } +} + impl MinMaxPersistState { pub fn new(min_max_deltas: MinMaxDeltas) -> Self { let group_keys: Vec = min_max_deltas.keys().cloned().collect(); diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index 40a8ca2af..28bad0d0d 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -5,6 +5,7 @@ //! //! Based on the DBSP paper: "DBSP: Automatic Incremental View Maintenance for Rich Query Languages" +use crate::incremental::aggregate_operator::AggregateOperator; use crate::incremental::dbsp::{Delta, DeltaPair}; use crate::incremental::expr_compiler::CompiledExpression; use crate::incremental::operator::{ @@ -12,7 +13,7 @@ use crate::incremental::operator::{ IncrementalOperator, InputOperator, JoinOperator, JoinType, ProjectOperator, }; use crate::schema::Type; -use crate::storage::btree::{BTreeCursor, BTreeKey}; +use crate::storage::btree::{BTreeCursor, BTreeKey, CursorTrait}; // Note: logical module must be made pub(crate) in translate/mod.rs use crate::translate::logical::{ BinaryOperator, Column, ColumnInfo, JoinType as LogicalJoinType, LogicalExpr, LogicalPlan, @@ -300,6 +301,8 @@ pub enum DbspOperator { Input { name: String, schema: SchemaRef }, /// Merge operator for combining streams (used in recursive CTEs and UNION) Merge { schema: SchemaRef }, + /// Distinct operator - removes duplicates + Distinct { schema: SchemaRef }, } /// Represents an expression in DBSP @@ -329,6 +332,11 @@ pub struct DbspNode { pub executable: Box, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for DbspNode {} +unsafe impl Sync for DbspNode {} + impl std::fmt::Debug for DbspNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("DbspNode") @@ -395,6 +403,11 @@ pub struct DbspCircuit { pub(super) internal_state_index_root: i64, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for DbspCircuit {} +unsafe impl Sync for DbspCircuit {} + impl DbspCircuit { /// Create a new empty circuit with initial empty schema /// The actual output schema will be set when the root node is established @@ -480,15 +493,10 @@ impl DbspCircuit { ) -> Result> { if let Some(root_id) = self.root { // Create temporary cursors for execute (non-commit) operations - let table_cursor = BTreeCursor::new_table( - None, - pager.clone(), - self.internal_state_root, - OPERATOR_COLUMNS, - ); + let table_cursor = + BTreeCursor::new_table(pager.clone(), self.internal_state_root, OPERATOR_COLUMNS); let index_def = create_dbsp_state_index(self.internal_state_index_root); let index_cursor = BTreeCursor::new_index( - None, pager.clone(), self.internal_state_index_root, &index_def, @@ -537,14 +545,12 @@ impl DbspCircuit { CommitState::Init => { // Create state cursors when entering CommitOperators state let state_table_cursor = BTreeCursor::new_table( - None, pager.clone(), self.internal_state_root, OPERATOR_COLUMNS, ); let index_def = create_dbsp_state_index(self.internal_state_index_root); let state_index_cursor = BTreeCursor::new_index( - None, pager.clone(), self.internal_state_index_root, &index_def, @@ -575,7 +581,6 @@ impl DbspCircuit { // Create view cursor when entering UpdateView state let view_cursor = Box::new(BTreeCursor::new_table( - None, pager.clone(), main_data_root, num_columns, @@ -605,7 +610,6 @@ impl DbspCircuit { // due to btree cursor state machine limitations if matches!(write_row_state, WriteRowView::GetRecord) { *view_cursor = Box::new(BTreeCursor::new_table( - None, pager.clone(), main_data_root, num_columns, @@ -633,7 +637,6 @@ impl DbspCircuit { let view_cursor = std::mem::replace( view_cursor, Box::new(BTreeCursor::new_table( - None, pager.clone(), main_data_root, num_columns, @@ -729,14 +732,12 @@ impl DbspCircuit { // Create temporary cursors for the recursive call let temp_table_cursor = BTreeCursor::new_table( - None, pager.clone(), self.internal_state_root, OPERATOR_COLUMNS, ); let index_def = create_dbsp_state_index(self.internal_state_index_root); let temp_index_cursor = BTreeCursor::new_index( - None, pager.clone(), self.internal_state_index_root, &index_def, @@ -820,6 +821,13 @@ impl DbspCircuit { schema.columns.len() )?; } + DbspOperator::Distinct { schema } => { + writeln!( + f, + "{indent}Distinct[{node_id}]: (schema: {} columns)", + schema.columns.len() + )?; + } } for input_id in &node.inputs { @@ -1145,16 +1153,34 @@ impl DbspCompiler { } } - // Compile aggregate expressions + // Compile aggregate expressions (both DISTINCT and regular) let mut aggregate_functions = Vec::new(); for expr in &agg.aggr_expr { - if let LogicalExpr::AggregateFunction { fun, args, .. } = expr { + if let LogicalExpr::AggregateFunction { fun, args, distinct } = expr { use crate::function::AggFunc; use crate::incremental::aggregate_operator::AggregateFunction; match fun { AggFunc::Count | AggFunc::Count0 => { - aggregate_functions.push(AggregateFunction::Count); + if *distinct { + // COUNT(DISTINCT col) + if args.is_empty() { + return Err(LimboError::ParseError("COUNT(DISTINCT) requires an argument".to_string())); + } + if let LogicalExpr::Column(col) = &args[0] { + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("COUNT(DISTINCT) column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::CountDistinct(col_idx)); + } else { + return Err(LimboError::ParseError( + "Only column references are supported in aggregate functions for incremental views".to_string() + )); + } + } else { + aggregate_functions.push(AggregateFunction::Count); + } } AggFunc::Sum => { if args.is_empty() { @@ -1166,7 +1192,11 @@ impl DbspCompiler { .ok_or_else(|| LimboError::ParseError( format!("SUM column '{}' not found in input", col.name) ))?; - aggregate_functions.push(AggregateFunction::Sum(col_idx)); + if *distinct { + aggregate_functions.push(AggregateFunction::SumDistinct(col_idx)); + } else { + aggregate_functions.push(AggregateFunction::Sum(col_idx)); + } } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -1182,7 +1212,11 @@ impl DbspCompiler { .ok_or_else(|| LimboError::ParseError( format!("AVG column '{}' not found in input", col.name) ))?; - aggregate_functions.push(AggregateFunction::Avg(col_idx)); + if *distinct { + aggregate_functions.push(AggregateFunction::AvgDistinct(col_idx)); + } else { + aggregate_functions.push(AggregateFunction::Avg(col_idx)); + } } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -1366,14 +1400,48 @@ impl DbspCompiler { // Handle UNION and UNION ALL self.compile_union(union) } + LogicalPlan::Distinct(distinct) => { + // DISTINCT is implemented as GROUP BY all columns with a special aggregate + let input_id = self.compile_plan(&distinct.input)?; + let input_schema = distinct.input.schema(); + + // Create GROUP BY indices for all columns + let group_by: Vec = (0..input_schema.columns.len()).collect(); + + // Column names for the operator + let input_column_names: Vec = input_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + + // Create the aggregate operator with DISTINCT mode + let operator_id = self.circuit.next_id; + let executable: Box = Box::new( + AggregateOperator::new( + operator_id, + group_by, + vec![], // Empty aggregates indicates plain DISTINCT + input_column_names, + ), + ); + + // Add the node to the circuit + let node_id = self.circuit.add_node( + DbspOperator::Distinct { + schema: input_schema.clone(), + }, + vec![input_id], + executable, + ); + + Ok(node_id) + } _ => Err(LimboError::ParseError( format!("Unsupported operator in DBSP compiler: only Filter, Projection, Join, Aggregate, and Union are supported, got: {:?}", match plan { LogicalPlan::Sort(_) => "Sort", LogicalPlan::Limit(_) => "Limit", LogicalPlan::Union(_) => "Union", - LogicalPlan::Distinct(_) => "Distinct", - LogicalPlan::EmptyRelation(_) => "EmptyRelation", + LogicalPlan::EmptyRelation(_) => "EmptyRelation", LogicalPlan::Values(_) => "Values", LogicalPlan::WithCTE(_) => "WithCTE", LogicalPlan::CTERef(_) => "CTERef", @@ -2149,6 +2217,31 @@ impl DbspCompiler { )) } } + LogicalExpr::IsNull { expr, negated } => { + // Extract column index from the inner expression + if let LogicalExpr::Column(col) = expr.as_ref() { + let column_idx = schema + .columns + .iter() + .position(|c| c.name == col.name) + .ok_or_else(|| { + LimboError::ParseError(format!( + "Column '{}' not found in schema for IS NULL filter", + col.name + )) + })?; + + if *negated { + Ok(FilterPredicate::IsNotNull { column_idx }) + } else { + Ok(FilterPredicate::IsNull { column_idx }) + } + } else { + Err(LimboError::ParseError( + "IS NULL/IS NOT NULL expects a column reference".to_string(), + )) + } + } _ => Err(LimboError::ParseError(format!( "Unsupported filter expression: {expr:?}" ))), @@ -2220,8 +2313,11 @@ mod tests { is_strict: false, has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }; - schema.add_btree_table(Arc::new(users_table)); + schema + .add_btree_table(Arc::new(users_table)) + .expect("Test setup: failed to add users table"); // Add products table for join tests let products_table = BTreeTable { @@ -2273,8 +2369,11 @@ mod tests { is_strict: false, has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }; - schema.add_btree_table(Arc::new(products_table)); + schema + .add_btree_table(Arc::new(products_table)) + .expect("Test setup: failed to add products table"); // Add orders table for join tests let orders_table = BTreeTable { @@ -2338,8 +2437,11 @@ mod tests { has_autoincrement: false, is_strict: false, unique_sets: vec![], + foreign_keys: vec![], }; - schema.add_btree_table(Arc::new(orders_table)); + schema + .add_btree_table(Arc::new(orders_table)) + .expect("Test setup: failed to add orders table"); // Add customers table with id and name for testing column ambiguity let customers_table = BTreeTable { @@ -2376,8 +2478,11 @@ mod tests { is_strict: false, has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }; - schema.add_btree_table(Arc::new(customers_table)); + schema + .add_btree_table(Arc::new(customers_table)) + .expect("Test setup: failed to add customers table"); // Add purchases table (junction table for three-way join) let purchases_table = BTreeTable { @@ -2438,8 +2543,11 @@ mod tests { is_strict: false, has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }; - schema.add_btree_table(Arc::new(purchases_table)); + schema + .add_btree_table(Arc::new(purchases_table)) + .expect("Test setup: failed to add purchases table"); // Add vendors table with id, name, and price (ambiguous columns with customers) let vendors_table = BTreeTable { @@ -2488,8 +2596,11 @@ mod tests { is_strict: false, has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }; - schema.add_btree_table(Arc::new(vendors_table)); + schema + .add_btree_table(Arc::new(vendors_table)) + .expect("Test setup: failed to add vendors table"); let sales_table = BTreeTable { name: "sales".to_string(), @@ -2525,8 +2636,11 @@ mod tests { is_strict: false, has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }; - schema.add_btree_table(Arc::new(sales_table)); + schema + .add_btree_table(Arc::new(sales_table)) + .expect("Test setup: failed to add sales table"); schema }}; @@ -2536,7 +2650,7 @@ mod tests { let io: Arc = Arc::new(MemoryIO::new()); let db = Database::open_file(io.clone(), ":memory:", false, false).unwrap(); let conn = db.connect().unwrap(); - let pager = conn.pager.read().clone(); + let pager = conn.pager.load().clone(); let _ = pager.io.block(|| pager.allocate_page1()).unwrap(); @@ -2710,14 +2824,15 @@ mod tests { // This reads the actual persisted data from the BTree #[cfg(test)] fn get_current_state(pager: Arc, circuit: &DbspCircuit) -> Result { + use crate::storage::btree::CursorTrait; + let mut delta = Delta::new(); let main_data_root = circuit.main_data_root; let num_columns = circuit.output_schema.columns.len() + 1; // Create a cursor to read the btree - let mut btree_cursor = - BTreeCursor::new_table(None, pager.clone(), main_data_root, num_columns); + let mut btree_cursor = BTreeCursor::new_table(pager.clone(), main_data_root, num_columns); // Rewind to the beginning pager.io.block(|| btree_cursor.rewind())?; @@ -3266,11 +3381,12 @@ mod tests { assert_eq!(row.values.len(), 1); // The hex function converts the number to string first, then to hex - // 96 as string is "96", which in hex is "3936" (hex of ASCII '9' and '6') + // SUM now returns Float, so 96.0 as string is "96.0", which in hex is "39362E30" + // (hex of ASCII '9', '6', '.', '0') assert_eq!( row.values[0], - Value::Text("3936".to_string().into()), - "HEX(SUM(age + 2)) should return '3936' for sum of 96" + Value::Text("39362E30".to_string().into()), + "HEX(SUM(age + 2)) should return '39362E30' for sum of 96.0" ); // Test incremental update: add a new user @@ -3289,22 +3405,22 @@ mod tests { let result = test_execute(&mut circuit, input_data, pager.clone()).unwrap(); - // Expected: new SUM(age + 2) = 96 + (40+2) = 138 - // HEX(138) = hex of "138" = "313338" + // Expected: new SUM(age + 2) = 96.0 + (40+2) = 138.0 + // HEX(138.0) = hex of "138.0" = "3133382E30" assert_eq!(result.changes.len(), 2); - // First change: remove old aggregate (96) + // First change: remove old aggregate (96.0) let (row, weight) = &result.changes[0]; assert_eq!(*weight, -1); - assert_eq!(row.values[0], Value::Text("3936".to_string().into())); + assert_eq!(row.values[0], Value::Text("39362E30".to_string().into())); - // Second change: add new aggregate (138) + // Second change: add new aggregate (138.0) let (row, weight) = &result.changes[1]; assert_eq!(*weight, 1); assert_eq!( row.values[0], - Value::Text("313338".to_string().into()), - "HEX(SUM(age + 2)) should return '313338' for sum of 138" + Value::Text("3133382E30".to_string().into()), + "HEX(SUM(age + 2)) should return '3133382E30' for sum of 138.0" ); } @@ -3352,8 +3468,8 @@ mod tests { .unwrap(); // Expected results: - // Alice: SUM(25*2 + 35*2) = 50 + 70 = 120, HEX("120") = "313230" - // Bob: SUM(30*2) = 60, HEX("60") = "3630" + // Alice: SUM(25*2 + 35*2) = 50 + 70 = 120.0, HEX("120.0") = "3132302E30" + // Bob: SUM(30*2) = 60.0, HEX("60.0") = "36302E30" assert_eq!(result.changes.len(), 2); let results: HashMap = result @@ -3374,13 +3490,13 @@ mod tests { assert_eq!( results.get("Alice").unwrap(), - "313230", - "Alice's HEX(SUM(age * 2)) should be '313230' (120)" + "3132302E30", + "Alice's HEX(SUM(age * 2)) should be '3132302E30' (120.0)" ); assert_eq!( results.get("Bob").unwrap(), - "3630", - "Bob's HEX(SUM(age * 2)) should be '3630' (60)" + "36302E30", + "Bob's HEX(SUM(age * 2)) should be '36302E30' (60.0)" ); } @@ -4697,12 +4813,12 @@ mod tests { ); // Check the results - let mut results_map: HashMap = HashMap::new(); + let mut results_map: HashMap = HashMap::new(); for (row, weight) in result.changes { assert_eq!(weight, 1); assert_eq!(row.values.len(), 2); // name and total_quantity - if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + if let (Value::Text(name), Value::Float(total)) = (&row.values[0], &row.values[1]) { results_map.insert(name.to_string(), *total); } else { panic!("Unexpected value types in result"); @@ -4711,12 +4827,12 @@ mod tests { assert_eq!( results_map.get("Alice"), - Some(&10), + Some(&10.0), "Alice should have total quantity 10" ); assert_eq!( results_map.get("Bob"), - Some(&7), + Some(&7.0), "Bob should have total quantity 7" ); } @@ -4813,24 +4929,24 @@ mod tests { ); // Check the results - let mut results_map: HashMap = HashMap::new(); + let mut results_map: HashMap = HashMap::new(); for (row, weight) in result.changes { assert_eq!(weight, 1); assert_eq!(row.values.len(), 2); // name and total - if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + if let (Value::Text(name), Value::Float(total)) = (&row.values[0], &row.values[1]) { results_map.insert(name.to_string(), *total); } } assert_eq!( results_map.get("Alice"), - Some(&8), + Some(&8.0), "Alice should have total 8" ); assert_eq!( results_map.get("Charlie"), - Some(&7), + Some(&7.0), "Charlie should have total 7" ); assert_eq!(results_map.get("Bob"), None, "Bob should be filtered out"); @@ -4969,7 +5085,7 @@ mod tests { // Row should have name, product_name, and sum columns assert_eq!(row.values.len(), 3); - if let (Value::Text(name), Value::Text(product), Value::Integer(total)) = + if let (Value::Text(name), Value::Text(product), Value::Float(total)) = (&row.values[0], &row.values[1], &row.values[2]) { let key = format!("{}-{}", name.as_ref(), product.as_ref()); @@ -4977,12 +5093,14 @@ mod tests { match key.as_str() { "Alice-Widget" => { - assert_eq!(*total, 9, "Alice should have ordered 9 Widgets total") + assert_eq!(*total, 9.0, "Alice should have ordered 9 Widgets total") } - "Alice-Gadget" => assert_eq!(*total, 3, "Alice should have ordered 3 Gadgets"), - "Bob-Widget" => assert_eq!(*total, 7, "Bob should have ordered 7 Widgets"), + "Alice-Gadget" => { + assert_eq!(*total, 3.0, "Alice should have ordered 3 Gadgets") + } + "Bob-Widget" => assert_eq!(*total, 7.0, "Bob should have ordered 7 Widgets"), "Bob-Doohickey" => { - assert_eq!(*total, 2, "Bob should have ordered 2 Doohickeys") + assert_eq!(*total, 2.0, "Bob should have ordered 2 Doohickeys") } _ => panic!("Unexpected result: {key}"), } diff --git a/core/incremental/cursor.rs b/core/incremental/cursor.rs index 67814bd0d..12e3c49c6 100644 --- a/core/incremental/cursor.rs +++ b/core/incremental/cursor.rs @@ -5,7 +5,7 @@ use crate::{ view::{IncrementalView, ViewTransactionState}, }, return_if_io, - storage::btree::BTreeCursor, + storage::btree::CursorTrait, types::{IOResult, SeekKey, SeekOp, SeekResult, Value}, LimboError, Pager, Result, }; @@ -35,7 +35,7 @@ enum SeekState { /// and overlays transaction changes as needed. pub struct MaterializedViewCursor { // Core components - btree_cursor: Box, + btree_cursor: Box, view: Arc>, pager: Arc, @@ -62,7 +62,7 @@ pub struct MaterializedViewCursor { impl MaterializedViewCursor { pub fn new( - btree_cursor: Box, + btree_cursor: Box, view: Arc>, pager: Arc, tx_state: Arc, @@ -117,15 +117,12 @@ impl MaterializedViewCursor { Some(rowid) => rowid, }; - let btree_record = return_if_io!(self.btree_cursor.record()); - let btree_ref_values = btree_record - .ok_or_else(|| { - crate::LimboError::InternalError( - "Invalid data in materialized view: found a rowid, but not the row!" - .to_string(), - ) - })? - .get_values(); + let btree_record = return_if_io!(self.btree_cursor.record()).ok_or_else(|| { + crate::LimboError::InternalError( + "Invalid data in materialized view: found a rowid, but not the row!".to_string(), + ) + })?; + let btree_ref_values = btree_record.get_values(); // Convert RefValues to Values (copying for now - can optimize later) let mut btree_values: Vec = @@ -299,6 +296,7 @@ impl MaterializedViewCursor { #[cfg(test)] mod tests { use super::*; + use crate::storage::btree::BTreeCursor; use crate::util::IOExt; use crate::{Connection, Database, OpenFlags}; use std::sync::Arc; @@ -362,12 +360,7 @@ mod tests { // Create a btree cursor let pager = conn.get_pager(); - let btree_cursor = Box::new(BTreeCursor::new( - None, // No MvCursor - pager.clone(), - root_page, - num_columns, - )); + let btree_cursor = Box::new(BTreeCursor::new(pager.clone(), root_page, num_columns)); // Get or create transaction state for this view let tx_state = conn.view_transaction_states.get_or_create("test_view"); diff --git a/core/incremental/expr_compiler.rs b/core/incremental/expr_compiler.rs index 44b2cef49..f823c870d 100644 --- a/core/incremental/expr_compiler.rs +++ b/core/incremental/expr_compiler.rs @@ -458,11 +458,6 @@ impl CompiledExpression { "Expression evaluation produced unexpected row".to_string(), )); } - crate::vdbe::execute::InsnFunctionStepResult::Interrupt => { - return Err(crate::LimboError::InternalError( - "Expression evaluation was interrupted".to_string(), - )); - } crate::vdbe::execute::InsnFunctionStepResult::Step => { pc = state.pc as usize; } diff --git a/core/incremental/filter_operator.rs b/core/incremental/filter_operator.rs index 84a3c53ce..5b9c7e5d9 100644 --- a/core/incremental/filter_operator.rs +++ b/core/incremental/filter_operator.rs @@ -39,6 +39,11 @@ pub enum FilterPredicate { /// Column <= Column comparisons ColumnLessThanOrEqual { left_idx: usize, right_idx: usize }, + /// Column IS NULL check + IsNull { column_idx: usize }, + /// Column IS NOT NULL check + IsNotNull { column_idx: usize }, + /// Logical AND of two predicates And(Box, Box), /// Logical OR of two predicates @@ -214,6 +219,18 @@ impl FilterOperator { } false } + FilterPredicate::IsNull { column_idx } => { + if let Some(v) = values.get(*column_idx) { + return matches!(v, Value::Null); + } + false + } + FilterPredicate::IsNotNull { column_idx } => { + if let Some(v) = values.get(*column_idx) { + return !matches!(v, Value::Null); + } + false + } } } } @@ -293,3 +310,202 @@ impl IncrementalOperator for FilterOperator { self.tracker = Some(tracker); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::Text; + + #[test] + fn test_is_null_predicate() { + let predicate = FilterPredicate::IsNull { column_idx: 1 }; + let filter = FilterOperator::new(predicate); + + // Test with NULL value + let values_with_null = vec![ + Value::Integer(1), + Value::Null, + Value::Text(Text::from("test")), + ]; + assert!(filter.evaluate_predicate(&values_with_null)); + + // Test with non-NULL value + let values_without_null = vec![ + Value::Integer(1), + Value::Integer(42), + Value::Text(Text::from("test")), + ]; + assert!(!filter.evaluate_predicate(&values_without_null)); + + // Test with different non-NULL types + let values_with_text = vec![ + Value::Integer(1), + Value::Text(Text::from("not null")), + Value::Text(Text::from("test")), + ]; + assert!(!filter.evaluate_predicate(&values_with_text)); + + let values_with_blob = vec![ + Value::Integer(1), + Value::Blob(vec![1, 2, 3]), + Value::Text(Text::from("test")), + ]; + assert!(!filter.evaluate_predicate(&values_with_blob)); + } + + #[test] + fn test_is_not_null_predicate() { + let predicate = FilterPredicate::IsNotNull { column_idx: 1 }; + let filter = FilterOperator::new(predicate); + + // Test with NULL value + let values_with_null = vec![ + Value::Integer(1), + Value::Null, + Value::Text(Text::from("test")), + ]; + assert!(!filter.evaluate_predicate(&values_with_null)); + + // Test with non-NULL value (Integer) + let values_with_integer = vec![ + Value::Integer(1), + Value::Integer(42), + Value::Text(Text::from("test")), + ]; + assert!(filter.evaluate_predicate(&values_with_integer)); + + // Test with non-NULL value (Text) + let values_with_text = vec![ + Value::Integer(1), + Value::Text(Text::from("not null")), + Value::Text(Text::from("test")), + ]; + assert!(filter.evaluate_predicate(&values_with_text)); + + // Test with non-NULL value (Blob) + let values_with_blob = vec![ + Value::Integer(1), + Value::Blob(vec![1, 2, 3]), + Value::Text(Text::from("test")), + ]; + assert!(filter.evaluate_predicate(&values_with_blob)); + } + + #[test] + fn test_is_null_with_and() { + // Test: column_0 = 1 AND column_1 IS NULL + let predicate = FilterPredicate::And( + Box::new(FilterPredicate::Equals { + column_idx: 0, + value: Value::Integer(1), + }), + Box::new(FilterPredicate::IsNull { column_idx: 1 }), + ); + let filter = FilterOperator::new(predicate); + + // Should match: column_0 = 1 AND column_1 IS NULL + let values_match = vec![ + Value::Integer(1), + Value::Null, + Value::Text(Text::from("test")), + ]; + assert!(filter.evaluate_predicate(&values_match)); + + // Should not match: column_0 = 2 AND column_1 IS NULL + let values_wrong_first = vec![ + Value::Integer(2), + Value::Null, + Value::Text(Text::from("test")), + ]; + assert!(!filter.evaluate_predicate(&values_wrong_first)); + + // Should not match: column_0 = 1 AND column_1 IS NOT NULL + let values_not_null = vec![ + Value::Integer(1), + Value::Integer(42), + Value::Text(Text::from("test")), + ]; + assert!(!filter.evaluate_predicate(&values_not_null)); + } + + #[test] + fn test_is_not_null_with_or() { + // Test: column_0 = 1 OR column_1 IS NOT NULL + let predicate = FilterPredicate::Or( + Box::new(FilterPredicate::Equals { + column_idx: 0, + value: Value::Integer(1), + }), + Box::new(FilterPredicate::IsNotNull { column_idx: 1 }), + ); + let filter = FilterOperator::new(predicate); + + // Should match: column_0 = 1 (regardless of column_1) + let values_first_matches = vec![ + Value::Integer(1), + Value::Null, + Value::Text(Text::from("test")), + ]; + assert!(filter.evaluate_predicate(&values_first_matches)); + + // Should match: column_1 IS NOT NULL (regardless of column_0) + let values_second_matches = vec![ + Value::Integer(2), + Value::Integer(42), + Value::Text(Text::from("test")), + ]; + assert!(filter.evaluate_predicate(&values_second_matches)); + + // Should not match: column_0 != 1 AND column_1 IS NULL + let values_no_match = vec![ + Value::Integer(2), + Value::Null, + Value::Text(Text::from("test")), + ]; + assert!(!filter.evaluate_predicate(&values_no_match)); + } + + #[test] + fn test_complex_null_predicates() { + // Test: (column_0 IS NULL OR column_1 IS NOT NULL) AND column_2 = 'test' + let predicate = FilterPredicate::And( + Box::new(FilterPredicate::Or( + Box::new(FilterPredicate::IsNull { column_idx: 0 }), + Box::new(FilterPredicate::IsNotNull { column_idx: 1 }), + )), + Box::new(FilterPredicate::Equals { + column_idx: 2, + value: Value::Text(Text::from("test")), + }), + ); + let filter = FilterOperator::new(predicate); + + // Should match: column_0 IS NULL, column_2 = 'test' + let values1 = vec![Value::Null, Value::Null, Value::Text(Text::from("test"))]; + assert!(filter.evaluate_predicate(&values1)); + + // Should match: column_1 IS NOT NULL, column_2 = 'test' + let values2 = vec![ + Value::Integer(1), + Value::Integer(42), + Value::Text(Text::from("test")), + ]; + assert!(filter.evaluate_predicate(&values2)); + + // Should not match: column_2 != 'test' + let values3 = vec![ + Value::Null, + Value::Integer(42), + Value::Text(Text::from("other")), + ]; + assert!(!filter.evaluate_predicate(&values3)); + + // Should not match: column_0 IS NOT NULL AND column_1 IS NULL AND column_2 = 'test' + let values4 = vec![ + Value::Integer(1), + Value::Null, + Value::Text(Text::from("test")), + ]; + assert!(!filter.evaluate_predicate(&values4)); + } +} diff --git a/core/incremental/join_operator.rs b/core/incremental/join_operator.rs index 722274559..982545ca9 100644 --- a/core/incremental/join_operator.rs +++ b/core/incremental/join_operator.rs @@ -6,6 +6,7 @@ use crate::incremental::operator::{ generate_storage_id, ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, }; use crate::incremental::persistence::WriteRow; +use crate::storage::btree::CursorTrait; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; use crate::{return_and_restore_if_io, return_if_io, Result, Value}; use std::sync::{Arc, Mutex}; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index f4598e1fb..764a59167 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -218,7 +218,9 @@ pub enum QueryOperator { /// Operator DAG (Directed Acyclic Graph) /// Base trait for incremental operators -pub trait IncrementalOperator: Debug { +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +pub trait IncrementalOperator: Debug + Send { /// Evaluate the operator with a state, without modifying internal state /// This is used during query execution to compute results /// May need to read from storage to get current state (e.g., for aggregates) @@ -254,6 +256,7 @@ mod tests { use super::*; use crate::incremental::aggregate_operator::{AggregateOperator, AGG_TYPE_REGULAR}; use crate::incremental::dbsp::HashableRow; + use crate::storage::btree::CursorTrait; use crate::storage::pager::CreateBTreeFlags; use crate::types::Text; use crate::util::IOExt; @@ -267,7 +270,7 @@ mod tests { let db = Database::open_file(io.clone(), ":memory:", false, false).unwrap(); let conn = db.connect().unwrap(); - let pager = conn.pager.read().clone(); + let pager = conn.pager.load().clone(); // Allocate page 1 first (database header) let _ = pager.io.block(|| pager.allocate_page1()); @@ -390,12 +393,11 @@ mod tests { // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); // Create an aggregate operator for SUM(age) with no GROUP BY @@ -510,12 +512,11 @@ mod tests { // Create an aggregate operator for SUM(score) GROUP BY team // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -661,12 +662,11 @@ mod tests { // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); // Create COUNT(*) GROUP BY category @@ -742,12 +742,11 @@ mod tests { // Create SUM(amount) GROUP BY product // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -839,12 +838,11 @@ mod tests { // Test the example from DBSP_ROADMAP: COUNT(*) and SUM(amount) GROUP BY user_id // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -931,12 +929,11 @@ mod tests { // Test AVG aggregation // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -1031,12 +1028,11 @@ mod tests { // Test that deletes (negative weights) properly update aggregates // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -1118,12 +1114,11 @@ mod tests { // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -1208,12 +1203,11 @@ mod tests { // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -1292,12 +1286,11 @@ mod tests { // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -1361,12 +1354,11 @@ mod tests { // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -1447,12 +1439,11 @@ mod tests { // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut filter = FilterOperator::new(FilterPredicate::GreaterThan { @@ -1506,12 +1497,11 @@ mod tests { fn test_filter_eval_with_uncommitted() { // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut filter = FilterOperator::new(FilterPredicate::GreaterThan { @@ -1597,12 +1587,11 @@ mod tests { // This is the critical test - aggregations must not modify internal state during eval // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -1767,12 +1756,11 @@ mod tests { // doesn't pollute the internal state // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -1849,12 +1837,11 @@ mod tests { // Test eval with both committed delta and uncommitted changes // Create a persistent pager for the test let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); // Create index cursor with proper index definition for DBSP state table let index_def = create_dbsp_state_index(index_root_page_id); // Index has 4 columns: operator_id, zset_id, element_id, rowid - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -1965,10 +1952,9 @@ mod tests { fn test_min_max_basic() { // Test basic MIN/MAX functionality let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2033,10 +2019,9 @@ mod tests { fn test_min_max_deletion_updates_min() { // Test that deleting the MIN value updates to the next lowest let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2123,10 +2108,9 @@ mod tests { fn test_min_max_deletion_updates_max() { // Test that deleting the MAX value updates to the next highest let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2213,10 +2197,9 @@ mod tests { fn test_min_max_insertion_updates_min() { // Test that inserting a new MIN value updates the aggregate let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2295,10 +2278,9 @@ mod tests { fn test_min_max_insertion_updates_max() { // Test that inserting a new MAX value updates the aggregate let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2377,10 +2359,9 @@ mod tests { fn test_min_max_update_changes_min() { // Test that updating a row to become the new MIN updates the aggregate let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2467,10 +2448,9 @@ mod tests { fn test_min_max_with_group_by() { // Test MIN/MAX with GROUP BY let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2569,10 +2549,9 @@ mod tests { fn test_min_max_with_nulls() { // Test that NULL values are ignored in MIN/MAX let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2645,10 +2624,9 @@ mod tests { fn test_min_max_integer_values() { // Test MIN/MAX with integer values instead of floats let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2713,10 +2691,9 @@ mod tests { fn test_min_max_text_values() { // Test MIN/MAX with text values (alphabetical ordering) let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2752,10 +2729,9 @@ mod tests { #[test] fn test_min_max_with_other_aggregates() { let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2844,10 +2820,9 @@ mod tests { #[test] fn test_min_max_multiple_columns() { let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( @@ -2923,10 +2898,9 @@ mod tests { fn test_join_operator_inner() { // Test INNER JOIN with incremental updates let (pager, table_page_id, index_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_page_id, 10); let index_def = create_dbsp_state_index(index_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_page_id, &index_def, 10); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut join = JoinOperator::new( 1, // operator_id @@ -3020,10 +2994,9 @@ mod tests { fn test_join_operator_with_deletions() { // Test INNER JOIN with deletions (negative weights) let (pager, table_page_id, index_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_page_id, 10); let index_def = create_dbsp_state_index(index_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_page_id, &index_def, 10); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut join = JoinOperator::new( @@ -3111,10 +3084,9 @@ mod tests { fn test_join_operator_one_to_many() { // Test one-to-many relationship: one customer with multiple orders let (pager, table_page_id, index_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_page_id, 10); let index_def = create_dbsp_state_index(index_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_page_id, &index_def, 10); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut join = JoinOperator::new( @@ -3248,10 +3220,9 @@ mod tests { fn test_join_operator_many_to_many() { // Test many-to-many: multiple rows with same key on both sides let (pager, table_page_id, index_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_page_id, 10); let index_def = create_dbsp_state_index(index_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_page_id, &index_def, 10); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut join = JoinOperator::new( @@ -3365,10 +3336,9 @@ mod tests { fn test_join_operator_update_in_one_to_many() { // Test updates in one-to-many scenarios let (pager, table_page_id, index_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_page_id, 10); let index_def = create_dbsp_state_index(index_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_page_id, &index_def, 10); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut join = JoinOperator::new( @@ -3488,10 +3458,9 @@ mod tests { fn test_join_operator_weight_accumulation_complex() { // Test complex weight accumulation with multiple identical rows let (pager, table_page_id, index_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_page_id, 10); let index_def = create_dbsp_state_index(index_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_page_id, &index_def, 10); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut join = JoinOperator::new( @@ -3624,9 +3593,9 @@ mod tests { let mut state = EvalState::Init { deltas: delta_pair }; let (pager, table_root, index_root) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root, 5); let index_def = create_dbsp_state_index(index_root); - let index_cursor = BTreeCursor::new_index(None, pager.clone(), index_root, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let result = pager @@ -3684,10 +3653,10 @@ mod tests { #[test] fn test_merge_operator_basic() { let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(_pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); let index_cursor = - BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + BTreeCursor::new_index(_pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut merge_op = MergeOperator::new( @@ -3745,10 +3714,10 @@ mod tests { #[test] fn test_merge_operator_stateful_distinct() { let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(_pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); let index_cursor = - BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + BTreeCursor::new_index(_pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); // Test that UNION (distinct) properly deduplicates across multiple operations @@ -3819,10 +3788,10 @@ mod tests { #[test] fn test_merge_operator_single_sided_inputs_union_all() { let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(_pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); let index_cursor = - BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + BTreeCursor::new_index(_pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); // Test UNION ALL with inputs coming from only one side at a time @@ -3939,10 +3908,10 @@ mod tests { #[test] fn test_merge_operator_both_sides_empty() { let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(_pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); let index_cursor = - BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + BTreeCursor::new_index(_pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); // Test that both sides being empty works correctly @@ -4019,10 +3988,9 @@ mod tests { // Test that aggregate state serialization correctly preserves column indices // when multiple aggregates operate on different columns let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); - let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); let index_def = create_dbsp_state_index(index_root_page_id); - let index_cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); // Create first operator with SUM(col1), MIN(col3) GROUP BY col0 @@ -4124,4 +4092,518 @@ mod tests { "MIN(col1) should be 20 (new data only)" ); } + + #[test] + fn test_distinct_removes_duplicates() { + let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Create a DISTINCT operator that groups by all columns + let mut operator = AggregateOperator::new( + 0, // operator_id + vec![0], // group by column 0 (value) + vec![], // Empty aggregates for plain DISTINCT + vec!["value".to_string()], + ); + + // Create input with duplicates + let mut input = Delta::new(); + input.insert(1, vec![Value::Integer(100)]); // First 100 + input.insert(2, vec![Value::Integer(200)]); // First 200 + input.insert(3, vec![Value::Integer(100)]); // Duplicate 100 + input.insert(4, vec![Value::Integer(300)]); // First 300 + input.insert(5, vec![Value::Integer(200)]); // Duplicate 200 + input.insert(6, vec![Value::Integer(100)]); // Another duplicate 100 + + // Execute commit (for materialized views) instead of eval + let result = pager + .io + .block(|| operator.commit((&input).into(), &mut cursors)) + .unwrap(); + + // Should have exactly 3 distinct values (100, 200, 300) + let distinct_values: std::collections::HashSet = result + .changes + .iter() + .map(|(row, _weight)| match &row.values[0] { + Value::Integer(i) => *i, + _ => panic!("Expected integer value"), + }) + .collect(); + + assert_eq!( + distinct_values.len(), + 3, + "Should have exactly 3 distinct values" + ); + assert!(distinct_values.contains(&100)); + assert!(distinct_values.contains(&200)); + assert!(distinct_values.contains(&300)); + + // All weights should be 1 (distinct normalizes weights) + for (_row, weight) in &result.changes { + assert_eq!(*weight, 1, "DISTINCT should output weight 1 for all groups"); + } + } + + #[test] + fn test_distinct_incremental_updates() { + let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut operator = AggregateOperator::new( + 0, + vec![0, 1], // group by both columns + vec![], // Empty aggregates for plain DISTINCT + vec!["category".to_string(), "value".to_string()], + ); + + // First batch: insert some values + let mut delta1 = Delta::new(); + delta1.insert(1, vec![Value::Text("A".into()), Value::Integer(100)]); + delta1.insert(2, vec![Value::Text("B".into()), Value::Integer(200)]); + delta1.insert(3, vec![Value::Text("A".into()), Value::Integer(100)]); // Duplicate + + // Commit first batch + let result1 = pager + .io + .block(|| operator.commit((&delta1).into(), &mut cursors)) + .unwrap(); + + // Should have 2 distinct groups: (A,100) and (B,200) + assert_eq!( + result1.changes.len(), + 2, + "First commit should output 2 distinct groups" + ); + + // Verify each group appears with weight +1 + for (_row, weight) in &result1.changes { + assert_eq!(*weight, 1, "New groups should have weight +1"); + } + + // Second batch: delete one instance of (A,100) and add new group + let mut delta2 = Delta::new(); + delta2.delete(1, vec![Value::Text("A".into()), Value::Integer(100)]); + delta2.insert(4, vec![Value::Text("C".into()), Value::Integer(300)]); + + let result2 = pager + .io + .block(|| operator.commit((&delta2).into(), &mut cursors)) + .unwrap(); + + // Should only output the new group (C,300) with weight +1 + // (A,100) still exists (weight went from 2 to 1), so no output for it + assert_eq!( + result2.changes.len(), + 1, + "Second commit should only output new group" + ); + + let (row, weight) = &result2.changes[0]; + assert_eq!(*weight, 1); + assert_eq!(row.values[0], Value::Text("C".into())); + assert_eq!(row.values[1], Value::Integer(300)); + + // Third batch: delete last instance of (A,100) + let mut delta3 = Delta::new(); + delta3.delete(3, vec![Value::Text("A".into()), Value::Integer(100)]); + + let result3 = pager + .io + .block(|| operator.commit((&delta3).into(), &mut cursors)) + .unwrap(); + + // Should output (A,100) with weight -1 (group disappeared) + assert_eq!( + result3.changes.len(), + 1, + "Third commit should output disappeared group" + ); + + let (row, weight) = &result3.changes[0]; + assert_eq!(*weight, -1, "Disappeared group should have weight -1"); + assert_eq!(row.values[0], Value::Text("A".into())); + assert_eq!(row.values[1], Value::Integer(100)) + } + + #[test] + fn test_distinct_state_transitions() { + let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Test that DISTINCT correctly tracks state transitions (0 ↔ positive) + let mut operator = AggregateOperator::new( + 0, + vec![0], + vec![], // Empty aggregates for plain DISTINCT + vec!["value".to_string()], + ); + + // Insert value with weight 3 + let mut delta1 = Delta::new(); + for i in 1..=3 { + delta1.insert(i, vec![Value::Integer(100)]); + } + + let result1 = pager + .io + .block(|| operator.commit((&delta1).into(), &mut cursors)) + .unwrap(); + + assert_eq!(result1.changes.len(), 1); + assert_eq!(result1.changes[0].1, 1, "First appearance should output +1"); + + // Remove 2 instances (weight goes from 3 to 1, still positive) + let mut delta2 = Delta::new(); + for i in 1..=2 { + delta2.delete(i, vec![Value::Integer(100)]); + } + + let result2 = pager + .io + .block(|| operator.commit((&delta2).into(), &mut cursors)) + .unwrap(); + + assert_eq!(result2.changes.len(), 0, "No transition, no output"); + + // Remove last instance (weight goes from 1 to 0) + let mut delta3 = Delta::new(); + delta3.delete(3, vec![Value::Integer(100)]); + + let result3 = pager + .io + .block(|| operator.commit((&delta3).into(), &mut cursors)) + .unwrap(); + + assert_eq!(result3.changes.len(), 1); + assert_eq!(result3.changes[0].1, -1, "Disappearance should output -1"); + + // Re-add the value (weight goes from 0 to 1) + let mut delta4 = Delta::new(); + delta4.insert(4, vec![Value::Integer(100)]); + + let result4 = pager + .io + .block(|| operator.commit((&delta4).into(), &mut cursors)) + .unwrap(); + + assert_eq!(result4.changes.len(), 1); + assert_eq!(result4.changes[0].1, 1, "Reappearance should output +1") + } + + #[test] + fn test_distinct_persistence() { + let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // First operator instance + let mut operator1 = AggregateOperator::new( + 0, + vec![0], + vec![], // Empty aggregates for plain DISTINCT + vec!["value".to_string()], + ); + + // Insert some values + let mut delta1 = Delta::new(); + delta1.insert(1, vec![Value::Integer(100)]); + delta1.insert(2, vec![Value::Integer(100)]); // Duplicate + delta1.insert(3, vec![Value::Integer(200)]); + + let result1 = pager + .io + .block(|| operator1.commit((&delta1).into(), &mut cursors)) + .unwrap(); + + // Should have 2 distinct values + assert_eq!(result1.changes.len(), 2, "Should output 2 distinct values"); + + // Create new operator instance with same ID (simulates restart) + // Create new cursors for the second operator + let table_cursor2 = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); + let index_cursor2 = + BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors2 = DbspStateCursors::new(table_cursor2, index_cursor2); + + let mut operator2 = AggregateOperator::new( + 0, // Same operator_id + vec![0], + vec![], // Empty aggregates for plain DISTINCT + vec!["value".to_string()], + ); + + // Add new value and delete existing (100 has weight 2, so it stays) + let mut delta2 = Delta::new(); + delta2.insert(4, vec![Value::Integer(300)]); + delta2.delete(1, vec![Value::Integer(100)]); // Remove one of the 100s + + let result2 = pager + .io + .block(|| operator2.commit((&delta2).into(), &mut cursors2)) + .unwrap(); + + // Should only output the new value (300) + // 100 still exists (went from weight 2 to 1) + assert_eq!(result2.changes.len(), 1, "Should only output new value"); + assert_eq!(result2.changes[0].1, 1, "Should be insertion"); + assert_eq!(result2.changes[0].0.values[0], Value::Integer(300)); + + // Now delete the last instance of 100 + let mut delta3 = Delta::new(); + delta3.delete(2, vec![Value::Integer(100)]); + + let result3 = pager + .io + .block(|| operator2.commit((&delta3).into(), &mut cursors2)) + .unwrap(); + + // Should output deletion of 100 + assert_eq!(result3.changes.len(), 1, "Should output deletion"); + assert_eq!(result3.changes[0].1, -1, "Should be deletion"); + assert_eq!(result3.changes[0].0.values[0], Value::Integer(100)); + } + + #[test] + fn test_distinct_batch_with_multiple_groups() { + let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut operator = AggregateOperator::new( + 0, + vec![0, 1], // group by category and value + vec![], // Empty aggregates for plain DISTINCT + vec!["category".to_string(), "value".to_string()], + ); + + // Create a complex batch with multiple groups and duplicates within each group + let mut delta = Delta::new(); + + // Group (A, 100): 3 instances + delta.insert(1, vec![Value::Text("A".into()), Value::Integer(100)]); + delta.insert(2, vec![Value::Text("A".into()), Value::Integer(100)]); + delta.insert(3, vec![Value::Text("A".into()), Value::Integer(100)]); + + // Group (B, 200): 2 instances + delta.insert(4, vec![Value::Text("B".into()), Value::Integer(200)]); + delta.insert(5, vec![Value::Text("B".into()), Value::Integer(200)]); + + // Group (A, 200): 1 instance + delta.insert(6, vec![Value::Text("A".into()), Value::Integer(200)]); + + // Group (C, 100): 2 instances + delta.insert(7, vec![Value::Text("C".into()), Value::Integer(100)]); + delta.insert(8, vec![Value::Text("C".into()), Value::Integer(100)]); + + // More instances of Group (A, 100) + delta.insert(9, vec![Value::Text("A".into()), Value::Integer(100)]); + delta.insert(10, vec![Value::Text("A".into()), Value::Integer(100)]); + + // Group (B, 100): 1 instance + delta.insert(11, vec![Value::Text("B".into()), Value::Integer(100)]); + + let result = pager + .io + .block(|| operator.commit((&delta).into(), &mut cursors)) + .unwrap(); + + // Should have exactly 5 distinct groups: + // (A, 100), (A, 200), (B, 100), (B, 200), (C, 100) + assert_eq!( + result.changes.len(), + 5, + "Should have exactly 5 distinct groups" + ); + + // All should have weight +1 (new groups appearing) + for (_row, weight) in &result.changes { + assert_eq!(*weight, 1, "All groups should have weight +1"); + } + + // Verify the distinct groups + let groups: std::collections::HashSet<(String, i64)> = result + .changes + .iter() + .map(|(row, _)| { + let category = match &row.values[0] { + Value::Text(s) => String::from_utf8(s.value.clone()).unwrap(), + _ => panic!("Expected text for category"), + }; + let value = match &row.values[1] { + Value::Integer(i) => *i, + _ => panic!("Expected integer for value"), + }; + (category, value) + }) + .collect(); + + assert!(groups.contains(&("A".to_string(), 100))); + assert!(groups.contains(&("A".to_string(), 200))); + assert!(groups.contains(&("B".to_string(), 100))); + assert!(groups.contains(&("B".to_string(), 200))); + assert!(groups.contains(&("C".to_string(), 100))); + } + + #[test] + fn test_multiple_distinct_aggregates_same_column() { + // Test that multiple DISTINCT aggregates on the same column don't interfere + let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); + + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); + + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Create operator with COUNT(DISTINCT value), SUM(DISTINCT value), AVG(DISTINCT value) + // all on the same column + let mut operator = AggregateOperator::new( + 0, + vec![], // No group by - single group + vec![ + AggregateFunction::CountDistinct(0), // COUNT(DISTINCT value) + AggregateFunction::SumDistinct(0), // SUM(DISTINCT value) + AggregateFunction::AvgDistinct(0), // AVG(DISTINCT value) + ], + vec!["value".to_string()], + ); + + // Insert distinct values: 10, 20, 30 (each appearing multiple times) + let mut input = Delta::new(); + input.insert(1, vec![Value::Integer(10)]); + input.insert(2, vec![Value::Integer(10)]); // duplicate + input.insert(3, vec![Value::Integer(20)]); + input.insert(4, vec![Value::Integer(20)]); // duplicate + input.insert(5, vec![Value::Integer(30)]); + input.insert(6, vec![Value::Integer(10)]); // duplicate + + let output = pager + .io + .block(|| operator.commit((&input).into(), &mut cursors)) + .unwrap(); + + // Should have exactly one output row (no group by) + assert_eq!(output.changes.len(), 1); + let (row, weight) = &output.changes[0]; + assert_eq!(*weight, 1); + + // Extract the aggregate values + let values = &row.values; + assert_eq!(values.len(), 3); // 3 aggregate values + + // COUNT(DISTINCT value) should be 3 (distinct values: 10, 20, 30) + assert_eq!(values[0], Value::Integer(3)); + + // SUM(DISTINCT value) should be 60 (10 + 20 + 30) + assert_eq!(values[1], Value::Integer(60)); + + // AVG(DISTINCT value) should be 20.0 (60 / 3) + assert_eq!(values[2], Value::Float(20.0)); + } + + #[test] + fn test_count_distinct_with_deletions() { + let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); + + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut operator = AggregateOperator::new( + 1, + vec![], // No GROUP BY + vec![AggregateFunction::CountDistinct(1)], + vec!["id".to_string(), "value".to_string()], + ); + + // Insert 3 distinct values + let mut delta1 = Delta::new(); + delta1.insert(1, vec![Value::Integer(1), Value::Integer(100)]); + delta1.insert(2, vec![Value::Integer(2), Value::Integer(200)]); + delta1.insert(3, vec![Value::Integer(3), Value::Integer(300)]); + + let result1 = pager + .io + .block(|| operator.commit((&delta1).into(), &mut cursors)) + .unwrap(); + + assert_eq!(result1.changes.len(), 1); + assert_eq!(result1.changes[0].1, 1); + assert_eq!(result1.changes[0].0.values[0], Value::Integer(3)); + + // Delete one value + let mut delta2 = Delta::new(); + delta2.delete(2, vec![Value::Integer(2), Value::Integer(200)]); + + let result2 = pager + .io + .block(|| operator.commit((&delta2).into(), &mut cursors)) + .unwrap(); + + assert_eq!(result2.changes.len(), 2); + let new_row = result2.changes.iter().find(|(_, w)| *w == 1).unwrap(); + assert_eq!(new_row.0.values[0], Value::Integer(2)); + } + + #[test] + fn test_sum_distinct_with_deletions() { + let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); + + let table_cursor = BTreeCursor::new_table(pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = BTreeCursor::new_index(pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut operator = AggregateOperator::new( + 1, + vec![], + vec![AggregateFunction::SumDistinct(1)], + vec!["id".to_string(), "value".to_string()], + ); + + // Insert values including a duplicate + let mut delta1 = Delta::new(); + delta1.insert(1, vec![Value::Integer(1), Value::Integer(100)]); + delta1.insert(2, vec![Value::Integer(2), Value::Integer(200)]); + delta1.insert(3, vec![Value::Integer(3), Value::Integer(100)]); // Duplicate + delta1.insert(4, vec![Value::Integer(4), Value::Integer(300)]); + + let result1 = pager + .io + .block(|| operator.commit((&delta1).into(), &mut cursors)) + .unwrap(); + + assert_eq!(result1.changes.len(), 1); + assert_eq!(result1.changes[0].1, 1); + assert_eq!(result1.changes[0].0.values[0], Value::Float(600.0)); // 100 + 200 + 300 + + // Delete value 200 + let mut delta2 = Delta::new(); + delta2.delete(2, vec![Value::Integer(2), Value::Integer(200)]); + + let result2 = pager + .io + .block(|| operator.commit((&delta2).into(), &mut cursors)) + .unwrap(); + + assert_eq!(result2.changes.len(), 2); + let new_row = result2.changes.iter().find(|(_, w)| *w == 1).unwrap(); + assert_eq!(new_row.0.values[0], Value::Float(400.0)); // 100 + 300 + } } diff --git a/core/incremental/persistence.rs b/core/incremental/persistence.rs index 81d0837c2..26b9b8b3f 100644 --- a/core/incremental/persistence.rs +++ b/core/incremental/persistence.rs @@ -1,5 +1,5 @@ use crate::incremental::operator::{AggregateState, DbspStateCursors}; -use crate::storage::btree::{BTreeCursor, BTreeKey}; +use crate::storage::btree::{BTreeCursor, BTreeKey, CursorTrait}; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; use crate::{return_if_io, LimboError, Result, Value}; @@ -8,7 +8,7 @@ pub enum ReadRecord { #[default] GetRecord, Done { - state: Option, + state: Box>, }, } @@ -27,7 +27,9 @@ impl ReadRecord { ReadRecord::GetRecord => { let res = return_if_io!(cursor.seek(key.clone(), SeekOp::GE { eq_only: true })); if !matches!(res, SeekResult::Found) { - *self = ReadRecord::Done { state: None }; + *self = ReadRecord::Done { + state: Box::new(None), + }; } else { let record = return_if_io!(cursor.record()); let r = record.ok_or_else(|| { @@ -41,14 +43,21 @@ impl ReadRecord { let (state, _group_key) = match blob { Value::Blob(blob) => AggregateState::from_blob(&blob), + Value::Null => { + // For plain DISTINCT, we store null value and just track weight + // Return a minimal state indicating existence + Ok((AggregateState::default(), vec![])) + } _ => Err(LimboError::ParseError( - "Value in aggregator not blob".to_string(), + "Value in aggregator not blob or null".to_string(), )), }?; - *self = ReadRecord::Done { state: Some(state) } + *self = ReadRecord::Done { + state: Box::new(Some(state)), + } } } - ReadRecord::Done { state } => return Ok(IOResult::Done(state.clone())), + ReadRecord::Done { state } => return Ok(IOResult::Done((**state).clone())), } } } diff --git a/core/incremental/project_operator.rs b/core/incremental/project_operator.rs index b82a1a138..8103ac7ff 100644 --- a/core/incremental/project_operator.rs +++ b/core/incremental/project_operator.rs @@ -86,7 +86,7 @@ impl ProjectOperator { for col in &self.columns { // Use the internal connection's pager for expression evaluation - let internal_pager = self.internal_conn.pager.read().clone(); + let internal_pager = self.internal_conn.pager.load().clone(); // Execute the compiled expression (handles both columns and complex expressions) let result = col diff --git a/core/incremental/view.rs b/core/incremental/view.rs index cb90b57b9..a82a1188b 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -2,7 +2,7 @@ use super::compiler::{DbspCircuit, DbspCompiler, DeltaSet}; use super::dbsp::Delta; use super::operator::ComputationTracker; use crate::schema::{BTreeTable, Schema}; -use crate::storage::btree::BTreeCursor; +use crate::storage::btree::CursorTrait; use crate::translate::logical::LogicalPlanBuilder; use crate::types::{IOResult, Value}; use crate::util::{extract_view_columns, ViewColumnSchema}; @@ -40,6 +40,11 @@ pub enum PopulateState { Done, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for PopulateState {} +unsafe impl Sync for PopulateState {} + /// State machine for merge_delta to handle I/O operations impl fmt::Debug for PopulateState { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -130,6 +135,11 @@ pub struct AllViewsTxState { states: Rc>>>, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for AllViewsTxState {} +unsafe impl Sync for AllViewsTxState {} + impl AllViewsTxState { /// Create a new container for view transaction states pub fn new() -> Self { @@ -210,6 +220,11 @@ pub struct IncrementalView { root_page: i64, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for IncrementalView {} +unsafe impl Sync for IncrementalView {} + impl IncrementalView { /// Try to compile the SELECT statement into a DBSP circuit fn try_compile_circuit( @@ -1112,7 +1127,7 @@ impl IncrementalView { &mut self, conn: &std::sync::Arc, pager: &std::sync::Arc, - _btree_cursor: &mut BTreeCursor, + _btree_cursor: &mut dyn CursorTrait, ) -> crate::Result> { // Assert that this is a materialized view with a root page assert!( @@ -1282,7 +1297,7 @@ impl IncrementalView { pending_row: None, // No pending row when interrupted between rows }; // TODO: Get the actual I/O completion from the statement - let completion = crate::io::Completion::new_dummy(); + let completion = crate::io::Completion::new_yield(); return Ok(IOResult::IO(crate::types::IOCompletions::Single( completion, ))); @@ -1411,6 +1426,7 @@ mod tests { has_rowid: true, is_strict: false, unique_sets: vec![], + foreign_keys: vec![], has_autoincrement: false, }; @@ -1460,6 +1476,7 @@ mod tests { has_rowid: true, is_strict: false, has_autoincrement: false, + foreign_keys: vec![], unique_sets: vec![], }; @@ -1509,6 +1526,7 @@ mod tests { has_rowid: true, is_strict: false, has_autoincrement: false, + foreign_keys: vec![], unique_sets: vec![], }; @@ -1558,13 +1576,26 @@ mod tests { has_rowid: true, // Has implicit rowid but no alias is_strict: false, has_autoincrement: false, + foreign_keys: vec![], unique_sets: vec![], }; - schema.add_btree_table(Arc::new(customers_table)); - schema.add_btree_table(Arc::new(orders_table)); - schema.add_btree_table(Arc::new(products_table)); - schema.add_btree_table(Arc::new(logs_table)); + schema + .add_btree_table(Arc::new(customers_table)) + .expect("Test setup: failed to add customers table"); + + schema + .add_btree_table(Arc::new(orders_table)) + .expect("Test setup: failed to add orders table"); + + schema + .add_btree_table(Arc::new(products_table)) + .expect("Test setup: failed to add products table"); + + schema + .add_btree_table(Arc::new(logs_table)) + .expect("Test setup: failed to add logs table"); + schema } diff --git a/core/io/clock.rs b/core/io/clock.rs index d522ac278..06edc65e3 100644 --- a/core/io/clock.rs +++ b/core/io/clock.rs @@ -87,3 +87,15 @@ impl std::ops::Sub for Instant { pub trait Clock { fn now(&self) -> Instant; } + +pub struct DefaultClock; + +impl Clock for DefaultClock { + fn now(&self) -> Instant { + let now = chrono::Local::now(); + Instant { + secs: now.timestamp(), + micros: now.timestamp_subsec_micros(), + } + } +} diff --git a/core/io/completions.rs b/core/io/completions.rs new file mode 100644 index 000000000..a381324c6 --- /dev/null +++ b/core/io/completions.rs @@ -0,0 +1,984 @@ +use core::fmt::{self, Debug}; +use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, OnceLock, + }, + task::Waker, +}; + +use parking_lot::Mutex; + +use crate::{Buffer, CompletionError}; + +pub type ReadComplete = dyn Fn(Result<(Arc, i32), CompletionError>); +pub type WriteComplete = dyn Fn(Result); +pub type SyncComplete = dyn Fn(Result); +pub type TruncateComplete = dyn Fn(Result); + +#[must_use] +#[derive(Debug, Clone)] +pub struct Completion { + /// Optional completion state. If None, it means we are Yield in order to not allocate anything + pub(super) inner: Option>, +} + +#[derive(Debug, Default)] +struct ContextInner { + waker: Option, + // TODO: add abort signal +} + +#[derive(Debug, Clone)] +pub struct Context { + inner: Arc>, +} + +impl ContextInner { + pub fn new() -> Self { + Self { waker: None } + } + + pub fn wake(&mut self) { + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } + + pub fn set_waker(&mut self, waker: &Waker) { + if let Some(curr_waker) = self.waker.as_mut() { + // only call and change waker if it would awake a different task + if !curr_waker.will_wake(waker) { + let prev_waker = std::mem::replace(curr_waker, waker.clone()); + prev_waker.wake(); + } + } else { + self.waker = Some(waker.clone()); + } + } +} + +impl Context { + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(ContextInner::new())), + } + } + + pub fn wake(&self) { + self.inner.lock().wake(); + } + + pub fn set_waker(&self, waker: &Waker) { + self.inner.lock().set_waker(waker); + } +} + +pub(super) struct CompletionInner { + completion_type: CompletionType, + /// None means we completed successfully + // Thread safe with OnceLock + pub(super) result: std::sync::OnceLock>, + needs_link: bool, + context: Context, + /// Optional parent group this completion belongs to + parent: OnceLock>, +} + +impl fmt::Debug for CompletionInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CompletionInner") + .field("completion_type", &self.completion_type) + .field("needs_link", &self.needs_link) + .field("parent", &self.parent.get().is_some()) + .finish() + } +} + +pub struct CompletionGroup { + completions: Vec, + callback: Box) + Send + Sync>, +} + +impl CompletionGroup { + pub fn new(callback: F) -> Self + where + F: Fn(Result) + Send + Sync + 'static, + { + Self { + completions: Vec::new(), + callback: Box::new(callback), + } + } + + pub fn add(&mut self, completion: &Completion) { + self.completions.push(completion.clone()); + } + + pub fn cancel(&self) { + for c in &self.completions { + c.abort(); + } + } + + pub fn build(self) -> Completion { + let total = self.completions.len(); + if total == 0 { + (self.callback)(Ok(0)); + return Completion::new_yield(); + } + let group_completion = GroupCompletion::new(self.callback, total); + let group = Completion::new(CompletionType::Group(group_completion)); + + // Store the group completion reference for later callback + if let CompletionType::Group(ref g) = group.get_inner().completion_type { + let _ = g.inner.self_completion.set(group.clone()); + } + + for mut c in self.completions { + // If the completion has not completed, link it to the group. + if !c.finished() { + c.link_internal(&group); + continue; + } + let group_inner = match &group.get_inner().completion_type { + CompletionType::Group(g) => &g.inner, + _ => unreachable!(), + }; + // Return early if there was an error. + if let Some(err) = c.get_error() { + let _ = group_inner.result.set(Some(err)); + group_inner.outstanding.store(0, Ordering::SeqCst); + (group_inner.complete)(Err(err)); + return group; + } + // Mark the successful completion as done. + group_inner.outstanding.fetch_sub(1, Ordering::SeqCst); + } + + let group_inner = match &group.get_inner().completion_type { + CompletionType::Group(g) => &g.inner, + _ => unreachable!(), + }; + if group_inner.outstanding.load(Ordering::SeqCst) == 0 { + (group_inner.complete)(Ok(0)); + } + group + } +} + +pub struct GroupCompletion { + inner: Arc, +} + +impl fmt::Debug for GroupCompletion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GroupCompletion") + .field( + "outstanding", + &self.inner.outstanding.load(Ordering::SeqCst), + ) + .finish() + } +} + +struct GroupCompletionInner { + /// Number of completions that need to finish + outstanding: AtomicUsize, + /// Callback to invoke when all completions finish + complete: Box) + Send + Sync>, + /// Cached result after all completions finish + result: OnceLock>, + /// Reference to the group's own Completion for notifying parents + self_completion: OnceLock, +} + +impl GroupCompletion { + pub fn new(complete: F, outstanding: usize) -> Self + where + F: Fn(Result) + Send + Sync + 'static, + { + Self { + inner: Arc::new(GroupCompletionInner { + outstanding: AtomicUsize::new(outstanding), + complete: Box::new(complete), + result: OnceLock::new(), + self_completion: OnceLock::new(), + }), + } + } + + pub fn callback(&self, result: Result) { + assert_eq!( + self.inner.outstanding.load(Ordering::SeqCst), + 0, + "callback called before all completions finished" + ); + (self.inner.complete)(result); + } +} + +impl Debug for CompletionType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Read(..) => f.debug_tuple("Read").finish(), + Self::Write(..) => f.debug_tuple("Write").finish(), + Self::Sync(..) => f.debug_tuple("Sync").finish(), + Self::Truncate(..) => f.debug_tuple("Truncate").finish(), + Self::Group(..) => f.debug_tuple("Group").finish(), + Self::Yield => f.debug_tuple("Yield").finish(), + } + } +} + +pub enum CompletionType { + Read(ReadCompletion), + Write(WriteCompletion), + Sync(SyncCompletion), + Truncate(TruncateCompletion), + Group(GroupCompletion), + Yield, +} + +impl CompletionInner { + fn new(completion_type: CompletionType, needs_link: bool) -> Self { + Self { + completion_type, + result: OnceLock::new(), + needs_link, + context: Context::new(), + parent: OnceLock::new(), + } + } +} + +impl Completion { + pub fn new(completion_type: CompletionType) -> Self { + Self { + inner: Some(Arc::new(CompletionInner::new(completion_type, false))), + } + } + + pub fn new_linked(completion_type: CompletionType) -> Self { + Self { + inner: Some(Arc::new(CompletionInner::new(completion_type, true))), + } + } + + pub(super) fn get_inner(&self) -> &Arc { + self.inner.as_ref().unwrap() + } + + pub fn needs_link(&self) -> bool { + self.get_inner().needs_link + } + + pub fn new_write_linked(complete: F) -> Self + where + F: Fn(Result) + 'static, + { + Self::new_linked(CompletionType::Write(WriteCompletion::new(Box::new( + complete, + )))) + } + + pub fn new_write(complete: F) -> Self + where + F: Fn(Result) + 'static, + { + Self::new(CompletionType::Write(WriteCompletion::new(Box::new( + complete, + )))) + } + + pub fn new_read(buf: Arc, complete: F) -> Self + where + F: Fn(Result<(Arc, i32), CompletionError>) + 'static, + { + Self::new(CompletionType::Read(ReadCompletion::new( + buf, + Box::new(complete), + ))) + } + pub fn new_sync(complete: F) -> Self + where + F: Fn(Result) + 'static, + { + Self::new(CompletionType::Sync(SyncCompletion::new(Box::new( + complete, + )))) + } + + pub fn new_trunc(complete: F) -> Self + where + F: Fn(Result) + 'static, + { + Self::new(CompletionType::Truncate(TruncateCompletion::new(Box::new( + complete, + )))) + } + + /// Create a yield completion. These are completed by default allowing to yield control without + /// allocating memory. + pub fn new_yield() -> Self { + Self { inner: None } + } + + pub fn wake(&self) { + self.get_inner().context.wake(); + } + + pub fn set_waker(&self, waker: &Waker) { + if self.finished() || self.inner.is_none() { + waker.wake_by_ref(); + } else { + self.get_inner().context.set_waker(waker); + } + } + + pub fn succeeded(&self) -> bool { + match &self.inner { + Some(inner) => match &inner.completion_type { + CompletionType::Group(g) => { + g.inner.outstanding.load(Ordering::SeqCst) == 0 + && g.inner.result.get().is_none_or(|e| e.is_none()) + } + _ => inner.result.get().is_some(), + }, + None => true, + } + } + + pub fn failed(&self) -> bool { + match &self.inner { + Some(inner) => inner.result.get().is_some_and(|val| val.is_some()), + None => false, + } + } + + pub fn get_error(&self) -> Option { + match &self.inner { + Some(inner) => { + match &inner.completion_type { + CompletionType::Group(g) => { + // For groups, check the group's cached result field + // (set when the last completion finishes) + g.inner.result.get().and_then(|res| *res) + } + _ => inner.result.get().and_then(|res| *res), + } + } + None => None, + } + } + + /// Checks if the Completion completed or errored + pub fn finished(&self) -> bool { + match &self.inner { + Some(inner) => match &inner.completion_type { + CompletionType::Group(g) => g.inner.outstanding.load(Ordering::SeqCst) == 0, + _ => inner.result.get().is_some(), + }, + None => true, + } + } + + pub fn complete(&self, result: i32) { + let result = Ok(result); + self.callback(result); + } + + pub fn error(&self, err: CompletionError) { + let result = Err(err); + self.callback(result); + } + + pub fn abort(&self) { + self.error(CompletionError::Aborted); + } + + fn callback(&self, result: Result) { + let inner = self.get_inner(); + inner.result.get_or_init(|| { + match &inner.completion_type { + CompletionType::Read(r) => r.callback(result), + CompletionType::Write(w) => w.callback(result), + CompletionType::Sync(s) => s.callback(result), // fix + CompletionType::Truncate(t) => t.callback(result), + CompletionType::Group(g) => g.callback(result), + CompletionType::Yield => {} + }; + + if let Some(group) = inner.parent.get() { + // Capture first error in group + if let Err(err) = result { + let _ = group.result.set(Some(err)); + } + let prev = group.outstanding.fetch_sub(1, Ordering::SeqCst); + + // If this was the last completion in the group, trigger the group's callback + // which will recursively call this same callback() method to notify parents + if prev == 1 { + if let Some(group_completion) = group.self_completion.get() { + let group_result = group.result.get().and_then(|e| *e); + group_completion.callback(group_result.map_or(Ok(0), Err)); + } + } + } + + result.err() + }); + // call the waker regardless + inner.context.wake(); + } + + /// only call this method if you are sure that the completion is + /// a ReadCompletion, panics otherwise + pub fn as_read(&self) -> &ReadCompletion { + let inner = self.get_inner(); + match inner.completion_type { + CompletionType::Read(ref r) => r, + _ => unreachable!(), + } + } + + /// Link this completion to a group completion (internal use only) + fn link_internal(&mut self, group: &Completion) { + let group_inner = match &group.get_inner().completion_type { + CompletionType::Group(g) => &g.inner, + _ => panic!("link_internal() requires a group completion"), + }; + + // Set the parent (can only be set once) + if self.get_inner().parent.set(group_inner.clone()).is_err() { + panic!("completion can only be linked once"); + } + } +} + +pub struct ReadCompletion { + pub buf: Arc, + pub complete: Box, +} + +impl ReadCompletion { + pub fn new(buf: Arc, complete: Box) -> Self { + Self { buf, complete } + } + + pub fn buf(&self) -> &Buffer { + &self.buf + } + + pub fn callback(&self, bytes_read: Result) { + (self.complete)(bytes_read.map(|b| (self.buf.clone(), b))); + } + + pub fn buf_arc(&self) -> Arc { + self.buf.clone() + } +} + +pub struct WriteCompletion { + pub complete: Box, +} + +impl WriteCompletion { + pub fn new(complete: Box) -> Self { + Self { complete } + } + + pub fn callback(&self, bytes_written: Result) { + (self.complete)(bytes_written); + } +} + +pub struct SyncCompletion { + pub complete: Box, +} + +impl SyncCompletion { + pub fn new(complete: Box) -> Self { + Self { complete } + } + + pub fn callback(&self, res: Result) { + (self.complete)(res); + } +} + +pub struct TruncateCompletion { + pub complete: Box, +} + +impl TruncateCompletion { + pub fn new(complete: Box) -> Self { + Self { complete } + } + + pub fn callback(&self, res: Result) { + (self.complete)(res); + } +} + +#[cfg(test)] +mod tests { + use crate::CompletionError; + + use super::*; + + #[test] + fn test_completion_group_empty() { + use std::sync::atomic::{AtomicBool, Ordering}; + + let callback_called = Arc::new(AtomicBool::new(false)); + let callback_called_clone = callback_called.clone(); + + let group = CompletionGroup::new(move |_| { + callback_called_clone.store(true, Ordering::SeqCst); + }); + let group = group.build(); + assert!(group.finished()); + assert!(group.succeeded()); + assert!(group.get_error().is_none()); + + // Verify the callback was actually called + assert!( + callback_called.load(Ordering::SeqCst), + "callback should be called for empty group" + ); + } + + #[test] + fn test_completion_group_single_completion() { + let mut group = CompletionGroup::new(|_| {}); + let c = Completion::new_write(|_| {}); + group.add(&c); + let group = group.build(); + + assert!(!group.finished()); + assert!(!group.succeeded()); + + c.complete(0); + + assert!(group.finished()); + assert!(group.succeeded()); + assert!(group.get_error().is_none()); + } + + #[test] + fn test_completion_group_multiple_completions() { + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + let c3 = Completion::new_write(|_| {}); + group.add(&c1); + group.add(&c2); + group.add(&c3); + let group = group.build(); + + assert!(!group.succeeded()); + assert!(!group.finished()); + + c1.complete(0); + assert!(!group.succeeded()); + assert!(!group.finished()); + + c2.complete(0); + assert!(!group.succeeded()); + assert!(!group.finished()); + + c3.complete(0); + assert!(group.succeeded()); + assert!(group.finished()); + } + + #[test] + fn test_completion_group_with_error() { + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + group.add(&c1); + group.add(&c2); + let group = group.build(); + + c1.complete(0); + c2.error(CompletionError::Aborted); + + assert!(group.finished()); + assert!(!group.succeeded()); + assert_eq!(group.get_error(), Some(CompletionError::Aborted)); + } + + #[test] + fn test_completion_group_callback() { + use std::sync::atomic::{AtomicBool, Ordering}; + let called = Arc::new(AtomicBool::new(false)); + let called_clone = called.clone(); + + let mut group = CompletionGroup::new(move |_| { + called_clone.store(true, Ordering::SeqCst); + }); + + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + group.add(&c1); + group.add(&c2); + let group = group.build(); + + assert!(!called.load(Ordering::SeqCst)); + + c1.complete(0); + assert!(!called.load(Ordering::SeqCst)); + + c2.complete(0); + assert!(called.load(Ordering::SeqCst)); + assert!(group.finished()); + assert!(group.succeeded()); + } + + #[test] + fn test_completion_group_some_already_completed() { + // Test some completions added to group, then finish before build() + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + let c3 = Completion::new_write(|_| {}); + + // Add all to group while pending + group.add(&c1); + group.add(&c2); + group.add(&c3); + + // Complete c1 and c2 AFTER adding but BEFORE build() + c1.complete(0); + c2.complete(0); + + let group = group.build(); + + // c1 and c2 finished before build(), so outstanding should account for them + // Only c3 should be pending + assert!(!group.finished()); + assert!(!group.succeeded()); + + // Complete c3 + c3.complete(0); + + // Now the group should be finished + assert!(group.finished()); + assert!(group.succeeded()); + assert!(group.get_error().is_none()); + } + + #[test] + fn test_completion_group_all_already_completed() { + // Test when all completions are already finished before build() + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + + // Complete both before adding to group + c1.complete(0); + c2.complete(0); + + group.add(&c1); + group.add(&c2); + + let group = group.build(); + + // All completions were already complete, so group should be finished immediately + assert!(group.finished()); + assert!(group.succeeded()); + assert!(group.get_error().is_none()); + } + + #[test] + fn test_completion_group_mixed_finished_and_pending() { + use std::sync::atomic::{AtomicBool, Ordering}; + let called = Arc::new(AtomicBool::new(false)); + let called_clone = called.clone(); + + let mut group = CompletionGroup::new(move |_| { + called_clone.store(true, Ordering::SeqCst); + }); + + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + let c3 = Completion::new_write(|_| {}); + let c4 = Completion::new_write(|_| {}); + + // Complete c1 and c3 before adding to group + c1.complete(0); + c3.complete(0); + + group.add(&c1); + group.add(&c2); + group.add(&c3); + group.add(&c4); + + let group = group.build(); + + // Only c2 and c4 should be pending + assert!(!group.finished()); + assert!(!called.load(Ordering::SeqCst)); + + c2.complete(0); + assert!(!group.finished()); + assert!(!called.load(Ordering::SeqCst)); + + c4.complete(0); + assert!(group.finished()); + assert!(group.succeeded()); + assert!(called.load(Ordering::SeqCst)); + } + + #[test] + fn test_completion_group_already_completed_with_error() { + // Test when a completion finishes with error before build() + let mut group = CompletionGroup::new(|_| {}); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + + // Complete c1 with error before adding to group + c1.error(CompletionError::Aborted); + + group.add(&c1); + group.add(&c2); + + let group = group.build(); + + // Group should immediately fail with the error + assert!(group.finished()); + assert!(!group.succeeded()); + assert_eq!(group.get_error(), Some(CompletionError::Aborted)); + } + + #[test] + fn test_completion_group_tracks_all_completions() { + // This test verifies the fix for the bug where CompletionGroup::add() + // would skip successfully-finished completions. This caused problems + // when code used drain() to move completions into a group, because + // finished completions would be removed from the source but not tracked + // by the group, effectively losing them. + use std::sync::atomic::{AtomicUsize, Ordering}; + + let callback_count = Arc::new(AtomicUsize::new(0)); + let callback_count_clone = callback_count.clone(); + + // Simulate the pattern: create multiple completions, complete some, + // then add ALL of them to a group (like drain() would do) + let mut completions = Vec::new(); + + // Create 4 completions + for _ in 0..4 { + completions.push(Completion::new_write(|_| {})); + } + + // Complete 2 of them before adding to group (simulate async completion) + completions[0].complete(0); + completions[2].complete(0); + + // Now create a group and add ALL completions (like drain() would do) + let mut group = CompletionGroup::new(move |_| { + callback_count_clone.fetch_add(1, Ordering::SeqCst); + }); + + // Add all completions to the group + for c in &completions { + group.add(c); + } + + let group = group.build(); + + // The group should track all 4 completions: + // - c[0] and c[2] are already finished + // - c[1] and c[3] are still pending + // So the group should not be finished yet + assert!(!group.finished()); + assert_eq!(callback_count.load(Ordering::SeqCst), 0); + + // Complete the first pending completion + completions[1].complete(0); + assert!(!group.finished()); + assert_eq!(callback_count.load(Ordering::SeqCst), 0); + + // Complete the last pending completion - now group should finish + completions[3].complete(0); + assert!(group.finished()); + assert!(group.succeeded()); + assert_eq!(callback_count.load(Ordering::SeqCst), 1); + + // Verify no errors + assert!(group.get_error().is_none()); + } + + #[test] + fn test_completion_group_with_all_finished_successfully() { + // Edge case: all completions are already successfully finished + // when added to the group. The group should complete immediately. + use std::sync::atomic::{AtomicBool, Ordering}; + + let callback_called = Arc::new(AtomicBool::new(false)); + let callback_called_clone = callback_called.clone(); + + let mut completions = Vec::new(); + + // Create and immediately complete 3 completions + for _ in 0..3 { + let c = Completion::new_write(|_| {}); + c.complete(0); + completions.push(c); + } + + // Add all already-completed completions to group + let mut group = CompletionGroup::new(move |_| { + callback_called_clone.store(true, Ordering::SeqCst); + }); + + for c in &completions { + group.add(c); + } + + let group = group.build(); + + // Group should be immediately finished since all completions were done + assert!(group.finished()); + assert!(group.succeeded()); + assert!(callback_called.load(Ordering::SeqCst)); + assert!(group.get_error().is_none()); + } + + #[test] + fn test_completion_group_nested() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + // Track callbacks at different levels + let parent_called = Arc::new(AtomicUsize::new(0)); + let child1_called = Arc::new(AtomicUsize::new(0)); + let child2_called = Arc::new(AtomicUsize::new(0)); + + // Create child group 1 with 2 completions + let child1_called_clone = child1_called.clone(); + let mut child_group1 = CompletionGroup::new(move |_| { + child1_called_clone.fetch_add(1, Ordering::SeqCst); + }); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + child_group1.add(&c1); + child_group1.add(&c2); + let child_group1 = child_group1.build(); + + // Create child group 2 with 2 completions + let child2_called_clone = child2_called.clone(); + let mut child_group2 = CompletionGroup::new(move |_| { + child2_called_clone.fetch_add(1, Ordering::SeqCst); + }); + let c3 = Completion::new_write(|_| {}); + let c4 = Completion::new_write(|_| {}); + child_group2.add(&c3); + child_group2.add(&c4); + let child_group2 = child_group2.build(); + + // Create parent group containing both child groups + let parent_called_clone = parent_called.clone(); + let mut parent_group = CompletionGroup::new(move |_| { + parent_called_clone.fetch_add(1, Ordering::SeqCst); + }); + parent_group.add(&child_group1); + parent_group.add(&child_group2); + let parent_group = parent_group.build(); + + // Initially nothing should be finished + assert!(!parent_group.finished()); + assert!(!child_group1.finished()); + assert!(!child_group2.finished()); + assert_eq!(parent_called.load(Ordering::SeqCst), 0); + assert_eq!(child1_called.load(Ordering::SeqCst), 0); + assert_eq!(child2_called.load(Ordering::SeqCst), 0); + + // Complete first completion in child group 1 + c1.complete(0); + assert!(!child_group1.finished()); + assert!(!parent_group.finished()); + assert_eq!(child1_called.load(Ordering::SeqCst), 0); + assert_eq!(parent_called.load(Ordering::SeqCst), 0); + + // Complete second completion in child group 1 - should finish child group 1 + c2.complete(0); + assert!(child_group1.finished()); + assert!(child_group1.succeeded()); + assert_eq!(child1_called.load(Ordering::SeqCst), 1); + + // Parent should not be finished yet because child group 2 is still pending + assert!(!parent_group.finished()); + assert_eq!(parent_called.load(Ordering::SeqCst), 0); + + // Complete first completion in child group 2 + c3.complete(0); + assert!(!child_group2.finished()); + assert!(!parent_group.finished()); + assert_eq!(child2_called.load(Ordering::SeqCst), 0); + assert_eq!(parent_called.load(Ordering::SeqCst), 0); + + // Complete second completion in child group 2 - should finish everything + c4.complete(0); + assert!(child_group2.finished()); + assert!(child_group2.succeeded()); + assert_eq!(child2_called.load(Ordering::SeqCst), 1); + + // Parent should now be finished + assert!(parent_group.finished()); + assert!(parent_group.succeeded()); + assert_eq!(parent_called.load(Ordering::SeqCst), 1); + assert!(parent_group.get_error().is_none()); + } + + #[test] + fn test_completion_group_nested_with_error() { + use std::sync::atomic::{AtomicBool, Ordering}; + + let parent_called = Arc::new(AtomicBool::new(false)); + let child_called = Arc::new(AtomicBool::new(false)); + + // Create child group with 2 completions + let child_called_clone = child_called.clone(); + let mut child_group = CompletionGroup::new(move |_| { + child_called_clone.store(true, Ordering::SeqCst); + }); + let c1 = Completion::new_write(|_| {}); + let c2 = Completion::new_write(|_| {}); + child_group.add(&c1); + child_group.add(&c2); + let child_group = child_group.build(); + + // Create parent group containing child group and another completion + let parent_called_clone = parent_called.clone(); + let mut parent_group = CompletionGroup::new(move |_| { + parent_called_clone.store(true, Ordering::SeqCst); + }); + let c3 = Completion::new_write(|_| {}); + parent_group.add(&child_group); + parent_group.add(&c3); + let parent_group = parent_group.build(); + + // Complete child group with success + c1.complete(0); + c2.complete(0); + assert!(child_group.finished()); + assert!(child_group.succeeded()); + assert!(child_called.load(Ordering::SeqCst)); + + // Parent still pending + assert!(!parent_group.finished()); + assert!(!parent_called.load(Ordering::SeqCst)); + + // Complete c3 with error + c3.error(CompletionError::Aborted); + + // Parent should finish with error + assert!(parent_group.finished()); + assert!(!parent_group.succeeded()); + assert_eq!(parent_group.get_error(), Some(CompletionError::Aborted)); + assert!(parent_called.load(Ordering::SeqCst)); + } +} diff --git a/core/io/generic.rs b/core/io/generic.rs index 8eef59d3b..b465a24cb 100644 --- a/core/io/generic.rs +++ b/core/io/generic.rs @@ -1,4 +1,6 @@ -use crate::{Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO}; +use crate::{ + io::clock::DefaultClock, Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO, +}; use parking_lot::RwLock; use std::io::{Read, Seek, Write}; use std::sync::Arc; @@ -44,11 +46,7 @@ impl IO for GenericIO { impl Clock for GenericIO { fn now(&self) -> Instant { - let now = chrono::Local::now(); - Instant { - secs: now.timestamp(), - micros: now.timestamp_subsec_micros(), - } + DefaultClock.now() } } @@ -59,12 +57,12 @@ pub struct GenericFile { impl File for GenericFile { #[instrument(err, skip_all, level = Level::TRACE)] fn lock_file(&self, exclusive: bool) -> Result<()> { - unimplemented!() + Ok(()) } #[instrument(err, skip_all, level = Level::TRACE)] fn unlock_file(&self) -> Result<()> { - unimplemented!() + Ok(()) } #[instrument(skip(self, c), level = Level::TRACE)] diff --git a/core/io/io_uring.rs b/core/io/io_uring.rs index 03f8dc3a0..c681649eb 100644 --- a/core/io/io_uring.rs +++ b/core/io/io_uring.rs @@ -1,7 +1,7 @@ #![allow(clippy::arc_with_non_send_sync)] use super::{common, Completion, CompletionInner, File, OpenFlags, IO}; -use crate::io::clock::{Clock, Instant}; +use crate::io::clock::{Clock, DefaultClock, Instant}; use crate::storage::wal::CKPT_BATCH_PAGES; use crate::{turso_assert, CompletionError, LimboError, Result}; use parking_lot::Mutex; @@ -697,11 +697,7 @@ impl IO for UringIO { impl Clock for UringIO { fn now(&self) -> Instant { - let now = chrono::Local::now(); - Instant { - secs: now.timestamp(), - micros: now.timestamp_subsec_micros(), - } + DefaultClock.now() } } @@ -709,14 +705,16 @@ impl Clock for UringIO { /// use the callback pointer as the user_data for the operation as is /// common practice for io_uring to prevent more indirection fn get_key(c: Completion) -> u64 { - Arc::into_raw(c.inner.clone()) as u64 + Arc::into_raw(c.get_inner().clone()) as u64 } #[inline(always)] /// convert the user_data back to an Completion pointer fn completion_from_key(key: u64) -> Completion { let c_inner = unsafe { Arc::from_raw(key as *const CompletionInner) }; - Completion { inner: c_inner } + Completion { + inner: Some(c_inner), + } } pub struct UringFile { diff --git a/core/io/memory.rs b/core/io/memory.rs index fc0549ca7..31c78a4b1 100644 --- a/core/io/memory.rs +++ b/core/io/memory.rs @@ -1,5 +1,5 @@ use super::{Buffer, Clock, Completion, File, OpenFlags, IO}; -use crate::Result; +use crate::{io::clock::DefaultClock, Result}; use crate::io::clock::Instant; use std::{ @@ -35,11 +35,7 @@ impl Default for MemoryIO { impl Clock for MemoryIO { fn now(&self) -> Instant { - let now = chrono::Local::now(); - Instant { - secs: now.timestamp(), - micros: now.timestamp_subsec_micros(), - } + DefaultClock.now() } } diff --git a/core/io/mod.rs b/core/io/mod.rs index 0698df9c6..1c9d107a4 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -1,15 +1,47 @@ use crate::storage::buffer_pool::ArenaBuffer; use crate::storage::sqlite3_ondisk::WAL_FRAME_HEADER_SIZE; -use crate::{BufferPool, CompletionError, Result}; +use crate::{BufferPool, Result}; use bitflags::bitflags; use cfg_block::cfg_block; +use rand::{Rng, RngCore}; use std::cell::RefCell; use std::fmt; use std::ptr::NonNull; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use std::{fmt::Debug, pin::Pin}; +cfg_block! { + #[cfg(all(target_os = "linux", feature = "io_uring", not(miri)))] { + mod io_uring; + #[cfg(feature = "fs")] + pub use io_uring::UringIO; + } + + #[cfg(all(target_family = "unix", not(miri)))] { + mod unix; + #[cfg(feature = "fs")] + pub use unix::UnixIO; + pub use unix::UnixIO as PlatformIO; + pub use PlatformIO as SyscallIO; + } + + #[cfg(any(not(any(target_family = "unix", target_os = "android", target_os = "ios")), miri))] { + mod generic; + pub use generic::GenericIO as PlatformIO; + pub use PlatformIO as SyscallIO; + } +} + +mod memory; +#[cfg(feature = "fs")] +mod vfs; +pub use memory::MemoryIO; +pub mod clock; +mod common; +mod completions; +pub use clock::Clock; +pub use completions::*; + pub trait File: Send + Sync { fn lock_file(&self, exclusive: bool) -> Result<()>; fn unlock_file(&self) -> Result<()>; @@ -65,6 +97,11 @@ pub trait File: Send + Sync { #[derive(Debug, Copy, Clone, PartialEq)] pub struct OpenFlags(i32); +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for OpenFlags {} +unsafe impl Sync for OpenFlags {} + bitflags! { impl OpenFlags: i32 { const None = 0b00000000; @@ -102,16 +139,21 @@ pub trait IO: Clock + Send + Sync { while !c.finished() { self.step()? } - if let Some(Some(err)) = c.inner.result.get().copied() { - return Err(err.into()); + if let Some(inner) = &c.inner { + if let Some(Some(err)) = inner.result.get().copied() { + return Err(err.into()); + } } Ok(()) } fn generate_random_number(&self) -> i64 { - let mut buf = [0u8; 8]; - getrandom::getrandom(&mut buf).unwrap(); - i64::from_ne_bytes(buf) + rand::rng().random() + } + + /// Fill `dest` with random data. + fn fill_bytes(&self, dest: &mut [u8]) { + rand::rng().fill_bytes(dest); } fn get_memory_io(&self) -> Arc { @@ -125,412 +167,6 @@ pub trait IO: Clock + Send + Sync { } } -pub type ReadComplete = dyn Fn(Result<(Arc, i32), CompletionError>); -pub type WriteComplete = dyn Fn(Result); -pub type SyncComplete = dyn Fn(Result); -pub type TruncateComplete = dyn Fn(Result); - -#[must_use] -#[derive(Debug, Clone)] -pub struct Completion { - inner: Arc, -} - -struct CompletionInner { - completion_type: CompletionType, - /// None means we completed successfully - // Thread safe with OnceLock - result: std::sync::OnceLock>, - needs_link: bool, - /// Optional parent group this completion belongs to - parent: OnceLock>, -} - -impl fmt::Debug for CompletionInner { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("CompletionInner") - .field("completion_type", &self.completion_type) - .field("needs_link", &self.needs_link) - .field("parent", &self.parent.get().is_some()) - .finish() - } -} - -pub struct CompletionGroup { - completions: Vec, - callback: Box) + Send + Sync>, -} - -impl CompletionGroup { - pub fn new(callback: F) -> Self - where - F: Fn(Result) + Send + Sync + 'static, - { - Self { - completions: Vec::new(), - callback: Box::new(callback), - } - } - - pub fn add(&mut self, completion: &Completion) { - if !completion.finished() || completion.failed() { - self.completions.push(completion.clone()); - } - // Skip successfully finished completions - } - - pub fn build(self) -> Completion { - let total = self.completions.len(); - if total == 0 { - let group_completion = GroupCompletion::new(self.callback, 0); - return Completion::new(CompletionType::Group(group_completion)); - } - let group_completion = GroupCompletion::new(self.callback, total); - let group = Completion::new(CompletionType::Group(group_completion)); - - for mut c in self.completions { - // If the completion has not completed, link it to the group. - if !c.finished() { - c.link_internal(&group); - continue; - } - let group_inner = match &group.inner.completion_type { - CompletionType::Group(g) => &g.inner, - _ => unreachable!(), - }; - // Return early if there was an error. - if let Some(err) = c.get_error() { - let _ = group_inner.result.set(Some(err)); - group_inner.outstanding.store(0, Ordering::SeqCst); - (group_inner.complete)(Err(err)); - return group; - } - // Mark the successful completion as done. - group_inner.outstanding.fetch_sub(1, Ordering::SeqCst); - } - - let group_inner = match &group.inner.completion_type { - CompletionType::Group(g) => &g.inner, - _ => unreachable!(), - }; - if group_inner.outstanding.load(Ordering::SeqCst) == 0 { - (group_inner.complete)(Ok(0)); - } - group - } -} - -pub struct GroupCompletion { - inner: Arc, -} - -impl fmt::Debug for GroupCompletion { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("GroupCompletion") - .field( - "outstanding", - &self.inner.outstanding.load(Ordering::SeqCst), - ) - .finish() - } -} - -struct GroupCompletionInner { - /// Number of completions that need to finish - outstanding: AtomicUsize, - /// Callback to invoke when all completions finish - complete: Box) + Send + Sync>, - /// Cached result after all completions finish - result: OnceLock>, -} - -impl GroupCompletion { - pub fn new(complete: F, outstanding: usize) -> Self - where - F: Fn(Result) + Send + Sync + 'static, - { - Self { - inner: Arc::new(GroupCompletionInner { - outstanding: AtomicUsize::new(outstanding), - complete: Box::new(complete), - result: OnceLock::new(), - }), - } - } - - pub fn callback(&self, result: Result) { - assert_eq!( - self.inner.outstanding.load(Ordering::SeqCst), - 0, - "callback called before all completions finished" - ); - (self.inner.complete)(result); - } -} - -impl Debug for CompletionType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Read(..) => f.debug_tuple("Read").finish(), - Self::Write(..) => f.debug_tuple("Write").finish(), - Self::Sync(..) => f.debug_tuple("Sync").finish(), - Self::Truncate(..) => f.debug_tuple("Truncate").finish(), - Self::Group(..) => f.debug_tuple("Group").finish(), - } - } -} - -pub enum CompletionType { - Read(ReadCompletion), - Write(WriteCompletion), - Sync(SyncCompletion), - Truncate(TruncateCompletion), - Group(GroupCompletion), -} - -impl Completion { - pub fn new(completion_type: CompletionType) -> Self { - Self { - inner: Arc::new(CompletionInner { - completion_type, - result: OnceLock::new(), - needs_link: false, - parent: OnceLock::new(), - }), - } - } - - pub fn new_linked(completion_type: CompletionType) -> Self { - Self { - inner: Arc::new(CompletionInner { - completion_type, - result: OnceLock::new(), - needs_link: true, - parent: OnceLock::new(), - }), - } - } - - pub fn needs_link(&self) -> bool { - self.inner.needs_link - } - - pub fn new_write_linked(complete: F) -> Self - where - F: Fn(Result) + 'static, - { - Self::new_linked(CompletionType::Write(WriteCompletion::new(Box::new( - complete, - )))) - } - - pub fn new_write(complete: F) -> Self - where - F: Fn(Result) + 'static, - { - Self::new(CompletionType::Write(WriteCompletion::new(Box::new( - complete, - )))) - } - - pub fn new_read(buf: Arc, complete: F) -> Self - where - F: Fn(Result<(Arc, i32), CompletionError>) + 'static, - { - Self::new(CompletionType::Read(ReadCompletion::new( - buf, - Box::new(complete), - ))) - } - pub fn new_sync(complete: F) -> Self - where - F: Fn(Result) + 'static, - { - Self::new(CompletionType::Sync(SyncCompletion::new(Box::new( - complete, - )))) - } - - pub fn new_trunc(complete: F) -> Self - where - F: Fn(Result) + 'static, - { - Self::new(CompletionType::Truncate(TruncateCompletion::new(Box::new( - complete, - )))) - } - - /// Create a dummy completed completion - pub fn new_dummy() -> Self { - let c = Self::new_write(|_| {}); - c.complete(0); - c - } - - pub fn succeeded(&self) -> bool { - match &self.inner.completion_type { - CompletionType::Group(g) => { - g.inner.outstanding.load(Ordering::SeqCst) == 0 - && g.inner.result.get().is_none_or(|e| e.is_none()) - } - _ => self.inner.result.get().is_some(), - } - } - - pub fn failed(&self) -> bool { - self.inner.result.get().is_some_and(|val| val.is_some()) - } - - pub fn get_error(&self) -> Option { - match &self.inner.completion_type { - CompletionType::Group(g) => { - // For groups, check the group's cached result field - // (set when the last completion finishes) - g.inner.result.get().and_then(|res| *res) - } - _ => self.inner.result.get().and_then(|res| *res), - } - } - - /// Checks if the Completion completed or errored - pub fn finished(&self) -> bool { - match &self.inner.completion_type { - CompletionType::Group(g) => g.inner.outstanding.load(Ordering::SeqCst) == 0, - _ => self.inner.result.get().is_some(), - } - } - - pub fn complete(&self, result: i32) { - let result = Ok(result); - self.callback(result); - } - - pub fn error(&self, err: CompletionError) { - let result = Err(err); - self.callback(result); - } - - pub fn abort(&self) { - self.error(CompletionError::Aborted); - } - - fn callback(&self, result: Result) { - self.inner.result.get_or_init(|| { - match &self.inner.completion_type { - CompletionType::Read(r) => r.callback(result), - CompletionType::Write(w) => w.callback(result), - CompletionType::Sync(s) => s.callback(result), // fix - CompletionType::Truncate(t) => t.callback(result), - CompletionType::Group(g) => g.callback(result), - }; - - if let Some(group) = self.inner.parent.get() { - // Capture first error in group - if let Err(err) = result { - let _ = group.result.set(Some(err)); - } - let prev = group.outstanding.fetch_sub(1, Ordering::SeqCst); - - // If this was the last completion, call the group callback - if prev == 1 { - let group_result = group.result.get().and_then(|e| *e); - (group.complete)(group_result.map_or(Ok(0), Err)); - } - // TODO: remove self from parent group - } - - result.err() - }); - } - - /// only call this method if you are sure that the completion is - /// a ReadCompletion, panics otherwise - pub fn as_read(&self) -> &ReadCompletion { - match self.inner.completion_type { - CompletionType::Read(ref r) => r, - _ => unreachable!(), - } - } - - /// Link this completion to a group completion (internal use only) - fn link_internal(&mut self, group: &Completion) { - let group_inner = match &group.inner.completion_type { - CompletionType::Group(g) => &g.inner, - _ => panic!("link_internal() requires a group completion"), - }; - - // Set the parent (can only be set once) - if self.inner.parent.set(group_inner.clone()).is_err() { - panic!("completion can only be linked once"); - } - } -} - -pub struct ReadCompletion { - pub buf: Arc, - pub complete: Box, -} - -impl ReadCompletion { - pub fn new(buf: Arc, complete: Box) -> Self { - Self { buf, complete } - } - - pub fn buf(&self) -> &Buffer { - &self.buf - } - - pub fn callback(&self, bytes_read: Result) { - (self.complete)(bytes_read.map(|b| (self.buf.clone(), b))); - } - - pub fn buf_arc(&self) -> Arc { - self.buf.clone() - } -} - -pub struct WriteCompletion { - pub complete: Box, -} - -impl WriteCompletion { - pub fn new(complete: Box) -> Self { - Self { complete } - } - - pub fn callback(&self, bytes_written: Result) { - (self.complete)(bytes_written); - } -} - -pub struct SyncCompletion { - pub complete: Box, -} - -impl SyncCompletion { - pub fn new(complete: Box) -> Self { - Self { complete } - } - - pub fn callback(&self, res: Result) { - (self.complete)(res); - } -} - -pub struct TruncateCompletion { - pub complete: Box, -} - -impl TruncateCompletion { - pub fn new(complete: Box) -> Self { - Self { complete } - } - - pub fn callback(&self, res: Result) { - (self.complete)(res); - } -} - pub type BufferData = Pin>; pub enum Buffer { @@ -690,251 +326,3 @@ impl TempBufferCache { } } } - -cfg_block! { - #[cfg(all(target_os = "linux", feature = "io_uring"))] { - mod io_uring; - #[cfg(feature = "fs")] - pub use io_uring::UringIO; - } - - #[cfg(target_family = "unix")] { - mod unix; - #[cfg(feature = "fs")] - pub use unix::UnixIO; - pub use unix::UnixIO as PlatformIO; - pub use PlatformIO as SyscallIO; - } - - #[cfg(not(any(target_family = "unix", target_os = "android", target_os = "ios")))] { - mod generic; - pub use generic::GenericIO as PlatformIO; - pub use PlatformIO as SyscallIO; - } -} - -mod memory; -#[cfg(feature = "fs")] -mod vfs; -pub use memory::MemoryIO; -pub mod clock; -mod common; -pub use clock::Clock; - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_completion_group_empty() { - let group = CompletionGroup::new(|_| {}); - let group = group.build(); - assert!(group.finished()); - assert!(group.succeeded()); - assert!(group.get_error().is_none()); - } - - #[test] - fn test_completion_group_single_completion() { - let mut group = CompletionGroup::new(|_| {}); - let c = Completion::new_write(|_| {}); - group.add(&c); - let group = group.build(); - - assert!(!group.finished()); - assert!(!group.succeeded()); - - c.complete(0); - - assert!(group.finished()); - assert!(group.succeeded()); - assert!(group.get_error().is_none()); - } - - #[test] - fn test_completion_group_multiple_completions() { - let mut group = CompletionGroup::new(|_| {}); - let c1 = Completion::new_write(|_| {}); - let c2 = Completion::new_write(|_| {}); - let c3 = Completion::new_write(|_| {}); - group.add(&c1); - group.add(&c2); - group.add(&c3); - let group = group.build(); - - assert!(!group.succeeded()); - assert!(!group.finished()); - - c1.complete(0); - assert!(!group.succeeded()); - assert!(!group.finished()); - - c2.complete(0); - assert!(!group.succeeded()); - assert!(!group.finished()); - - c3.complete(0); - assert!(group.succeeded()); - assert!(group.finished()); - } - - #[test] - fn test_completion_group_with_error() { - let mut group = CompletionGroup::new(|_| {}); - let c1 = Completion::new_write(|_| {}); - let c2 = Completion::new_write(|_| {}); - group.add(&c1); - group.add(&c2); - let group = group.build(); - - c1.complete(0); - c2.error(CompletionError::Aborted); - - assert!(group.finished()); - assert!(!group.succeeded()); - assert_eq!(group.get_error(), Some(CompletionError::Aborted)); - } - - #[test] - fn test_completion_group_callback() { - use std::sync::atomic::{AtomicBool, Ordering}; - let called = Arc::new(AtomicBool::new(false)); - let called_clone = called.clone(); - - let mut group = CompletionGroup::new(move |_| { - called_clone.store(true, Ordering::SeqCst); - }); - - let c1 = Completion::new_write(|_| {}); - let c2 = Completion::new_write(|_| {}); - group.add(&c1); - group.add(&c2); - let group = group.build(); - - assert!(!called.load(Ordering::SeqCst)); - - c1.complete(0); - assert!(!called.load(Ordering::SeqCst)); - - c2.complete(0); - assert!(called.load(Ordering::SeqCst)); - assert!(group.finished()); - assert!(group.succeeded()); - } - - #[test] - fn test_completion_group_some_already_completed() { - // Test some completions added to group, then finish before build() - let mut group = CompletionGroup::new(|_| {}); - let c1 = Completion::new_write(|_| {}); - let c2 = Completion::new_write(|_| {}); - let c3 = Completion::new_write(|_| {}); - - // Add all to group while pending - group.add(&c1); - group.add(&c2); - group.add(&c3); - - // Complete c1 and c2 AFTER adding but BEFORE build() - c1.complete(0); - c2.complete(0); - - let group = group.build(); - - // c1 and c2 finished before build(), so outstanding should account for them - // Only c3 should be pending - assert!(!group.finished()); - assert!(!group.succeeded()); - - // Complete c3 - c3.complete(0); - - // Now the group should be finished - assert!(group.finished()); - assert!(group.succeeded()); - assert!(group.get_error().is_none()); - } - - #[test] - fn test_completion_group_all_already_completed() { - // Test when all completions are already finished before build() - let mut group = CompletionGroup::new(|_| {}); - let c1 = Completion::new_write(|_| {}); - let c2 = Completion::new_write(|_| {}); - - // Complete both before adding to group - c1.complete(0); - c2.complete(0); - - group.add(&c1); - group.add(&c2); - - let group = group.build(); - - // All completions were already complete, so group should be finished immediately - assert!(group.finished()); - assert!(group.succeeded()); - assert!(group.get_error().is_none()); - } - - #[test] - fn test_completion_group_mixed_finished_and_pending() { - use std::sync::atomic::{AtomicBool, Ordering}; - let called = Arc::new(AtomicBool::new(false)); - let called_clone = called.clone(); - - let mut group = CompletionGroup::new(move |_| { - called_clone.store(true, Ordering::SeqCst); - }); - - let c1 = Completion::new_write(|_| {}); - let c2 = Completion::new_write(|_| {}); - let c3 = Completion::new_write(|_| {}); - let c4 = Completion::new_write(|_| {}); - - // Complete c1 and c3 before adding to group - c1.complete(0); - c3.complete(0); - - group.add(&c1); - group.add(&c2); - group.add(&c3); - group.add(&c4); - - let group = group.build(); - - // Only c2 and c4 should be pending - assert!(!group.finished()); - assert!(!called.load(Ordering::SeqCst)); - - c2.complete(0); - assert!(!group.finished()); - assert!(!called.load(Ordering::SeqCst)); - - c4.complete(0); - assert!(group.finished()); - assert!(group.succeeded()); - assert!(called.load(Ordering::SeqCst)); - } - - #[test] - fn test_completion_group_already_completed_with_error() { - // Test when a completion finishes with error before build() - let mut group = CompletionGroup::new(|_| {}); - let c1 = Completion::new_write(|_| {}); - let c2 = Completion::new_write(|_| {}); - - // Complete c1 with error before adding to group - c1.error(CompletionError::Aborted); - - group.add(&c1); - group.add(&c2); - - let group = group.build(); - - // Group should immediately fail with the error - assert!(group.finished()); - assert!(!group.succeeded()); - assert_eq!(group.get_error(), Some(CompletionError::Aborted)); - } -} diff --git a/core/io/unix.rs b/core/io/unix.rs index bb17765f2..f95c9b95d 100644 --- a/core/io/unix.rs +++ b/core/io/unix.rs @@ -1,6 +1,6 @@ use super::{Completion, File, OpenFlags, IO}; use crate::error::LimboError; -use crate::io::clock::{Clock, Instant}; +use crate::io::clock::{Clock, DefaultClock, Instant}; use crate::io::common; use crate::Result; use parking_lot::Mutex; @@ -27,11 +27,7 @@ impl UnixIO { impl Clock for UnixIO { fn now(&self) -> Instant { - let now = chrono::Local::now(); - Instant { - secs: now.timestamp(), - micros: now.timestamp_subsec_micros(), - } + DefaultClock.now() } } diff --git a/core/io/vfs.rs b/core/io/vfs.rs index 9c7b116c0..b2ce62424 100644 --- a/core/io/vfs.rs +++ b/core/io/vfs.rs @@ -1,6 +1,6 @@ use super::{Buffer, Completion, File, OpenFlags, IO}; use crate::ext::VfsMod; -use crate::io::clock::{Clock, Instant}; +use crate::io::clock::{Clock, DefaultClock, Instant}; use crate::io::CompletionInner; use crate::{LimboError, Result}; use std::ffi::{c_void, CString}; @@ -10,11 +10,7 @@ use turso_ext::{BufferRef, IOCallback, SendPtr, VfsFileImpl, VfsImpl}; impl Clock for VfsMod { fn now(&self) -> Instant { - let now = chrono::Local::now(); - Instant { - secs: now.timestamp(), - micros: now.timestamp_subsec_micros(), - } + DefaultClock.now() } } @@ -86,14 +82,14 @@ impl VfsMod { /// that the into_raw/from_raw contract will hold unsafe extern "C" fn callback_fn(result: i32, ctx: SendPtr) { let completion = Completion { - inner: (Arc::from_raw(ctx.inner().as_ptr() as *mut CompletionInner)), + inner: (Some(Arc::from_raw(ctx.inner().as_ptr() as *mut CompletionInner))), }; completion.complete(result); } fn to_callback(c: Completion) -> IOCallback { IOCallback::new(callback_fn, unsafe { - NonNull::new_unchecked(Arc::into_raw(c.inner) as *mut c_void) + NonNull::new_unchecked(Arc::into_raw(c.get_inner().clone()) as *mut c_void) }) } diff --git a/core/io/windows.rs b/core/io/windows.rs index a884cc922..3431e5454 100644 --- a/core/io/windows.rs +++ b/core/io/windows.rs @@ -1,4 +1,6 @@ -use crate::{Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO}; +use crate::{ + io::clock::DefaultClock, Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO, +}; use parking_lot::RwLock; use std::io::{Read, Seek, Write}; use std::sync::Arc; @@ -44,11 +46,7 @@ impl IO for WindowsIO { impl Clock for WindowsIO { fn now(&self) -> Instant { - let now = chrono::Local::now(); - Instant { - secs: now.timestamp(), - micros: now.timestamp_subsec_micros(), - } + DefaultClock.now() } } diff --git a/core/json/mod.rs b/core/json/mod.rs index 6209b6774..d3b1f4e19 100644 --- a/core/json/mod.rs +++ b/core/json/mod.rs @@ -11,9 +11,9 @@ pub use crate::json::ops::{ jsonb_replace, }; use crate::json::path::{json_path, JsonPath, PathElement}; -use crate::types::{RawSlice, Text, TextRef, TextSubtype, Value, ValueType}; +use crate::types::{Text, TextSubtype, Value, ValueType}; use crate::vdbe::Register; -use crate::{bail_constraint_error, bail_parse_error, LimboError, RefValue}; +use crate::{bail_constraint_error, bail_parse_error, LimboError, ValueRef}; pub use cache::JsonCacheCell; use jsonb::{ElementType, Jsonb, JsonbHeader, PathOperationMode, SearchOperation, SetOperation}; use std::borrow::Cow; @@ -105,14 +105,12 @@ pub fn json_from_raw_bytes_agg(data: &[u8], raw: bool) -> crate::Result { pub fn convert_dbtype_to_jsonb(val: &Value, strict: Conv) -> crate::Result { convert_ref_dbtype_to_jsonb( - &match val { - Value::Null => RefValue::Null, - Value::Integer(x) => RefValue::Integer(*x), - Value::Float(x) => RefValue::Float(*x), - Value::Text(text) => { - RefValue::Text(TextRef::create_from(text.as_str().as_bytes(), text.subtype)) - } - Value::Blob(items) => RefValue::Blob(RawSlice::create_from(items)), + match val { + Value::Null => ValueRef::Null, + Value::Integer(x) => ValueRef::Integer(*x), + Value::Float(x) => ValueRef::Float(*x), + Value::Text(text) => ValueRef::Text(text.as_str().as_bytes(), text.subtype), + Value::Blob(items) => ValueRef::Blob(items.as_slice()), }, strict, ) @@ -124,14 +122,14 @@ fn parse_as_json_text(slice: &[u8]) -> crate::Result { Jsonb::from_str_with_mode(str, Conv::Strict).map_err(Into::into) } -pub fn convert_ref_dbtype_to_jsonb(val: &RefValue, strict: Conv) -> crate::Result { +pub fn convert_ref_dbtype_to_jsonb(val: ValueRef<'_>, strict: Conv) -> crate::Result { match val { - RefValue::Text(text) => { - let res = if text.subtype == TextSubtype::Json || matches!(strict, Conv::Strict) { - Jsonb::from_str_with_mode(text.as_str(), strict) + ValueRef::Text(text, subtype) => { + let res = if subtype == TextSubtype::Json || matches!(strict, Conv::Strict) { + Jsonb::from_str_with_mode(&String::from_utf8_lossy(text), strict) } else { // Handle as a string literal otherwise - let mut str = text.as_str().replace('"', "\\\""); + let mut str = String::from_utf8_lossy(text).replace('"', "\\\""); // Quote the string to make it a JSON string str.insert(0, '"'); str.push('"'); @@ -139,8 +137,8 @@ pub fn convert_ref_dbtype_to_jsonb(val: &RefValue, strict: Conv) -> crate::Resul }; res.map_err(|_| LimboError::ParseError("malformed JSON".to_string())) } - RefValue::Blob(blob) => { - let bytes = blob.to_slice(); + ValueRef::Blob(blob) => { + let bytes = blob; // Valid JSON can start with these whitespace characters let index = bytes .iter() @@ -177,15 +175,15 @@ pub fn convert_ref_dbtype_to_jsonb(val: &RefValue, strict: Conv) -> crate::Resul json.element_type()?; Ok(json) } - RefValue::Null => Ok(Jsonb::from_raw_data( + ValueRef::Null => Ok(Jsonb::from_raw_data( JsonbHeader::make_null().into_bytes().as_bytes(), )), - RefValue::Float(float) => { + ValueRef::Float(float) => { let mut buff = ryu::Buffer::new(); - Jsonb::from_str(buff.format(*float)) + Jsonb::from_str(buff.format(float)) .map_err(|_| LimboError::ParseError("malformed JSON".to_string())) } - RefValue::Integer(int) => Jsonb::from_str(&int.to_string()) + ValueRef::Integer(int) => Jsonb::from_str(&int.to_string()) .map_err(|_| LimboError::ParseError("malformed JSON".to_string())), } } diff --git a/core/lib.rs b/core/lib.rs index 11b85be81..4e5eeea8d 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -40,21 +40,24 @@ pub mod numeric; mod numeric; use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES; +use crate::storage::encryption::AtomicCipherMode; use crate::translate::display::PlanContext; use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME; #[cfg(all(feature = "fs", feature = "conn_raw_api"))] use crate::types::{WalFrameInfo, WalState}; #[cfg(feature = "fs")] use crate::util::{OpenMode, OpenOptions}; +use crate::vdbe::explain::{EXPLAIN_COLUMNS_TYPE, EXPLAIN_QUERY_PLAN_COLUMNS_TYPE}; use crate::vdbe::metrics::ConnectionMetrics; use crate::vtab::VirtualTable; use crate::{incremental::view::AllViewsTxState, translate::emitter::TransactionMode}; +use arc_swap::ArcSwap; use core::str; pub use error::{CompletionError, LimboError}; pub use io::clock::{Clock, Instant}; -#[cfg(all(feature = "fs", target_family = "unix"))] +#[cfg(all(feature = "fs", target_family = "unix", not(miri)))] pub use io::UnixIO; -#[cfg(all(feature = "fs", target_os = "linux", feature = "io_uring"))] +#[cfg(all(feature = "fs", target_os = "linux", feature = "io_uring", not(miri)))] pub use io::UringIO; pub use io::{ Buffer, Completion, CompletionType, File, GroupCompletion, MemoryIO, OpenFlags, PlatformIO, @@ -62,16 +65,17 @@ pub use io::{ }; use parking_lot::RwLock; use schema::Schema; +use std::task::Waker; use std::{ borrow::Cow, - cell::RefCell, + cell::{Cell, RefCell}, collections::HashMap, fmt::{self, Display}, num::NonZero, ops::Deref, rc::Rc, sync::{ - atomic::{AtomicBool, AtomicI32, AtomicI64, AtomicU16, AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicI32, AtomicI64, AtomicIsize, AtomicU16, AtomicUsize, Ordering}, Arc, LazyLock, Mutex, Weak, }, time::Duration, @@ -91,12 +95,12 @@ pub use storage::{ wal::{CheckpointMode, CheckpointResult, Wal, WalFile, WalFileShared}, }; use tracing::{instrument, Level}; -use turso_macros::match_ignore_ascii_case; +use turso_macros::{match_ignore_ascii_case, AtomicEnum}; use turso_parser::ast::fmt::ToTokens; use turso_parser::{ast, ast::Cmd, parser::Parser}; use types::IOResult; -pub use types::RefValue; pub use types::Value; +pub use types::ValueRef; use util::parse_schema_rows; pub use util::IOExt; pub use vdbe::{builder::QueryMode, explain::EXPLAIN_COLUMNS, explain::EXPLAIN_QUERY_PLAN_COLUMNS}; @@ -176,7 +180,7 @@ impl EncryptionOpts { pub type Result = std::result::Result; -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Clone, AtomicEnum, Copy, PartialEq, Eq, Debug)] enum TransactionState { Write { schema_did_change: bool }, Read, @@ -184,7 +188,7 @@ enum TransactionState { None, } -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq, Eq)] pub enum SyncMode { Off = 0, Full = 2, @@ -217,12 +221,14 @@ pub struct Database { shared_wal: Arc>, db_state: Arc, init_lock: Arc>, - open_flags: OpenFlags, + open_flags: Cell, builtin_syms: RwLock, opts: DatabaseOpts, n_connections: AtomicUsize, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 unsafe impl Send for Database {} unsafe impl Sync for Database {} @@ -231,7 +237,7 @@ impl fmt::Debug for Database { let mut debug_struct = f.debug_struct("Database"); debug_struct .field("path", &self.path) - .field("open_flags", &self.open_flags); + .field("open_flags", &self.open_flags.get()); // Database state information let db_state_value = match self.db_state.get() { @@ -468,7 +474,7 @@ impl Database { db_file, builtin_syms: syms.into(), io: io.clone(), - open_flags: flags, + open_flags: flags.into(), db_state: Arc::new(AtomicDbState::new(db_state)), init_lock: Arc::new(Mutex::new(())), opts, @@ -482,13 +488,13 @@ impl Database { let conn = db.connect()?; let syms = conn.syms.read(); - let pager = conn.pager.read().clone(); + let pager = conn.pager.load().clone(); if let Some(encryption_opts) = encryption_opts { conn.pragma_update("cipher", format!("'{}'", encryption_opts.cipher))?; conn.pragma_update("hexkey", format!("'{}'", encryption_opts.hexkey))?; // Clear page cache so the header page can be reread from disk and decrypted using the encryption context. - pager.clear_page_cache(); + pager.clear_page_cache(false); } db.with_schema_mut(|schema| { let header_schema_cookie = pager @@ -497,10 +503,7 @@ impl Database { schema.schema_version = header_schema_cookie; let result = schema .make_from_btree(None, pager.clone(), &syms) - .or_else(|e| { - pager.end_read_tx()?; - Err(e) - }); + .inspect_err(|_| pager.end_read_tx()); if let Err(LimboError::ExtensionError(e)) = result { // this means that a vtab exists and we no longer have the module loaded. we print // a warning to the user to load the module @@ -557,16 +560,11 @@ impl Database { .get(); let conn = Arc::new(Connection { db: self.clone(), - pager: RwLock::new(pager), - schema: RwLock::new( - self.schema - .lock() - .map_err(|_| LimboError::SchemaLocked)? - .clone(), - ), + pager: ArcSwap::new(pager), + schema: RwLock::new(self.schema.lock().unwrap().clone()), database_schemas: RwLock::new(std::collections::HashMap::new()), auto_commit: AtomicBool::new(true), - transaction_state: RwLock::new(TransactionState::None), + transaction_state: AtomicTransactionState::new(TransactionState::None), last_insert_rowid: AtomicI64::new(0), last_change: AtomicI64::new(0), total_changes: AtomicI64::new(0), @@ -584,11 +582,13 @@ impl Database { metrics: RwLock::new(ConnectionMetrics::new()), is_nested_stmt: AtomicBool::new(false), encryption_key: RwLock::new(None), - encryption_cipher_mode: RwLock::new(None), - sync_mode: RwLock::new(SyncMode::Full), + encryption_cipher_mode: AtomicCipherMode::new(CipherMode::None), + sync_mode: AtomicSyncMode::new(SyncMode::Full), data_sync_retry: AtomicBool::new(false), busy_timeout: RwLock::new(Duration::new(0, 0)), is_mvcc_bootstrap_connection: AtomicBool::new(is_mvcc_bootstrap_connection), + fk_pragma: AtomicBool::new(false), + fk_deferred_violations: AtomicIsize::new(0), }); self.n_connections .fetch_add(1, std::sync::atomic::Ordering::SeqCst); @@ -599,14 +599,14 @@ impl Database { } pub fn is_readonly(&self) -> bool { - self.open_flags.contains(OpenFlags::ReadOnly) + self.open_flags.get().contains(OpenFlags::ReadOnly) } /// If we do not have a physical WAL file, but we know the database file is initialized on disk, /// we need to read the page_size from the database header. fn read_page_size_from_db_header(&self) -> Result { turso_assert!( - self.db_state.is_initialized(), + self.db_state.get().is_initialized(), "read_page_size_from_db_header called on uninitialized database" ); turso_assert!( @@ -624,7 +624,7 @@ impl Database { fn read_reserved_space_bytes_from_db_header(&self) -> Result { turso_assert!( - self.db_state.is_initialized(), + self.db_state.get().is_initialized(), "read_reserved_space_bytes_from_db_header called on uninitialized database" ); turso_assert!( @@ -660,7 +660,7 @@ impl Database { return Ok(page_size); } } - if self.db_state.is_initialized() { + if self.db_state.get().is_initialized() { Ok(self.read_page_size_from_db_header()?) } else { let Some(size) = requested_page_size else { @@ -676,7 +676,7 @@ impl Database { /// if the database is initialized i.e. it exists on disk, return the reserved space bytes from /// the header or None fn maybe_get_reserved_space_bytes(&self) -> Result> { - if self.db_state.is_initialized() { + if self.db_state.get().is_initialized() { Ok(Some(self.read_reserved_space_bytes_from_db_header()?)) } else { Ok(None) @@ -698,7 +698,7 @@ impl Database { drop(shared_wal); let buffer_pool = self.buffer_pool.clone(); - if self.db_state.is_initialized() { + if self.db_state.get().is_initialized() { buffer_pool.finalize_with_page_size(page_size.get() as usize)?; } @@ -731,7 +731,7 @@ impl Database { let buffer_pool = self.buffer_pool.clone(); - if self.db_state.is_initialized() { + if self.db_state.get().is_initialized() { buffer_pool.finalize_with_page_size(page_size.get() as usize)?; } @@ -796,7 +796,7 @@ impl Database { None => match vfs.as_ref() { "memory" => Arc::new(MemoryIO::new()), "syscall" => Arc::new(SyscallIO::new()?), - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(all(target_os = "linux", feature = "io_uring", not(miri)))] "io_uring" => Arc::new(UringIO::new()?), other => { return Err(LimboError::InvalidArgument(format!("no such VFS: {other}"))); @@ -835,17 +835,17 @@ impl Database { #[inline] pub(crate) fn with_schema_mut(&self, f: impl FnOnce(&mut Schema) -> Result) -> Result { - let mut schema_ref = self.schema.lock().map_err(|_| LimboError::SchemaLocked)?; + let mut schema_ref = self.schema.lock().unwrap(); let schema = Arc::make_mut(&mut *schema_ref); f(schema) } - pub(crate) fn clone_schema(&self) -> Result> { - let schema = self.schema.lock().map_err(|_| LimboError::SchemaLocked)?; - Ok(schema.clone()) + pub(crate) fn clone_schema(&self) -> Arc { + let schema = self.schema.lock().unwrap(); + schema.clone() } - pub(crate) fn update_schema_if_newer(&self, another: Arc) -> Result<()> { - let mut schema = self.schema.lock().map_err(|_| LimboError::SchemaLocked)?; + pub(crate) fn update_schema_if_newer(&self, another: Arc) { + let mut schema = self.schema.lock().unwrap(); if schema.schema_version < another.schema_version { tracing::debug!( "DB schema is outdated: {} < {}", @@ -860,7 +860,6 @@ impl Database { another.schema_version ); } - Ok(()) } pub fn get_mv_store(&self) -> Option<&Arc> { @@ -1063,14 +1062,14 @@ impl DatabaseCatalog { pub struct Connection { db: Arc, - pager: RwLock>, + pager: ArcSwap, schema: RwLock>, /// Per-database schema cache (database_index -> schema) /// Loaded lazily to avoid copying all schemas on connection open database_schemas: RwLock>>, /// Whether to automatically commit transaction auto_commit: AtomicBool, - transaction_state: RwLock, + transaction_state: AtomicTransactionState, last_insert_rowid: AtomicI64, last_change: AtomicI64, total_changes: AtomicI64, @@ -1099,16 +1098,24 @@ pub struct Connection { /// Generally this is only true for ParseSchema. is_nested_stmt: AtomicBool, encryption_key: RwLock>, - encryption_cipher_mode: RwLock>, - sync_mode: RwLock, + encryption_cipher_mode: AtomicCipherMode, + sync_mode: AtomicSyncMode, data_sync_retry: AtomicBool, /// User defined max accumulated Busy timeout duration /// Default is 0 (no timeout) busy_timeout: RwLock, /// Whether this is an internal connection used for MVCC bootstrap is_mvcc_bootstrap_connection: AtomicBool, + /// Whether pragma foreign_keys=ON for this connection + fk_pragma: AtomicBool, + fk_deferred_violations: AtomicIsize, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for Connection {} +unsafe impl Sync for Connection {} + impl Drop for Connection { fn drop(&mut self) { if !self.is_closed() { @@ -1154,8 +1161,8 @@ impl Connection { let input = str::from_utf8(&sql.as_bytes()[..byte_offset_end]) .unwrap() .trim(); - self.maybe_update_schema()?; - let pager = self.pager.read().clone(); + self.maybe_update_schema(); + let pager = self.pager.load().clone(); let mode = QueryMode::new(&cmd); let (Cmd::Stmt(stmt) | Cmd::Explain(stmt) | Cmd::ExplainQueryPlan(stmt)) = cmd; let program = translate::translate( @@ -1186,7 +1193,7 @@ impl Connection { /// This function must be called outside of any transaction because internally it will start transaction session by itself #[allow(dead_code)] fn maybe_reparse_schema(self: &Arc) -> Result<()> { - let pager = self.pager.read().clone(); + let pager = self.pager.load().clone(); // first, quickly read schema_version from the root page in order to check if schema changed pager.begin_read_tx()?; @@ -1201,11 +1208,11 @@ impl Connection { 0 } Err(err) => { - pager.end_read_tx().expect("read txn must be finished"); + pager.end_read_tx(); return Err(err); } }; - pager.end_read_tx().expect("read txn must be finished"); + pager.end_read_tx(); let db_schema_version = self.db.schema.lock().unwrap().schema_version; tracing::debug!( @@ -1233,8 +1240,7 @@ impl Connection { let reparse_result = self.reparse_schema(); - let previous = - std::mem::replace(&mut *self.transaction_state.write(), TransactionState::None); + let previous = self.transaction_state.swap(TransactionState::None); turso_assert!( matches!(previous, TransactionState::None | TransactionState::Read), "unexpected end transaction state" @@ -1242,17 +1248,18 @@ impl Connection { // close opened transaction if it was kept open // (in most cases, it will be automatically closed if stmt was executed properly) if previous == TransactionState::Read { - pager.end_read_tx().expect("read txn must be finished"); + pager.end_read_tx(); } reparse_result?; let schema = self.schema.read().clone(); - self.db.update_schema_if_newer(schema) + self.db.update_schema_if_newer(schema); + Ok(()) } fn reparse_schema(self: &Arc) -> Result<()> { - let pager = self.pager.read().clone(); + let pager = self.pager.load().clone(); // read cookie before consuming statement program - otherwise we can end up reading cookie with closed transaction state let cookie = pager @@ -1303,13 +1310,13 @@ impl Connection { "The supplied SQL string contains no statements".to_string(), )); } - self.maybe_update_schema()?; + self.maybe_update_schema(); let sql = sql.as_ref(); tracing::trace!("Preparing and executing batch: {}", sql); let mut parser = Parser::new(sql.as_bytes()); while let Some(cmd) = parser.next_cmd()? { let syms = self.syms.read(); - let pager = self.pager.read().clone(); + let pager = self.pager.load().clone(); let byte_offset_end = parser.offset(); let input = str::from_utf8(&sql.as_bytes()[..byte_offset_end]) .unwrap() @@ -1337,7 +1344,7 @@ impl Connection { return Err(LimboError::InternalError("Connection closed".to_string())); } let sql = sql.as_ref(); - self.maybe_update_schema()?; + self.maybe_update_schema(); tracing::trace!("Querying: {}", sql); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next_cmd()?; @@ -1361,7 +1368,7 @@ impl Connection { return Err(LimboError::InternalError("Connection closed".to_string())); } let syms = self.syms.read(); - let pager = self.pager.read().clone(); + let pager = self.pager.load().clone(); let mode = QueryMode::new(&cmd); let (Cmd::Stmt(stmt) | Cmd::Explain(stmt) | Cmd::ExplainQueryPlan(stmt)) = cmd; let program = translate::translate( @@ -1389,11 +1396,11 @@ impl Connection { return Err(LimboError::InternalError("Connection closed".to_string())); } let sql = sql.as_ref(); - self.maybe_update_schema()?; + self.maybe_update_schema(); let mut parser = Parser::new(sql.as_bytes()); while let Some(cmd) = parser.next_cmd()? { let syms = self.syms.read(); - let pager = self.pager.read().clone(); + let pager = self.pager.load().clone(); let byte_offset_end = parser.offset(); let input = str::from_utf8(&sql.as_bytes()[..byte_offset_end]) .unwrap() @@ -1422,7 +1429,7 @@ impl Connection { return Ok(None); }; let syms = self.syms.read(); - let pager = self.pager.read().clone(); + let pager = self.pager.load().clone(); let byte_offset_end = parser.offset(); let input = str::from_utf8(&sql.as_bytes()[..byte_offset_end]) .unwrap() @@ -1512,10 +1519,10 @@ impl Connection { if let Some(encryption_opts) = encryption_opts { let _ = conn.pragma_update("cipher", encryption_opts.cipher.to_string()); let _ = conn.pragma_update("hexkey", encryption_opts.hexkey.to_string()); - let pager = conn.pager.read(); - if db.db_state.is_initialized() { + let pager = conn.pager.load(); + if db.db_state.get().is_initialized() { // Clear page cache so the header page can be reread from disk and decrypted using the encryption context. - pager.clear_page_cache(); + pager.clear_page_cache(false); } } Ok((io, conn)) @@ -1527,9 +1534,7 @@ impl Connection { db_opts: DatabaseOpts, io: Arc, ) -> Result> { - let mut opts = OpenOptions::parse(uri)?; - // FIXME: for now, only support read only attach - opts.mode = OpenMode::ReadOnly; + let opts = OpenOptions::parse(uri)?; let flags = opts.get_flags()?; let io = opts.vfs.map(Database::io_for_vfs).unwrap_or(Ok(io))?; let db = Database::open_file_with_flags(io.clone(), &opts.path, flags, db_opts, None)?; @@ -1540,26 +1545,35 @@ impl Connection { Ok(db) } - pub fn maybe_update_schema(&self) -> Result<()> { + pub fn set_foreign_keys_enabled(&self, enable: bool) { + self.fk_pragma.store(enable, Ordering::Release); + } + + pub fn foreign_keys_enabled(&self) -> bool { + self.fk_pragma.load(Ordering::Acquire) + } + pub(crate) fn clear_deferred_foreign_key_violations(&self) -> isize { + self.fk_deferred_violations.swap(0, Ordering::Release) + } + + pub(crate) fn get_deferred_foreign_key_violations(&self) -> isize { + self.fk_deferred_violations.load(Ordering::Acquire) + } + + pub fn maybe_update_schema(&self) { let current_schema_version = self.schema.read().schema_version; - let schema = self - .db - .schema - .lock() - .map_err(|_| LimboError::SchemaLocked)?; + let schema = self.db.schema.lock().unwrap(); if matches!(self.get_tx_state(), TransactionState::None) && current_schema_version != schema.schema_version { *self.schema.write() = schema.clone(); } - - Ok(()) } /// Read schema version at current transaction #[cfg(all(feature = "fs", feature = "conn_raw_api"))] pub fn read_schema_version(&self) -> Result { - let pager = self.pager.read(); + let pager = self.pager.load(); pager .io .block(|| pager.with_header(|header| header.schema_cookie)) @@ -1577,16 +1591,16 @@ impl Connection { "write_schema_version must be called from within Write transaction".to_string(), )); }; - let pager = self.pager.read(); + let pager = self.pager.load(); pager.io.block(|| { pager.with_header_mut(|header| { turso_assert!( header.schema_cookie.get() < version, "cookie can't go back in time" ); - *self.transaction_state.write() = TransactionState::Write { + self.set_tx_state(TransactionState::Write { schema_did_change: true, - }; + }); self.with_schema_mut(|schema| schema.schema_version = version); header.schema_cookie = version.into(); }) @@ -1604,7 +1618,7 @@ impl Connection { page: &mut [u8], frame_watermark: Option, ) -> Result { - let pager = self.pager.read(); + let pager = self.pager.load(); let (page_ref, c) = match pager.read_page_no_cache(page_idx as i64, frame_watermark, true) { Ok(result) => result, // on windows, zero read will trigger UnexpectedEof @@ -1630,19 +1644,19 @@ impl Connection { /// (so, if concurrent connection wrote something to the WAL - this method will not see this change) #[cfg(all(feature = "fs", feature = "conn_raw_api"))] pub fn wal_changed_pages_after(&self, frame_watermark: u64) -> Result> { - self.pager.read().wal_changed_pages_after(frame_watermark) + self.pager.load().wal_changed_pages_after(frame_watermark) } #[cfg(all(feature = "fs", feature = "conn_raw_api"))] pub fn wal_state(&self) -> Result { - self.pager.read().wal_state() + self.pager.load().wal_state() } #[cfg(all(feature = "fs", feature = "conn_raw_api"))] pub fn wal_get_frame(&self, frame_no: u64, frame: &mut [u8]) -> Result { use crate::storage::sqlite3_ondisk::parse_wal_frame_header; - let c = self.pager.read().wal_get_frame(frame_no, frame)?; + let c = self.pager.load().wal_get_frame(frame_no, frame)?; self.db.io.wait_for_completion(c)?; let (header, _) = parse_wal_frame_header(frame); Ok(WalFrameInfo { @@ -1656,22 +1670,22 @@ impl Connection { /// If attempt to write frame at the position `frame_no` will create gap in the WAL - method will return error #[cfg(all(feature = "fs", feature = "conn_raw_api"))] pub fn wal_insert_frame(&self, frame_no: u64, frame: &[u8]) -> Result { - self.pager.read().wal_insert_frame(frame_no, frame) + self.pager.load().wal_insert_frame(frame_no, frame) } /// Start WAL session by initiating read+write transaction for this connection #[cfg(all(feature = "fs", feature = "conn_raw_api"))] pub fn wal_insert_begin(&self) -> Result<()> { - let pager = self.pager.read(); + let pager = self.pager.load(); pager.begin_read_tx()?; pager.io.block(|| pager.begin_write_tx()).inspect_err(|_| { - pager.end_read_tx().expect("read txn must be closed"); + pager.end_read_tx(); })?; // start write transaction and disable auto-commit mode as SQL can be executed within WAL session (at caller own risk) - *self.transaction_state.write() = TransactionState::Write { + self.set_tx_state(TransactionState::Write { schema_did_change: false, - }; + }); self.auto_commit.store(false, Ordering::SeqCst); Ok(()) @@ -1682,7 +1696,7 @@ impl Connection { #[cfg(all(feature = "fs", feature = "conn_raw_api"))] pub fn wal_insert_end(self: &Arc, force_commit: bool) -> Result<()> { { - let pager = self.pager.read(); + let pager = self.pager.load(); let Some(wal) = pager.wal.as_ref() else { return Err(LimboError::InternalError( @@ -1713,13 +1727,11 @@ impl Connection { wal.end_read_tx(); } - let rollback_err = if !force_commit { + if !force_commit { // remove all non-commited changes in case if WAL session left some suffix without commit frame - pager.rollback(false, self, true).err() - } else { - None - }; - if let Some(err) = commit_err.or(rollback_err) { + pager.rollback(false, self, true); + } + if let Some(err) = commit_err { return Err(err); } } @@ -1734,19 +1746,14 @@ impl Connection { if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } - self.pager.read().cacheflush() - } - - pub fn clear_page_cache(&self) -> Result<()> { - self.pager.read().clear_page_cache(); - Ok(()) + self.pager.load().cacheflush() } pub fn checkpoint(&self, mode: CheckpointMode) -> Result { if self.is_closed() { return Err(LimboError::InternalError("Connection closed".to_string())); } - self.pager.read().wal_checkpoint(mode) + self.pager.load().wal_checkpoint(mode) } /// Close a connection and checkpoint. @@ -1762,13 +1769,8 @@ impl Connection { } _ => { if !self.mvcc_enabled() { - let pager = self.pager.read(); - pager.io.block(|| { - pager.end_tx( - true, // rollback = true for close - self, - ) - })?; + let pager = self.pager.load(); + pager.rollback_tx(self); } self.set_tx_state(TransactionState::None); } @@ -1781,7 +1783,7 @@ impl Connection { .eq(&1) { self.pager - .read() + .load() .checkpoint_shutdown(self.is_wal_auto_checkpoint_disabled())?; }; Ok(()) @@ -1890,11 +1892,11 @@ impl Connection { shared_wal.enabled.store(false, Ordering::SeqCst); shared_wal.file = None; } - self.pager.write().clear_page_cache(); + self.pager.load().clear_page_cache(false); let pager = self.db.init_pager(Some(size.get() as usize))?; pager.enable_encryption(self.db.opts.enable_encryption); - *self.pager.write() = Arc::new(pager); - self.pager.read().set_initial_page_size(size); + self.pager.store(Arc::new(pager)); + self.pager.load().set_initial_page_size(size); Ok(()) } @@ -2028,12 +2030,12 @@ impl Connection { } pub fn is_db_initialized(&self) -> bool { - self.db.db_state.is_initialized() + self.db.db_state.get().is_initialized() } fn get_pager_from_database_index(&self, index: &usize) -> Arc { if *index < 2 { - self.pager.read().clone() + self.pager.load().clone() } else { self.attached_databases.read().get_pager_by_index(index) } @@ -2075,12 +2077,7 @@ impl Connection { ))); } - let use_indexes = self - .db - .schema - .lock() - .map_err(|_| LimboError::SchemaLocked)? - .indexes_enabled(); + let use_indexes = self.db.schema.lock().unwrap().indexes_enabled(); let use_mvcc = self.db.mv_store.is_some(); let use_views = self.db.experimental_views_enabled(); let use_strict = self.db.experimental_strict_enabled(); @@ -2090,9 +2087,15 @@ impl Connection { .with_indexes(use_indexes) .with_views(use_views) .with_strict(use_strict); - let db = Self::from_uri_attached(path, db_opts, self.db.io.clone())?; + let io: Arc = if path.contains(":memory:") { + Arc::new(MemoryIO::new()) + } else { + Arc::new(PlatformIO::new()?) + }; + let db = Self::from_uri_attached(path, db_opts, io)?; let pager = Arc::new(db.init_pager(None)?); - + // FIXME: for now, only support read only attach + db.open_flags.set(OpenFlags::ReadOnly); self.attached_databases.write().insert(alias, (db, pager)); Ok(()) @@ -2245,7 +2248,7 @@ impl Connection { } pub fn get_pager(&self) -> Arc { - self.pager.read().clone() + self.pager.load().clone() } pub fn get_query_only(&self) -> bool { @@ -2257,11 +2260,11 @@ impl Connection { } pub fn get_sync_mode(&self) -> SyncMode { - *self.sync_mode.read() + self.sync_mode.get() } pub fn set_sync_mode(&self, mode: SyncMode) { - *self.sync_mode.write() = mode; + self.sync_mode.set(mode); } pub fn get_data_sync_retry(&self) -> bool { @@ -2287,18 +2290,21 @@ impl Connection { pub fn set_encryption_cipher(&self, cipher_mode: CipherMode) -> Result<()> { tracing::trace!("setting encryption cipher for connection"); - *self.encryption_cipher_mode.write() = Some(cipher_mode); + self.encryption_cipher_mode.set(cipher_mode); self.set_encryption_context() } pub fn set_reserved_bytes(&self, reserved_bytes: u8) -> Result<()> { - let pager = self.pager.read(); + let pager = self.pager.load(); pager.set_reserved_space_bytes(reserved_bytes); Ok(()) } pub fn get_encryption_cipher_mode(&self) -> Option { - *self.encryption_cipher_mode.read() + match self.encryption_cipher_mode.get() { + CipherMode::None => None, + mode => Some(mode), + } } // if both key and cipher are set, set encryption context on pager @@ -2307,12 +2313,12 @@ impl Connection { let Some(key) = key_guard.as_ref() else { return Ok(()); }; - let cipher_guard = self.encryption_cipher_mode.read(); - let Some(cipher_mode) = *cipher_guard else { + let cipher_mode = self.get_encryption_cipher_mode(); + let Some(cipher_mode) = cipher_mode else { return Ok(()); }; tracing::trace!("setting encryption ctx for connection"); - let pager = self.pager.read(); + let pager = self.pager.load(); if pager.is_encryption_ctx_set() { return Err(LimboError::InvalidArgument( "cannot reset encryption attributes if already set in the session".to_string(), @@ -2346,11 +2352,11 @@ impl Connection { } fn set_tx_state(&self, state: TransactionState) { - *self.transaction_state.write() = state; + self.transaction_state.set(state); } fn get_tx_state(&self) -> TransactionState { - *self.transaction_state.read() + self.transaction_state.get() } pub(crate) fn get_mv_tx_id(&self) -> Option { @@ -2360,6 +2366,23 @@ impl Connection { pub(crate) fn get_mv_tx(&self) -> Option<(u64, TransactionMode)> { *self.mv_tx.read() } + + pub(crate) fn set_mvcc_checkpoint_threshold(&self, threshold: i64) -> Result<()> { + match self.db.mv_store.as_ref() { + Some(mv_store) => { + mv_store.set_checkpoint_threshold(threshold); + Ok(()) + } + None => Err(LimboError::InternalError("MVCC not enabled".into())), + } + } + + pub(crate) fn mvcc_checkpoint_threshold(&self) -> Result { + match self.db.mv_store.as_ref() { + Some(mv_store) => Ok(mv_store.checkpoint_threshold()), + None => Err(LimboError::InternalError("MVCC not enabled".into())), + } + } } #[derive(Debug)] @@ -2426,7 +2449,7 @@ impl BusyTimeout { } } - self.iteration += 1; + self.iteration = self.iteration.saturating_add(1); self.timeout = now + delay; } } @@ -2449,6 +2472,12 @@ pub struct Statement { busy_timeout: Option, } +impl Drop for Statement { + fn drop(&mut self) { + self.reset(); + } +} + impl Statement { pub fn new( program: vdbe::Program, @@ -2492,10 +2521,13 @@ impl Statement { self.state.interrupt(); } - pub fn step(&mut self) -> Result { + fn _step(&mut self, waker: Option<&Waker>) -> Result { if let Some(busy_timeout) = self.busy_timeout.as_ref() { if self.pager.io.now() < busy_timeout.timeout { // Yield the query as the timeout has not been reached yet + if let Some(waker) = waker { + waker.wake_by_ref(); + } return Ok(StepResult::IO); } } @@ -2506,6 +2538,7 @@ impl Statement { self.mv_store.as_ref(), self.pager.clone(), self.query_mode, + waker, ) } else { const MAX_SCHEMA_RETRY: usize = 50; @@ -2514,6 +2547,7 @@ impl Statement { self.mv_store.as_ref(), self.pager.clone(), self.query_mode, + waker, ); for attempt in 0..MAX_SCHEMA_RETRY { // Only reprepare if we still need to update schema @@ -2527,6 +2561,7 @@ impl Statement { self.mv_store.as_ref(), self.pager.clone(), self.query_mode, + waker, ); } res @@ -2557,6 +2592,9 @@ impl Statement { }; if now < self.busy_timeout.as_ref().unwrap().timeout { + if let Some(waker) = waker { + waker.wake_by_ref(); + } res = Ok(StepResult::IO); } } @@ -2564,6 +2602,14 @@ impl Statement { res } + pub fn step(&mut self) -> Result { + self._step(None) + } + + pub fn step_with_waker(&mut self, waker: &Waker) -> Result { + self._step(Some(waker)) + } + pub(crate) fn run_ignore_rows(&mut self) -> Result<()> { loop { match self.step()? { @@ -2598,7 +2644,8 @@ impl Statement { fn reprepare(&mut self) -> Result<()> { tracing::trace!("repreparing statement"); let conn = self.program.connection.clone(); - *conn.schema.write() = conn.db.clone_schema()?; + + *conn.schema.write() = conn.db.clone_schema(); self.program = { let mut parser = Parser::new(self.program.sql.as_bytes()); let cmd = parser.next_cmd()?; @@ -2626,7 +2673,7 @@ impl Statement { QueryMode::Explain => (EXPLAIN_COLUMNS.len(), 0), QueryMode::ExplainQueryPlan => (EXPLAIN_QUERY_PLAN_COLUMNS.len(), 0), }; - self._reset(Some(max_registers), Some(cursor_count)); + self.reset_internal(Some(max_registers), Some(cursor_count)); // Load the parameters back into the state self.state.parameters = parameters; Ok(()) @@ -2648,12 +2695,8 @@ impl Statement { } let state = self.program.connection.get_tx_state(); if let TransactionState::Write { .. } = state { - let end_tx_res = self.pager.end_tx(true, &self.program.connection)?; + self.pager.rollback_tx(&self.program.connection); self.program.connection.set_tx_state(TransactionState::None); - assert!( - matches!(end_tx_res, IOResult::Done(_)), - "end_tx should not return IO as it should just end txn without flushing anything. Got {end_tx_res:?}" - ); } } res @@ -2668,6 +2711,17 @@ impl Statement { } pub fn get_column_name(&self, idx: usize) -> Cow<'_, str> { + if self.query_mode == QueryMode::Explain { + return Cow::Owned(EXPLAIN_COLUMNS.get(idx).expect("No column").to_string()); + } + if self.query_mode == QueryMode::ExplainQueryPlan { + return Cow::Owned( + EXPLAIN_QUERY_PLAN_COLUMNS + .get(idx) + .expect("No column") + .to_string(), + ); + } match self.query_mode { QueryMode::Normal => { let column = &self.program.result_columns.get(idx).expect("No column"); @@ -2686,6 +2740,9 @@ impl Statement { } pub fn get_column_table_name(&self, idx: usize) -> Option> { + if self.query_mode == QueryMode::Explain || self.query_mode == QueryMode::ExplainQueryPlan { + return None; + } let column = &self.program.result_columns.get(idx).expect("No column"); match &column.expr { turso_parser::ast::Expr::Column { table, .. } => self @@ -2698,6 +2755,22 @@ impl Statement { } pub fn get_column_type(&self, idx: usize) -> Option { + if self.query_mode == QueryMode::Explain { + return Some( + EXPLAIN_COLUMNS_TYPE + .get(idx) + .expect("No column") + .to_string(), + ); + } + if self.query_mode == QueryMode::ExplainQueryPlan { + return Some( + EXPLAIN_QUERY_PLAN_COLUMNS_TYPE + .get(idx) + .expect("No column") + .to_string(), + ); + } let column = &self.program.result_columns.get(idx).expect("No column"); match &column.expr { turso_parser::ast::Expr::Column { @@ -2744,10 +2817,17 @@ impl Statement { } pub fn reset(&mut self) { - self._reset(None, None); + self.reset_internal(None, None); } - pub fn _reset(&mut self, max_registers: Option, max_cursors: Option) { + fn reset_internal(&mut self, max_registers: Option, max_cursors: Option) { + // as abort uses auto_txn_cleanup value - it needs to be called before state.reset + self.program.abort( + self.mv_store.as_ref(), + &self.pager, + None, + &mut self.state.auto_txn_cleanup, + ); self.state.reset(max_registers, max_cursors); self.busy = false; self.busy_timeout = None; diff --git a/core/mvcc/cursor.rs b/core/mvcc/cursor.rs index 0709b2641..b9e0ce136 100644 --- a/core/mvcc/cursor.rs +++ b/core/mvcc/cursor.rs @@ -1,8 +1,13 @@ +use parking_lot::RwLock; + use crate::mvcc::clock::LogicalClock; use crate::mvcc::database::{MVTableId, MvStore, Row, RowID}; -use crate::types::{IOResult, SeekKey, SeekOp, SeekResult}; -use crate::Result; +use crate::storage::btree::{BTreeCursor, BTreeKey, CursorTrait}; +use crate::types::{IOResult, ImmutableRecord, RecordCursor, SeekKey, SeekOp, SeekResult}; +use crate::{return_if_io, Result}; use crate::{Pager, Value}; +use std::any::Any; +use std::cell::{Ref, RefCell}; use std::fmt::Debug; use std::ops::Bound; use std::sync::Arc; @@ -16,75 +21,49 @@ enum CursorPosition { /// We have reached the end of the table. End, } -#[derive(Debug)] + pub struct MvccLazyCursor { pub db: Arc>, - current_pos: CursorPosition, + current_pos: RefCell, pub table_id: MVTableId, tx_id: u64, + /// Reusable immutable record, used to allow better allocation strategy. + reusable_immutable_record: RefCell>, + _btree_cursor: Box, + null_flag: bool, + record_cursor: RefCell, + next_rowid_lock: Arc>, } -impl MvccLazyCursor { +impl MvccLazyCursor { pub fn new( db: Arc>, tx_id: u64, root_page_or_table_id: i64, pager: Arc, + btree_cursor: Box, ) -> Result> { + assert!( + (&*btree_cursor as &dyn Any).is::(), + "BTreeCursor expected for mvcc cursor" + ); let table_id = db.get_table_id_from_root_page(root_page_or_table_id); db.maybe_initialize_table(table_id, pager)?; - let cursor = Self { + Ok(Self { db, tx_id, - current_pos: CursorPosition::BeforeFirst, + current_pos: RefCell::new(CursorPosition::BeforeFirst), table_id, - }; - Ok(cursor) + reusable_immutable_record: RefCell::new(None), + _btree_cursor: btree_cursor, + null_flag: false, + record_cursor: RefCell::new(RecordCursor::new()), + next_rowid_lock: Arc::new(RwLock::new(())), + }) } - /// Insert a row into the table. - /// Sets the cursor to the inserted row. - pub fn insert(&mut self, row: Row) -> Result<()> { - self.current_pos = CursorPosition::Loaded(row.id); - if self.db.read(self.tx_id, row.id)?.is_some() { - self.db.update(self.tx_id, row).inspect_err(|_| { - self.current_pos = CursorPosition::BeforeFirst; - })?; - } else { - self.db.insert(self.tx_id, row).inspect_err(|_| { - self.current_pos = CursorPosition::BeforeFirst; - })?; - } - Ok(()) - } - - pub fn delete(&mut self, rowid: RowID) -> Result<()> { - self.db.delete(self.tx_id, rowid)?; - Ok(()) - } - - pub fn current_row_id(&mut self) -> Option { - match self.current_pos { - CursorPosition::Loaded(id) => Some(id), - CursorPosition::BeforeFirst => { - // If we are before first, we need to try and find the first row. - let maybe_rowid = - self.db - .get_next_row_id_for_table(self.table_id, i64::MIN, self.tx_id); - if let Some(id) = maybe_rowid { - self.current_pos = CursorPosition::Loaded(id); - Some(id) - } else { - self.current_pos = CursorPosition::BeforeFirst; - None - } - } - CursorPosition::End => None, - } - } - - pub fn current_row(&mut self) -> Result> { - match self.current_pos { + pub fn current_row(&self) -> Result> { + match *self.current_pos.borrow() { CursorPosition::Loaded(id) => self.db.read(self.tx_id, id), CursorPosition::BeforeFirst => { // If we are before first, we need to try and find the first row. @@ -92,7 +71,7 @@ impl MvccLazyCursor { self.db .get_next_row_id_for_table(self.table_id, i64::MIN, self.tx_id); if let Some(id) = maybe_rowid { - self.current_pos = CursorPosition::Loaded(id); + self.current_pos.replace(CursorPosition::Loaded(id)); self.db.read(self.tx_id, id) } else { Ok(None) @@ -106,19 +85,61 @@ impl MvccLazyCursor { Ok(()) } + pub fn get_next_rowid(&mut self) -> i64 { + // lock so we don't get same two rowids + let lock = self.next_rowid_lock.clone(); + let _lock = lock.write(); + let _ = self.last(); + match *self.current_pos.borrow() { + CursorPosition::Loaded(id) => id.row_id + 1, + CursorPosition::BeforeFirst => 1, + CursorPosition::End => i64::MAX, + } + } + + fn get_immutable_record_or_create(&self) -> std::cell::RefMut<'_, Option> { + let mut reusable_immutable_record = self.reusable_immutable_record.borrow_mut(); + if reusable_immutable_record.is_none() { + let record = ImmutableRecord::new(1024); + reusable_immutable_record.replace(record); + } + reusable_immutable_record + } + + fn get_current_pos(&self) -> CursorPosition { + *self.current_pos.borrow() + } +} + +impl CursorTrait for MvccLazyCursor { + fn last(&mut self) -> Result> { + let last_rowid = self.db.get_last_rowid(self.table_id); + if let Some(last_rowid) = last_rowid { + self.current_pos.replace(CursorPosition::Loaded(RowID { + table_id: self.table_id, + row_id: last_rowid, + })); + } else { + self.current_pos.replace(CursorPosition::BeforeFirst); + } + self.invalidate_record(); + Ok(IOResult::Done(())) + } + /// Move the cursor to the next row. Returns true if the cursor moved to the next row, false if the cursor is at the end of the table. - pub fn forward(&mut self) -> bool { - let before_first = matches!(self.current_pos, CursorPosition::BeforeFirst); - let min_id = match self.current_pos { + fn next(&mut self) -> Result> { + let before_first = matches!(self.get_current_pos(), CursorPosition::BeforeFirst); + let min_id = match *self.current_pos.borrow() { CursorPosition::Loaded(id) => id.row_id + 1, // TODO: do we need to forward twice? CursorPosition::BeforeFirst => i64::MIN, // we need to find first row, so we look from the first id, CursorPosition::End => { // let's keep same state, we reached the end so no point in moving forward. - return false; + return Ok(IOResult::Done(false)); } }; - self.current_pos = + + let new_position = match self .db .get_next_row_id_for_table(self.table_id, min_id, self.tx_id) @@ -134,46 +155,60 @@ impl MvccLazyCursor { } } }; - matches!(self.current_pos, CursorPosition::Loaded(_)) + self.current_pos.replace(new_position); + self.invalidate_record(); + + Ok(IOResult::Done(matches!( + self.get_current_pos(), + CursorPosition::Loaded(_) + ))) } - /// Returns true if the is not pointing to any row. - pub fn is_empty(&self) -> bool { - // If we reached the end of the table, it means we traversed the whole table therefore there must be something in the table. - // If we have loaded a row, it means there is something in the table. - match self.current_pos { - CursorPosition::Loaded(_) => false, - CursorPosition::BeforeFirst => true, - CursorPosition::End => true, + fn prev(&mut self) -> Result> { + todo!() + } + + fn rowid(&self) -> Result>> { + let rowid = match self.get_current_pos() { + CursorPosition::Loaded(id) => Some(id.row_id), + CursorPosition::BeforeFirst => { + // If we are before first, we need to try and find the first row. + let maybe_rowid = + self.db + .get_next_row_id_for_table(self.table_id, i64::MIN, self.tx_id); + if let Some(id) = maybe_rowid { + self.current_pos.replace(CursorPosition::Loaded(id)); + Some(id.row_id) + } else { + self.current_pos.replace(CursorPosition::BeforeFirst); + None + } + } + CursorPosition::End => None, + }; + Ok(IOResult::Done(rowid)) + } + + fn record( + &self, + ) -> Result>>> { + let Some(row) = self.current_row()? else { + return Ok(IOResult::Done(None)); + }; + + { + let mut record = self.get_immutable_record_or_create(); + let record = record.as_mut().unwrap(); + record.invalidate(); + record.start_serialization(&row.data); } + + let record_ref = + Ref::filter_map(self.reusable_immutable_record.borrow(), |opt| opt.as_ref()).unwrap(); + Ok(IOResult::Done(Some(record_ref))) } - pub fn rewind(&mut self) { - self.current_pos = CursorPosition::BeforeFirst; - } - - pub fn last(&mut self) { - let last_rowid = self.db.get_last_rowid(self.table_id); - if let Some(last_rowid) = last_rowid { - self.current_pos = CursorPosition::Loaded(RowID { - table_id: self.table_id, - row_id: last_rowid, - }); - } else { - self.current_pos = CursorPosition::BeforeFirst; - } - } - - pub fn get_next_rowid(&mut self) -> i64 { - self.last(); - match self.current_pos { - CursorPosition::Loaded(id) => id.row_id + 1, - CursorPosition::BeforeFirst => 1, - CursorPosition::End => i64::MAX, - } - } - - pub fn seek(&mut self, seek_key: SeekKey<'_>, op: SeekOp) -> Result> { + fn seek(&mut self, seek_key: SeekKey<'_>, op: SeekOp) -> Result> { let row_id = match seek_key { SeekKey::TableRowId(row_id) => row_id, SeekKey::IndexKey(_) => { @@ -194,9 +229,10 @@ impl MvccLazyCursor { SeekOp::LT => (Bound::Excluded(&rowid), false), SeekOp::LE { eq_only: _ } => (Bound::Included(&rowid), false), }; + self.invalidate_record(); let rowid = self.db.seek_rowid(bound, lower_bound, self.tx_id); if let Some(rowid) = rowid { - self.current_pos = CursorPosition::Loaded(rowid); + self.current_pos.replace(CursorPosition::Loaded(rowid)); if op.eq_only() { if rowid.row_id == row_id { Ok(IOResult::Done(SeekResult::Found)) @@ -209,36 +245,193 @@ impl MvccLazyCursor { } else { let forwards = matches!(op, SeekOp::GE { eq_only: _ } | SeekOp::GT); if forwards { - self.last(); + let _ = self.last()?; } else { - self.rewind(); + let _ = self.rewind()?; } Ok(IOResult::Done(SeekResult::NotFound)) } } - pub fn exists(&mut self, key: &Value) -> Result> { + /// Insert a row into the table. + /// Sets the cursor to the inserted row. + fn insert(&mut self, key: &BTreeKey) -> Result> { + let Some(rowid) = key.maybe_rowid() else { + todo!() + }; + let row_id = RowID::new(self.table_id, rowid); + let record_buf = key.get_record().unwrap().get_payload().to_vec(); + let num_columns = match key { + BTreeKey::IndexKey(record) => record.column_count(), + BTreeKey::TableRowId((_, record)) => record.as_ref().unwrap().column_count(), + }; + let row = crate::mvcc::database::Row::new(row_id, record_buf, num_columns); + + self.current_pos.replace(CursorPosition::Loaded(row.id)); + if self.db.read(self.tx_id, row.id)?.is_some() { + self.db.update(self.tx_id, row).inspect_err(|_| { + self.current_pos.replace(CursorPosition::BeforeFirst); + })?; + } else { + self.db.insert(self.tx_id, row).inspect_err(|_| { + self.current_pos.replace(CursorPosition::BeforeFirst); + })?; + } + self.invalidate_record(); + Ok(IOResult::Done(())) + } + + fn delete(&mut self) -> Result> { + let IOResult::Done(Some(rowid)) = self.rowid()? else { + todo!(); + }; + let rowid = RowID::new(self.table_id, rowid); + self.db.delete(self.tx_id, rowid)?; + self.invalidate_record(); + Ok(IOResult::Done(())) + } + + fn set_null_flag(&mut self, flag: bool) { + self.null_flag = flag; + } + + fn get_null_flag(&self) -> bool { + self.null_flag + } + + fn exists(&mut self, key: &Value) -> Result> { + self.invalidate_record(); let int_key = match key { Value::Integer(i) => i, _ => unreachable!("btree tables are indexed by integers!"), }; - let exists = self - .db - .seek_rowid( - Bound::Included(&RowID { - table_id: self.table_id, - row_id: *int_key, - }), - true, - self.tx_id, - ) - .is_some(); - if exists { - self.current_pos = CursorPosition::Loaded(RowID { + let rowid = self.db.seek_rowid( + Bound::Included(&RowID { table_id: self.table_id, row_id: *int_key, - }); + }), + true, + self.tx_id, + ); + tracing::trace!("found {rowid:?}"); + let exists = if let Some(rowid) = rowid { + rowid.row_id == *int_key + } else { + false + }; + if exists { + self.current_pos.replace(CursorPosition::Loaded(RowID { + table_id: self.table_id, + row_id: *int_key, + })); } Ok(IOResult::Done(exists)) } + + fn clear_btree(&mut self) -> Result>> { + todo!() + } + + fn btree_destroy(&mut self) -> Result>> { + todo!() + } + + fn count(&mut self) -> Result> { + todo!() + } + + /// Returns true if the is not pointing to any row. + fn is_empty(&self) -> bool { + // If we reached the end of the table, it means we traversed the whole table therefore there must be something in the table. + // If we have loaded a row, it means there is something in the table. + match self.get_current_pos() { + CursorPosition::Loaded(_) => false, + CursorPosition::BeforeFirst => true, + CursorPosition::End => true, + } + } + + fn root_page(&self) -> i64 { + self.table_id.into() + } + + fn rewind(&mut self) -> Result> { + self.invalidate_record(); + if !matches!(self.get_current_pos(), CursorPosition::BeforeFirst) { + self.current_pos.replace(CursorPosition::BeforeFirst); + } + // Next will set cursor position to a valid position if it exists, otherwise it will set it to one that doesn't exist. + let _ = return_if_io!(self.next()); + Ok(IOResult::Done(())) + } + + fn has_record(&self) -> bool { + todo!() + } + + fn set_has_record(&self, _has_record: bool) { + todo!() + } + + fn get_index_info(&self) -> &crate::types::IndexInfo { + todo!() + } + + fn seek_end(&mut self) -> Result> { + todo!() + } + + fn seek_to_last(&mut self) -> Result> { + self.invalidate_record(); + let max_rowid = RowID { + table_id: self.table_id, + row_id: i64::MAX, + }; + let bound = Bound::Included(&max_rowid); + let lower_bound = false; + + let rowid = self.db.seek_rowid(bound, lower_bound, self.tx_id); + if let Some(rowid) = rowid { + self.current_pos.replace(CursorPosition::Loaded(rowid)); + } else { + self.current_pos.replace(CursorPosition::End); + } + Ok(IOResult::Done(())) + } + + fn invalidate_record(&mut self) { + self.get_immutable_record_or_create() + .as_mut() + .unwrap() + .invalidate(); + self.record_cursor.borrow_mut().invalidate(); + } + + fn has_rowid(&self) -> bool { + todo!() + } + + fn record_cursor_mut(&self) -> std::cell::RefMut<'_, crate::types::RecordCursor> { + self.record_cursor.borrow_mut() + } + + fn get_pager(&self) -> Arc { + todo!() + } + + fn get_skip_advance(&self) -> bool { + todo!() + } +} + +impl Debug for MvccLazyCursor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MvccLazyCursor") + .field("current_pos", &self.current_pos) + .field("table_id", &self.table_id) + .field("tx_id", &self.tx_id) + .field("reusable_immutable_record", &self.reusable_immutable_record) + .field("btree_cursor", &()) + .finish() + } } diff --git a/core/mvcc/database/checkpoint_state_machine.rs b/core/mvcc/database/checkpoint_state_machine.rs index 344dab0d8..712cc8048 100644 --- a/core/mvcc/database/checkpoint_state_machine.rs +++ b/core/mvcc/database/checkpoint_state_machine.rs @@ -4,13 +4,13 @@ use crate::mvcc::database::{ SQLITE_SCHEMA_MVCC_TABLE_ID, }; use crate::state_machine::{StateMachine, StateTransition, TransitionResult}; -use crate::storage::btree::BTreeCursor; +use crate::storage::btree::{BTreeCursor, CursorTrait}; use crate::storage::pager::CreateBTreeFlags; use crate::storage::wal::{CheckpointMode, TursoRwLock}; use crate::types::{IOCompletions, IOResult, ImmutableRecord, RecordCursor}; use crate::{ - CheckpointResult, Completion, Connection, IOExt, Pager, RefValue, Result, TransactionState, - Value, + CheckpointResult, Completion, Connection, IOExt, Pager, Result, TransactionState, Value, + ValueRef, }; use parking_lot::RwLock; use std::collections::{HashMap, HashSet}; @@ -164,92 +164,66 @@ impl CheckpointStateMachine { // 2. A checkpointed table that was destroyed in the logical log. We need to destroy the btree in the pager/btree layer. continue; } + let row_versions = entry.value().read(); + + let mut version_to_checkpoint = None; let mut exists_in_db_file = false; - for (i, version) in row_versions.iter().enumerate() { - let is_last = i == row_versions.len() - 1; - if let TxTimestampOrID::Timestamp(ts) = &version.begin { - if *ts <= self.checkpointed_txid_max_old { + for version in row_versions.iter() { + if let Some(TxTimestampOrID::Timestamp(ts)) = version.begin { + //TODO: garbage collect row versions after checkpointing. + if ts > self.checkpointed_txid_max_old { + version_to_checkpoint = Some(version); + } else { exists_in_db_file = true; } + } + } - let current_version_ts = - if let Some(TxTimestampOrID::Timestamp(ts_end)) = version.end { - ts_end.max(*ts) - } else { - *ts - }; - if current_version_ts <= self.checkpointed_txid_max_old { - // already checkpointed. TODO: garbage collect row versions after checkpointing. - continue; - } + if let Some(version) = version_to_checkpoint { + let is_delete = version.end.is_some(); + if let Some(TxTimestampOrID::Timestamp(ts)) = version.begin { + max_timestamp = max_timestamp.max(ts); + } - // Row versions in sqlite_schema are temporarily assigned a negative root page that is equal to the table id, - // because the root page is not known until it's actually allocated during the checkpoint. - // However, existing tables have a real root page. - let get_table_id_or_root_page_from_sqlite_schema = |row_data: &Vec| { - let row_data = ImmutableRecord::from_bin_record(row_data.clone()); + // Only write the row to the B-tree if it is not a delete, or if it is a delete and it exists in + // the database file. + if !is_delete || exists_in_db_file { + let mut special_write = None; + + if version.row.id.table_id == SQLITE_SCHEMA_MVCC_TABLE_ID { + let row_data = ImmutableRecord::from_bin_record(version.row.data.clone()); let mut record_cursor = RecordCursor::new(); record_cursor.parse_full_header(&row_data).unwrap(); - let RefValue::Integer(root_page) = + if let ValueRef::Integer(root_page) = record_cursor.get_value(&row_data, 3).unwrap() - else { - panic!( - "Expected integer value for root page, got {:?}", - record_cursor.get_value(&row_data, 3) - ); - }; - root_page - }; + { + if is_delete { + let table_id = self + .mvstore + .table_id_to_rootpage + .iter() + .find(|entry| { + entry.value().is_some_and(|r| r == root_page as u64) + }) + .map(|entry| *entry.key()) + .unwrap(); // This assumes a valid mapping exists. + self.destroyed_tables.insert(table_id); - max_timestamp = max_timestamp.max(current_version_ts); - if is_last { - let is_delete = version.end.is_some(); - let is_delete_of_table = - is_delete && version.row.id.table_id == SQLITE_SCHEMA_MVCC_TABLE_ID; - let is_create_of_table = !exists_in_db_file - && !is_delete - && version.row.id.table_id == SQLITE_SCHEMA_MVCC_TABLE_ID; - - // We might need to create or destroy a B-tree in the pager during checkpoint if a row in root page 1 is deleted or created. - let special_write = if is_delete_of_table { - let root_page = - get_table_id_or_root_page_from_sqlite_schema(&version.row.data); - assert!(root_page > 0, "rootpage is positive integer"); - let root_page = root_page as u64; - let table_id = *self - .mvstore - .table_id_to_rootpage - .iter() - .find(|entry| entry.value().is_some_and(|r| r == root_page)) - .unwrap() - .key(); - self.destroyed_tables.insert(table_id); - - if exists_in_db_file { - Some(SpecialWrite::BTreeDestroy { + // We might need to create or destroy a B-tree in the pager during checkpoint if a row in root page 1 is deleted or created. + special_write = Some(SpecialWrite::BTreeDestroy { table_id, - root_page, + root_page: root_page as u64, num_columns: version.row.column_count, - }) - } else { - None + }); + } else if !exists_in_db_file { + let table_id = MVTableId::from(root_page); + special_write = Some(SpecialWrite::BTreeCreate { table_id }); } - } else if is_create_of_table { - let table_id = - get_table_id_or_root_page_from_sqlite_schema(&version.row.data); - let table_id = MVTableId::from(table_id); - Some(SpecialWrite::BTreeCreate { table_id }) - } else { - None - }; - - // Only write the row to the B-tree if it is not a delete, or if it is a delete and it exists in the database file. - let should_be_deleted_from_db_file = is_delete && exists_in_db_file; - if !is_delete || should_be_deleted_from_db_file { - self.write_set.push((version.clone(), special_write)); } } + + self.write_set.push((version.clone(), special_write)); } } } @@ -351,9 +325,9 @@ impl CheckpointStateMachine { } result?; if self.update_transaction_state { - *self.connection.transaction_state.write() = TransactionState::Write { + self.connection.set_tx_state(TransactionState::Write { schema_did_change: false, - }; // TODO: schema_did_change?? + }); // TODO: schema_did_change?? } self.lock_states.pager_write_tx = true; self.state = CheckpointState::WriteRow { @@ -416,7 +390,6 @@ impl CheckpointStateMachine { cursor.clone() } else { let cursor = BTreeCursor::new_table( - None, self.pager.clone(), known_root_page as i64, num_columns, @@ -491,12 +464,8 @@ impl CheckpointStateMachine { let cursor = if let Some(cursor) = self.cursors.get(&root_page) { cursor.clone() } else { - let cursor = BTreeCursor::new_table( - None, // Write directly to B-tree - self.pager.clone(), - root_page as i64, - num_columns, - ); + let cursor = + BTreeCursor::new_table(self.pager.clone(), root_page as i64, num_columns); let cursor = Arc::new(RwLock::new(cursor)); self.cursors.insert(root_page, cursor.clone()); cursor @@ -558,14 +527,14 @@ impl CheckpointStateMachine { CheckpointState::CommitPagerTxn => { tracing::debug!("Committing pager transaction"); - let result = self.pager.end_tx(false, &self.connection)?; + let result = self.pager.commit_tx(&self.connection)?; match result { IOResult::Done(_) => { self.state = CheckpointState::TruncateLogicalLog; self.lock_states.pager_read_tx = false; self.lock_states.pager_write_tx = false; if self.update_transaction_state { - *self.connection.transaction_state.write() = TransactionState::None; + self.connection.set_tx_state(TransactionState::None); } let header = self .pager @@ -652,18 +621,14 @@ impl StateTransition for CheckpointStateMachine { Err(err) => { tracing::info!("Error in checkpoint state machine: {err}"); if self.lock_states.pager_write_tx { - let rollback = true; - self.pager - .io - .block(|| self.pager.end_tx(rollback, self.connection.as_ref())) - .expect("failed to end pager write tx"); + self.pager.rollback_tx(self.connection.as_ref()); if self.update_transaction_state { - *self.connection.transaction_state.write() = TransactionState::None; + self.connection.set_tx_state(TransactionState::None); } } else if self.lock_states.pager_read_tx { - self.pager.end_read_tx().unwrap(); + self.pager.end_read_tx(); if self.update_transaction_state { - *self.connection.transaction_state.write() = TransactionState::None; + self.connection.set_tx_state(TransactionState::None); } } if self.lock_states.blocking_checkpoint_lock_held { diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index 3c3b4aaff..0057e56d8 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -5,6 +5,7 @@ use crate::state_machine::StateTransition; use crate::state_machine::TransitionResult; use crate::storage::btree::BTreeCursor; use crate::storage::btree::BTreeKey; +use crate::storage::btree::CursorTrait; use crate::storage::btree::CursorValidState; use crate::storage::sqlite3_ondisk::DatabaseHeader; use crate::storage::wal::TursoRwLock; @@ -18,11 +19,11 @@ use crate::Completion; use crate::File; use crate::IOExt; use crate::LimboError; -use crate::RefValue; use crate::Result; use crate::Statement; use crate::StepResult; use crate::Value; +use crate::ValueRef; use crate::{Connection, Pager}; use crossbeam_skiplist::{SkipMap, SkipSet}; use parking_lot::RwLock; @@ -115,9 +116,10 @@ impl Row { } /// A row version. +/// TODO: we can optimize this by using bitpacking for the begin and end fields. #[derive(Clone, Debug, PartialEq)] pub struct RowVersion { - pub begin: TxTimestampOrID, + pub begin: Option, pub end: Option, pub row: Row, } @@ -402,7 +404,7 @@ impl CommitStateMachine { commit_coordinator: Arc, header: Arc>>, ) -> Self { - let pager = connection.pager.read().clone(); + let pager = connection.pager.load().clone(); Self { state, is_finalized: false, @@ -572,11 +574,11 @@ impl StateTransition for CommitStateMachine { if let Some(row_versions) = mvcc_store.rows.get(id) { let mut row_versions = row_versions.value().write(); for row_version in row_versions.iter_mut() { - if let TxTimestampOrID::TxID(id) = row_version.begin { + if let Some(TxTimestampOrID::TxID(id)) = row_version.begin { if id == self.tx_id { // New version is valid STARTING FROM committing transaction's end timestamp // See diagram on page 299: https://www.cs.cmu.edu/~15721-f24/papers/Hekaton.pdf - row_version.begin = TxTimestampOrID::Timestamp(*end_ts); + row_version.begin = Some(TxTimestampOrID::Timestamp(*end_ts)); mvcc_store.insert_version_raw( &mut log_record.row_versions, row_version.clone(), @@ -625,7 +627,7 @@ impl StateTransition for CommitStateMachine { let locked = self.commit_coordinator.pager_commit_lock.write(); if !locked { return Ok(TransitionResult::Io(IOCompletions::Single( - Completion::new_dummy(), + Completion::new_yield(), ))); } } @@ -654,7 +656,7 @@ impl StateTransition for CommitStateMachine { let schema_did_change = self.did_commit_schema_change; if schema_did_change { let schema = connection.schema.read().clone(); - connection.db.update_schema_if_newer(schema)?; + connection.db.update_schema_if_newer(schema); } let tx = mvcc_store.txs.get(&self.tx_id).unwrap(); let tx_unlocked = tx.value(); @@ -1045,7 +1047,7 @@ impl MvStore { self.insert_table_id_to_rootpage(root_page_as_table_id, Some(*root_page)); } - if !self.maybe_recover_logical_log(bootstrap_conn.pager.read().clone())? { + if !self.maybe_recover_logical_log(bootstrap_conn.pager.load().clone())? { // There was no logical log to recover, so we're done. return Ok(()); } @@ -1091,7 +1093,7 @@ impl MvStore { assert_eq!(tx.state, TransactionState::Active); let id = row.id; let row_version = RowVersion { - begin: TxTimestampOrID::TxID(tx.tx_id), + begin: Some(TxTimestampOrID::TxID(tx.tx_id)), end: None, row, }; @@ -1354,7 +1356,7 @@ impl MvStore { &self, pager: Arc, maybe_existing_tx_id: Option, - ) -> Result> { + ) -> Result { if !self.blocking_checkpoint_lock.read() { // If there is a stop-the-world checkpoint in progress, we cannot begin any transaction at all. return Err(LimboError::Busy); @@ -1390,7 +1392,7 @@ impl MvStore { ); tracing::debug!("begin_exclusive_tx: tx_id={} succeeded", tx_id); self.txs.insert(tx_id, tx); - Ok(IOResult::Done(tx_id)) + Ok(tx_id) } /// Begins a new transaction in the database. @@ -1547,14 +1549,14 @@ impl MvStore { // Hekaton uses oldest-to-newest order for row versions, so we reverse iterate to find the newest one // this transaction changed. for row_version in row_versions.iter_mut().rev() { - if let TxTimestampOrID::TxID(id) = row_version.begin { + if let Some(TxTimestampOrID::TxID(id)) = row_version.begin { turso_assert!( id == tx_id, "only one tx(0) should exist on loading logical log" ); // New version is valid STARTING FROM committing transaction's end timestamp // See diagram on page 299: https://www.cs.cmu.edu/~15721-f24/papers/Hekaton.pdf - row_version.begin = TxTimestampOrID::Timestamp(end_ts); + row_version.begin = Some(TxTimestampOrID::Timestamp(end_ts)); } if let Some(TxTimestampOrID::TxID(id)) = row_version.end { turso_assert!( @@ -1578,39 +1580,37 @@ impl MvStore { /// # Arguments /// /// * `tx_id` - The ID of the transaction to abort. - pub fn rollback_tx( - &self, - tx_id: TxID, - _pager: Arc, - connection: &Connection, - ) -> Result<()> { + pub fn rollback_tx(&self, tx_id: TxID, _pager: Arc, connection: &Connection) { let tx_unlocked = self.txs.get(&tx_id).unwrap(); let tx = tx_unlocked.value(); *connection.mv_tx.write() = None; assert!(tx.state == TransactionState::Active || tx.state == TransactionState::Preparing); tx.state.store(TransactionState::Aborted); tracing::trace!("abort(tx_id={})", tx_id); - let write_set: Vec = tx.write_set.iter().map(|v| *v.value()).collect(); if self.is_exclusive_tx(&tx_id) { self.commit_coordinator.pager_commit_lock.unlock(); self.release_exclusive_tx(&tx_id); } - for ref id in write_set { - if let Some(row_versions) = self.rows.get(id) { + for rowid in &tx.write_set { + let rowid = rowid.value(); + if let Some(row_versions) = self.rows.get(rowid) { let mut row_versions = row_versions.value().write(); for rv in row_versions.iter_mut() { - if rv.end == Some(TxTimestampOrID::TxID(tx_id)) { + if let Some(TxTimestampOrID::TxID(id)) = rv.begin { + assert_eq!(id, tx_id); + // If the transaction has aborted, + // it marks all its new versions as garbage and sets their Begin + // and End timestamps to infinity to make them invisible + // See section 2.4: https://www.cs.cmu.edu/~15721-f24/papers/Hekaton.pdf + rv.begin = None; + rv.end = None; + } else if rv.end == Some(TxTimestampOrID::TxID(tx_id)) { // undo deletions by this transaction rv.end = None; } } - // remove insertions by this transaction - row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id)); - if row_versions.is_empty() { - self.rows.remove(id); - } } } @@ -1618,7 +1618,7 @@ impl MvStore { > connection.db.schema.lock().unwrap().schema_version { // Connection made schema changes during tx and rolled back -> revert connection-local schema. - *connection.schema.write() = connection.db.clone_schema()?; + *connection.schema.write() = connection.db.clone_schema(); } let tx = tx_unlocked.value(); @@ -1627,8 +1627,6 @@ impl MvStore { // FIXME: verify that we can already remove the transaction here! // Maybe it's fine for snapshot isolation, but too early for serializable? self.remove_tx(tx_id); - - Ok(()) } /// Returns true if the given transaction is the exclusive transaction. @@ -1759,10 +1757,20 @@ impl MvStore { // Extracts the begin timestamp from a transaction #[inline] - fn get_begin_timestamp(&self, ts_or_id: &TxTimestampOrID) -> u64 { + fn get_begin_timestamp(&self, ts_or_id: &Option) -> u64 { match ts_or_id { - TxTimestampOrID::Timestamp(ts) => *ts, - TxTimestampOrID::TxID(tx_id) => self.txs.get(tx_id).unwrap().value().begin_ts, + Some(TxTimestampOrID::Timestamp(ts)) => *ts, + Some(TxTimestampOrID::TxID(tx_id)) => self.txs.get(tx_id).unwrap().value().begin_ts, + // This function is intended to be used in the ordering of row versions within the row version chain in `insert_version_raw`. + // + // The row version chain should be append-only (aside from garbage collection), + // so the specific ordering handled by this function may not be critical. We might + // be able to append directly to the row version chain in the future. + // + // The value 0 is used here to represent an infinite timestamp value. This is a deliberate + // choice for a planned future bitpacking optimization, reserving 0 for this purpose, + // while actual timestamps will start from 1. + None => 0, } } @@ -1876,7 +1884,6 @@ impl MvStore { .value() .unwrap_or_else(|| panic!("Table ID does not have a root page: {table_id}")); let mut cursor = BTreeCursor::new_table( - None, // No MVCC cursor for scanning pager.clone(), root_page as i64, 1, // We'll adjust this as needed @@ -1914,7 +1921,7 @@ impl MvStore { self.insert_version( id, RowVersion { - begin: TxTimestampOrID::Timestamp(0), + begin: Some(TxTimestampOrID::Timestamp(0)), end: None, row: Row::new(id, record.get_payload().to_vec(), column_count), }, @@ -1985,7 +1992,7 @@ impl MvStore { let record = ImmutableRecord::from_bin_record(row_data); let mut record_cursor = RecordCursor::new(); let record_values = record_cursor.get_values(&record).unwrap(); - let RefValue::Integer(root_page) = record_values[3] else { + let ValueRef::Integer(root_page) = record_values[3] else { panic!( "Expected integer value for root page, got {:?}", record_values[3] @@ -2023,7 +2030,7 @@ impl MvStore { } StreamingResult::Eof => { // Set offset to the end so that next writes go to the end of the file - self.storage.logical_log.write().unwrap().offset = reader.offset as u64; + self.storage.logical_log.write().offset = reader.offset as u64; break; } } @@ -2032,9 +2039,13 @@ impl MvStore { Ok(true) } - pub fn set_checkpoint_threshold(&self, threshold: u64) { + pub fn set_checkpoint_threshold(&self, threshold: i64) { self.storage.set_checkpoint_threshold(threshold) } + + pub fn checkpoint_threshold(&self) -> i64 { + self.storage.checkpoint_threshold() + } } /// A write-write conflict happens when transaction T_current attempts to update a @@ -2080,8 +2091,8 @@ impl RowVersion { fn is_begin_visible(txs: &SkipMap, tx: &Transaction, rv: &RowVersion) -> bool { match rv.begin { - TxTimestampOrID::Timestamp(rv_begin_ts) => tx.begin_ts >= rv_begin_ts, - TxTimestampOrID::TxID(rv_begin) => { + Some(TxTimestampOrID::Timestamp(rv_begin_ts)) => tx.begin_ts >= rv_begin_ts, + Some(TxTimestampOrID::TxID(rv_begin)) => { let tb = txs.get(&rv_begin).unwrap(); let tb = tb.value(); let visible = match tb.state.load() { @@ -2101,6 +2112,7 @@ fn is_begin_visible(txs: &SkipMap, tx: &Transaction, rv: &Row ); visible } + None => false, } } diff --git a/core/mvcc/database/tests.rs b/core/mvcc/database/tests.rs index a5cd43c50..0af15e370 100644 --- a/core/mvcc/database/tests.rs +++ b/core/mvcc/database/tests.rs @@ -115,13 +115,17 @@ pub(crate) fn generate_simple_string_row(table_id: MVTableId, id: i64, data: &st } } +pub(crate) fn generate_simple_string_record(data: &str) -> ImmutableRecord { + ImmutableRecord::from_values(&[Value::Text(Text::new(data))], 1) +} + #[test] fn test_insert_read() { let db = MvccTestDb::new(); let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let tx1_row = generate_simple_string_row((-2).into(), 1, "Hello"); db.mvcc_store.insert(tx1, tx1_row.clone()).unwrap(); @@ -141,7 +145,7 @@ fn test_insert_read() { let tx2 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let row = db .mvcc_store @@ -162,7 +166,7 @@ fn test_read_nonexistent() { let db = MvccTestDb::new(); let tx = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let row = db.mvcc_store.read( tx, @@ -180,7 +184,7 @@ fn test_delete() { let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let tx1_row = generate_simple_string_row((-2).into(), 1, "Hello"); db.mvcc_store.insert(tx1, tx1_row.clone()).unwrap(); @@ -220,7 +224,7 @@ fn test_delete() { let tx2 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let row = db .mvcc_store @@ -240,7 +244,7 @@ fn test_delete_nonexistent() { let db = MvccTestDb::new(); let tx = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); assert!(!db .mvcc_store @@ -259,7 +263,7 @@ fn test_commit() { let db = MvccTestDb::new(); let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let tx1_row = generate_simple_string_row((-2).into(), 1, "Hello"); db.mvcc_store.insert(tx1, tx1_row.clone()).unwrap(); @@ -293,7 +297,7 @@ fn test_commit() { let tx2 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let row = db .mvcc_store @@ -316,7 +320,7 @@ fn test_rollback() { let db = MvccTestDb::new(); let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let row1 = generate_simple_string_row((-2).into(), 1, "Hello"); db.mvcc_store.insert(tx1, row1.clone()).unwrap(); @@ -347,11 +351,10 @@ fn test_rollback() { .unwrap(); assert_eq!(row3, row4); db.mvcc_store - .rollback_tx(tx1, db.conn.pager.read().clone(), &db.conn) - .unwrap(); + .rollback_tx(tx1, db.conn.pager.load().clone(), &db.conn); let tx2 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let row5 = db .mvcc_store @@ -373,7 +376,7 @@ fn test_dirty_write() { // T1 inserts a row with ID 1, but does not commit. let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let tx1_row = generate_simple_string_row((-2).into(), 1, "Hello"); db.mvcc_store.insert(tx1, tx1_row.clone()).unwrap(); @@ -392,7 +395,7 @@ fn test_dirty_write() { let conn2 = db.db.connect().unwrap(); // T2 attempts to delete row with ID 1, but fails because T1 has not committed. - let tx2 = db.mvcc_store.begin_tx(conn2.pager.read().clone()).unwrap(); + let tx2 = db.mvcc_store.begin_tx(conn2.pager.load().clone()).unwrap(); let tx2_row = generate_simple_string_row((-2).into(), 1, "World"); assert!(!db.mvcc_store.update(tx2, tx2_row).unwrap()); @@ -417,14 +420,14 @@ fn test_dirty_read() { // T1 inserts a row with ID 1, but does not commit. let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let row1 = generate_simple_string_row((-2).into(), 1, "Hello"); db.mvcc_store.insert(tx1, row1).unwrap(); // T2 attempts to read row with ID 1, but doesn't see one because T1 has not committed. let conn2 = db.db.connect().unwrap(); - let tx2 = db.mvcc_store.begin_tx(conn2.pager.read().clone()).unwrap(); + let tx2 = db.mvcc_store.begin_tx(conn2.pager.load().clone()).unwrap(); let row2 = db .mvcc_store .read( @@ -445,7 +448,7 @@ fn test_dirty_read_deleted() { // T1 inserts a row with ID 1 and commits. let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let tx1_row = generate_simple_string_row((-2).into(), 1, "Hello"); db.mvcc_store.insert(tx1, tx1_row.clone()).unwrap(); @@ -453,7 +456,7 @@ fn test_dirty_read_deleted() { // T2 deletes row with ID 1, but does not commit. let conn2 = db.db.connect().unwrap(); - let tx2 = db.mvcc_store.begin_tx(conn2.pager.read().clone()).unwrap(); + let tx2 = db.mvcc_store.begin_tx(conn2.pager.load().clone()).unwrap(); assert!(db .mvcc_store .delete( @@ -467,7 +470,7 @@ fn test_dirty_read_deleted() { // T3 reads row with ID 1, but doesn't see the delete because T2 hasn't committed. let conn3 = db.db.connect().unwrap(); - let tx3 = db.mvcc_store.begin_tx(conn3.pager.read().clone()).unwrap(); + let tx3 = db.mvcc_store.begin_tx(conn3.pager.load().clone()).unwrap(); let row = db .mvcc_store .read( @@ -489,7 +492,7 @@ fn test_fuzzy_read() { // T1 inserts a row with ID 1 and commits. let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let tx1_row = generate_simple_string_row((-2).into(), 1, "First"); db.mvcc_store.insert(tx1, tx1_row.clone()).unwrap(); @@ -509,7 +512,7 @@ fn test_fuzzy_read() { // T2 reads the row with ID 1 within an active transaction. let conn2 = db.db.connect().unwrap(); - let tx2 = db.mvcc_store.begin_tx(conn2.pager.read().clone()).unwrap(); + let tx2 = db.mvcc_store.begin_tx(conn2.pager.load().clone()).unwrap(); let row = db .mvcc_store .read( @@ -525,7 +528,7 @@ fn test_fuzzy_read() { // T3 updates the row and commits. let conn3 = db.db.connect().unwrap(); - let tx3 = db.mvcc_store.begin_tx(conn3.pager.read().clone()).unwrap(); + let tx3 = db.mvcc_store.begin_tx(conn3.pager.load().clone()).unwrap(); let tx3_row = generate_simple_string_row((-2).into(), 1, "Second"); db.mvcc_store.update(tx3, tx3_row).unwrap(); commit_tx(db.mvcc_store.clone(), &conn3, tx3).unwrap(); @@ -558,7 +561,7 @@ fn test_lost_update() { // T1 inserts a row with ID 1 and commits. let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let tx1_row = generate_simple_string_row((-2).into(), 1, "Hello"); db.mvcc_store.insert(tx1, tx1_row.clone()).unwrap(); @@ -578,13 +581,13 @@ fn test_lost_update() { // T2 attempts to update row ID 1 within an active transaction. let conn2 = db.db.connect().unwrap(); - let tx2 = db.mvcc_store.begin_tx(conn2.pager.read().clone()).unwrap(); + let tx2 = db.mvcc_store.begin_tx(conn2.pager.load().clone()).unwrap(); let tx2_row = generate_simple_string_row((-2).into(), 1, "World"); assert!(db.mvcc_store.update(tx2, tx2_row.clone()).unwrap()); // T3 also attempts to update row ID 1 within an active transaction. let conn3 = db.db.connect().unwrap(); - let tx3 = db.mvcc_store.begin_tx(conn3.pager.read().clone()).unwrap(); + let tx3 = db.mvcc_store.begin_tx(conn3.pager.load().clone()).unwrap(); let tx3_row = generate_simple_string_row((-2).into(), 1, "Hello, world!"); assert!(matches!( db.mvcc_store.update(tx3, tx3_row), @@ -592,8 +595,7 @@ fn test_lost_update() { )); // hack: in the actual tursodb database we rollback the mvcc tx ourselves, so manually roll it back here db.mvcc_store - .rollback_tx(tx3, conn3.pager.read().clone(), &conn3) - .unwrap(); + .rollback_tx(tx3, conn3.pager.load().clone(), &conn3); commit_tx(db.mvcc_store.clone(), &conn2, tx2).unwrap(); assert!(matches!( @@ -602,7 +604,7 @@ fn test_lost_update() { )); let conn4 = db.db.connect().unwrap(); - let tx4 = db.mvcc_store.begin_tx(conn4.pager.read().clone()).unwrap(); + let tx4 = db.mvcc_store.begin_tx(conn4.pager.load().clone()).unwrap(); let row = db .mvcc_store .read( @@ -626,7 +628,7 @@ fn test_committed_visibility() { // let's add $10 to my account since I like money let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let tx1_row = generate_simple_string_row((-2).into(), 1, "10"); db.mvcc_store.insert(tx1, tx1_row.clone()).unwrap(); @@ -634,7 +636,7 @@ fn test_committed_visibility() { // but I like more money, so let me try adding $10 more let conn2 = db.db.connect().unwrap(); - let tx2 = db.mvcc_store.begin_tx(conn2.pager.read().clone()).unwrap(); + let tx2 = db.mvcc_store.begin_tx(conn2.pager.load().clone()).unwrap(); let tx2_row = generate_simple_string_row((-2).into(), 1, "20"); assert!(db.mvcc_store.update(tx2, tx2_row.clone()).unwrap()); let row = db @@ -652,7 +654,7 @@ fn test_committed_visibility() { // can I check how much money I have? let conn3 = db.db.connect().unwrap(); - let tx3 = db.mvcc_store.begin_tx(conn3.pager.read().clone()).unwrap(); + let tx3 = db.mvcc_store.begin_tx(conn3.pager.load().clone()).unwrap(); let row = db .mvcc_store .read( @@ -674,11 +676,11 @@ fn test_future_row() { let tx1 = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let conn2 = db.db.connect().unwrap(); - let tx2 = db.mvcc_store.begin_tx(conn2.pager.read().clone()).unwrap(); + let tx2 = db.mvcc_store.begin_tx(conn2.pager.load().clone()).unwrap(); let tx2_row = generate_simple_string_row((-2).into(), 1, "Hello"); db.mvcc_store.insert(tx2, tx2_row).unwrap(); @@ -716,7 +718,7 @@ use crate::types::Text; use crate::Value; use crate::{Database, StepResult}; use crate::{MemoryIO, Statement}; -use crate::{RefValue, DATABASE_MANAGER}; +use crate::{ValueRef, DATABASE_MANAGER}; // Simple atomic clock implementation for testing @@ -724,7 +726,7 @@ fn setup_test_db() -> (MvccTestDb, u64) { let db = MvccTestDb::new(); let tx_id = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let table_id = MVTableId::new(-1); @@ -747,7 +749,7 @@ fn setup_test_db() -> (MvccTestDb, u64) { let tx_id = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); (db, tx_id) } @@ -756,7 +758,7 @@ fn setup_lazy_db(initial_keys: &[i64]) -> (MvccTestDb, u64) { let db = MvccTestDb::new(); let tx_id = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let table_id = -1; @@ -772,7 +774,7 @@ fn setup_lazy_db(initial_keys: &[i64]) -> (MvccTestDb, u64) { let tx_id = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); (db, tx_id) } @@ -827,19 +829,27 @@ fn test_lazy_scan_cursor_basic() { db.mvcc_store.clone(), tx_id, table_id, - db.conn.pager.read().clone(), + db.conn.pager.load().clone(), + Box::new(BTreeCursor::new(db.conn.pager.load().clone(), table_id, 1)), ) .unwrap(); // Check first row - assert!(cursor.forward()); + assert!(matches!(cursor.next().unwrap(), IOResult::Done(true))); assert!(!cursor.is_empty()); let row = cursor.current_row().unwrap().unwrap(); assert_eq!(row.id.row_id, 1); // Iterate through all rows let mut count = 1; - while cursor.forward() { + loop { + let res = cursor.next().unwrap(); + let IOResult::Done(res) = res else { + panic!("unexpected next result {res:?}"); + }; + if !res { + break; + } count += 1; let row = cursor.current_row().unwrap().unwrap(); assert_eq!(row.id.row_id, count); @@ -849,7 +859,7 @@ fn test_lazy_scan_cursor_basic() { assert_eq!(count, 5); // After the last row, is_empty should return true - assert!(!cursor.forward()); + assert!(!matches!(cursor.next().unwrap(), IOResult::Done(true))); assert!(cursor.is_empty()); } @@ -862,12 +872,13 @@ fn test_lazy_scan_cursor_with_gaps() { db.mvcc_store.clone(), tx_id, table_id, - db.conn.pager.read().clone(), + db.conn.pager.load().clone(), + Box::new(BTreeCursor::new(db.conn.pager.load().clone(), table_id, 1)), ) .unwrap(); // Check first row - assert!(cursor.forward()); + assert!(matches!(cursor.next().unwrap(), IOResult::Done(true))); assert!(!cursor.is_empty()); let row = cursor.current_row().unwrap().unwrap(); assert_eq!(row.id.row_id, 5); @@ -876,12 +887,27 @@ fn test_lazy_scan_cursor_with_gaps() { let expected_ids = [5, 10, 15, 20, 30]; let mut index = 0; - assert_eq!(cursor.current_row_id().unwrap().row_id, expected_ids[index]); + let IOResult::Done(rowid) = cursor.rowid().unwrap() else { + unreachable!(); + }; + let rowid = rowid.unwrap(); + assert_eq!(rowid, expected_ids[index]); - while cursor.forward() { + loop { + let res = cursor.next().unwrap(); + let IOResult::Done(res) = res else { + panic!("unexpected next result {res:?}"); + }; + if !res { + break; + } index += 1; if index < expected_ids.len() { - assert_eq!(cursor.current_row_id().unwrap().row_id, expected_ids[index]); + let IOResult::Done(rowid) = cursor.rowid().unwrap() else { + unreachable!(); + }; + let rowid = rowid.unwrap(); + assert_eq!(rowid, expected_ids[index]); } } @@ -898,11 +924,12 @@ fn test_cursor_basic() { db.mvcc_store.clone(), tx_id, table_id, - db.conn.pager.read().clone(), + db.conn.pager.load().clone(), + Box::new(BTreeCursor::new(db.conn.pager.load().clone(), table_id, 1)), ) .unwrap(); - cursor.forward(); + let _ = cursor.next().unwrap(); // Check first row assert!(!cursor.is_empty()); @@ -911,7 +938,14 @@ fn test_cursor_basic() { // Iterate through all rows let mut count = 1; - while cursor.forward() { + loop { + let res = cursor.next().unwrap(); + let IOResult::Done(res) = res else { + panic!("unexpected next result {res:?}"); + }; + if !res { + break; + } count += 1; let row = cursor.current_row().unwrap().unwrap(); assert_eq!(row.id.row_id, count); @@ -921,7 +955,7 @@ fn test_cursor_basic() { assert_eq!(count, 5); // After the last row, is_empty should return true - assert!(!cursor.forward()); + assert!(!matches!(cursor.next().unwrap(), IOResult::Done(true))); assert!(cursor.is_empty()); } @@ -930,26 +964,28 @@ fn test_cursor_with_empty_table() { let db = MvccTestDb::new(); { // FIXME: force page 1 initialization - let pager = db.conn.pager.read().clone(); + let pager = db.conn.pager.load().clone(); let tx_id = db.mvcc_store.begin_tx(pager.clone()).unwrap(); commit_tx(db.mvcc_store.clone(), &db.conn, tx_id).unwrap(); } let tx_id = db .mvcc_store - .begin_tx(db.conn.pager.read().clone()) + .begin_tx(db.conn.pager.load().clone()) .unwrap(); let table_id = -1; // Empty table // Test LazyScanCursor with empty table - let mut cursor = MvccLazyCursor::new( + let cursor = MvccLazyCursor::new( db.mvcc_store.clone(), tx_id, table_id, - db.conn.pager.read().clone(), + db.conn.pager.load().clone(), + Box::new(BTreeCursor::new(db.conn.pager.load().clone(), table_id, 1)), ) .unwrap(); assert!(cursor.is_empty()); - assert!(cursor.current_row_id().is_none()); + let rowid = cursor.rowid().unwrap(); + assert!(matches!(rowid, IOResult::Done(None))); } #[test] @@ -961,34 +997,37 @@ fn test_cursor_modification_during_scan() { db.mvcc_store.clone(), tx_id, table_id, - db.conn.pager.read().clone(), + db.conn.pager.load().clone(), + Box::new(BTreeCursor::new(db.conn.pager.load().clone(), table_id, 1)), ) .unwrap(); // Read first row - assert!(cursor.forward()); + assert!(matches!(cursor.next().unwrap(), IOResult::Done(true))); let first_row = cursor.current_row().unwrap().unwrap(); assert_eq!(first_row.id.row_id, 1); // Insert a new row with ID between existing rows let new_row_id = RowID::new(table_id.into(), 3); - let new_row = generate_simple_string_row(table_id.into(), new_row_id.row_id, "new_row"); + let new_row = generate_simple_string_record("new_row"); - cursor.insert(new_row).unwrap(); + let _ = cursor + .insert(&BTreeKey::TableRowId((new_row_id.row_id, Some(&new_row)))) + .unwrap(); let row = db.mvcc_store.read(tx_id, new_row_id).unwrap().unwrap(); let mut record = ImmutableRecord::new(1024); record.start_serialization(&row.data); let value = record.get_value(0).unwrap(); match value { - RefValue::Text(text) => { - assert_eq!(text.as_str(), "new_row"); + ValueRef::Text(text, _) => { + assert_eq!(text, b"new_row"); } _ => panic!("Expected Text value"), } assert_eq!(row.id.row_id, 3); // Continue scanning - the cursor should still work correctly - cursor.forward(); // Move to 4 + let _ = cursor.next().unwrap(); // Move to 4 let row = db .mvcc_store .read(tx_id, RowID::new(table_id.into(), 4)) @@ -996,14 +1035,14 @@ fn test_cursor_modification_during_scan() { .unwrap(); assert_eq!(row.id.row_id, 4); - cursor.forward(); // Move to 5 (our new row) + let _ = cursor.next().unwrap(); // Move to 5 (our new row) let row = db .mvcc_store .read(tx_id, RowID::new(table_id.into(), 5)) .unwrap() .unwrap(); assert_eq!(row.id.row_id, 5); - assert!(!cursor.forward()); + assert!(!matches!(cursor.next().unwrap(), IOResult::Done(true))); assert!(cursor.is_empty()); } @@ -1076,7 +1115,7 @@ fn test_snapshot_isolation_tx_visible1() { let current_tx = new_tx(4, 4, TransactionState::Preparing); - let rv_visible = |begin: TxTimestampOrID, end: Option| { + let rv_visible = |begin: Option, end: Option| { let row_version = RowVersion { begin, end, @@ -1088,60 +1127,60 @@ fn test_snapshot_isolation_tx_visible1() { // begin visible: transaction committed with ts < current_tx.begin_ts // end visible: inf - assert!(rv_visible(TxTimestampOrID::TxID(1), None)); + assert!(rv_visible(Some(TxTimestampOrID::TxID(1)), None)); // begin invisible: transaction committed with ts > current_tx.begin_ts - assert!(!rv_visible(TxTimestampOrID::TxID(2), None)); + assert!(!rv_visible(Some(TxTimestampOrID::TxID(2)), None)); // begin invisible: transaction aborted - assert!(!rv_visible(TxTimestampOrID::TxID(3), None)); + assert!(!rv_visible(Some(TxTimestampOrID::TxID(3)), None)); // begin visible: timestamp < current_tx.begin_ts // end invisible: transaction committed with ts > current_tx.begin_ts assert!(!rv_visible( - TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::Timestamp(0)), Some(TxTimestampOrID::TxID(1)) )); // begin visible: timestamp < current_tx.begin_ts // end visible: transaction committed with ts < current_tx.begin_ts assert!(rv_visible( - TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::Timestamp(0)), Some(TxTimestampOrID::TxID(2)) )); // begin visible: timestamp < current_tx.begin_ts // end invisible: transaction aborted assert!(!rv_visible( - TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::Timestamp(0)), Some(TxTimestampOrID::TxID(3)) )); // begin invisible: transaction preparing - assert!(!rv_visible(TxTimestampOrID::TxID(5), None)); + assert!(!rv_visible(Some(TxTimestampOrID::TxID(5)), None)); // begin invisible: transaction committed with ts > current_tx.begin_ts - assert!(!rv_visible(TxTimestampOrID::TxID(6), None)); + assert!(!rv_visible(Some(TxTimestampOrID::TxID(6)), None)); // begin invisible: transaction active - assert!(!rv_visible(TxTimestampOrID::TxID(7), None)); + assert!(!rv_visible(Some(TxTimestampOrID::TxID(7)), None)); // begin invisible: transaction committed with ts > current_tx.begin_ts - assert!(!rv_visible(TxTimestampOrID::TxID(6), None)); + assert!(!rv_visible(Some(TxTimestampOrID::TxID(6)), None)); // begin invisible: transaction active - assert!(!rv_visible(TxTimestampOrID::TxID(7), None)); + assert!(!rv_visible(Some(TxTimestampOrID::TxID(7)), None)); // begin visible: timestamp < current_tx.begin_ts // end invisible: transaction preparing assert!(!rv_visible( - TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::Timestamp(0)), Some(TxTimestampOrID::TxID(5)) )); // begin invisible: timestamp > current_tx.begin_ts assert!(!rv_visible( - TxTimestampOrID::Timestamp(6), + Some(TxTimestampOrID::Timestamp(6)), Some(TxTimestampOrID::TxID(6)) )); @@ -1150,9 +1189,11 @@ fn test_snapshot_isolation_tx_visible1() { // but that hasn't happened // (this is the https://avi.im/blag/2023/hekaton-paper-typo/ case, I believe!) assert!(rv_visible( - TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::Timestamp(0)), Some(TxTimestampOrID::TxID(7)) )); + + assert!(!rv_visible(None, None)); } #[test] @@ -1161,7 +1202,7 @@ fn test_restart() { { let conn = db.connect(); let mvcc_store = db.get_mvcc_store(); - let tx_id = mvcc_store.begin_tx(conn.pager.read().clone()).unwrap(); + let tx_id = mvcc_store.begin_tx(conn.pager.load().clone()).unwrap(); // insert table id -2 into sqlite_schema table (table_id -1) let data = ImmutableRecord::from_values( &[ @@ -1199,21 +1240,21 @@ fn test_restart() { { let conn = db.connect(); let mvcc_store = db.get_mvcc_store(); - let tx_id = mvcc_store.begin_tx(conn.pager.read().clone()).unwrap(); + let tx_id = mvcc_store.begin_tx(conn.pager.load().clone()).unwrap(); let row = generate_simple_string_row((-2).into(), 2, "bar"); mvcc_store.insert(tx_id, row).unwrap(); commit_tx(mvcc_store.clone(), &conn, tx_id).unwrap(); - let tx_id = mvcc_store.begin_tx(conn.pager.read().clone()).unwrap(); + let tx_id = mvcc_store.begin_tx(conn.pager.load().clone()).unwrap(); let row = mvcc_store .read(tx_id, RowID::new((-2).into(), 2)) .unwrap() .unwrap(); let record = get_record_value(&row); match record.get_value(0).unwrap() { - RefValue::Text(text) => { - assert_eq!(text.as_str(), "bar"); + ValueRef::Text(text, _) => { + assert_eq!(text, b"bar"); } _ => panic!("Expected Text value"), } diff --git a/core/mvcc/mod.rs b/core/mvcc/mod.rs index ba96b4c3e..319215448 100644 --- a/core/mvcc/mod.rs +++ b/core/mvcc/mod.rs @@ -65,7 +65,7 @@ mod tests { let conn = db.get_db().connect().unwrap(); let mvcc_store = db.get_db().mv_store.as_ref().unwrap().clone(); for _ in 0..iterations { - let tx = mvcc_store.begin_tx(conn.pager.read().clone()).unwrap(); + let tx = mvcc_store.begin_tx(conn.pager.load().clone()).unwrap(); let id = IDS.fetch_add(1, Ordering::SeqCst); let id = RowID { table_id: (-2).into(), @@ -74,7 +74,7 @@ mod tests { let row = generate_simple_string_row((-2).into(), id.row_id, "Hello"); mvcc_store.insert(tx, row.clone()).unwrap(); commit_tx_no_conn(&db, tx, &conn).unwrap(); - let tx = mvcc_store.begin_tx(conn.pager.read().clone()).unwrap(); + let tx = mvcc_store.begin_tx(conn.pager.load().clone()).unwrap(); let committed_row = mvcc_store.read(tx, id).unwrap(); commit_tx_no_conn(&db, tx, &conn).unwrap(); assert_eq!(committed_row, Some(row)); @@ -86,7 +86,7 @@ mod tests { let conn = db.get_db().connect().unwrap(); let mvcc_store = db.get_db().mv_store.as_ref().unwrap().clone(); for _ in 0..iterations { - let tx = mvcc_store.begin_tx(conn.pager.read().clone()).unwrap(); + let tx = mvcc_store.begin_tx(conn.pager.load().clone()).unwrap(); let id = IDS.fetch_add(1, Ordering::SeqCst); let id = RowID { table_id: (-2).into(), @@ -95,7 +95,7 @@ mod tests { let row = generate_simple_string_row((-2).into(), id.row_id, "World"); mvcc_store.insert(tx, row.clone()).unwrap(); commit_tx_no_conn(&db, tx, &conn).unwrap(); - let tx = mvcc_store.begin_tx(conn.pager.read().clone()).unwrap(); + let tx = mvcc_store.begin_tx(conn.pager.load().clone()).unwrap(); let committed_row = mvcc_store.read(tx, id).unwrap(); commit_tx_no_conn(&db, tx, &conn).unwrap(); assert_eq!(committed_row, Some(row)); @@ -127,7 +127,7 @@ mod tests { let dropped = mvcc_store.drop_unused_row_versions(); tracing::debug!("garbage collected {dropped} versions"); } - let tx = mvcc_store.begin_tx(conn.pager.read().clone()).unwrap(); + let tx = mvcc_store.begin_tx(conn.pager.load().clone()).unwrap(); let id = i % 16; let id = RowID { table_id: (-2).into(), diff --git a/core/mvcc/persistent_storage/logical_log.rs b/core/mvcc/persistent_storage/logical_log.rs index ee572f225..db8a33317 100644 --- a/core/mvcc/persistent_storage/logical_log.rs +++ b/core/mvcc/persistent_storage/logical_log.rs @@ -12,17 +12,14 @@ use std::sync::{Arc, RwLock}; use crate::File; -pub const DEFAULT_LOG_CHECKPOINT_THRESHOLD: u64 = 1024 * 1024 * 8; // 8 MiB as default to mimic - // 2000 pages in sqlite which is - // pretty much equal to - // 8MiB if page_size == - // 4096 bytes +pub const DEFAULT_LOG_CHECKPOINT_THRESHOLD: i64 = -1; // Disabled by default pub struct LogicalLog { pub file: Arc, pub offset: u64, /// Size at which we start performing a checkpoint on the logical log. - checkpoint_threshold: u64, + /// Set to -1 to disable automatic checkpointing. + checkpoint_threshold: i64, } /// Log's Header, this will be the 64 bytes in any logical log file. @@ -229,12 +226,19 @@ impl LogicalLog { } pub fn should_checkpoint(&self) -> bool { - self.offset >= self.checkpoint_threshold + if self.checkpoint_threshold < 0 { + return false; + } + self.offset >= self.checkpoint_threshold as u64 } - pub fn set_checkpoint_threshold(&mut self, threshold: u64) { + pub fn set_checkpoint_threshold(&mut self, threshold: i64) { self.checkpoint_threshold = threshold; } + + pub fn checkpoint_threshold(&self) -> i64 { + self.checkpoint_threshold + } } pub enum StreamingResult { @@ -481,7 +485,7 @@ impl StreamingLogicalLogReader { mod tests { use std::{collections::HashSet, sync::Arc}; - use rand::{thread_rng, Rng}; + use rand::{rng, Rng}; use rand_chacha::{ rand_core::{RngCore, SeedableRng}, ChaCha8Rng, @@ -497,7 +501,7 @@ mod tests { LocalClock, MvStore, }, types::{ImmutableRecord, Text}, - OpenFlags, RefValue, Value, + OpenFlags, Value, ValueRef, }; use super::LogRecordType; @@ -509,7 +513,7 @@ mod tests { let db = MvccTestDbNoConn::new_with_random_db(); let (io, pager) = { let conn = db.connect(); - let pager = conn.pager.read().clone(); + let pager = conn.pager.load().clone(); let mvcc_store = db.get_mvcc_store(); let tx_id = mvcc_store.begin_tx(pager.clone()).unwrap(); // insert table id -2 into sqlite_schema table (table_id -1) @@ -561,10 +565,10 @@ mod tests { let record = ImmutableRecord::from_bin_record(row.data.clone()); let values = record.get_values(); let foo = values.first().unwrap(); - let RefValue::Text(foo) = foo else { + let ValueRef::Text(foo, _) = foo else { unreachable!() }; - assert_eq!(foo.as_str(), "foo"); + assert_eq!(foo, b"foo"); } #[test] @@ -576,7 +580,7 @@ mod tests { let db = MvccTestDbNoConn::new_with_random_db(); let (io, pager) = { let conn = db.connect(); - let pager = conn.pager.read().clone(); + let pager = conn.pager.load().clone(); let mvcc_store = db.get_mvcc_store(); let tx_id = mvcc_store.begin_tx(pager.clone()).unwrap(); @@ -633,16 +637,16 @@ mod tests { let record = ImmutableRecord::from_bin_record(row.data.clone()); let values = record.get_values(); let foo = values.first().unwrap(); - let RefValue::Text(foo) = foo else { + let ValueRef::Text(foo, _) = foo else { unreachable!() }; - assert_eq!(foo.as_str(), value.as_str()); + assert_eq!(*foo, value.as_bytes()); } } #[test] fn test_logical_log_read_fuzz() { - let seed = thread_rng().gen(); + let seed = rng().random(); let mut rng = ChaCha8Rng::seed_from_u64(seed); let num_transactions = rng.next_u64() % 128; let mut txns = vec![]; @@ -687,7 +691,7 @@ mod tests { let mut db = MvccTestDbNoConn::new_with_random_db(); let pager = { let conn = db.connect(); - let pager = conn.pager.read().clone(); + let pager = conn.pager.load().clone(); let mvcc_store = db.get_mvcc_store(); // insert table id -2 into sqlite_schema table (table_id -1) @@ -754,11 +758,14 @@ mod tests { let record = ImmutableRecord::from_bin_record(row.data.clone()); let values = record.get_values(); let foo = values.first().unwrap(); - let RefValue::Text(foo) = foo else { + let ValueRef::Text(foo, _) = foo else { unreachable!() }; - assert_eq!(foo.as_str(), format!("row_{}", present_rowid.row_id as u64)); + assert_eq!( + String::from_utf8_lossy(foo), + format!("row_{}", present_rowid.row_id as u64) + ); } // Check rowids that were deleted diff --git a/core/mvcc/persistent_storage/mod.rs b/core/mvcc/persistent_storage/mod.rs index 1cc8d0c2b..0ddf14223 100644 --- a/core/mvcc/persistent_storage/mod.rs +++ b/core/mvcc/persistent_storage/mod.rs @@ -1,5 +1,6 @@ +use parking_lot::RwLock; use std::fmt::Debug; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; pub mod logical_log; use crate::mvcc::database::LogRecord; @@ -20,7 +21,7 @@ impl Storage { impl Storage { pub fn log_tx(&self, m: &LogRecord) -> Result { - self.logical_log.write().unwrap().log_tx(m) + self.logical_log.write().log_tx(m) } pub fn read_tx_log(&self) -> Result> { @@ -28,26 +29,27 @@ impl Storage { } pub fn sync(&self) -> Result { - self.logical_log.write().unwrap().sync() + self.logical_log.write().sync() } pub fn truncate(&self) -> Result { - self.logical_log.write().unwrap().truncate() + self.logical_log.write().truncate() } pub fn get_logical_log_file(&self) -> Arc { - self.logical_log.write().unwrap().file.clone() + self.logical_log.write().file.clone() } pub fn should_checkpoint(&self) -> bool { - self.logical_log.read().unwrap().should_checkpoint() + self.logical_log.read().should_checkpoint() } - pub fn set_checkpoint_threshold(&self, threshold: u64) { - self.logical_log - .write() - .unwrap() - .set_checkpoint_threshold(threshold) + pub fn set_checkpoint_threshold(&self, threshold: i64) { + self.logical_log.write().set_checkpoint_threshold(threshold) + } + + pub fn checkpoint_threshold(&self) -> i64 { + self.logical_log.read().checkpoint_threshold() } } diff --git a/core/pragma.rs b/core/pragma.rs index c83509a69..8cf9a99c5 100644 --- a/core/pragma.rs +++ b/core/pragma.rs @@ -127,6 +127,14 @@ pub fn pragma_for(pragma: &PragmaName) -> Pragma { PragmaFlags::Result0 | PragmaFlags::SchemaReq | PragmaFlags::NoColumns1, &["cipher"], ), + PragmaName::MvccCheckpointThreshold => Pragma::new( + PragmaFlags::NoColumns1 | PragmaFlags::Result0, + &["mvcc_checkpoint_threshold"], + ), + ForeignKeys => Pragma::new( + PragmaFlags::NoColumns1 | PragmaFlags::Result0, + &["foreign_keys"], + ), } } diff --git a/core/pseudo.rs b/core/pseudo.rs index d55ba0e98..3c43a55ea 100644 --- a/core/pseudo.rs +++ b/core/pseudo.rs @@ -1,16 +1,39 @@ -use crate::types::ImmutableRecord; +use std::cell::{Ref, RefCell}; + +use crate::{ + types::{ImmutableRecord, RecordCursor}, + Result, Value, +}; -#[derive(Default)] pub struct PseudoCursor { - current: Option, + record_cursor: RecordCursor, + current: RefCell>, +} + +impl Default for PseudoCursor { + fn default() -> Self { + Self { + record_cursor: RecordCursor::new(), + current: RefCell::new(None), + } + } } impl PseudoCursor { - pub fn record(&self) -> Option<&ImmutableRecord> { - self.current.as_ref() + pub fn record(&self) -> Ref> { + self.current.borrow() } pub fn insert(&mut self, record: ImmutableRecord) { - self.current = Some(record); + self.record_cursor.invalidate(); + self.current.replace(Some(record)); + } + + pub fn get_value(&mut self, column: usize) -> Result { + if let Some(record) = self.current.borrow().as_ref() { + Ok(self.record_cursor.get_value(record, column)?.to_owned()) + } else { + Ok(Value::Null) + } } } diff --git a/core/schema.rs b/core/schema.rs index 5188f962f..336fe45e4 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -72,7 +72,7 @@ impl Clone for View { /// Type alias for regular views collection pub type ViewsMap = HashMap>; -use crate::storage::btree::BTreeCursor; +use crate::storage::btree::{BTreeCursor, CursorTrait}; use crate::translate::collate::CollationSeq; use crate::translate::plan::{SelectPlan, TableReferences}; use crate::util::{ @@ -80,7 +80,7 @@ use crate::util::{ }; use crate::{ bail_parse_error, contains_ignore_ascii_case, eq_ignore_ascii_case, match_ignore_ascii_case, - Connection, LimboError, MvCursor, MvStore, Pager, RefValue, SymbolTable, VirtualTable, + Connection, LimboError, MvCursor, MvStore, Pager, SymbolTable, ValueRef, VirtualTable, }; use crate::{util::normalize_ident, Result}; use core::fmt; @@ -89,7 +89,9 @@ use std::ops::Deref; use std::sync::Arc; use std::sync::Mutex; use tracing::trace; -use turso_parser::ast::{self, ColumnDefinition, Expr, Literal, SortOrder, TableOptions}; +use turso_parser::ast::{ + self, ColumnDefinition, Expr, InitDeferredPred, Literal, RefAct, SortOrder, TableOptions, +}; use turso_parser::{ ast::{Cmd, CreateTableBody, ResultColumn, Stmt}, parser::Parser, @@ -287,9 +289,11 @@ impl Schema { } /// Add a regular (non-materialized) view - pub fn add_view(&mut self, view: View) { + pub fn add_view(&mut self, view: View) -> Result<()> { + self.check_object_name_conflict(&view.name)?; let name = normalize_ident(&view.name); self.views.insert(name, Arc::new(view)); + Ok(()) } /// Get a regular view by name @@ -298,14 +302,18 @@ impl Schema { self.views.get(&name).cloned() } - pub fn add_btree_table(&mut self, table: Arc) { + pub fn add_btree_table(&mut self, table: Arc) -> Result<()> { + self.check_object_name_conflict(&table.name)?; let name = normalize_ident(&table.name); self.tables.insert(name, Table::BTree(table).into()); + Ok(()) } - pub fn add_virtual_table(&mut self, table: Arc) { + pub fn add_virtual_table(&mut self, table: Arc) -> Result<()> { + self.check_object_name_conflict(&table.name)?; let name = normalize_ident(&table.name); self.tables.insert(name, Table::Virtual(table).into()); + Ok(()) } pub fn get_table(&self, name: &str) -> Option> { @@ -338,7 +346,8 @@ impl Schema { } } - pub fn add_index(&mut self, index: Arc) { + pub fn add_index(&mut self, index: Arc) -> Result<()> { + self.check_object_name_conflict(&index.name)?; let table_name = normalize_ident(&index.table_name); // We must add the new index to the front of the deque, because SQLite stores index definitions as a linked list // where the newest parsed index entry is at the head of list. If we would add it to the back of a regular Vec for example, @@ -348,7 +357,8 @@ impl Schema { self.indexes .entry(table_name) .or_default() - .push_front(index.clone()) + .push_front(index.clone()); + Ok(()) } pub fn get_indices(&self, table_name: &str) -> impl Iterator> { @@ -404,7 +414,7 @@ impl Schema { mv_cursor.is_none(), "mvcc not yet supported for make_from_btree" ); - let mut cursor = BTreeCursor::new_table(mv_cursor, Arc::clone(&pager), 1, 10); + let mut cursor = BTreeCursor::new_table(Arc::clone(&pager), 1, 10); let mut from_sql_indexes = Vec::with_capacity(10); let mut automatic_indices: HashMap> = HashMap::with_capacity(10); @@ -428,36 +438,36 @@ impl Schema { let mut record_cursor = cursor.record_cursor.borrow_mut(); // sqlite schema table has 5 columns: type, name, tbl_name, rootpage, sql let ty_value = record_cursor.get_value(&row, 0)?; - let RefValue::Text(ty) = ty_value else { + let ValueRef::Text(ty, _) = ty_value else { return Err(LimboError::ConversionError("Expected text value".into())); }; - let ty = ty.as_str(); - let RefValue::Text(name) = record_cursor.get_value(&row, 1)? else { + let ty = String::from_utf8_lossy(ty); + let ValueRef::Text(name, _) = record_cursor.get_value(&row, 1)? else { return Err(LimboError::ConversionError("Expected text value".into())); }; - let name = name.as_str(); + let name = String::from_utf8_lossy(name); let table_name_value = record_cursor.get_value(&row, 2)?; - let RefValue::Text(table_name) = table_name_value else { + let ValueRef::Text(table_name, _) = table_name_value else { return Err(LimboError::ConversionError("Expected text value".into())); }; - let table_name = table_name.as_str(); + let table_name = String::from_utf8_lossy(table_name); let root_page_value = record_cursor.get_value(&row, 3)?; - let RefValue::Integer(root_page) = root_page_value else { + let ValueRef::Integer(root_page) = root_page_value else { return Err(LimboError::ConversionError("Expected integer value".into())); }; let sql_value = record_cursor.get_value(&row, 4)?; let sql_textref = match sql_value { - RefValue::Text(sql) => Some(sql), + ValueRef::Text(sql, _) => Some(sql), _ => None, }; - let sql = sql_textref.as_ref().map(|s| s.as_str()); + let sql = sql_textref.map(|s| String::from_utf8_lossy(s)); self.handle_schema_row( - ty, - name, - table_name, + &ty, + &name, + &table_name, root_page, - sql, + sql.as_deref(), syms, &mut from_sql_indexes, &mut automatic_indices, @@ -472,7 +482,7 @@ impl Schema { pager.io.block(|| cursor.next())?; } - pager.end_read_tx()?; + pager.end_read_tx(); self.populate_indices(from_sql_indexes, automatic_indices)?; @@ -505,7 +515,7 @@ impl Schema { unparsed_sql_from_index.root_page, table.as_ref(), )?; - self.add_index(Arc::new(index)); + self.add_index(Arc::new(index))?; } } @@ -547,7 +557,7 @@ impl Schema { table.as_ref(), automatic_indexes.pop().unwrap(), 1, - )?)); + )?))?; } else { // Add single column unique index if let Some(autoidx) = automatic_indexes.pop() { @@ -555,7 +565,7 @@ impl Schema { table.as_ref(), autoidx, vec![(pos_in_table, unique_set.columns.first().unwrap().1)], - )?)); + )?))?; } } } @@ -573,7 +583,7 @@ impl Schema { table.as_ref(), automatic_indexes.pop().unwrap(), unique_set.columns.len(), - )?)); + )?))?; } else { // Add composite unique index let mut column_indices_and_sort_orders = @@ -591,7 +601,7 @@ impl Schema { table.as_ref(), automatic_indexes.pop().unwrap(), column_indices_and_sort_orders, - )?)); + )?))?; } } @@ -646,6 +656,7 @@ impl Schema { has_rowid: true, is_strict: false, has_autoincrement: false, + foreign_keys: vec![], unique_sets: vec![], }))); @@ -698,7 +709,7 @@ impl Schema { syms, )? }; - self.add_virtual_table(vtab); + self.add_virtual_table(vtab)?; } else { let table = BTreeTable::from_sql(sql, root_page)?; @@ -732,7 +743,7 @@ impl Schema { } } - self.add_btree_table(Arc::new(table)); + self.add_btree_table(Arc::new(table))?; } } "index" => { @@ -831,7 +842,7 @@ impl Schema { // Create regular view let view = View::new(name.to_string(), sql.to_string(), select, final_columns); - self.add_view(view); + self.add_view(view)?; } _ => {} } @@ -842,6 +853,292 @@ impl Schema { Ok(()) } + + /// Compute all resolved FKs *referencing* `table_name` (arg: `table_name` is the parent). + /// Each item contains the child table, normalized columns/positions, and the parent lookup + /// strategy (rowid vs. UNIQUE index or PK). + pub fn resolved_fks_referencing(&self, table_name: &str) -> Result> { + let fk_mismatch_err = |child: &str, parent: &str| -> crate::LimboError { + crate::LimboError::Constraint(format!( + "foreign key mismatch - \"{child}\" referencing \"{parent}\"" + )) + }; + let target = normalize_ident(table_name); + let mut out = Vec::with_capacity(4); // arbitrary estimate + let parent_tbl = self + .get_btree_table(&target) + .ok_or_else(|| fk_mismatch_err("", &target))?; + + // Precompute helper to find parent unique index, if it's not the rowid + let find_parent_unique = |cols: &Vec| -> Option> { + self.get_indices(&parent_tbl.name) + .find(|idx| { + idx.unique + && idx.columns.len() == cols.len() + && idx + .columns + .iter() + .zip(cols.iter()) + .all(|(ic, pc)| ic.name.eq_ignore_ascii_case(pc)) + }) + .cloned() + }; + + for t in self.tables.values() { + let Some(child) = t.btree() else { + continue; + }; + for fk in &child.foreign_keys { + if !fk.parent_table.eq_ignore_ascii_case(&target) { + continue; + } + if fk.child_columns.is_empty() { + // SQLite requires an explicit child column list unless the table has a single-column PK that + return Err(fk_mismatch_err(&child.name, &parent_tbl.name)); + } + let child_cols: Vec = fk.child_columns.clone(); + let mut child_pos = Vec::with_capacity(child_cols.len()); + + for cname in &child_cols { + let (i, _) = child + .get_column(cname) + .ok_or_else(|| fk_mismatch_err(&child.name, &parent_tbl.name))?; + child_pos.push(i); + } + let parent_cols: Vec = if fk.parent_columns.is_empty() { + if !parent_tbl.primary_key_columns.is_empty() { + parent_tbl + .primary_key_columns + .iter() + .map(|(col, _)| col) + .cloned() + .collect() + } else { + return Err(fk_mismatch_err(&child.name, &parent_tbl.name)); + } + } else { + fk.parent_columns.clone() + }; + + // Same length required + if parent_cols.len() != child_cols.len() { + return Err(fk_mismatch_err(&child.name, &parent_tbl.name)); + } + + let mut parent_pos = Vec::with_capacity(parent_cols.len()); + for pc in &parent_cols { + let pos = parent_tbl.get_column(pc).map(|(i, _)| i).or_else(|| { + ROWID_STRS + .iter() + .any(|s| pc.eq_ignore_ascii_case(s)) + .then_some(0) + }); + let Some(p) = pos else { + return Err(fk_mismatch_err(&child.name, &parent_tbl.name)); + }; + parent_pos.push(p); + } + + // Determine if parent key is ROWID/alias + let parent_uses_rowid = parent_tbl.primary_key_columns.len().eq(&1) && { + if parent_tbl.primary_key_columns.len() == 1 { + let pk_name = &parent_tbl.primary_key_columns[0].0; + // rowid or alias INTEGER PRIMARY KEY; either is ok implicitly + parent_tbl.columns.iter().any(|c| { + c.is_rowid_alias + && c.name + .as_deref() + .is_some_and(|n| n.eq_ignore_ascii_case(pk_name)) + }) || ROWID_STRS.iter().any(|&r| r.eq_ignore_ascii_case(pk_name)) + } else { + false + } + }; + + // If not rowid, there must be a non-partial UNIQUE exactly on parent_cols + let parent_unique_index = if parent_uses_rowid { + None + } else { + find_parent_unique(&parent_cols) + .ok_or_else(|| fk_mismatch_err(&child.name, &parent_tbl.name))? + .into() + }; + fk.validate()?; + out.push(ResolvedFkRef { + child_table: Arc::clone(&child), + fk: Arc::clone(fk), + parent_cols, + child_cols, + child_pos, + parent_pos, + parent_uses_rowid, + parent_unique_index, + }); + } + } + Ok(out) + } + + /// Compute all resolved FKs *declared by* `child_table` + pub fn resolved_fks_for_child(&self, child_table: &str) -> crate::Result> { + let fk_mismatch_err = |child: &str, parent: &str| -> crate::LimboError { + crate::LimboError::Constraint(format!( + "foreign key mismatch - \"{child}\" referencing \"{parent}\"" + )) + }; + let child_name = normalize_ident(child_table); + let child = self + .get_btree_table(&child_name) + .ok_or_else(|| fk_mismatch_err(&child_name, ""))?; + + let mut out = Vec::with_capacity(child.foreign_keys.len()); + + for fk in &child.foreign_keys { + let parent_name = normalize_ident(&fk.parent_table); + let parent_tbl = self + .get_btree_table(&parent_name) + .ok_or_else(|| fk_mismatch_err(&child.name, &parent_name))?; + + let child_cols: Vec = fk.child_columns.clone(); + if child_cols.is_empty() { + return Err(fk_mismatch_err(&child.name, &parent_tbl.name)); + } + + // Child positions exist + let mut child_pos = Vec::with_capacity(child_cols.len()); + for cname in &child_cols { + let (i, _) = child + .get_column(cname) + .ok_or_else(|| fk_mismatch_err(&child.name, &parent_tbl.name))?; + child_pos.push(i); + } + + let parent_cols: Vec = if fk.parent_columns.is_empty() { + if !parent_tbl.primary_key_columns.is_empty() { + parent_tbl + .primary_key_columns + .iter() + .map(|(col, _)| col) + .cloned() + .collect() + } else { + return Err(fk_mismatch_err(&child.name, &parent_tbl.name)); + } + } else { + fk.parent_columns.clone() + }; + + if parent_cols.len() != child_cols.len() { + return Err(fk_mismatch_err(&child.name, &parent_tbl.name)); + } + + // Parent positions exist, or rowid sentinel + let mut parent_pos = Vec::with_capacity(parent_cols.len()); + for pc in &parent_cols { + let pos = parent_tbl.get_column(pc).map(|(i, _)| i).or_else(|| { + ROWID_STRS + .iter() + .any(|&r| r.eq_ignore_ascii_case(pc)) + .then_some(0) + }); + let Some(p) = pos else { + return Err(fk_mismatch_err(&child.name, &parent_tbl.name)); + }; + parent_pos.push(p); + } + + let parent_uses_rowid = parent_cols.len().eq(&1) && { + let c = parent_cols[0].as_str(); + ROWID_STRS.iter().any(|&r| r.eq_ignore_ascii_case(c)) + || parent_tbl.columns.iter().any(|col| { + col.is_rowid_alias + && col + .name + .as_deref() + .is_some_and(|n| n.eq_ignore_ascii_case(c)) + }) + }; + + // Must be PK or a non-partial UNIQUE on exactly those columns. + let parent_unique_index = if parent_uses_rowid { + None + } else { + self.get_indices(&parent_tbl.name) + .find(|idx| { + idx.unique + && idx.where_clause.is_none() + && idx.columns.len() == parent_cols.len() + && idx + .columns + .iter() + .zip(parent_cols.iter()) + .all(|(ic, pc)| ic.name.eq_ignore_ascii_case(pc)) + }) + .cloned() + .ok_or_else(|| fk_mismatch_err(&child.name, &parent_tbl.name))? + .into() + }; + + fk.validate()?; + out.push(ResolvedFkRef { + child_table: Arc::clone(&child), + fk: Arc::clone(fk), + parent_cols, + child_cols, + child_pos, + parent_pos, + parent_uses_rowid, + parent_unique_index, + }); + } + + Ok(out) + } + + /// Returns if any table declares a FOREIGN KEY whose parent is `table_name`. + pub fn any_resolved_fks_referencing(&self, table_name: &str) -> bool { + self.tables.values().any(|t| { + let Some(bt) = t.btree() else { + return false; + }; + bt.foreign_keys + .iter() + .any(|fk| fk.parent_table == table_name) + }) + } + + /// Returns true if `table_name` declares any FOREIGN KEYs + pub fn has_child_fks(&self, table_name: &str) -> bool { + self.get_table(table_name) + .and_then(|t| t.btree()) + .is_some_and(|t| !t.foreign_keys.is_empty()) + } + + fn check_object_name_conflict(&self, name: &str) -> Result<()> { + let normalized_name = normalize_ident(name); + + if self.tables.contains_key(&normalized_name) { + return Err(crate::LimboError::ParseError( + ["table \"", name, "\" already exists"].concat().to_string(), + )); + } + + if self.views.contains_key(&normalized_name) { + return Err(crate::LimboError::ParseError( + ["view \"", name, "\" already exists"].concat().to_string(), + )); + } + + for index_list in self.indexes.values() { + if index_list.iter().any(|i| i.name.eq_ignore_ascii_case(name)) { + return Err(crate::LimboError::ParseError( + ["index \"", name, "\" already exists"].concat().to_string(), + )); + } + } + + Ok(()) + } } impl Clone for Schema { @@ -1016,6 +1313,7 @@ pub struct BTreeTable { pub is_strict: bool, pub has_autoincrement: bool, pub unique_sets: Vec, + pub foreign_keys: Vec>, } impl BTreeTable { @@ -1060,6 +1358,8 @@ impl BTreeTable { /// `CREATE TABLE t (x)`, whereas sqlite stores it with the original extra whitespace. pub fn to_sql(&self) -> String { let mut sql = format!("CREATE TABLE {} (", self.name); + let needs_pk_inline = self.primary_key_columns.len() == 1; + // Add columns for (i, column) in self.columns.iter().enumerate() { if i > 0 { sql.push_str(", "); @@ -1086,8 +1386,7 @@ impl BTreeTable { if column.unique { sql.push_str(" UNIQUE"); } - - if column.primary_key { + if needs_pk_inline && column.primary_key { sql.push_str(" PRIMARY KEY"); } @@ -1096,6 +1395,64 @@ impl BTreeTable { sql.push_str(&default.to_string()); } } + + let has_table_pk = !self.primary_key_columns.is_empty(); + // Add table-level PRIMARY KEY constraint if exists + if !needs_pk_inline && has_table_pk { + sql.push_str(", PRIMARY KEY ("); + for (i, col) in self.primary_key_columns.iter().enumerate() { + if i > 0 { + sql.push_str(", "); + } + sql.push_str(&col.0); + } + sql.push(')'); + } + + for fk in &self.foreign_keys { + sql.push_str(", FOREIGN KEY ("); + for (i, col) in fk.child_columns.iter().enumerate() { + if i > 0 { + sql.push_str(", "); + } + sql.push_str(col); + } + sql.push_str(") REFERENCES "); + sql.push_str(&fk.parent_table); + sql.push('('); + for (i, col) in fk.parent_columns.iter().enumerate() { + if i > 0 { + sql.push_str(", "); + } + sql.push_str(col); + } + sql.push(')'); + + // Add ON DELETE/UPDATE actions, NoAction is default so just make empty in that case + if fk.on_delete != RefAct::NoAction { + sql.push_str(" ON DELETE "); + sql.push_str(match fk.on_delete { + RefAct::SetNull => "SET NULL", + RefAct::SetDefault => "SET DEFAULT", + RefAct::Cascade => "CASCADE", + RefAct::Restrict => "RESTRICT", + _ => "", + }); + } + if fk.on_update != RefAct::NoAction { + sql.push_str(" ON UPDATE "); + sql.push_str(match fk.on_update { + RefAct::SetNull => "SET NULL", + RefAct::SetDefault => "SET DEFAULT", + RefAct::Cascade => "CASCADE", + RefAct::Restrict => "RESTRICT", + _ => "", + }); + } + if fk.deferred { + sql.push_str(" DEFERRABLE INITIALLY DEFERRED"); + } + } sql.push(')'); sql } @@ -1146,6 +1503,7 @@ pub fn create_table(tbl_name: &str, body: &CreateTableBody, root_page: i64) -> R let mut has_rowid = true; let mut has_autoincrement = false; let mut primary_key_columns = vec![]; + let mut foreign_keys = vec![]; let mut cols = vec![]; let is_strict: bool; let mut unique_sets: Vec = vec![]; @@ -1219,8 +1577,81 @@ pub fn create_table(tbl_name: &str, body: &CreateTableBody, root_page: i64) -> R is_primary_key: false, }; unique_sets.push(unique_set); + } else if let ast::TableConstraint::ForeignKey { + columns, + clause, + defer_clause, + } = &c.constraint + { + let child_columns: Vec = columns + .iter() + .map(|ic| normalize_ident(ic.col_name.as_str())) + .collect(); + // derive parent columns: explicit or default to parent PK + let parent_table = normalize_ident(clause.tbl_name.as_str()); + let parent_columns: Vec = clause + .columns + .iter() + .map(|ic| normalize_ident(ic.col_name.as_str())) + .collect(); + + // Only check arity if parent columns were explicitly listed + if !parent_columns.is_empty() && child_columns.len() != parent_columns.len() { + crate::bail_parse_error!( + "foreign key on \"{}\" has {} child column(s) but {} parent column(s)", + tbl_name, + child_columns.len(), + parent_columns.len() + ); + } + // deferrable semantics + let deferred = match defer_clause { + Some(d) => { + d.deferrable + && matches!( + d.init_deferred, + Some(InitDeferredPred::InitiallyDeferred) + ) + } + None => false, // NOT DEFERRABLE INITIALLY IMMEDIATE by default + }; + let fk = ForeignKey { + parent_table, + parent_columns, + child_columns, + on_delete: clause + .args + .iter() + .find_map(|a| { + if let ast::RefArg::OnDelete(x) = a { + Some(*x) + } else { + None + } + }) + .unwrap_or(RefAct::NoAction), + on_update: clause + .args + .iter() + .find_map(|a| { + if let ast::RefArg::OnUpdate(x) = a { + Some(*x) + } else { + None + } + }) + .unwrap_or(RefAct::NoAction), + deferred, + }; + foreign_keys.push(Arc::new(fk)); } } + + // Due to a bug in SQLite, this check is needed to maintain backwards compatibility with rowid alias + // SQLite docs: https://sqlite.org/lang_createtable.html#rowids_and_the_integer_primary_key + // Issue: https://github.com/tursodatabase/turso/issues/3665 + let mut primary_key_desc_columns_constraint = false; + for ast::ColumnDefinition { col_name, col_type, @@ -1259,12 +1690,24 @@ pub fn create_table(tbl_name: &str, body: &CreateTableBody, root_page: i64) -> R let mut unique = false; let mut collation = None; for c_def in constraints { - match c_def.constraint { + match &c_def.constraint { + ast::ColumnConstraint::Check { .. } => { + crate::bail_parse_error!("CHECK constraints are not yet supported"); + } + ast::ColumnConstraint::Generated { .. } => { + crate::bail_parse_error!("GENERATED columns are not yet supported"); + } ast::ColumnConstraint::PrimaryKey { order: o, auto_increment, + conflict_clause, .. } => { + if conflict_clause.is_some() { + crate::bail_parse_error!( + "ON CONFLICT not implemented for column definition" + ); + } if !primary_key_columns.is_empty() { crate::bail_parse_error!( "table \"{}\" has more than one primary key", @@ -1272,18 +1715,27 @@ pub fn create_table(tbl_name: &str, body: &CreateTableBody, root_page: i64) -> R ); } primary_key = true; - if auto_increment { + if *auto_increment { has_autoincrement = true; } if let Some(o) = o { - order = o; + order = *o; } unique_sets.push(UniqueSet { columns: vec![(name.clone(), order)], is_primary_key: true, }); } - ast::ColumnConstraint::NotNull { nullable, .. } => { + ast::ColumnConstraint::NotNull { + nullable, + conflict_clause, + .. + } => { + if conflict_clause.is_some() { + crate::bail_parse_error!( + "ON CONFLICT not implemented for column definition" + ); + } notnull = !nullable; } ast::ColumnConstraint::Default(ref expr) => { @@ -1292,9 +1744,11 @@ pub fn create_table(tbl_name: &str, body: &CreateTableBody, root_page: i64) -> R ); } // TODO: for now we don't check Resolve type of unique - ast::ColumnConstraint::Unique(on_conflict) => { - if on_conflict.is_some() { - unimplemented!("ON CONFLICT not implemented"); + ast::ColumnConstraint::Unique(conflict) => { + if conflict.is_some() { + crate::bail_parse_error!( + "ON CONFLICT not implemented for column definition" + ); } unique = true; unique_sets.push(UniqueSet { @@ -1305,12 +1759,61 @@ pub fn create_table(tbl_name: &str, body: &CreateTableBody, root_page: i64) -> R ast::ColumnConstraint::Collate { ref collation_name } => { collation = Some(CollationSeq::new(collation_name.as_str())?); } - _ => {} + ast::ColumnConstraint::ForeignKey { + clause, + defer_clause, + } => { + let fk = ForeignKey { + parent_table: normalize_ident(clause.tbl_name.as_str()), + parent_columns: clause + .columns + .iter() + .map(|c| normalize_ident(c.col_name.as_str())) + .collect(), + on_delete: clause + .args + .iter() + .find_map(|arg| { + if let ast::RefArg::OnDelete(act) = arg { + Some(*act) + } else { + None + } + }) + .unwrap_or(RefAct::NoAction), + on_update: clause + .args + .iter() + .find_map(|arg| { + if let ast::RefArg::OnUpdate(act) = arg { + Some(*act) + } else { + None + } + }) + .unwrap_or(RefAct::NoAction), + child_columns: vec![name.clone()], + deferred: match defer_clause { + Some(d) => { + d.deferrable + && matches!( + d.init_deferred, + Some(InitDeferredPred::InitiallyDeferred) + ) + } + None => false, + }, + }; + foreign_keys.push(Arc::new(fk)); + } } } if primary_key { primary_key_columns.push((name.clone(), order)); + if order == SortOrder::Desc { + primary_key_desc_columns_constraint = true; + } } else if primary_key_columns .iter() .any(|(col_name, _)| col_name == &name) @@ -1323,7 +1826,9 @@ pub fn create_table(tbl_name: &str, body: &CreateTableBody, root_page: i64) -> R ty, ty_str, primary_key, - is_rowid_alias: typename_exactly_integer && primary_key, + is_rowid_alias: typename_exactly_integer + && primary_key + && !primary_key_desc_columns_constraint, notnull, default, unique, @@ -1384,6 +1889,7 @@ pub fn create_table(tbl_name: &str, body: &CreateTableBody, root_page: i64) -> R has_autoincrement, columns: cols, is_strict, + foreign_keys, unique_sets: { // If there are any unique sets that have identical column names in the same order (even if they are PRIMARY KEY and UNIQUE and have different sort orders), remove the duplicates. // Examples: @@ -1441,6 +1947,115 @@ pub fn _build_pseudo_table(columns: &[ResultColumn]) -> PseudoCursorType { table } +#[derive(Debug, Clone)] +pub struct ForeignKey { + /// Columns in this table (child side) + pub child_columns: Vec, + /// Referenced (parent) table + pub parent_table: String, + /// Parent-side referenced columns + pub parent_columns: Vec, + pub on_delete: RefAct, + pub on_update: RefAct, + /// DEFERRABLE INITIALLY DEFERRED + pub deferred: bool, +} +impl ForeignKey { + fn validate(&self) -> Result<()> { + // TODO: remove this when actions are implemented + if !(matches!(self.on_update, RefAct::NoAction) + && matches!(self.on_delete, RefAct::NoAction)) + { + crate::bail_parse_error!( + "foreign key actions other than NO ACTION are not implemented" + ); + } + if self + .parent_columns + .iter() + .any(|c| ROWID_STRS.iter().any(|&r| r.eq_ignore_ascii_case(c))) + { + return Err(crate::LimboError::Constraint(format!( + "foreign key mismatch referencing \"{}\"", + self.parent_table + ))); + } + Ok(()) + } +} + +/// A single resolved foreign key where `parent_table == target`. +#[derive(Clone, Debug)] +pub struct ResolvedFkRef { + /// Child table that owns the FK. + pub child_table: Arc, + /// The FK as declared on the child table. + pub fk: Arc, + + /// Resolved, normalized column names. + pub parent_cols: Vec, + pub child_cols: Vec, + + /// Column positions in the child/parent tables (pos_in_table) + pub child_pos: Vec, + pub parent_pos: Vec, + + /// If the parent key is rowid or a rowid-alias (single-column only) + pub parent_uses_rowid: bool, + /// For non-rowid parents: the UNIQUE index that enforces the parent key. + /// (None when `parent_uses_rowid == true`.) + pub parent_unique_index: Option>, +} + +impl ResolvedFkRef { + /// Returns if any referenced parent column can change when these column positions are updated. + pub fn parent_key_may_change( + &self, + updated_parent_positions: &HashSet, + parent_tbl: &BTreeTable, + ) -> bool { + if self.parent_uses_rowid { + // parent rowid changes if the parent's rowid or alias is updated + if let Some((idx, _)) = parent_tbl + .columns + .iter() + .enumerate() + .find(|(_, c)| c.is_rowid_alias) + { + return updated_parent_positions.contains(&idx); + } + // Without a rowid alias, a direct rowid update is represented separately with ROWID_SENTINEL + return true; + } + self.parent_pos + .iter() + .any(|p| updated_parent_positions.contains(p)) + } + + /// Returns if any child column of this FK is in `updated_child_positions` + pub fn child_key_changed( + &self, + updated_child_positions: &HashSet, + child_tbl: &BTreeTable, + ) -> bool { + if self + .child_pos + .iter() + .any(|p| updated_child_positions.contains(p)) + { + return true; + } + // special case: if FK uses a rowid alias on child, and rowid changed + if self.child_cols.len() == 1 { + let (i, col) = child_tbl.get_column(&self.child_cols[0]).unwrap(); + if col.is_rowid_alias && updated_child_positions.contains(&i) { + return true; + } + } + false + } +} + #[derive(Debug, Clone)] pub struct Column { pub name: Option, @@ -1782,6 +2397,7 @@ pub fn sqlite_schema_table() -> BTreeTable { hidden: false, }, ], + foreign_keys: vec![], unique_sets: vec![], } } @@ -2392,6 +3008,7 @@ mod tests { hidden: false, }], unique_sets: vec![], + foreign_keys: vec![], }; let result = diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 0dfaff8c1..fff63726f 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -1,7 +1,8 @@ use tracing::{instrument, Level}; use crate::{ - io_yield_many, io_yield_one, + io::CompletionGroup, + io_yield_one, schema::Index, storage::{ pager::{BtreePageAllocMode, Pager}, @@ -23,12 +24,12 @@ use crate::{ RecordCursor, SeekResult, }, util::IOExt, - Completion, MvCursor, Page, + Completion, Page, }; use crate::{ return_corrupt, return_if_io, - types::{compare_immutable, IOResult, ImmutableRecord, RefValue, SeekKey, SeekOp, Value}, + types::{compare_immutable, IOResult, ImmutableRecord, SeekKey, SeekOp, Value, ValueRef}, LimboError, Result, }; @@ -38,8 +39,8 @@ use super::{ write_varint_to_vec, IndexInteriorCell, IndexLeafCell, OverflowCell, MINIMUM_CELL_SIZE, }, }; -use parking_lot::RwLock; use std::{ + any::Any, cell::{Cell, Ref, RefCell}, cmp::{Ordering, Reverse}, collections::{BinaryHeap, HashMap}, @@ -336,7 +337,7 @@ impl BTreeKey<'_> { } /// Get the record, if present. Index will always be present, - fn get_record(&self) -> Option<&'_ ImmutableRecord> { + pub fn get_record(&self) -> Option<&'_ ImmutableRecord> { match self { BTreeKey::TableRowId((_, record)) => *record, BTreeKey::IndexKey(record) => Some(record), @@ -344,7 +345,7 @@ impl BTreeKey<'_> { } /// Get the rowid, if present. Index will never be present. - fn maybe_rowid(&self) -> Option { + pub fn maybe_rowid(&self) -> Option { match self { BTreeKey::TableRowId((rowid, _)) => Some(*rowid), BTreeKey::IndexKey(_) => None, @@ -509,9 +510,55 @@ pub enum CursorSeekState { }, } +pub trait CursorTrait: Any { + /// Move cursor to last entry. + fn last(&mut self) -> Result>; + /// Move cursor to next entry. + fn next(&mut self) -> Result>; + /// Move cursor to previous entry. + fn prev(&mut self) -> Result>; + /// Get the rowid of the entry the cursor is poiting to if any + fn rowid(&self) -> Result>>; + /// Get the record of the entry the cursor is poiting to if any + fn record(&self) -> Result>>>; + /// Move the cursor based on the key and the type of operation (op). + fn seek(&mut self, key: SeekKey<'_>, op: SeekOp) -> Result>; + /// Insert a record in the position the cursor is at. + fn insert(&mut self, key: &BTreeKey) -> Result>; + /// Delete a record in the position the cursor is at. + fn delete(&mut self) -> Result>; + fn set_null_flag(&mut self, flag: bool); + fn get_null_flag(&self) -> bool; + /// Check if a key exists. + fn exists(&mut self, key: &Value) -> Result>; + fn clear_btree(&mut self) -> Result>>; + fn btree_destroy(&mut self) -> Result>>; + /// Count the number of entries in the b-tree + /// + /// Only supposed to be used in the context of a simple Count Select Statement + fn count(&mut self) -> Result>; + fn is_empty(&self) -> bool; + fn root_page(&self) -> i64; + /// Move cursor at the start. + fn rewind(&mut self) -> Result>; + /// Check if cursor is poiting at a valid entry with a record. + fn has_record(&self) -> bool; + fn set_has_record(&self, has_record: bool); + fn get_index_info(&self) -> &IndexInfo; + + fn seek_end(&mut self) -> Result>; + fn seek_to_last(&mut self) -> Result>; + + // --- start: BTreeCursor specific functions ---- + fn invalidate_record(&mut self); + fn has_rowid(&self) -> bool; + fn record_cursor_mut(&self) -> std::cell::RefMut<'_, RecordCursor>; + fn get_pager(&self) -> Arc; + fn get_skip_advance(&self) -> bool; + // --- end: BTreeCursor specific functions ---- +} + pub struct BTreeCursor { - /// The multi-version cursor that is used to read and write to the database file. - mv_cursor: Option>>, /// The pager that is used to read and write to the database file. pub pager: Arc, /// Cached value of the usable space of a BTree page, since it is very expensive to call in a hot loop via pager.usable_space(). @@ -614,20 +661,14 @@ impl BTreeNodeState { } impl BTreeCursor { - pub fn new( - mv_cursor: Option>>, - pager: Arc, - root_page: i64, - num_columns: usize, - ) -> Self { - let valid_state = if root_page == 1 && !pager.db_state.is_initialized() { + pub fn new(pager: Arc, root_page: i64, num_columns: usize) -> Self { + let valid_state = if root_page == 1 && !pager.db_state.get().is_initialized() { CursorValidState::Invalid } else { CursorValidState::Valid }; let usable_space = pager.usable_space(); Self { - mv_cursor, pager, root_page, usable_space_cached: usable_space, @@ -665,34 +706,16 @@ impl BTreeCursor { } } - pub fn new_table( - mv_cursor: Option>>, - pager: Arc, - root_page: i64, - num_columns: usize, - ) -> Self { - Self::new(mv_cursor, pager, root_page, num_columns) + pub fn new_table(pager: Arc, root_page: i64, num_columns: usize) -> Self { + Self::new(pager, root_page, num_columns) } - pub fn new_index( - mv_cursor: Option>>, - pager: Arc, - root_page: i64, - index: &Index, - num_columns: usize, - ) -> Self { - let mut cursor = Self::new(mv_cursor, pager, root_page, num_columns); + pub fn new_index(pager: Arc, root_page: i64, index: &Index, num_columns: usize) -> Self { + let mut cursor = Self::new(pager, root_page, num_columns); cursor.index_info = Some(IndexInfo::new_from_index(index)); cursor } - pub fn has_rowid(&self) -> bool { - match &self.index_info { - Some(index_key_info) => index_key_info.has_rowid, - None => true, // currently we don't support WITHOUT ROWID tables - } - } - pub fn get_index_rowid_from_record(&self) -> Option { if !self.has_rowid() { return None; @@ -705,7 +728,7 @@ impl BTreeCursor { .unwrap() .last_value(record_cursor) { - Some(Ok(RefValue::Integer(rowid))) => rowid, + Some(Ok(ValueRef::Integer(rowid))) => rowid, _ => unreachable!( "index where has_rowid() is true should have an integer rowid as the last value" ), @@ -721,10 +744,6 @@ impl BTreeCursor { let state = self.is_empty_table_state.borrow().clone(); match state { EmptyTableState::Start => { - if let Some(mv_cursor) = &self.mv_cursor { - let mv_cursor = mv_cursor.read(); - return Ok(IOResult::Done(mv_cursor.is_empty())); - } let (page, c) = self.pager.read_page(self.root_page)?; *self.is_empty_table_state.borrow_mut() = EmptyTableState::ReadPage { page }; if let Some(c) = c { @@ -1030,7 +1049,13 @@ impl BTreeCursor { local_amount = local_size as u32 - offset; } if is_write { - self.write_payload_to_page(offset, local_amount, payload, buffer, page.clone()); + self.write_payload_to_page( + offset, + local_amount, + payload, + buffer, + page.clone(), + )?; } else { self.read_payload_from_page(offset, local_amount, payload, buffer); } @@ -1156,7 +1181,7 @@ impl BTreeCursor { page_payload, buffer, page.clone(), - ); + )?; } else { self.read_payload_from_page( payload_offset as u32, @@ -1228,13 +1253,14 @@ impl BTreeCursor { payload: &[u8], buffer: &mut [u8], page: PageRef, - ) { - self.pager.add_dirty(&page); + ) -> Result<()> { + self.pager.add_dirty(&page)?; // SAFETY: This is safe as long as the page is not evicted from the cache. let payload_mut = unsafe { std::slice::from_raw_parts_mut(payload.as_ptr() as *mut u8, payload.len()) }; payload_mut[payload_offset as usize..payload_offset as usize + num_bytes as usize] .copy_from_slice(&buffer[..num_bytes as usize]); + Ok(()) } /// Check if any ancestor pages still have cells to iterate. @@ -1250,17 +1276,7 @@ impl BTreeCursor { /// Used in forwards iteration, which is the default. #[instrument(skip(self), level = Level::DEBUG, name = "next")] pub fn get_next_record(&mut self) -> Result> { - if let Some(mv_cursor) = &self.mv_cursor { - let mut mv_cursor = mv_cursor.write(); - mv_cursor.forward(); - let rowid = mv_cursor.current_row_id(); - match rowid { - Some(_rowid) => { - return Ok(IOResult::Done(true)); - } - None => return Ok(IOResult::Done(false)), - } - } else if self.stack.current_page == -1 { + if self.stack.current_page == -1 { // This can happen in nested left joins. See: // https://github.com/tursodatabase/turso/issues/2924 return Ok(IOResult::Done(false)); @@ -1835,10 +1851,6 @@ impl BTreeCursor { /// of iterating cells in order. #[instrument(skip_all, level = Level::DEBUG)] fn tablebtree_seek(&mut self, rowid: i64, seek_op: SeekOp) -> Result> { - turso_assert!( - self.mv_cursor.is_none(), - "attempting to seek with MV cursor" - ); let iter_dir = seek_op.iteration_direction(); if matches!( @@ -2164,7 +2176,7 @@ impl BTreeCursor { fn compare_with_current_record( &self, - key_values: &[RefValue], + key_values: &[ValueRef], seek_op: SeekOp, record_comparer: &RecordCompare, index_info: &IndexInfo, @@ -2191,10 +2203,6 @@ impl BTreeCursor { #[instrument(skip_all, level = Level::DEBUG)] pub fn move_to(&mut self, key: SeekKey<'_>, cmp: SeekOp) -> Result> { - turso_assert!( - self.mv_cursor.is_none(), - "attempting to move with MV cursor" - ); tracing::trace!(?key, ?cmp); // For a table with N rows, we can find any row by row id in O(log(N)) time by starting at the root page and following the B-tree pointers. // B-trees consist of interior pages and leaf pages. Interior pages contain pointers to other pages, while leaf pages contain the actual row data. @@ -2274,7 +2282,7 @@ impl BTreeCursor { // get page and find cell let cell_idx = { - self.pager.add_dirty(&page); + self.pager.add_dirty(&page)?; self.stack.current_cell_index() }; if cell_idx == -1 { @@ -2642,8 +2650,8 @@ impl BTreeCursor { usable_space, )?; parent_contents.write_rightmost_ptr(new_rightmost_leaf.get().id as u32); - self.pager.add_dirty(parent); - self.pager.add_dirty(&new_rightmost_leaf); + self.pager.add_dirty(parent)?; + self.pager.add_dirty(&new_rightmost_leaf)?; // Continue balance from the parent page (inserting the new divider cell may have overflowed the parent) self.stack.pop(); @@ -2720,7 +2728,7 @@ impl BTreeCursor { overflow_cell.index ); } - self.pager.add_dirty(parent_page); + self.pager.add_dirty(parent_page)?; let parent_contents = parent_page.get_contents(); let page_to_balance_idx = self.stack.current_cell_index() as usize; @@ -2809,23 +2817,22 @@ impl BTreeCursor { // load sibling pages // start loading right page first - let mut pgno: u32 = unsafe { right_pointer.cast::().read().swap_bytes() }; + let mut pgno: u32 = + unsafe { right_pointer.cast::().read_unaligned().swap_bytes() }; let current_sibling = sibling_pointer; - let mut completions: Vec = Vec::with_capacity(current_sibling + 1); + let mut group = CompletionGroup::new(|_| {}); for i in (0..=current_sibling).rev() { match btree_read_page(&self.pager, pgno as i64) { Err(e) => { tracing::error!("error reading page {}: {}", pgno, e); - self.pager.io.cancel(&completions)?; + group.cancel(); self.pager.io.drain()?; return Err(e); } Ok((page, c)) => { - // mark as dirty - self.pager.add_dirty(&page); pages_to_balance[i].replace(page); if let Some(c) = c { - completions.push(c); + group.add(&c); } } } @@ -2892,8 +2899,9 @@ impl BTreeCursor { first_divider_cell: first_cell_divider, }); *sub_state = BalanceSubState::NonRootDoBalancing; - if !completions.is_empty() { - io_yield_many!(completions); + let completion = group.build(); + if !completion.finished() { + io_yield_one!(completion); } } BalanceSubState::NonRootDoBalancing => { @@ -2906,7 +2914,7 @@ impl BTreeCursor { .take(balance_info.sibling_count) { let page = page.as_ref().unwrap(); - turso_assert!(page.is_loaded(), "page should be loaded"); + self.pager.add_dirty(page)?; #[cfg(debug_assertions)] let page_type_of_siblings = balance_info.pages_to_balance[0] @@ -3479,7 +3487,7 @@ impl BTreeCursor { if *new_id != page.get().id { page.get().id = *new_id; self.pager - .update_dirty_loaded_page_in_cache(*new_id, page.clone())?; + .upsert_page_in_cache(*new_id, page.clone(), true)?; } } @@ -4405,118 +4413,471 @@ impl BTreeCursor { self.usable_space_cached } - pub fn seek_end(&mut self) -> Result> { - assert!(self.mv_cursor.is_none()); // unsure about this -_- + /// Clear the overflow pages linked to a specific page provided by the leaf cell + /// Uses a state machine to keep track of it's operations so that traversal can be + /// resumed from last point after IO interruption + #[instrument(skip_all, level = Level::DEBUG)] + fn clear_overflow_pages(&mut self, cell: &BTreeCell) -> Result> { loop { - match self.seek_end_state { - SeekEndState::Start => { - let c = self.move_to_root()?; - self.seek_end_state = SeekEndState::ProcessPage; - if let Some(c) = c { - io_yield_one!(c); + match self.overflow_state.clone() { + OverflowState::Start => { + let first_overflow_page = match cell { + BTreeCell::TableLeafCell(leaf_cell) => leaf_cell.first_overflow_page, + BTreeCell::IndexLeafCell(leaf_cell) => leaf_cell.first_overflow_page, + BTreeCell::IndexInteriorCell(interior_cell) => { + interior_cell.first_overflow_page + } + BTreeCell::TableInteriorCell(_) => return Ok(IOResult::Done(())), // No overflow pages + }; + + if let Some(next_page) = first_overflow_page { + if next_page < 2 + || next_page + > self + .pager + .io + .block(|| { + self.pager.with_header(|header| header.database_size) + })? + .get() + { + self.overflow_state = OverflowState::Start; + return Err(LimboError::Corrupt("Invalid overflow page number".into())); + } + let (page, c) = self.read_page(next_page as i64)?; + self.overflow_state = OverflowState::ProcessPage { next_page: page }; + if let Some(c) = c { + io_yield_one!(c); + } + } else { + self.overflow_state = OverflowState::Done; } } - SeekEndState::ProcessPage => { - let mem_page = self.stack.top_ref(); - let contents = mem_page.get_contents(); - if contents.is_leaf() { - // set cursor just past the last cell to append - self.stack.set_cell_index(contents.cell_count() as i32); - self.seek_end_state = SeekEndState::Start; - return Ok(IOResult::Done(())); + OverflowState::ProcessPage { next_page: page } => { + turso_assert!(page.is_loaded(), "page should be loaded"); + + let contents = page.get_contents(); + let next = contents.read_u32_no_offset(0); + let next_page_id = page.get().id; + + return_if_io!(self.pager.free_page(Some(page), next_page_id)); + + if next != 0 { + if next < 2 + || next + > self + .pager + .io + .block(|| { + self.pager.with_header(|header| header.database_size) + })? + .get() + { + self.overflow_state = OverflowState::Start; + return Err(LimboError::Corrupt("Invalid overflow page number".into())); + } + let (page, c) = self.read_page(next as i64)?; + self.overflow_state = OverflowState::ProcessPage { next_page: page }; + if let Some(c) = c { + io_yield_one!(c); + } + } else { + self.overflow_state = OverflowState::Done; + } + } + OverflowState::Done => { + self.overflow_state = OverflowState::Start; + return Ok(IOResult::Done(())); + } + }; + } + } + + /// Deletes all contents of the B-tree by freeing all its pages in an iterative depth-first order. + /// This ensures child pages are freed before their parents + /// Uses a state machine to keep track of the operation to ensure IO doesn't cause repeated traversals + /// + /// Depending on the caller, the root page may either be freed as well or left allocated but emptied. + /// + /// # Example + /// For a B-tree with this structure (where 4' is an overflow page): + /// ```text + /// 1 (root) + /// / \ + /// 2 3 + /// / \ / \ + /// 4' <- 4 5 6 7 + /// ``` + /// + /// The destruction order would be: [4',4,5,2,6,7,3,1] + fn destroy_btree_contents(&mut self, keep_root: bool) -> Result>> { + if let CursorState::None = &self.state { + let c = self.move_to_root()?; + self.state = CursorState::Destroy(DestroyInfo { + state: DestroyState::Start, + }); + if let Some(c) = c { + io_yield_one!(c); + } + } + + loop { + let destroy_state = { + let destroy_info = self + .state + .destroy_info() + .expect("unable to get a mut reference to destroy state in cursor"); + destroy_info.state.clone() + }; + + match destroy_state { + DestroyState::Start => { + let destroy_info = self + .state + .mut_destroy_info() + .expect("unable to get a mut reference to destroy state in cursor"); + destroy_info.state = DestroyState::LoadPage; + } + DestroyState::LoadPage => { + let _page = self.stack.top_ref(); + + let destroy_info = self + .state + .mut_destroy_info() + .expect("unable to get a mut reference to destroy state in cursor"); + destroy_info.state = DestroyState::ProcessPage; + } + DestroyState::ProcessPage => { + self.stack.advance(); + let page = self.stack.top_ref(); + let contents = page.get_contents(); + let cell_idx = self.stack.current_cell_index(); + + // If we've processed all cells in this page, figure out what to do with this page + if cell_idx >= contents.cell_count() as i32 { + match (contents.is_leaf(), cell_idx) { + // Leaf pages with all cells processed + (true, n) if n >= contents.cell_count() as i32 => { + let destroy_info = self.state.mut_destroy_info().expect( + "unable to get a mut reference to destroy state in cursor", + ); + destroy_info.state = DestroyState::FreePage; + continue; + } + // Non-leaf page which has processed all children but not it's potential right child + (false, n) if n == contents.cell_count() as i32 => { + if let Some(rightmost) = contents.rightmost_pointer() { + let (rightmost_page, c) = self.read_page(rightmost as i64)?; + self.stack.push(rightmost_page); + let destroy_info = self.state.mut_destroy_info().expect( + "unable to get a mut reference to destroy state in cursor", + ); + destroy_info.state = DestroyState::LoadPage; + if let Some(c) = c { + io_yield_one!(c); + } + } else { + let destroy_info = self.state.mut_destroy_info().expect( + "unable to get a mut reference to destroy state in cursor", + ); + destroy_info.state = DestroyState::FreePage; + } + continue; + } + // Non-leaf page which has processed all children and it's right child + (false, n) if n > contents.cell_count() as i32 => { + let destroy_info = self.state.mut_destroy_info().expect( + "unable to get a mut reference to destroy state in cursor", + ); + destroy_info.state = DestroyState::FreePage; + continue; + } + _ => unreachable!("Invalid cell idx state"), + } } - match contents.rightmost_pointer() { - Some(right_most_pointer) => { - self.stack.set_cell_index(contents.cell_count() as i32 + 1); // invalid on interior - let (child, c) = self.read_page(right_most_pointer as i64)?; - self.stack.push(child); + // We have not yet processed all cells in this page + // Get the current cell + let cell = contents.cell_get(cell_idx as usize, self.usable_space())?; + + match contents.is_leaf() { + // For a leaf cell, clear the overflow pages associated with this cell + true => { + let destroy_info = self + .state + .mut_destroy_info() + .expect("unable to get a mut reference to destroy state in cursor"); + destroy_info.state = DestroyState::ClearOverflowPages { cell }; + continue; + } + // For interior cells, check the type of cell to determine what to do + false => match &cell { + // For index interior cells, remove the overflow pages + BTreeCell::IndexInteriorCell(_) => { + let destroy_info = self.state.mut_destroy_info().expect( + "unable to get a mut reference to destroy state in cursor", + ); + destroy_info.state = DestroyState::ClearOverflowPages { cell }; + continue; + } + // For all other interior cells, load the left child page + _ => { + let child_page_id = match &cell { + BTreeCell::TableInteriorCell(cell) => cell.left_child_page, + BTreeCell::IndexInteriorCell(cell) => cell.left_child_page, + _ => panic!("expected interior cell"), + }; + let (child_page, c) = self.read_page(child_page_id as i64)?; + self.stack.push(child_page); + let destroy_info = self.state.mut_destroy_info().expect( + "unable to get a mut reference to destroy state in cursor", + ); + destroy_info.state = DestroyState::LoadPage; + if let Some(c) = c { + io_yield_one!(c); + } + } + }, + } + } + DestroyState::ClearOverflowPages { cell } => { + return_if_io!(self.clear_overflow_pages(&cell)); + match cell { + // For an index interior cell, clear the left child page now that overflow pages have been cleared + BTreeCell::IndexInteriorCell(index_int_cell) => { + let (child_page, c) = + self.read_page(index_int_cell.left_child_page as i64)?; + self.stack.push(child_page); + let destroy_info = self + .state + .mut_destroy_info() + .expect("unable to get a mut reference to destroy state in cursor"); + destroy_info.state = DestroyState::LoadPage; if let Some(c) = c { io_yield_one!(c); } } - None => unreachable!("interior page must have rightmost pointer"), - } - } - } - } - } - - #[instrument(skip_all, level = Level::DEBUG)] - pub fn seek_to_last(&mut self) -> Result> { - loop { - match self.seek_to_last_state { - SeekToLastState::Start => { - assert!(self.mv_cursor.is_none()); - let has_record = return_if_io!(self.move_to_rightmost()); - self.invalidate_record(); - self.has_record.replace(has_record); - if !has_record { - self.seek_to_last_state = SeekToLastState::IsEmpty; - continue; - } - return Ok(IOResult::Done(())); - } - SeekToLastState::IsEmpty => { - let is_empty = return_if_io!(self.is_empty_table()); - assert!(is_empty); - self.seek_to_last_state = SeekToLastState::Start; - return Ok(IOResult::Done(())); - } - } - } - } - - pub fn is_empty(&self) -> bool { - !self.has_record.get() - } - - pub fn root_page(&self) -> i64 { - self.root_page - } - - #[instrument(skip_all, level = Level::DEBUG)] - pub fn rewind(&mut self) -> Result> { - if self.valid_state == CursorValidState::Invalid { - return Ok(IOResult::Done(())); - } - self.skip_advance.set(false); - loop { - match self.rewind_state { - RewindState::Start => { - self.rewind_state = RewindState::NextRecord; - if let Some(mv_cursor) = &self.mv_cursor { - let mut mv_cursor = mv_cursor.write(); - mv_cursor.rewind(); - } else { - let c = self.move_to_root()?; - if let Some(c) = c { - io_yield_one!(c); + // For any leaf cell, advance the index now that overflow pages have been cleared + BTreeCell::TableLeafCell(_) | BTreeCell::IndexLeafCell(_) => { + let destroy_info = self + .state + .mut_destroy_info() + .expect("unable to get a mut reference to destroy state in cursor"); + destroy_info.state = DestroyState::LoadPage; } + _ => panic!("unexpected cell type"), } } - RewindState::NextRecord => { - let cursor_has_record = return_if_io!(self.get_next_record()); - self.invalidate_record(); - self.has_record.replace(cursor_has_record); - self.rewind_state = RewindState::Start; + DestroyState::FreePage => { + let page = self.stack.top(); + let page_id = page.get().id; + + if self.stack.has_parent() { + return_if_io!(self.pager.free_page(Some(page), page_id)); + + self.stack.pop(); + let destroy_info = self + .state + .mut_destroy_info() + .expect("unable to get a mut reference to destroy state in cursor"); + destroy_info.state = DestroyState::ProcessPage; + } else { + if keep_root { + self.clear_root(&page)?; + } else { + return_if_io!(self.pager.free_page(Some(page), page_id)); + } + + self.state = CursorState::None; + // TODO: For now, no-op the result return None always. This will change once [AUTO_VACUUM](https://www.sqlite.org/lang_vacuum.html) is introduced + // At that point, the last root page(call this x) will be moved into the position of the root page of this table and the value returned will be x + return Ok(IOResult::Done(None)); + } + } + } + } + } + + fn clear_root(&mut self, root_page: &PageRef) -> Result<()> { + let page_ref = root_page.get(); + let contents = page_ref.contents.as_ref().unwrap(); + + let page_type = match contents.page_type() { + PageType::TableLeaf | PageType::TableInterior => PageType::TableLeaf, + PageType::IndexLeaf | PageType::IndexInterior => PageType::IndexLeaf, + }; + + self.pager.add_dirty(root_page)?; + btree_init_page(root_page, page_type, 0, self.pager.usable_space()); + Ok(()) + } + + pub fn overwrite_cell( + &mut self, + page: &PageRef, + cell_idx: usize, + record: &ImmutableRecord, + state: &mut OverwriteCellState, + ) -> Result> { + loop { + turso_assert!(page.is_loaded(), "page {} is not loaded", page.get().id); + match state { + OverwriteCellState::AllocatePayload => { + let serial_types_len = self.record_cursor.borrow_mut().len(record); + let new_payload = Vec::with_capacity(serial_types_len); + let rowid = return_if_io!(self.rowid()); + *state = OverwriteCellState::FillPayload { + new_payload, + rowid, + fill_cell_payload_state: FillCellPayloadState::Start, + }; + continue; + } + OverwriteCellState::FillPayload { + new_payload, + rowid, + fill_cell_payload_state, + } => { + { + return_if_io!(fill_cell_payload( + page, + *rowid, + new_payload, + cell_idx, + record, + self.usable_space(), + self.pager.clone(), + fill_cell_payload_state, + )); + } + // figure out old cell offset & size + let (old_offset, old_local_size) = { + let contents = page.get_contents(); + contents.cell_get_raw_region(cell_idx, self.usable_space()) + }; + + *state = OverwriteCellState::ClearOverflowPagesAndOverwrite { + new_payload: new_payload.clone(), + old_offset, + old_local_size, + }; + continue; + } + OverwriteCellState::ClearOverflowPagesAndOverwrite { + new_payload, + old_offset, + old_local_size, + } => { + let contents = page.get_contents(); + let cell = contents.cell_get(cell_idx, self.usable_space())?; + return_if_io!(self.clear_overflow_pages(&cell)); + + // if it all fits in local space and old_local_size is enough, do an in-place overwrite + if new_payload.len() == *old_local_size { + let _res = + BTreeCursor::overwrite_content(page.clone(), *old_offset, new_payload)?; + return Ok(IOResult::Done(())); + } + + drop_cell(contents, cell_idx, self.usable_space())?; + insert_into_cell(contents, new_payload, cell_idx, self.usable_space())?; return Ok(IOResult::Done(())); } } } } - #[instrument(skip_all, level = Level::DEBUG)] - pub fn last(&mut self) -> Result> { - assert!(self.mv_cursor.is_none()); - let cursor_has_record = return_if_io!(self.move_to_rightmost()); - self.has_record.replace(cursor_has_record); - self.invalidate_record(); + pub fn overwrite_content( + page: PageRef, + dest_offset: usize, + new_payload: &[u8], + ) -> Result> { + turso_assert!(page.is_loaded(), "page should be loaded"); + let buf = page.get_contents().as_ptr(); + buf[dest_offset..dest_offset + new_payload.len()].copy_from_slice(new_payload); + Ok(IOResult::Done(())) } + fn get_immutable_record_or_create(&self) -> std::cell::RefMut<'_, Option> { + let mut reusable_immutable_record = self.reusable_immutable_record.borrow_mut(); + if reusable_immutable_record.is_none() { + let page_size = self.pager.get_page_size_unchecked().get(); + let record = ImmutableRecord::new(page_size as usize); + reusable_immutable_record.replace(record); + } + reusable_immutable_record + } + + fn get_immutable_record(&self) -> std::cell::RefMut<'_, Option> { + self.reusable_immutable_record.borrow_mut() + } + + pub fn is_write_in_progress(&self) -> bool { + matches!(self.state, CursorState::Write(_)) + } + + // Save cursor context, to be restored later + pub fn save_context(&mut self, cursor_context: CursorContext) { + self.valid_state = CursorValidState::RequireSeek; + self.context = Some(cursor_context); + } + + /// If context is defined, restore it and set it None on success #[instrument(skip_all, level = Level::DEBUG)] - pub fn next(&mut self) -> Result> { + fn restore_context(&mut self) -> Result> { + if self.context.is_none() || matches!(self.valid_state, CursorValidState::Valid) { + return Ok(IOResult::Done(())); + } + if let CursorValidState::RequireAdvance(direction) = self.valid_state { + let has_record = return_if_io!(match direction { + // Avoid calling next()/prev() directly because they immediately call restore_context() + IterationDirection::Forwards => self.get_next_record(), + IterationDirection::Backwards => self.get_prev_record(), + }); + self.has_record.set(has_record); + self.invalidate_record(); + self.context = None; + self.valid_state = CursorValidState::Valid; + return Ok(IOResult::Done(())); + } + let ctx = self.context.take().unwrap(); + let seek_key = match ctx.key { + CursorContextKey::TableRowId(rowid) => SeekKey::TableRowId(rowid), + CursorContextKey::IndexKeyRowId(ref record) => SeekKey::IndexKey(record), + }; + let res = self.seek(seek_key, ctx.seek_op)?; + match res { + IOResult::Done(res) => { + if let SeekResult::TryAdvance = res { + self.valid_state = + CursorValidState::RequireAdvance(ctx.seek_op.iteration_direction()); + self.context = Some(ctx); + io_yield_one!(Completion::new_yield()); + } + self.valid_state = CursorValidState::Valid; + Ok(IOResult::Done(())) + } + IOResult::IO(io) => { + self.context = Some(ctx); + Ok(IOResult::IO(io)) + } + } + } + + pub fn read_page(&self, page_idx: i64) -> Result<(PageRef, Option)> { + btree_read_page(&self.pager, page_idx) + } + + pub fn allocate_page(&self, page_type: PageType, offset: usize) -> Result> { + self.pager + .do_allocate_page(page_type, offset, BtreePageAllocMode::Any) + } +} + +impl CursorTrait for BTreeCursor { + #[instrument(skip_all, level = Level::DEBUG)] + fn next(&mut self) -> Result> { if self.valid_state == CursorValidState::Invalid { return Ok(IOResult::Done(false)); } @@ -4553,17 +4914,16 @@ impl BTreeCursor { } } - pub fn invalidate_record(&mut self) { - self.get_immutable_record_or_create() - .as_mut() - .unwrap() - .invalidate(); - self.record_cursor.borrow_mut().invalidate(); + #[instrument(skip_all, level = Level::DEBUG)] + fn last(&mut self) -> Result> { + let cursor_has_record = return_if_io!(self.move_to_rightmost()); + self.has_record.replace(cursor_has_record); + self.invalidate_record(); + Ok(IOResult::Done(())) } #[instrument(skip_all, level = Level::DEBUG)] - pub fn prev(&mut self) -> Result> { - assert!(self.mv_cursor.is_none()); + fn prev(&mut self) -> Result> { loop { match self.advance_state { AdvanceState::Start => { @@ -4581,14 +4941,7 @@ impl BTreeCursor { } #[instrument(skip(self), level = Level::DEBUG)] - pub fn rowid(&self) -> Result>> { - if let Some(mv_cursor) = &self.mv_cursor { - let mut mv_cursor = mv_cursor.write(); - let Some(rowid) = mv_cursor.current_row_id() else { - return Ok(IOResult::Done(None)); - }; - return Ok(IOResult::Done(Some(rowid.row_id))); - } + fn rowid(&self) -> Result>> { if self.get_null_flag() { return Ok(IOResult::Done(None)); } @@ -4610,11 +4963,7 @@ impl BTreeCursor { } #[instrument(skip(self, key), level = Level::DEBUG)] - pub fn seek(&mut self, key: SeekKey<'_>, op: SeekOp) -> Result> { - if let Some(mv_cursor) = &self.mv_cursor { - let mut mv_cursor = mv_cursor.write(); - return mv_cursor.seek(key, op); - } + fn seek(&mut self, key: SeekKey<'_>, op: SeekOp) -> Result> { self.skip_advance.set(false); // Empty trace to capture the span information tracing::trace!(""); @@ -4630,12 +4979,9 @@ impl BTreeCursor { Ok(IOResult::Done(seek_result)) } - /// Return a reference to the record the cursor is currently pointing to. - /// If record was not parsed yet, then we have to parse it and in case of I/O we yield control - /// back. #[instrument(skip(self), level = Level::DEBUG)] - pub fn record(&self) -> Result>>> { - if !self.has_record.get() && self.mv_cursor.is_none() { + fn record(&self) -> Result>>> { + if !self.has_record.get() { return Ok(IOResult::Done(None)); } let invalidated = self @@ -4649,25 +4995,6 @@ impl BTreeCursor { .unwrap(); return Ok(IOResult::Done(Some(record_ref))); } - if let Some(mv_cursor) = &self.mv_cursor { - let mut mv_cursor = mv_cursor.write(); - let Some(row) = mv_cursor.current_row()? else { - return Ok(IOResult::Done(None)); - }; - self.get_immutable_record_or_create() - .as_mut() - .unwrap() - .invalidate(); - self.get_immutable_record_or_create() - .as_mut() - .unwrap() - .start_serialization(&row.data); - self.record_cursor.borrow_mut().invalidate(); - let record_ref = - Ref::filter_map(self.reusable_immutable_record.borrow(), |opt| opt.as_ref()) - .unwrap(); - return Ok(IOResult::Done(Some(record_ref))); - } let page = self.stack.top_ref(); let contents = page.get_contents(); @@ -4713,35 +5040,16 @@ impl BTreeCursor { } #[instrument(skip_all, level = Level::DEBUG)] - pub fn insert(&mut self, key: &BTreeKey) -> Result> { + fn insert(&mut self, key: &BTreeKey) -> Result> { tracing::debug!(valid_state = ?self.valid_state, cursor_state = ?self.state, is_write_in_progress = self.is_write_in_progress()); - match &self.mv_cursor { - Some(mv_cursor) => match key.maybe_rowid() { - Some(rowid) => { - let row_id = - crate::mvcc::database::RowID::new(mv_cursor.read().table_id, rowid); - let record_buf = key.get_record().unwrap().get_payload().to_vec(); - let num_columns = match key { - BTreeKey::IndexKey(record) => record.column_count(), - BTreeKey::TableRowId((_, record)) => { - record.as_ref().unwrap().column_count() - } - }; - let row = crate::mvcc::database::Row::new(row_id, record_buf, num_columns); - mv_cursor.write().insert(row)?; - } - None => todo!("Support mvcc inserts with index btrees"), - }, - None => { - return_if_io!(self.insert_into_page(key)); - if key.maybe_rowid().is_some() { - self.has_record.replace(true); - } - } - }; + return_if_io!(self.insert_into_page(key)); + if key.maybe_rowid().is_some() { + self.has_record.replace(true); + } Ok(IOResult::Done(())) } + #[instrument(skip(self), level = Level::DEBUG)] /// Delete state machine flow: /// 1. Start -> check if the rowid to be delete is present in the page or not. If not we early return /// 2. DeterminePostBalancingSeekKey -> determine the key to seek to after balancing. @@ -4754,14 +5062,7 @@ impl BTreeCursor { /// 8. PostInteriorNodeReplacement -> if an interior node was replaced, we need to advance the cursor once. /// 9. SeekAfterBalancing -> adjust the cursor to a node that is closer to the deleted value. go to Finish /// 10. Finish -> Delete operation is done. Return CursorResult(Ok()) - #[instrument(skip(self), level = Level::DEBUG)] - pub fn delete(&mut self) -> Result> { - if let Some(mv_cursor) = &self.mv_cursor { - let rowid = mv_cursor.write().current_row_id().unwrap(); - mv_cursor.write().delete(rowid)?; - return Ok(IOResult::Done(())); - } - + fn delete(&mut self) -> Result> { if let CursorState::None = &self.state { self.state = CursorState::Delete(DeleteState::Start); } @@ -4777,7 +5078,7 @@ impl BTreeCursor { match delete_state { DeleteState::Start => { let page = self.stack.top_ref(); - self.pager.add_dirty(page); + self.pager.add_dirty(page)?; if matches!( page.get_contents().page_type(), PageType::TableLeaf | PageType::TableInterior @@ -4974,8 +5275,8 @@ impl BTreeCursor { let leaf_page = self.stack.top_ref(); - self.pager.add_dirty(page); - self.pager.add_dirty(leaf_page); + self.pager.add_dirty(page)?; + self.pager.add_dirty(leaf_page)?; // Step 2: Replace the cell in the parent (interior) page. { @@ -5120,22 +5421,21 @@ impl BTreeCursor { } } + #[inline(always)] /// In outer joins, whenever the right-side table has no matching row, the query must still return a row /// for each left-side row. In order to achieve this, we set the null flag on the right-side table cursor /// so that it returns NULL for all columns until cleared. - #[inline(always)] - pub fn set_null_flag(&mut self, flag: bool) { + fn set_null_flag(&mut self, flag: bool) { self.null_flag = flag; } #[inline(always)] - pub fn get_null_flag(&self) -> bool { + fn get_null_flag(&self) -> bool { self.null_flag } #[instrument(skip_all, level = Level::DEBUG)] - pub fn exists(&mut self, key: &Value) -> Result> { - assert!(self.mv_cursor.is_none()); + fn exists(&mut self, key: &Value) -> Result> { let int_key = match key { Value::Integer(i) => i, _ => unreachable!("btree tables are indexed by integers!"), @@ -5147,92 +5447,12 @@ impl BTreeCursor { Ok(IOResult::Done(exists)) } - /// Clear the overflow pages linked to a specific page provided by the leaf cell - /// Uses a state machine to keep track of it's operations so that traversal can be - /// resumed from last point after IO interruption - #[instrument(skip_all, level = Level::DEBUG)] - fn clear_overflow_pages(&mut self, cell: &BTreeCell) -> Result> { - loop { - match self.overflow_state.clone() { - OverflowState::Start => { - let first_overflow_page = match cell { - BTreeCell::TableLeafCell(leaf_cell) => leaf_cell.first_overflow_page, - BTreeCell::IndexLeafCell(leaf_cell) => leaf_cell.first_overflow_page, - BTreeCell::IndexInteriorCell(interior_cell) => { - interior_cell.first_overflow_page - } - BTreeCell::TableInteriorCell(_) => return Ok(IOResult::Done(())), // No overflow pages - }; - - if let Some(next_page) = first_overflow_page { - if next_page < 2 - || next_page - > self - .pager - .io - .block(|| { - self.pager.with_header(|header| header.database_size) - })? - .get() - { - self.overflow_state = OverflowState::Start; - return Err(LimboError::Corrupt("Invalid overflow page number".into())); - } - let (page, c) = self.read_page(next_page as i64)?; - self.overflow_state = OverflowState::ProcessPage { next_page: page }; - if let Some(c) = c { - io_yield_one!(c); - } - } else { - self.overflow_state = OverflowState::Done; - } - } - OverflowState::ProcessPage { next_page: page } => { - turso_assert!(page.is_loaded(), "page should be loaded"); - - let contents = page.get_contents(); - let next = contents.read_u32_no_offset(0); - let next_page_id = page.get().id; - - return_if_io!(self.pager.free_page(Some(page), next_page_id)); - - if next != 0 { - if next < 2 - || next - > self - .pager - .io - .block(|| { - self.pager.with_header(|header| header.database_size) - })? - .get() - { - self.overflow_state = OverflowState::Start; - return Err(LimboError::Corrupt("Invalid overflow page number".into())); - } - let (page, c) = self.read_page(next as i64)?; - self.overflow_state = OverflowState::ProcessPage { next_page: page }; - if let Some(c) = c { - io_yield_one!(c); - } - } else { - self.overflow_state = OverflowState::Done; - } - } - OverflowState::Done => { - self.overflow_state = OverflowState::Start; - return Ok(IOResult::Done(())); - } - }; - } - } - /// Deletes all content from the B-Tree but preserves the root page. /// /// Unlike [`btree_destroy`], which frees all pages including the root, /// this method only clears the tree’s contents. The root page remains /// allocated and is reset to an empty leaf page. - pub fn clear_btree(&mut self) -> Result>> { + fn clear_btree(&mut self) -> Result>> { self.destroy_btree_contents(true) } @@ -5243,342 +5463,15 @@ impl BTreeCursor { /// /// For cases where the B-Tree should remain allocated but emptied, see [`btree_clear`]. #[instrument(skip(self), level = Level::DEBUG)] - pub fn btree_destroy(&mut self) -> Result>> { + fn btree_destroy(&mut self) -> Result>> { self.destroy_btree_contents(false) } - /// Deletes all contents of the B-tree by freeing all its pages in an iterative depth-first order. - /// This ensures child pages are freed before their parents - /// Uses a state machine to keep track of the operation to ensure IO doesn't cause repeated traversals - /// - /// Depending on the caller, the root page may either be freed as well or left allocated but emptied. - /// - /// # Example - /// For a B-tree with this structure (where 4' is an overflow page): - /// ```text - /// 1 (root) - /// / \ - /// 2 3 - /// / \ / \ - /// 4' <- 4 5 6 7 - /// ``` - /// - /// The destruction order would be: [4',4,5,2,6,7,3,1] - fn destroy_btree_contents(&mut self, keep_root: bool) -> Result>> { - if let CursorState::None = &self.state { - let c = self.move_to_root()?; - self.state = CursorState::Destroy(DestroyInfo { - state: DestroyState::Start, - }); - if let Some(c) = c { - io_yield_one!(c); - } - } - - loop { - let destroy_state = { - let destroy_info = self - .state - .destroy_info() - .expect("unable to get a mut reference to destroy state in cursor"); - destroy_info.state.clone() - }; - - match destroy_state { - DestroyState::Start => { - let destroy_info = self - .state - .mut_destroy_info() - .expect("unable to get a mut reference to destroy state in cursor"); - destroy_info.state = DestroyState::LoadPage; - } - DestroyState::LoadPage => { - let _page = self.stack.top_ref(); - - let destroy_info = self - .state - .mut_destroy_info() - .expect("unable to get a mut reference to destroy state in cursor"); - destroy_info.state = DestroyState::ProcessPage; - } - DestroyState::ProcessPage => { - self.stack.advance(); - let page = self.stack.top_ref(); - let contents = page.get_contents(); - let cell_idx = self.stack.current_cell_index(); - - // If we've processed all cells in this page, figure out what to do with this page - if cell_idx >= contents.cell_count() as i32 { - match (contents.is_leaf(), cell_idx) { - // Leaf pages with all cells processed - (true, n) if n >= contents.cell_count() as i32 => { - let destroy_info = self.state.mut_destroy_info().expect( - "unable to get a mut reference to destroy state in cursor", - ); - destroy_info.state = DestroyState::FreePage; - continue; - } - // Non-leaf page which has processed all children but not it's potential right child - (false, n) if n == contents.cell_count() as i32 => { - if let Some(rightmost) = contents.rightmost_pointer() { - let (rightmost_page, c) = self.read_page(rightmost as i64)?; - self.stack.push(rightmost_page); - let destroy_info = self.state.mut_destroy_info().expect( - "unable to get a mut reference to destroy state in cursor", - ); - destroy_info.state = DestroyState::LoadPage; - if let Some(c) = c { - io_yield_one!(c); - } - } else { - let destroy_info = self.state.mut_destroy_info().expect( - "unable to get a mut reference to destroy state in cursor", - ); - destroy_info.state = DestroyState::FreePage; - } - continue; - } - // Non-leaf page which has processed all children and it's right child - (false, n) if n > contents.cell_count() as i32 => { - let destroy_info = self.state.mut_destroy_info().expect( - "unable to get a mut reference to destroy state in cursor", - ); - destroy_info.state = DestroyState::FreePage; - continue; - } - _ => unreachable!("Invalid cell idx state"), - } - } - - // We have not yet processed all cells in this page - // Get the current cell - let cell = contents.cell_get(cell_idx as usize, self.usable_space())?; - - match contents.is_leaf() { - // For a leaf cell, clear the overflow pages associated with this cell - true => { - let destroy_info = self - .state - .mut_destroy_info() - .expect("unable to get a mut reference to destroy state in cursor"); - destroy_info.state = DestroyState::ClearOverflowPages { cell }; - continue; - } - // For interior cells, check the type of cell to determine what to do - false => match &cell { - // For index interior cells, remove the overflow pages - BTreeCell::IndexInteriorCell(_) => { - let destroy_info = self.state.mut_destroy_info().expect( - "unable to get a mut reference to destroy state in cursor", - ); - destroy_info.state = DestroyState::ClearOverflowPages { cell }; - continue; - } - // For all other interior cells, load the left child page - _ => { - let child_page_id = match &cell { - BTreeCell::TableInteriorCell(cell) => cell.left_child_page, - BTreeCell::IndexInteriorCell(cell) => cell.left_child_page, - _ => panic!("expected interior cell"), - }; - let (child_page, c) = self.read_page(child_page_id as i64)?; - self.stack.push(child_page); - let destroy_info = self.state.mut_destroy_info().expect( - "unable to get a mut reference to destroy state in cursor", - ); - destroy_info.state = DestroyState::LoadPage; - if let Some(c) = c { - io_yield_one!(c); - } - } - }, - } - } - DestroyState::ClearOverflowPages { cell } => { - return_if_io!(self.clear_overflow_pages(&cell)); - match cell { - // For an index interior cell, clear the left child page now that overflow pages have been cleared - BTreeCell::IndexInteriorCell(index_int_cell) => { - let (child_page, c) = - self.read_page(index_int_cell.left_child_page as i64)?; - self.stack.push(child_page); - let destroy_info = self - .state - .mut_destroy_info() - .expect("unable to get a mut reference to destroy state in cursor"); - destroy_info.state = DestroyState::LoadPage; - if let Some(c) = c { - io_yield_one!(c); - } - } - // For any leaf cell, advance the index now that overflow pages have been cleared - BTreeCell::TableLeafCell(_) | BTreeCell::IndexLeafCell(_) => { - let destroy_info = self - .state - .mut_destroy_info() - .expect("unable to get a mut reference to destroy state in cursor"); - destroy_info.state = DestroyState::LoadPage; - } - _ => panic!("unexpected cell type"), - } - } - DestroyState::FreePage => { - let page = self.stack.top(); - let page_id = page.get().id; - - if self.stack.has_parent() { - return_if_io!(self.pager.free_page(Some(page), page_id)); - - self.stack.pop(); - let destroy_info = self - .state - .mut_destroy_info() - .expect("unable to get a mut reference to destroy state in cursor"); - destroy_info.state = DestroyState::ProcessPage; - } else { - if keep_root { - self.clear_root(&page); - } else { - return_if_io!(self.pager.free_page(Some(page), page_id)); - } - - self.state = CursorState::None; - // TODO: For now, no-op the result return None always. This will change once [AUTO_VACUUM](https://www.sqlite.org/lang_vacuum.html) is introduced - // At that point, the last root page(call this x) will be moved into the position of the root page of this table and the value returned will be x - return Ok(IOResult::Done(None)); - } - } - } - } - } - - fn clear_root(&mut self, root_page: &PageRef) { - let page_ref = root_page.get(); - let contents = page_ref.contents.as_ref().unwrap(); - - let page_type = match contents.page_type() { - PageType::TableLeaf | PageType::TableInterior => PageType::TableLeaf, - PageType::IndexLeaf | PageType::IndexInterior => PageType::IndexLeaf, - }; - - self.pager.add_dirty(root_page); - btree_init_page(root_page, page_type, 0, self.pager.usable_space()); - } - - pub fn overwrite_cell( - &mut self, - page: &PageRef, - cell_idx: usize, - record: &ImmutableRecord, - state: &mut OverwriteCellState, - ) -> Result> { - loop { - turso_assert!(page.is_loaded(), "page {} is not loaded", page.get().id); - match state { - OverwriteCellState::AllocatePayload => { - let serial_types_len = self.record_cursor.borrow_mut().len(record); - let new_payload = Vec::with_capacity(serial_types_len); - let rowid = return_if_io!(self.rowid()); - *state = OverwriteCellState::FillPayload { - new_payload, - rowid, - fill_cell_payload_state: FillCellPayloadState::Start, - }; - continue; - } - OverwriteCellState::FillPayload { - new_payload, - rowid, - fill_cell_payload_state, - } => { - { - return_if_io!(fill_cell_payload( - page, - *rowid, - new_payload, - cell_idx, - record, - self.usable_space(), - self.pager.clone(), - fill_cell_payload_state, - )); - } - // figure out old cell offset & size - let (old_offset, old_local_size) = { - let contents = page.get_contents(); - contents.cell_get_raw_region(cell_idx, self.usable_space()) - }; - - *state = OverwriteCellState::ClearOverflowPagesAndOverwrite { - new_payload: new_payload.clone(), - old_offset, - old_local_size, - }; - continue; - } - OverwriteCellState::ClearOverflowPagesAndOverwrite { - new_payload, - old_offset, - old_local_size, - } => { - let contents = page.get_contents(); - let cell = contents.cell_get(cell_idx, self.usable_space())?; - return_if_io!(self.clear_overflow_pages(&cell)); - - // if it all fits in local space and old_local_size is enough, do an in-place overwrite - if new_payload.len() == *old_local_size { - let _res = - BTreeCursor::overwrite_content(page.clone(), *old_offset, new_payload)?; - return Ok(IOResult::Done(())); - } - - drop_cell(contents, cell_idx, self.usable_space())?; - insert_into_cell(contents, new_payload, cell_idx, self.usable_space())?; - return Ok(IOResult::Done(())); - } - } - } - } - - pub fn overwrite_content( - page: PageRef, - dest_offset: usize, - new_payload: &[u8], - ) -> Result> { - turso_assert!(page.is_loaded(), "page should be loaded"); - let buf = page.get_contents().as_ptr(); - buf[dest_offset..dest_offset + new_payload.len()].copy_from_slice(new_payload); - - Ok(IOResult::Done(())) - } - - fn get_immutable_record_or_create(&self) -> std::cell::RefMut<'_, Option> { - let mut reusable_immutable_record = self.reusable_immutable_record.borrow_mut(); - if reusable_immutable_record.is_none() { - let page_size = self.pager.get_page_size_unchecked().get(); - let record = ImmutableRecord::new(page_size as usize); - reusable_immutable_record.replace(record); - } - reusable_immutable_record - } - - fn get_immutable_record(&self) -> std::cell::RefMut<'_, Option> { - self.reusable_immutable_record.borrow_mut() - } - - pub fn is_write_in_progress(&self) -> bool { - matches!(self.state, CursorState::Write(_)) - } - + #[instrument(skip(self), level = Level::DEBUG)] /// Count the number of entries in the b-tree /// /// Only supposed to be used in the context of a simple Count Select Statement - #[instrument(skip(self), level = Level::DEBUG)] - pub fn count(&mut self) -> Result> { - if let Some(_mv_cursor) = &self.mv_cursor { - todo!("Implement count for mvcc"); - } - + fn count(&mut self) -> Result> { let mut mem_page; let mut contents; @@ -5680,66 +5573,136 @@ impl BTreeCursor { } } } - - // Save cursor context, to be restored later - pub fn save_context(&mut self, cursor_context: CursorContext) { - self.valid_state = CursorValidState::RequireSeek; - self.context = Some(cursor_context); + fn is_empty(&self) -> bool { + !self.has_record.get() + } + + fn root_page(&self) -> i64 { + self.root_page } - /// If context is defined, restore it and set it None on success #[instrument(skip_all, level = Level::DEBUG)] - fn restore_context(&mut self) -> Result> { - if self.context.is_none() || matches!(self.valid_state, CursorValidState::Valid) { + fn rewind(&mut self) -> Result> { + if self.valid_state == CursorValidState::Invalid { return Ok(IOResult::Done(())); } - if let CursorValidState::RequireAdvance(direction) = self.valid_state { - let has_record = return_if_io!(match direction { - // Avoid calling next()/prev() directly because they immediately call restore_context() - IterationDirection::Forwards => self.get_next_record(), - IterationDirection::Backwards => self.get_prev_record(), - }); - self.has_record.set(has_record); - self.invalidate_record(); - self.context = None; - self.valid_state = CursorValidState::Valid; - return Ok(IOResult::Done(())); - } - let ctx = self.context.take().unwrap(); - let seek_key = match ctx.key { - CursorContextKey::TableRowId(rowid) => SeekKey::TableRowId(rowid), - CursorContextKey::IndexKeyRowId(ref record) => SeekKey::IndexKey(record), - }; - let res = self.seek(seek_key, ctx.seek_op)?; - match res { - IOResult::Done(res) => { - if let SeekResult::TryAdvance = res { - self.valid_state = - CursorValidState::RequireAdvance(ctx.seek_op.iteration_direction()); - self.context = Some(ctx); - io_yield_one!(Completion::new_dummy()); + self.skip_advance.set(false); + loop { + match self.rewind_state { + RewindState::Start => { + self.rewind_state = RewindState::NextRecord; + let c = self.move_to_root()?; + if let Some(c) = c { + io_yield_one!(c); + } + } + RewindState::NextRecord => { + let cursor_has_record = return_if_io!(self.get_next_record()); + self.invalidate_record(); + self.has_record.replace(cursor_has_record); + self.rewind_state = RewindState::Start; + return Ok(IOResult::Done(())); } - self.valid_state = CursorValidState::Valid; - Ok(IOResult::Done(())) - } - IOResult::IO(io) => { - self.context = Some(ctx); - Ok(IOResult::IO(io)) } } } - pub fn read_page(&self, page_idx: i64) -> Result<(PageRef, Option)> { - btree_read_page(&self.pager, page_idx) + fn has_rowid(&self) -> bool { + match &self.index_info { + Some(index_key_info) => index_key_info.has_rowid, + None => true, // currently we don't support WITHOUT ROWID tables + } } - pub fn allocate_page(&self, page_type: PageType, offset: usize) -> Result> { - self.pager - .do_allocate_page(page_type, offset, BtreePageAllocMode::Any) + fn invalidate_record(&mut self) { + self.get_immutable_record_or_create() + .as_mut() + .unwrap() + .invalidate(); + self.record_cursor.borrow_mut().invalidate(); + } + fn record_cursor_mut(&self) -> std::cell::RefMut<'_, RecordCursor> { + self.record_cursor.borrow_mut() } - pub fn get_mvcc_cursor(&self) -> Arc> { - self.mv_cursor.as_ref().unwrap().clone() + fn get_pager(&self) -> Arc { + self.pager.clone() + } + + fn get_skip_advance(&self) -> bool { + self.skip_advance.get() + } + + fn has_record(&self) -> bool { + self.has_record.get() + } + + fn set_has_record(&self, has_record: bool) { + self.has_record.set(has_record) + } + + fn get_index_info(&self) -> &IndexInfo { + self.index_info.as_ref().unwrap() + } + + fn seek_end(&mut self) -> Result> { + loop { + match self.seek_end_state { + SeekEndState::Start => { + let c = self.move_to_root()?; + self.seek_end_state = SeekEndState::ProcessPage; + if let Some(c) = c { + io_yield_one!(c); + } + } + SeekEndState::ProcessPage => { + let mem_page = self.stack.top_ref(); + let contents = mem_page.get_contents(); + if contents.is_leaf() { + // set cursor just past the last cell to append + self.stack.set_cell_index(contents.cell_count() as i32); + self.seek_end_state = SeekEndState::Start; + return Ok(IOResult::Done(())); + } + + match contents.rightmost_pointer() { + Some(right_most_pointer) => { + self.stack.set_cell_index(contents.cell_count() as i32 + 1); // invalid on interior + let (child, c) = self.read_page(right_most_pointer as i64)?; + self.stack.push(child); + if let Some(c) = c { + io_yield_one!(c); + } + } + None => unreachable!("interior page must have rightmost pointer"), + } + } + } + } + } + + #[instrument(skip_all, level = Level::DEBUG)] + fn seek_to_last(&mut self) -> Result> { + loop { + match self.seek_to_last_state { + SeekToLastState::Start => { + let has_record = return_if_io!(self.move_to_rightmost()); + self.invalidate_record(); + self.has_record.replace(has_record); + if !has_record { + self.seek_to_last_state = SeekToLastState::IsEmpty; + continue; + } + return Ok(IOResult::Done(())); + } + SeekToLastState::IsEmpty => { + let is_empty = return_if_io!(self.is_empty_table()); + assert!(is_empty); + self.seek_to_last_state = SeekToLastState::Start; + return Ok(IOResult::Done(())); + } + } + } } } @@ -5811,6 +5774,8 @@ pub enum IntegrityCheckError { actual_count: usize, expected_count: usize, }, + #[error("Page {page_id}: never used")] + PageNeverUsed { page_id: i64 }, } #[derive(Debug, Clone, Copy, PartialEq)] @@ -5836,16 +5801,18 @@ struct IntegrityCheckPageEntry { } pub struct IntegrityCheckState { page_stack: Vec, + pub db_size: usize, first_leaf_level: Option, - page_reference: HashMap, + pub page_reference: HashMap, page: Option, pub freelist_count: CheckFreelist, } impl IntegrityCheckState { - pub fn new() -> Self { + pub fn new(db_size: usize) -> Self { Self { page_stack: Vec::new(), + db_size, page_reference: HashMap::new(), first_leaf_level: None, page: None, @@ -7804,7 +7771,7 @@ fn shift_pointers_left(page: &mut PageContent, cell_idx: usize) { #[cfg(test)] mod tests { - use rand::{thread_rng, Rng}; + use rand::{rng, Rng}; use rand_chacha::{ rand_core::{RngCore, SeedableRng}, ChaCha8Rng, @@ -7903,11 +7870,11 @@ mod tests { pos, &record, 4096, - conn.pager.read().clone(), + conn.pager.load().clone(), &mut fill_cell_payload_state, ) }, - &conn.pager.read().clone(), + &conn.pager.load().clone(), ) .unwrap(); insert_into_cell(page.get_contents(), &payload, pos, 4096).unwrap(); @@ -8005,7 +7972,7 @@ mod tests { fn validate_btree(pager: Arc, page_idx: i64) -> (usize, bool) { let num_columns = 5; - let cursor = BTreeCursor::new_table(None, pager.clone(), page_idx, num_columns); + let cursor = BTreeCursor::new_table(pager.clone(), page_idx, num_columns); let (page, _c) = cursor.read_page(page_idx).unwrap(); while page.is_locked() { pager.io.step().unwrap(); @@ -8116,7 +8083,7 @@ mod tests { fn format_btree(pager: Arc, page_idx: i64, depth: usize) -> String { let num_columns = 5; - let cursor = BTreeCursor::new_table(None, pager.clone(), page_idx, num_columns); + let cursor = BTreeCursor::new_table(pager.clone(), page_idx, num_columns); let (page, _c) = cursor.read_page(page_idx).unwrap(); while page.is_locked() { pager.io.step().unwrap(); @@ -8176,14 +8143,14 @@ mod tests { let io: Arc = Arc::new(MemoryIO::new()); let db = Database::open_file(io.clone(), ":memory:", false, false).unwrap(); let conn = db.connect().unwrap(); - let pager = conn.pager.read().clone(); + let pager = conn.pager.load().clone(); // FIXME: handle page cache is full // force allocate page1 with a transaction pager.begin_read_tx().unwrap(); run_until_done(|| pager.begin_write_tx(), &pager).unwrap(); - run_until_done(|| pager.end_tx(false, &conn), &pager).unwrap(); + run_until_done(|| pager.commit_tx(&conn), &pager).unwrap(); let page2 = run_until_done(|| pager.allocate_page(), &pager).unwrap(); btree_init_page(&page2, PageType::TableLeaf, 0, pager.usable_space()); @@ -8196,9 +8163,9 @@ mod tests { let io: Arc = Arc::new(MemoryIO::new()); let db = Database::open_file(io.clone(), ":memory:", false, false).unwrap(); let conn = db.connect().unwrap(); - let pager = conn.pager.read().clone(); + let pager = conn.pager.load().clone(); - let mut cursor = BTreeCursor::new(None, pager, 1, 5); + let mut cursor = BTreeCursor::new(pager, 1, 5); let result = cursor.rewind()?; assert!(matches!(result, IOResult::Done(_))); let result = cursor.next()?; @@ -8228,7 +8195,7 @@ mod tests { let large_record = ImmutableRecord::from_registers(regs, regs.len()); // Create cursor for the table - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); let initial_pagecount = pager .io @@ -8380,7 +8347,7 @@ mod tests { let (pager, root_page, _, _) = empty_btree(); let num_columns = 5; - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); for (key, size) in sequence.iter() { run_until_done( || { @@ -8446,7 +8413,7 @@ mod tests { for _ in 0..attempts { let (pager, root_page, _db, conn) = empty_btree(); - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); let mut keys = SortedVec::new(); tracing::info!("seed: {seed}"); for insert_id in 0..inserts { @@ -8495,7 +8462,7 @@ mod tests { pager.deref(), ) .unwrap(); - pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); + pager.io.block(|| pager.commit_tx(&conn)).unwrap(); pager.begin_read_tx().unwrap(); // FIXME: add sorted vector instead, should be okay for small amounts of keys for now :P, too lazy to fix right now let _c = cursor.move_to_root().unwrap(); @@ -8524,7 +8491,7 @@ mod tests { println!("btree after:\n{btree_after}"); panic!("invalid btree"); } - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } pager.begin_read_tx().unwrap(); tracing::info!( @@ -8546,7 +8513,7 @@ mod tests { "key {key} is not found, got {cursor_rowid}" ); } - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } } @@ -8588,13 +8555,8 @@ mod tests { has_rowid: false, }; let num_columns = index_def.columns.len(); - let mut cursor = BTreeCursor::new_index( - None, - pager.clone(), - index_root_page, - &index_def, - num_columns, - ); + let mut cursor = + BTreeCursor::new_index(pager.clone(), index_root_page, &index_def, num_columns); let mut keys = SortedVec::new(); tracing::info!("seed: {seed}"); for i in 0..inserts { @@ -8641,7 +8603,7 @@ mod tests { if let Some(c) = c { pager.io.wait_for_completion(c).unwrap(); } - pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); + pager.io.block(|| pager.commit_tx(&conn)).unwrap(); } // Check that all keys can be found by seeking @@ -8690,7 +8652,11 @@ mod tests { run_until_done(|| cursor.next(), pager.deref()).unwrap(); let record = run_until_done(|| cursor.record(), &pager).unwrap(); let record = record.as_ref().unwrap(); - let cur = record.get_values().clone(); + let cur = record + .get_values() + .iter() + .map(ValueRef::to_owned) + .collect::>(); if let Some(prev) = prev { if prev >= cur { println!("Seed: {seed}"); @@ -8702,7 +8668,7 @@ mod tests { } prev = Some(cur); } - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } } @@ -8747,8 +8713,7 @@ mod tests { ephemeral: false, has_rowid: false, }; - let mut cursor = - BTreeCursor::new_index(None, pager.clone(), index_root_page, &index_def, 1); + let mut cursor = BTreeCursor::new_index(pager.clone(), index_root_page, &index_def, 1); // Track expected keys that should be present in the tree let mut expected_keys = Vec::new(); @@ -8848,7 +8813,7 @@ mod tests { if let Some(c) = c { pager.io.wait_for_completion(c).unwrap(); } - pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); + pager.io.block(|| pager.commit_tx(&conn)).unwrap(); } // Final validation @@ -8856,7 +8821,7 @@ mod tests { sorted_keys.sort(); validate_expected_keys(&pager, &mut cursor, &sorted_keys, seed); - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } } @@ -8930,16 +8895,12 @@ mod tests { let record = record.as_ref().unwrap(); let cur = record.get_values().clone(); let cur = cur.first().unwrap(); - let RefValue::Blob(ref cur) = cur else { + let ValueRef::Blob(ref cur) = cur else { panic!("expected blob, got {cur:?}"); }; - assert_eq!( - cur.to_slice(), - key, - "key {key:?} is not found, seed: {seed}" - ); + assert_eq!(cur, key, "key {key:?} is not found, seed: {seed}"); } - pager.end_read_tx().unwrap(); + pager.end_read_tx(); } #[test] @@ -9143,7 +9104,7 @@ mod tests { let pager = setup_test_env(5); let num_columns = 5; - let mut cursor = BTreeCursor::new_table(None, pager.clone(), 1, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), 1, num_columns); let max_local = payload_overflow_threshold_max(PageType::TableLeaf, 4096); let usable_size = cursor.usable_space(); @@ -9254,7 +9215,7 @@ mod tests { let pager = setup_test_env(5); let num_columns = 5; - let mut cursor = BTreeCursor::new_table(None, pager.clone(), 1, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), 1, num_columns); let small_payload = vec![b'A'; 10]; @@ -9303,7 +9264,7 @@ mod tests { let pager = setup_test_env(initial_size); let num_columns = 5; - let mut cursor = BTreeCursor::new_table(None, pager.clone(), 2, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), 2, num_columns); // Initialize page 2 as a root page (interior) let root_page = run_until_done( @@ -9396,7 +9357,7 @@ mod tests { let num_columns = 5; let record_count = 10; - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); for rowid in 1..=record_count { insert_record(&mut cursor, &pager, rowid, Value::Integer(rowid))?; @@ -9421,7 +9382,7 @@ mod tests { let num_columns = 5; let record_count = 1000; - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); for rowid in 1..=record_count { insert_record(&mut cursor, &pager, rowid, Value::Integer(rowid))?; @@ -9447,7 +9408,7 @@ mod tests { let num_columns = 5; let record_count = 1000; - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); for rowid in 1..=record_count { insert_record(&mut cursor, &pager, rowid, Value::Integer(rowid))?; @@ -9469,11 +9430,11 @@ mod tests { let exists = run_until_done(|| cursor.next(), &pager)?; assert!(exists, "Record {i} not found"); - let record = run_until_done(|| cursor.record(), &pager)?; - let value = record.unwrap().get_value(0)?; + let record = run_until_done(|| cursor.record(), &pager)?.unwrap(); + let value = record.get_value(0)?; assert_eq!( value, - RefValue::Integer(i), + ValueRef::Integer(i), "Unexpected value for record {i}", ); } @@ -9487,8 +9448,8 @@ mod tests { let num_columns = 5; let record_count = 1000; - let mut cursor1 = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); - let mut cursor2 = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor1 = BTreeCursor::new_table(pager.clone(), root_page, num_columns); + let mut cursor2 = BTreeCursor::new_table(pager.clone(), root_page, num_columns); // Use cursor1 to insert records for rowid in 1..=record_count { @@ -9521,7 +9482,7 @@ mod tests { let num_columns = 5; let record_count = 100; - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); let initial_page_count = pager .io @@ -9648,7 +9609,7 @@ mod tests { let mut cells = Vec::new(); let usable_space = 4096; let mut i = 100000; - let seed = thread_rng().gen(); + let seed = rng().random(); tracing::info!("seed {}", seed); let mut rng = ChaCha8Rng::seed_from_u64(seed); while i > 0 { @@ -9671,11 +9632,11 @@ mod tests { cell_idx, &record, 4096, - conn.pager.read().clone(), + conn.pager.load().clone(), &mut fill_cell_payload_state, ) }, - &conn.pager.read().clone(), + &conn.pager.load().clone(), ) .unwrap(); if (free as usize) < payload.len() + 2 { @@ -9753,11 +9714,11 @@ mod tests { cell_idx, &record, 4096, - conn.pager.read().clone(), + conn.pager.load().clone(), &mut fill_cell_payload_state, ) }, - &conn.pager.read().clone(), + &conn.pager.load().clone(), ) .unwrap(); if (free as usize) < payload.len() - 2 { @@ -10126,11 +10087,11 @@ mod tests { 0, &record, 4096, - conn.pager.read().clone(), + conn.pager.load().clone(), &mut fill_cell_payload_state, ) }, - &conn.pager.read().clone(), + &conn.pager.load().clone(), ) .unwrap(); @@ -10157,7 +10118,7 @@ mod tests { let num_columns = 5; for i in 0..10000 { - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); tracing::info!("INSERT INTO t VALUES ({});", i,); let regs = &[Register::Value(Value::Integer(i))]; let value = ImmutableRecord::from_registers(regs, regs.len()); @@ -10185,7 +10146,7 @@ mod tests { format_btree(pager.clone(), root_page, 0) ); for key in keys.iter() { - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); let key = Value::Integer(*key); let exists = run_until_done(|| cursor.exists(&key), pager.deref()).unwrap(); assert!(exists, "key not found {key}"); @@ -10212,11 +10173,11 @@ mod tests { 0, &record, 4096, - conn.pager.read().clone(), + conn.pager.load().clone(), &mut fill_cell_payload_state, ) }, - &conn.pager.read().clone(), + &conn.pager.load().clone(), ) .unwrap(); insert_into_cell(page.get_contents(), &payload, 0, 4096).unwrap(); @@ -10244,7 +10205,7 @@ mod tests { // Insert 10,000 records in to the BTree. for i in 1..=10000 { - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); let regs = &[Register::Value(Value::Text(Text::new("hello world")))]; let value = ImmutableRecord::from_registers(regs, regs.len()); @@ -10271,7 +10232,7 @@ mod tests { // Delete records with 500 <= key <= 3500 for i in 500..=3500 { - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); let seek_key = SeekKey::TableRowId(i); let seek_result = run_until_done( @@ -10291,7 +10252,7 @@ mod tests { continue; } - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); let key = Value::Integer(i); let exists = run_until_done(|| cursor.exists(&key), pager.deref()).unwrap(); assert!(exists, "Key {i} should exist but doesn't"); @@ -10299,7 +10260,7 @@ mod tests { // Verify the deleted records don't exist. for i in 500..=3500 { - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); let key = Value::Integer(i); let exists = run_until_done(|| cursor.exists(&key), pager.deref()).unwrap(); assert!(!exists, "Deleted key {i} still exists"); @@ -10322,7 +10283,7 @@ mod tests { let num_columns = 5; for (i, huge_text) in huge_texts.iter().enumerate().take(iterations) { - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); tracing::info!("INSERT INTO t VALUES ({});", i,); let regs = &[Register::Value(Value::Text(Text { value: huge_text.as_bytes().to_vec(), @@ -10352,7 +10313,7 @@ mod tests { format_btree(pager.clone(), root_page, 0) ); } - let mut cursor = BTreeCursor::new_table(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new_table(pager.clone(), root_page, num_columns); let _c = cursor.move_to_root().unwrap(); for i in 0..iterations { let has_next = run_until_done(|| cursor.next(), pager.deref()).unwrap(); @@ -10370,7 +10331,7 @@ mod tests { pub fn test_read_write_payload_with_offset() { let (pager, root_page, _, _) = empty_btree(); let num_columns = 5; - let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new(pager.clone(), root_page, num_columns); let offset = 2; // blobs data starts at offset 2 let initial_text = "hello world"; let initial_blob = initial_text.as_bytes().to_vec(); @@ -10447,7 +10408,7 @@ mod tests { pub fn test_read_write_payload_with_overflow_page() { let (pager, root_page, _, _) = empty_btree(); let num_columns = 5; - let mut cursor = BTreeCursor::new(None, pager.clone(), root_page, num_columns); + let mut cursor = BTreeCursor::new(pager.clone(), root_page, num_columns); let mut large_blob = vec![b'A'; 40960 - 11]; // insert large blob. 40960 = 10 page long. let hello_world = b"hello world"; large_blob.extend_from_slice(hello_world); diff --git a/core/storage/buffer_pool.rs b/core/storage/buffer_pool.rs index 05fe77d6a..6bac73f71 100644 --- a/core/storage/buffer_pool.rs +++ b/core/storage/buffer_pool.rs @@ -239,13 +239,13 @@ impl PoolInner { .wal_frame_arena .as_ref() .and_then(|wal_arena| Arena::try_alloc(wal_arena, len)) - .unwrap_or(Buffer::new_temporary(len)); + .unwrap_or_else(|| Buffer::new_temporary(len)); } // For all other sizes, use regular arena self.page_arena .as_ref() .and_then(|arena| Arena::try_alloc(arena, len)) - .unwrap_or(Buffer::new_temporary(len)) + .unwrap_or_else(|| Buffer::new_temporary(len)) } fn get_db_page_buffer(&mut self) -> Buffer { @@ -253,7 +253,7 @@ impl PoolInner { self.page_arena .as_ref() .and_then(|arena| Arena::try_alloc(arena, db_page_size)) - .unwrap_or(Buffer::new_temporary(db_page_size)) + .unwrap_or_else(|| Buffer::new_temporary(db_page_size)) } fn get_wal_frame_buffer(&mut self) -> Buffer { @@ -261,7 +261,7 @@ impl PoolInner { self.wal_frame_arena .as_ref() .and_then(|wal_arena| Arena::try_alloc(wal_arena, len)) - .unwrap_or(Buffer::new_temporary(len)) + .unwrap_or_else(|| Buffer::new_temporary(len)) } /// Allocate a new arena for the pool to use @@ -427,11 +427,8 @@ impl Arena { } } -#[cfg(unix)] +#[cfg(all(unix, not(miri)))] mod arena { - #[cfg(target_vendor = "apple")] - use libc::MAP_ANON as MAP_ANONYMOUS; - #[cfg(target_os = "linux")] use libc::MAP_ANONYMOUS; use libc::{mmap, munmap, MAP_PRIVATE, PROT_READ, PROT_WRITE}; use std::ffi::c_void; @@ -463,11 +460,11 @@ mod arena { } } -#[cfg(not(unix))] +#[cfg(any(not(unix), miri))] mod arena { pub fn alloc(len: usize) -> *mut u8 { let layout = std::alloc::Layout::from_size_align(len, std::mem::size_of::()).unwrap(); - unsafe { std::alloc::alloc(layout) } + unsafe { std::alloc::alloc_zeroed(layout) } } pub fn dealloc(ptr: *mut u8, len: usize) { let layout = std::alloc::Layout::from_size_align(len, std::mem::size_of::()).unwrap(); diff --git a/core/storage/checksum.rs b/core/storage/checksum.rs index a67376048..43bd0c5be 100644 --- a/core/storage/checksum.rs +++ b/core/storage/checksum.rs @@ -94,6 +94,7 @@ impl Default for ChecksumContext { } #[cfg(test)] +#[cfg(feature = "checksum")] mod tests { use super::*; @@ -110,7 +111,6 @@ mod tests { } #[test] - #[cfg(feature = "checksum")] fn test_add_checksum_to_page() { let ctx = ChecksumContext::new(); let mut page = get_random_page(); @@ -128,7 +128,7 @@ mod tests { } #[test] - fn test_verify_and_strip_checksum_valid() { + fn test_verify_checksum_valid() { let ctx = ChecksumContext::new(); let mut page = get_random_page(); @@ -139,8 +139,7 @@ mod tests { } #[test] - #[cfg(feature = "checksum")] - fn test_verify_and_strip_checksum_mismatch() { + fn test_verify_checksum_mismatch() { let ctx = ChecksumContext::new(); let mut page = get_random_page(); @@ -165,8 +164,7 @@ mod tests { } #[test] - #[cfg(feature = "checksum")] - fn test_verify_and_strip_checksum_corrupted_checksum() { + fn test_verify_checksum_corrupted_checksum() { let ctx = ChecksumContext::new(); let mut page = get_random_page(); diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index 3a1bb97b9..f0184dbc0 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -10,7 +10,7 @@ use aes_gcm::{ aead::{Aead, AeadCore, KeyInit, OsRng}, Aes128Gcm, Aes256Gcm, Key, Nonce, }; -use turso_macros::match_ignore_ascii_case; +use turso_macros::{match_ignore_ascii_case, AtomicEnum}; /// Encryption Scheme /// We support two major algorithms: AEGIS, AES GCM. These algorithms picked so that they also do @@ -319,8 +319,9 @@ define_aegis_cipher!( "AEGIS-128X4" ); -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq)] pub enum CipherMode { + None, Aes128Gcm, Aes256Gcm, Aegis256, @@ -363,6 +364,7 @@ impl std::fmt::Display for CipherMode { CipherMode::Aegis128X4 => write!(f, "aegis128x4"), CipherMode::Aegis256X2 => write!(f, "aegis256x2"), CipherMode::Aegis256X4 => write!(f, "aegis256x4"), + CipherMode::None => write!(f, "None"), } } } @@ -380,6 +382,7 @@ impl CipherMode { CipherMode::Aegis128L => 16, CipherMode::Aegis128X2 => 16, CipherMode::Aegis128X4 => 16, + CipherMode::None => 0, } } @@ -394,6 +397,7 @@ impl CipherMode { CipherMode::Aegis128L => 16, CipherMode::Aegis128X2 => 16, CipherMode::Aegis128X4 => 16, + CipherMode::None => 0, } } @@ -408,6 +412,7 @@ impl CipherMode { CipherMode::Aegis128L => 16, CipherMode::Aegis128X2 => 16, CipherMode::Aegis128X4 => 16, + CipherMode::None => 0, } } @@ -427,6 +432,7 @@ impl CipherMode { CipherMode::Aegis128L => 6, CipherMode::Aegis128X2 => 7, CipherMode::Aegis128X4 => 8, + CipherMode::None => 0, } } @@ -503,6 +509,11 @@ impl EncryptionContext { CipherMode::Aegis128L => Cipher::Aegis128L(Box::new(Aegis128LCipher::new(key))), CipherMode::Aegis128X2 => Cipher::Aegis128X2(Box::new(Aegis128X2Cipher::new(key))), CipherMode::Aegis128X4 => Cipher::Aegis128X4(Box::new(Aegis128X4Cipher::new(key))), + CipherMode::None => { + return Err(LimboError::InvalidArgument( + "must select valid CipherMode".into(), + )) + } }; Ok(Self { cipher_mode, @@ -979,14 +990,14 @@ mod tests { } fn generate_random_hex_key() -> String { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut bytes = [0u8; 32]; rng.fill(&mut bytes); hex::encode(bytes) } fn generate_random_hex_key_128() -> String { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut bytes = [0u8; 16]; rng.fill(&mut bytes); hex::encode(bytes) @@ -995,7 +1006,7 @@ mod tests { fn create_test_page_1() -> Vec { let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE]; page[..SQLITE_HEADER.len()].copy_from_slice(SQLITE_HEADER); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); // 48 is the max reserved bytes we might need for metadata with any cipher rng.fill(&mut page[SQLITE_HEADER.len()..DEFAULT_ENCRYPTED_PAGE_SIZE - 48]); page @@ -1135,7 +1146,7 @@ mod tests { #[test] fn test_aes128gcm_encrypt_decrypt_round_trip() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let cipher_mode = CipherMode::Aes128Gcm; let metadata_size = cipher_mode.metadata_size(); let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size; @@ -1144,7 +1155,7 @@ mod tests { let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE]; page.iter_mut() .take(data_size) - .for_each(|byte| *byte = rng.gen()); + .for_each(|byte| *byte = rng.random()); page }; @@ -1165,7 +1176,7 @@ mod tests { #[test] fn test_aes_encrypt_decrypt_round_trip() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let cipher_mode = CipherMode::Aes256Gcm; let metadata_size = cipher_mode.metadata_size(); let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size; @@ -1174,7 +1185,7 @@ mod tests { let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE]; page.iter_mut() .take(data_size) - .for_each(|byte| *byte = rng.gen()); + .for_each(|byte| *byte = rng.random()); page }; @@ -1211,7 +1222,7 @@ mod tests { #[test] fn test_aegis256_encrypt_decrypt_round_trip() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let cipher_mode = CipherMode::Aegis256; let metadata_size = cipher_mode.metadata_size(); let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size; @@ -1220,7 +1231,7 @@ mod tests { let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE]; page.iter_mut() .take(data_size) - .for_each(|byte| *byte = rng.gen()); + .for_each(|byte| *byte = rng.random()); page }; @@ -1256,7 +1267,7 @@ mod tests { #[test] fn test_aegis128x2_encrypt_decrypt_round_trip() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let cipher_mode = CipherMode::Aegis128X2; let metadata_size = cipher_mode.metadata_size(); let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size; @@ -1265,7 +1276,7 @@ mod tests { let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE]; page.iter_mut() .take(data_size) - .for_each(|byte| *byte = rng.gen()); + .for_each(|byte| *byte = rng.random()); page }; @@ -1301,7 +1312,7 @@ mod tests { #[test] fn test_aegis128l_encrypt_decrypt_round_trip() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let cipher_mode = CipherMode::Aegis128L; let metadata_size = cipher_mode.metadata_size(); let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size; @@ -1310,7 +1321,7 @@ mod tests { let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE]; page.iter_mut() .take(data_size) - .for_each(|byte| *byte = rng.gen()); + .for_each(|byte| *byte = rng.random()); page }; @@ -1346,7 +1357,7 @@ mod tests { #[test] fn test_aegis128x4_encrypt_decrypt_round_trip() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let cipher_mode = CipherMode::Aegis128X4; let metadata_size = cipher_mode.metadata_size(); let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size; @@ -1355,7 +1366,7 @@ mod tests { let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE]; page.iter_mut() .take(data_size) - .for_each(|byte| *byte = rng.gen()); + .for_each(|byte| *byte = rng.random()); page }; @@ -1391,7 +1402,7 @@ mod tests { #[test] fn test_aegis256x2_encrypt_decrypt_round_trip() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let cipher_mode = CipherMode::Aegis256X2; let metadata_size = cipher_mode.metadata_size(); let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size; @@ -1400,7 +1411,7 @@ mod tests { let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE]; page.iter_mut() .take(data_size) - .for_each(|byte| *byte = rng.gen()); + .for_each(|byte| *byte = rng.random()); page }; @@ -1436,7 +1447,7 @@ mod tests { #[test] fn test_aegis256x4_encrypt_decrypt_round_trip() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let cipher_mode = CipherMode::Aegis256X4; let metadata_size = cipher_mode.metadata_size(); let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size; @@ -1445,7 +1456,7 @@ mod tests { let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE]; page.iter_mut() .take(data_size) - .for_each(|byte| *byte = rng.gen()); + .for_each(|byte| *byte = rng.random()); page }; diff --git a/core/storage/mod.rs b/core/storage/mod.rs index ee5029ba0..f64d26385 100644 --- a/core/storage/mod.rs +++ b/core/storage/mod.rs @@ -22,6 +22,7 @@ pub(crate) mod pager; pub(super) mod slot_bitmap; pub(crate) mod sqlite3_ondisk; mod state_machines; +pub(crate) mod subjournal; #[allow(clippy::arc_with_non_send_sync)] pub(crate) mod wal; diff --git a/core/storage/page_cache.rs b/core/storage/page_cache.rs index 396332eb2..25bfe357a 100644 --- a/core/storage/page_cache.rs +++ b/core/storage/page_cache.rs @@ -425,11 +425,11 @@ impl PageCache { Err(CacheError::Full) } - pub fn clear(&mut self) -> Result<(), CacheError> { + pub fn clear(&mut self, clear_dirty: bool) -> Result<(), CacheError> { // Check all pages are clean for &entry_ptr in self.map.values() { let entry = unsafe { &*entry_ptr }; - if entry.page.is_dirty() { + if entry.page.is_dirty() && !clear_dirty { return Err(CacheError::Dirty { pgno: entry.page.get().id, }); @@ -852,7 +852,7 @@ mod tests { let key1 = insert_page(&mut cache, 1); let key2 = insert_page(&mut cache, 2); - assert!(cache.clear().is_ok()); + assert!(cache.clear(false).is_ok()); assert!(cache.get(&key1).unwrap().is_none()); assert!(cache.get(&key2).unwrap().is_none()); assert_eq!(cache.len(), 0); @@ -1141,7 +1141,7 @@ mod tests { cache.insert(key, page).unwrap(); } - cache.clear().unwrap(); + cache.clear(false).unwrap(); drop(cache); } @@ -1231,7 +1231,7 @@ mod tests { for i in 1..=3 { let _ = insert_page(&mut c, i); } - c.clear().unwrap(); + c.clear(false).unwrap(); // No elements; insert should not rely on stale hand let _ = insert_page(&mut c, 10); diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 3dd8183a8..75d4a47db 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1,4 +1,5 @@ use crate::storage::database::DatabaseFile; +use crate::storage::subjournal::Subjournal; use crate::storage::wal::IOV_MAX; use crate::storage::{ buffer_pool::BufferPool, @@ -10,12 +11,13 @@ use crate::storage::{ }; use crate::types::{IOCompletions, WalState}; use crate::util::IOExt as _; -use crate::{io_yield_many, io_yield_one, IOContext}; use crate::{ - return_if_io, turso_assert, types::WalFrameInfo, Completion, Connection, IOResult, LimboError, - Result, TransactionState, + io::CompletionGroup, return_if_io, turso_assert, types::WalFrameInfo, Completion, Connection, + IOResult, LimboError, Result, TransactionState, }; +use crate::{io_yield_one, CompletionError, IOContext, OpenFlags, IO}; use parking_lot::RwLock; +use roaring::RoaringBitmap; use std::cell::{RefCell, UnsafeCell}; use std::collections::HashSet; use std::hash; @@ -25,6 +27,7 @@ use std::sync::atomic::{ }; use std::sync::{Arc, Mutex}; use tracing::{instrument, trace, Level}; +use turso_macros::AtomicEnum; use super::btree::btree_init_page; use super::page_cache::{CacheError, CacheResizeResult, PageCache, PageCacheKey}; @@ -57,7 +60,7 @@ impl HeaderRef { tracing::trace!("HeaderRef::from_pager - {:?}", state); match state { HeaderRefState::Start => { - if !pager.db_state.is_initialized() { + if !pager.db_state.get().is_initialized() { return Err(LimboError::Page1NotAlloc); } @@ -97,7 +100,7 @@ impl HeaderRefMut { tracing::trace!(?state); match state { HeaderRefState::Start => { - if !pager.db_state.is_initialized() { + if !pager.db_state.get().is_initialized() { return Err(LimboError::Page1NotAlloc); } @@ -114,7 +117,7 @@ impl HeaderRefMut { "incorrect header page id" ); - pager.add_dirty(&page); + pager.add_dirty(&page)?; *pager.header_ref_state.write() = HeaderRefState::Start; break Ok(IOResult::Done(Self(page))); } @@ -416,57 +419,19 @@ impl From for AutoVacuumMode { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(usize)] +#[derive(Debug, AtomicEnum, Clone, Copy, PartialEq, Eq)] pub enum DbState { - Uninitialized = Self::UNINITIALIZED, - Initializing = Self::INITIALIZING, - Initialized = Self::INITIALIZED, + Uninitialized, + Initializing, + Initialized, } impl DbState { - pub(self) const UNINITIALIZED: usize = 0; - pub(self) const INITIALIZING: usize = 1; - pub(self) const INITIALIZED: usize = 2; - - #[inline] pub fn is_initialized(&self) -> bool { matches!(self, DbState::Initialized) } } -#[derive(Debug)] -#[repr(transparent)] -pub struct AtomicDbState(AtomicUsize); - -impl AtomicDbState { - #[inline] - pub const fn new(state: DbState) -> Self { - Self(AtomicUsize::new(state as usize)) - } - - #[inline] - pub fn set(&self, state: DbState) { - self.0.store(state as usize, Ordering::Release); - } - - #[inline] - pub fn get(&self) -> DbState { - let v = self.0.load(Ordering::Acquire); - match v { - DbState::UNINITIALIZED => DbState::Uninitialized, - DbState::INITIALIZING => DbState::Initializing, - DbState::INITIALIZED => DbState::Initialized, - _ => unreachable!(), - } - } - - #[inline] - pub fn is_initialized(&self) -> bool { - self.get().is_initialized() - } -} - #[derive(Debug, Clone)] #[cfg(not(feature = "omit_autovacuum"))] enum PtrMapGetState { @@ -501,6 +466,38 @@ enum BtreeCreateVacuumFullState { PtrMapPut { allocated_page_id: u32 }, } +pub struct Savepoint { + /// Start offset of this savepoint in the subjournal. + start_offset: AtomicU64, + /// Current write offset in the subjournal. + write_offset: AtomicU64, + /// Bitmap of page numbers that are dirty in the savepoint. + page_bitmap: RwLock, + /// Database size at the start of the savepoint. + /// If the database grows during the savepoint and a rollback to the savepoint is performed, + /// the pages exceeding the database size at the start of the savepoint will be ignored. + db_size: AtomicU32, +} + +impl Savepoint { + pub fn new(subjournal_offset: u64, db_size: u32) -> Self { + Self { + start_offset: AtomicU64::new(subjournal_offset), + write_offset: AtomicU64::new(subjournal_offset), + page_bitmap: RwLock::new(RoaringBitmap::new()), + db_size: AtomicU32::new(db_size), + } + } + + pub fn add_dirty_page(&self, page_num: u32) { + self.page_bitmap.write().insert(page_num); + } + + pub fn has_dirty_page(&self, page_num: u32) -> bool { + self.page_bitmap.read().contains(page_num) + } +} + /// The pager interface implements the persistence layer by providing access /// to pages of the database file, including caching, concurrency control, and /// transaction management. @@ -517,7 +514,8 @@ pub struct Pager { /// I/O interface for input/output operations. pub io: Arc, dirty_pages: Arc>>>, - + subjournal: RwLock>, + savepoints: Arc>>, commit_info: RwLock, checkpoint_state: RwLock, syncing: Arc, @@ -537,6 +535,11 @@ pub struct Pager { /// to change it. pub(crate) page_size: AtomicU32, reserved_space: AtomicU16, + /// Schema cookie cache. + /// + /// Note that schema cookie is 32-bits, but we use 64-bit field so we can + /// represent case where value is not set. + schema_cookie: AtomicU64, free_page_state: RwLock, /// Maximum number of pages allowed in the database. Default is 1073741823 (SQLite default). max_page_count: AtomicU32, @@ -548,6 +551,11 @@ pub struct Pager { enable_encryption: AtomicBool, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for Pager {} +unsafe impl Sync for Pager {} + #[cfg(not(feature = "omit_autovacuum"))] pub struct VacuumState { /// State machine for [Pager::ptrmap_get] @@ -616,7 +624,7 @@ impl Pager { db_state: Arc, init_lock: Arc>, ) -> Result { - let allocate_page1_state = if !db_state.is_initialized() { + let allocate_page1_state = if !db_state.get().is_initialized() { RwLock::new(AllocatePage1State::Start) } else { RwLock::new(AllocatePage1State::Done) @@ -630,6 +638,8 @@ impl Pager { dirty_pages: Arc::new(RwLock::new(HashSet::with_hasher( hash::BuildHasherDefault::new(), ))), + subjournal: RwLock::new(None), + savepoints: Arc::new(RwLock::new(Vec::new())), commit_info: RwLock::new(CommitInfo { result: None, completions: Vec::new(), @@ -645,6 +655,7 @@ impl Pager { allocate_page1_state, page_size: AtomicU32::new(0), // 0 means not set reserved_space: AtomicU16::new(RESERVED_SPACE_NOT_SET), + schema_cookie: AtomicU64::new(Self::SCHEMA_COOKIE_NOT_SET), free_page_state: RwLock::new(FreePageState::Start), allocate_page_state: RwLock::new(AllocatePageState::Start), max_page_count: AtomicU32::new(DEFAULT_MAX_PAGE_COUNT), @@ -660,6 +671,239 @@ impl Pager { }) } + pub fn begin_statement(&self, db_size: u32) -> Result<()> { + self.open_subjournal()?; + self.open_savepoint(db_size)?; + Ok(()) + } + + /// Open the subjournal if not yet open. + /// The subjournal is a file that is used to store the "before images" of pages for the + /// current savepoint. If the savepoint is rolled back, the pages can be restored from the subjournal. + /// + /// Currently uses MemoryIO, but should eventually be backed by temporary on-disk files. + pub fn open_subjournal(&self) -> Result<()> { + if self.subjournal.read().is_some() { + return Ok(()); + } + use crate::MemoryIO; + + let db_file_io = Arc::new(MemoryIO::new()); + let file = db_file_io.open_file("subjournal", OpenFlags::Create, false)?; + let db_file = Subjournal::new(file); + *self.subjournal.write() = Some(db_file); + Ok(()) + } + + /// Write page to subjournal if the current savepoint does not currently + /// contain an an entry for it. In case of a statement-level rollback, + /// the page image can be restored from the subjournal. + /// + /// A buffer of length page_size + 4 bytes is allocated and the page id + /// is written to the beginning of the buffer. The rest of the buffer is filled with the page contents. + pub fn subjournal_page_if_required(&self, page: &Page) -> Result<()> { + if self.subjournal.read().is_none() { + return Ok(()); + } + let write_offset = { + let savepoints = self.savepoints.read(); + let Some(cur_savepoint) = savepoints.last() else { + return Ok(()); + }; + if cur_savepoint.has_dirty_page(page.get().id as u32) { + return Ok(()); + } + cur_savepoint.write_offset.load(Ordering::SeqCst) + }; + let page_id = page.get().id; + let page_size = self.page_size.load(Ordering::SeqCst) as usize; + let buffer = { + let page_id = page.get().id as u32; + let contents = page.get_contents(); + let buffer = self.buffer_pool.allocate(page_size + 4); + let contents_buffer = contents.buffer.as_slice(); + turso_assert!( + contents_buffer.len() == page_size, + "contents buffer length should be equal to page size" + ); + + buffer.as_mut_slice()[0..4].copy_from_slice(&page_id.to_be_bytes()); + buffer.as_mut_slice()[4..4 + page_size].copy_from_slice(contents_buffer); + + Arc::new(buffer) + }; + + let savepoints = self.savepoints.clone(); + + let write_complete = { + let buf_copy = buffer.clone(); + Box::new(move |res: Result| { + let Ok(bytes_written) = res else { + return; + }; + let buf_copy = buf_copy.clone(); + let buf_len = buf_copy.len(); + + turso_assert!( + bytes_written == buf_len as i32, + "wrote({bytes_written}) != expected({buf_len})" + ); + + let savepoints = savepoints.read(); + let cur_savepoint = savepoints.last().unwrap(); + cur_savepoint.add_dirty_page(page_id as u32); + cur_savepoint + .write_offset + .fetch_add(page_size as u64 + 4, Ordering::SeqCst); + }) + }; + let c = Completion::new_write(write_complete); + + let subjournal = self.subjournal.read(); + let subjournal = subjournal.as_ref().unwrap(); + + let c = subjournal.write_page(write_offset, page_size, buffer.clone(), c)?; + assert!(c.succeeded(), "memory IO should complete immediately"); + Ok(()) + } + + pub fn open_savepoint(&self, db_size: u32) -> Result<()> { + self.open_subjournal()?; + let subjournal_offset = self.subjournal.read().as_ref().unwrap().size()?; + // Currently as we only have anonymous savepoints opened at the start of a statement, + // the subjournal offset should always be 0 as we should only have max 1 savepoint + // opened at any given time. + turso_assert!(subjournal_offset == 0, "subjournal offset should be 0"); + let savepoint = Savepoint::new(subjournal_offset, db_size); + let mut savepoints = self.savepoints.write(); + turso_assert!( + savepoints.is_empty(), + "savepoints should be empty, but had {} savepoints open", + savepoints.len() + ); + savepoints.push(savepoint); + Ok(()) + } + + /// Release i.e. commit the current savepoint. This basically just means removing it. + pub fn release_savepoint(&self) -> Result<()> { + let mut savepoints = self.savepoints.write(); + let Some(savepoint) = savepoints.pop() else { + return Ok(()); + }; + let subjournal = self.subjournal.read(); + let Some(subjournal) = subjournal.as_ref() else { + return Ok(()); + }; + let start_offset = savepoint.start_offset.load(Ordering::SeqCst); + // Same reason as in open_savepoint, the start offset should always be 0 as we should only have max 1 savepoint + // opened at any given time. + turso_assert!(start_offset == 0, "start offset should be 0"); + let c = subjournal.truncate(start_offset)?; + assert!(c.succeeded(), "memory IO should complete immediately"); + Ok(()) + } + + pub fn clear_savepoints(&self) -> Result<()> { + *self.savepoints.write() = Vec::new(); + let subjournal = self.subjournal.read(); + let Some(subjournal) = subjournal.as_ref() else { + return Ok(()); + }; + let c = subjournal.truncate(0)?; + assert!(c.succeeded(), "memory IO should complete immediately"); + Ok(()) + } + + /// Rollback to the newest savepoint. This basically just means reading the subjournal from the start offset + /// of the savepoint to the end of the subjournal and restoring the page images to the page cache. + pub fn rollback_to_newest_savepoint(&self) -> Result<()> { + let subjournal = self.subjournal.read(); + let Some(subjournal) = subjournal.as_ref() else { + return Ok(()); + }; + let mut savepoints = self.savepoints.write(); + let Some(savepoint) = savepoints.pop() else { + return Ok(()); + }; + let journal_start_offset = savepoint.start_offset.load(Ordering::SeqCst); + + let mut rollback_bitset = RoaringBitmap::new(); + + // Read the subjournal starting from start offset, first reading 4 bytes to get page id, then if rollback_bitset already has the page, skip reading the page + // and just advance the offset. otherwise read the page and add the page id to the rollback_bitset + put the page image into the page cache + let mut current_offset = journal_start_offset; + let page_size = self.page_size.load(Ordering::SeqCst) as u64; + let journal_end_offset = savepoint.write_offset.load(Ordering::SeqCst); + let db_size = savepoint.db_size.load(Ordering::SeqCst); + + let mut dirty_pages = self.dirty_pages.write(); + + while current_offset < journal_end_offset { + // Read 4 bytes for page id + let page_id_buffer = Arc::new(self.buffer_pool.allocate(4)); + let c = subjournal.read_page_number(current_offset, page_id_buffer.clone())?; + assert!(c.succeeded(), "memory IO should complete immediately"); + let page_id = u32::from_be_bytes(page_id_buffer.as_slice()[0..4].try_into().unwrap()); + current_offset += 4; + + // Check if we've already rolled back this page or if the page is beyond the database size at the start of the savepoint + let already_rolled_back = rollback_bitset.contains(page_id); + if already_rolled_back { + current_offset += page_size; + continue; + } + let page_wont_exist_after_rollback = page_id > db_size; + if page_wont_exist_after_rollback { + dirty_pages.remove(&(page_id as usize)); + if let Some(page) = self + .page_cache + .write() + .get(&PageCacheKey::new(page_id as usize))? + { + page.clear_dirty(); + page.try_unpin(); + } + current_offset += page_size; + rollback_bitset.insert(page_id); + continue; + } + + // Read the page data + let page_buffer = Arc::new(self.buffer_pool.allocate(page_size as usize)); + let page = Arc::new(Page::new(page_id as i64)); + let c = subjournal.read_page( + current_offset, + page_buffer.clone(), + page.clone(), + page_size as usize, + )?; + assert!(c.succeeded(), "memory IO should complete immediately"); + current_offset += page_size; + + // Add page to rollback bitset + rollback_bitset.insert(page_id); + + // Put the page image into the page cache + self.upsert_page_in_cache(page_id as usize, page, false)?; + } + + let truncate_completion = self + .subjournal + .read() + .as_ref() + .unwrap() + .truncate(journal_start_offset)?; + assert!( + truncate_completion.succeeded(), + "memory IO should complete immediately" + ); + + self.page_cache.write().truncate(db_size as usize)?; + + Ok(()) + } + #[cfg(feature = "test_helper")] pub fn get_pending_byte() -> u32 { PENDING_BYTE.load(Ordering::Relaxed) @@ -909,7 +1153,7 @@ impl Pager { ptrmap_page.get().id == ptrmap_pg_no, "ptrmap page has unexpected number" ); - self.add_dirty(&ptrmap_page); + self.add_dirty(&ptrmap_page)?; self.vacuum_state.write().ptrmap_put_state = PtrMapPutState::Start; break Ok(IOResult::Done(())); } @@ -1110,6 +1354,41 @@ impl Pager { self.reserved_space.store(space as u16, Ordering::SeqCst); } + /// Schema cookie sentinel value that represents value not set. + const SCHEMA_COOKIE_NOT_SET: u64 = u64::MAX; + + /// Get the cached schema cookie. Returns None if not set yet. + pub fn get_schema_cookie_cached(&self) -> Option { + let value = self.schema_cookie.load(Ordering::SeqCst); + if value == Self::SCHEMA_COOKIE_NOT_SET { + None + } else { + Some(value as u32) + } + } + + /// Set the schema cookie cache. + pub fn set_schema_cookie(&self, cookie: Option) { + match cookie { + Some(value) => { + self.schema_cookie.store(value as u64, Ordering::SeqCst); + } + None => self + .schema_cookie + .store(Self::SCHEMA_COOKIE_NOT_SET, Ordering::SeqCst), + } + } + + /// Get the schema cookie, using the cached value if available to avoid reading page 1. + pub fn get_schema_cookie(&self) -> Result> { + // Try to use cached value first + if let Some(cookie) = self.get_schema_cookie_cached() { + return Ok(IOResult::Done(cookie)); + } + // If not cached, read from header and cache it + self.with_header(|header| header.schema_cookie.get()) + } + #[inline(always)] #[instrument(skip_all, level = Level::DEBUG)] pub fn begin_read_tx(&self) -> Result<()> { @@ -1119,14 +1398,16 @@ impl Pager { let changed = wal.borrow_mut().begin_read_tx()?; if changed { // Someone else changed the database -> assume our page cache is invalid (this is default SQLite behavior, we can probably do better with more granular invalidation) - self.clear_page_cache(); + self.clear_page_cache(false); + // Invalidate cached schema cookie to force re-read on next access + self.set_schema_cookie(None); } Ok(()) } #[instrument(skip_all, level = Level::DEBUG)] pub fn maybe_allocate_page1(&self) -> Result> { - if !self.db_state.is_initialized() { + if !self.db_state.get().is_initialized() { if let Ok(_lock) = self.init_lock.try_lock() { match (self.db_state.get(), self.allocating_page1()) { // In case of being empty or (allocating and this connection is performing allocation) then allocate the first page @@ -1142,7 +1423,7 @@ impl Pager { } } else { // Give a chance for the allocation to happen elsewhere - io_yield_one!(Completion::new_dummy()); + io_yield_one!(Completion::new_yield()); } } Ok(IOResult::Done(())) @@ -1161,33 +1442,20 @@ impl Pager { } #[instrument(skip_all, level = Level::DEBUG)] - pub fn end_tx( - &self, - rollback: bool, - connection: &Connection, - ) -> Result> { + pub fn commit_tx(&self, connection: &Connection) -> Result> { if connection.is_nested_stmt.load(Ordering::SeqCst) { // Parent statement will handle the transaction rollback. return Ok(IOResult::Done(PagerCommitResult::Rollback)); } - tracing::trace!("end_tx(rollback={})", rollback); let Some(wal) = self.wal.as_ref() else { // TODO: Unsure what the semantics of "end_tx" is for in-memory databases, ephemeral tables and ephemeral indexes. return Ok(IOResult::Done(PagerCommitResult::Rollback)); }; - let (is_write, schema_did_change) = match connection.get_tx_state() { + let (_, schema_did_change) = match connection.get_tx_state() { TransactionState::Write { schema_did_change } => (true, schema_did_change), _ => (false, false), }; - tracing::trace!("end_tx(schema_did_change={})", schema_did_change); - if rollback { - if is_write { - wal.borrow().end_write_tx(); - } - wal.borrow().end_read_tx(); - self.rollback(schema_did_change, connection, is_write)?; - return Ok(IOResult::Done(PagerCommitResult::Rollback)); - } + tracing::trace!("commit_tx(schema_did_change={})", schema_did_change); let commit_status = return_if_io!(self.commit_dirty_pages( connection.is_wal_auto_checkpoint_disabled(), connection.get_sync_mode(), @@ -1198,18 +1466,41 @@ impl Pager { if schema_did_change { let schema = connection.schema.read().clone(); - connection.db.update_schema_if_newer(schema)?; + connection.db.update_schema_if_newer(schema); } Ok(IOResult::Done(commit_status)) } #[instrument(skip_all, level = Level::DEBUG)] - pub fn end_read_tx(&self) -> Result<()> { + pub fn rollback_tx(&self, connection: &Connection) { + if connection.is_nested_stmt.load(Ordering::SeqCst) { + // Parent statement will handle the transaction rollback. + return; + } let Some(wal) = self.wal.as_ref() else { - return Ok(()); + // TODO: Unsure what the semantics of "end_tx" is for in-memory databases, ephemeral tables and ephemeral indexes. + return; + }; + let (is_write, schema_did_change) = match connection.get_tx_state() { + TransactionState::Write { schema_did_change } => (true, schema_did_change), + _ => (false, false), + }; + tracing::trace!("rollback_tx(schema_did_change={})", schema_did_change); + if is_write { + self.clear_savepoints() + .expect("in practice, clear_savepoints() should never fail as it uses memory IO"); + wal.borrow().end_write_tx(); + } + wal.borrow().end_read_tx(); + self.rollback(schema_did_change, connection, is_write); + } + + #[instrument(skip_all, level = Level::DEBUG)] + pub fn end_read_tx(&self) { + let Some(wal) = self.wal.as_ref() else { + return; }; wal.borrow().end_read_tx(); - Ok(()) } /// Reads a page from disk (either WAL or DB file) bypassing page-cache @@ -1357,11 +1648,18 @@ impl Pager { Ok(page_cache.resize(capacity)) } - pub fn add_dirty(&self, page: &Page) { + pub fn add_dirty(&self, page: &Page) -> Result<()> { + turso_assert!( + page.is_loaded(), + "page {} must be loaded in add_dirty() so its contents can be subjournaled", + page.get().id + ); + self.subjournal_page_if_required(page)?; // TODO: check duplicates? let mut dirty_pages = self.dirty_pages.write(); dirty_pages.insert(page.get().id); page.set_dirty(); + Ok(()) } pub fn wal_state(&self) -> Result { @@ -1619,20 +1917,21 @@ impl Pager { CommitState::Checkpoint => { match self.checkpoint()? { IOResult::IO(cmp) => { - let completions = { + let completion = { let mut commit_info = self.commit_info.write(); match cmp { IOCompletions::Single(c) => { commit_info.completions.push(c); } - IOCompletions::Many(c) => { - commit_info.completions.extend(c); - } } - std::mem::take(&mut commit_info.completions) + let mut group = CompletionGroup::new(|_| {}); + for c in commit_info.completions.drain(..) { + group.add(&c); + } + group.build() }; // TODO: remove serialization of checkpoint path - io_yield_many!(completions); + io_yield_one!(completion); } IOResult::Done(res) => { let mut commit_info = self.commit_info.write(); @@ -1671,21 +1970,25 @@ impl Pager { .unwrap() .as_millis() ); - let (should_finish, result, completions) = { + let (should_finish, result, completion) = { let mut commit_info = self.commit_info.write(); if commit_info.completions.iter().all(|c| c.succeeded()) { commit_info.completions.clear(); commit_info.state = CommitState::PrepareWal; - (true, commit_info.result.take(), Vec::new()) + (true, commit_info.result.take(), Completion::new_yield()) } else { - (false, None, std::mem::take(&mut commit_info.completions)) + let mut group = CompletionGroup::new(|_| {}); + for c in commit_info.completions.drain(..) { + group.add(&c); + } + (false, None, group.build()) } }; if should_finish { wal.borrow_mut().finish_append_frames_commit()?; return Ok(IOResult::Done(result.expect("commit result should be set"))); } - io_yield_many!(completions); + io_yield_one!(completion); } } } @@ -1800,7 +2103,7 @@ impl Pager { /// Invalidates entire page cache by removing all dirty and clean pages. Usually used in case /// of a rollback or in case we want to invalidate page cache after starting a read transaction /// right after new writes happened which would invalidate current page cache. - pub fn clear_page_cache(&self) { + pub fn clear_page_cache(&self, clear_dirty: bool) { let dirty_pages = self.dirty_pages.read(); let mut cache = self.page_cache.write(); for page_id in dirty_pages.iter() { @@ -1809,7 +2112,9 @@ impl Pager { page.clear_dirty(); } } - cache.clear().expect("Failed to clear page cache"); + cache + .clear(clear_dirty) + .expect("Failed to clear page cache"); } /// Checkpoint in Truncate mode and delete the WAL file. This method is _only_ to be called @@ -1914,7 +2219,7 @@ impl Pager { // TODO: only clear cache of things that are really invalidated self.page_cache .write() - .clear() + .clear(false) .map_err(|e| LimboError::InternalError(format!("Failed to clear page cache: {e:?}")))?; Ok(IOResult::Done(())) } @@ -2011,7 +2316,7 @@ impl Pager { trunk_page.get().id == trunk_page_id as usize, "trunk page has unexpected id" ); - self.add_dirty(&trunk_page); + self.add_dirty(&trunk_page)?; trunk_page_contents.write_u32_no_offset( TRUNK_PAGE_LEAF_COUNT_OFFSET, @@ -2031,7 +2336,7 @@ impl Pager { turso_assert!(page.is_loaded(), "page should be loaded"); // If we get here, need to make this page a new trunk turso_assert!(page.get().id == page_id, "page has unexpected id"); - self.add_dirty(page); + self.add_dirty(page)?; let trunk_page_id = header.freelist_trunk_page.get(); @@ -2171,7 +2476,7 @@ impl Pager { // we will allocate a ptrmap page, so increment size new_db_size += 1; let page = allocate_new_page(new_db_size as i64, &self.buffer_pool, 0); - self.add_dirty(&page); + self.add_dirty(&page)?; let page_key = PageCacheKey::new(page.get().id as usize); let mut cache = self.page_cache.write(); cache.insert(page_key, page.clone())?; @@ -2247,7 +2552,7 @@ impl Pager { // and update the database's first freelist trunk page to the next trunk page. header.freelist_trunk_page = next_trunk_page_id.into(); header.freelist_pages = (header.freelist_pages.get() - 1).into(); - self.add_dirty(trunk_page); + self.add_dirty(trunk_page)?; // zero out the page turso_assert!( trunk_page.get_contents().overflow_cells.is_empty(), @@ -2279,7 +2584,7 @@ impl Pager { leaf_page.get().id ); let page_contents = trunk_page.get_contents(); - self.add_dirty(leaf_page); + self.add_dirty(leaf_page)?; // zero out the page turso_assert!( leaf_page.get_contents().overflow_cells.is_empty(), @@ -2317,7 +2622,7 @@ impl Pager { FREELIST_TRUNK_OFFSET_LEAF_COUNT, remaining_leaves_count as u32, ); - self.add_dirty(trunk_page); + self.add_dirty(trunk_page)?; header.freelist_pages = (header.freelist_pages.get() - 1).into(); let leaf_page = leaf_page.clone(); @@ -2331,7 +2636,7 @@ impl Pager { if Some(new_db_size) == self.pending_byte_page_id() { let richard_hipp_special_page = allocate_new_page(new_db_size as i64, &self.buffer_pool, 0); - self.add_dirty(&richard_hipp_special_page); + self.add_dirty(&richard_hipp_special_page)?; let page_key = PageCacheKey::new(richard_hipp_special_page.get().id); { let mut cache = self.page_cache.write(); @@ -2355,7 +2660,7 @@ impl Pager { let page = allocate_new_page(new_db_size as i64, &self.buffer_pool, 0); { // setup page and add to cache - self.add_dirty(&page); + self.add_dirty(&page)?; let page_key = PageCacheKey::new(page.get().id as usize); { @@ -2372,16 +2677,19 @@ impl Pager { } } - pub fn update_dirty_loaded_page_in_cache( + pub fn upsert_page_in_cache( &self, id: usize, page: PageRef, + dirty_page_must_exist: bool, ) -> Result<(), LimboError> { let mut cache = self.page_cache.write(); let page_key = PageCacheKey::new(id); // FIXME: use specific page key for writer instead of max frame, this will make readers not conflict - assert!(page.is_dirty()); + if dirty_page_must_exist { + assert!(page.is_dirty()); + } cache.upsert_page(page_key, page.clone()).map_err(|e| { LimboError::InternalError(format!( "Failed to insert loaded page {id} into cache: {e:?}" @@ -2393,14 +2701,9 @@ impl Pager { } #[instrument(skip_all, level = Level::DEBUG)] - pub fn rollback( - &self, - schema_did_change: bool, - connection: &Connection, - is_write: bool, - ) -> Result<(), LimboError> { + pub fn rollback(&self, schema_did_change: bool, connection: &Connection, is_write: bool) { tracing::debug!(schema_did_change); - self.clear_page_cache(); + self.clear_page_cache(is_write); if is_write { self.dirty_pages.write().clear(); } else { @@ -2410,16 +2713,16 @@ impl Pager { ); } self.reset_internal_states(); + // Invalidate cached schema cookie since rollback may have restored the database schema cookie + self.set_schema_cookie(None); if schema_did_change { - *connection.schema.write() = connection.db.clone_schema()?; + *connection.schema.write() = connection.db.clone_schema(); } if is_write { if let Some(wal) = self.wal.as_ref() { - wal.borrow_mut().rollback()?; + wal.borrow_mut().rollback(); } } - - Ok(()) } fn reset_internal_states(&self) { @@ -2443,13 +2746,18 @@ impl Pager { pub fn with_header(&self, f: impl Fn(&DatabaseHeader) -> T) -> Result> { let header_ref = return_if_io!(HeaderRef::from_pager(self)); let header = header_ref.borrow(); + // Update cached schema cookie when reading header + self.set_schema_cookie(Some(header.schema_cookie.get())); Ok(IOResult::Done(f(header))) } pub fn with_header_mut(&self, f: impl Fn(&mut DatabaseHeader) -> T) -> Result> { let header_ref = return_if_io!(HeaderRefMut::from_pager(self)); let header = header_ref.borrow_mut(); - Ok(IOResult::Done(f(header))) + let result = f(header); + // Update cached schema cookie after modification + self.set_schema_cookie(Some(header.schema_cookie.get())); + Ok(IOResult::Done(result)) } pub fn is_encryption_ctx_set(&self) -> bool { @@ -2483,7 +2791,7 @@ impl Pager { // might have been loaded with page 1 to initialise the connection. During initialisation, // we only read the header which is unencrypted, but the rest of the page is. If so, lets // clear the cache. - self.clear_page_cache(); + self.clear_page_cache(false); Ok(()) } @@ -2764,7 +3072,7 @@ mod ptrmap_tests { use super::*; use crate::io::{MemoryIO, OpenFlags, IO}; use crate::storage::buffer_pool::BufferPool; - use crate::storage::database::{DatabaseFile, DatabaseStorage}; + use crate::storage::database::DatabaseFile; use crate::storage::page_cache::PageCache; use crate::storage::pager::Pager; use crate::storage::sqlite3_ondisk::PageSize; diff --git a/core/storage/slot_bitmap.rs b/core/storage/slot_bitmap.rs index 140eb708d..86050d3e9 100644 --- a/core/storage/slot_bitmap.rs +++ b/core/storage/slot_bitmap.rs @@ -516,14 +516,14 @@ pub mod tests { ]; for &seed in seeds { let mut rng = StdRng::seed_from_u64(seed); - let n_slots = rng.gen_range(1..10) * 64; + let n_slots = rng.random_range(1..10) * 64; let mut pb = SlotBitmap::new(n_slots); let mut model = vec![true; n_slots as usize]; let iters = 2000usize; for _ in 0..iters { - let op = rng.gen_range(0..100); + let op = rng.random_range(0..100); match op { 0..=49 => { // alloc_one @@ -540,8 +540,9 @@ pub mod tests { } 50..=79 => { // alloc_run with random length - let need = - rng.gen_range(1..=std::cmp::max(1, (n_slots as usize).min(128))) as u32; + let need = rng + .random_range(1..=std::cmp::max(1, (n_slots as usize).min(128))) + as u32; let got = pb.alloc_run(need); if let Some(start) = got { assert!(start + need <= n_slots, "within bounds"); @@ -560,13 +561,14 @@ pub mod tests { } _ => { // free_run on a random valid range - let len = - rng.gen_range(1..=std::cmp::max(1, (n_slots as usize).min(128))) as u32; + let len = rng + .random_range(1..=std::cmp::max(1, (n_slots as usize).min(128))) + as u32; let max_start = n_slots.saturating_sub(len); let start = if max_start == 0 { 0 } else { - rng.gen_range(0..=max_start) + rng.random_range(0..=max_start) }; pb.free_run(start, len); ref_mark_run(&mut model, start, len, true); diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 5fb353804..93f897dc1 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -61,7 +61,7 @@ use crate::storage::buffer_pool::BufferPool; use crate::storage::database::{DatabaseFile, DatabaseStorage, EncryptionOrChecksum}; use crate::storage::pager::Pager; use crate::storage::wal::READMARK_NOT_USED; -use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype}; +use crate::types::{SerialType, SerialTypeKind, TextSubtype, ValueRef}; use crate::{ bail_corrupt_error, turso_assert, CompletionError, File, IOContext, Result, WalFileShared, }; @@ -1320,22 +1320,22 @@ impl Iterator for SmallVecIter<'_, T, N> { /// Reads a value that might reference the buffer it is reading from. Be sure to store RefValue with the buffer /// always. #[inline(always)] -pub fn read_value(buf: &[u8], serial_type: SerialType) -> Result<(RefValue, usize)> { +pub fn read_value<'a>(buf: &'a [u8], serial_type: SerialType) -> Result<(ValueRef<'a>, usize)> { match serial_type.kind() { - SerialTypeKind::Null => Ok((RefValue::Null, 0)), + SerialTypeKind::Null => Ok((ValueRef::Null, 0)), SerialTypeKind::I8 => { if buf.is_empty() { crate::bail_corrupt_error!("Invalid UInt8 value"); } let val = buf[0] as i8; - Ok((RefValue::Integer(val as i64), 1)) + Ok((ValueRef::Integer(val as i64), 1)) } SerialTypeKind::I16 => { if buf.len() < 2 { crate::bail_corrupt_error!("Invalid BEInt16 value"); } Ok(( - RefValue::Integer(i16::from_be_bytes([buf[0], buf[1]]) as i64), + ValueRef::Integer(i16::from_be_bytes([buf[0], buf[1]]) as i64), 2, )) } @@ -1345,7 +1345,7 @@ pub fn read_value(buf: &[u8], serial_type: SerialType) -> Result<(RefValue, usiz } let sign_extension = if buf[0] <= 127 { 0 } else { 255 }; Ok(( - RefValue::Integer( + ValueRef::Integer( i32::from_be_bytes([sign_extension, buf[0], buf[1], buf[2]]) as i64 ), 3, @@ -1356,7 +1356,7 @@ pub fn read_value(buf: &[u8], serial_type: SerialType) -> Result<(RefValue, usiz crate::bail_corrupt_error!("Invalid BEInt32 value"); } Ok(( - RefValue::Integer(i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64), + ValueRef::Integer(i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64), 4, )) } @@ -1366,7 +1366,7 @@ pub fn read_value(buf: &[u8], serial_type: SerialType) -> Result<(RefValue, usiz } let sign_extension = if buf[0] <= 127 { 0 } else { 255 }; Ok(( - RefValue::Integer(i64::from_be_bytes([ + ValueRef::Integer(i64::from_be_bytes([ sign_extension, sign_extension, buf[0], @@ -1384,7 +1384,7 @@ pub fn read_value(buf: &[u8], serial_type: SerialType) -> Result<(RefValue, usiz crate::bail_corrupt_error!("Invalid BEInt64 value"); } Ok(( - RefValue::Integer(i64::from_be_bytes([ + ValueRef::Integer(i64::from_be_bytes([ buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], ])), 8, @@ -1395,26 +1395,20 @@ pub fn read_value(buf: &[u8], serial_type: SerialType) -> Result<(RefValue, usiz crate::bail_corrupt_error!("Invalid BEFloat64 value"); } Ok(( - RefValue::Float(f64::from_be_bytes([ + ValueRef::Float(f64::from_be_bytes([ buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], ])), 8, )) } - SerialTypeKind::ConstInt0 => Ok((RefValue::Integer(0), 0)), - SerialTypeKind::ConstInt1 => Ok((RefValue::Integer(1), 0)), + SerialTypeKind::ConstInt0 => Ok((ValueRef::Integer(0), 0)), + SerialTypeKind::ConstInt1 => Ok((ValueRef::Integer(1), 0)), SerialTypeKind::Blob => { let content_size = serial_type.size(); if buf.len() < content_size { crate::bail_corrupt_error!("Invalid Blob value"); } - if content_size == 0 { - Ok((RefValue::Blob(RawSlice::new(std::ptr::null(), 0)), 0)) - } else { - let ptr = &buf[0] as *const u8; - let slice = RawSlice::new(ptr, content_size); - Ok((RefValue::Blob(slice), content_size)) - } + Ok((ValueRef::Blob(&buf[..content_size]), content_size)) } SerialTypeKind::Text => { let content_size = serial_type.size(); @@ -1427,10 +1421,7 @@ pub fn read_value(buf: &[u8], serial_type: SerialType) -> Result<(RefValue, usiz } Ok(( - RefValue::Text(TextRef::create_from( - &buf[..content_size], - TextSubtype::Text, - )), + ValueRef::Text(&buf[..content_size], TextSubtype::Text), content_size, )) } @@ -1495,6 +1486,44 @@ pub fn read_integer(buf: &[u8], serial_type: u8) -> Result { } } +/// Fast varint reader optimized for the common cases of 1-byte and 2-byte varints. +/// +/// This function is a performance-optimized version of `read_varint()` that handles +/// the most common varint cases inline before falling back to the full implementation. +/// It follows the same varint encoding as SQLite. +/// +/// # Optimized Cases +/// +/// - **Single-byte case**: Values 0-127 (0x00-0x7F) are returned immediately +/// - **Two-byte case**: Values 128-16383 (0x80-0x3FFF) are handled inline +/// - **Multi-byte case**: Larger values fall back to the full `read_varint()` implementation +/// +/// This function is similar to `sqlite3GetVarint32` +#[inline(always)] +pub fn read_varint_fast(buf: &[u8]) -> Result<(u64, usize)> { + // Fast path: Single-byte varint + if let Some(&first_byte) = buf.first() { + if first_byte & 0x80 == 0 { + return Ok((first_byte as u64, 1)); + } + } else { + crate::bail_corrupt_error!("Invalid varint"); + } + + // Fast path: Two-byte varint + if let Some(&second_byte) = buf.get(1) { + if second_byte & 0x80 == 0 { + let v = (((buf[0] & 0x7f) as u64) << 7) + (second_byte as u64); + return Ok((v, 2)); + } + } else { + crate::bail_corrupt_error!("Invalid varint"); + } + + //Fallback: Multi-byte varint + read_varint(buf) +} + #[inline(always)] pub fn read_varint(buf: &[u8]) -> Result<(u64, usize)> { let mut v: u64 = 0; @@ -1572,7 +1601,7 @@ pub fn write_varint(buf: &mut [u8], value: u64) -> usize { return 9; } - let mut encoded: [u8; 10] = [0; 10]; + let mut encoded: [u8; 9] = [0; 9]; let mut bytes = value; let mut n = 0; while bytes != 0 { @@ -1746,7 +1775,7 @@ impl StreamingWalReader { .min((self.file_size - offset) as usize); if read_size == 0 { // end-of-file; let caller finalize - return Ok((0, Completion::new_dummy())); + return Ok((0, Completion::new_yield())); } let buf = Arc::new(Buffer::new_temporary(read_size)); @@ -1921,7 +1950,7 @@ impl StreamingWalReader { wfs.loaded.store(true, Ordering::SeqCst); self.done.store(true, Ordering::Release); - tracing::info!( + tracing::debug!( "WAL loading complete: {} frames processed, last commit at frame {}", st.frame_idx - 1, max_frame diff --git a/core/storage/subjournal.rs b/core/storage/subjournal.rs new file mode 100644 index 000000000..62a30cf1d --- /dev/null +++ b/core/storage/subjournal.rs @@ -0,0 +1,88 @@ +use std::sync::Arc; + +use crate::{ + storage::sqlite3_ondisk::finish_read_page, Buffer, Completion, CompletionError, PageRef, Result, +}; + +#[derive(Clone)] +pub struct Subjournal { + file: Arc, +} + +impl Subjournal { + pub fn new(file: Arc) -> Self { + Self { file } + } + + pub fn size(&self) -> Result { + self.file.size() + } + + pub fn write_page( + &self, + offset: u64, + page_size: usize, + buffer: Arc, + c: Completion, + ) -> Result { + assert!( + buffer.len() == page_size + 4, + "buffer length should be page_size + 4 bytes for page id" + ); + self.file.pwrite(offset, buffer, c) + } + + pub fn read_page_number(&self, offset: u64, page_id_buffer: Arc) -> Result { + assert!( + page_id_buffer.len() == 4, + "page_id_buffer length should be 4 bytes" + ); + let c = Completion::new_read( + page_id_buffer, + move |res: Result<(Arc, i32), CompletionError>| { + let Ok((_buffer, _bytes_read)) = res else { + return; + }; + }, + ); + let c = self.file.pread(offset, c)?; + Ok(c) + } + + pub fn read_page( + &self, + offset: u64, + buffer: Arc, + page: PageRef, + page_size: usize, + ) -> Result { + assert!( + buffer.len() == page_size, + "buffer length should be page_size" + ); + let c = Completion::new_read( + buffer, + move |res: Result<(Arc, i32), CompletionError>| { + let Ok((buffer, bytes_read)) = res else { + return; + }; + assert!( + bytes_read == page_size as i32, + "bytes_read should be page_size" + ); + finish_read_page(page.get().id, buffer, page.clone()); + }, + ); + let c = self.file.pread(offset, c)?; + Ok(c) + } + + pub fn truncate(&self, offset: u64) -> Result { + let c = Completion::new_trunc(move |res: Result| { + let Ok(_) = res else { + return; + }; + }); + self.file.truncate(offset, c) + } +} diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 61791c5ee..941e07f4d 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -171,7 +171,7 @@ impl TursoRwLock { // for success, Acquire establishes happens-before relationship with the previous Release from unlock // for failure we only care about reading it for the next iteration so we can use Relaxed. self.0 - .compare_exchange_weak(cur, desired, Ordering::Acquire, Ordering::Relaxed) + .compare_exchange(cur, desired, Ordering::Acquire, Ordering::Relaxed) .is_ok() } @@ -302,7 +302,7 @@ pub trait Wal: Debug { fn get_checkpoint_seq(&self) -> u32; fn get_max_frame(&self) -> u64; fn get_min_frame(&self) -> u64; - fn rollback(&mut self) -> Result<()>; + fn rollback(&mut self); /// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no fn changed_pages_after(&self, frame_watermark: u64) -> Result>; @@ -822,16 +822,16 @@ impl Wal for WalFile { self.max_frame_read_lock_index.load(Ordering::Acquire), NO_LOCK_HELD ); - let (shared_max, nbackfills, last_checksum, checkpoint_seq, transaction_count) = { - let shared = self.get_shared(); - let mx = shared.max_frame.load(Ordering::Acquire); - let nb = shared.nbackfills.load(Ordering::Acquire); - let ck = shared.last_checksum; - let checkpoint_seq = shared.wal_header.lock().checkpoint_seq; - let transaction_count = shared.transaction_count.load(Ordering::Acquire); - (mx, nb, ck, checkpoint_seq, transaction_count) - }; - let db_changed = self.db_changed(&self.get_shared()); + let (shared_max, nbackfills, last_checksum, checkpoint_seq, transaction_count) = self + .with_shared(|shared| { + let mx = shared.max_frame.load(Ordering::Acquire); + let nb = shared.nbackfills.load(Ordering::Acquire); + let ck = shared.last_checksum; + let checkpoint_seq = shared.wal_header.lock().checkpoint_seq; + let transaction_count = shared.transaction_count.load(Ordering::Acquire); + (mx, nb, ck, checkpoint_seq, transaction_count) + }); + let db_changed = self.with_shared(|shared| self.db_changed(shared)); // WAL is already fully back‑filled into the main DB image // (mxFrame == nBackfill). Readers can therefore ignore the @@ -840,7 +840,7 @@ impl Wal for WalFile { if shared_max == nbackfills { tracing::debug!("begin_read_tx: WAL is already fully back‑filled into the main DB image, shared_max={}, nbackfills={}", shared_max, nbackfills); let lock_0_idx = 0; - if !self.get_shared().read_locks[lock_0_idx].read() { + if !self.with_shared(|shared| shared.read_locks[lock_0_idx].read()) { tracing::debug!("begin_read_tx: read lock 0 is already held, returning Busy"); return Err(LimboError::Busy); } @@ -864,33 +864,31 @@ impl Wal for WalFile { // Find largest mark <= mx among slots 1..N let mut best_idx: i64 = -1; let mut best_mark: u32 = 0; - for (idx, lock) in self.get_shared().read_locks.iter().enumerate().skip(1) { - let m = lock.get_value(); - if m != READMARK_NOT_USED && m <= shared_max as u32 && m > best_mark { - best_mark = m; - best_idx = idx as i64; + self.with_shared(|shared| { + for (idx, lock) in shared.read_locks.iter().enumerate().skip(1) { + let m = lock.get_value(); + if m != READMARK_NOT_USED && m <= shared_max as u32 && m > best_mark { + best_mark = m; + best_idx = idx as i64; + } } - } + }); // If none found or lagging, try to claim/update a slot if best_idx == -1 || (best_mark as u64) < shared_max { - for (idx, lock) in self - .get_shared_mut() - .read_locks - .iter_mut() - .enumerate() - .skip(1) - { - if !lock.write() { - continue; // busy slot + self.with_shared_mut(|shared| { + for (idx, lock) in shared.read_locks.iter_mut().enumerate().skip(1) { + if !lock.write() { + continue; // busy slot + } + // claim or bump this slot + lock.set_value_exclusive(shared_max as u32); + best_idx = idx as i64; + best_mark = shared_max as u32; + lock.unlock(); + break; } - // claim or bump this slot - lock.set_value_exclusive(shared_max as u32); - best_idx = idx as i64; - best_mark = shared_max as u32; - lock.unlock(); - break; - } + }) } if best_idx == -1 || best_mark != shared_max as u32 { @@ -901,20 +899,19 @@ impl Wal for WalFile { // Now take a shared read on that slot, and if we are successful, // grab another snapshot of the shared state. - let (mx2, nb2, cksm2, ckpt_seq2) = { - let shared = self.get_shared(); + let (mx2, nb2, cksm2, ckpt_seq2) = self.with_shared(|shared| { if !shared.read_locks[best_idx as usize].read() { // TODO: we should retry here instead of always returning Busy return Err(LimboError::Busy); } let checkpoint_seq = shared.wal_header.lock().checkpoint_seq; - ( + Ok(( shared.max_frame.load(Ordering::Acquire), shared.nbackfills.load(Ordering::Acquire), shared.last_checksum, checkpoint_seq, - ) - }; + )) + })?; // sqlite/src/wal.c 3225 // Now that the read-lock has been obtained, check that neither the @@ -967,7 +964,7 @@ impl Wal for WalFile { fn end_read_tx(&self) { let slot = self.max_frame_read_lock_index.load(Ordering::Acquire); if slot != NO_LOCK_HELD { - self.get_shared_mut().read_locks[slot].unlock(); + self.with_shared_mut(|shared| shared.read_locks[slot].unlock()); self.max_frame_read_lock_index .store(NO_LOCK_HELD, Ordering::Release); tracing::debug!("end_read_tx(slot={slot})"); @@ -979,36 +976,36 @@ impl Wal for WalFile { /// Begin a write transaction #[instrument(skip_all, level = Level::DEBUG)] fn begin_write_tx(&mut self) -> Result<()> { - let shared = self.get_shared_mut(); - // sqlite/src/wal.c 3702 - // Cannot start a write transaction without first holding a read - // transaction. - // assert(pWal->readLock >= 0); - // assert(pWal->writeLock == 0 && pWal->iReCksum == 0); - turso_assert!( - self.max_frame_read_lock_index.load(Ordering::Acquire) != NO_LOCK_HELD, - "must have a read transaction to begin a write transaction" - ); - if !shared.write_lock.write() { - return Err(LimboError::Busy); - } - let db_changed = self.db_changed(&shared); - if !db_changed { - drop(shared); - return Ok(()); - } + self.with_shared_mut(|shared| { + // sqlite/src/wal.c 3702 + // Cannot start a write transaction without first holding a read + // transaction. + // assert(pWal->readLock >= 0); + // assert(pWal->writeLock == 0 && pWal->iReCksum == 0); + turso_assert!( + self.max_frame_read_lock_index.load(Ordering::Acquire) != NO_LOCK_HELD, + "must have a read transaction to begin a write transaction" + ); + if !shared.write_lock.write() { + return Err(LimboError::Busy); + } + let db_changed = self.db_changed(shared); + if !db_changed { + return Ok(()); + } - // Snapshot is stale, give up and let caller retry from scratch - tracing::debug!("unable to upgrade transaction from read to write: snapshot is stale, give up and let caller retry from scratch, self.max_frame={}, shared_max={}", self.max_frame.load(Ordering::Acquire), shared.max_frame.load(Ordering::Acquire)); - shared.write_lock.unlock(); - Err(LimboError::Busy) + // Snapshot is stale, give up and let caller retry from scratch + tracing::debug!("unable to upgrade transaction from read to write: snapshot is stale, give up and let caller retry from scratch, self.max_frame={}, shared_max={}", self.max_frame.load(Ordering::Acquire), shared.max_frame.load(Ordering::Acquire)); + shared.write_lock.unlock(); + Err(LimboError::Busy) + }) } /// End a write transaction #[instrument(skip_all, level = Level::DEBUG)] fn end_write_tx(&self) { tracing::debug!("end_write_txn"); - self.get_shared().write_lock.unlock(); + self.with_shared(|shared| shared.write_lock.unlock()); } /// Find the latest frame containing a page. @@ -1029,10 +1026,13 @@ impl Wal for WalFile { // // if it's not, than pages from WAL range [frame_watermark..nBackfill] are already in the DB file, // and in case if page first occurrence in WAL was after frame_watermark - we will be unable to read proper previous version of the page - turso_assert!( - frame_watermark.is_none() || frame_watermark.unwrap() >= self.get_shared().nbackfills.load(Ordering::Acquire), - "frame_watermark must be >= than current WAL backfill amount: frame_watermark={:?}, nBackfill={}", frame_watermark, self.get_shared().nbackfills.load(Ordering::Acquire) - ); + self.with_shared(|shared| { + let nbackfills = shared.nbackfills.load(Ordering::Acquire); + turso_assert!( + frame_watermark.is_none() || frame_watermark.unwrap() >= nbackfills, + "frame_watermark must be >= than current WAL backfill amount: frame_watermark={:?}, nBackfill={}", frame_watermark, nbackfills + ); + }); // if we are holding read_lock 0 and didn't write anything to the WAL, skip and read right from db file. // @@ -1050,30 +1050,31 @@ impl Wal for WalFile { ); return Ok(None); } - let shared = self.get_shared(); - let frames = shared.frame_cache.lock(); - let range = frame_watermark.map(|x| 0..=x).unwrap_or( - self.min_frame.load(Ordering::Acquire)..=self.max_frame.load(Ordering::Acquire), - ); - tracing::debug!( - "find_frame(page_id={}, frame_watermark={:?}): min_frame={}, max_frame={}", - page_id, - frame_watermark, - self.min_frame.load(Ordering::Acquire), - self.max_frame.load(Ordering::Acquire) - ); - if let Some(list) = frames.get(&page_id) { - if let Some(f) = list.iter().rfind(|&&f| range.contains(&f)) { - tracing::debug!( - "find_frame(page_id={}, frame_watermark={:?}): found frame={}", - page_id, - frame_watermark, - *f - ); - return Ok(Some(*f)); + self.with_shared(|shared| { + let frames = shared.frame_cache.lock(); + let range = frame_watermark.map(|x| 0..=x).unwrap_or( + self.min_frame.load(Ordering::Acquire)..=self.max_frame.load(Ordering::Acquire), + ); + tracing::debug!( + "find_frame(page_id={}, frame_watermark={:?}): min_frame={}, max_frame={}", + page_id, + frame_watermark, + self.min_frame.load(Ordering::Acquire), + self.max_frame.load(Ordering::Acquire) + ); + if let Some(list) = frames.get(&page_id) { + if let Some(f) = list.iter().rfind(|&&f| range.contains(&f)) { + tracing::debug!( + "find_frame(page_id={}, frame_watermark={:?}): found frame={}", + page_id, + frame_watermark, + *f + ); + return Ok(Some(*f)); + } } - } - Ok(None) + Ok(None) + }) } /// Read a frame from the WAL. @@ -1110,8 +1111,7 @@ impl Wal for WalFile { let epoch = shared_file.read().epoch.load(Ordering::Acquire); frame.set_wal_tag(frame_id, epoch); }); - let file = { - let shared = self.get_shared(); + let file = self.with_shared(|shared| { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); // important not to hold shared lock beyond this point to avoid deadlock scenario where: // thread 1: takes readlock here, passes reference to shared.file to begin_read_wal_frame @@ -1125,7 +1125,7 @@ impl Wal for WalFile { // when there are writers waiting to acquire the lock. // Because of this, attempts to recursively acquire a read lock within a single thread may result in a deadlock." shared.file.as_ref().unwrap().clone() - }; + }); begin_read_wal_frame( file.as_ref(), offset + WAL_FRAME_HEADER_SIZE as u64, @@ -1184,9 +1184,10 @@ impl Wal for WalFile { } } }); - let shared = self.get_shared(); - assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - let file = shared.file.as_ref().unwrap(); + let file = self.with_shared(|shared| { + assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); + shared.file.as_ref().unwrap().clone() + }); let c = begin_read_wal_frame_raw(&self.buffer_pool, file.as_ref(), offset, complete)?; Ok(c) } @@ -1243,9 +1244,10 @@ impl Wal for WalFile { } } }); - let shared = self.get_shared(); - assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - let file = shared.file.as_ref().unwrap(); + let file = self.with_shared(|shared| { + assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); + shared.file.as_ref().unwrap().clone() + }); let c = begin_read_wal_frame( file.as_ref(), offset + WAL_FRAME_HEADER_SIZE as u64, @@ -1266,13 +1268,12 @@ impl Wal for WalFile { // perform actual write let offset = self.frame_offset(frame_id); - let (header, file) = { - let shared = self.get_shared(); + let (header, file) = self.with_shared(|shared| { let header = shared.wal_header.clone(); assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); let file = shared.file.as_ref().unwrap().clone(); (header, file) - }; + }); let header = header.lock(); let checksums = self.last_checksum; let (checksums, frame_bytes) = prepare_wal_frame( @@ -1296,10 +1297,11 @@ impl Wal for WalFile { #[instrument(skip_all, level = Level::DEBUG)] fn should_checkpoint(&self) -> bool { - let shared = self.get_shared(); - let frame_id = shared.max_frame.load(Ordering::Acquire) as usize; - let nbackfills = shared.nbackfills.load(Ordering::Acquire) as usize; - frame_id > self.checkpoint_threshold + nbackfills + self.with_shared(|shared| { + let frame_id = shared.max_frame.load(Ordering::Acquire) as usize; + let nbackfills = shared.nbackfills.load(Ordering::Acquire) as usize; + frame_id > self.checkpoint_threshold + nbackfills + }) } #[instrument(skip_all, level = Level::DEBUG)] @@ -1322,10 +1324,11 @@ impl Wal for WalFile { tracing::debug!("wal_sync finish"); syncing.store(false, Ordering::SeqCst); }); - let shared = self.get_shared(); + let file = self.with_shared(|shared| { + assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); + shared.file.as_ref().unwrap().clone() + }); self.syncing.store(true, Ordering::SeqCst); - assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - let file = shared.file.as_ref().unwrap(); let c = file.sync(completion)?; Ok(c) } @@ -1336,11 +1339,11 @@ impl Wal for WalFile { } fn get_max_frame_in_wal(&self) -> u64 { - self.get_shared().max_frame.load(Ordering::Acquire) + self.with_shared(|shared| shared.max_frame.load(Ordering::Acquire)) } fn get_checkpoint_seq(&self) -> u32 { - self.get_shared().wal_header.lock().checkpoint_seq + self.with_shared(|shared| shared.wal_header.lock().checkpoint_seq) } fn get_max_frame(&self) -> u64 { @@ -1351,10 +1354,9 @@ impl Wal for WalFile { self.min_frame.load(Ordering::Acquire) } - #[instrument(err, skip_all, level = Level::DEBUG)] - fn rollback(&mut self) -> Result<()> { - let (max_frame, last_checksum) = { - let shared = self.get_shared(); + #[instrument(skip_all, level = Level::DEBUG)] + fn rollback(&mut self) { + let (max_frame, last_checksum) = self.with_shared(|shared| { let max_frame = shared.max_frame.load(Ordering::Acquire); let mut frame_cache = shared.frame_cache.lock(); frame_cache.retain(|_page_id, frames| { @@ -1365,27 +1367,27 @@ impl Wal for WalFile { !frames.is_empty() }); (max_frame, shared.last_checksum) - }; + }); self.last_checksum = last_checksum; self.max_frame.store(max_frame, Ordering::Release); self.reset_internal_states(); - Ok(()) } #[instrument(skip_all, level = Level::DEBUG)] fn finish_append_frames_commit(&mut self) -> Result<()> { - let mut shared = self.get_shared_mut(); - shared - .max_frame - .store(self.max_frame.load(Ordering::Acquire), Ordering::Release); - tracing::trace!(max_frame = self.max_frame.load(Ordering::Acquire), ?self.last_checksum); - shared.last_checksum = self.last_checksum; - self.transaction_count.fetch_add(1, Ordering::Release); - shared.transaction_count.store( - self.transaction_count.load(Ordering::Acquire), - Ordering::Release, - ); - Ok(()) + self.with_shared_mut(|shared| { + shared + .max_frame + .store(self.max_frame.load(Ordering::Acquire), Ordering::Release); + tracing::trace!(max_frame = self.max_frame.load(Ordering::Acquire), ?self.last_checksum); + shared.last_checksum = self.last_checksum; + self.transaction_count.fetch_add(1, Ordering::Release); + shared.transaction_count.store( + self.transaction_count.load(Ordering::Acquire), + Ordering::Release, + ); + Ok(()) + }) } fn changed_pages_after(&self, frame_watermark: u64) -> Result> { @@ -1412,12 +1414,11 @@ impl Wal for WalFile { } fn prepare_wal_start(&mut self, page_size: PageSize) -> Result> { - if self.get_shared().is_initialized()? { + if self.with_shared(|shared| shared.is_initialized())? { return Ok(None); } tracing::debug!("ensure_header_if_needed"); - self.last_checksum = { - let mut shared = self.get_shared_mut(); + self.last_checksum = self.with_shared_mut(|shared| { let checksum = { let mut hdr = shared.wal_header.lock(); hdr.magic = if cfg!(target_endian = "big") { @@ -1443,20 +1444,25 @@ impl Wal for WalFile { }; shared.last_checksum = checksum; checksum - }; + }); self.max_frame.store(0, Ordering::Release); - let shared = self.get_shared(); - assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - let file = shared.file.as_ref().unwrap(); - let c = sqlite3_ondisk::begin_write_wal_header(file.as_ref(), &shared.wal_header.lock())?; + let (header, file) = self.with_shared(|shared| { + assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); + ( + *shared.wal_header.lock(), + shared.file.as_ref().unwrap().clone(), + ) + }); + let c = sqlite3_ondisk::begin_write_wal_header(file.as_ref(), &header)?; Ok(Some(c)) } fn prepare_wal_finish(&mut self) -> Result { - let shared = self.get_shared(); - assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - let file = shared.file.as_ref().unwrap(); + let file = self.with_shared(|shared| { + assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); + shared.file.as_ref().unwrap().clone() + }); let shared = self.shared.clone(); let c = file.sync(Completion::new_sync(move |_| { shared.read().initialized.store(true, Ordering::Release); @@ -1476,18 +1482,17 @@ impl Wal for WalFile { "we limit number of iovecs to IOV_MAX" ); turso_assert!( - self.get_shared().is_initialized()?, + self.with_shared(|shared| shared.is_initialized())?, "WAL must be prepared with prepare_wal_start/prepare_wal_finish method" ); - let (header, shared_page_size, epoch) = { - let shared = self.get_shared(); + let (header, shared_page_size, epoch) = self.with_shared(|shared| { let hdr_guard = shared.wal_header.lock(); let header: WalHeader = *hdr_guard; let shared_page_size = header.page_size; let epoch = shared.epoch.load(Ordering::Acquire); (header, shared_page_size, epoch) - }; + }); turso_assert!( shared_page_size == page_sz.get(), "page size mismatch, tried to change page size after WAL header was already initialized: shared.page_size={shared_page_size}, page_size={}", @@ -1576,9 +1581,10 @@ impl Wal for WalFile { let c = Completion::new_write_linked(cmp); - let shared = self.get_shared(); - assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - let file = shared.file.as_ref().unwrap(); + let file = self.with_shared(|shared| { + assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); + shared.file.as_ref().unwrap().clone() + }); let c = file.pwritev(start_off, iovecs, c)?; Ok(c) } @@ -1593,8 +1599,10 @@ impl Wal for WalFile { } fn update_max_frame(&mut self) { - let new_max_frame = self.get_shared().max_frame.load(Ordering::Acquire); - self.max_frame.store(new_max_frame, Ordering::Release); + self.with_shared(|shared| { + let new_max_frame = shared.max_frame.load(Ordering::Acquire); + self.max_frame.store(new_max_frame, Ordering::Release); + }) } } @@ -1646,7 +1654,7 @@ impl WalFile { } fn page_size(&self) -> u32 { - self.get_shared().wal_header.lock().page_size + self.with_shared(|shared| shared.wal_header.lock().page_size) } fn frame_offset(&self, frame_id: u64) -> u64 { @@ -1655,7 +1663,7 @@ impl WalFile { WAL_HEADER_SIZE as u64 + page_offset } - fn get_shared_mut(&self) -> parking_lot::RwLockWriteGuard<'_, WalFileShared> { + fn _get_shared_mut(&self) -> parking_lot::RwLockWriteGuard<'_, WalFileShared> { // WASM in browser main thread doesn't have a way to "park" a thread // so, we spin way here instead of calling blocking lock #[cfg(target_family = "wasm")] @@ -1674,7 +1682,7 @@ impl WalFile { } } - fn get_shared(&self) -> parking_lot::RwLockReadGuard<'_, WalFileShared> { + fn _get_shared(&self) -> parking_lot::RwLockReadGuard<'_, WalFileShared> { // WASM in browser main thread doesn't have a way to "park" a thread // so, we spin way here instead of calling blocking lock #[cfg(target_family = "wasm")] @@ -1693,11 +1701,28 @@ impl WalFile { } } + #[inline] + fn with_shared_mut(&self, func: F) -> R + where + F: FnOnce(&mut WalFileShared) -> R, + { + let mut shared = self._get_shared_mut(); + func(&mut shared) + } + + #[inline] + fn with_shared(&self, func: F) -> R + where + F: FnOnce(&WalFileShared) -> R, + { + let shared = self._get_shared(); + func(&shared) + } + fn complete_append_frame(&mut self, page_id: u64, frame_id: u64, checksums: (u32, u32)) { self.last_checksum = checksums; self.max_frame.store(frame_id, Ordering::Release); - let shared = self.get_shared(); - { + self.with_shared(|shared| { let mut frame_cache = shared.frame_cache.lock(); match frame_cache.get_mut(&page_id) { Some(frames) => { @@ -1707,7 +1732,7 @@ impl WalFile { frame_cache.insert(page_id, vec![frame_id]); } } - } + }) } fn reset_internal_states(&mut self) { @@ -1742,12 +1767,11 @@ impl WalFile { // so no other checkpointer can run. fsync WAL if there are unapplied frames. // Decide the largest frame we are allowed to back‑fill. CheckpointState::Start => { - let (max_frame, nbackfills) = { - let shared = self.get_shared(); + let (max_frame, nbackfills) = self.with_shared(|shared| { let max_frame = shared.max_frame.load(Ordering::Acquire); let n_backfills = shared.nbackfills.load(Ordering::Acquire); (max_frame, n_backfills) - }; + }); let needs_backfill = max_frame > nbackfills; if !needs_backfill && !mode.should_restart_log() { // there are no frames to copy over and we don't need to reset @@ -1783,8 +1807,7 @@ impl WalFile { self.ongoing_checkpoint.max_frame = max_frame; self.ongoing_checkpoint.min_frame = nbackfills + 1; - let to_checkpoint = { - let shared = self.get_shared(); + let to_checkpoint = self.with_shared(|shared| { let frame_cache = shared.frame_cache.lock(); let mut list = Vec::with_capacity( self.ongoing_checkpoint @@ -1805,7 +1828,7 @@ impl WalFile { // sort by frame_id for read locality list.sort_unstable_by(|a, b| (a.1, a.0).cmp(&(b.1, b.0))); list - }; + }); self.ongoing_checkpoint.pages_to_checkpoint = to_checkpoint; self.ongoing_checkpoint.current_page = 0; self.ongoing_checkpoint.inflight_writes.clear(); @@ -1836,7 +1859,7 @@ impl WalFile { if self.ongoing_checkpoint.process_pending_reads() { tracing::trace!("Drained reads into batch"); } - let epoch = self.get_shared().epoch.load(Ordering::Acquire); + let epoch = self.with_shared(|shared| shared.epoch.load(Ordering::Acquire)); // Issue reads until we hit limits 'inner: while self.ongoing_checkpoint.should_issue_reads() { let (page_id, target_frame) = self.ongoing_checkpoint.pages_to_checkpoint @@ -1928,8 +1951,7 @@ impl WalFile { self.ongoing_checkpoint.complete(), "checkpoint pending flush must have finished" ); - let checkpoint_result = { - let shared = self.get_shared(); + let checkpoint_result = self.with_shared(|shared| { let current_mx = shared.max_frame.load(Ordering::Acquire); let nbackfills = shared.nbackfills.load(Ordering::Acquire); // Record two num pages fields to return as checkpoint result to caller. @@ -1961,14 +1983,16 @@ impl WalFile { checkpoint_max_frame, ) } - }; + }); // store the max frame we were able to successfully checkpoint. // NOTE: we don't have a .shm file yet, so it's safe to update nbackfills here // before we sync, because if we crash and then recover, we will checkpoint the entire db anyway. - self.get_shared() - .nbackfills - .store(self.ongoing_checkpoint.max_frame, Ordering::Release); + self.with_shared(|shared| { + shared + .nbackfills + .store(self.ongoing_checkpoint.max_frame, Ordering::Release) + }); if mode.require_all_backfilled() && !checkpoint_result.everything_backfilled() { return Err(LimboError::Busy); @@ -1999,7 +2023,7 @@ impl WalFile { checkpoint_result.take().unwrap() }; // increment wal epoch to ensure no stale pages are used for backfilling - self.get_shared().epoch.fetch_add(1, Ordering::Release); + self.with_shared(|shared| shared.epoch.fetch_add(1, Ordering::Release)); // store a copy of the checkpoint result to return in the future if pragma // wal_checkpoint is called and we haven't backfilled again since. @@ -2068,29 +2092,30 @@ impl WalFile { /// We never modify slot values while a reader holds that slot's lock. /// TOOD: implement proper BUSY handling behavior fn determine_max_safe_checkpoint_frame(&self) -> u64 { - let mut shared = self.get_shared_mut(); - let shared_max = shared.max_frame.load(Ordering::Acquire); - let mut max_safe_frame = shared_max; + self.with_shared_mut(|shared| { + let shared_max = shared.max_frame.load(Ordering::Acquire); + let mut max_safe_frame = shared_max; - for (read_lock_idx, read_lock) in shared.read_locks.iter_mut().enumerate().skip(1) { - let this_mark = read_lock.get_value(); - if this_mark < max_safe_frame as u32 { - let busy = !read_lock.write(); - if !busy { - let val = if read_lock_idx == 1 { - // store the max_frame for the default read slot 1 - max_safe_frame as u32 + for (read_lock_idx, read_lock) in shared.read_locks.iter_mut().enumerate().skip(1) { + let this_mark = read_lock.get_value(); + if this_mark < max_safe_frame as u32 { + let busy = !read_lock.write(); + if !busy { + let val = if read_lock_idx == 1 { + // store the max_frame for the default read slot 1 + max_safe_frame as u32 + } else { + READMARK_NOT_USED + }; + read_lock.set_value_exclusive(val); + read_lock.unlock(); } else { - READMARK_NOT_USED - }; - read_lock.set_value_exclusive(val); - read_lock.unlock(); - } else { - max_safe_frame = this_mark as u64; + max_safe_frame = this_mark as u64; + } } } - } - max_safe_frame + max_safe_frame + }) } /// Called once the entire WAL has been back‑filled in RESTART or TRUNCATE mode. @@ -2106,9 +2131,8 @@ impl WalFile { self.checkpoint_guard ); tracing::debug!("restart_log(mode={mode:?})"); - { + self.with_shared_mut(|shared| { // Block all readers - let mut shared = self.get_shared_mut(); for idx in 1..shared.read_locks.len() { let lock = &mut shared.read_locks[idx]; if !lock.write() { @@ -2122,11 +2146,12 @@ impl WalFile { // after the log is reset, we must set all secondary marks to READMARK_NOT_USED so the next reader selects a fresh slot lock.set_value_exclusive(READMARK_NOT_USED); } - } + Ok(()) + })?; // reinitialize in‑memory state - self.get_shared_mut().restart_wal_header(&self.io, mode); - let cksm = self.get_shared().last_checksum; + self.with_shared_mut(|shared| shared.restart_wal_header(&self.io, mode)); + let cksm = self.with_shared(|shared| shared.last_checksum); self.last_checksum = cksm; self.max_frame.store(0, Ordering::Release); self.min_frame.store(0, Ordering::Release); @@ -2135,12 +2160,11 @@ impl WalFile { } fn truncate_log(&mut self) -> Result> { - let file = { - let shared = self.get_shared(); + let file = self.with_shared(|shared| { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); shared.initialized.store(false, Ordering::Release); shared.file.as_ref().unwrap().clone() - }; + }); let CheckpointState::Truncate { sync_sent, @@ -2247,9 +2271,10 @@ impl WalFile { }) }; // schedule read of the page payload - let shared = self.get_shared(); - assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - let file = shared.file.as_ref().unwrap(); + let file = self.with_shared(|shared| { + assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); + shared.file.as_ref().unwrap().clone() + }); let c = begin_read_wal_frame( file.as_ref(), offset + WAL_FRAME_HEADER_SIZE as u64, @@ -2388,7 +2413,7 @@ impl WalFileShared { pub fn create(&mut self, file: Arc) -> Result<()> { if self.enabled.load(Ordering::SeqCst) { - return Err(LimboError::InternalError("WAL already enabled".to_string())); + return Ok(()); } let magic = if cfg!(target_endian = "big") { @@ -2535,7 +2560,7 @@ pub mod test { for _i in 0..25 { let _ = conn.execute("insert into test (value) values (randomblob(1024)), (randomblob(1024)), (randomblob(1024))"); } - let pager = conn.pager.write(); + let pager = conn.pager.load(); let _ = pager.cacheflush(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); @@ -2626,7 +2651,7 @@ pub mod test { conn.execute("create table test(id integer primary key, value text)") .unwrap(); bulk_inserts(&conn, 20, 3); - let completions = conn.pager.write().cacheflush().unwrap(); + let completions = conn.pager.load().cacheflush().unwrap(); for c in completions { db.io.wait_for_completion(c).unwrap(); } @@ -2652,7 +2677,7 @@ pub mod test { // Run a RESTART checkpoint, should backfill everything and reset WAL counters, // but NOT truncate the file. { - let pager = conn.pager.read(); + let pager = conn.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); let res = run_checkpoint_until_done(&mut *wal, &pager, CheckpointMode::Restart); assert_eq!(res.num_attempted, mx_before); @@ -2698,7 +2723,7 @@ pub mod test { conn.execute("insert into test(value) values ('post_restart')") .unwrap(); conn.pager - .write() + .load() .wal .as_ref() .unwrap() @@ -2721,14 +2746,14 @@ pub mod test { .execute("create table test(id integer primary key, value text)") .unwrap(); bulk_inserts(&conn1.clone(), 15, 2); - let completions = conn1.pager.write().cacheflush().unwrap(); + let completions = conn1.pager.load().cacheflush().unwrap(); for c in completions { db.io.wait_for_completion(c).unwrap(); } // Force a read transaction that will freeze a lower read mark let readmark = { - let pager = conn2.pager.write(); + let pager = conn2.pager.load(); let mut wal2 = pager.wal.as_ref().unwrap().borrow_mut(); wal2.begin_read_tx().unwrap(); wal2.get_max_frame() @@ -2736,14 +2761,14 @@ pub mod test { // generate more frames that the reader will not see. bulk_inserts(&conn1.clone(), 15, 2); - let completions = conn1.pager.write().cacheflush().unwrap(); + let completions = conn1.pager.load().cacheflush().unwrap(); for c in completions { db.io.wait_for_completion(c).unwrap(); } // Run passive checkpoint, expect partial let (res1, max_before) = { - let pager = conn1.pager.read(); + let pager = conn1.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); let res = run_checkpoint_until_done( &mut *wal, @@ -2768,13 +2793,13 @@ pub mod test { ); // Release reader { - let pager = conn2.pager.write(); + let pager = conn2.pager.load(); let wal2 = pager.wal.as_ref().unwrap().borrow_mut(); wal2.end_read_tx(); } // Second passive checkpoint should finish - let pager = conn1.pager.read(); + let pager = conn1.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); let res2 = run_checkpoint_until_done( &mut *wal, @@ -2798,7 +2823,7 @@ pub mod test { // Start a read transaction conn2 .pager - .write() + .load() .wal .as_ref() .unwrap() @@ -2808,7 +2833,7 @@ pub mod test { // checkpoint should succeed here because the wal is fully checkpointed (empty) // so the reader is using readmark0 to read directly from the db file. - let p = conn1.pager.read(); + let p = conn1.pager.load(); let mut w = p.wal.as_ref().unwrap().borrow_mut(); loop { match w.checkpoint(&p, CheckpointMode::Restart) { @@ -2825,7 +2850,7 @@ pub mod test { } } drop(w); - conn2.pager.write().end_read_tx().unwrap(); + conn2.pager.load().end_read_tx(); conn1 .execute("create table test(id integer primary key, value text)") @@ -2836,8 +2861,8 @@ pub mod test { .unwrap(); } // now that we have some frames to checkpoint, try again - conn2.pager.write().begin_read_tx().unwrap(); - let p = conn1.pager.read(); + conn2.pager.load().begin_read_tx().unwrap(); + let p = conn1.pager.load(); let mut w = p.wal.as_ref().unwrap().borrow_mut(); loop { match w.checkpoint(&p, CheckpointMode::Restart) { @@ -2869,7 +2894,7 @@ pub mod test { bulk_inserts(&conn, 10, 5); // Checkpoint with restart { - let pager = conn.pager.read(); + let pager = conn.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); let result = run_checkpoint_until_done(&mut *wal, &pager, CheckpointMode::Restart); assert!(result.everything_backfilled()); @@ -2910,7 +2935,7 @@ pub mod test { // R1 starts reading let r1_max_frame = { - let pager = conn_r1.pager.write(); + let pager = conn_r1.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); wal.begin_read_tx().unwrap(); wal.get_max_frame() @@ -2919,7 +2944,7 @@ pub mod test { // R2 starts reading, sees more frames than R1 let r2_max_frame = { - let pager = conn_r2.pager.write(); + let pager = conn_r2.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); wal.begin_read_tx().unwrap(); wal.get_max_frame() @@ -2927,7 +2952,7 @@ pub mod test { // try passive checkpoint, should only checkpoint up to R1's position let checkpoint_result = { - let pager = conn_writer.pager.read(); + let pager = conn_writer.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done( &mut *wal, @@ -2951,7 +2976,7 @@ pub mod test { assert_eq!( conn_r2 .pager - .read() + .load() .wal .as_ref() .unwrap() @@ -2976,7 +3001,7 @@ pub mod test { let max_frame_before = wal_shared.read().max_frame.load(Ordering::SeqCst); { - let pager = conn.pager.read(); + let pager = conn.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); let _result = run_checkpoint_until_done( &mut *wal, @@ -3009,7 +3034,7 @@ pub mod test { // start a write transaction { - let pager = conn2.pager.write(); + let pager = conn2.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); let _ = wal.begin_read_tx().unwrap(); wal.begin_write_tx().unwrap(); @@ -3017,7 +3042,7 @@ pub mod test { // should fail because writer lock is held let result = { - let pager = conn1.pager.read(); + let pager = conn1.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); wal.checkpoint(&pager, CheckpointMode::Restart) }; @@ -3029,7 +3054,7 @@ pub mod test { conn2 .pager - .read() + .load() .wal .as_ref() .unwrap() @@ -3038,7 +3063,7 @@ pub mod test { // release write lock conn2 .pager - .read() + .load() .wal .as_ref() .unwrap() @@ -3047,7 +3072,7 @@ pub mod test { // now restart should succeed let result = { - let pager = conn1.pager.read(); + let pager = conn1.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done(&mut *wal, &pager, CheckpointMode::Restart) }; @@ -3065,13 +3090,13 @@ pub mod test { .unwrap(); // Attempt to start a write transaction without a read transaction - let pager = conn.pager.read(); + let pager = conn.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); let _ = wal.begin_write_tx(); } fn check_read_lock_slot(conn: &Arc, expected_slot: usize) -> bool { - let pager = conn.pager.read(); + let pager = conn.pager.load(); let wal = pager.wal.as_ref().unwrap().borrow(); #[cfg(debug_assertions)] { @@ -3099,7 +3124,7 @@ pub mod test { stmt.step().unwrap(); let frame = conn .pager - .read() + .load() .wal .as_ref() .unwrap() @@ -3127,7 +3152,7 @@ pub mod test { // passive checkpoint #1 let result1 = { - let pager = conn_writer.pager.read(); + let pager = conn_writer.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done( &mut *wal, @@ -3144,7 +3169,7 @@ pub mod test { // passive checkpoint #2 let result2 = { - let pager = conn_writer.pager.read(); + let pager = conn_writer.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done( &mut *wal, @@ -3193,7 +3218,7 @@ pub mod test { // Do a TRUNCATE checkpoint { - let pager = conn.pager.read(); + let pager = conn.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done( &mut *wal, @@ -3254,7 +3279,7 @@ pub mod test { // Do a TRUNCATE checkpoint { - let pager = conn.pager.read(); + let pager = conn.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done( &mut *wal, @@ -3292,7 +3317,7 @@ pub mod test { assert_eq!(hdr.page_size, 4096, "invalid page size"); assert_eq!(hdr.checkpoint_seq, 1, "invalid checkpoint_seq"); { - let pager = conn.pager.read(); + let pager = conn.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done( &mut *wal, @@ -3342,7 +3367,7 @@ pub mod test { .unwrap(); // Start a read transaction on conn2 { - let pager = conn2.pager.write(); + let pager = conn2.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); wal.begin_read_tx().unwrap(); } @@ -3350,7 +3375,7 @@ pub mod test { bulk_inserts(&conn1, 5, 5); // Try to start a write transaction on conn2 with a stale snapshot let result = { - let pager = conn2.pager.read(); + let pager = conn2.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); wal.begin_write_tx() }; @@ -3359,14 +3384,14 @@ pub mod test { // End read transaction and start a fresh one { - let pager = conn2.pager.read(); + let pager = conn2.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); wal.end_read_tx(); wal.begin_read_tx().unwrap(); } // Now write transaction should work let result = { - let pager = conn2.pager.read(); + let pager = conn2.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); wal.begin_write_tx() }; @@ -3385,7 +3410,7 @@ pub mod test { bulk_inserts(&conn1, 5, 5); // Do a full checkpoint to move all data to DB file { - let pager = conn1.pager.read(); + let pager = conn1.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done( &mut *wal, @@ -3398,14 +3423,14 @@ pub mod test { // Start a read transaction on conn2 { - let pager = conn2.pager.write(); + let pager = conn2.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); wal.begin_read_tx().unwrap(); } // should use slot 0, as everything is backfilled assert!(check_read_lock_slot(&conn2, 0)); { - let pager = conn1.pager.read(); + let pager = conn1.pager.load(); let wal = pager.wal.as_ref().unwrap().borrow(); let frame = wal.find_frame(5, None); // since we hold readlock0, we should ignore the db file and find_frame should return none @@ -3413,7 +3438,7 @@ pub mod test { } // Try checkpoint, should fail because reader has slot 0 { - let pager = conn1.pager.read(); + let pager = conn1.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); let result = wal.checkpoint(&pager, CheckpointMode::Restart); @@ -3424,12 +3449,12 @@ pub mod test { } // End the read transaction { - let pager = conn2.pager.read(); + let pager = conn2.pager.load(); let wal = pager.wal.as_ref().unwrap().borrow(); wal.end_read_tx(); } { - let pager = conn1.pager.read(); + let pager = conn1.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); let result = run_checkpoint_until_done(&mut *wal, &pager, CheckpointMode::Restart); assert!( @@ -3450,7 +3475,7 @@ pub mod test { bulk_inserts(&conn, 8, 4); // Ensure frames are flushed to the WAL - let completions = conn.pager.write().cacheflush().unwrap(); + let completions = conn.pager.load().cacheflush().unwrap(); for c in completions { db.io.wait_for_completion(c).unwrap(); } @@ -3462,7 +3487,7 @@ pub mod test { // Run FULL checkpoint - must backfill *all* frames up to mx_before let result = { - let pager = conn.pager.read(); + let pager = conn.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done(&mut *wal, &pager, CheckpointMode::Full) }; @@ -3483,26 +3508,26 @@ pub mod test { // First commit some data and flush (reader will snapshot here) bulk_inserts(&writer, 2, 3); - let completions = writer.pager.write().cacheflush().unwrap(); + let completions = writer.pager.load().cacheflush().unwrap(); for c in completions { db.io.wait_for_completion(c).unwrap(); } // Start a read transaction pinned at the current snapshot { - let pager = reader.pager.write(); + let pager = reader.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); wal.begin_read_tx().unwrap(); } let r_snapshot = { - let pager = reader.pager.read(); + let pager = reader.pager.load(); let wal = pager.wal.as_ref().unwrap().borrow(); wal.get_max_frame() }; // Advance WAL beyond the reader's snapshot bulk_inserts(&writer, 3, 4); - let completions = writer.pager.write().cacheflush().unwrap(); + let completions = writer.pager.load().cacheflush().unwrap(); for c in completions { db.io.wait_for_completion(c).unwrap(); } @@ -3511,7 +3536,7 @@ pub mod test { // FULL must return Busy while a reader is stuck behind { - let pager = writer.pager.read(); + let pager = writer.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); loop { match wal.checkpoint(&pager, CheckpointMode::Full) { @@ -3529,13 +3554,13 @@ pub mod test { // Release the reader, now full mode should succeed and backfill everything { - let pager = reader.pager.read(); + let pager = reader.pager.load(); let wal = pager.wal.as_ref().unwrap().borrow(); wal.end_read_tx(); } let result = { - let pager = writer.pager.read(); + let pager = writer.pager.load(); let mut wal = pager.wal.as_ref().unwrap().borrow_mut(); run_checkpoint_until_done(&mut *wal, &pager, CheckpointMode::Full) }; diff --git a/core/translate/alter.rs b/core/translate/alter.rs index 77a1949ea..4f8de0d99 100644 --- a/core/translate/alter.rs +++ b/core/translate/alter.rs @@ -6,7 +6,7 @@ use turso_parser::{ use crate::{ function::{AlterTableFunc, Func}, - schema::{Column, Table}, + schema::{Column, Table, RESERVED_TABLE_PREFIXES}, translate::{ emitter::Resolver, expr::{walk_expr, WalkControl}, @@ -41,6 +41,17 @@ pub fn translate_alter_table( crate::bail_parse_error!("table {} may not be modified", table_name); } + if let ast::AlterTableBody::RenameTo(new_table_name) = &alter_table { + let normalized_new_name = normalize_ident(new_table_name.as_str()); + + if RESERVED_TABLE_PREFIXES + .iter() + .any(|prefix| normalized_new_name.starts_with(prefix)) + { + crate::bail_parse_error!("Object name reserved for internal use: {}", new_table_name); + } + } + let table_indexes = resolver.schema.get_indices(table_name).collect::>(); if !table_indexes.is_empty() && !resolver.schema.indexes_enabled() { diff --git a/core/translate/analyze.rs b/core/translate/analyze.rs index ece9a558b..665e43e0f 100644 --- a/core/translate/analyze.rs +++ b/core/translate/analyze.rs @@ -97,6 +97,7 @@ pub fn translate_analyze( program.emit_insn(Insn::Delete { cursor_id, table_name: "sqlite_stat1".to_string(), + is_part_of_update: false, }); program.emit_insn(Insn::Next { cursor_id, diff --git a/core/translate/collate.rs b/core/translate/collate.rs index 04324424c..89804c470 100644 --- a/core/translate/collate.rs +++ b/core/translate/collate.rs @@ -1,6 +1,5 @@ use std::{cmp::Ordering, str::FromStr as _}; -use tracing::Level; use turso_parser::ast::Expr; use crate::{ @@ -37,8 +36,8 @@ impl CollationSeq { }) } + #[inline(always)] pub fn compare_strings(&self, lhs: &str, rhs: &str) -> Ordering { - tracing::event!(Level::DEBUG, collate = %self, lhs, rhs); match self { CollationSeq::Binary => Self::binary_cmp(lhs, rhs), CollationSeq::NoCase => Self::nocase_cmp(lhs, rhs), @@ -46,16 +45,19 @@ impl CollationSeq { } } + #[inline(always)] fn binary_cmp(lhs: &str, rhs: &str) -> Ordering { lhs.cmp(rhs) } + #[inline(always)] fn nocase_cmp(lhs: &str, rhs: &str) -> Ordering { let nocase_lhs = uncased::UncasedStr::new(lhs); let nocase_rhs = uncased::UncasedStr::new(rhs); nocase_lhs.cmp(nocase_rhs) } + #[inline(always)] fn rtrim_cmp(lhs: &str, rhs: &str) -> Ordering { lhs.trim_end().cmp(rhs.trim_end()) } @@ -371,6 +373,7 @@ mod tests { hidden: false, }], unique_sets: vec![], + foreign_keys: vec![], })), }); @@ -413,6 +416,7 @@ mod tests { hidden: false, }], unique_sets: vec![], + foreign_keys: vec![], })), }); // Right table t2(id=2) @@ -446,6 +450,7 @@ mod tests { hidden: false, }], unique_sets: vec![], + foreign_keys: vec![], })), }); table_references @@ -486,6 +491,7 @@ mod tests { hidden: false, }], unique_sets: vec![], + foreign_keys: vec![], })), }); table_references diff --git a/core/translate/compound_select.rs b/core/translate/compound_select.rs index 08a55d099..619d07585 100644 --- a/core/translate/compound_select.rs +++ b/core/translate/compound_select.rs @@ -3,14 +3,13 @@ use crate::translate::collate::get_collseq_from_expr; use crate::translate::emitter::{emit_query, LimitCtx, Resolver, TranslateCtx}; use crate::translate::expr::translate_expr; use crate::translate::plan::{Plan, QueryDestination, SelectPlan}; -use crate::translate::result_row::try_fold_expr_to_i64; use crate::vdbe::builder::{CursorType, ProgramBuilder}; use crate::vdbe::insn::Insn; use crate::vdbe::BranchOffset; -use crate::{emit_explain, QueryMode, SymbolTable}; +use crate::{emit_explain, LimboError, QueryMode, SymbolTable}; use std::sync::Arc; use tracing::instrument; -use turso_parser::ast::{CompoundOperator, SortOrder}; +use turso_parser::ast::{CompoundOperator, Expr, Literal, SortOrder}; use tracing::Level; @@ -41,44 +40,66 @@ pub fn emit_program_for_compound_select( // Each subselect shares the same limit_ctx and offset, because the LIMIT, OFFSET applies to // the entire compound select, not just a single subselect. - let limit_ctx = limit.as_ref().map(|limit| { - let reg = program.alloc_register(); - if let Some(val) = try_fold_expr_to_i64(limit) { - program.emit_insn(Insn::Integer { - value: val, - dest: reg, - }); - } else { - program.add_comment(program.offset(), "OFFSET expr"); - _ = translate_expr(program, None, limit, reg, &right_most_ctx.resolver); + let limit_ctx = limit + .as_ref() + .map(|limit| { + let reg = program.alloc_register(); + match limit.as_ref() { + Expr::Literal(Literal::Numeric(n)) => { + if let Ok(value) = n.parse::() { + program.add_comment(program.offset(), "LIMIT counter"); + program.emit_insn(Insn::Integer { value, dest: reg }); + } else { + let value = n + .parse::() + .map_err(|_| LimboError::ParseError("invalid limit".to_string()))?; + program.emit_insn(Insn::Real { value, dest: reg }); + program.add_comment(program.offset(), "LIMIT counter"); + program.emit_insn(Insn::MustBeInt { reg }); + } + } + _ => { + _ = translate_expr(program, None, limit, reg, &right_most_ctx.resolver); + program.add_comment(program.offset(), "LIMIT counter"); + program.emit_insn(Insn::MustBeInt { reg }); + } + } + Ok::<_, LimboError>(LimitCtx::new_shared(reg)) + }) + .transpose()?; + let offset_reg = offset + .as_ref() + .map(|offset_expr| { + let reg = program.alloc_register(); + match offset_expr.as_ref() { + Expr::Literal(Literal::Numeric(n)) => { + // Compile-time constant offset + if let Ok(value) = n.parse::() { + program.emit_insn(Insn::Integer { value, dest: reg }); + } else { + let value = n + .parse::() + .map_err(|_| LimboError::ParseError("invalid offset".to_string()))?; + program.emit_insn(Insn::Real { value, dest: reg }); + } + } + _ => { + _ = translate_expr(program, None, offset_expr, reg, &right_most_ctx.resolver); + } + } + program.add_comment(program.offset(), "OFFSET counter"); program.emit_insn(Insn::MustBeInt { reg }); - } - LimitCtx::new_shared(reg) - }); - let offset_reg = offset.as_ref().map(|offset_expr| { - let reg = program.alloc_register(); - - if let Some(val) = try_fold_expr_to_i64(offset_expr) { - // Compile-time constant offset - program.emit_insn(Insn::Integer { - value: val, - dest: reg, + let combined_reg = program.alloc_register(); + program.add_comment(program.offset(), "OFFSET + LIMIT"); + program.emit_insn(Insn::OffsetLimit { + offset_reg: reg, + combined_reg, + limit_reg: limit_ctx.as_ref().unwrap().reg_limit, }); - } else { - program.add_comment(program.offset(), "OFFSET expr"); - _ = translate_expr(program, None, offset_expr, reg, &right_most_ctx.resolver); - program.emit_insn(Insn::MustBeInt { reg }); - } - let combined_reg = program.alloc_register(); - program.emit_insn(Insn::OffsetLimit { - offset_reg: reg, - combined_reg, - limit_reg: limit_ctx.as_ref().unwrap().reg_limit, - }); - - reg - }); + Ok::<_, LimboError>(reg) + }) + .transpose()?; // When a compound SELECT is part of a query that yields results to a coroutine (e.g. within an INSERT clause), // we must allocate registers for the result columns to be yielded. Each subselect will then yield to diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 4d2dbdbff..e9d49da49 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -54,7 +54,7 @@ pub fn translate_delete( result_columns, connection, )?; - optimize_plan(&mut delete_plan, resolver.schema)?; + optimize_plan(&mut program, &mut delete_plan, resolver.schema)?; let Plan::Delete(ref delete) = delete_plan else { panic!("delete_plan is not a DeletePlan"); }; diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 5e60617e6..ff22d1035 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -1,11 +1,12 @@ // This module contains code for emitting bytecode instructions for SQL query execution. // It handles translating high-level SQL operations into low-level bytecode that can be executed by the virtual machine. +use std::collections::HashSet; use std::num::NonZeroUsize; use std::sync::Arc; use tracing::{instrument, Level}; -use turso_parser::ast::{self, Expr}; +use turso_parser::ast::{self, Expr, Literal}; use super::aggregation::emit_ungrouped_aggregation; use super::expr::translate_expr; @@ -29,15 +30,18 @@ use crate::translate::expr::{ emit_returning_results, translate_expr_no_constant_opt, walk_expr_mut, NoConstantOptReason, ReturningValueRegisters, WalkControl, }; +use crate::translate::fkeys::{ + build_index_affinity_string, emit_fk_child_update_counters, + emit_fk_delete_parent_existence_checks, emit_guarded_fk_decrement, + emit_parent_key_change_checks, open_read_index, open_read_table, stabilize_new_row_for_fk, +}; use crate::translate::plan::{DeletePlan, JoinedTable, Plan, QueryDestination, Search}; use crate::translate::planner::ROWID_STRS; -use crate::translate::result_row::try_fold_expr_to_i64; use crate::translate::values::emit_values; use crate::translate::window::{emit_window_results, init_window, WindowMetadata}; use crate::util::{exprs_are_equivalent, normalize_ident}; use crate::vdbe::builder::{CursorKey, CursorType, ProgramBuilder}; use crate::vdbe::insn::{CmpInsFlags, IdxInsertFlags, InsertFlags, RegisterOrLiteral}; -use crate::vdbe::CursorID; use crate::vdbe::{insn::Insn, BranchOffset}; use crate::Connection; use crate::{bail_parse_error, Result, SymbolTable}; @@ -184,13 +188,30 @@ impl<'a> TranslateCtx<'a> { } } +#[derive(Debug, Clone)] +/// Update row source for UPDATE statements +/// `Normal` is the default mode, it will iterate either the table itself or an index on the table. +/// `PrebuiltEphemeralTable` is used when an ephemeral table containing the target rowids to update has +/// been built and it is being used for iteration. +pub enum UpdateRowSource { + /// Iterate over the table itself or an index on the table + Normal, + /// Iterate over an ephemeral table containing the target rowids to update + PrebuiltEphemeralTable { + /// The cursor id of the ephemeral table that is being used to iterate the target rowids to update. + ephemeral_table_cursor_id: usize, + /// The table that is being updated. + target_table: Arc, + }, +} + /// Used to distinguish database operations #[allow(clippy::upper_case_acronyms, dead_code)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone)] pub enum OperationMode { SELECT, INSERT, - UPDATE, + UPDATE(UpdateRowSource), DELETE, } @@ -251,7 +272,7 @@ pub fn emit_query<'a>( let after_main_loop_label = program.allocate_label(); t_ctx.label_main_loop_end = Some(after_main_loop_label); - init_limit(program, t_ctx, &plan.limit, &plan.offset); + init_limit(program, t_ctx, &plan.limit, &plan.offset)?; if !plan.values.is_empty() { let reg_result_cols_start = emit_values(program, plan, t_ctx)?; @@ -298,6 +319,7 @@ pub fn emit_query<'a>( &plan.order_by, &plan.table_references, plan.group_by.is_some(), + plan.distinctness != Distinctness::NonDistinct, &plan.aggregates, )?; } @@ -357,6 +379,7 @@ pub fn emit_query<'a>( &plan.join_order, &plan.where_clause, None, + OperationMode::SELECT, )?; // Process result columns and expressions in the inner loop @@ -368,7 +391,7 @@ pub fn emit_query<'a>( t_ctx, &plan.table_references, &plan.join_order, - None, + OperationMode::SELECT, )?; program.preassign_label_to_next_insn(after_main_loop_label); @@ -422,7 +445,7 @@ fn emit_program_for_delete( let after_main_loop_label = program.allocate_label(); t_ctx.label_main_loop_end = Some(after_main_loop_label); - init_limit(program, &mut t_ctx, &plan.limit, &None); + init_limit(program, &mut t_ctx, &plan.limit, &None)?; // No rows will be read from source table loops if there is a constant false condition eg. WHERE 0 if plan.contains_constant_false_condition { @@ -450,6 +473,7 @@ fn emit_program_for_delete( &[JoinOrderMember::default()], &plan.where_clause, None, + OperationMode::DELETE, )?; emit_delete_insns( @@ -466,16 +490,152 @@ fn emit_program_for_delete( &mut t_ctx, &plan.table_references, &[JoinOrderMember::default()], - None, + OperationMode::DELETE, )?; program.preassign_label_to_next_insn(after_main_loop_label); - // Finalize program program.result_columns = plan.result_columns; program.table_references.extend(plan.table_references); Ok(()) } +pub fn emit_fk_child_decrement_on_delete( + program: &mut ProgramBuilder, + resolver: &Resolver, + child_tbl: &BTreeTable, + child_table_name: &str, + child_cursor_id: usize, + child_rowid_reg: usize, +) -> crate::Result<()> { + for fk_ref in resolver.schema.resolved_fks_for_child(child_table_name)? { + if !fk_ref.fk.deferred { + continue; + } + // Fast path: if any FK column is NULL can't be a violation + let null_skip = program.allocate_label(); + for cname in &fk_ref.child_cols { + let (pos, col) = child_tbl.get_column(cname).unwrap(); + let src = if col.is_rowid_alias { + child_rowid_reg + } else { + let tmp = program.alloc_register(); + program.emit_insn(Insn::Column { + cursor_id: child_cursor_id, + column: pos, + dest: tmp, + default: None, + }); + tmp + }; + program.emit_insn(Insn::IsNull { + reg: src, + target_pc: null_skip, + }); + } + + if fk_ref.parent_uses_rowid { + // Probe parent table by rowid + let parent_tbl = resolver + .schema + .get_btree_table(&fk_ref.fk.parent_table) + .expect("parent btree"); + let pcur = open_read_table(program, &parent_tbl); + + let (pos, col) = child_tbl.get_column(&fk_ref.child_cols[0]).unwrap(); + let val = if col.is_rowid_alias { + child_rowid_reg + } else { + let tmp = program.alloc_register(); + program.emit_insn(Insn::Column { + cursor_id: child_cursor_id, + column: pos, + dest: tmp, + default: None, + }); + tmp + }; + let tmpi = program.alloc_register(); + program.emit_insn(Insn::Copy { + src_reg: val, + dst_reg: tmpi, + extra_amount: 0, + }); + program.emit_insn(Insn::MustBeInt { reg: tmpi }); + + // NotExists jumps when the parent key is missing, so we decrement there + let missing = program.allocate_label(); + let done = program.allocate_label(); + + program.emit_insn(Insn::NotExists { + cursor: pcur, + rowid_reg: tmpi, + target_pc: missing, + }); + + // Parent FOUND, no decrement + program.emit_insn(Insn::Close { cursor_id: pcur }); + program.emit_insn(Insn::Goto { target_pc: done }); + + // Parent MISSING, decrement is guarded by FkIfZero to avoid underflow + program.preassign_label_to_next_insn(missing); + program.emit_insn(Insn::Close { cursor_id: pcur }); + emit_guarded_fk_decrement(program, done); + program.preassign_label_to_next_insn(done); + } else { + // Probe parent unique index + let parent_tbl = resolver + .schema + .get_btree_table(&fk_ref.fk.parent_table) + .expect("parent btree"); + let idx = fk_ref.parent_unique_index.as_ref().expect("unique index"); + let icur = open_read_index(program, idx); + + // Build probe from current child row + let n = fk_ref.child_cols.len(); + let probe = program.alloc_registers(n); + for (i, cname) in fk_ref.child_cols.iter().enumerate() { + let (pos, col) = child_tbl.get_column(cname).unwrap(); + let src = if col.is_rowid_alias { + child_rowid_reg + } else { + let r = program.alloc_register(); + program.emit_insn(Insn::Column { + cursor_id: child_cursor_id, + column: pos, + dest: r, + default: None, + }); + r + }; + program.emit_insn(Insn::Copy { + src_reg: src, + dst_reg: probe + i, + extra_amount: 0, + }); + } + program.emit_insn(Insn::Affinity { + start_reg: probe, + count: std::num::NonZeroUsize::new(n).unwrap(), + affinities: build_index_affinity_string(idx, &parent_tbl), + }); + + let ok = program.allocate_label(); + program.emit_insn(Insn::Found { + cursor_id: icur, + target_pc: ok, + record_reg: probe, + num_regs: n, + }); + program.emit_insn(Insn::Close { cursor_id: icur }); + emit_guarded_fk_decrement(program, ok); + program.preassign_label_to_next_insn(ok); + program.emit_insn(Insn::Close { cursor_id: icur }); + } + program.preassign_label_to_next_insn(null_skip); + } + Ok(()) +} + fn emit_delete_insns( connection: &Arc, program: &mut ProgramBuilder, @@ -514,6 +674,34 @@ fn emit_delete_insns( dest: key_reg, }); + if connection.foreign_keys_enabled() { + if let Some(table) = unsafe { &*table_reference }.btree() { + if t_ctx + .resolver + .schema + .any_resolved_fks_referencing(table_name) + { + emit_fk_delete_parent_existence_checks( + program, + &t_ctx.resolver, + table_name, + main_table_cursor_id, + key_reg, + )?; + } + if t_ctx.resolver.schema.has_child_fks(table_name) { + emit_fk_child_decrement_on_delete( + program, + &t_ctx.resolver, + &table, + table_name, + main_table_cursor_id, + key_reg, + )?; + } + } + } + if unsafe { &*table_reference }.virtual_table().is_some() { let conflict_action = 0u16; let start_reg = key_reg; @@ -671,6 +859,7 @@ fn emit_delete_insns( program.emit_insn(Insn::Delete { cursor_id: main_table_cursor_id, table_name: table_name.to_string(), + is_part_of_update: false, }); if let Some(index) = iteration_index { @@ -679,6 +868,7 @@ fn emit_delete_insns( program.emit_insn(Insn::Delete { cursor_id: iteration_index_cursor, table_name: index.name.clone(), + is_part_of_update: false, }); } } @@ -710,7 +900,7 @@ fn emit_program_for_update( let after_main_loop_label = program.allocate_label(); t_ctx.label_main_loop_end = Some(after_main_loop_label); - init_limit(program, &mut t_ctx, &plan.limit, &plan.offset); + init_limit(program, &mut t_ctx, &plan.limit, &plan.offset)?; // No rows will be read from source table loops if there is a constant false condition eg. WHERE 0 if plan.contains_constant_false_condition { @@ -726,7 +916,15 @@ fn emit_program_for_update( }; *cursor_id }); - if let Some(ephemeral_plan) = ephemeral_plan { + let has_ephemeral_table = ephemeral_plan.is_some(); + + let target_table = if let Some(ephemeral_plan) = ephemeral_plan { + let table = ephemeral_plan + .table_references + .joined_tables() + .first() + .unwrap() + .clone(); program.emit_insn(Insn::OpenEphemeral { cursor_id: temp_cursor_id.unwrap(), is_table: true, @@ -734,7 +932,27 @@ fn emit_program_for_update( program.incr_nesting(); emit_program_for_select(program, resolver, ephemeral_plan)?; program.decr_nesting(); - } + Arc::new(table) + } else { + Arc::new( + plan.table_references + .joined_tables() + .first() + .unwrap() + .clone(), + ) + }; + + let mode = OperationMode::UPDATE(if has_ephemeral_table { + UpdateRowSource::PrebuiltEphemeralTable { + ephemeral_table_cursor_id: temp_cursor_id.expect( + "ephemeral table cursor id is always allocated if has_ephemeral_table is true", + ), + target_table: target_table.clone(), + } + } else { + UpdateRowSource::Normal + }); // Initialize the main loop init_loop( @@ -743,7 +961,7 @@ fn emit_program_for_update( &plan.table_references, &mut [], None, - OperationMode::UPDATE, + mode.clone(), &plan.where_clause, )?; @@ -780,8 +998,18 @@ fn emit_program_for_update( &[JoinOrderMember::default()], &plan.where_clause, temp_cursor_id, + mode.clone(), )?; + let target_table_cursor_id = + program.resolve_cursor_id(&CursorKey::table(target_table.internal_id)); + + let iteration_cursor_id = if has_ephemeral_table { + temp_cursor_id.unwrap() + } else { + target_table_cursor_id + }; + // Emit update instructions emit_update_insns( connection, @@ -789,7 +1017,9 @@ fn emit_program_for_update( &t_ctx, program, index_cursors, - temp_cursor_id, + iteration_cursor_id, + target_table_cursor_id, + target_table, )?; // Close the main loop @@ -798,11 +1028,10 @@ fn emit_program_for_update( &mut t_ctx, &plan.table_references, &[JoinOrderMember::default()], - temp_cursor_id, + mode.clone(), )?; program.preassign_label_to_next_insn(after_main_loop_label); - after(program); program.result_columns = plan.returning.unwrap_or_default(); @@ -811,20 +1040,28 @@ fn emit_program_for_update( } #[instrument(skip_all, level = Level::DEBUG)] +#[allow(clippy::too_many_arguments)] +/// Emits the instructions for the UPDATE loop. +/// +/// `iteration_cursor_id` is the cursor id of the table that is being iterated over. This can be either the table itself, an index, or an ephemeral table (see [crate::translate::plan::UpdatePlan]). +/// +/// `target_table_cursor_id` is the cursor id of the table that is being updated. +/// +/// `target_table` is the table that is being updated. fn emit_update_insns( connection: &Arc, plan: &mut UpdatePlan, t_ctx: &TranslateCtx, program: &mut ProgramBuilder, index_cursors: Vec<(usize, usize)>, - temp_cursor_id: Option, + iteration_cursor_id: usize, + target_table_cursor_id: usize, + target_table: Arc, ) -> crate::Result<()> { - // we can either use this obviously safe raw pointer or we can clone it - let table_ref: *const JoinedTable = plan.table_references.joined_tables().first().unwrap(); - let internal_id = unsafe { (*table_ref).internal_id }; + let internal_id = target_table.internal_id; let loop_labels = t_ctx.labels_main_loop.first().unwrap(); - let cursor_id = program.resolve_cursor_id(&CursorKey::table(internal_id)); - let (index, is_virtual) = match &unsafe { &*table_ref }.op { + let source_table = plan.table_references.joined_tables().first().unwrap(); + let (index, is_virtual) = match &source_table.op { Operation::Scan(Scan::BTreeTable { index, .. }) => ( index.as_ref().map(|index| { ( @@ -834,7 +1071,7 @@ fn emit_update_insns( }), false, ), - Operation::Scan(_) => (None, unsafe { &*table_ref }.virtual_table().is_some()), + Operation::Scan(_) => (None, target_table.virtual_table().is_some()), Operation::Search(search) => match search { &Search::RowidEq { .. } | Search::Seek { index: None, .. } => (None, false), Search::Seek { @@ -850,7 +1087,7 @@ fn emit_update_insns( }; let beg = program.alloc_registers( - unsafe { &*table_ref }.table.columns().len() + target_table.table.columns().len() + if is_virtual { 2 // two args before the relevant columns for VUpdate } else { @@ -858,12 +1095,13 @@ fn emit_update_insns( }, ); program.emit_insn(Insn::RowId { - cursor_id: temp_cursor_id.unwrap_or(cursor_id), + cursor_id: iteration_cursor_id, dest: beg, }); // Check if rowid was provided (through INTEGER PRIMARY KEY as a rowid alias) - let rowid_alias_index = unsafe { &*table_ref } + let rowid_alias_index = target_table + .table .columns() .iter() .position(|c| c.is_rowid_alias); @@ -885,15 +1123,18 @@ fn emit_update_insns( None }; - let check_rowid_not_exists_label = if has_user_provided_rowid { + let not_exists_check_required = + has_user_provided_rowid || iteration_cursor_id != target_table_cursor_id; + + let check_rowid_not_exists_label = if not_exists_check_required { Some(program.allocate_label()) } else { None }; - if has_user_provided_rowid { + if not_exists_check_required { program.emit_insn(Insn::NotExists { - cursor: cursor_id, + cursor: target_table_cursor_id, rowid_reg: beg, target_pc: check_rowid_not_exists_label.unwrap(), }); @@ -920,7 +1161,7 @@ fn emit_update_insns( decrement_by: 1, }); } - let col_len = unsafe { &*table_ref }.columns().len(); + let col_len = target_table.table.columns().len(); // we scan a column at a time, loading either the column's values, or the new value // from the Set expression, into registers so we can emit a MakeRecord and update the row. @@ -932,7 +1173,7 @@ fn emit_update_insns( } else { None }; - let table_name = unsafe { &*table_ref }.table.get_name(); + let table_name = target_table.table.get_name(); let start = if is_virtual { beg + 2 } else { beg + 1 }; @@ -951,7 +1192,7 @@ fn emit_update_insns( }); } } - for (idx, table_column) in unsafe { &*table_ref }.columns().iter().enumerate() { + for (idx, table_column) in target_table.table.columns().iter().enumerate() { let target_reg = start + idx; if let Some((col_idx, expr)) = plan.set_clauses.iter().find(|(i, _)| *i == idx) { // Skip if this is the sentinel value @@ -1034,7 +1275,7 @@ fn emit_update_insns( program.emit_null(target_reg, None); } else if is_virtual { program.emit_insn(Insn::VColumn { - cursor_id, + cursor_id: target_table_cursor_id, column: idx, dest: target_reg, }); @@ -1048,7 +1289,7 @@ fn emit_update_insns( None } }) - .unwrap_or(&cursor_id); + .unwrap_or(&target_table_cursor_id); program.emit_column_or_rowid( cursor_id, column_idx_in_index.unwrap_or(idx), @@ -1067,6 +1308,60 @@ fn emit_update_insns( } } + if connection.foreign_keys_enabled() { + let rowid_new_reg = rowid_set_clause_reg.unwrap_or(beg); + if let Some(table_btree) = target_table.table.btree() { + stabilize_new_row_for_fk( + program, + &table_btree, + &plan.set_clauses, + target_table_cursor_id, + start, + rowid_new_reg, + )?; + if t_ctx.resolver.schema.has_child_fks(table_name) { + // Child-side checks: + // this ensures updated row still satisfies child FKs that point OUT from this table + emit_fk_child_update_counters( + program, + &t_ctx.resolver, + &table_btree, + table_name, + target_table_cursor_id, + start, + rowid_new_reg, + &plan + .set_clauses + .iter() + .map(|(i, _)| *i) + .collect::>(), + )?; + } + // Parent-side checks: + // We only need to do work if the referenced key (the parent key) might change. + // we detect that by comparing OLD vs NEW primary key representation + // then run parent FK checks only when it actually changes. + if t_ctx + .resolver + .schema + .any_resolved_fks_referencing(table_name) + { + emit_parent_key_change_checks( + program, + &t_ctx.resolver, + &table_btree, + plan.indexes_to_update.iter(), + target_table_cursor_id, + beg, + start, + rowid_new_reg, + rowid_set_clause_reg, + &plan.set_clauses, + )?; + } + } + } + for (index, (idx_cursor_id, record_reg)) in plan.indexes_to_update.iter().zip(&index_cursors) { // We need to know whether or not the OLD values satisfied the predicate on the // partial index, so we can know whether or not to delete the old index entry, @@ -1098,7 +1393,7 @@ fn emit_update_insns( // to refer to the new values, which are already loaded into registers starting at `start`. rewrite_where_for_update_registers( &mut new_where, - unsafe { &*table_ref }.columns(), + target_table.table.columns(), start, rowid_set_clause_reg.unwrap_or(beg), )?; @@ -1140,13 +1435,13 @@ fn emit_update_insns( let delete_start_reg = program.alloc_registers(num_regs); for (reg_offset, column_index) in index.columns.iter().enumerate() { program.emit_column_or_rowid( - cursor_id, + target_table_cursor_id, column_index.pos_in_table, delete_start_reg + reg_offset, ); } program.emit_insn(Insn::RowId { - cursor_id, + cursor_id: target_table_cursor_id, dest: delete_start_reg + num_regs - 1, }); program.emit_insn(Insn::IdxDelete { @@ -1178,7 +1473,8 @@ fn emit_update_insns( let rowid_reg = rowid_set_clause_reg.unwrap_or(beg); for (i, col) in index.columns.iter().enumerate() { - let col_in_table = unsafe { &*table_ref } + let col_in_table = target_table + .table .columns() .get(col.pos_in_table) .expect("column index out of bounds"); @@ -1213,7 +1509,7 @@ fn emit_update_insns( .columns .iter() .map(|ic| { - unsafe { &*table_ref }.columns()[ic.pos_in_table] + target_table.table.columns()[ic.pos_in_table] .affinity() .aff_mask() }) @@ -1283,7 +1579,7 @@ fn emit_update_insns( } } - if let Some(btree_table) = unsafe { &*table_ref }.btree() { + if let Some(btree_table) = target_table.table.btree() { if btree_table.is_strict { program.emit_insn(Insn::TypeCheck { start_reg: start, @@ -1306,7 +1602,7 @@ fn emit_update_insns( }); program.emit_insn(Insn::NotExists { - cursor: cursor_id, + cursor: target_table_cursor_id, rowid_reg: target_reg, target_pc: record_label, }); @@ -1314,7 +1610,8 @@ fn emit_update_insns( let description = if let Some(idx) = rowid_alias_index { String::from(table_name) + "." - + unsafe { &*table_ref } + + target_table + .table .columns() .get(idx) .unwrap() @@ -1335,7 +1632,8 @@ fn emit_update_insns( let record_reg = program.alloc_register(); - let affinity_str = unsafe { &*table_ref } + let affinity_str = target_table + .table .columns() .iter() .map(|col| col.affinity().aff_mask()) @@ -1349,9 +1647,9 @@ fn emit_update_insns( affinity_str: Some(affinity_str), }); - if has_user_provided_rowid { + if not_exists_check_required { program.emit_insn(Insn::NotExists { - cursor: cursor_id, + cursor: target_table_cursor_id, rowid_reg: beg, target_pc: check_rowid_not_exists_label.unwrap(), }); @@ -1365,7 +1663,7 @@ fn emit_update_insns( let cdc_rowid_before_reg = program.alloc_register(); if has_user_provided_rowid { program.emit_insn(Insn::RowId { - cursor_id, + cursor_id: target_table_cursor_id, dest: cdc_rowid_before_reg, }); Some(cdc_rowid_before_reg) @@ -1380,8 +1678,8 @@ fn emit_update_insns( let cdc_before_reg = if program.capture_data_changes_mode().has_before() { Some(emit_cdc_full_record( program, - unsafe { &*table_ref }.table.columns(), - cursor_id, + target_table.table.columns(), + target_table_cursor_id, cdc_rowid_before_reg.expect("cdc_rowid_before_reg must be set"), )) } else { @@ -1391,25 +1689,26 @@ fn emit_update_insns( // If we are updating the rowid, we cannot rely on overwrite on the // Insert instruction to update the cell. We need to first delete the current cell // and later insert the updated record - if has_user_provided_rowid { + if not_exists_check_required { program.emit_insn(Insn::Delete { - cursor_id, + cursor_id: target_table_cursor_id, table_name: table_name.to_string(), + is_part_of_update: true, }); } program.emit_insn(Insn::Insert { - cursor: cursor_id, + cursor: target_table_cursor_id, key_reg: rowid_set_clause_reg.unwrap_or(beg), record_reg, - flag: if has_user_provided_rowid { + flag: if not_exists_check_required { // The previous Insn::NotExists and Insn::Delete seek to the old rowid, // so to insert a new user-provided rowid, we need to seek to the correct place. InsertFlags::new().require_seek().update_rowid_change() } else { InsertFlags::new() }, - table_name: unsafe { &*table_ref }.identifier.clone(), + table_name: target_table.identifier.clone(), }); // Emit RETURNING results if specified @@ -1429,7 +1728,7 @@ fn emit_update_insns( let cdc_after_reg = if program.capture_data_changes_mode().has_after() { Some(emit_cdc_patch_record( program, - &unsafe { &*table_ref }.table, + &target_table.table, start, record_reg, cdc_rowid_after_reg, @@ -1483,7 +1782,14 @@ fn emit_update_insns( emit_cdc_insns( program, &t_ctx.resolver, - OperationMode::UPDATE, + OperationMode::UPDATE(if plan.ephemeral_plan.is_some() { + UpdateRowSource::PrebuiltEphemeralTable { + ephemeral_table_cursor_id: iteration_cursor_id, + target_table: target_table.clone(), + } + } else { + UpdateRowSource::Normal + }), cdc_cursor_id, cdc_rowid_before_reg, cdc_before_reg, @@ -1493,10 +1799,10 @@ fn emit_update_insns( )?; } } - } else if unsafe { &*table_ref }.virtual_table().is_some() { + } else if target_table.virtual_table().is_some() { let arg_count = col_len + 2; program.emit_insn(Insn::VUpdate { - cursor_id, + cursor_id: target_table_cursor_id, arg_count, start_reg: beg, conflict_action: 0u16, @@ -1651,7 +1957,7 @@ pub fn emit_cdc_insns( let change_type = match operation_mode { OperationMode::INSERT => 1, - OperationMode::UPDATE | OperationMode::SELECT => 0, + OperationMode::UPDATE { .. } | OperationMode::SELECT => 0, OperationMode::DELETE => -1, }; program.emit_int(change_type, turso_cdc_registers + 2); @@ -1732,62 +2038,82 @@ fn init_limit( t_ctx: &mut TranslateCtx, limit: &Option>, offset: &Option>, -) { +) -> Result<()> { if t_ctx.limit_ctx.is_none() && limit.is_some() { t_ctx.limit_ctx = Some(LimitCtx::new(program)); } let Some(limit_ctx) = &t_ctx.limit_ctx else { - return; + return Ok(()); }; if limit_ctx.initialize_counter { if let Some(expr) = limit { - if let Some(value) = try_fold_expr_to_i64(expr) { - program.emit_insn(Insn::Integer { - value, - dest: limit_ctx.reg_limit, - }); - } else { - let r = limit_ctx.reg_limit; - program.add_comment(program.offset(), "OFFSET expr"); - _ = translate_expr(program, None, expr, r, &t_ctx.resolver); - program.emit_insn(Insn::MustBeInt { reg: r }); + match expr.as_ref() { + Expr::Literal(Literal::Numeric(n)) => { + if let Ok(value) = n.parse::() { + program.add_comment(program.offset(), "LIMIT counter"); + program.emit_insn(Insn::Integer { + value, + dest: limit_ctx.reg_limit, + }); + } else { + program.emit_insn(Insn::Real { + value: n.parse::().unwrap(), + dest: limit_ctx.reg_limit, + }); + program.add_comment(program.offset(), "LIMIT counter"); + program.emit_insn(Insn::MustBeInt { + reg: limit_ctx.reg_limit, + }); + } + } + _ => { + let r = limit_ctx.reg_limit; + + _ = translate_expr(program, None, expr, r, &t_ctx.resolver); + program.emit_insn(Insn::MustBeInt { reg: r }); + } } } } if t_ctx.reg_offset.is_none() { if let Some(expr) = offset { - if let Some(value) = try_fold_expr_to_i64(expr) { - if value != 0 { - let reg = program.alloc_register(); - t_ctx.reg_offset = Some(reg); - program.emit_insn(Insn::Integer { value, dest: reg }); - let combined_reg = program.alloc_register(); - t_ctx.reg_limit_offset_sum = Some(combined_reg); - program.emit_insn(Insn::OffsetLimit { - limit_reg: limit_ctx.reg_limit, - offset_reg: reg, - combined_reg, - }); + let offset_reg = program.alloc_register(); + t_ctx.reg_offset = Some(offset_reg); + match expr.as_ref() { + Expr::Literal(Literal::Numeric(n)) => { + if let Ok(value) = n.parse::() { + program.emit_insn(Insn::Integer { + value, + dest: offset_reg, + }); + } else { + let value = n.parse::()?; + program.emit_insn(Insn::Real { + value, + dest: limit_ctx.reg_limit, + }); + program.emit_insn(Insn::MustBeInt { + reg: limit_ctx.reg_limit, + }); + } + } + _ => { + _ = translate_expr(program, None, expr, offset_reg, &t_ctx.resolver); } - } else { - let reg = program.alloc_register(); - t_ctx.reg_offset = Some(reg); - let r = reg; - - program.add_comment(program.offset(), "OFFSET expr"); - _ = translate_expr(program, None, expr, r, &t_ctx.resolver); - program.emit_insn(Insn::MustBeInt { reg: r }); - - let combined_reg = program.alloc_register(); - t_ctx.reg_limit_offset_sum = Some(combined_reg); - program.emit_insn(Insn::OffsetLimit { - limit_reg: limit_ctx.reg_limit, - offset_reg: reg, - combined_reg, - }); } + program.add_comment(program.offset(), "OFFSET counter"); + program.emit_insn(Insn::MustBeInt { reg: offset_reg }); + + let combined_reg = program.alloc_register(); + t_ctx.reg_limit_offset_sum = Some(combined_reg); + program.add_comment(program.offset(), "OFFSET + LIMIT"); + program.emit_insn(Insn::OffsetLimit { + limit_reg: limit_ctx.reg_limit, + offset_reg, + combined_reg, + }); } } @@ -1800,6 +2126,8 @@ fn init_limit( target_pc: main_loop_end, jump_if_null: false, }); + + Ok(()) } /// We have `Expr`s which have *not* had column references bound to them, diff --git a/core/translate/expr.rs b/core/translate/expr.rs index e110523f3..e082ca6ae 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -30,6 +30,7 @@ pub struct ConditionMetadata { pub jump_if_condition_is_true: bool, pub jump_target_when_true: BranchOffset, pub jump_target_when_false: BranchOffset, + pub jump_target_when_null: BranchOffset, } /// Container for register locations of values that can be referenced in RETURNING expressions @@ -154,128 +155,127 @@ macro_rules! expect_arguments_even { /// /// This is extracted from the original conditional implementation to be reusable. /// The logic exactly matches the original conditional InList implementation. +/// +/// An IN expression has one of the following formats: +/// ```sql +/// x IN (y1, y2,...,yN) +/// x IN (subquery) (Not yet implemented) +/// ``` +/// The result of an IN operator is one of TRUE, FALSE, or NULL. A NULL result +/// means that it cannot be determined if the LHS is contained in the RHS due +/// to the presence of NULL values. +/// +/// Currently, we do a simple full-scan, yet it's not ideal when there are many rows +/// on RHS. (Check sqlite's in-operator.md) +/// +/// Algorithm: +/// 1. Set the null-flag to false +/// 2. For each row in the RHS: +/// - Compare LHS and RHS +/// - If LHS matches RHS, returns TRUE +/// - If the comparison results in NULL, set the null-flag to true +/// 3. If the null-flag is true, return NULL +/// 4. Return FALSE +/// +/// A "NOT IN" operator is computed by first computing the equivalent IN +/// operator, then interchanging the TRUE and FALSE results. +// todo: Check right affinities #[instrument(skip(program, referenced_tables, resolver), level = Level::DEBUG)] fn translate_in_list( program: &mut ProgramBuilder, referenced_tables: Option<&TableReferences>, lhs: &ast::Expr, rhs: &[Box], - not: bool, condition_metadata: ConditionMetadata, + // dest if null should be in ConditionMetadata resolver: &Resolver, ) -> Result<()> { - // lhs is e.g. a column reference - // rhs is an Option> - // If rhs is None, it means the IN expression is always false, i.e. tbl.id IN (). - // If rhs is Some, it means the IN expression has a list of values to compare against, e.g. tbl.id IN (1, 2, 3). - // - // The IN expression is equivalent to a series of OR expressions. - // For example, `a IN (1, 2, 3)` is equivalent to `a = 1 OR a = 2 OR a = 3`. - // The NOT IN expression is equivalent to a series of AND expressions. - // For example, `a NOT IN (1, 2, 3)` is equivalent to `a != 1 AND a != 2 AND a != 3`. - // - // SQLite typically optimizes IN expressions to use a binary search on an ephemeral index if there are many values. - // For now we don't have the plumbing to do that, so we'll just emit a series of comparisons, - // which is what SQLite also does for small lists of values. - // TODO: Let's refactor this later to use a more efficient implementation conditionally based on the number of values. + let lhs_reg = if let Expr::Parenthesized(v) = lhs { + program.alloc_registers(v.len()) + } else { + program.alloc_register() + }; + let _ = translate_expr(program, referenced_tables, lhs, lhs_reg, resolver)?; + let mut check_null_reg = 0; + let label_ok = program.allocate_label(); - if rhs.is_empty() { - // If rhs is None, IN expressions are always false and NOT IN expressions are always true. - if not { - // On a trivially true NOT IN () expression we can only jump to the 'jump_target_when_true' label if 'jump_if_condition_is_true'; otherwise me must fall through. - // This is because in a more complex condition we might need to evaluate the rest of the condition. - // Note that we are already breaking up our WHERE clauses into a series of terms at "AND" boundaries, so right now we won't be running into cases where jumping on true would be incorrect, - // but once we have e.g. parenthesization and more complex conditions, not having this 'if' here would introduce a bug. - if condition_metadata.jump_if_condition_is_true { - program.emit_insn(Insn::Goto { - target_pc: condition_metadata.jump_target_when_true, - }); - } - } else { - program.emit_insn(Insn::Goto { - target_pc: condition_metadata.jump_target_when_false, - }); - } - return Ok(()); + if condition_metadata.jump_target_when_false != condition_metadata.jump_target_when_null { + check_null_reg = program.alloc_register(); + program.emit_insn(Insn::BitAnd { + lhs: lhs_reg, + rhs: lhs_reg, + dest: check_null_reg, + }); } - // The left hand side only needs to be evaluated once we have a list of values to compare against. - let lhs_reg = program.alloc_register(); - let _ = translate_expr(program, referenced_tables, lhs, lhs_reg, resolver)?; + for (i, expr) in rhs.iter().enumerate() { + let last_condition = i == rhs.len() - 1; + let rhs_reg = program.alloc_register(); + let _ = translate_expr(program, referenced_tables, expr, rhs_reg, resolver)?; - // The difference between a local jump and an "upper level" jump is that for example in this case: - // WHERE foo IN (1,2,3) OR bar = 5, - // we can immediately jump to the 'jump_target_when_true' label of the ENTIRE CONDITION if foo = 1, foo = 2, or foo = 3 without evaluating the bar = 5 condition. - // This is why in Binary-OR expressions we set jump_if_condition_is_true to true for the first condition. - // However, in this example: - // WHERE foo IN (1,2,3) AND bar = 5, - // we can't jump to the 'jump_target_when_true' label of the entire condition foo = 1, foo = 2, or foo = 3, because we still need to evaluate the bar = 5 condition later. - // This is why in that case we just jump over the rest of the IN conditions in this "local" branch which evaluates the IN condition. - let jump_target_when_true = if condition_metadata.jump_if_condition_is_true { - condition_metadata.jump_target_when_true - } else { - program.allocate_label() - }; + if check_null_reg != 0 && expr.can_be_null() { + program.emit_insn(Insn::BitAnd { + lhs: check_null_reg, + rhs: rhs_reg, + dest: check_null_reg, + }); + } - if !not { - // If it's an IN expression, we need to jump to the 'jump_target_when_true' label if any of the conditions are true. - for (i, expr) in rhs.iter().enumerate() { - let rhs_reg = program.alloc_register(); - let last_condition = i == rhs.len() - 1; - let _ = translate_expr(program, referenced_tables, expr, rhs_reg, resolver)?; - // If this is not the last condition, we need to jump to the 'jump_target_when_true' label if the condition is true. - if !last_condition { + if !last_condition + || condition_metadata.jump_target_when_false != condition_metadata.jump_target_when_null + { + if lhs_reg != rhs_reg { program.emit_insn(Insn::Eq { lhs: lhs_reg, rhs: rhs_reg, - target_pc: jump_target_when_true, + target_pc: label_ok, + // Use affinity instead flags: CmpInsFlags::default(), collation: program.curr_collation(), }); } else { - // If this is the last condition, we need to jump to the 'jump_target_when_false' label if there is no match. - program.emit_insn(Insn::Ne { - lhs: lhs_reg, - rhs: rhs_reg, - target_pc: condition_metadata.jump_target_when_false, - flags: CmpInsFlags::default().jump_if_null(), - collation: program.curr_collation(), + program.emit_insn(Insn::NotNull { + reg: lhs_reg, + target_pc: label_ok, }); } - } - // If we got here, then the last condition was a match, so we jump to the 'jump_target_when_true' label if 'jump_if_condition_is_true'. - // If not, we can just fall through without emitting an unnecessary instruction. - if condition_metadata.jump_if_condition_is_true { - program.emit_insn(Insn::Goto { - target_pc: condition_metadata.jump_target_when_true, - }); - } - } else { - // If it's a NOT IN expression, we need to jump to the 'jump_target_when_false' label if any of the conditions are true. - for expr in rhs.iter() { - let rhs_reg = program.alloc_register(); - let _ = translate_expr(program, referenced_tables, expr, rhs_reg, resolver)?; - program.emit_insn(Insn::Eq { + // sqlite3VdbeChangeP5(v, zAff[0]); + } else if lhs_reg != rhs_reg { + program.emit_insn(Insn::Ne { lhs: lhs_reg, rhs: rhs_reg, target_pc: condition_metadata.jump_target_when_false, - flags: CmpInsFlags::default().jump_if_null(), + flags: CmpInsFlags::default(), collation: program.curr_collation(), }); - } - // If we got here, then none of the conditions were a match, so we jump to the 'jump_target_when_true' label if 'jump_if_condition_is_true'. - // If not, we can just fall through without emitting an unnecessary instruction. - if condition_metadata.jump_if_condition_is_true { - program.emit_insn(Insn::Goto { - target_pc: condition_metadata.jump_target_when_true, + } else { + program.emit_insn(Insn::IsNull { + reg: lhs_reg, + target_pc: condition_metadata.jump_target_when_false, }); } } - if !condition_metadata.jump_if_condition_is_true { - program.preassign_label_to_next_insn(jump_target_when_true); + if check_null_reg != 0 { + program.emit_insn(Insn::IsNull { + reg: check_null_reg, + target_pc: condition_metadata.jump_target_when_null, + }); + program.emit_insn(Insn::Goto { + target_pc: condition_metadata.jump_target_when_false, + }); } + program.resolve_label(label_ok, program.offset()); + + // by default if IN expression is true we just continue to the next instruction + if condition_metadata.jump_if_condition_is_true { + program.emit_insn(Insn::Goto { + target_pc: condition_metadata.jump_target_when_true, + }); + } + // todo: deallocate check_null_reg + Ok(()) } @@ -402,16 +402,55 @@ pub fn translate_condition_expr( translate_expr(program, Some(referenced_tables), expr, reg, resolver)?; emit_cond_jump(program, condition_metadata, reg); } + ast::Expr::InList { lhs, not, rhs } => { + let ConditionMetadata { + jump_if_condition_is_true, + jump_target_when_true, + jump_target_when_false, + jump_target_when_null, + } = condition_metadata; + + // Adjust targets if `NOT IN` + let (adjusted_metadata, not_true_label, not_false_label) = if *not { + let not_true_label = program.allocate_label(); + let not_false_label = program.allocate_label(); + ( + ConditionMetadata { + jump_if_condition_is_true, + jump_target_when_true: not_true_label, + jump_target_when_false: not_false_label, + jump_target_when_null, + }, + Some(not_true_label), + Some(not_false_label), + ) + } else { + (condition_metadata, None, None) + }; + translate_in_list( program, Some(referenced_tables), lhs, rhs, - *not, - condition_metadata, + adjusted_metadata, resolver, )?; + + if *not { + // When IN is TRUE (match found), NOT IN should be FALSE + program.resolve_label(not_true_label.unwrap(), program.offset()); + program.emit_insn(Insn::Goto { + target_pc: jump_target_when_false, + }); + + // When IN is FALSE (no match), NOT IN should be TRUE + program.resolve_label(not_false_label.unwrap(), program.offset()); + program.emit_insn(Insn::Goto { + target_pc: jump_target_when_true, + }); + } } ast::Expr::Like { not, .. } => { let cur_reg = program.alloc_register(); @@ -879,6 +918,14 @@ pub fn translate_expr( emit_function_call(program, func_ctx, &[start_reg], target_register)?; Ok(target_register) } + VectorFunc::Vector32Sparse => { + let args = expect_arguments_exact!(args, 1, vector_func); + let start_reg = program.alloc_register(); + translate_expr(program, referenced_tables, &args[0], start_reg, resolver)?; + + emit_function_call(program, func_ctx, &[start_reg], target_register)?; + Ok(target_register) + } VectorFunc::Vector64 => { let args = expect_arguments_exact!(args, 1, vector_func); let start_reg = program.alloc_register(); @@ -904,7 +951,16 @@ pub fn translate_expr( emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?; Ok(target_register) } - VectorFunc::VectorDistanceEuclidean => { + VectorFunc::VectorDistanceL2 => { + let args = expect_arguments_exact!(args, 2, vector_func); + let regs = program.alloc_registers(2); + translate_expr(program, referenced_tables, &args[0], regs, resolver)?; + translate_expr(program, referenced_tables, &args[1], regs + 1, resolver)?; + + emit_function_call(program, func_ctx, &[regs, regs + 1], target_register)?; + Ok(target_register) + } + VectorFunc::VectorDistanceJaccard => { let args = expect_arguments_exact!(args, 2, vector_func); let regs = program.alloc_registers(2); translate_expr(program, referenced_tables, &args[0], regs, resolver)?; @@ -1087,51 +1143,66 @@ pub fn translate_expr( Ok(target_register) } ScalarFunc::Iif => { - if args.len() != 3 { - crate::bail_parse_error!( - "{} requires exactly 3 arguments", - srf.to_string() - ); + let args = expect_arguments_min!(args, 2, srf); + + let iif_end_label = program.allocate_label(); + let condition_reg = program.alloc_register(); + + for pair in args.chunks_exact(2) { + let condition_expr = &pair[0]; + let value_expr = &pair[1]; + let next_check_label = program.allocate_label(); + + translate_expr_no_constant_opt( + program, + referenced_tables, + condition_expr, + condition_reg, + resolver, + NoConstantOptReason::RegisterReuse, + )?; + + program.emit_insn(Insn::IfNot { + reg: condition_reg, + target_pc: next_check_label, + jump_if_null: true, + }); + + translate_expr_no_constant_opt( + program, + referenced_tables, + value_expr, + target_register, + resolver, + NoConstantOptReason::RegisterReuse, + )?; + program.emit_insn(Insn::Goto { + target_pc: iif_end_label, + }); + + program.preassign_label_to_next_insn(next_check_label); } - let temp_reg = program.alloc_register(); - translate_expr_no_constant_opt( - program, - referenced_tables, - &args[0], - temp_reg, - resolver, - NoConstantOptReason::RegisterReuse, - )?; - let jump_target_when_false = program.allocate_label(); - program.emit_insn(Insn::IfNot { - reg: temp_reg, - target_pc: jump_target_when_false, - jump_if_null: true, - }); - translate_expr_no_constant_opt( - program, - referenced_tables, - &args[1], - target_register, - resolver, - NoConstantOptReason::RegisterReuse, - )?; - let jump_target_result = program.allocate_label(); - program.emit_insn(Insn::Goto { - target_pc: jump_target_result, - }); - program.preassign_label_to_next_insn(jump_target_when_false); - translate_expr_no_constant_opt( - program, - referenced_tables, - &args[2], - target_register, - resolver, - NoConstantOptReason::RegisterReuse, - )?; - program.preassign_label_to_next_insn(jump_target_result); + + if args.len() % 2 != 0 { + translate_expr_no_constant_opt( + program, + referenced_tables, + args.last().unwrap(), + target_register, + resolver, + NoConstantOptReason::RegisterReuse, + )?; + } else { + program.emit_insn(Insn::Null { + dest: target_register, + dest_end: None, + }); + } + + program.preassign_label_to_next_insn(iif_end_label); Ok(target_register) } + ScalarFunc::Glob | ScalarFunc::Like => { if args.len() < 2 { crate::bail_parse_error!( @@ -1499,7 +1570,9 @@ pub fn translate_expr( Ok(target_register) } - ScalarFunc::SqliteVersion => { + ScalarFunc::SqliteVersion + | ScalarFunc::TursoVersion + | ScalarFunc::SqliteSourceId => { if !args.is_empty() { crate::bail_parse_error!("sqlite_version function with arguments"); } @@ -1519,28 +1592,6 @@ pub fn translate_expr( }); Ok(target_register) } - ScalarFunc::SqliteSourceId => { - if !args.is_empty() { - crate::bail_parse_error!( - "sqlite_source_id function with arguments" - ); - } - - let output_register = program.alloc_register(); - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: output_register, - dest: output_register, - func: func_ctx, - }); - - program.emit_insn(Insn::Copy { - src_reg: output_register, - dst_reg: target_register, - extra_amount: 0, - }); - Ok(target_register) - } ScalarFunc::Replace => { if !args.len() == 3 { crate::bail_parse_error!( @@ -1819,8 +1870,55 @@ pub fn translate_expr( Func::AlterTable(_) => unreachable!(), } } - ast::Expr::FunctionCallStar { .. } => { - crate::bail_parse_error!("FunctionCallStar in WHERE clause is not supported") + ast::Expr::FunctionCallStar { name, filter_over } => { + // Handle func(*) syntax as a function call with 0 arguments + // This is equivalent to func() for functions that accept 0 arguments + let args_count = 0; + let func_type = resolver.resolve_function(name.as_str(), args_count); + + if func_type.is_none() { + crate::bail_parse_error!("unknown function {}", name.as_str()); + } + + let func_ctx = FuncCtx { + func: func_type.unwrap(), + arg_count: args_count, + }; + + // Check if this function supports the (*) syntax by verifying it can be called with 0 args + match &func_ctx.func { + Func::Agg(_) => { + crate::bail_parse_error!( + "misuse of {} function {}(*)", + if filter_over.over_clause.is_some() { + "window" + } else { + "aggregate" + }, + name.as_str() + ) + } + // For supported functions, delegate to the existing FunctionCall logic + // by creating a synthetic FunctionCall with empty args + _ => { + let synthetic_call = ast::Expr::FunctionCall { + name: name.clone(), + distinctness: None, + args: vec![], // Empty args for func(*) + filter_over: filter_over.clone(), + order_by: vec![], // Empty order_by for func(*) + }; + + // Recursively call translate_expr with the synthetic function call + translate_expr( + program, + referenced_tables, + &synthetic_call, + target_register, + resolver, + ) + } + } } ast::Expr::Id(id) => { // Treat double-quoted identifiers as string literals (SQLite compatibility) @@ -1853,7 +1951,12 @@ pub fn translate_expr( let table = referenced_tables .unwrap() .find_table_by_internal_id(*table_ref_id) - .expect("table reference should be found"); + .unwrap_or_else(|| { + unreachable!( + "table reference should be found: {} (referenced_tables: {:?})", + table_ref_id, referenced_tables + ) + }); let Some(table_column) = table.get_column_at(*column) else { crate::bail_parse_error!("column index out of bounds"); @@ -1982,26 +2085,29 @@ pub fn translate_expr( // but wrap it with appropriate expression context handling let result_reg = target_register; - // Set result to NULL initially (matches SQLite behavior) - program.emit_insn(Insn::Null { - dest: result_reg, + let dest_if_false = program.allocate_label(); + let dest_if_null = program.allocate_label(); + let dest_if_true = program.allocate_label(); + + // Ideally we wouldn't need a tmp register, but currently if an IN expression + // is used inside an aggregator the target_register is cleared on every iteration, + // losing the state of the aggregator. + let tmp = program.alloc_register(); + program.emit_no_constant_insn(Insn::Null { + dest: tmp, dest_end: None, }); - let dest_if_false = program.allocate_label(); - let label_integer_conversion = program.allocate_label(); - - // Call the core InList logic with expression-appropriate condition metadata translate_in_list( program, referenced_tables, lhs, rhs, - *not, ConditionMetadata { jump_if_condition_is_true: false, - jump_target_when_true: label_integer_conversion, // will be resolved below + jump_target_when_true: dest_if_true, jump_target_when_false: dest_if_false, + jump_target_when_null: dest_if_null, }, resolver, )?; @@ -2009,27 +2115,30 @@ pub fn translate_expr( // condition true: set result to 1 program.emit_insn(Insn::Integer { value: 1, - dest: result_reg, - }); - program.emit_insn(Insn::Goto { - target_pc: label_integer_conversion, + dest: tmp, }); // False path: set result to 0 program.resolve_label(dest_if_false, program.offset()); - program.emit_insn(Insn::Integer { - value: 0, - dest: result_reg, - }); - - program.resolve_label(label_integer_conversion, program.offset()); // Force integer conversion with AddImm 0 program.emit_insn(Insn::AddImm { - register: result_reg, + register: tmp, value: 0, }); + if *not { + program.emit_insn(Insn::Not { + reg: tmp, + dest: tmp, + }); + } + program.resolve_label(dest_if_null, program.offset()); + program.emit_insn(Insn::Copy { + src_reg: tmp, + dst_reg: result_reg, + extra_amount: 0, + }); Ok(result_reg) } ast::Expr::InSelect { .. } => { @@ -3272,11 +3381,14 @@ impl ParamState { /// TryCanonicalColumnsFirst means that canonical columns take precedence over result columns. This is used for e.g. WHERE clauses. /// /// ResultColumnsNotAllowed means that referring to result columns is not allowed. This is used e.g. for DML statements. +/// +/// AllowUnboundIdentifiers means that unbound identifiers are allowed. This is used for INSERT ... ON CONFLICT DO UPDATE SET ... where binding is handled later than this phase. #[derive(Debug, Clone, PartialEq, Eq)] pub enum BindingBehavior { TryResultColumnsFirst, TryCanonicalColumnsFirst, ResultColumnsNotAllowed, + AllowUnboundIdentifiers, } /// Rewrite ast::Expr in place, binding Column references/rewriting Expr::Id -> Expr::Column @@ -3326,242 +3438,264 @@ pub fn bind_and_rewrite_expr<'a>( } _ => {} } - if let Some(referenced_tables) = &mut referenced_tables { - match expr { - Expr::Id(id) => { - let normalized_id = normalize_ident(id.as_str()); + match expr { + Expr::Id(id) => { + let Some(referenced_tables) = &mut referenced_tables else { + if binding_behavior == BindingBehavior::AllowUnboundIdentifiers { + return Ok(WalkControl::Continue); + } + crate::bail_parse_error!("no such column: {}", id.as_str()); + }; + let normalized_id = normalize_ident(id.as_str()); - if binding_behavior == BindingBehavior::TryResultColumnsFirst { - if let Some(result_columns) = result_columns { - for result_column in result_columns.iter() { - if result_column.name(referenced_tables).is_some_and(|name| { - name.eq_ignore_ascii_case(&normalized_id) - }) { + if binding_behavior == BindingBehavior::TryResultColumnsFirst { + if let Some(result_columns) = result_columns { + for result_column in result_columns.iter() { + if let Some(alias) = &result_column.alias { + if alias.eq_ignore_ascii_case(&normalized_id) { *expr = result_column.expr.clone(); return Ok(WalkControl::Continue); } } } } - let mut match_result = None; + } + let mut match_result = None; - // First check joined tables - for joined_table in referenced_tables.joined_tables().iter() { - let col_idx = joined_table.table.columns().iter().position(|c| { + // First check joined tables + for joined_table in referenced_tables.joined_tables().iter() { + let col_idx = joined_table.table.columns().iter().position(|c| { + c.name + .as_ref() + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + }); + if col_idx.is_some() { + if match_result.is_some() { + let mut ok = false; + // Column name ambiguity is ok if it is in the USING clause because then it is deduplicated + // and the left table is used. + if let Some(join_info) = &joined_table.join_info { + if join_info.using.iter().any(|using_col| { + using_col.as_str().eq_ignore_ascii_case(&normalized_id) + }) { + ok = true; + } + } + if !ok { + crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); + } + } else { + let col = + joined_table.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = Some(( + joined_table.internal_id, + col_idx.unwrap(), + col.is_rowid_alias, + )); + } + // only if we haven't found a match, check for explicit rowid reference + } else { + let is_btree_table = matches!(joined_table.table, Table::BTree(_)); + if is_btree_table { + if let Some(row_id_expr) = parse_row_id( + &normalized_id, + referenced_tables.joined_tables()[0].internal_id, + || referenced_tables.joined_tables().len() != 1, + )? { + *expr = row_id_expr; + return Ok(WalkControl::Continue); + } + } + } + } + + // Then check outer query references, if we still didn't find something. + // Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous) + // but in the case of subqueries, the inner query takes precedence. + // For example: + // SELECT * FROM t WHERE x = (SELECT x FROM t2) + // In this case, there is no ambiguity: + // - x in the outer query refers to t.x, + // - x in the inner query refers to t2.x. + if match_result.is_none() { + for outer_ref in referenced_tables.outer_query_refs().iter() { + let col_idx = outer_ref.table.columns().iter().position(|c| { c.name .as_ref() .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) }); if col_idx.is_some() { if match_result.is_some() { - let mut ok = false; - // Column name ambiguity is ok if it is in the USING clause because then it is deduplicated - // and the left table is used. - if let Some(join_info) = &joined_table.join_info { - if join_info.using.iter().any(|using_col| { - using_col.as_str().eq_ignore_ascii_case(&normalized_id) - }) { - ok = true; - } - } - if !ok { - crate::bail_parse_error!( - "Column {} is ambiguous", - id.as_str() - ); - } - } else { - let col = - joined_table.table.columns().get(col_idx.unwrap()).unwrap(); - match_result = Some(( - joined_table.internal_id, - col_idx.unwrap(), - col.is_rowid_alias, - )); + crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); } - // only if we haven't found a match, check for explicit rowid reference - } else if let Some(row_id_expr) = parse_row_id( - &normalized_id, - referenced_tables.joined_tables()[0].internal_id, - || referenced_tables.joined_tables().len() != 1, - )? { - *expr = row_id_expr; - - return Ok(WalkControl::Continue); + let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = Some(( + outer_ref.internal_id, + col_idx.unwrap(), + col.is_rowid_alias, + )); } } + } - // Then check outer query references, if we still didn't find something. - // Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous) - // but in the case of subqueries, the inner query takes precedence. - // For example: - // SELECT * FROM t WHERE x = (SELECT x FROM t2) - // In this case, there is no ambiguity: - // - x in the outer query refers to t.x, - // - x in the inner query refers to t2.x. - if match_result.is_none() { - for outer_ref in referenced_tables.outer_query_refs().iter() { - let col_idx = outer_ref.table.columns().iter().position(|c| { - c.name.as_ref().is_some_and(|name| { - name.eq_ignore_ascii_case(&normalized_id) - }) - }); - if col_idx.is_some() { - if match_result.is_some() { - crate::bail_parse_error!( - "Column {} is ambiguous", - id.as_str() - ); - } - let col = - outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); - match_result = Some(( - outer_ref.internal_id, - col_idx.unwrap(), - col.is_rowid_alias, - )); - } - } - } + if let Some((table_id, col_idx, is_rowid_alias)) = match_result { + *expr = Expr::Column { + database: None, // TODO: support different databases + table: table_id, + column: col_idx, + is_rowid_alias, + }; + referenced_tables.mark_column_used(table_id, col_idx); + return Ok(WalkControl::Continue); + } - if let Some((table_id, col_idx, is_rowid_alias)) = match_result { - *expr = Expr::Column { - database: None, // TODO: support different databases - table: table_id, - column: col_idx, - is_rowid_alias, - }; - referenced_tables.mark_column_used(table_id, col_idx); - return Ok(WalkControl::Continue); - } - - if binding_behavior == BindingBehavior::TryCanonicalColumnsFirst { - if let Some(result_columns) = result_columns { - for result_column in result_columns.iter() { - if result_column.name(referenced_tables).is_some_and(|name| { - name.eq_ignore_ascii_case(&normalized_id) - }) { + if binding_behavior == BindingBehavior::TryCanonicalColumnsFirst { + if let Some(result_columns) = result_columns { + for result_column in result_columns.iter() { + if let Some(alias) = &result_column.alias { + if alias.eq_ignore_ascii_case(&normalized_id) { *expr = result_column.expr.clone(); return Ok(WalkControl::Continue); } } } } - - // SQLite behavior: Only double-quoted identifiers get fallback to string literals - // Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns - if id.quoted_with('"') { - // Convert failed double-quoted identifier to string literal - *expr = Expr::Literal(ast::Literal::String(id.as_literal())); - return Ok(WalkControl::Continue); - } else { - // Unquoted identifiers must resolve to columns - no fallback - crate::bail_parse_error!("no such column: {}", id.as_str()) - } } - Expr::Qualified(tbl, id) => { - tracing::debug!("bind_and_rewrite_expr({:?}, {:?})", tbl, id); - let normalized_table_name = normalize_ident(tbl.as_str()); - let matching_tbl = referenced_tables - .find_table_and_internal_id_by_identifier(&normalized_table_name); - if matching_tbl.is_none() { - crate::bail_parse_error!("no such table: {}", normalized_table_name); - } - let (tbl_id, tbl) = matching_tbl.unwrap(); - let normalized_id = normalize_ident(id.as_str()); - let col_idx = tbl.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - }); - if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? { - *expr = row_id_expr; + // SQLite behavior: Only double-quoted identifiers get fallback to string literals + // Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns + if id.quoted_with('"') { + // Convert failed double-quoted identifier to string literal + *expr = Expr::Literal(ast::Literal::String(id.as_literal())); + return Ok(WalkControl::Continue); + } else { + // Unquoted identifiers must resolve to columns - no fallback + crate::bail_parse_error!("no such column: {}", id.as_str()) + } + } + Expr::Qualified(tbl, id) => { + tracing::debug!("bind_and_rewrite_expr({:?}, {:?})", tbl, id); + let Some(referenced_tables) = &mut referenced_tables else { + if binding_behavior == BindingBehavior::AllowUnboundIdentifiers { return Ok(WalkControl::Continue); } - let Some(col_idx) = col_idx else { - crate::bail_parse_error!("no such column: {}", normalized_id); - }; - let col = tbl.columns().get(col_idx).unwrap(); - *expr = Expr::Column { - database: None, // TODO: support different databases - table: tbl_id, - column: col_idx, - is_rowid_alias: col.is_rowid_alias, - }; - tracing::debug!("rewritten to column"); - referenced_tables.mark_column_used(tbl_id, col_idx); + crate::bail_parse_error!( + "no such column: {}.{}", + tbl.as_str(), + id.as_str() + ); + }; + let normalized_table_name = normalize_ident(tbl.as_str()); + let matching_tbl = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_table_name); + if matching_tbl.is_none() { + crate::bail_parse_error!("no such table: {}", normalized_table_name); + } + let (tbl_id, tbl) = matching_tbl.unwrap(); + let normalized_id = normalize_ident(id.as_str()); + let col_idx = tbl.columns().iter().position(|c| { + c.name + .as_ref() + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + }); + if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? { + *expr = row_id_expr; + return Ok(WalkControl::Continue); } - Expr::DoublyQualified(db_name, tbl_name, col_name) => { - let normalized_col_name = normalize_ident(col_name.as_str()); + let Some(col_idx) = col_idx else { + crate::bail_parse_error!("no such column: {}", normalized_id); + }; + let col = tbl.columns().get(col_idx).unwrap(); + *expr = Expr::Column { + database: None, // TODO: support different databases + table: tbl_id, + column: col_idx, + is_rowid_alias: col.is_rowid_alias, + }; + tracing::debug!("rewritten to column"); + referenced_tables.mark_column_used(tbl_id, col_idx); + return Ok(WalkControl::Continue); + } + Expr::DoublyQualified(db_name, tbl_name, col_name) => { + let Some(referenced_tables) = &mut referenced_tables else { + if binding_behavior == BindingBehavior::AllowUnboundIdentifiers { + return Ok(WalkControl::Continue); + } + crate::bail_parse_error!( + "no such column: {}.{}.{}", + db_name.as_str(), + tbl_name.as_str(), + col_name.as_str() + ); + }; + let normalized_col_name = normalize_ident(col_name.as_str()); - // Create a QualifiedName and use existing resolve_database_id method - let qualified_name = ast::QualifiedName { - db_name: Some(db_name.clone()), - name: tbl_name.clone(), - alias: None, + // Create a QualifiedName and use existing resolve_database_id method + let qualified_name = ast::QualifiedName { + db_name: Some(db_name.clone()), + name: tbl_name.clone(), + alias: None, + }; + let database_id = connection.resolve_database_id(&qualified_name)?; + + // Get the table from the specified database + let table = connection + .with_schema(database_id, |schema| schema.get_table(tbl_name.as_str())) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "no such table: {}.{}", + db_name.as_str(), + tbl_name.as_str() + )) + })?; + + // Find the column in the table + let col_idx = table + .columns() + .iter() + .position(|c| { + c.name + .as_ref() + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_col_name)) + }) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column: {}.{}.{} not found", + db_name.as_str(), + tbl_name.as_str(), + col_name.as_str() + )) + })?; + + let col = table.columns().get(col_idx).unwrap(); + + // Check if this is a rowid alias + let is_rowid_alias = col.is_rowid_alias; + + // Convert to Column expression - since this is a cross-database reference, + // we need to create a synthetic table reference for it + // For now, we'll error if the table isn't already in the referenced tables + let normalized_tbl_name = normalize_ident(tbl_name.as_str()); + let matching_tbl = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_tbl_name); + + if let Some((tbl_id, _)) = matching_tbl { + // Table is already in referenced tables, use existing internal ID + *expr = Expr::Column { + database: Some(database_id), + table: tbl_id, + column: col_idx, + is_rowid_alias, }; - let database_id = connection.resolve_database_id(&qualified_name)?; - - // Get the table from the specified database - let table = connection - .with_schema(database_id, |schema| schema.get_table(tbl_name.as_str())) - .ok_or_else(|| { - crate::LimboError::ParseError(format!( - "no such table: {}.{}", - db_name.as_str(), - tbl_name.as_str() - )) - })?; - - // Find the column in the table - let col_idx = table - .columns() - .iter() - .position(|c| { - c.name.as_ref().is_some_and(|name| { - name.eq_ignore_ascii_case(&normalized_col_name) - }) - }) - .ok_or_else(|| { - crate::LimboError::ParseError(format!( - "Column: {}.{}.{} not found", - db_name.as_str(), - tbl_name.as_str(), - col_name.as_str() - )) - })?; - - let col = table.columns().get(col_idx).unwrap(); - - // Check if this is a rowid alias - let is_rowid_alias = col.is_rowid_alias; - - // Convert to Column expression - since this is a cross-database reference, - // we need to create a synthetic table reference for it - // For now, we'll error if the table isn't already in the referenced tables - let normalized_tbl_name = normalize_ident(tbl_name.as_str()); - let matching_tbl = referenced_tables - .find_table_and_internal_id_by_identifier(&normalized_tbl_name); - - if let Some((tbl_id, _)) = matching_tbl { - // Table is already in referenced tables, use existing internal ID - *expr = Expr::Column { - database: Some(database_id), - table: tbl_id, - column: col_idx, - is_rowid_alias, - }; - referenced_tables.mark_column_used(tbl_id, col_idx); - } else { - return Err(crate::LimboError::ParseError(format!( + referenced_tables.mark_column_used(tbl_id, col_idx); + } else { + return Err(crate::LimboError::ParseError(format!( "table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined" ))); - } } - _ => {} } + _ => {} } Ok(WalkControl::Continue) }, diff --git a/core/translate/fkeys.rs b/core/translate/fkeys.rs new file mode 100644 index 000000000..32a7099ad --- /dev/null +++ b/core/translate/fkeys.rs @@ -0,0 +1,934 @@ +use turso_parser::ast::Expr; + +use super::ProgramBuilder; +use crate::{ + schema::{BTreeTable, ForeignKey, Index, ResolvedFkRef, ROWID_SENTINEL}, + translate::{emitter::Resolver, planner::ROWID_STRS}, + vdbe::{ + builder::CursorType, + insn::{CmpInsFlags, Insn}, + BranchOffset, + }, + Result, +}; +use std::{collections::HashSet, num::NonZeroUsize, sync::Arc}; + +#[inline] +pub fn emit_guarded_fk_decrement(program: &mut ProgramBuilder, label: BranchOffset) { + program.emit_insn(Insn::FkIfZero { + deferred: true, + target_pc: label, + }); + program.emit_insn(Insn::FkCounter { + increment_value: -1, + deferred: true, + }); +} + +/// Open a read cursor on an index and return its cursor id. +#[inline] +pub fn open_read_index(program: &mut ProgramBuilder, idx: &Arc) -> usize { + let icur = program.alloc_cursor_id(CursorType::BTreeIndex(idx.clone())); + program.emit_insn(Insn::OpenRead { + cursor_id: icur, + root_page: idx.root_page, + db: 0, + }); + icur +} + +/// Open a read cursor on a table and return its cursor id. +#[inline] +pub fn open_read_table(program: &mut ProgramBuilder, tbl: &Arc) -> usize { + let tcur = program.alloc_cursor_id(CursorType::BTreeTable(tbl.clone())); + program.emit_insn(Insn::OpenRead { + cursor_id: tcur, + root_page: tbl.root_page, + db: 0, + }); + tcur +} + +/// Copy `len` registers starting at `src_start` to a fresh block and apply index affinities. +/// Returns the destination start register. +#[inline] +fn copy_with_affinity( + program: &mut ProgramBuilder, + src_start: usize, + len: usize, + idx: &Index, + aff_from_tbl: &BTreeTable, +) -> usize { + let dst = program.alloc_registers(len); + for i in 0..len { + program.emit_insn(Insn::Copy { + src_reg: src_start + i, + dst_reg: dst + i, + extra_amount: 0, + }); + } + if let Some(count) = NonZeroUsize::new(len) { + program.emit_insn(Insn::Affinity { + start_reg: dst, + count, + affinities: build_index_affinity_string(idx, aff_from_tbl), + }); + } + dst +} + +/// Issue an index probe using `Found`/`NotFound` and route to `on_found`/`on_not_found`. +pub fn index_probe( + program: &mut ProgramBuilder, + icur: usize, + record_reg: usize, + num_regs: usize, + mut on_found: F, + mut on_not_found: G, +) -> Result<()> +where + F: FnMut(&mut ProgramBuilder) -> Result<()>, + G: FnMut(&mut ProgramBuilder) -> Result<()>, +{ + let lbl_found = program.allocate_label(); + let lbl_join = program.allocate_label(); + + program.emit_insn(Insn::Found { + cursor_id: icur, + target_pc: lbl_found, + record_reg, + num_regs, + }); + + // NOT FOUND path + on_not_found(program)?; + program.emit_insn(Insn::Goto { + target_pc: lbl_join, + }); + + // FOUND path + program.preassign_label_to_next_insn(lbl_found); + on_found(program)?; + + // Join & close once + program.preassign_label_to_next_insn(lbl_join); + program.emit_insn(Insn::Close { cursor_id: icur }); + Ok(()) +} + +/// Iterate a table and call `on_match` when all child columns equal the key at `parent_key_start`. +/// Skips rows where any FK column is NULL. If `self_exclude_rowid` is Some, the row with that rowid is skipped. +fn table_scan_match_any( + program: &mut ProgramBuilder, + child_tbl: &Arc, + child_cols: &[String], + parent_key_start: usize, + self_exclude_rowid: Option, + mut on_match: F, +) -> Result<()> +where + F: FnMut(&mut ProgramBuilder) -> Result<()>, +{ + let ccur = open_read_table(program, child_tbl); + let done = program.allocate_label(); + program.emit_insn(Insn::Rewind { + cursor_id: ccur, + pc_if_empty: done, + }); + + let loop_top = program.allocate_label(); + program.preassign_label_to_next_insn(loop_top); + let next_row = program.allocate_label(); + + // Compare each FK column to parent key component. + for (i, cname) in child_cols.iter().enumerate() { + let (pos, _) = child_tbl.get_column(cname).ok_or_else(|| { + crate::LimboError::InternalError(format!("child col {cname} missing")) + })?; + let tmp = program.alloc_register(); + program.emit_insn(Insn::Column { + cursor_id: ccur, + column: pos, + dest: tmp, + default: None, + }); + program.emit_insn(Insn::IsNull { + reg: tmp, + target_pc: next_row, + }); + + let cont = program.allocate_label(); + program.emit_insn(Insn::Eq { + lhs: tmp, + rhs: parent_key_start + i, + target_pc: cont, + flags: CmpInsFlags::default().jump_if_null(), + collation: Some(super::collate::CollationSeq::Binary), + }); + program.emit_insn(Insn::Goto { + target_pc: next_row, + }); + program.preassign_label_to_next_insn(cont); + } + + //self-reference exclusion on rowid + if let Some(parent_rowid) = self_exclude_rowid { + let child_rowid = program.alloc_register(); + let skip = program.allocate_label(); + program.emit_insn(Insn::RowId { + cursor_id: ccur, + dest: child_rowid, + }); + program.emit_insn(Insn::Eq { + lhs: child_rowid, + rhs: parent_rowid, + target_pc: skip, + flags: CmpInsFlags::default(), + collation: None, + }); + on_match(program)?; + program.preassign_label_to_next_insn(skip); + } else { + on_match(program)?; + } + + program.preassign_label_to_next_insn(next_row); + program.emit_insn(Insn::Next { + cursor_id: ccur, + pc_if_next: loop_top, + }); + + program.preassign_label_to_next_insn(done); + program.emit_insn(Insn::Close { cursor_id: ccur }); + Ok(()) +} + +/// Build the index affinity mask string (one char per indexed column). +#[inline] +pub fn build_index_affinity_string(idx: &Index, table: &BTreeTable) -> String { + idx.columns + .iter() + .map(|ic| table.columns[ic.pos_in_table].affinity().aff_mask()) + .collect() +} + +/// Increment a foreign key violation counter; for deferred FKs, this is a global counter +/// on the connection; for immediate FKs, this is a per-statement counter in the program state. +pub fn emit_fk_violation(program: &mut ProgramBuilder, fk: &ForeignKey) -> Result<()> { + program.emit_insn(Insn::FkCounter { + increment_value: 1, + deferred: fk.deferred, + }); + Ok(()) +} + +/// Stabilize the NEW row image for FK checks (UPDATE): +/// fill in unmodified PK columns from the current row so the NEW PK vector is complete. +pub fn stabilize_new_row_for_fk( + program: &mut ProgramBuilder, + table_btree: &BTreeTable, + set_clauses: &[(usize, Box)], + cursor_id: usize, + start: usize, + rowid_new_reg: usize, +) -> Result<()> { + if table_btree.primary_key_columns.is_empty() { + return Ok(()); + } + let set_cols: HashSet = set_clauses + .iter() + .filter_map(|(i, _)| if *i == ROWID_SENTINEL { None } else { Some(*i) }) + .collect(); + + for (pk_name, _) in &table_btree.primary_key_columns { + let (pos, col) = table_btree + .get_column(pk_name) + .ok_or_else(|| crate::LimboError::InternalError(format!("pk col {pk_name} missing")))?; + if !set_cols.contains(&pos) { + if col.is_rowid_alias { + program.emit_insn(Insn::Copy { + src_reg: rowid_new_reg, + dst_reg: start + pos, + extra_amount: 0, + }); + } else { + program.emit_insn(Insn::Column { + cursor_id, + column: pos, + dest: start + pos, + default: None, + }); + } + } + } + Ok(()) +} + +/// Parent-side checks when the parent key might change (UPDATE on parent): +/// Detect if any child references the OLD key (potential violation), and if any references the NEW key +/// (which cancels one potential violation). For composite keys this builds OLD/NEW vectors first. +#[allow(clippy::too_many_arguments)] +pub fn emit_parent_key_change_checks( + program: &mut ProgramBuilder, + resolver: &Resolver, + table_btree: &BTreeTable, + indexes_to_update: impl Iterator>, + cursor_id: usize, + old_rowid_reg: usize, + start: usize, + rowid_new_reg: usize, + rowid_set_clause_reg: Option, + set_clauses: &[(usize, Box)], +) -> Result<()> { + let updated_positions: HashSet = set_clauses.iter().map(|(i, _)| *i).collect(); + let incoming = resolver + .schema + .resolved_fks_referencing(&table_btree.name)?; + let affects_pk = incoming + .iter() + .any(|r| r.parent_key_may_change(&updated_positions, table_btree)); + if !affects_pk { + return Ok(()); + } + + let primary_key_is_rowid_alias = table_btree.get_rowid_alias_column().is_some(); + + if primary_key_is_rowid_alias || table_btree.primary_key_columns.is_empty() { + emit_rowid_pk_change_check( + program, + &incoming, + resolver, + old_rowid_reg, + rowid_set_clause_reg.unwrap_or(old_rowid_reg), + )?; + } + + for index in indexes_to_update { + emit_parent_index_key_change_checks( + program, + cursor_id, + start, + old_rowid_reg, + rowid_new_reg, + &incoming, + resolver, + table_btree, + index.as_ref(), + )?; + } + Ok(()) +} + +/// Rowid-table parent PK change: compare rowid OLD vs NEW; if changed, run two-pass counters. +pub fn emit_rowid_pk_change_check( + program: &mut ProgramBuilder, + incoming: &[ResolvedFkRef], + resolver: &Resolver, + old_rowid_reg: usize, + new_rowid_reg: usize, +) -> Result<()> { + let skip = program.allocate_label(); + program.emit_insn(Insn::Eq { + lhs: new_rowid_reg, + rhs: old_rowid_reg, + target_pc: skip, + flags: CmpInsFlags::default(), + collation: None, + }); + + let old_pk = program.alloc_register(); + let new_pk = program.alloc_register(); + program.emit_insn(Insn::Copy { + src_reg: old_rowid_reg, + dst_reg: old_pk, + extra_amount: 0, + }); + program.emit_insn(Insn::Copy { + src_reg: new_rowid_reg, + dst_reg: new_pk, + extra_amount: 0, + }); + + emit_fk_parent_pk_change_counters(program, incoming, resolver, old_pk, new_pk, 1)?; + program.preassign_label_to_next_insn(skip); + Ok(()) +} + +/// Foreign keys are only legal if the referenced parent key is: +/// 1. The rowid alias (no separate index) +/// 2. Part of a primary key / unique index (there is no practical difference between the two) +/// +/// If the foreign key references a composite key, all of the columns in the key must be referenced. +/// E.g. +/// CREATE TABLE parent (a, b, c, PRIMARY KEY (a, b, c)); +/// CREATE TABLE child (a, b, c, FOREIGN KEY (a, b, c) REFERENCES parent (a, b, c)); +/// +/// Whereas this is not allowed: +/// CREATE TABLE parent (a, b, c, PRIMARY KEY (a, b, c)); +/// CREATE TABLE child (a, b, c, FOREIGN KEY (a, b) REFERENCES parent (a, b, c)); +/// +/// This function checks if the parent key has changed by comparing the OLD and NEW values. +/// If the parent key has changed, it emits the counters for the foreign keys. +/// If the parent key has not changed, it does nothing. +#[allow(clippy::too_many_arguments)] +pub fn emit_parent_index_key_change_checks( + program: &mut ProgramBuilder, + cursor_id: usize, + new_values_start: usize, + old_rowid_reg: usize, + new_rowid_reg: usize, + incoming: &[ResolvedFkRef], + resolver: &Resolver, + table_btree: &BTreeTable, + index: &Index, +) -> Result<()> { + let idx_len = index.columns.len(); + + let old_key = program.alloc_registers(idx_len); + for (i, index_col) in index.columns.iter().enumerate() { + let pos_in_table = index_col.pos_in_table; + let column = &table_btree.columns[pos_in_table]; + if column.is_rowid_alias { + program.emit_insn(Insn::Copy { + src_reg: old_rowid_reg, + dst_reg: old_key + i, + extra_amount: 0, + }); + } else { + program.emit_insn(Insn::Column { + cursor_id, + column: pos_in_table, + dest: old_key + i, + default: None, + }); + } + } + let new_key = program.alloc_registers(idx_len); + for (i, index_col) in index.columns.iter().enumerate() { + let pos_in_table = index_col.pos_in_table; + let column = &table_btree.columns[pos_in_table]; + let src = if column.is_rowid_alias { + new_rowid_reg + } else { + new_values_start + pos_in_table + }; + program.emit_insn(Insn::Copy { + src_reg: src, + dst_reg: new_key + i, + extra_amount: 0, + }); + } + + let skip = program.allocate_label(); + let changed = program.allocate_label(); + for i in 0..idx_len { + let next = if i + 1 == idx_len { + None + } else { + Some(program.allocate_label()) + }; + program.emit_insn(Insn::Eq { + lhs: old_key + i, + rhs: new_key + i, + target_pc: next.unwrap_or(skip), + flags: CmpInsFlags::default(), + collation: None, + }); + program.emit_insn(Insn::Goto { target_pc: changed }); + if let Some(n) = next { + program.preassign_label_to_next_insn(n); + } + } + + program.preassign_label_to_next_insn(changed); + emit_fk_parent_pk_change_counters(program, incoming, resolver, old_key, new_key, idx_len)?; + program.preassign_label_to_next_insn(skip); + Ok(()) +} + +/// Two-pass parent-side maintenance for UPDATE of a parent key: +/// 1. Probe child for OLD key, increment deferred counter if any references exist. +/// 2. Probe child for NEW key, guarded decrement cancels exactly one increment if present +pub fn emit_fk_parent_pk_change_counters( + program: &mut ProgramBuilder, + incoming: &[ResolvedFkRef], + resolver: &Resolver, + old_pk_start: usize, + new_pk_start: usize, + n_cols: usize, +) -> Result<()> { + for fk_ref in incoming { + emit_fk_parent_key_probe( + program, + resolver, + fk_ref, + old_pk_start, + n_cols, + ParentProbePass::Old, + )?; + emit_fk_parent_key_probe( + program, + resolver, + fk_ref, + new_pk_start, + n_cols, + ParentProbePass::New, + )?; + } + Ok(()) +} + +#[derive(Clone, Copy)] +enum ParentProbePass { + Old, + New, +} + +/// Probe the child side for a given parent key +fn emit_fk_parent_key_probe( + program: &mut ProgramBuilder, + resolver: &Resolver, + fk_ref: &ResolvedFkRef, + parent_key_start: usize, + n_cols: usize, + pass: ParentProbePass, +) -> Result<()> { + let child_tbl = &fk_ref.child_table; + let child_cols = &fk_ref.fk.child_columns; + let is_deferred = fk_ref.fk.deferred; + + let on_match = |p: &mut ProgramBuilder| -> Result<()> { + match (is_deferred, pass) { + // OLD key referenced by a child + (_, ParentProbePass::Old) => { + emit_fk_violation(p, &fk_ref.fk)?; + } + + // NEW key referenced by a child (cancel one deferred violation) + (true, ParentProbePass::New) => { + // Guard to avoid underflow if OLD pass didn't increment. + let skip = p.allocate_label(); + emit_guarded_fk_decrement(p, skip); + p.preassign_label_to_next_insn(skip); + } + // Immediate FK on NEW pass: nothing to cancel; do nothing. + (false, ParentProbePass::New) => {} + } + Ok(()) + }; + + // Prefer exact child index on (child_cols...) + let idx = resolver.schema.get_indices(&child_tbl.name).find(|ix| { + ix.columns.len() == child_cols.len() + && ix + .columns + .iter() + .zip(child_cols.iter()) + .all(|(ic, cc)| ic.name.eq_ignore_ascii_case(cc)) + }); + + if let Some(ix) = idx { + let icur = open_read_index(program, ix); + let probe = copy_with_affinity(program, parent_key_start, n_cols, ix, child_tbl); + + // FOUND => on_match; NOT FOUND => no-op + index_probe(program, icur, probe, n_cols, on_match, |_p| Ok(()))?; + } else { + // Table scan fallback + table_scan_match_any( + program, + child_tbl, + child_cols, + parent_key_start, + None, + on_match, + )?; + } + + Ok(()) +} + +/// Build a parent key vector (in FK parent-column order) into `dest_start`. +/// Handles rowid aliasing and explicit ROWID names; uses current row for non-rowid columns. +fn build_parent_key( + program: &mut ProgramBuilder, + parent_bt: &BTreeTable, + parent_cols: &[String], + parent_cursor_id: usize, + parent_rowid_reg: usize, + dest_start: usize, +) -> Result<()> { + for (i, pcol) in parent_cols.iter().enumerate() { + let src = if ROWID_STRS.iter().any(|s| pcol.eq_ignore_ascii_case(s)) { + parent_rowid_reg + } else { + let (pos, col) = parent_bt + .get_column(pcol) + .ok_or_else(|| crate::LimboError::InternalError(format!("col {pcol} missing")))?; + if col.is_rowid_alias { + parent_rowid_reg + } else { + program.emit_insn(Insn::Column { + cursor_id: parent_cursor_id, + column: pos, + dest: dest_start + i, + default: None, + }); + continue; + } + }; + program.emit_insn(Insn::Copy { + src_reg: src, + dst_reg: dest_start + i, + extra_amount: 0, + }); + } + Ok(()) +} + +/// Child-side FK maintenance for UPDATE/UPSERT: +/// If any FK columns of this child row changed: +/// Pass 1 (OLD tuple): if OLD is non-NULL and parent is missing: decrement deferred counter (guarded). +/// Pass 2 (NEW tuple): if NEW is non-NULL and parent is missing: immediate error or deferred(+1). +#[allow(clippy::too_many_arguments)] +pub fn emit_fk_child_update_counters( + program: &mut ProgramBuilder, + resolver: &Resolver, + child_tbl: &BTreeTable, + child_table_name: &str, + child_cursor_id: usize, + new_start_reg: usize, + new_rowid_reg: usize, + updated_cols: &HashSet, +) -> crate::Result<()> { + // Helper: materialize OLD tuple for this FK; returns (start_reg, ncols) or None if any component is NULL. + let load_old_tuple = + |program: &mut ProgramBuilder, fk_cols: &[String]| -> Option<(usize, usize)> { + let n = fk_cols.len(); + let start = program.alloc_registers(n); + let null_jmp = program.allocate_label(); + + for (k, cname) in fk_cols.iter().enumerate() { + let (pos, _col) = match child_tbl.get_column(cname) { + Some(v) => v, + None => { + return None; + } + }; + program.emit_column_or_rowid(child_cursor_id, pos, start + k); + program.emit_insn(Insn::IsNull { + reg: start + k, + target_pc: null_jmp, + }); + } + + // No NULLs, proceed + let cont = program.allocate_label(); + program.emit_insn(Insn::Goto { target_pc: cont }); + // NULL encountered: invalidate tuple by jumping here + program.preassign_label_to_next_insn(null_jmp); + + program.preassign_label_to_next_insn(cont); + Some((start, n)) + }; + + for fk_ref in resolver.schema.resolved_fks_for_child(child_table_name)? { + // If the child-side FK columns did not change, there is nothing to do. + if !fk_ref.child_key_changed(updated_cols, child_tbl) { + continue; + } + + let ncols = fk_ref.child_cols.len(); + + // Pass 1: OLD tuple handling only for deferred FKs + if fk_ref.fk.deferred { + if let Some((old_start, _)) = load_old_tuple(program, &fk_ref.child_cols) { + if fk_ref.parent_uses_rowid { + // Parent key is rowid: probe parent table by rowid + let parent_tbl = resolver + .schema + .get_btree_table(&fk_ref.fk.parent_table) + .expect("parent btree"); + let pcur = open_read_table(program, &parent_tbl); + + // first FK col is the rowid value + let rid = program.alloc_register(); + program.emit_insn(Insn::Copy { + src_reg: old_start, + dst_reg: rid, + extra_amount: 0, + }); + program.emit_insn(Insn::MustBeInt { reg: rid }); + + // If NOT exists => decrement + let miss = program.allocate_label(); + program.emit_insn(Insn::NotExists { + cursor: pcur, + rowid_reg: rid, + target_pc: miss, + }); + // found: close & continue + let join = program.allocate_label(); + program.emit_insn(Insn::Close { cursor_id: pcur }); + program.emit_insn(Insn::Goto { target_pc: join }); + + // missing: guarded decrement + program.preassign_label_to_next_insn(miss); + program.emit_insn(Insn::Close { cursor_id: pcur }); + let skip = program.allocate_label(); + emit_guarded_fk_decrement(program, skip); + program.preassign_label_to_next_insn(skip); + + program.preassign_label_to_next_insn(join); + } else { + // Parent key is a unique index: use index probe and guarded decrement on NOT FOUND + let parent_tbl = resolver + .schema + .get_btree_table(&fk_ref.fk.parent_table) + .expect("parent btree"); + let idx = fk_ref + .parent_unique_index + .as_ref() + .expect("parent unique index required"); + let icur = open_read_index(program, idx); + + // Copy OLD tuple and apply parent index affinities + let probe = copy_with_affinity(program, old_start, ncols, idx, &parent_tbl); + // Found: nothing; Not found: guarded decrement + index_probe( + program, + icur, + probe, + ncols, + |_p| Ok(()), + |p| { + let skip = p.allocate_label(); + emit_guarded_fk_decrement(p, skip); + p.preassign_label_to_next_insn(skip); + Ok(()) + }, + )?; + } + } + } + + // Pass 2: NEW tuple handling + let fk_ok = program.allocate_label(); + for cname in &fk_ref.fk.child_columns { + let (i, col) = child_tbl.get_column(cname).unwrap(); + let src = if col.is_rowid_alias { + new_rowid_reg + } else { + new_start_reg + i + }; + program.emit_insn(Insn::IsNull { + reg: src, + target_pc: fk_ok, + }); + } + + if fk_ref.parent_uses_rowid { + let parent_tbl = resolver + .schema + .get_btree_table(&fk_ref.fk.parent_table) + .expect("parent btree"); + let pcur = open_read_table(program, &parent_tbl); + + // Take the first child column value from NEW image + let (i_child, col_child) = child_tbl.get_column(&fk_ref.child_cols[0]).unwrap(); + let val_reg = if col_child.is_rowid_alias { + new_rowid_reg + } else { + new_start_reg + i_child + }; + + let tmp = program.alloc_register(); + program.emit_insn(Insn::Copy { + src_reg: val_reg, + dst_reg: tmp, + extra_amount: 0, + }); + program.emit_insn(Insn::MustBeInt { reg: tmp }); + + let violation = program.allocate_label(); + program.emit_insn(Insn::NotExists { + cursor: pcur, + rowid_reg: tmp, + target_pc: violation, + }); + // found: close and continue + program.emit_insn(Insn::Close { cursor_id: pcur }); + program.emit_insn(Insn::Goto { target_pc: fk_ok }); + + // missing: violation (immediate HALT or deferred +1) + program.preassign_label_to_next_insn(violation); + program.emit_insn(Insn::Close { cursor_id: pcur }); + emit_fk_violation(program, &fk_ref.fk)?; + } else { + let parent_tbl = resolver + .schema + .get_btree_table(&fk_ref.fk.parent_table) + .expect("parent btree"); + let idx = fk_ref + .parent_unique_index + .as_ref() + .expect("parent unique index required"); + let icur = open_read_index(program, idx); + + // Build NEW probe (in FK child column order, aligns with parent index columns) + let probe = { + let start = program.alloc_registers(ncols); + for (k, cname) in fk_ref.child_cols.iter().enumerate() { + let (i, col) = child_tbl.get_column(cname).unwrap(); + program.emit_insn(Insn::Copy { + src_reg: if col.is_rowid_alias { + new_rowid_reg + } else { + new_start_reg + i + }, + dst_reg: start + k, + extra_amount: 0, + }); + } + // Apply affinities of the parent index/table + if let Some(cnt) = NonZeroUsize::new(ncols) { + program.emit_insn(Insn::Affinity { + start_reg: start, + count: cnt, + affinities: build_index_affinity_string(idx, &parent_tbl), + }); + } + start + }; + + // FOUND: ok; NOT FOUND: violation path + index_probe( + program, + icur, + probe, + ncols, + |_p| Ok(()), + |p| { + emit_fk_violation(p, &fk_ref.fk)?; + Ok(()) + }, + )?; + program.emit_insn(Insn::Goto { target_pc: fk_ok }); + } + + // Skip label for NEW tuple NULL short-circuit + program.preassign_label_to_next_insn(fk_ok); + } + + Ok(()) +} + +/// Prevent deleting a parent row that is still referenced by any child. +/// For each incoming FK referencing `parent_table_name`: +/// 1. Build the parent key vector from the current parent row (FK parent-column order, +/// or the table's PK columns when the FK omits parent columns). +/// 2. Look for referencing child rows: +/// - Prefer an exact child index on (child_columns...). If found, probe the index. +/// - Otherwise scan the child table. For self-referential FKs, exclude the current rowid. +/// 3. If a referencing child exists: +/// - Immediate FK: HALT with SQLITE_CONSTRAINT_FOREIGNKEY +/// - Deferred FK: FkCounter +1 +pub fn emit_fk_delete_parent_existence_checks( + program: &mut ProgramBuilder, + resolver: &Resolver, + parent_table_name: &str, + parent_cursor_id: usize, + parent_rowid_reg: usize, +) -> Result<()> { + let parent_bt = resolver + .schema + .get_btree_table(parent_table_name) + .ok_or_else(|| crate::LimboError::InternalError("parent not btree".into()))?; + + for fk_ref in resolver + .schema + .resolved_fks_referencing(parent_table_name)? + { + let is_self_ref = fk_ref + .child_table + .name + .eq_ignore_ascii_case(parent_table_name); + + // Build parent key in FK's parent-column order (or table PK columns if unspecified). + let parent_cols: Vec = if fk_ref.fk.parent_columns.is_empty() { + parent_bt + .primary_key_columns + .iter() + .map(|(n, _)| n.clone()) + .collect() + } else { + fk_ref.fk.parent_columns.clone() + }; + let ncols = parent_cols.len(); + + let parent_key_start = program.alloc_registers(ncols); + build_parent_key( + program, + &parent_bt, + &parent_cols, + parent_cursor_id, + parent_rowid_reg, + parent_key_start, + )?; + + // Try an exact child index on (child_columns...) if available and not self-ref + let child_cols = &fk_ref.fk.child_columns; + let child_idx = if !is_self_ref { + resolver + .schema + .get_indices(&fk_ref.child_table.name) + .find(|idx| { + idx.columns.len() == child_cols.len() + && idx + .columns + .iter() + .zip(child_cols.iter()) + .all(|(ic, cc)| ic.name.eq_ignore_ascii_case(cc)) + }) + } else { + None + }; + + if let Some(idx) = child_idx { + // Index probe: FOUND => violation; NOT FOUND => ok. + let icur = open_read_index(program, idx); + let probe = + copy_with_affinity(program, parent_key_start, ncols, idx, &fk_ref.child_table); + + index_probe( + program, + icur, + probe, + ncols, + |p| { + emit_fk_violation(p, &fk_ref.fk)?; + Ok(()) + }, + |_p| Ok(()), + )?; + } else { + // Table scan fallback; for self-ref, exclude the same parent row by rowid. + table_scan_match_any( + program, + &fk_ref.child_table, + child_cols, + parent_key_start, + if is_self_ref { + Some(parent_rowid_reg) + } else { + None + }, + |p| { + emit_fk_violation(p, &fk_ref.fk)?; + Ok(()) + }, + )?; + } + } + Ok(()) +} diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 94d8dee03..19f91c5bc 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -786,6 +786,8 @@ pub fn group_by_emit_row_phase<'a>( jump_if_condition_is_true: false, jump_target_when_false: labels.label_group_by_end_without_emitting_row, jump_target_when_true: if_true_target, + // treat null result has false for now + jump_target_when_null: labels.label_group_by_end_without_emitting_row, }, &t_ctx.resolver, )?; diff --git a/core/translate/index.rs b/core/translate/index.rs index 8aced50d9..b680f560c 100644 --- a/core/translate/index.rs +++ b/core/translate/index.rs @@ -245,6 +245,7 @@ pub fn translate_create_index( jump_if_condition_is_true: false, jump_target_when_false: label, jump_target_when_true: BranchOffset::Placeholder, + jump_target_when_null: label, }, resolver, )?; @@ -594,6 +595,7 @@ pub fn translate_drop_index( program.emit_insn(Insn::Delete { cursor_id: sqlite_schema_cursor_id, table_name: "sqlite_schema".to_string(), + is_part_of_update: false, }); program.resolve_label(next_label, program.offset()); diff --git a/core/translate/insert.rs b/core/translate/insert.rs index e46a5607d..9209bdd2a 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -2,13 +2,12 @@ use std::num::NonZeroUsize; use std::sync::Arc; use turso_parser::ast::{ self, Expr, InsertBody, OneSelect, QualifiedName, ResolveType, ResultColumn, Upsert, UpsertDo, - With, }; use crate::error::{ SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, SQLITE_CONSTRAINT_UNIQUE, }; -use crate::schema::{self, Affinity, Index, Table}; +use crate::schema::{self, Affinity, BTreeTable, Index, ResolvedFkRef, Table}; use crate::translate::emitter::{ emit_cdc_insns, emit_cdc_patch_record, prepare_cdc_if_necessary, OperationMode, }; @@ -16,14 +15,18 @@ use crate::translate::expr::{ bind_and_rewrite_expr, emit_returning_results, process_returning_clause, walk_expr_mut, BindingBehavior, ReturningValueRegisters, WalkControl, }; -use crate::translate::plan::TableReferences; +use crate::translate::fkeys::{ + build_index_affinity_string, emit_fk_violation, emit_guarded_fk_decrement, index_probe, + open_read_index, open_read_table, +}; +use crate::translate::plan::{ResultSetColumn, TableReferences}; use crate::translate::planner::ROWID_STRS; use crate::translate::upsert::{ collect_set_clauses_for_upsert, emit_upsert, resolve_upsert_target, ResolvedUpsertTarget, }; use crate::util::normalize_ident; use crate::vdbe::builder::ProgramBuilderOpts; -use crate::vdbe::insn::{IdxInsertFlags, InsertFlags, RegisterOrLiteral}; +use crate::vdbe::insn::{CmpInsFlags, IdxInsertFlags, InsertFlags, RegisterOrLiteral}; use crate::vdbe::BranchOffset; use crate::{ schema::{Column, Schema}, @@ -32,22 +35,153 @@ use crate::{ insn::Insn, }, }; -use crate::{Result, VirtualTable}; +use crate::{Connection, Result, VirtualTable}; use super::emitter::Resolver; use super::expr::{translate_expr, translate_expr_no_constant_opt, NoConstantOptReason}; use super::plan::QueryDestination; use super::select::translate_select; -struct TempTableCtx { +/// Validate anything with this insert statement that should throw an early parse error +fn validate(table_name: &str, resolver: &Resolver, table: &Table) -> Result<()> { + // Check if this is a system table that should be protected from direct writes + if crate::schema::is_system_table(table_name) { + crate::bail_parse_error!("table {} may not be modified", table_name); + } + // Check if this table has any incompatible dependent views + let incompatible_views = resolver.schema.has_incompatible_dependent_views(table_name); + if !incompatible_views.is_empty() { + use crate::incremental::compiler::DBSP_CIRCUIT_VERSION; + crate::bail_parse_error!( + "Cannot INSERT into table '{}' because it has incompatible dependent materialized view(s): {}. \n\ + These views were created with a different DBSP version than the current version ({}). \n\ + Please DROP and recreate the view(s) before modifying this table.", + table_name, + incompatible_views.join(", "), + DBSP_CIRCUIT_VERSION + ); + } + + // Check if this is a materialized view + if resolver.schema.is_materialized_view(table_name) { + crate::bail_parse_error!("cannot modify materialized view {}", table_name); + } + if resolver.schema.table_has_indexes(table_name) && !resolver.schema.indexes_enabled() { + // Let's disable altering a table with indices altogether instead of checking column by + // column to be extra safe. + crate::bail_parse_error!( + "INSERT to table with indexes is disabled. Omit the `--experimental-indexes=false` flag to enable this feature." + ); + } + if table.btree().is_some_and(|t| !t.has_rowid) { + crate::bail_parse_error!("INSERT into WITHOUT ROWID table is not supported"); + } + + Ok(()) +} + +pub struct TempTableCtx { cursor_id: usize, loop_start_label: BranchOffset, loop_end_label: BranchOffset, } +#[allow(dead_code)] +pub struct InsertEmitCtx<'a> { + /// Parent table being inserted into + pub table: &'a Arc, + + /// Index cursors we need to populate for this table + /// (idx name, root_page, idx cursor id) + pub idx_cursors: Vec<(&'a String, i64, usize)>, + + /// Context for if the insert values are materialized first + /// into a temporary table + pub temp_table_ctx: Option, + /// on conflict, default to ABORT + pub on_conflict: ResolveType, + /// Arity of the insert values + pub num_values: usize, + /// The yield register, if a coroutine is used to yield multiple rows + pub yield_reg_opt: Option, + /// The register to hold the rowid of a conflicting row + pub conflict_rowid_reg: usize, + /// The cursor id of the table being inserted into + pub cursor_id: usize, + + /// Label to jump to on HALT + pub halt_label: BranchOffset, + /// Label to jump to when a row is done processing (either inserted or upserted) + pub row_done_label: BranchOffset, + /// Jump here at the complete end of the statement + pub stmt_epilogue: BranchOffset, + /// Beginning of the loop for multiple-row inserts + pub loop_start_label: BranchOffset, + /// Label to jump to when a generated key is ready for uniqueness check + pub key_ready_for_uniqueness_check_label: BranchOffset, + /// Label to jump to when no key is provided and one must be generated + pub key_generation_label: BranchOffset, + /// Jump here when the insert value SELECT source has been fully exhausted + pub select_exhausted_label: Option, + + /// CDC table info + pub cdc_table: Option<(usize, Arc)>, + /// Autoincrement sequence table info + pub autoincrement_meta: Option, +} + +impl<'a> InsertEmitCtx<'a> { + fn new( + program: &mut ProgramBuilder, + resolver: &'a Resolver, + table: &'a Arc, + on_conflict: Option, + cdc_table: Option<(usize, Arc)>, + num_values: usize, + temp_table_ctx: Option, + ) -> Self { + // allocate cursor id's for each btree index cursor we'll need to populate the indexes + let idx_cursors = resolver + .schema + .get_indices(table.name.as_str()) + .map(|idx| { + ( + &idx.name, + idx.root_page, + program.alloc_cursor_id(CursorType::BTreeIndex(idx.clone())), + ) + }) + .collect::>(); + let halt_label = program.allocate_label(); + let loop_start_label = program.allocate_label(); + let row_done_label = program.allocate_label(); + let stmt_epilogue = program.allocate_label(); + let key_ready_for_uniqueness_check_label = program.allocate_label(); + let key_generation_label = program.allocate_label(); + Self { + table, + idx_cursors, + temp_table_ctx, + on_conflict: on_conflict.unwrap_or(ResolveType::Abort), + yield_reg_opt: None, + conflict_rowid_reg: program.alloc_register(), + select_exhausted_label: None, + cursor_id: 0, // set later in emit_source_emission + halt_label, + row_done_label, + stmt_epilogue, + loop_start_label, + cdc_table, + num_values, + key_ready_for_uniqueness_check_label, + key_generation_label, + autoincrement_meta: None, + } + } +} + #[allow(clippy::too_many_arguments)] pub fn translate_insert( - with: Option, resolver: &Resolver, on_conflict: Option, tbl_name: QualifiedName, @@ -63,58 +197,15 @@ pub fn translate_insert( approx_num_labels: 5, }; program.extend(&opts); - if with.is_some() { - crate::bail_parse_error!("WITH clause is not supported"); - } - if on_conflict.is_some() { - crate::bail_parse_error!("ON CONFLICT clause is not supported"); - } - - if resolver - .schema - .table_has_indexes(&tbl_name.name.to_string()) - && !resolver.schema.indexes_enabled() - { - // Let's disable altering a table with indices altogether instead of checking column by - // column to be extra safe. - crate::bail_parse_error!( - "INSERT to table with indexes is disabled. Omit the `--experimental-indexes=false` flag to enable this feature." - ); - } let table_name = &tbl_name.name; - - // Check if this is a system table that should be protected from direct writes - if crate::schema::is_system_table(table_name.as_str()) { - crate::bail_parse_error!("table {} may not be modified", table_name); - } - let table = match resolver.schema.get_table(table_name.as_str()) { Some(table) => table, None => crate::bail_parse_error!("no such table: {}", table_name), }; + validate(table_name.as_str(), resolver, &table)?; - // Check if this is a materialized view - if resolver.schema.is_materialized_view(table_name.as_str()) { - crate::bail_parse_error!("cannot modify materialized view {}", table_name); - } - - // Check if this table has any incompatible dependent views - let incompatible_views = resolver - .schema - .has_incompatible_dependent_views(table_name.as_str()); - if !incompatible_views.is_empty() { - use crate::incremental::compiler::DBSP_CIRCUIT_VERSION; - crate::bail_parse_error!( - "Cannot INSERT into table '{}' because it has incompatible dependent materialized view(s): {}. \n\ - These views were created with a different DBSP version than the current version ({}). \n\ - Please DROP and recreate the view(s) before modifying this table.", - table_name, - incompatible_views.join(", "), - DBSP_CIRCUIT_VERSION - ); - } - + let fk_enabled = connection.foreign_keys_enabled(); if let Some(virtual_table) = &table.virtual_table() { program = translate_virtual_table_insert( program, @@ -130,99 +221,24 @@ pub fn translate_insert( let Some(btree_table) = table.btree() else { crate::bail_parse_error!("no such table: {}", table_name); }; - if !btree_table.has_rowid { - crate::bail_parse_error!("INSERT into WITHOUT ROWID table is not supported"); - } - let root_page = btree_table.root_page; - - let mut values: Option>> = None; - let mut upsert_actions: Vec<(ResolvedUpsertTarget, BranchOffset, Box)> = Vec::new(); - - let mut inserting_multiple_rows = false; - if let InsertBody::Select(select, upsert_opt) = &mut body { - match &mut select.body.select { - // TODO see how to avoid clone - OneSelect::Values(values_expr) if values_expr.len() <= 1 => { - if values_expr.is_empty() { - crate::bail_parse_error!("no values to insert"); - } - for expr in values_expr.iter_mut().flat_map(|v| v.iter_mut()) { - match expr.as_mut() { - Expr::Id(name) => { - if name.quoted_with('"') { - *expr = - Expr::Literal(ast::Literal::String(name.as_literal())).into(); - } else { - // an INSERT INTO ... VALUES (...) cannot reference columns - crate::bail_parse_error!("no such column: {name}"); - } - } - Expr::Qualified(first_name, second_name) => { - // an INSERT INTO ... VALUES (...) cannot reference columns - crate::bail_parse_error!("no such column: {first_name}.{second_name}"); - } - _ => {} - } - bind_and_rewrite_expr( - expr, - None, - None, - connection, - &mut program.param_ctx, - BindingBehavior::ResultColumnsNotAllowed, - )?; - } - values = values_expr.pop(); - } - _ => inserting_multiple_rows = true, - } - while let Some(mut upsert) = upsert_opt.take() { - if let UpsertDo::Set { - ref mut sets, - ref mut where_clause, - } = &mut upsert.do_clause - { - for set in sets.iter_mut() { - bind_and_rewrite_expr( - &mut set.expr, - None, - None, - connection, - &mut program.param_ctx, - BindingBehavior::ResultColumnsNotAllowed, - )?; - } - if let Some(ref mut where_expr) = where_clause { - bind_and_rewrite_expr( - where_expr, - None, - None, - connection, - &mut program.param_ctx, - BindingBehavior::ResultColumnsNotAllowed, - )?; - } - } - let next = upsert.next.take(); - upsert_actions.push(( - // resolve the constrained target for UPSERT in the chain - resolve_upsert_target(resolver.schema, &table, &upsert)?, - program.allocate_label(), - upsert, - )); - *upsert_opt = next; - } - } + let BoundInsertResult { + mut values, + mut upsert_actions, + inserting_multiple_rows, + } = bind_insert( + &mut program, + resolver, + &table, + &mut body, + connection, + on_conflict.unwrap_or(ResolveType::Abort), + )?; if inserting_multiple_rows && btree_table.has_autoincrement { ensure_sequence_initialized(&mut program, resolver.schema, &btree_table)?; } - let halt_label = program.allocate_label(); - let loop_start_label = program.allocate_label(); - let row_done_label = program.allocate_label(); - let cdc_table = prepare_cdc_if_necessary(&mut program, resolver.schema, table.get_name())?; // Process RETURNING clause using shared module @@ -233,314 +249,53 @@ pub fn translate_insert( &mut program, connection, )?; + let has_fks = fk_enabled + && (resolver.schema.has_child_fks(table_name.as_str()) + || resolver + .schema + .any_resolved_fks_referencing(table_name.as_str())); - let mut yield_reg_opt = None; - let mut temp_table_ctx = None; - let (num_values, cursor_id) = match body { - InsertBody::Select(select, _) => { - // Simple Common case of INSERT INTO VALUES (...) - if matches!(&select.body.select, OneSelect::Values(values) if values.len() <= 1) { - ( - values.as_ref().unwrap().len(), - program.alloc_cursor_id(CursorType::BTreeTable(btree_table.clone())), - ) - } else { - // Multiple rows - use coroutine for value population - let yield_reg = program.alloc_register(); - let jump_on_definition_label = program.allocate_label(); - let start_offset_label = program.allocate_label(); - program.emit_insn(Insn::InitCoroutine { - yield_reg, - jump_on_definition: jump_on_definition_label, - start_offset: start_offset_label, - }); + let mut ctx = InsertEmitCtx::new( + &mut program, + resolver, + &btree_table, + on_conflict, + cdc_table, + values.len(), + None, + ); - program.preassign_label_to_next_insn(start_offset_label); - - let query_destination = QueryDestination::CoroutineYield { - yield_reg, - coroutine_implementation_start: halt_label, - }; - program.incr_nesting(); - let result = - translate_select(select, resolver, program, query_destination, connection)?; - program = result.program; - program.decr_nesting(); - - program.emit_insn(Insn::EndCoroutine { yield_reg }); - program.preassign_label_to_next_insn(jump_on_definition_label); - - let cursor_id = - program.alloc_cursor_id(CursorType::BTreeTable(btree_table.clone())); - - // From SQLite - /* Set useTempTable to TRUE if the result of the SELECT statement - ** should be written into a temporary table (template 4). Set to - ** FALSE if each output row of the SELECT can be written directly into - ** the destination table (template 3). - ** - ** A temp table must be used if the table being updated is also one - ** of the tables being read by the SELECT statement. Also use a - ** temp table in the case of row triggers. - */ - if program.is_table_open(&table) { - let temp_cursor_id = - program.alloc_cursor_id(CursorType::BTreeTable(btree_table.clone())); - temp_table_ctx = Some(TempTableCtx { - cursor_id: temp_cursor_id, - loop_start_label: program.allocate_label(), - loop_end_label: program.allocate_label(), - }); - - program.emit_insn(Insn::OpenEphemeral { - cursor_id: temp_cursor_id, - is_table: true, - }); - - // Main loop - // FIXME: rollback is not implemented. E.g. if you insert 2 rows and one fails to unique constraint violation, - // the other row will still be inserted. - program.preassign_label_to_next_insn(loop_start_label); - - let yield_label = program.allocate_label(); - - program.emit_insn(Insn::Yield { - yield_reg, - end_offset: yield_label, - }); - let record_reg = program.alloc_register(); - - let affinity_str = if columns.is_empty() { - btree_table - .columns - .iter() - .filter(|col| !col.hidden) - .map(|col| col.affinity().aff_mask()) - .collect::() - } else { - columns - .iter() - .map(|col_name| { - let column_name = normalize_ident(col_name.as_str()); - if ROWID_STRS - .iter() - .any(|s| s.eq_ignore_ascii_case(&column_name)) - { - return Affinity::Integer.aff_mask(); - } - table - .get_column_by_name(&column_name) - .unwrap() - .1 - .affinity() - .aff_mask() - }) - .collect::() - }; - - program.emit_insn(Insn::MakeRecord { - start_reg: program.reg_result_cols_start.unwrap_or(yield_reg + 1), - count: result.num_result_cols, - dest_reg: record_reg, - index_name: None, - affinity_str: Some(affinity_str), - }); - - let rowid_reg = program.alloc_register(); - program.emit_insn(Insn::NewRowid { - cursor: temp_cursor_id, - rowid_reg, - prev_largest_reg: 0, - }); - - program.emit_insn(Insn::Insert { - cursor: temp_cursor_id, - key_reg: rowid_reg, - record_reg, - // since we are not doing an Insn::NewRowid or an Insn::NotExists here, we need to seek to ensure the insertion happens in the correct place. - flag: InsertFlags::new().require_seek(), - table_name: "".to_string(), - }); - - // loop back - program.emit_insn(Insn::Goto { - target_pc: loop_start_label, - }); - - program.preassign_label_to_next_insn(yield_label); - - program.emit_insn(Insn::OpenWrite { - cursor_id, - root_page: RegisterOrLiteral::Literal(root_page), - db: 0, - }); - } else { - program.emit_insn(Insn::OpenWrite { - cursor_id, - root_page: RegisterOrLiteral::Literal(root_page), - db: 0, - }); - - // Main loop - // FIXME: rollback is not implemented. E.g. if you insert 2 rows and one fails to unique constraint violation, - // the other row will still be inserted. - program.preassign_label_to_next_insn(loop_start_label); - program.emit_insn(Insn::Yield { - yield_reg, - end_offset: halt_label, - }); - } - - yield_reg_opt = Some(yield_reg); - (result.num_result_cols, cursor_id) - } - } - InsertBody::DefaultValues => { - let num_values = table.columns().len(); - values = Some( - table - .columns() - .iter() - .map(|c| { - c.default - .clone() - .unwrap_or(Box::new(ast::Expr::Literal(ast::Literal::Null))) - }) - .collect(), - ); - ( - num_values, - program.alloc_cursor_id(CursorType::BTreeTable(btree_table.clone())), - ) - } - }; + program = init_source_emission( + program, + &table, + connection, + &mut ctx, + resolver, + &mut values, + body, + &columns, + )?; let has_upsert = !upsert_actions.is_empty(); // Set up the program to return result columns if RETURNING is specified if !result_columns.is_empty() { program.result_columns = result_columns.clone(); } + let insertion = build_insertion(&mut program, &table, &columns, ctx.num_values)?; - // allocate cursor id's for each btree index cursor we'll need to populate the indexes - // (idx name, root_page, idx cursor id) - let idx_cursors = resolver - .schema - .get_indices(table_name.as_str()) - .map(|idx| { - ( - &idx.name, - idx.root_page, - program.alloc_cursor_id(CursorType::BTreeIndex(idx.clone())), - ) - }) - .collect::>(); - - let insertion = build_insertion(&mut program, &table, &columns, num_values)?; - - let conflict_rowid_reg = program.alloc_register(); - - if inserting_multiple_rows { - let select_result_start_reg = program - .reg_result_cols_start - .unwrap_or(yield_reg_opt.unwrap() + 1); - translate_rows_multiple( - &mut program, - &insertion, - select_result_start_reg, - resolver, - &temp_table_ctx, - )?; - } else { - // Single row - populate registers directly - program.emit_insn(Insn::OpenWrite { - cursor_id, - root_page: RegisterOrLiteral::Literal(root_page), - db: 0, - }); - - translate_rows_single(&mut program, &values.unwrap(), &insertion, resolver)?; - } - - // Open all the index btrees for writing - for idx_cursor in idx_cursors.iter() { - program.emit_insn(Insn::OpenWrite { - cursor_id: idx_cursor.2, - root_page: idx_cursor.1.into(), - db: 0, - }); - } + translate_rows_and_open_tables( + &mut program, + resolver, + &insertion, + &ctx, + &values, + inserting_multiple_rows, + )?; let has_user_provided_rowid = insertion.key.is_provided_by_user(); - let key_ready_for_uniqueness_check_label = program.allocate_label(); - let key_generation_label = program.allocate_label(); - let mut autoincrement_meta = None; - - if btree_table.has_autoincrement { - let seq_table = resolver - .schema - .get_btree_table("sqlite_sequence") - .ok_or_else(|| { - crate::error::LimboError::InternalError( - "sqlite_sequence table not found".to_string(), - ) - })?; - let seq_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(seq_table.clone())); - program.emit_insn(Insn::OpenWrite { - cursor_id: seq_cursor_id, - root_page: seq_table.root_page.into(), - db: 0, - }); - - let table_name_reg = program.emit_string8_new_reg(btree_table.name.clone()); - let r_seq = program.alloc_register(); - let r_seq_rowid = program.alloc_register(); - autoincrement_meta = Some((seq_cursor_id, r_seq, r_seq_rowid, table_name_reg)); - - program.emit_insn(Insn::Integer { - dest: r_seq, - value: 0, - }); - program.emit_insn(Insn::Null { - dest: r_seq_rowid, - dest_end: None, - }); - - let loop_start_label = program.allocate_label(); - let loop_end_label = program.allocate_label(); - let found_label = program.allocate_label(); - - program.emit_insn(Insn::Rewind { - cursor_id: seq_cursor_id, - pc_if_empty: loop_end_label, - }); - program.preassign_label_to_next_insn(loop_start_label); - - let name_col_reg = program.alloc_register(); - program.emit_column_or_rowid(seq_cursor_id, 0, name_col_reg); - program.emit_insn(Insn::Ne { - lhs: table_name_reg, - rhs: name_col_reg, - target_pc: found_label, - flags: Default::default(), - collation: None, - }); - - program.emit_column_or_rowid(seq_cursor_id, 1, r_seq); - program.emit_insn(Insn::RowId { - cursor_id: seq_cursor_id, - dest: r_seq_rowid, - }); - program.emit_insn(Insn::Goto { - target_pc: loop_end_label, - }); - - program.preassign_label_to_next_insn(found_label); - program.emit_insn(Insn::Next { - cursor_id: seq_cursor_id, - pc_if_next: loop_start_label, - }); - program.preassign_label_to_next_insn(loop_end_label); + if ctx.table.has_autoincrement { + init_autoincrement(&mut program, &mut ctx, resolver)?; } if has_user_provided_rowid { @@ -552,7 +307,7 @@ pub fn translate_insert( }); program.emit_insn(Insn::Goto { - target_pc: key_generation_label, + target_pc: ctx.key_generation_label, }); program.preassign_label_to_next_insn(must_be_int_label); @@ -561,18 +316,369 @@ pub fn translate_insert( }); program.emit_insn(Insn::Goto { - target_pc: key_ready_for_uniqueness_check_label, + target_pc: ctx.key_ready_for_uniqueness_check_label, }); } - program.preassign_label_to_next_insn(key_generation_label); - if let Some((_, r_seq, _, _)) = autoincrement_meta { + program.preassign_label_to_next_insn(ctx.key_generation_label); + + emit_rowid_generation(&mut program, resolver, &ctx, &insertion)?; + + program.preassign_label_to_next_insn(ctx.key_ready_for_uniqueness_check_label); + + if ctx.table.is_strict { + program.emit_insn(Insn::TypeCheck { + start_reg: insertion.first_col_register(), + count: insertion.col_mappings.len(), + check_generated: true, + table_reference: Arc::clone(ctx.table), + }); + } + + // Build a list of upsert constraints/indexes we need to run preflight + // checks against, in the proper order of evaluation, + let constraints = build_constraints_to_check( + resolver, + table_name.as_str(), + &upsert_actions, + has_user_provided_rowid, + ); + + // We need to separate index handling and insertion into a `preflight` and a + // `commit` phase, because in UPSERT mode we might need to skip the actual insertion, as we can + // have a naked ON CONFLICT DO NOTHING, so if we eagerly insert any indexes, we could insert + // invalid index entries before we hit a conflict down the line. + emit_preflight_constraint_checks( + &mut program, + &ctx, + resolver, + &insertion, + &upsert_actions, + &constraints, + )?; + + emit_notnulls(&mut program, &ctx, &insertion); + + // Create and insert the record + let affinity_str = insertion + .col_mappings + .iter() + .map(|col_mapping| col_mapping.column.affinity().aff_mask()) + .collect::(); + program.emit_insn(Insn::MakeRecord { + start_reg: insertion.first_col_register(), + count: insertion.col_mappings.len(), + dest_reg: insertion.record_register(), + index_name: None, + affinity_str: Some(affinity_str), + }); + + if has_upsert { + emit_commit_phase(&mut program, resolver, &insertion, &ctx)?; + } + + if has_fks { + // Child-side check must run before Insert (may HALT or increment deferred counter) + emit_fk_child_insert_checks( + &mut program, + resolver, + &btree_table, + insertion.first_col_register(), + insertion.key_register(), + )?; + } + + program.emit_insn(Insn::Insert { + cursor: ctx.cursor_id, + key_reg: insertion.key_register(), + record_reg: insertion.record_register(), + flag: InsertFlags::new(), + table_name: table_name.to_string(), + }); + + if has_fks { + // After the row is actually present, repair deferred counters for children referencing this NEW parent key. + emit_parent_side_fk_decrement_on_insert(&mut program, resolver, &btree_table, &insertion)?; + } + + if let Some(AutoincMeta { + seq_cursor_id, + r_seq, + r_seq_rowid, + table_name_reg, + }) = ctx.autoincrement_meta + { + let no_update_needed_label = program.allocate_label(); + program.emit_insn(Insn::Le { + lhs: insertion.key_register(), + rhs: r_seq, + target_pc: no_update_needed_label, + flags: Default::default(), + collation: None, + }); + + emit_update_sqlite_sequence( + &mut program, + resolver.schema, + seq_cursor_id, + r_seq_rowid, + table_name_reg, + insertion.key_register(), + )?; + + program.preassign_label_to_next_insn(no_update_needed_label); + program.emit_insn(Insn::Close { + cursor_id: seq_cursor_id, + }); + } + + // Emit update in the CDC table if necessary (after the INSERT updated the table) + if let Some((cdc_cursor_id, _)) = &ctx.cdc_table { + let cdc_has_after = program.capture_data_changes_mode().has_after(); + let after_record_reg = if cdc_has_after { + Some(emit_cdc_patch_record( + &mut program, + &table, + insertion.first_col_register(), + insertion.record_register(), + insertion.key_register(), + )) + } else { + None + }; + emit_cdc_insns( + &mut program, + resolver, + OperationMode::INSERT, + *cdc_cursor_id, + insertion.key_register(), + None, + after_record_reg, + None, + table_name.as_str(), + )?; + } + + // Emit RETURNING results if specified + if !result_columns.is_empty() { + let value_registers = ReturningValueRegisters { + rowid_register: insertion.key_register(), + columns_start_register: insertion.first_col_register(), + num_columns: table.columns().len(), + }; + + emit_returning_results(&mut program, &result_columns, &value_registers)?; + } + program.emit_insn(Insn::Goto { + target_pc: ctx.row_done_label, + }); + + resolve_upserts( + &mut program, + resolver, + &mut upsert_actions, + &ctx, + &insertion, + &table, + &mut result_columns, + connection, + )?; + + emit_epilogue(&mut program, &ctx, inserting_multiple_rows); + + program.set_needs_stmt_subtransactions(true); + Ok(program) +} + +fn emit_epilogue(program: &mut ProgramBuilder, ctx: &InsertEmitCtx, inserting_multiple_rows: bool) { + if inserting_multiple_rows { + if let Some(temp_table_ctx) = &ctx.temp_table_ctx { + program.resolve_label(ctx.row_done_label, program.offset()); + + program.emit_insn(Insn::Next { + cursor_id: temp_table_ctx.cursor_id, + pc_if_next: temp_table_ctx.loop_start_label, + }); + program.preassign_label_to_next_insn(temp_table_ctx.loop_end_label); + + program.emit_insn(Insn::Close { + cursor_id: temp_table_ctx.cursor_id, + }); + program.emit_insn(Insn::Goto { + target_pc: ctx.stmt_epilogue, + }); + } else { + // For multiple rows which not require a temp table, loop back + program.resolve_label(ctx.row_done_label, program.offset()); + program.emit_insn(Insn::Goto { + target_pc: ctx.loop_start_label, + }); + if let Some(sel_eof) = ctx.select_exhausted_label { + program.preassign_label_to_next_insn(sel_eof); + program.emit_insn(Insn::Goto { + target_pc: ctx.stmt_epilogue, + }); + } + } + } else { + program.resolve_label(ctx.row_done_label, program.offset()); + // single-row falls through to epilogue + program.emit_insn(Insn::Goto { + target_pc: ctx.stmt_epilogue, + }); + } + program.preassign_label_to_next_insn(ctx.stmt_epilogue); + program.resolve_label(ctx.halt_label, program.offset()); +} + +// COMMIT PHASE: no preflight jumps happened; emit the actual index writes now +// We re-check partial-index predicates against the NEW image, produce packed records, +// and insert into all applicable indexes, we do not re-probe uniqueness here, as preflight +// already guaranteed non-conflict. +fn emit_commit_phase( + program: &mut ProgramBuilder, + resolver: &Resolver, + insertion: &Insertion, + ctx: &InsertEmitCtx, +) -> Result<()> { + for index in resolver.schema.get_indices(ctx.table.name.as_str()) { + let idx_cursor_id = ctx + .idx_cursors + .iter() + .find(|(name, _, _)| *name == &index.name) + .map(|(_, _, c_id)| *c_id) + .expect("no cursor found for index"); + + // Re-evaluate partial predicate on the would-be inserted image + let commit_skip_label = if let Some(where_clause) = &index.where_clause { + let mut where_for_eval = where_clause.as_ref().clone(); + rewrite_partial_index_where(&mut where_for_eval, insertion)?; + let reg = program.alloc_register(); + translate_expr_no_constant_opt( + program, + Some(&TableReferences::new_empty()), + &where_for_eval, + reg, + resolver, + NoConstantOptReason::RegisterReuse, + )?; + let lbl = program.allocate_label(); + program.emit_insn(Insn::IfNot { + reg, + target_pc: lbl, + jump_if_null: true, + }); + Some(lbl) + } else { + None + }; + + let num_cols = index.columns.len(); + let idx_start_reg = program.alloc_registers(num_cols + 1); + + // Build [key cols..., rowid] from insertion registers + for (i, idx_col) in index.columns.iter().enumerate() { + let Some(cm) = insertion.get_col_mapping_by_name(&idx_col.name) else { + return Err(crate::LimboError::PlanningError( + "Column not found in INSERT (commit phase)".to_string(), + )); + }; + program.emit_insn(Insn::Copy { + src_reg: cm.register, + dst_reg: idx_start_reg + i, + extra_amount: 0, + }); + } + program.emit_insn(Insn::Copy { + src_reg: insertion.key_register(), + dst_reg: idx_start_reg + num_cols, + extra_amount: 0, + }); + + let record_reg = program.alloc_register(); + program.emit_insn(Insn::MakeRecord { + start_reg: idx_start_reg, + count: num_cols + 1, + dest_reg: record_reg, + index_name: Some(index.name.clone()), + affinity_str: None, + }); + program.emit_insn(Insn::IdxInsert { + cursor_id: idx_cursor_id, + record_reg, + unpacked_start: Some(idx_start_reg), + unpacked_count: Some((num_cols + 1) as u16), + flags: IdxInsertFlags::new().nchange(true), + }); + + if let Some(lbl) = commit_skip_label { + program.resolve_label(lbl, program.offset()); + } + } + Ok(()) +} + +fn translate_rows_and_open_tables( + program: &mut ProgramBuilder, + resolver: &Resolver, + insertion: &Insertion, + ctx: &InsertEmitCtx, + values: &[Box], + inserting_multiple_rows: bool, +) -> Result<()> { + if inserting_multiple_rows { + let select_result_start_reg = program + .reg_result_cols_start + .unwrap_or(ctx.yield_reg_opt.unwrap() + 1); + translate_rows_multiple( + program, + insertion, + select_result_start_reg, + resolver, + &ctx.temp_table_ctx, + )?; + } else { + // Single row - populate registers directly + program.emit_insn(Insn::OpenWrite { + cursor_id: ctx.cursor_id, + root_page: RegisterOrLiteral::Literal(ctx.table.root_page), + db: 0, + }); + + translate_rows_single(program, values, insertion, resolver)?; + } + + // Open all the index btrees for writing + for idx_cursor in ctx.idx_cursors.iter() { + program.emit_insn(Insn::OpenWrite { + cursor_id: idx_cursor.2, + root_page: idx_cursor.1.into(), + db: 0, + }); + } + Ok(()) +} + +fn emit_rowid_generation( + program: &mut ProgramBuilder, + resolver: &Resolver, + ctx: &InsertEmitCtx, + insertion: &Insertion, +) -> Result<()> { + if let Some(AutoincMeta { + r_seq, + seq_cursor_id, + r_seq_rowid, + table_name_reg, + .. + }) = ctx.autoincrement_meta + { let r_max = program.alloc_register(); let dummy_reg = program.alloc_register(); program.emit_insn(Insn::NewRowid { - cursor: cursor_id, + cursor: ctx.cursor_id, rowid_reg: dummy_reg, prev_largest_reg: r_max, }); @@ -613,295 +719,144 @@ pub fn translate_insert( value: 1, }); - if let Some((seq_cursor_id, _, r_seq_rowid, table_name_reg)) = autoincrement_meta { - emit_update_sqlite_sequence( - &mut program, - resolver.schema, - seq_cursor_id, - r_seq_rowid, - table_name_reg, - insertion.key_register(), - )?; - } + emit_update_sqlite_sequence( + program, + resolver.schema, + seq_cursor_id, + r_seq_rowid, + table_name_reg, + insertion.key_register(), + )?; } else { program.emit_insn(Insn::NewRowid { - cursor: cursor_id, + cursor: ctx.cursor_id, rowid_reg: insertion.key_register(), prev_largest_reg: 0, }); } + Ok(()) +} - program.preassign_label_to_next_insn(key_ready_for_uniqueness_check_label); +#[allow(clippy::too_many_arguments)] +fn resolve_upserts( + program: &mut ProgramBuilder, + resolver: &Resolver, + upsert_actions: &mut [(ResolvedUpsertTarget, BranchOffset, Box)], + ctx: &InsertEmitCtx, + insertion: &Insertion, + table: &Table, + result_columns: &mut [ResultSetColumn], + connection: &Arc, +) -> Result<()> { + for (_, label, upsert) in upsert_actions { + program.preassign_label_to_next_insn(*label); - match table.btree() { - Some(t) if t.is_strict => { - program.emit_insn(Insn::TypeCheck { - start_reg: insertion.first_col_register(), - count: insertion.col_mappings.len(), - check_generated: true, - table_reference: Arc::clone(&t), + if let UpsertDo::Set { + ref mut sets, + ref mut where_clause, + } = upsert.do_clause + { + // Normalize SET pairs once + let mut rewritten_sets = collect_set_clauses_for_upsert(table, sets)?; + + emit_upsert( + program, + table, + ctx, + insertion, + &mut rewritten_sets, + where_clause, + resolver, + result_columns, + connection, + )?; + } else { + // UpsertDo::Nothing case + program.emit_insn(Insn::Goto { + target_pc: ctx.row_done_label, }); } - _ => (), } + Ok(()) +} - let mut constraints_to_check = Vec::new(); - if has_user_provided_rowid { - // Check uniqueness constraint for rowid if it was provided by user. - // When the DB allocates it there are no need for separate uniqueness checks. - let position = upsert_actions - .iter() - .position(|(target, ..)| matches!(target, ResolvedUpsertTarget::PrimaryKey)); - constraints_to_check.push((ResolvedUpsertTarget::PrimaryKey, position)); - } - for index in resolver.schema.get_indices(table_name.as_str()) { - let position = upsert_actions - .iter() - .position(|(target, ..)| matches!(target, ResolvedUpsertTarget::Index(x) if Arc::ptr_eq(x, index))); - constraints_to_check.push((ResolvedUpsertTarget::Index(index.clone()), position)); - } - - constraints_to_check.sort_by(|(_, p1), (_, p2)| match (p1, p2) { - (Some(p1), Some(p2)) => p1.cmp(p2), - (Some(_), None) => std::cmp::Ordering::Less, - (None, Some(_)) => std::cmp::Ordering::Greater, - (None, None) => std::cmp::Ordering::Equal, +fn init_autoincrement( + program: &mut ProgramBuilder, + ctx: &mut InsertEmitCtx, + resolver: &Resolver, +) -> Result<()> { + let seq_table = resolver + .schema + .get_btree_table("sqlite_sequence") + .ok_or_else(|| { + crate::error::LimboError::InternalError("sqlite_sequence table not found".to_string()) + })?; + let seq_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(seq_table.clone())); + program.emit_insn(Insn::OpenWrite { + cursor_id: seq_cursor_id, + root_page: seq_table.root_page.into(), + db: 0, }); - let upsert_catch_all_position = - if let Some((ResolvedUpsertTarget::CatchAll, ..)) = upsert_actions.last() { - Some(upsert_actions.len() - 1) - } else { - None - }; + let table_name_reg = program.emit_string8_new_reg(ctx.table.name.clone()); + let r_seq = program.alloc_register(); + let r_seq_rowid = program.alloc_register(); - // We need to separate index handling and insertion into a `preflight` and a - // `commit` phase, because in UPSERT mode we might need to skip the actual insertion, as we can - // have a naked ON CONFLICT DO NOTHING, so if we eagerly insert any indexes, we could insert - // invalid index entries before we hit a conflict down the line. - // - // Preflight phase: evaluate each applicable UNIQUE constraint and probe with NoConflict. - // If any probe hits: - // DO NOTHING -> jump to row_done_label. - // - // DO UPDATE (matching target) -> fetch conflicting rowid and jump to `upsert_entry`. - // - // otherwise, raise SQLITE_CONSTRAINT_UNIQUE - for (constraint, position) in constraints_to_check { - match constraint { - ResolvedUpsertTarget::PrimaryKey => { - let make_record_label = program.allocate_label(); - program.emit_insn(Insn::NotExists { - cursor: cursor_id, - rowid_reg: insertion.key_register(), - target_pc: make_record_label, - }); - let rowid_column_name = insertion.key.column_name(); + ctx.autoincrement_meta = Some(AutoincMeta { + seq_cursor_id, + r_seq, + r_seq_rowid, + table_name_reg, + }); - // Conflict on rowid: attempt to route through UPSERT if it targets the PK, otherwise raise constraint. - // emit Halt for every case *except* when upsert handles the conflict - 'emit_halt: { - if let Some(position) = position.or(upsert_catch_all_position) { - // PK conflict: the conflicting rowid is exactly the attempted key - program.emit_insn(Insn::Copy { - src_reg: insertion.key_register(), - dst_reg: conflict_rowid_reg, - extra_amount: 0, - }); - program.emit_insn(Insn::Goto { - target_pc: upsert_actions[position].1, - }); - break 'emit_halt; - } - let mut description = String::with_capacity( - table_name.as_str().len() + rowid_column_name.len() + 2, - ); - description.push_str(table_name.as_str()); - description.push('.'); - description.push_str(rowid_column_name); - program.emit_insn(Insn::Halt { - err_code: SQLITE_CONSTRAINT_PRIMARYKEY, - description, - }); - } - program.preassign_label_to_next_insn(make_record_label); - } - ResolvedUpsertTarget::Index(index) => { - let column_mappings = index - .columns - .iter() - .map(|idx_col| insertion.get_col_mapping_by_name(&idx_col.name)); - // find which cursor we opened earlier for this index - let idx_cursor_id = idx_cursors - .iter() - .find(|(name, _, _)| *name == &index.name) - .map(|(_, _, c_id)| *c_id) - .expect("no cursor found for index"); + program.emit_insn(Insn::Integer { + dest: r_seq, + value: 0, + }); + program.emit_insn(Insn::Null { + dest: r_seq_rowid, + dest_end: None, + }); - let maybe_skip_probe_label = if let Some(where_clause) = &index.where_clause { - let mut where_for_eval = where_clause.as_ref().clone(); - rewrite_partial_index_where(&mut where_for_eval, &insertion)?; - let reg = program.alloc_register(); - translate_expr_no_constant_opt( - &mut program, - Some(&TableReferences::new_empty()), - &where_for_eval, - reg, - resolver, - NoConstantOptReason::RegisterReuse, - )?; - let lbl = program.allocate_label(); - program.emit_insn(Insn::IfNot { - reg, - target_pc: lbl, - jump_if_null: true, - }); - Some(lbl) - } else { - None - }; + let loop_start_label = program.allocate_label(); + let loop_end_label = program.allocate_label(); + let found_label = program.allocate_label(); - let num_cols = index.columns.len(); - // allocate scratch registers for the index columns plus rowid - let idx_start_reg = program.alloc_registers(num_cols + 1); + program.emit_insn(Insn::Rewind { + cursor_id: seq_cursor_id, + pc_if_empty: loop_end_label, + }); + program.preassign_label_to_next_insn(loop_start_label); - // build unpacked key [idx_start_reg .. idx_start_reg+num_cols-1], and rowid in last reg, - // copy each index column from the table's column registers into these scratch regs - for (i, column_mapping) in column_mappings.clone().enumerate() { - // copy from the table's column register over to the index's scratch register - let Some(col_mapping) = column_mapping else { - return Err(crate::LimboError::PlanningError( - "Column not found in INSERT".to_string(), - )); - }; - program.emit_insn(Insn::Copy { - src_reg: col_mapping.register, - dst_reg: idx_start_reg + i, - extra_amount: 0, - }); - } - // last register is the rowid - program.emit_insn(Insn::Copy { - src_reg: insertion.key_register(), - dst_reg: idx_start_reg + num_cols, - extra_amount: 0, - }); + let name_col_reg = program.alloc_register(); + program.emit_column_or_rowid(seq_cursor_id, 0, name_col_reg); + program.emit_insn(Insn::Ne { + lhs: table_name_reg, + rhs: name_col_reg, + target_pc: found_label, + flags: Default::default(), + collation: None, + }); - if index.unique { - let aff = index - .columns - .iter() - .map(|ic| table.columns()[ic.pos_in_table].affinity().aff_mask()) - .collect::(); - program.emit_insn(Insn::Affinity { - start_reg: idx_start_reg, - count: NonZeroUsize::new(num_cols).expect("nonzero col count"), - affinities: aff, - }); + program.emit_column_or_rowid(seq_cursor_id, 1, r_seq); + program.emit_insn(Insn::RowId { + cursor_id: seq_cursor_id, + dest: r_seq_rowid, + }); + program.emit_insn(Insn::Goto { + target_pc: loop_end_label, + }); - if has_upsert { - let next_check = program.allocate_label(); - program.emit_insn(Insn::NoConflict { - cursor_id: idx_cursor_id, - target_pc: next_check, - record_reg: idx_start_reg, - num_regs: num_cols, - }); - - // Conflict detected, figure out if this UPSERT handles the conflict - if let Some(position) = position.or(upsert_catch_all_position) { - match &upsert_actions[position].2.do_clause { - UpsertDo::Nothing => { - // Bail out without writing anything - program.emit_insn(Insn::Goto { - target_pc: row_done_label, - }); - } - UpsertDo::Set { .. } => { - // Route to DO UPDATE: capture conflicting rowid then jump - program.emit_insn(Insn::IdxRowId { - cursor_id: idx_cursor_id, - dest: conflict_rowid_reg, - }); - program.emit_insn(Insn::Goto { - target_pc: upsert_actions[position].1, - }); - } - } - } - // No matching UPSERT handler so we emit constraint error - // (if conflict clause matched - VM will jump to later instructions and skip halt) - program.emit_insn(Insn::Halt { - err_code: SQLITE_CONSTRAINT_UNIQUE, - description: format_unique_violation_desc(table_name.as_str(), &index), - }); - - // continue preflight with next constraint - program.preassign_label_to_next_insn(next_check); - } else { - // No UPSERT fast-path: probe and immediately insert - let ok = program.allocate_label(); - program.emit_insn(Insn::NoConflict { - cursor_id: idx_cursor_id, - target_pc: ok, - record_reg: idx_start_reg, - num_regs: num_cols, - }); - // Unique violation without ON CONFLICT clause -> error - program.emit_insn(Insn::Halt { - err_code: SQLITE_CONSTRAINT_UNIQUE, - description: format_unique_violation_desc(table_name.as_str(), &index), - }); - program.preassign_label_to_next_insn(ok); - - // In the non-UPSERT case, we insert the index - let record_reg = program.alloc_register(); - program.emit_insn(Insn::MakeRecord { - start_reg: idx_start_reg, - count: num_cols + 1, - dest_reg: record_reg, - index_name: Some(index.name.clone()), - affinity_str: None, - }); - program.emit_insn(Insn::IdxInsert { - cursor_id: idx_cursor_id, - record_reg, - unpacked_start: Some(idx_start_reg), - unpacked_count: Some((num_cols + 1) as u16), - flags: IdxInsertFlags::new().nchange(true), - }); - } - } else { - // Non-unique index: in UPSERT mode we postpone writes to commit phase. - if !has_upsert { - // eager insert for non-unique, no UPSERT - let record_reg = program.alloc_register(); - program.emit_insn(Insn::MakeRecord { - start_reg: idx_start_reg, - count: num_cols + 1, - dest_reg: record_reg, - index_name: Some(index.name.clone()), - affinity_str: None, - }); - program.emit_insn(Insn::IdxInsert { - cursor_id: idx_cursor_id, - record_reg, - unpacked_start: Some(idx_start_reg), - unpacked_count: Some((num_cols + 1) as u16), - flags: IdxInsertFlags::new().nchange(true), - }); - } - } - - // Close the partial-index skip (preflight) - if let Some(lbl) = maybe_skip_probe_label { - program.resolve_label(lbl, program.offset()); - } - } - ResolvedUpsertTarget::CatchAll => unreachable!(), - } - } + program.preassign_label_to_next_insn(found_label); + program.emit_insn(Insn::Next { + cursor_id: seq_cursor_id, + pc_if_next: loop_start_label, + }); + program.preassign_label_to_next_insn(loop_end_label); + Ok(()) +} +fn emit_notnulls(program: &mut ProgramBuilder, ctx: &InsertEmitCtx, insertion: &Insertion) { for column_mapping in insertion .col_mappings .iter() @@ -916,7 +871,7 @@ pub fn translate_insert( err_code: SQLITE_CONSTRAINT_NOTNULL, description: { let mut description = String::with_capacity( - table_name.as_str().len() + ctx.table.name.as_str().len() + column_mapping .column .name @@ -925,7 +880,7 @@ pub fn translate_insert( .len() + 2, ); - description.push_str(table_name.as_str()); + description.push_str(ctx.table.name.as_str()); description.push('.'); description.push_str( column_mapping @@ -938,238 +893,345 @@ pub fn translate_insert( }, }); } +} - // Create and insert the record - let affinity_str = insertion - .col_mappings - .iter() - .map(|col_mapping| col_mapping.column.affinity().aff_mask()) - .collect::(); +struct BoundInsertResult { + #[allow(clippy::vec_box)] + values: Vec>, + upsert_actions: Vec<(ResolvedUpsertTarget, BranchOffset, Box)>, + inserting_multiple_rows: bool, +} - program.emit_insn(Insn::MakeRecord { - start_reg: insertion.first_col_register(), - count: insertion.col_mappings.len(), - dest_reg: insertion.record_register(), - index_name: None, - affinity_str: Some(affinity_str), - }); - - if has_upsert { - // COMMIT PHASE: no preflight jumps happened; emit the actual index writes now - // We re-check partial-index predicates against the NEW image, produce packed records, - // and insert into all applicable indexes, we do not re-probe uniqueness here, as preflight - // already guaranteed non-conflict. - for index in resolver.schema.get_indices(table_name.as_str()) { - let idx_cursor_id = idx_cursors +fn bind_insert( + program: &mut ProgramBuilder, + resolver: &Resolver, + table: &Table, + body: &mut InsertBody, + connection: &Arc, + on_conflict: ResolveType, +) -> Result { + let mut values: Vec> = vec![]; + let mut upsert: Option> = None; + let mut upsert_actions: Vec<(ResolvedUpsertTarget, BranchOffset, Box)> = Vec::new(); + let mut inserting_multiple_rows = false; + match body { + InsertBody::DefaultValues => { + // Generate default values for the table + values = table + .columns() .iter() - .find(|(name, _, _)| *name == &index.name) - .map(|(_, _, c_id)| *c_id) - .expect("no cursor found for index"); - - // Re-evaluate partial predicate on the would-be inserted image - let commit_skip_label = if let Some(where_clause) = &index.where_clause { - let mut where_for_eval = where_clause.as_ref().clone(); - rewrite_partial_index_where(&mut where_for_eval, &insertion)?; - let reg = program.alloc_register(); - translate_expr_no_constant_opt( - &mut program, - Some(&TableReferences::new_empty()), - &where_for_eval, - reg, - resolver, - NoConstantOptReason::RegisterReuse, - )?; - let lbl = program.allocate_label(); - program.emit_insn(Insn::IfNot { - reg, - target_pc: lbl, - jump_if_null: true, - }); - Some(lbl) - } else { - None - }; - - let num_cols = index.columns.len(); - let idx_start_reg = program.alloc_registers(num_cols + 1); - - // Build [key cols..., rowid] from insertion registers - for (i, idx_col) in index.columns.iter().enumerate() { - let Some(cm) = insertion.get_col_mapping_by_name(&idx_col.name) else { - return Err(crate::LimboError::PlanningError( - "Column not found in INSERT (commit phase)".to_string(), - )); - }; - program.emit_insn(Insn::Copy { - src_reg: cm.register, - dst_reg: idx_start_reg + i, - extra_amount: 0, - }); - } - program.emit_insn(Insn::Copy { - src_reg: insertion.key_register(), - dst_reg: idx_start_reg + num_cols, - extra_amount: 0, - }); - - let record_reg = program.alloc_register(); - program.emit_insn(Insn::MakeRecord { - start_reg: idx_start_reg, - count: num_cols + 1, - dest_reg: record_reg, - index_name: Some(index.name.clone()), - affinity_str: None, - }); - program.emit_insn(Insn::IdxInsert { - cursor_id: idx_cursor_id, - record_reg, - unpacked_start: Some(idx_start_reg), - unpacked_count: Some((num_cols + 1) as u16), - flags: IdxInsertFlags::new().nchange(true), - }); - - if let Some(lbl) = commit_skip_label { - program.resolve_label(lbl, program.offset()); + .filter(|c| !c.hidden) + .map(|c| { + c.default + .clone() + .unwrap_or(Box::new(ast::Expr::Literal(ast::Literal::Null))) + }) + .collect(); + } + InsertBody::Select(select, upsert_opt) => { + match &mut select.body.select { + // TODO see how to avoid clone + OneSelect::Values(values_expr) if values_expr.len() <= 1 => { + if values_expr.is_empty() { + crate::bail_parse_error!("no values to insert"); + } + for expr in values_expr.iter_mut().flat_map(|v| v.iter_mut()) { + match expr.as_mut() { + Expr::Id(name) => { + if name.quoted_with('"') { + *expr = Expr::Literal(ast::Literal::String(name.as_literal())) + .into(); + } else { + // an INSERT INTO ... VALUES (...) cannot reference columns + crate::bail_parse_error!("no such column: {name}"); + } + } + Expr::Qualified(first_name, second_name) => { + // an INSERT INTO ... VALUES (...) cannot reference columns + crate::bail_parse_error!( + "no such column: {first_name}.{second_name}" + ); + } + _ => {} + } + bind_and_rewrite_expr( + expr, + None, + None, + connection, + &mut program.param_ctx, + BindingBehavior::ResultColumnsNotAllowed, + )?; + } + values = values_expr.pop().unwrap_or_else(Vec::new); + } + _ => inserting_multiple_rows = true, } + upsert = upsert_opt.take(); } } - - program.emit_insn(Insn::Insert { - cursor: cursor_id, - key_reg: insertion.key_register(), - record_reg: insertion.record_register(), - flag: InsertFlags::new(), - table_name: table_name.to_string(), - }); - - if let Some((seq_cursor_id, r_seq, r_seq_rowid, table_name_reg)) = autoincrement_meta { - let no_update_needed_label = program.allocate_label(); - program.emit_insn(Insn::Le { - lhs: insertion.key_register(), - rhs: r_seq, - target_pc: no_update_needed_label, - flags: Default::default(), - collation: None, - }); - - emit_update_sqlite_sequence( - &mut program, - resolver.schema, - seq_cursor_id, - r_seq_rowid, - table_name_reg, - insertion.key_register(), - )?; - - program.preassign_label_to_next_insn(no_update_needed_label); - program.emit_insn(Insn::Close { - cursor_id: seq_cursor_id, - }); + match on_conflict { + ResolveType::Ignore => { + upsert.replace(Box::new(ast::Upsert { + do_clause: UpsertDo::Nothing, + index: None, + next: None, + })); + } + ResolveType::Abort => { + // This is the default conflict resolution strategy for INSERT in SQLite. + } + _ => { + crate::bail_parse_error!( + "INSERT OR {} is only supported with UPSERT", + on_conflict.to_string() + ); + } } - - // Emit update in the CDC table if necessary (after the INSERT updated the table) - if let Some((cdc_cursor_id, _)) = &cdc_table { - let cdc_has_after = program.capture_data_changes_mode().has_after(); - let after_record_reg = if cdc_has_after { - Some(emit_cdc_patch_record( - &mut program, - &table, - insertion.first_col_register(), - insertion.record_register(), - insertion.key_register(), - )) - } else { - None - }; - emit_cdc_insns( - &mut program, - resolver, - OperationMode::INSERT, - *cdc_cursor_id, - insertion.key_register(), - None, - after_record_reg, - None, - table_name.as_str(), - )?; - } - - // Emit RETURNING results if specified - if !result_columns.is_empty() { - let value_registers = ReturningValueRegisters { - rowid_register: insertion.key_register(), - columns_start_register: insertion.first_col_register(), - num_columns: table.columns().len(), - }; - - emit_returning_results(&mut program, &result_columns, &value_registers)?; - } - program.emit_insn(Insn::Goto { - target_pc: row_done_label, - }); - - for (_, label, mut upsert) in upsert_actions { - program.preassign_label_to_next_insn(label); - + while let Some(mut upsert_opt) = upsert.take() { if let UpsertDo::Set { ref mut sets, ref mut where_clause, - } = upsert.do_clause + } = &mut upsert_opt.do_clause { - // Normalize SET pairs once - let mut rewritten_sets = collect_set_clauses_for_upsert(&table, sets)?; - - emit_upsert( - &mut program, - &table, - &insertion, - cursor_id, - conflict_rowid_reg, - &mut rewritten_sets, - where_clause, - resolver, - &idx_cursors, - &mut result_columns, - cdc_table.as_ref().map(|c| c.0), - row_done_label, - )?; - } else { - // UpsertDo::Nothing case - program.emit_insn(Insn::Goto { - target_pc: row_done_label, - }); + for set in sets.iter_mut() { + bind_and_rewrite_expr( + &mut set.expr, + None, + None, + connection, + &mut program.param_ctx, + BindingBehavior::AllowUnboundIdentifiers, + )?; + } + if let Some(ref mut where_expr) = where_clause { + bind_and_rewrite_expr( + where_expr, + None, + None, + connection, + &mut program.param_ctx, + BindingBehavior::AllowUnboundIdentifiers, + )?; + } } + let next = upsert_opt.next.take(); + upsert_actions.push(( + // resolve the constrained target for UPSERT in the chain + resolve_upsert_target(resolver.schema, table, &upsert_opt)?, + program.allocate_label(), + upsert_opt, + )); + upsert = next; } + Ok(BoundInsertResult { + values, + upsert_actions, + inserting_multiple_rows, + }) +} - if inserting_multiple_rows { - if let Some(temp_table_ctx) = temp_table_ctx { - program.resolve_label(row_done_label, program.offset()); +/// Depending on the InsertBody, we begin to initialize the source of the insert values +/// into registers using the following methods: +/// +/// Values with a single row, expressions are directly evaluated into registers, so nothing +/// is emitted here, we simply allocate the cursor ID and store the arity. +/// +/// Values with multiple rows, we use a coroutine to yield each row into registers directly. +/// +/// Select, we use a coroutine to yield each row from the SELECT into registers, +/// materializing into a temporary table if the target table is also read by the SELECT. +/// +/// For DefaultValues, we allocate the cursor and extend the empty values vector with either the +/// default expressions registered for the columns, or NULLs, so they can be translated into +/// registers later. +#[allow(clippy::too_many_arguments, clippy::vec_box)] +fn init_source_emission<'a>( + mut program: ProgramBuilder, + table: &Table, + connection: &Arc, + ctx: &mut InsertEmitCtx<'a>, + resolver: &Resolver, + values: &mut Vec>, + body: InsertBody, + columns: &'a [ast::Name], +) -> Result { + let (num_values, cursor_id) = match body { + InsertBody::Select(select, _) => { + // Simple Common case of INSERT INTO
VALUES (...) + if matches!(&select.body.select, OneSelect::Values(values) if values.len() <= 1) { + ( + values.len(), + program.alloc_cursor_id(CursorType::BTreeTable(ctx.table.clone())), + ) + } else { + // Multiple rows - use coroutine for value population + let yield_reg = program.alloc_register(); + let jump_on_definition_label = program.allocate_label(); + let start_offset_label = program.allocate_label(); + program.emit_insn(Insn::InitCoroutine { + yield_reg, + jump_on_definition: jump_on_definition_label, + start_offset: start_offset_label, + }); + program.preassign_label_to_next_insn(start_offset_label); - program.emit_insn(Insn::Next { - cursor_id: temp_table_ctx.cursor_id, - pc_if_next: temp_table_ctx.loop_start_label, - }); - program.preassign_label_to_next_insn(temp_table_ctx.loop_end_label); + let query_destination = QueryDestination::CoroutineYield { + yield_reg, + coroutine_implementation_start: ctx.halt_label, + }; + program.incr_nesting(); + let result = + translate_select(select, resolver, program, query_destination, connection)?; + program = result.program; + program.decr_nesting(); - program.emit_insn(Insn::Close { - cursor_id: temp_table_ctx.cursor_id, - }); - } else { - // For multiple rows which not require a temp table, loop back - program.resolve_label(row_done_label, program.offset()); - program.emit_insn(Insn::Goto { - target_pc: loop_start_label, - }); + program.emit_insn(Insn::EndCoroutine { yield_reg }); + program.preassign_label_to_next_insn(jump_on_definition_label); + + let cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(ctx.table.clone())); + + // From SQLite + /* Set useTempTable to TRUE if the result of the SELECT statement + ** should be written into a temporary table (template 4). Set to + ** FALSE if each output row of the SELECT can be written directly into + ** the destination table (template 3). + ** + ** A temp table must be used if the table being updated is also one + ** of the tables being read by the SELECT statement. Also use a + ** temp table in the case of row triggers. + */ + if program.is_table_open(table) { + let temp_cursor_id = + program.alloc_cursor_id(CursorType::BTreeTable(ctx.table.clone())); + ctx.temp_table_ctx = Some(TempTableCtx { + cursor_id: temp_cursor_id, + loop_start_label: program.allocate_label(), + loop_end_label: program.allocate_label(), + }); + + program.emit_insn(Insn::OpenEphemeral { + cursor_id: temp_cursor_id, + is_table: true, + }); + + // Main loop + program.preassign_label_to_next_insn(ctx.loop_start_label); + let yield_label = program.allocate_label(); + program.emit_insn(Insn::Yield { + yield_reg, + end_offset: yield_label, // stays local, we’ll route at loop end + }); + + let record_reg = program.alloc_register(); + let affinity_str = if columns.is_empty() { + ctx.table + .columns + .iter() + .filter(|col| !col.hidden) + .map(|col| col.affinity().aff_mask()) + .collect::() + } else { + columns + .iter() + .map(|col_name| { + let column_name = normalize_ident(col_name.as_str()); + if ROWID_STRS + .iter() + .any(|s| s.eq_ignore_ascii_case(&column_name)) + { + return Affinity::Integer.aff_mask(); + } + table + .get_column_by_name(&column_name) + .unwrap() + .1 + .affinity() + .aff_mask() + }) + .collect::() + }; + + program.emit_insn(Insn::MakeRecord { + start_reg: program.reg_result_cols_start.unwrap_or(yield_reg + 1), + count: result.num_result_cols, + dest_reg: record_reg, + index_name: None, + affinity_str: Some(affinity_str), + }); + + let rowid_reg = program.alloc_register(); + program.emit_insn(Insn::NewRowid { + cursor: temp_cursor_id, + rowid_reg, + prev_largest_reg: 0, + }); + program.emit_insn(Insn::Insert { + cursor: temp_cursor_id, + key_reg: rowid_reg, + record_reg, + // since we are not doing an Insn::NewRowid or an Insn::NotExists here, we need to seek to ensure the insertion happens in the correct place. + flag: InsertFlags::new().require_seek(), + table_name: "".to_string(), + }); + // loop back + program.emit_insn(Insn::Goto { + target_pc: ctx.loop_start_label, + }); + program.preassign_label_to_next_insn(yield_label); + + program.emit_insn(Insn::OpenWrite { + cursor_id, + root_page: RegisterOrLiteral::Literal(ctx.table.root_page), + db: 0, + }); + } else { + program.emit_insn(Insn::OpenWrite { + cursor_id, + root_page: RegisterOrLiteral::Literal(ctx.table.root_page), + db: 0, + }); + + program.preassign_label_to_next_insn(ctx.loop_start_label); + + // on EOF, jump to select_exhausted to check FK constraints + let select_exhausted = program.allocate_label(); + ctx.select_exhausted_label = Some(select_exhausted); + program.emit_insn(Insn::Yield { + yield_reg, + end_offset: select_exhausted, + }); + } + + ctx.yield_reg_opt = Some(yield_reg); + (result.num_result_cols, cursor_id) + } } - } else { - program.resolve_label(row_done_label, program.offset()); - } - - program.resolve_label(halt_label, program.offset()); - + InsertBody::DefaultValues => { + let num_values = table.columns().len(); + values.extend(table.columns().iter().map(|c| { + c.default + .clone() + .unwrap_or(Box::new(ast::Expr::Literal(ast::Literal::Null))) + })); + ( + num_values, + program.alloc_cursor_id(CursorType::BTreeTable(ctx.table.clone())), + ) + } + }; + ctx.num_values = num_values; + ctx.cursor_id = cursor_id; Ok(program) } +pub struct AutoincMeta { + seq_cursor_id: usize, + r_seq: usize, + r_seq_rowid: usize, + table_name_reg: usize, +} + pub const ROWID_COLUMN: &Column = &Column { name: None, ty: schema::Type::Integer, @@ -1556,6 +1618,244 @@ fn translate_column( Ok(()) } +// Preflight phase: evaluate each applicable UNIQUE constraint and probe with NoConflict. +// If any probe hits: +// DO NOTHING -> jump to row_done_label. +// +// DO UPDATE (matching target) -> fetch conflicting rowid and jump to `upsert_entry`. +// +// otherwise, raise SQLITE_CONSTRAINT_UNIQUE +fn emit_preflight_constraint_checks( + program: &mut ProgramBuilder, + ctx: &InsertEmitCtx, + resolver: &Resolver, + insertion: &Insertion, + upsert_actions: &[(ResolvedUpsertTarget, BranchOffset, Box)], + constraints: &ConstraintsToCheck, +) -> Result<()> { + for (constraint, position) in &constraints.constraints_to_check { + match constraint { + ResolvedUpsertTarget::PrimaryKey => { + let make_record_label = program.allocate_label(); + program.emit_insn(Insn::NotExists { + cursor: ctx.cursor_id, + rowid_reg: insertion.key_register(), + target_pc: make_record_label, + }); + let rowid_column_name = insertion.key.column_name(); + + // Conflict on rowid: attempt to route through UPSERT if it targets the PK, otherwise raise constraint. + // emit Halt for every case *except* when upsert handles the conflict + 'emit_halt: { + if let Some(position) = position.or(constraints.upsert_catch_all_position) { + // PK conflict: the conflicting rowid is exactly the attempted key + program.emit_insn(Insn::Copy { + src_reg: insertion.key_register(), + dst_reg: ctx.conflict_rowid_reg, + extra_amount: 0, + }); + program.emit_insn(Insn::Goto { + target_pc: upsert_actions[position].1, + }); + break 'emit_halt; + } + let mut description = + String::with_capacity(ctx.table.name.len() + rowid_column_name.len() + 2); + description.push_str(ctx.table.name.as_str()); + description.push('.'); + description.push_str(rowid_column_name); + program.emit_insn(Insn::Halt { + err_code: SQLITE_CONSTRAINT_PRIMARYKEY, + description, + }); + } + program.preassign_label_to_next_insn(make_record_label); + } + ResolvedUpsertTarget::Index(index) => { + let column_mappings = index + .columns + .iter() + .map(|idx_col| insertion.get_col_mapping_by_name(&idx_col.name)); + // find which cursor we opened earlier for this index + let idx_cursor_id = ctx + .idx_cursors + .iter() + .find(|(name, _, _)| *name == &index.name) + .map(|(_, _, c_id)| *c_id) + .expect("no cursor found for index"); + + let maybe_skip_probe_label = if let Some(where_clause) = &index.where_clause { + let mut where_for_eval = where_clause.as_ref().clone(); + rewrite_partial_index_where(&mut where_for_eval, insertion)?; + let reg = program.alloc_register(); + translate_expr_no_constant_opt( + program, + Some(&TableReferences::new_empty()), + &where_for_eval, + reg, + resolver, + NoConstantOptReason::RegisterReuse, + )?; + let lbl = program.allocate_label(); + program.emit_insn(Insn::IfNot { + reg, + target_pc: lbl, + jump_if_null: true, + }); + Some(lbl) + } else { + None + }; + + let num_cols = index.columns.len(); + // allocate scratch registers for the index columns plus rowid + let idx_start_reg = program.alloc_registers(num_cols + 1); + + // build unpacked key [idx_start_reg .. idx_start_reg+num_cols-1], and rowid in last reg, + // copy each index column from the table's column registers into these scratch regs + for (i, column_mapping) in column_mappings.clone().enumerate() { + // copy from the table's column register over to the index's scratch register + let Some(col_mapping) = column_mapping else { + return Err(crate::LimboError::PlanningError( + "Column not found in INSERT".to_string(), + )); + }; + program.emit_insn(Insn::Copy { + src_reg: col_mapping.register, + dst_reg: idx_start_reg + i, + extra_amount: 0, + }); + } + // last register is the rowid + program.emit_insn(Insn::Copy { + src_reg: insertion.key_register(), + dst_reg: idx_start_reg + num_cols, + extra_amount: 0, + }); + + if index.unique { + let aff = index + .columns + .iter() + .map(|ic| ctx.table.columns[ic.pos_in_table].affinity().aff_mask()) + .collect::(); + program.emit_insn(Insn::Affinity { + start_reg: idx_start_reg, + count: NonZeroUsize::new(num_cols).expect("nonzero col count"), + affinities: aff, + }); + + if !upsert_actions.is_empty() { + let next_check = program.allocate_label(); + program.emit_insn(Insn::NoConflict { + cursor_id: idx_cursor_id, + target_pc: next_check, + record_reg: idx_start_reg, + num_regs: num_cols, + }); + + // Conflict detected, figure out if this UPSERT handles the conflict + if let Some(position) = position.or(constraints.upsert_catch_all_position) { + match &upsert_actions[position].2.do_clause { + UpsertDo::Nothing => { + // Bail out without writing anything + program.emit_insn(Insn::Goto { + target_pc: ctx.row_done_label, + }); + } + UpsertDo::Set { .. } => { + // Route to DO UPDATE: capture conflicting rowid then jump + program.emit_insn(Insn::IdxRowId { + cursor_id: idx_cursor_id, + dest: ctx.conflict_rowid_reg, + }); + program.emit_insn(Insn::Goto { + target_pc: upsert_actions[position].1, + }); + } + } + } + // No matching UPSERT handler so we emit constraint error + // (if conflict clause matched - VM will jump to later instructions and skip halt) + program.emit_insn(Insn::Halt { + err_code: SQLITE_CONSTRAINT_UNIQUE, + description: format_unique_violation_desc( + ctx.table.name.as_str(), + index, + ), + }); + + // continue preflight with next constraint + program.preassign_label_to_next_insn(next_check); + } else { + // No UPSERT fast-path: probe and immediately insert + let ok = program.allocate_label(); + program.emit_insn(Insn::NoConflict { + cursor_id: idx_cursor_id, + target_pc: ok, + record_reg: idx_start_reg, + num_regs: num_cols, + }); + // Unique violation without ON CONFLICT clause -> error + program.emit_insn(Insn::Halt { + err_code: SQLITE_CONSTRAINT_UNIQUE, + description: format_unique_violation_desc( + ctx.table.name.as_str(), + index, + ), + }); + program.preassign_label_to_next_insn(ok); + + // In the non-UPSERT case, we insert the index + let record_reg = program.alloc_register(); + program.emit_insn(Insn::MakeRecord { + start_reg: idx_start_reg, + count: num_cols + 1, + dest_reg: record_reg, + index_name: Some(index.name.clone()), + affinity_str: None, + }); + program.emit_insn(Insn::IdxInsert { + cursor_id: idx_cursor_id, + record_reg, + unpacked_start: Some(idx_start_reg), + unpacked_count: Some((num_cols + 1) as u16), + flags: IdxInsertFlags::new().nchange(true), + }); + } + } else { + // Non-unique index: in UPSERT mode we postpone writes to commit phase. + if upsert_actions.is_empty() { + // eager insert for non-unique, no UPSERT + let record_reg = program.alloc_register(); + program.emit_insn(Insn::MakeRecord { + start_reg: idx_start_reg, + count: num_cols + 1, + dest_reg: record_reg, + index_name: Some(index.name.clone()), + affinity_str: None, + }); + program.emit_insn(Insn::IdxInsert { + cursor_id: idx_cursor_id, + record_reg, + unpacked_start: Some(idx_start_reg), + unpacked_count: Some((num_cols + 1) as u16), + flags: IdxInsertFlags::new().nchange(true), + }); + } + } + + // Close the partial-index skip (preflight) + if let Some(lbl) = maybe_skip_probe_label { + program.resolve_label(lbl, program.offset()); + } + } + ResolvedUpsertTarget::CatchAll => unreachable!(), + } + } + Ok(()) +} + // TODO: comeback here later to apply the same improvements on select fn translate_virtual_table_insert( mut program: ProgramBuilder, @@ -1786,6 +2086,52 @@ pub fn rewrite_partial_index_where( ) } +struct ConstraintsToCheck { + constraints_to_check: Vec<(ResolvedUpsertTarget, Option)>, + upsert_catch_all_position: Option, +} + +fn build_constraints_to_check( + resolver: &Resolver, + table_name: &str, + upsert_actions: &[(ResolvedUpsertTarget, BranchOffset, Box)], + has_user_provided_rowid: bool, +) -> ConstraintsToCheck { + let mut constraints_to_check = Vec::new(); + if has_user_provided_rowid { + // Check uniqueness constraint for rowid if it was provided by user. + // When the DB allocates it there are no need for separate uniqueness checks. + let position = upsert_actions + .iter() + .position(|(target, ..)| matches!(target, ResolvedUpsertTarget::PrimaryKey)); + constraints_to_check.push((ResolvedUpsertTarget::PrimaryKey, position)); + } + for index in resolver.schema.get_indices(table_name) { + let position = upsert_actions + .iter() + .position(|(target, ..)| matches!(target, ResolvedUpsertTarget::Index(x) if Arc::ptr_eq(x, index))); + constraints_to_check.push((ResolvedUpsertTarget::Index(index.clone()), position)); + } + + constraints_to_check.sort_by(|(_, p1), (_, p2)| match (p1, p2) { + (Some(p1), Some(p2)) => p1.cmp(p2), + (Some(_), None) => std::cmp::Ordering::Less, + (None, Some(_)) => std::cmp::Ordering::Greater, + (None, None) => std::cmp::Ordering::Equal, + }); + + let upsert_catch_all_position = + if let Some((ResolvedUpsertTarget::CatchAll, ..)) = upsert_actions.last() { + Some(upsert_actions.len() - 1) + } else { + None + }; + ConstraintsToCheck { + constraints_to_check, + upsert_catch_all_position, + } +} + fn emit_update_sqlite_sequence( program: &mut ProgramBuilder, schema: &Schema, @@ -1857,3 +2203,385 @@ fn emit_update_sqlite_sequence( Ok(()) } + +/// Child-side FK checks for INSERT of a single row: +/// For each outgoing FK on `child_tbl`, if the NEW tuple's FK columns are all non-NULL, +/// verify that the referenced parent key exists. +pub fn emit_fk_child_insert_checks( + program: &mut ProgramBuilder, + resolver: &Resolver, + child_tbl: &BTreeTable, + new_start_reg: usize, + new_rowid_reg: usize, +) -> crate::Result<()> { + for fk_ref in resolver.schema.resolved_fks_for_child(&child_tbl.name)? { + let is_self_ref = fk_ref.fk.parent_table.eq_ignore_ascii_case(&child_tbl.name); + + // Short-circuit if any NEW component is NULL + let fk_ok = program.allocate_label(); + for cname in &fk_ref.child_cols { + let (i, col) = child_tbl.get_column(cname).unwrap(); + let src = if col.is_rowid_alias { + new_rowid_reg + } else { + new_start_reg + i + }; + program.emit_insn(Insn::IsNull { + reg: src, + target_pc: fk_ok, + }); + } + let parent_tbl = resolver + .schema + .get_btree_table(&fk_ref.fk.parent_table) + .expect("parent btree"); + if fk_ref.parent_uses_rowid { + let pcur = open_read_table(program, &parent_tbl); + + // first child col carries rowid + let (i_child, col_child) = child_tbl.get_column(&fk_ref.child_cols[0]).unwrap(); + let val_reg = if col_child.is_rowid_alias { + new_rowid_reg + } else { + new_start_reg + i_child + }; + + // Normalize rowid to integer for both the probe and the same-row fast path. + let tmp = program.alloc_register(); + program.emit_insn(Insn::Copy { + src_reg: val_reg, + dst_reg: tmp, + extra_amount: 0, + }); + program.emit_insn(Insn::MustBeInt { reg: tmp }); + + // If this is a self-reference *and* the child FK equals NEW rowid, + // the constraint will be satisfied once this row is inserted + if is_self_ref { + program.emit_insn(Insn::Eq { + lhs: tmp, + rhs: new_rowid_reg, + target_pc: fk_ok, + flags: CmpInsFlags::default(), + collation: None, + }); + } + + let violation = program.allocate_label(); + program.emit_insn(Insn::NotExists { + cursor: pcur, + rowid_reg: tmp, + target_pc: violation, + }); + program.emit_insn(Insn::Close { cursor_id: pcur }); + program.emit_insn(Insn::Goto { target_pc: fk_ok }); + + // Missing parent: immediate vs deferred as usual + program.preassign_label_to_next_insn(violation); + program.emit_insn(Insn::Close { cursor_id: pcur }); + emit_fk_violation(program, &fk_ref.fk)?; + program.preassign_label_to_next_insn(fk_ok); + } else { + let idx = fk_ref + .parent_unique_index + .as_ref() + .expect("parent unique index required"); + let icur = open_read_index(program, idx); + let ncols = fk_ref.child_cols.len(); + + // Build NEW child probe from child NEW values, apply parent-index affinities. + let probe = { + let start = program.alloc_registers(ncols); + for (k, cname) in fk_ref.child_cols.iter().enumerate() { + let (i, col) = child_tbl.get_column(cname).unwrap(); + program.emit_insn(Insn::Copy { + src_reg: if col.is_rowid_alias { + new_rowid_reg + } else { + new_start_reg + i + }, + dst_reg: start + k, + extra_amount: 0, + }); + } + if let Some(cnt) = NonZeroUsize::new(ncols) { + program.emit_insn(Insn::Affinity { + start_reg: start, + count: cnt, + affinities: build_index_affinity_string(idx, &parent_tbl), + }); + } + start + }; + if is_self_ref { + // Determine the parent column order to compare against: + let parent_cols: Vec<&str> = + idx.columns.iter().map(|ic| ic.name.as_str()).collect(); + + // Build new parent-key image from this same row’s new values, in the index order. + let parent_new = program.alloc_registers(ncols); + for (i, pname) in parent_cols.iter().enumerate() { + let (pos, col) = child_tbl.get_column(pname).unwrap(); + program.emit_insn(Insn::Copy { + src_reg: if col.is_rowid_alias { + new_rowid_reg + } else { + new_start_reg + pos + }, + dst_reg: parent_new + i, + extra_amount: 0, + }); + } + if let Some(cnt) = NonZeroUsize::new(ncols) { + program.emit_insn(Insn::Affinity { + start_reg: parent_new, + count: cnt, + affinities: build_index_affinity_string(idx, &parent_tbl), + }); + } + + // Compare child probe to NEW parent image column-by-column. + let mismatch = program.allocate_label(); + for i in 0..ncols { + let cont = program.allocate_label(); + program.emit_insn(Insn::Eq { + lhs: probe + i, + rhs: parent_new + i, + target_pc: cont, + flags: CmpInsFlags::default().jump_if_null(), + collation: Some(super::collate::CollationSeq::Binary), + }); + program.emit_insn(Insn::Goto { + target_pc: mismatch, + }); + program.preassign_label_to_next_insn(cont); + } + // All equal: same-row OK + program.emit_insn(Insn::Goto { target_pc: fk_ok }); + program.preassign_label_to_next_insn(mismatch); + } + index_probe( + program, + icur, + probe, + ncols, + // on_found: parent exists, FK satisfied + |_p| Ok(()), + // on_not_found: behave like a normal FK + |p| { + emit_fk_violation(p, &fk_ref.fk)?; + Ok(()) + }, + )?; + program.emit_insn(Insn::Goto { target_pc: fk_ok }); + program.preassign_label_to_next_insn(fk_ok); + } + } + Ok(()) +} + +/// Build NEW parent key image in FK parent-column order into a contiguous register block. +/// Handles 3 shapes: +/// - parent_uses_rowid: single "rowid" component +/// - explicit fk.parent_columns +/// - fk.parent_columns empty => use parent's declared PK columns (order-preserving) +fn build_parent_key_image_for_insert( + program: &mut ProgramBuilder, + parent_table: &BTreeTable, + pref: &ResolvedFkRef, + insertion: &Insertion, +) -> crate::Result<(usize, usize)> { + // Decide column list + let parent_cols: Vec = if pref.parent_uses_rowid { + vec!["rowid".to_string()] + } else if !pref.fk.parent_columns.is_empty() { + pref.fk.parent_columns.clone() + } else { + // fall back to the declared PK of the parent table, in schema order + parent_table + .primary_key_columns + .iter() + .map(|(n, _)| n.clone()) + .collect() + }; + + let ncols = parent_cols.len(); + let start = program.alloc_registers(ncols); + // Copy from the would-be parent insertion + for (i, pname) in parent_cols.iter().enumerate() { + let src = if pname.eq_ignore_ascii_case("rowid") { + insertion.key_register() + } else { + // For rowid-alias parents, get_col_mapping_by_name will return the key mapping, + // not the NULL placeholder in col_mappings. + insertion + .get_col_mapping_by_name(pname) + .ok_or_else(|| { + crate::LimboError::PlanningError(format!( + "Column '{}' not present in INSERT image for parent {}", + pname, parent_table.name + )) + })? + .register + }; + program.emit_insn(Insn::Copy { + src_reg: src, + dst_reg: start + i, + extra_amount: 0, + }); + } + + // Apply affinities of the parent columns (or integer for rowid) + let aff: String = if pref.parent_uses_rowid { + "i".to_string() + } else { + parent_cols + .iter() + .map(|name| { + let (_, col) = parent_table.get_column(name).ok_or_else(|| { + crate::LimboError::InternalError(format!("parent col {name} missing")) + })?; + Ok::<_, crate::LimboError>(col.affinity().aff_mask()) + }) + .collect::>()? + }; + if let Some(count) = NonZeroUsize::new(ncols) { + program.emit_insn(Insn::Affinity { + start_reg: start, + count, + affinities: aff, + }); + } + + Ok((start, ncols)) +} + +/// Parent-side: when inserting into the parent, decrement the counter +/// if any child rows reference the NEW parent key. +/// We *always* do this for deferred FKs, and we *also* do it for +/// self-referential FKs (even if immediate) because the insert can +/// “repair” a prior child-insert count recorded earlier in the same statement. +pub fn emit_parent_side_fk_decrement_on_insert( + program: &mut ProgramBuilder, + resolver: &Resolver, + parent_table: &BTreeTable, + insertion: &Insertion, +) -> crate::Result<()> { + for pref in resolver + .schema + .resolved_fks_referencing(&parent_table.name)? + { + let is_self_ref = pref + .child_table + .name + .eq_ignore_ascii_case(&parent_table.name); + // Skip only when it cannot repair anything: non-deferred and not self-referencing + if !pref.fk.deferred && !is_self_ref { + continue; + } + let (new_pk_start, n_cols) = + build_parent_key_image_for_insert(program, parent_table, &pref, insertion)?; + + let child_tbl = &pref.child_table; + let child_cols = &pref.fk.child_columns; + let idx = resolver.schema.get_indices(&child_tbl.name).find(|ix| { + ix.columns.len() == child_cols.len() + && ix + .columns + .iter() + .zip(child_cols.iter()) + .all(|(ic, cc)| ic.name.eq_ignore_ascii_case(cc)) + }); + + if let Some(ix) = idx { + let icur = open_read_index(program, ix); + // Copy key into probe regs and apply child-index affinities + let probe_start = program.alloc_registers(n_cols); + for i in 0..n_cols { + program.emit_insn(Insn::Copy { + src_reg: new_pk_start + i, + dst_reg: probe_start + i, + extra_amount: 0, + }); + } + if let Some(count) = NonZeroUsize::new(n_cols) { + program.emit_insn(Insn::Affinity { + start_reg: probe_start, + count, + affinities: build_index_affinity_string(ix, child_tbl), + }); + } + + let found = program.allocate_label(); + program.emit_insn(Insn::Found { + cursor_id: icur, + target_pc: found, + record_reg: probe_start, + num_regs: n_cols, + }); + + // Not found, nothing to decrement + program.emit_insn(Insn::Close { cursor_id: icur }); + let skip = program.allocate_label(); + program.emit_insn(Insn::Goto { target_pc: skip }); + + // Found: guarded counter decrement + program.resolve_label(found, program.offset()); + program.emit_insn(Insn::Close { cursor_id: icur }); + emit_guarded_fk_decrement(program, skip); + program.resolve_label(skip, program.offset()); + } else { + // fallback scan :( + let ccur = open_read_table(program, child_tbl); + let done = program.allocate_label(); + program.emit_insn(Insn::Rewind { + cursor_id: ccur, + pc_if_empty: done, + }); + let loop_top = program.allocate_label(); + let next_row = program.allocate_label(); + program.resolve_label(loop_top, program.offset()); + + for (i, child_name) in child_cols.iter().enumerate() { + let (pos, _) = child_tbl.get_column(child_name).ok_or_else(|| { + crate::LimboError::InternalError(format!("child col {child_name} missing")) + })?; + let tmp = program.alloc_register(); + program.emit_insn(Insn::Column { + cursor_id: ccur, + column: pos, + dest: tmp, + default: None, + }); + + program.emit_insn(Insn::IsNull { + reg: tmp, + target_pc: next_row, + }); + + let cont = program.allocate_label(); + program.emit_insn(Insn::Eq { + lhs: tmp, + rhs: new_pk_start + i, + target_pc: cont, + flags: CmpInsFlags::default().jump_if_null(), + collation: Some(super::collate::CollationSeq::Binary), + }); + program.emit_insn(Insn::Goto { + target_pc: next_row, + }); + program.resolve_label(cont, program.offset()); + } + // Matched one child row: guarded decrement of counter + emit_guarded_fk_decrement(program, next_row); + program.resolve_label(next_row, program.offset()); + program.emit_insn(Insn::Next { + cursor_id: ccur, + pc_if_next: loop_top, + }); + program.resolve_label(done, program.offset()); + program.emit_insn(Insn::Close { cursor_id: ccur }); + } + } + Ok(()) +} diff --git a/core/translate/integrity_check.rs b/core/translate/integrity_check.rs index fc750fef8..a2bafe3f0 100644 --- a/core/translate/integrity_check.rs +++ b/core/translate/integrity_check.rs @@ -16,6 +16,11 @@ pub fn translate_integrity_check( for table in schema.tables.values() { if let crate::schema::Table::BTree(table) = table.as_ref() { root_pages.push(table.root_page); + if let Some(indexes) = schema.indexes.get(table.name.as_str()) { + for index in indexes.iter() { + root_pages.push(index.root_page); + } + } }; } let message_register = program.alloc_register(); diff --git a/core/translate/logical.rs b/core/translate/logical.rs index 349b5f64b..bdd29972e 100644 --- a/core/translate/logical.rs +++ b/core/translate/logical.rs @@ -2389,6 +2389,7 @@ mod tests { name: "users".to_string(), root_page: 2, primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + foreign_keys: vec![], columns: vec![ SchemaColumn { name: Some("id".to_string()), @@ -2444,7 +2445,9 @@ mod tests { has_autoincrement: false, unique_sets: vec![], }; - schema.add_btree_table(Arc::new(users_table)); + schema + .add_btree_table(Arc::new(users_table)) + .expect("Test setup: failed to add users table"); // Create orders table let orders_table = BTreeTable { @@ -2505,8 +2508,11 @@ mod tests { is_strict: false, has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }; - schema.add_btree_table(Arc::new(orders_table)); + schema + .add_btree_table(Arc::new(orders_table)) + .expect("Test setup: failed to add orders table"); // Create products table let products_table = BTreeTable { @@ -2567,8 +2573,11 @@ mod tests { is_strict: false, has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }; - schema.add_btree_table(Arc::new(products_table)); + schema + .add_btree_table(Arc::new(products_table)) + .expect("Test setup: failed to add products table"); schema } diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 0b11d0c3e..32afc1c21 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -18,12 +18,14 @@ use super::{ Search, SeekDef, SelectPlan, TableReferences, WhereTerm, }, }; -use crate::translate::{collate::get_collseq_from_expr, window::emit_window_loop_source}; +use crate::translate::{ + collate::get_collseq_from_expr, emitter::UpdateRowSource, window::emit_window_loop_source, +}; use crate::{ schema::{Affinity, Index, IndexColumn, Table}, translate::{ emitter::prepare_cdc_if_necessary, - plan::{DistinctCtx, Distinctness, Scan}, + plan::{DistinctCtx, Distinctness, Scan, SeekKeyComponent}, result_row::emit_select_result, }, types::SeekOp, @@ -129,8 +131,8 @@ pub fn init_loop( ); if matches!( - mode, - OperationMode::INSERT | OperationMode::UPDATE | OperationMode::DELETE + &mode, + OperationMode::INSERT | OperationMode::UPDATE { .. } | OperationMode::DELETE ) { assert!(tables.joined_tables().len() == 1); let changed_table = &tables.joined_tables()[0].table; @@ -202,9 +204,9 @@ pub fn init_loop( } } let (table_cursor_id, index_cursor_id) = - table.open_cursors(program, mode, t_ctx.resolver.schema)?; + table.open_cursors(program, mode.clone(), t_ctx.resolver.schema)?; match &table.op { - Operation::Scan(Scan::BTreeTable { index, .. }) => match (mode, &table.table) { + Operation::Scan(Scan::BTreeTable { index, .. }) => match (&mode, &table.table) { (OperationMode::SELECT, Table::BTree(btree)) => { let root_page = btree.root_page; if let Some(cursor_id) = table_cursor_id { @@ -259,14 +261,28 @@ pub fn init_loop( } } } - (OperationMode::UPDATE, Table::BTree(btree)) => { + (OperationMode::UPDATE(update_mode), Table::BTree(btree)) => { let root_page = btree.root_page; - program.emit_insn(Insn::OpenWrite { - cursor_id: table_cursor_id - .expect("table cursor is always opened in OperationMode::UPDATE"), - root_page: root_page.into(), - db: table.database_id, - }); + match &update_mode { + UpdateRowSource::Normal => { + program.emit_insn(Insn::OpenWrite { + cursor_id: table_cursor_id.expect( + "table cursor is always opened in OperationMode::UPDATE", + ), + root_page: root_page.into(), + db: table.database_id, + }); + } + UpdateRowSource::PrebuiltEphemeralTable { target_table, .. } => { + let target_table_cursor_id = program + .resolve_cursor_id(&CursorKey::table(target_table.internal_id)); + program.emit_insn(Insn::OpenWrite { + cursor_id: target_table_cursor_id, + root_page: target_table.btree().unwrap().root_page.into(), + db: table.database_id, + }); + } + } if let Some(index_cursor_id) = index_cursor_id { program.emit_insn(Insn::OpenWrite { cursor_id: index_cursor_id, @@ -281,7 +297,9 @@ pub fn init_loop( if let Table::Virtual(tbl) = &table.table { let is_write = matches!( mode, - OperationMode::INSERT | OperationMode::UPDATE | OperationMode::DELETE + OperationMode::INSERT + | OperationMode::UPDATE { .. } + | OperationMode::DELETE ); if is_write && tbl.readonly() { return Err(crate::LimboError::ReadOnly); @@ -303,7 +321,7 @@ pub fn init_loop( }); } } - OperationMode::DELETE | OperationMode::UPDATE => { + OperationMode::DELETE | OperationMode::UPDATE { .. } => { let table_cursor_id = table_cursor_id.expect( "table cursor is always opened in OperationMode::DELETE or OperationMode::UPDATE", ); @@ -316,7 +334,7 @@ pub fn init_loop( // For DELETE, we need to open all the indexes for writing // UPDATE opens these in emit_program_for_update() separately - if mode == OperationMode::DELETE { + if matches!(mode, OperationMode::DELETE) { if let Some(indexes) = t_ctx.resolver.schema.indexes.get(table.table.get_name()) { @@ -361,7 +379,7 @@ pub fn init_loop( db: table.database_id, }); } - OperationMode::UPDATE | OperationMode::DELETE => { + OperationMode::UPDATE { .. } | OperationMode::DELETE => { program.emit_insn(Insn::OpenWrite { cursor_id: index_cursor_id .expect("index cursor is always opened in Seek with index"), @@ -388,6 +406,7 @@ pub fn init_loop( jump_if_condition_is_true: false, jump_target_when_true: jump_target, jump_target_when_false: t_ctx.label_main_loop_end.unwrap(), + jump_target_when_null: t_ctx.label_main_loop_end.unwrap(), }; translate_condition_expr(program, tables, &cond.expr, meta, &t_ctx.resolver)?; program.preassign_label_to_next_insn(jump_target); @@ -406,6 +425,7 @@ pub fn open_loop( join_order: &[JoinOrderMember], predicates: &[WhereTerm], temp_cursor_id: Option, + mode: OperationMode, ) -> Result<()> { for (join_index, join) in join_order.iter().enumerate() { let joined_table_index = join.original_idx; @@ -432,7 +452,7 @@ pub fn open_loop( } } - let (table_cursor_id, index_cursor_id) = table.resolve_cursors(program)?; + let (table_cursor_id, index_cursor_id) = table.resolve_cursors(program, mode.clone())?; match &table.op { Operation::Scan(scan) => { @@ -606,7 +626,10 @@ pub fn open_loop( ); }; - let start_reg = program.alloc_registers(seek_def.key.len()); + let max_registers = seek_def + .size(&seek_def.start) + .max(seek_def.size(&seek_def.end)); + let start_reg = program.alloc_registers(max_registers); emit_seek( program, table_references, @@ -710,6 +733,7 @@ fn emit_conditions( jump_if_condition_is_true: false, jump_target_when_true, jump_target_when_false: next, + jump_target_when_null: next, }; translate_condition_expr( program, @@ -724,7 +748,7 @@ fn emit_conditions( Ok(()) } -/// SQLite (and so Limbo) processes joins as a nested loop. +/// SQLite (and so Turso) processes joins as a nested loop. /// The loop may emit rows to various destinations depending on the query: /// - a GROUP BY sorter (grouping is done by sorting based on the GROUP BY keys and aggregating while the GROUP BY keys match) /// - a GROUP BY phase with no sorting (when the rows are already in the order required by the GROUP BY keys) @@ -982,7 +1006,7 @@ pub fn close_loop( t_ctx: &mut TranslateCtx, tables: &TableReferences, join_order: &[JoinOrderMember], - temp_cursor_id: Option, + mode: OperationMode, ) -> Result<()> { // We close the loops for all tables in reverse order, i.e. innermost first. // OPEN t1 @@ -1000,20 +1024,28 @@ pub fn close_loop( .get(table_index) .expect("source has no loop labels"); - let (table_cursor_id, index_cursor_id) = table.resolve_cursors(program)?; + let (table_cursor_id, index_cursor_id) = table.resolve_cursors(program, mode.clone())?; match &table.op { Operation::Scan(scan) => { program.resolve_label(loop_labels.next, program.offset()); match scan { Scan::BTreeTable { iter_dir, .. } => { - let iteration_cursor_id = temp_cursor_id.unwrap_or_else(|| { + let iteration_cursor_id = if let OperationMode::UPDATE( + UpdateRowSource::PrebuiltEphemeralTable { + ephemeral_table_cursor_id, + .. + }, + ) = &mode + { + *ephemeral_table_cursor_id + } else { index_cursor_id.unwrap_or_else(|| { table_cursor_id.expect( "Either ephemeral or index or table cursor must be opened", ) }) - }); + }; if *iter_dir == IterationDirection::Backwards { program.emit_insn(Insn::Prev { cursor_id: iteration_cursor_id, @@ -1050,12 +1082,19 @@ pub fn close_loop( "Subqueries do not support index seeks" ); program.resolve_label(loop_labels.next, program.offset()); - let iteration_cursor_id = temp_cursor_id.unwrap_or_else(|| { - index_cursor_id.unwrap_or_else(|| { - table_cursor_id - .expect("Either ephemeral or index or table cursor must be opened") - }) - }); + let iteration_cursor_id = + if let OperationMode::UPDATE(UpdateRowSource::PrebuiltEphemeralTable { + ephemeral_table_cursor_id, + .. + }) = &mode + { + *ephemeral_table_cursor_id + } else { + index_cursor_id.unwrap_or_else(|| { + table_cursor_id + .expect("Either ephemeral or index or table cursor must be opened") + }) + }; // Rowid equality point lookups are handled with a SeekRowid instruction which does not loop, so there is no need to emit a Next instruction. if !matches!(search, Search::RowidEq { .. }) { let iter_dir = match search { @@ -1146,7 +1185,8 @@ fn emit_seek( seek_index: Option<&Arc>, ) -> Result<()> { let is_index = seek_index.is_some(); - let Some(seek) = seek_def.seek.as_ref() else { + if seek_def.prefix.is_empty() && matches!(seek_def.start.last_component, SeekKeyComponent::None) + { // If there is no seek key, we start from the first or last row of the index, // depending on the iteration direction. // @@ -1196,43 +1236,34 @@ fn emit_seek( }; // We allocated registers for the full index key, but our seek key might not use the full index key. // See [crate::translate::optimizer::build_seek_def] for more details about in which cases we do and don't use the full index key. - for i in 0..seek_def.key.len() { + for (i, key) in seek_def.iter(&seek_def.start).enumerate() { let reg = start_reg + i; - if i >= seek.len { - if seek.null_pad { - program.emit_insn(Insn::Null { - dest: reg, - dest_end: None, - }); - } - } else { - let expr = &seek_def.key[i].0; - translate_expr_no_constant_opt( - program, - Some(tables), - expr, - reg, - &t_ctx.resolver, - NoConstantOptReason::RegisterReuse, - )?; - // If the seek key column is not verifiably non-NULL, we need check whether it is NULL, - // and if so, jump to the loop end. - // This is to avoid returning rows for e.g. SELECT * FROM t WHERE t.x > NULL, - // which would erroneously return all rows from t, as NULL is lower than any non-NULL value in index key comparisons. - if !expr.is_nonnull(tables) { - program.emit_insn(Insn::IsNull { + match key { + SeekKeyComponent::Expr(expr) => { + translate_expr_no_constant_opt( + program, + Some(tables), + expr, reg, - target_pc: loop_end, - }); + &t_ctx.resolver, + NoConstantOptReason::RegisterReuse, + )?; + // If the seek key column is not verifiably non-NULL, we need check whether it is NULL, + // and if so, jump to the loop end. + // This is to avoid returning rows for e.g. SELECT * FROM t WHERE t.x > NULL, + // which would erroneously return all rows from t, as NULL is lower than any non-NULL value in index key comparisons. + if !expr.is_nonnull(tables) { + program.emit_insn(Insn::IsNull { + reg, + target_pc: loop_end, + }); + } } + SeekKeyComponent::None => unreachable!("None component is not possible in iterator"), } } - let num_regs = if seek.null_pad { - seek_def.key.len() - } else { - seek.len - }; - match seek.op { + let num_regs = seek_def.size(&seek_def.start); + match seek_def.start.op { SeekOp::GE { eq_only } => program.emit_insn(Insn::SeekGE { is_index, cursor_id: seek_cursor_id, @@ -1289,7 +1320,7 @@ fn emit_seek_termination( seek_index: Option<&Arc>, ) -> Result<()> { let is_index = seek_index.is_some(); - let Some(termination) = seek_def.termination.as_ref() else { + if seek_def.prefix.is_empty() && matches!(seek_def.end.last_component, SeekKeyComponent::None) { program.preassign_label_to_next_insn(loop_start); // If we will encounter NULLs in the index at the end of iteration (Forward + Desc OR Backward + Asc) // then, we must explicitly stop before them as seek always has some bound condition over indexed column (e.g. c < ?, c >= ?, ...) @@ -1320,46 +1351,23 @@ fn emit_seek_termination( return Ok(()); }; - // How many non-NULL values were used for seeking. - let seek_len = seek_def.seek.as_ref().map_or(0, |seek| seek.len); + // For all index key values apart from the last one, we are guaranteed to use the same values + // as these values were emited from common prefix, so we don't need to emit them again. - // How many values will be used for the termination condition. - let num_regs = if termination.null_pad { - seek_def.key.len() - } else { - termination.len - }; - for i in 0..seek_def.key.len() { - let reg = start_reg + i; - let is_last = i == seek_def.key.len() - 1; - - // For all index key values apart from the last one, we are guaranteed to use the same values - // as were used for the seek, so we don't need to emit them again. - if i < seek_len && !is_last { - continue; - } - // For the last index key value, we need to emit a NULL if the termination condition is NULL-padded. - // See [SeekKey::null_pad] and [crate::translate::optimizer::build_seek_def] for why this is the case. - if i >= termination.len && !termination.null_pad { - continue; - } - if is_last && termination.null_pad { - program.emit_insn(Insn::Null { - dest: reg, - dest_end: None, - }); - // if the seek key is shorter than the termination key, we need to translate the remaining suffix of the termination key. - // if not, we just reuse what was emitted for the seek. - } else if seek_len < termination.len { + let num_regs = seek_def.size(&seek_def.end); + let last_reg = start_reg + seek_def.prefix.len(); + match &seek_def.end.last_component { + SeekKeyComponent::Expr(expr) => { translate_expr_no_constant_opt( program, Some(tables), - &seek_def.key[i].0, - reg, + expr, + last_reg, &t_ctx.resolver, NoConstantOptReason::RegisterReuse, )?; } + SeekKeyComponent::None => {} } program.preassign_label_to_next_insn(loop_start); let mut rowid_reg = None; @@ -1385,7 +1393,7 @@ fn emit_seek_termination( Some(Affinity::Numeric) }; } - match (is_index, termination.op) { + match (is_index, seek_def.end.op) { (true, SeekOp::GE { .. }) => program.emit_insn(Insn::IdxGE { cursor_id: seek_cursor_id, start_reg, diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 690ad7c47..758031544 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -17,6 +17,7 @@ pub(crate) mod delete; pub(crate) mod display; pub(crate) mod emitter; pub(crate) mod expr; +pub(crate) mod fkeys; pub(crate) mod group_by; pub(crate) mod index; pub(crate) mod insert; @@ -283,17 +284,21 @@ pub fn translate_inner( columns, body, returning, - } => translate_insert( - with, - resolver, - or_conflict, - tbl_name, - columns, - body, - returning, - program, - connection, - )?, + } => { + if with.is_some() { + crate::bail_parse_error!("WITH clause is not supported"); + } + translate_insert( + resolver, + or_conflict, + tbl_name, + columns, + body, + returning, + program, + connection, + )? + } }; // Indicate write operations so that in the epilogue we can emit the correct type of transaction diff --git a/core/translate/optimizer/access_method.rs b/core/translate/optimizer/access_method.rs index 35e5e2718..883ae789c 100644 --- a/core/translate/optimizer/access_method.rs +++ b/core/translate/optimizer/access_method.rs @@ -3,7 +3,9 @@ use std::sync::Arc; use turso_ext::{ConstraintInfo, ConstraintUsage, ResultCode}; use turso_parser::ast::SortOrder; -use crate::translate::optimizer::constraints::{convert_to_vtab_constraint, Constraint}; +use crate::translate::optimizer::constraints::{ + convert_to_vtab_constraint, Constraint, RangeConstraintRef, +}; use crate::{ schema::{Index, Table}, translate::plan::{IterationDirection, JoinOrderMember, JoinedTable}, @@ -12,24 +14,24 @@ use crate::{ }; use super::{ - constraints::{usable_constraints_for_join_order, ConstraintRef, TableConstraints}, + constraints::{usable_constraints_for_join_order, TableConstraints}, cost::{estimate_cost_for_scan_or_seek, Cost, IndexInfo}, order::OrderTarget, }; #[derive(Debug, Clone)] /// Represents a way to access a table. -pub struct AccessMethod<'a> { +pub struct AccessMethod { /// The estimated number of page fetches. /// We are ignoring CPU cost for now. pub cost: Cost, /// Table-type specific access method details. - pub params: AccessMethodParams<'a>, + pub params: AccessMethodParams, } /// Table‑specific details of how an [`AccessMethod`] operates. #[derive(Debug, Clone)] -pub enum AccessMethodParams<'a> { +pub enum AccessMethodParams { BTreeTable { /// The direction of iteration for the access method. /// Typically this is backwards only if it helps satisfy an [OrderTarget]. @@ -39,7 +41,7 @@ pub enum AccessMethodParams<'a> { /// The constraint references that are being used, if any. /// An empty list of constraint refs means a scan (full table or index); /// a non-empty list means a search. - constraint_refs: &'a [ConstraintRef], + constraint_refs: Vec, }, VirtualTable { /// Index identifier returned by the table's `best_index` method. @@ -57,13 +59,13 @@ pub enum AccessMethodParams<'a> { } /// Return the best [AccessMethod] for a given join order. -pub fn find_best_access_method_for_join_order<'a>( +pub fn find_best_access_method_for_join_order( rhs_table: &JoinedTable, - rhs_constraints: &'a TableConstraints, + rhs_constraints: &TableConstraints, join_order: &[JoinOrderMember], maybe_order_target: Option<&OrderTarget>, input_cardinality: f64, -) -> Result>> { +) -> Result> { match &rhs_table.table { Table::BTree(_) => find_best_access_method_for_btree( rhs_table, @@ -85,19 +87,19 @@ pub fn find_best_access_method_for_join_order<'a>( } } -fn find_best_access_method_for_btree<'a>( +fn find_best_access_method_for_btree( rhs_table: &JoinedTable, - rhs_constraints: &'a TableConstraints, + rhs_constraints: &TableConstraints, join_order: &[JoinOrderMember], maybe_order_target: Option<&OrderTarget>, input_cardinality: f64, -) -> Result>> { +) -> Result> { let table_no = join_order.last().unwrap().table_id; let mut best_cost = estimate_cost_for_scan_or_seek(None, &[], &[], input_cardinality); let mut best_params = AccessMethodParams::BTreeTable { iter_dir: IterationDirection::Forwards, index: None, - constraint_refs: &[], + constraint_refs: vec![], }; let rowid_column_idx = rhs_table.columns().iter().position(|c| c.is_rowid_alias); @@ -123,7 +125,7 @@ fn find_best_access_method_for_btree<'a>( let cost = estimate_cost_for_scan_or_seek( Some(index_info), &rhs_constraints.constraints, - usable_constraint_refs, + &usable_constraint_refs, input_cardinality, ); @@ -192,12 +194,12 @@ fn find_best_access_method_for_btree<'a>( })) } -fn find_best_access_method_for_vtab<'a>( +fn find_best_access_method_for_vtab( vtab: &VirtualTable, constraints: &[Constraint], join_order: &[JoinOrderMember], input_cardinality: f64, -) -> Result>> { +) -> Result> { let vtab_constraints = convert_to_vtab_constraint(constraints, join_order); // TODO: get proper order_by information to pass to the vtab. diff --git a/core/translate/optimizer/constraints.rs b/core/translate/optimizer/constraints.rs index 62d77f3c9..486c36687 100644 --- a/core/translate/optimizer/constraints.rs +++ b/core/translate/optimizer/constraints.rs @@ -67,17 +67,17 @@ pub enum BinaryExprSide { } impl Constraint { - /// Get the constraining expression, e.g. '2+3' from 't.x = 2+3' - pub fn get_constraining_expr(&self, where_clause: &[WhereTerm]) -> ast::Expr { + /// Get the constraining expression and operator, e.g. ('>=', '2+3') from 't.x >= 2+3' + pub fn get_constraining_expr(&self, where_clause: &[WhereTerm]) -> (ast::Operator, ast::Expr) { let (idx, side) = self.where_clause_pos; let where_term = &where_clause[idx]; let Ok(Some((lhs, _, rhs))) = as_binary_components(&where_term.expr) else { panic!("Expected a valid binary expression"); }; if side == BinaryExprSide::Lhs { - lhs.clone() + (self.operator, lhs.clone()) } else { - rhs.clone() + (self.operator, rhs.clone()) } } @@ -108,19 +108,6 @@ pub struct ConstraintRef { pub sort_order: SortOrder, } -impl ConstraintRef { - /// Convert the constraint to a column usable in a [crate::translate::plan::SeekDef::key]. - pub fn as_seek_key_column( - &self, - constraints: &[Constraint], - where_clause: &[WhereTerm], - ) -> (ast::Expr, SortOrder) { - let constraint = &constraints[self.constraint_vec_pos]; - let constraining_expr = constraint.get_constraining_expr(where_clause); - (constraining_expr, self.sort_order) - } -} - /// A collection of [ConstraintRef]s for a given index, or if index is None, for the table's rowid index. /// For example, given a table `T (x,y,z)` with an index `T_I (y desc,z)`, take the following query: /// ```sql @@ -150,6 +137,7 @@ pub struct ConstraintUseCandidate { /// The index that may be used to satisfy the constraints. If none, the table's rowid index is used. pub index: Option>, /// References to the constraints that may be used as an access path for the index. + /// Refs are sorted by [ConstraintRef::index_col_pos] pub refs: Vec, } @@ -193,6 +181,9 @@ fn estimate_selectivity(column: &Column, op: ast::Operator) -> f64 { /// Precompute all potentially usable [Constraints] from a WHERE clause. /// The resulting list of [TableConstraints] is then used to evaluate the best access methods for various join orders. +/// +/// This method do not perform much filtering of constraints and delegate this tasks to the consumers of the method +/// Consumers must inspect [TableConstraints] and its candidates and pick best constraints for optimized access pub fn constraints_from_where_clause( where_clause: &[WhereTerm], table_references: &TableReferences, @@ -379,24 +370,6 @@ pub fn constraints_from_where_clause( for candidate in cs.candidates.iter_mut() { // Sort by index_col_pos, ascending -- index columns must be consumed in contiguous order. candidate.refs.sort_by_key(|cref| cref.index_col_pos); - // Deduplicate by position, keeping first occurrence (which will be equality if one exists, since the constraints vec is sorted that way) - candidate.refs.dedup_by_key(|cref| cref.index_col_pos); - // Truncate at first gap in positions -- again, index columns must be consumed in contiguous order. - let contiguous_len = candidate - .refs - .iter() - .enumerate() - .take_while(|(i, cref)| cref.index_col_pos == *i) - .count(); - candidate.refs.truncate(contiguous_len); - - // Truncate after the first inequality, since the left-prefix rule of indexes requires that all constraints but the last one must be equalities; - // again see: https://www.solarwinds.com/blog/the-left-prefix-index-rule - if let Some(first_inequality) = candidate.refs.iter().position(|cref| { - cs.constraints[cref.constraint_vec_pos].operator != ast::Operator::Equals - }) { - candidate.refs.truncate(first_inequality + 1); - } } cs.candidates.retain(|c| { if let Some(idx) = &c.index { @@ -413,6 +386,87 @@ pub fn constraints_from_where_clause( Ok(constraints) } +#[derive(Clone, Debug)] +/// A reference to a [Constraint]s in a [TableConstraints] for single column. +/// +/// This is specialized version of [ConstraintRef] which specifically holds range-like constraints: +/// - x = 10 (eq is set) +/// - x >= 10, x > 10 (lower_bound is set) +/// - x <= 10, x < 10 (upper_bound is set) +/// - x > 10 AND x < 20 (both lower_bound and upper_bound are set) +/// +/// eq, lower_bound and upper_bound holds None or position of the constraint in the [Constraint] array +pub struct RangeConstraintRef { + /// position of the column in the table definition + pub table_col_pos: usize, + /// position of the column in the index definition + pub index_col_pos: usize, + /// sort order for the column in the index definition + pub sort_order: SortOrder, + /// equality constraint + pub eq: Option, + /// lower bound constraint (either > or >=) + pub lower_bound: Option, + /// upper bound constraint (either < or <=) + pub upper_bound: Option, +} + +#[derive(Debug, Clone)] +/// Represent seek range which can be used in query planning to emit range scan over table or index +pub struct SeekRangeConstraint { + pub sort_order: SortOrder, + pub eq: Option<(ast::Operator, ast::Expr)>, + pub lower_bound: Option<(ast::Operator, ast::Expr)>, + pub upper_bound: Option<(ast::Operator, ast::Expr)>, +} + +impl SeekRangeConstraint { + pub fn new_eq(sort_order: SortOrder, eq: (ast::Operator, ast::Expr)) -> Self { + Self { + sort_order, + eq: Some(eq), + lower_bound: None, + upper_bound: None, + } + } + pub fn new_range( + sort_order: SortOrder, + lower_bound: Option<(ast::Operator, ast::Expr)>, + upper_bound: Option<(ast::Operator, ast::Expr)>, + ) -> Self { + assert!(lower_bound.is_some() || upper_bound.is_some()); + Self { + sort_order, + eq: None, + lower_bound, + upper_bound, + } + } +} + +impl RangeConstraintRef { + /// Convert the [RangeConstraintRef] to a [SeekRangeConstraint] usable in a [crate::translate::plan::SeekDef::key]. + pub fn as_seek_range_constraint( + &self, + constraints: &[Constraint], + where_clause: &[WhereTerm], + ) -> SeekRangeConstraint { + if let Some(eq) = self.eq { + return SeekRangeConstraint::new_eq( + self.sort_order, + constraints[eq].get_constraining_expr(where_clause), + ); + } + SeekRangeConstraint::new_range( + self.sort_order, + self.lower_bound + .map(|x| constraints[x].get_constraining_expr(where_clause)), + self.upper_bound + .map(|x| constraints[x].get_constraining_expr(where_clause)), + ) + } +} + /// Find which [Constraint]s are usable for a given join order. /// Returns a slice of the references to the constraints that are usable. /// A constraint is considered usable for a given table if all of the other tables referenced by the constraint @@ -421,28 +475,102 @@ pub fn usable_constraints_for_join_order<'a>( constraints: &'a [Constraint], refs: &'a [ConstraintRef], join_order: &[JoinOrderMember], -) -> &'a [ConstraintRef] { +) -> Vec { + debug_assert!(refs.is_sorted_by_key(|x| x.index_col_pos)); + let table_idx = join_order.last().unwrap().original_idx; - let mut usable_until = 0; + let lhs_mask = TableMask::from_table_number_iter( + join_order + .iter() + .take(join_order.len() - 1) + .map(|j| j.original_idx), + ); + let mut usable: Vec = Vec::new(); + let mut current_required_column_pos = 0; for cref in refs.iter() { let constraint = &constraints[cref.constraint_vec_pos]; let other_side_refers_to_self = constraint.lhs_mask.contains_table(table_idx); if other_side_refers_to_self { break; } - let lhs_mask = TableMask::from_table_number_iter( - join_order - .iter() - .take(join_order.len() - 1) - .map(|j| j.original_idx), - ); let all_required_tables_are_on_left_side = lhs_mask.contains_all(&constraint.lhs_mask); if !all_required_tables_are_on_left_side { break; } - usable_until += 1; + if Some(cref.index_col_pos) == usable.last().map(|x| x.index_col_pos) { + // Two constraints on the same index column can be combined into a single range constraint. + assert_eq!(cref.sort_order, usable.last().unwrap().sort_order); + assert_eq!(cref.index_col_pos, usable.last().unwrap().index_col_pos); + assert_eq!( + constraints[cref.constraint_vec_pos].table_col_pos, + usable.last().unwrap().table_col_pos + ); + // if we already have eq constraint - we must not add anything to it + // otherwise, we can incorrectly consume filters which will not be used in the access path + if usable.last().unwrap().eq.is_some() { + continue; + } + match constraints[cref.constraint_vec_pos].operator { + ast::Operator::Greater | ast::Operator::GreaterEquals => { + usable.last_mut().unwrap().lower_bound = Some(cref.constraint_vec_pos); + } + ast::Operator::Less | ast::Operator::LessEquals => { + usable.last_mut().unwrap().upper_bound = Some(cref.constraint_vec_pos); + } + _ => {} + } + continue; + } + if cref.index_col_pos != current_required_column_pos { + // Index columns must be consumed contiguously in the order they appear in the index. + break; + } + if usable.last().is_some_and(|x| x.eq.is_none()) { + // Usable index key must have 0-n equalities and then a maximum of 1 range constraint with one or both bounds set. + // If we already have a range constraint before this one, we must not add anything to it + break; + } + let operator = constraints[cref.constraint_vec_pos].operator; + let table_col_pos = constraints[cref.constraint_vec_pos].table_col_pos; + if operator == ast::Operator::Equals + && usable + .last() + .is_some_and(|x| x.table_col_pos == table_col_pos) + { + // If we already have an equality constraint for this column, we can't use it again + continue; + } + let constraint_group = match operator { + ast::Operator::Equals => RangeConstraintRef { + table_col_pos, + index_col_pos: cref.index_col_pos, + sort_order: cref.sort_order, + eq: Some(cref.constraint_vec_pos), + lower_bound: None, + upper_bound: None, + }, + ast::Operator::Greater | ast::Operator::GreaterEquals => RangeConstraintRef { + table_col_pos, + index_col_pos: cref.index_col_pos, + sort_order: cref.sort_order, + eq: None, + lower_bound: Some(cref.constraint_vec_pos), + upper_bound: None, + }, + ast::Operator::Less | ast::Operator::LessEquals => RangeConstraintRef { + table_col_pos, + index_col_pos: cref.index_col_pos, + sort_order: cref.sort_order, + eq: None, + lower_bound: None, + upper_bound: Some(cref.constraint_vec_pos), + }, + _ => continue, + }; + usable.push(constraint_group); + current_required_column_pos += 1; } - &refs[..usable_until] + usable } fn can_use_partial_index(index: &Index, query_where_clause: &[WhereTerm]) -> bool { diff --git a/core/translate/optimizer/cost.rs b/core/translate/optimizer/cost.rs index 460fa9b0a..c96947e5d 100644 --- a/core/translate/optimizer/cost.rs +++ b/core/translate/optimizer/cost.rs @@ -1,4 +1,6 @@ -use super::constraints::{Constraint, ConstraintRef}; +use crate::translate::optimizer::constraints::RangeConstraintRef; + +use super::constraints::Constraint; /// A simple newtype wrapper over a f64 that represents the cost of an operation. /// @@ -43,7 +45,7 @@ pub fn estimate_page_io_cost(rowcount: f64) -> Cost { pub fn estimate_cost_for_scan_or_seek( index_info: Option, constraints: &[Constraint], - usable_constraint_refs: &[ConstraintRef], + usable_constraint_refs: &[RangeConstraintRef], input_cardinality: f64, ) -> Cost { let Some(index_info) = index_info else { @@ -55,8 +57,18 @@ pub fn estimate_cost_for_scan_or_seek( let selectivity_multiplier: f64 = usable_constraint_refs .iter() .map(|cref| { - let constraint = &constraints[cref.constraint_vec_pos]; - constraint.selectivity + if let Some(eq) = cref.eq { + let constraint = &constraints[eq]; + return constraint.selectivity; + } + let mut selectivity = 1.0; + if let Some(lower_bound) = cref.lower_bound { + selectivity *= constraints[lower_bound].selectivity; + } + if let Some(upper_bound) = cref.upper_bound { + selectivity *= constraints[upper_bound].selectivity; + } + selectivity }) .product(); diff --git a/core/translate/optimizer/join.rs b/core/translate/optimizer/join.rs index db5e71000..79b80174a 100644 --- a/core/translate/optimizer/join.rs +++ b/core/translate/optimizer/join.rs @@ -49,7 +49,7 @@ pub fn join_lhs_and_rhs<'a>( rhs_constraints: &'a TableConstraints, join_order: &[JoinOrderMember], maybe_order_target: Option<&OrderTarget>, - access_methods_arena: &'a RefCell>>, + access_methods_arena: &'a RefCell>, cost_upper_bound: Cost, ) -> Result> { // The input cardinality for this join is the output cardinality of the previous join. @@ -125,7 +125,7 @@ pub fn compute_best_join_order<'a>( joined_tables: &[JoinedTable], maybe_order_target: Option<&OrderTarget>, constraints: &'a [TableConstraints], - access_methods_arena: &'a RefCell>>, + access_methods_arena: &'a RefCell>, ) -> Result> { // Skip work if we have no tables to consider. if joined_tables.is_empty() { @@ -403,7 +403,7 @@ pub fn compute_best_join_order<'a>( pub fn compute_naive_left_deep_plan<'a>( joined_tables: &[JoinedTable], maybe_order_target: Option<&OrderTarget>, - access_methods_arena: &'a RefCell>>, + access_methods_arena: &'a RefCell>, constraints: &'a [TableConstraints], ) -> Result> { let n = joined_tables.len(); @@ -509,9 +509,9 @@ mod tests { use crate::{ schema::{BTreeTable, Column, Index, IndexColumn, Table, Type}, translate::{ - optimizer::access_method::AccessMethodParams, - optimizer::constraints::{ - constraints_from_where_clause, BinaryExprSide, ConstraintRef, + optimizer::{ + access_method::AccessMethodParams, + constraints::{constraints_from_where_clause, BinaryExprSide, RangeConstraintRef}, }, plan::{ ColumnUsedMask, IterationDirection, JoinInfo, Operation, TableReferences, WhereTerm, @@ -632,8 +632,7 @@ mod tests { assert!(iter_dir == IterationDirection::Forwards); assert!(constraint_refs.len() == 1); assert!( - table_constraints[0].constraints[constraint_refs[0].constraint_vec_pos] - .where_clause_pos + table_constraints[0].constraints[constraint_refs[0].eq.unwrap()].where_clause_pos == (0, BinaryExprSide::Rhs) ); } @@ -701,8 +700,7 @@ mod tests { assert!(index.as_ref().unwrap().name == "sqlite_autoindex_test_table_1"); assert!(constraint_refs.len() == 1); assert!( - table_constraints[0].constraints[constraint_refs[0].constraint_vec_pos] - .where_clause_pos + table_constraints[0].constraints[constraint_refs[0].eq.unwrap()].where_clause_pos == (0, BinaryExprSide::Rhs) ); } @@ -784,8 +782,7 @@ mod tests { assert!(index.as_ref().unwrap().name == "index1"); assert!(constraint_refs.len() == 1); assert!( - table_constraints[TABLE1].constraints[constraint_refs[0].constraint_vec_pos] - .where_clause_pos + table_constraints[TABLE1].constraints[constraint_refs[0].eq.unwrap()].where_clause_pos == (0, BinaryExprSide::Rhs) ); } @@ -960,8 +957,8 @@ mod tests { assert!(iter_dir == IterationDirection::Forwards); assert!(index.as_ref().unwrap().name == "sqlite_autoindex_customers_1"); assert!(constraint_refs.len() == 1); - let constraint = &table_constraints[TABLE_NO_CUSTOMERS].constraints - [constraint_refs[0].constraint_vec_pos]; + let constraint = + &table_constraints[TABLE_NO_CUSTOMERS].constraints[constraint_refs[0].eq.unwrap()]; assert!(constraint.lhs_mask.is_empty()); let access_method = &access_methods_arena.borrow()[best_plan.data[1].1]; @@ -970,7 +967,7 @@ mod tests { assert!(index.as_ref().unwrap().name == "orders_customer_id_idx"); assert!(constraint_refs.len() == 1); let constraint = - &table_constraints[TABLE_NO_ORDERS].constraints[constraint_refs[0].constraint_vec_pos]; + &table_constraints[TABLE_NO_ORDERS].constraints[constraint_refs[0].eq.unwrap()]; assert!(constraint.lhs_mask.contains_table(TABLE_NO_CUSTOMERS)); let access_method = &access_methods_arena.borrow()[best_plan.data[2].1]; @@ -978,8 +975,8 @@ mod tests { assert!(iter_dir == IterationDirection::Forwards); assert!(index.as_ref().unwrap().name == "order_items_order_id_idx"); assert!(constraint_refs.len() == 1); - let constraint = &table_constraints[TABLE_NO_ORDER_ITEMS].constraints - [constraint_refs[0].constraint_vec_pos]; + let constraint = + &table_constraints[TABLE_NO_ORDER_ITEMS].constraints[constraint_refs[0].eq.unwrap()]; assert!(constraint.lhs_mask.contains_table(TABLE_NO_ORDERS)); } @@ -1187,8 +1184,8 @@ mod tests { assert!(iter_dir == IterationDirection::Forwards); assert!(index.is_none()); assert!(constraint_refs.len() == 1); - let constraint = &table_constraints[*table_number].constraints - [constraint_refs[0].constraint_vec_pos]; + let constraint = + &table_constraints[*table_number].constraints[constraint_refs[0].eq.unwrap()]; assert!(constraint.lhs_mask.contains_table(FACT_TABLE_IDX)); assert!(constraint.operator == ast::Operator::Equals); } @@ -1280,7 +1277,7 @@ mod tests { assert!(iter_dir == IterationDirection::Forwards); assert!(index.is_none()); assert!(constraint_refs.len() == 1); - let constraint = &table_constraints.constraints[constraint_refs[0].constraint_vec_pos]; + let constraint = &table_constraints.constraints[constraint_refs[0].eq.unwrap()]; assert!(constraint.lhs_mask.contains_table(i - 1)); assert!(constraint.operator == ast::Operator::Equals); } @@ -1481,7 +1478,7 @@ mod tests { let (_, index, constraint_refs) = _as_btree(access_method); assert!(index.as_ref().is_some_and(|i| i.name == "idx1")); assert!(constraint_refs.len() == 1); - let constraint = &table_constraints[0].constraints[constraint_refs[0].constraint_vec_pos]; + let constraint = &table_constraints[0].constraints[constraint_refs[0].eq.unwrap()]; assert!(constraint.operator == ast::Operator::Equals); assert!(constraint.table_col_pos == 0); // c1 } @@ -1608,10 +1605,10 @@ mod tests { let (_, index, constraint_refs) = _as_btree(access_method); assert!(index.as_ref().is_some_and(|i| i.name == "idx1")); assert!(constraint_refs.len() == 2); - let constraint = &table_constraints[0].constraints[constraint_refs[0].constraint_vec_pos]; + let constraint = &table_constraints[0].constraints[constraint_refs[0].eq.unwrap()]; assert!(constraint.operator == ast::Operator::Equals); assert!(constraint.table_col_pos == 0); // c1 - let constraint = &table_constraints[0].constraints[constraint_refs[1].constraint_vec_pos]; + let constraint = &table_constraints[0].constraints[constraint_refs[1].lower_bound.unwrap()]; assert!(constraint.operator == ast::Operator::Greater); assert!(constraint.table_col_pos == 1); // c2 } @@ -1664,6 +1661,7 @@ mod tests { has_rowid: true, is_strict: false, unique_sets: vec![], + foreign_keys: vec![], }) } @@ -1710,9 +1708,13 @@ mod tests { Expr::Literal(ast::Literal::Numeric(value.to_string())) } - fn _as_btree<'a>( - access_method: &AccessMethod<'a>, - ) -> (IterationDirection, Option>, &'a [ConstraintRef]) { + fn _as_btree( + access_method: &AccessMethod, + ) -> ( + IterationDirection, + Option>, + &'_ [RangeConstraintRef], + ) { match &access_method.params { AccessMethodParams::BTreeTable { iter_dir, diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index bd4ecdd2d..798cd50f3 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -16,12 +16,19 @@ use turso_ext::{ConstraintInfo, ConstraintUsage}; use turso_parser::ast::{self, Expr, SortOrder}; use crate::{ - schema::{Index, IndexColumn, Schema, Table}, + schema::{BTreeTable, Column, Index, IndexColumn, Schema, Table, Type, ROWID_SENTINEL}, translate::{ - optimizer::access_method::AccessMethodParams, optimizer::constraints::TableConstraints, - plan::Scan, plan::TerminationKey, + optimizer::{ + access_method::AccessMethodParams, + constraints::{RangeConstraintRef, SeekRangeConstraint, TableConstraints}, + }, + plan::{ + ColumnUsedMask, OuterQueryReference, QueryDestination, ResultSetColumn, Scan, + SeekKeyComponent, + }, }, types::SeekOp, + vdbe::builder::{CursorKey, CursorType, ProgramBuilder}, LimboError, Result, }; @@ -41,11 +48,11 @@ pub(crate) mod lift_common_subexpressions; pub(crate) mod order; #[tracing::instrument(skip_all, level = tracing::Level::DEBUG)] -pub fn optimize_plan(plan: &mut Plan, schema: &Schema) -> Result<()> { +pub fn optimize_plan(program: &mut ProgramBuilder, plan: &mut Plan, schema: &Schema) -> Result<()> { match plan { Plan::Select(plan) => optimize_select_plan(plan, schema)?, Plan::Delete(plan) => optimize_delete_plan(plan, schema)?, - Plan::Update(plan) => optimize_update_plan(plan, schema)?, + Plan::Update(plan) => optimize_update_plan(program, plan, schema)?, Plan::CompoundSelect { left, right_most, .. } => { @@ -112,7 +119,11 @@ fn optimize_delete_plan(plan: &mut DeletePlan, schema: &Schema) -> Result<()> { Ok(()) } -fn optimize_update_plan(plan: &mut UpdatePlan, schema: &Schema) -> Result<()> { +fn optimize_update_plan( + program: &mut ProgramBuilder, + plan: &mut UpdatePlan, + schema: &Schema, +) -> Result<()> { lift_common_subexpressions_from_binary_or_terms(&mut plan.where_clause)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? @@ -129,28 +140,170 @@ fn optimize_update_plan(plan: &mut UpdatePlan, schema: &Schema) -> Result<()> { &mut None, )?; - // It is not safe to use an index that is going to be updated as the iteration index for a table. - // In these cases, we will fall back to a table scan. - // FIXME: this should probably be incorporated into the optimizer itself, but it's a smaller fix this way. let table_ref = &mut plan.table_references.joined_tables_mut()[0]; - // No index, OK. - let Some(index) = table_ref.op.index() else { - return Ok(()); + // An ephemeral table is required if the UPDATE modifies any column that is present in the key of the + // btree used to iterate over the table. + // For regular table scans or seeks, this is just the rowid or the rowid alias column (INTEGER PRIMARY KEY) + // For index scans and seeks, this is any column in the index used. + let requires_ephemeral_table = 'requires: { + let Some(btree_table) = table_ref.table.btree() else { + break 'requires false; + }; + let Some(index) = table_ref.op.index() else { + let rowid_alias_used = plan.set_clauses.iter().fold(false, |accum, (idx, _)| { + accum || (*idx != ROWID_SENTINEL && btree_table.columns[*idx].is_rowid_alias) + }); + if rowid_alias_used { + break 'requires true; + } + let direct_rowid_update = plan + .set_clauses + .iter() + .any(|(idx, _)| *idx == ROWID_SENTINEL); + if direct_rowid_update { + break 'requires true; + } + break 'requires false; + }; + + plan.set_clauses + .iter() + .any(|(idx, _)| index.columns.iter().any(|c| c.pos_in_table == *idx)) }; - // Iteration index not affected by update, OK. - if !plan.indexes_to_update.iter().any(|i| Arc::ptr_eq(index, i)) { + + if !requires_ephemeral_table { return Ok(()); } - // Otherwise, fall back to a table scan. - table_ref.op = Operation::Scan(Scan::BTreeTable { - iter_dir: IterationDirection::Forwards, - index: None, + + add_ephemeral_table_to_update_plan(program, plan) +} + +/// An ephemeral table is required if the UPDATE modifies any column that is present in the key of the +/// btree used to iterate over the table. +/// For regular table scans or seeks, the key is the rowid or the rowid alias column (INTEGER PRIMARY KEY). +/// For index scans and seeks, the key is any column in the index used. +/// +/// The ephemeral table will accumulate all the rowids of the rows that are affected by the UPDATE, +/// and then the temp table will be iterated over and the actual row updates performed. +/// +/// This is necessary because an UPDATE is implemented as a DELETE-then-INSERT operation, which could +/// mess up the iteration order of the rows by changing the keys in the table/index that the iteration +/// is performed over. The ephemeral table ensures stable iteration because it is not modified during +/// the UPDATE loop. +fn add_ephemeral_table_to_update_plan( + program: &mut ProgramBuilder, + plan: &mut UpdatePlan, +) -> Result<()> { + let internal_id = program.table_reference_counter.next(); + let ephemeral_table = Arc::new(BTreeTable { + root_page: 0, // Not relevant for ephemeral table definition + name: "ephemeral_scratch".to_string(), + has_rowid: true, + has_autoincrement: false, + primary_key_columns: vec![], + columns: vec![Column { + name: Some("rowid".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: false, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }], + is_strict: false, + unique_sets: vec![], + foreign_keys: vec![], }); - // Revert the decision to use a WHERE clause term as an index constraint. - plan.where_clause - .iter_mut() - .for_each(|term| term.consumed = false); + + let temp_cursor_id = program.alloc_cursor_id_keyed( + CursorKey::table(internal_id), + CursorType::BTreeTable(ephemeral_table.clone()), + ); + + // The actual update loop will use the ephemeral table as the single [JoinedTable] which it then loops over. + let table_references_update = TableReferences::new( + vec![JoinedTable { + table: Table::BTree(ephemeral_table.clone()), + identifier: "ephemeral_scratch".to_string(), + internal_id, + op: Operation::Scan(Scan::BTreeTable { + iter_dir: IterationDirection::Forwards, + index: None, + }), + join_info: None, + col_used_mask: ColumnUsedMask::default(), + database_id: 0, + }], + vec![], + ); + + // Building the ephemeral table will use the TableReferences from the original plan -- i.e. if we chose an index scan originally, + // we will build the ephemeral table by using the same index scan and using the same WHERE filters. + let table_references_ephemeral_select = + std::mem::replace(&mut plan.table_references, table_references_update); + + for table in table_references_ephemeral_select.joined_tables() { + // The update loop needs to reference columns from the original source table, so we add it as an outer query reference. + plan.table_references + .add_outer_query_reference(OuterQueryReference { + identifier: table.identifier.clone(), + internal_id: table.internal_id, + table: table.table.clone(), + col_used_mask: table.col_used_mask.clone(), + }); + } + + let join_order = table_references_ephemeral_select + .joined_tables() + .iter() + .enumerate() + .map(|(i, t)| JoinOrderMember { + table_id: t.internal_id, + original_idx: i, + is_outer: t + .join_info + .as_ref() + .is_some_and(|join_info| join_info.outer), + }) + .collect(); + let rowid_internal_id = table_references_ephemeral_select + .joined_tables() + .first() + .unwrap() + .internal_id; + + let ephemeral_plan = SelectPlan { + table_references: table_references_ephemeral_select, + result_columns: vec![ResultSetColumn { + expr: Expr::RowId { + database: None, + table: rowid_internal_id, + }, + alias: None, + contains_aggregates: false, + }], + where_clause: plan.where_clause.drain(..).collect(), + group_by: None, // N/A + order_by: vec![], // N/A + aggregates: vec![], // N/A + limit: None, // N/A + query_destination: QueryDestination::EphemeralTable { + cursor_id: temp_cursor_id, + table: ephemeral_table, + }, + join_order, + offset: None, + contains_constant_false_condition: false, + distinctness: super::plan::Distinctness::NonDistinct, + values: vec![], + window: None, + }; + + plan.ephemeral_plan = Some(ephemeral_plan); Ok(()) } @@ -184,6 +337,12 @@ fn optimize_table_access( order_by: &mut Vec<(Box, SortOrder)>, group_by: &mut Option, ) -> Result>> { + if table_references.joined_tables().len() > TableReferences::MAX_JOINED_TABLES { + crate::bail_parse_error!( + "Only up to {} tables can be joined", + TableReferences::MAX_JOINED_TABLES + ); + } let access_methods_arena = RefCell::new(Vec::new()); let maybe_order_target = compute_order_target(order_by, group_by.as_mut()); let constraints_per_table = @@ -343,13 +502,15 @@ fn optimize_table_access( .filter(|c| c.usable) .cloned() .collect::>(); - let temp_constraint_refs = (0..usable_constraints.len()) + let mut temp_constraint_refs = (0..usable_constraints.len()) .map(|i| ConstraintRef { constraint_vec_pos: i, - index_col_pos: usable_constraints[i].table_col_pos, + index_col_pos: i, sort_order: SortOrder::Asc, }) .collect::>(); + temp_constraint_refs.sort_by_key(|x| x.index_col_pos); + let usable_constraint_refs = usable_constraints_for_join_order( &usable_constraints, &temp_constraint_refs, @@ -362,17 +523,14 @@ fn optimize_table_access( }); continue; } - let ephemeral_index = ephemeral_index_build( - &joined_tables[table_idx], - &usable_constraints, - usable_constraint_refs, - ); + let ephemeral_index = + ephemeral_index_build(&joined_tables[table_idx], &usable_constraint_refs); let ephemeral_index = Arc::new(ephemeral_index); joined_tables[table_idx].op = Operation::Search(Search::Seek { index: Some(ephemeral_index), seek_def: build_seek_def_from_constraints( - &usable_constraints, - usable_constraint_refs, + &table_constraints.constraints, + &usable_constraint_refs, *iter_dir, where_clause, )?, @@ -383,25 +541,29 @@ fn optimize_table_access( .as_ref() .is_some_and(|join_info| join_info.outer); for cref in constraint_refs.iter() { - let constraint = - &constraints_per_table[table_idx].constraints[cref.constraint_vec_pos]; - let where_term = &mut where_clause[constraint.where_clause_pos.0]; - assert!( - !where_term.consumed, - "trying to consume a where clause term twice: {where_term:?}", - ); - if is_outer_join && where_term.from_outer_join.is_none() { - // Don't consume WHERE terms from outer joins if the where term is not part of the outer join condition. Consider: - // - SELECT * FROM t1 LEFT JOIN t2 ON false WHERE t2.id = 5 - // - there is no row in t2 where t2.id = 5 - // This should never produce any rows with null columns for t2 (because NULL != 5), but if we consume 't2.id = 5' to use it as a seek key, - // this will cause a null row to be emitted for EVERY row of t1. - // Note: in most cases like this, the LEFT JOIN could just be converted into an INNER JOIN (because e.g. t2.id=5 statically excludes any null rows), - // but that optimization should not be done here - it should be done before the join order optimization happens. - continue; + for constraint_vec_pos in &[cref.eq, cref.lower_bound, cref.upper_bound] { + let Some(constraint_vec_pos) = constraint_vec_pos else { + continue; + }; + let constraint = + &constraints_per_table[table_idx].constraints[*constraint_vec_pos]; + let where_term = &mut where_clause[constraint.where_clause_pos.0]; + assert!( + !where_term.consumed, + "trying to consume a where clause term twice: {where_term:?}", + ); + if is_outer_join && where_term.from_outer_join.is_none() { + // Don't consume WHERE terms from outer joins if the where term is not part of the outer join condition. Consider: + // - SELECT * FROM t1 LEFT JOIN t2 ON false WHERE t2.id = 5 + // - there is no row in t2 where t2.id = 5 + // This should never produce any rows with null columns for t2 (because NULL != 5), but if we consume 't2.id = 5' to use it as a seek key, + // this will cause a null row to be emitted for EVERY row of t1. + // Note: in most cases like this, the LEFT JOIN could just be converted into an INNER JOIN (because e.g. t2.id=5 statically excludes any null rows), + // but that optimization should not be done here - it should be done before the join order optimization happens. + continue; + } + where_term.consumed = true; } - - where_clause[constraint.where_clause_pos.0].consumed = true; } if let Some(index) = &index { joined_tables[table_idx].op = Operation::Search(Search::Seek { @@ -419,13 +581,14 @@ fn optimize_table_access( constraint_refs.len() == 1, "expected exactly one constraint for rowid seek, got {constraint_refs:?}" ); - let constraint = &constraints_per_table[table_idx].constraints - [constraint_refs[0].constraint_vec_pos]; - joined_tables[table_idx].op = match constraint.operator { - ast::Operator::Equals => Operation::Search(Search::RowidEq { - cmp_expr: constraint.get_constraining_expr(where_clause), - }), - _ => Operation::Search(Search::Seek { + joined_tables[table_idx].op = if let Some(eq) = constraint_refs[0].eq { + Operation::Search(Search::RowidEq { + cmp_expr: constraints_per_table[table_idx].constraints[eq] + .get_constraining_expr(where_clause) + .1, + }) + } else { + Operation::Search(Search::Seek { index: None, seek_def: build_seek_def_from_constraints( &constraints_per_table[table_idx].constraints, @@ -433,7 +596,7 @@ fn optimize_table_access( *iter_dir, where_clause, )?, - }), + }) }; } } @@ -505,7 +668,7 @@ fn build_vtab_scan_op( if usage.omit { where_clause[constraint.where_clause_pos.0].consumed = true; } - let expr = constraint.get_constraining_expr(where_clause); + let (_, expr) = constraint.get_constraining_expr(where_clause); constraints[zero_based_argv_index] = Some(expr); arg_count += 1; } @@ -864,8 +1027,7 @@ impl Optimizable for ast::Expr { fn ephemeral_index_build( table_reference: &JoinedTable, - constraints: &[Constraint], - constraint_refs: &[ConstraintRef], + constraint_refs: &[RangeConstraintRef], ) -> Index { let mut ephemeral_columns: Vec = table_reference .columns() @@ -886,11 +1048,11 @@ fn ephemeral_index_build( let a_constraint = constraint_refs .iter() .enumerate() - .find(|(_, c)| constraints[c.constraint_vec_pos].table_col_pos == a.pos_in_table); + .find(|(_, c)| c.table_col_pos == a.pos_in_table); let b_constraint = constraint_refs .iter() .enumerate() - .find(|(_, c)| constraints[c.constraint_vec_pos].table_col_pos == b.pos_in_table); + .find(|(_, c)| c.table_col_pos == b.pos_in_table); match (a_constraint, b_constraint) { (Some(_), None) => Ordering::Less, (None, Some(_)) => Ordering::Greater, @@ -922,7 +1084,7 @@ fn ephemeral_index_build( /// Build a [SeekDef] for a given list of [Constraint]s pub fn build_seek_def_from_constraints( constraints: &[Constraint], - constraint_refs: &[ConstraintRef], + constraint_refs: &[RangeConstraintRef], iter_dir: IterationDirection, where_clause: &[WhereTerm], ) -> Result { @@ -933,472 +1095,294 @@ pub fn build_seek_def_from_constraints( // Extract the key values and operators let key = constraint_refs .iter() - .map(|cref| cref.as_seek_key_column(constraints, where_clause)) + .map(|cref| cref.as_seek_range_constraint(constraints, where_clause)) .collect(); - // We know all but potentially the last term is an equality, so we can use the operator of the last term - // to form the SeekOp - let op = constraints[constraint_refs.last().unwrap().constraint_vec_pos].operator; - - let seek_def = build_seek_def(op, iter_dir, key)?; + let seek_def = build_seek_def(iter_dir, key)?; Ok(seek_def) } -/// Build a [SeekDef] for a given comparison operator and index key. +/// Build a [SeekDef] for a given [SeekRangeConstraint] and [IterationDirection]. /// To be usable as a seek key, all but potentially the last term must be equalities. -/// The last term can be a nonequality. -/// The comparison operator referred to by `op` is the operator of the last term. +/// The last term can be a nonequality (range with potentially one unbounded range). /// /// There are two parts to the seek definition: -/// 1. The [SeekKey], which specifies the key that we will use to seek to the first row that matches the index key. -/// 2. The [TerminationKey], which specifies the key that we will use to terminate the index scan that follows the seek. +/// 1. start [SeekKey], which specifies the key that we will use to seek to the first row that matches the index key. +/// 2. end [SeekKey], which specifies the key that we will use to terminate the index scan that follows the seek. /// -/// There are some nuances to how, and which parts of, the index key can be used in the [SeekKey] and [TerminationKey], +/// There are some nuances to how, and which parts of, the index key can be used in the start and end [SeekKey]s, /// depending on the operator and iteration order. This function explains those nuances inline when dealing with /// each case. /// /// But to illustrate the general idea, consider the following examples: /// /// 1. For example, having two conditions like (x>10 AND y>20) cannot be used as a valid [SeekKey] GT(x:10, y:20) -/// because the first row greater than (x:10, y:20) might be (x:10, y:21), which does not satisfy the where clause. +/// because the first row greater than (x:10, y:20) might be (x:11, y:19), which does not satisfy the where clause. /// In this case, only GT(x:10) must be used as the [SeekKey], and rows with y <= 20 must be filtered as a regular condition expression for each value of x. /// /// 2. In contrast, having (x=10 AND y>20) forms a valid index key GT(x:10, y:20) because after the seek, we can simply terminate as soon as x > 10, -/// i.e. use GT(x:10, y:20) as the [SeekKey] and GT(x:10) as the [TerminationKey]. +/// i.e. use GT(x:10, y:20) as the start [SeekKey] and GT(x:10) as the end. /// /// The preceding examples are for an ascending index. The logic is similar for descending indexes, but an important distinction is that /// since a descending index is laid out in reverse order, the comparison operators are reversed, e.g. LT becomes GT, LE becomes GE, etc. /// So when you see e.g. a SeekOp::GT below for a descending index, it actually means that we are seeking the first row where the index key is LESS than the seek key. /// fn build_seek_def( - op: ast::Operator, iter_dir: IterationDirection, - key: Vec<(ast::Expr, SortOrder)>, + mut key: Vec, ) -> Result { - let key_len = key.len(); - let sort_order_of_last_key = key.last().unwrap().1; + assert!(!key.is_empty()); + let last = key.last().unwrap(); + + // if we searching for exact key - emit definition immediately with prefix as a full key + if last.eq.is_some() { + let (start_op, end_op) = match iter_dir { + IterationDirection::Forwards => (SeekOp::GE { eq_only: true }, SeekOp::GT), + IterationDirection::Backwards => (SeekOp::LE { eq_only: true }, SeekOp::LT), + }; + return Ok(SeekDef { + prefix: key, + iter_dir, + start: SeekKey { + last_component: SeekKeyComponent::None, + op: start_op, + }, + end: SeekKey { + last_component: SeekKeyComponent::None, + op: end_op, + }, + }); + } + assert!(last.lower_bound.is_some() || last.upper_bound.is_some()); + + // pop last key as we will do some form of range search + let last = key.pop().unwrap(); + + // after that all key components must be equality constraints + debug_assert!(key.iter().all(|k| k.eq.is_some())); // For the commented examples below, keep in mind that since a descending index is laid out in reverse order, the comparison operators are reversed, e.g. LT becomes GT, LE becomes GE, etc. // Also keep in mind that index keys are compared based on the number of columns given, so for example: // - if key is GT(x:10), then (x=10, y=usize::MAX) is not GT because only X is compared. (x=11, y=) is GT. // - if key is GT(x:10, y:20), then (x=10, y=21) is GT because both X and Y are compared. // - if key is GT(x:10, y:NULL), then (x=10, y=0) is GT because NULL is always LT in index key comparisons. - Ok(match (iter_dir, op) { - // Forwards, EQ: - // Example: (x=10 AND y=20) - // Seek key: start from the first GE(x:10, y:20) - // Termination key: end at the first GT(x:10, y:20) - // Ascending vs descending doesn't matter because all the comparisons are equalities. - (IterationDirection::Forwards, ast::Operator::Equals) => SeekDef { - key, - iter_dir, - seek: Some(SeekKey { - len: key_len, - null_pad: false, - op: SeekOp::GE { eq_only: true }, - }), - termination: Some(TerminationKey { - len: key_len, - null_pad: false, - op: SeekOp::GT, - }), - }, - // Forwards, GT: - // Ascending index example: (x=10 AND y>20) - // Seek key: start from the first GT(x:10, y:20), e.g. (x=10, y=21) - // Termination key: end at the first GT(x:10), e.g. (x=11, y=0) - // - // Descending index example: (x=10 AND y>20) - // Seek key: start from the first LE(x:10), e.g. (x=10, y=usize::MAX), so reversed -> GE(x:10) - // Termination key: end at the first LE(x:10, y:20), e.g. (x=10, y=20) so reversed -> GE(x:10, y:20) - (IterationDirection::Forwards, ast::Operator::Greater) => { - let (seek_key_len, termination_key_len, seek_op, termination_op) = - if sort_order_of_last_key == SortOrder::Asc { - (key_len, key_len - 1, SeekOp::GT, SeekOp::GT) - } else { - ( - key_len - 1, - key_len, - SeekOp::LE { eq_only: false }.reverse(), - SeekOp::LE { eq_only: false }.reverse(), - ) - }; + Ok(match iter_dir { + IterationDirection::Forwards => { + let (start, end) = match last.sort_order { + SortOrder::Asc => { + let start = match last.lower_bound { + // Forwards, Asc, GT: (x=10 AND y>20) + // Start key: start from the first GT(x:10, y:20) + Some((ast::Operator::Greater, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::GT, + }, + // Forwards, Asc, GE: (x=10 AND y>=20) + // Start key: start from the first GE(x:10, y:20) + Some((ast::Operator::GreaterEquals, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::GE { eq_only: false }, + }, + // Forwards, Asc, None, (x=10 AND y<30) + // Start key: start from the first GE(x:10) + None => SeekKey { + last_component: SeekKeyComponent::None, + op: SeekOp::GE { eq_only: false }, + }, + Some((op, _)) => { + crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op,) + } + }; + let end = match last.upper_bound { + // Forwards, Asc, LT, (x=10 AND y<30) + // End key: end at first GE(x:10, y:30) + Some((ast::Operator::Less, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::GE { eq_only: false }, + }, + // Forwards, Asc, LE, (x=10 AND y<=30) + // End key: end at first GT(x:10, y:30) + Some((ast::Operator::LessEquals, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::GT, + }, + // Forwards, Asc, None, (x=10 AND y>20) + // End key: end at first GT(x:10) + None => SeekKey { + last_component: SeekKeyComponent::None, + op: SeekOp::GT, + }, + Some((op, _)) => { + crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op,) + } + }; + (start, end) + } + SortOrder::Desc => { + let start = match last.upper_bound { + // Forwards, Desc, LT: (x=10 AND y<30) + // Start key: start from the first GT(x:10, y:30) + Some((ast::Operator::Less, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::GT, + }, + // Forwards, Desc, LE: (x=10 AND y<=30) + // Start key: start from the first GE(x:10, y:30) + Some((ast::Operator::LessEquals, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::GE { eq_only: false }, + }, + // Forwards, Desc, None: (x=10 AND y>20) + // Start key: start from the first GE(x:10) + None => SeekKey { + last_component: SeekKeyComponent::None, + op: SeekOp::GE { eq_only: false }, + }, + Some((op, _)) => { + crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op,) + } + }; + let end = match last.lower_bound { + // Forwards, Asc, GT, (x=10 AND y>20) + // End key: end at first GE(x:10, y:20) + Some((ast::Operator::Greater, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::GE { eq_only: false }, + }, + // Forwards, Asc, GE, (x=10 AND y>=20) + // End key: end at first GT(x:10, y:20) + Some((ast::Operator::GreaterEquals, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::GT, + }, + // Forwards, Asc, None, (x=10 AND y<30) + // End key: end at first GT(x:10) + None => SeekKey { + last_component: SeekKeyComponent::None, + op: SeekOp::GT, + }, + Some((op, _)) => { + crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op,) + } + }; + (start, end) + } + }; SeekDef { - key, + prefix: key, iter_dir, - seek: if seek_key_len > 0 { - Some(SeekKey { - len: seek_key_len, - op: seek_op, - null_pad: false, - }) - } else { - None - }, - termination: if termination_key_len > 0 { - Some(TerminationKey { - len: termination_key_len, - op: termination_op, - null_pad: false, - }) - } else { - None - }, + start, + end, } } - // Forwards, GE: - // Ascending index example: (x=10 AND y>=20) - // Seek key: start from the first GE(x:10, y:20), e.g. (x=10, y=20) - // Termination key: end at the first GT(x:10), e.g. (x=11, y=0) - // - // Descending index example: (x=10 AND y>=20) - // Seek key: start from the first LE(x:10), e.g. (x=10, y=usize::MAX), so reversed -> GE(x:10) - // Termination key: end at the first LT(x:10, y:20), e.g. (x=10, y=19), so reversed -> GT(x:10, y:20) - (IterationDirection::Forwards, ast::Operator::GreaterEquals) => { - let (seek_key_len, termination_key_len, seek_op, termination_op) = - if sort_order_of_last_key == SortOrder::Asc { - ( - key_len, - key_len - 1, - SeekOp::GE { eq_only: false }, - SeekOp::GT, - ) - } else { - ( - key_len - 1, - key_len, - SeekOp::LE { eq_only: false }.reverse(), - SeekOp::LT.reverse(), - ) - }; + IterationDirection::Backwards => { + let (start, end) = match last.sort_order { + SortOrder::Asc => { + let start = match last.upper_bound { + // Backwards, Asc, LT: (x=10 AND y<30) + // Start key: start from the first LT(x:10, y:30) + Some((ast::Operator::Less, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::LT, + }, + // Backwards, Asc, LT: (x=10 AND y<=30) + // Start key: start from the first LE(x:10, y:30) + Some((ast::Operator::LessEquals, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::LE { eq_only: false }, + }, + // Backwards, Asc, None: (x=10 AND y>20) + // Start key: start from the first LE(x:10) + None => SeekKey { + last_component: SeekKeyComponent::None, + op: SeekOp::LE { eq_only: false }, + }, + Some((op, _)) => { + crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op) + } + }; + let end = match last.lower_bound { + // Backwards, Asc, GT, (x=10 AND y>20) + // End key: end at first LE(x:10, y:20) + Some((ast::Operator::Greater, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::LE { eq_only: false }, + }, + // Backwards, Asc, GT, (x=10 AND y>=20) + // End key: end at first LT(x:10, y:20) + Some((ast::Operator::GreaterEquals, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::LT, + }, + // Backwards, Asc, None, (x=10 AND y<30) + // End key: end at first LT(x:10) + None => SeekKey { + last_component: SeekKeyComponent::None, + op: SeekOp::LT, + }, + Some((op, _)) => { + crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op,) + } + }; + (start, end) + } + SortOrder::Desc => { + let start = match last.lower_bound { + // Backwards, Desc, LT: (x=10 AND y>20) + // Start key: start from the first LT(x:10, y:20) + Some((ast::Operator::Greater, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::LT, + }, + // Backwards, Desc, LE: (x=10 AND y>=20) + // Start key: start from the first LE(x:10, y:20) + Some((ast::Operator::GreaterEquals, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::LE { eq_only: false }, + }, + // Backwards, Desc, LE: (x=10 AND y<30) + // Start key: start from the first LE(x:10) + None => SeekKey { + last_component: SeekKeyComponent::None, + op: SeekOp::LE { eq_only: false }, + }, + Some((op, _)) => { + crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op,) + } + }; + let end = match last.upper_bound { + // Backwards, Desc, LT, (x=10 AND y<30) + // End key: end at first LE(x:10, y:30) + Some((ast::Operator::Less, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::LE { eq_only: false }, + }, + // Backwards, Desc, LT, (x=10 AND y<=30) + // End key: end at first LT(x:10, y:30) + Some((ast::Operator::LessEquals, bound)) => SeekKey { + last_component: SeekKeyComponent::Expr(bound), + op: SeekOp::LT, + }, + // Backwards, Desc, LT, (x=10 AND y>20) + // End key: end at first LT(x:10) + None => SeekKey { + last_component: SeekKeyComponent::None, + op: SeekOp::LT, + }, + Some((op, _)) => { + crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op,) + } + }; + (start, end) + } + }; SeekDef { - key, + prefix: key, iter_dir, - seek: if seek_key_len > 0 { - Some(SeekKey { - len: seek_key_len, - op: seek_op, - null_pad: false, - }) - } else { - None - }, - termination: if termination_key_len > 0 { - Some(TerminationKey { - len: termination_key_len, - op: termination_op, - null_pad: false, - }) - } else { - None - }, + start, + end, } } - // Forwards, LT: - // Ascending index example: (x=10 AND y<20) - // Seek key: start from the first GT(x:10, y: NULL), e.g. (x=10, y=0) - // Termination key: end at the first GE(x:10, y:20), e.g. (x=10, y=20) - // - // Descending index example: (x=10 AND y<20) - // Seek key: start from the first LT(x:10, y:20), e.g. (x=10, y=19) so reversed -> GT(x:10, y:20) - // Termination key: end at the first LT(x:10), e.g. (x=9, y=usize::MAX), so reversed -> GE(x:10, NULL); i.e. GE the smallest possible (x=10, y) combination (NULL is always LT) - (IterationDirection::Forwards, ast::Operator::Less) => { - let (seek_key_len, termination_key_len, seek_op, termination_op) = - if sort_order_of_last_key == SortOrder::Asc { - ( - key_len - 1, - key_len, - SeekOp::GT, - SeekOp::GE { eq_only: false }, - ) - } else { - ( - key_len, - key_len - 1, - SeekOp::GT, - SeekOp::GE { eq_only: false }, - ) - }; - SeekDef { - key, - iter_dir, - seek: if seek_key_len > 0 { - Some(SeekKey { - len: seek_key_len, - op: seek_op, - null_pad: sort_order_of_last_key == SortOrder::Asc, - }) - } else { - None - }, - termination: if termination_key_len > 0 { - Some(TerminationKey { - len: termination_key_len, - op: termination_op, - null_pad: sort_order_of_last_key == SortOrder::Desc, - }) - } else { - None - }, - } - } - // Forwards, LE: - // Ascending index example: (x=10 AND y<=20) - // Seek key: start from the first GE(x:10, y:NULL), e.g. (x=10, y=0) - // Termination key: end at the first GT(x:10, y:20), e.g. (x=10, y=21) - // - // Descending index example: (x=10 AND y<=20) - // Seek key: start from the first LE(x:10, y:20), e.g. (x=10, y=20) so reversed -> GE(x:10, y:20) - // Termination key: end at the first LT(x:10), e.g. (x=9, y=usize::MAX), so reversed -> GE(x:10, NULL); i.e. GE the smallest possible (x=10, y) combination (NULL is always LT) - (IterationDirection::Forwards, ast::Operator::LessEquals) => { - let (seek_key_len, termination_key_len, seek_op, termination_op) = - if sort_order_of_last_key == SortOrder::Asc { - (key_len - 1, key_len, SeekOp::GT, SeekOp::GT) - } else { - ( - key_len, - key_len - 1, - SeekOp::LE { eq_only: false }.reverse(), - SeekOp::LE { eq_only: false }.reverse(), - ) - }; - SeekDef { - key, - iter_dir, - seek: if seek_key_len > 0 { - Some(SeekKey { - len: seek_key_len, - op: seek_op, - null_pad: sort_order_of_last_key == SortOrder::Asc, - }) - } else { - None - }, - termination: if termination_key_len > 0 { - Some(TerminationKey { - len: termination_key_len, - op: termination_op, - null_pad: sort_order_of_last_key == SortOrder::Desc, - }) - } else { - None - }, - } - } - // Backwards, EQ: - // Example: (x=10 AND y=20) - // Seek key: start from the last LE(x:10, y:20) - // Termination key: end at the first LT(x:10, y:20) - // Ascending vs descending doesn't matter because all the comparisons are equalities. - (IterationDirection::Backwards, ast::Operator::Equals) => SeekDef { - key, - iter_dir, - seek: Some(SeekKey { - len: key_len, - op: SeekOp::LE { eq_only: true }, - null_pad: false, - }), - termination: Some(TerminationKey { - len: key_len, - op: SeekOp::LT, - null_pad: false, - }), - }, - // Backwards, LT: - // Ascending index example: (x=10 AND y<20) - // Seek key: start from the last LT(x:10, y:20), e.g. (x=10, y=19) - // Termination key: end at the first LE(x:10, NULL), e.g. (x=9, y=usize::MAX) - // - // Descending index example: (x=10 AND y<20) - // Seek key: start from the last GT(x:10, y:NULL), e.g. (x=10, y=0) so reversed -> LT(x:10, NULL) - // Termination key: end at the first GE(x:10, y:20), e.g. (x=10, y=20) so reversed -> LE(x:10, y:20) - (IterationDirection::Backwards, ast::Operator::Less) => { - let (seek_key_len, termination_key_len, seek_op, termination_op) = - if sort_order_of_last_key == SortOrder::Asc { - ( - key_len, - key_len - 1, - SeekOp::LT, - SeekOp::LE { eq_only: false }, - ) - } else { - ( - key_len - 1, - key_len, - SeekOp::GT.reverse(), - SeekOp::GE { eq_only: false }.reverse(), - ) - }; - SeekDef { - key, - iter_dir, - seek: if seek_key_len > 0 { - Some(SeekKey { - len: seek_key_len, - op: seek_op, - null_pad: sort_order_of_last_key == SortOrder::Desc, - }) - } else { - None - }, - termination: if termination_key_len > 0 { - Some(TerminationKey { - len: termination_key_len, - op: termination_op, - null_pad: sort_order_of_last_key == SortOrder::Asc, - }) - } else { - None - }, - } - } - // Backwards, LE: - // Ascending index example: (x=10 AND y<=20) - // Seek key: start from the last LE(x:10, y:20), e.g. (x=10, y=20) - // Termination key: end at the first LT(x:10, NULL), e.g. (x=9, y=usize::MAX) - // - // Descending index example: (x=10 AND y<=20) - // Seek key: start from the last GT(x:10, NULL), e.g. (x=10, y=0) so reversed -> LT(x:10, NULL) - // Termination key: end at the first GT(x:10, y:20), e.g. (x=10, y=21) so reversed -> LT(x:10, y:20) - (IterationDirection::Backwards, ast::Operator::LessEquals) => { - let (seek_key_len, termination_key_len, seek_op, termination_op) = - if sort_order_of_last_key == SortOrder::Asc { - ( - key_len, - key_len - 1, - SeekOp::LE { eq_only: false }, - SeekOp::LE { eq_only: false }, - ) - } else { - ( - key_len - 1, - key_len, - SeekOp::GT.reverse(), - SeekOp::GT.reverse(), - ) - }; - SeekDef { - key, - iter_dir, - seek: if seek_key_len > 0 { - Some(SeekKey { - len: seek_key_len, - op: seek_op, - null_pad: sort_order_of_last_key == SortOrder::Desc, - }) - } else { - None - }, - termination: if termination_key_len > 0 { - Some(TerminationKey { - len: termination_key_len, - op: termination_op, - null_pad: sort_order_of_last_key == SortOrder::Asc, - }) - } else { - None - }, - } - } - // Backwards, GT: - // Ascending index example: (x=10 AND y>20) - // Seek key: start from the last LE(x:10), e.g. (x=10, y=usize::MAX) - // Termination key: end at the first LE(x:10, y:20), e.g. (x=10, y=20) - // - // Descending index example: (x=10 AND y>20) - // Seek key: start from the last GT(x:10, y:20), e.g. (x=10, y=21) so reversed -> LT(x:10, y:20) - // Termination key: end at the first GT(x:10), e.g. (x=11, y=0) so reversed -> LT(x:10) - (IterationDirection::Backwards, ast::Operator::Greater) => { - let (seek_key_len, termination_key_len, seek_op, termination_op) = - if sort_order_of_last_key == SortOrder::Asc { - ( - key_len - 1, - key_len, - SeekOp::LE { eq_only: false }, - SeekOp::LE { eq_only: false }, - ) - } else { - ( - key_len, - key_len - 1, - SeekOp::GT.reverse(), - SeekOp::GT.reverse(), - ) - }; - SeekDef { - key, - iter_dir, - seek: if seek_key_len > 0 { - Some(SeekKey { - len: seek_key_len, - op: seek_op, - null_pad: false, - }) - } else { - None - }, - termination: if termination_key_len > 0 { - Some(TerminationKey { - len: termination_key_len, - op: termination_op, - null_pad: false, - }) - } else { - None - }, - } - } - // Backwards, GE: - // Ascending index example: (x=10 AND y>=20) - // Seek key: start from the last LE(x:10), e.g. (x=10, y=usize::MAX) - // Termination key: end at the first LT(x:10, y:20), e.g. (x=10, y=19) - // - // Descending index example: (x=10 AND y>=20) - // Seek key: start from the last GE(x:10, y:20), e.g. (x=10, y=20) so reversed -> LE(x:10, y:20) - // Termination key: end at the first GT(x:10), e.g. (x=11, y=0) so reversed -> LT(x:10) - (IterationDirection::Backwards, ast::Operator::GreaterEquals) => { - let (seek_key_len, termination_key_len, seek_op, termination_op) = - if sort_order_of_last_key == SortOrder::Asc { - ( - key_len - 1, - key_len, - SeekOp::LE { eq_only: false }, - SeekOp::LT, - ) - } else { - ( - key_len, - key_len - 1, - SeekOp::GE { eq_only: false }.reverse(), - SeekOp::GT.reverse(), - ) - }; - SeekDef { - key, - iter_dir, - seek: if seek_key_len > 0 { - Some(SeekKey { - len: seek_key_len, - op: seek_op, - null_pad: false, - }) - } else { - None - }, - termination: if termination_key_len > 0 { - Some(TerminationKey { - len: termination_key_len, - op: termination_op, - null_pad: false, - }) - } else { - None - }, - } - } - (_, op) => { - crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op,) - } }) } diff --git a/core/translate/order_by.rs b/core/translate/order_by.rs index aafb32a21..2dd7a3c8a 100644 --- a/core/translate/order_by.rs +++ b/core/translate/order_by.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use turso_parser::ast::{self, SortOrder}; use crate::{ emit_explain, - schema::PseudoCursorType, + schema::{Index, IndexColumn, PseudoCursorType}, translate::{ collate::{get_collseq_from_expr, CollationSeq}, group_by::is_orderby_agg_or_const, @@ -11,7 +13,7 @@ use crate::{ util::exprs_are_equivalent, vdbe::{ builder::{CursorType, ProgramBuilder}, - insn::Insn, + insn::{IdxInsertFlags, Insn}, }, QueryMode, Result, }; @@ -39,9 +41,12 @@ pub struct SortMetadata { /// aggregates/constants, so that rows that tie on ORDER BY terms are output in /// the same relative order the underlying row stream produced them. pub has_sequence: bool, + /// Whether to use heap-sort with BTreeIndex instead of full-collection sort through Sorter + pub use_heap_sort: bool, } /// Initialize resources needed for ORDER BY processing +#[allow(clippy::too_many_arguments)] pub fn init_order_by( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, @@ -49,54 +54,113 @@ pub fn init_order_by( order_by: &[(Box, SortOrder)], referenced_tables: &TableReferences, has_group_by: bool, + has_distinct: bool, aggregates: &[Aggregate], ) -> Result<()> { - let sort_cursor = program.alloc_cursor_id(CursorType::Sorter); let only_aggs = order_by .iter() .all(|(e, _)| is_orderby_agg_or_const(&t_ctx.resolver, e, aggregates)); - // only emit sequence column if we have GROUP BY and ORDER BY is not only aggregates or constants - let has_sequence = has_group_by && !only_aggs; + let use_heap_sort = !has_distinct && !has_group_by && t_ctx.limit_ctx.is_some(); + + // only emit sequence column if (we have GROUP BY and ORDER BY is not only aggregates or constants) OR (we decided to use heap-sort) + let has_sequence = (has_group_by && !only_aggs) || use_heap_sort; + + let remappings = order_by_deduplicate_result_columns(order_by, result_columns, has_sequence); + let sort_cursor = if use_heap_sort { + let index_name = format!("heap_sort_{}", program.offset().as_offset_int()); // we don't really care about the name that much, just enough that we don't get name collisions + let mut index_columns = Vec::with_capacity(order_by.len() + result_columns.len()); + for (column, order) in order_by { + let collation = get_collseq_from_expr(column, referenced_tables)?; + let pos_in_table = index_columns.len(); + index_columns.push(IndexColumn { + name: pos_in_table.to_string(), + order: *order, + pos_in_table, + collation, + default: None, + }) + } + let pos_in_table = index_columns.len(); + // add sequence number between ORDER BY columns and result column + index_columns.push(IndexColumn { + name: pos_in_table.to_string(), + order: SortOrder::Asc, + pos_in_table, + collation: Some(CollationSeq::Binary), + default: None, + }); + for _ in remappings.iter().filter(|r| !r.deduplicated) { + let pos_in_table = index_columns.len(); + index_columns.push(IndexColumn { + name: pos_in_table.to_string(), + order: SortOrder::Asc, + pos_in_table, + collation: None, + default: None, + }) + } + let index = Arc::new(Index { + name: index_name.clone(), + table_name: String::new(), + ephemeral: true, + root_page: 0, + columns: index_columns, + unique: false, + has_rowid: false, + where_clause: None, + }); + program.alloc_cursor_id(CursorType::BTreeIndex(index)) + } else { + program.alloc_cursor_id(CursorType::Sorter) + }; t_ctx.meta_sort = Some(SortMetadata { sort_cursor, reg_sorter_data: program.alloc_register(), - remappings: order_by_deduplicate_result_columns(order_by, result_columns, has_sequence), + remappings, has_sequence, + use_heap_sort, }); - /* - * Terms of the ORDER BY clause that is part of a SELECT statement may be assigned a collating sequence using the COLLATE operator, - * in which case the specified collating function is used for sorting. - * Otherwise, if the expression sorted by an ORDER BY clause is a column, - * then the collating sequence of the column is used to determine sort order. - * If the expression is not a column and has no COLLATE clause, then the BINARY collating sequence is used. - */ - let mut collations = order_by - .iter() - .map(|(expr, _)| get_collseq_from_expr(expr, referenced_tables)) - .collect::>>()?; + if use_heap_sort { + program.emit_insn(Insn::OpenEphemeral { + cursor_id: sort_cursor, + is_table: false, + }); + } else { + /* + * Terms of the ORDER BY clause that is part of a SELECT statement may be assigned a collating sequence using the COLLATE operator, + * in which case the specified collating function is used for sorting. + * Otherwise, if the expression sorted by an ORDER BY clause is a column, + * then the collating sequence of the column is used to determine sort order. + * If the expression is not a column and has no COLLATE clause, then the BINARY collating sequence is used. + */ + let mut collations = order_by + .iter() + .map(|(expr, _)| get_collseq_from_expr(expr, referenced_tables)) + .collect::>>()?; - if has_sequence { - // sequence column uses BINARY collation - collations.push(Some(CollationSeq::default())); + if has_sequence { + // sequence column uses BINARY collation + collations.push(Some(CollationSeq::default())); + } + + let key_len = order_by.len() + if has_sequence { 1 } else { 0 }; + + program.emit_insn(Insn::SorterOpen { + cursor_id: sort_cursor, + columns: key_len, + order: { + let mut ord: Vec = order_by.iter().map(|(_, d)| *d).collect(); + if has_sequence { + // sequence is ascending tiebreaker + ord.push(SortOrder::Asc); + } + ord + }, + collations, + }); } - - let key_len = order_by.len() + if has_sequence { 1 } else { 0 }; - - program.emit_insn(Insn::SorterOpen { - cursor_id: sort_cursor, - columns: key_len, - order: { - let mut ord: Vec = order_by.iter().map(|(_, d)| *d).collect(); - if has_sequence { - // sequence is ascending tiebreaker - ord.push(SortOrder::Asc); - } - ord - }, - collations, - }); Ok(()) } @@ -118,6 +182,7 @@ pub fn emit_order_by( reg_sorter_data, ref remappings, has_sequence, + use_heap_sort, } = *t_ctx.meta_sort.as_ref().unwrap(); let sorter_column_count = order_by.len() @@ -128,33 +193,44 @@ pub fn emit_order_by( // to emit correct explain output. emit_explain!(program, false, "USE TEMP B-TREE FOR ORDER BY".to_owned()); - let pseudo_cursor = program.alloc_cursor_id(CursorType::Pseudo(PseudoCursorType { - column_count: sorter_column_count, - })); + let cursor_id = if !use_heap_sort { + let pseudo_cursor = program.alloc_cursor_id(CursorType::Pseudo(PseudoCursorType { + column_count: sorter_column_count, + })); - program.emit_insn(Insn::OpenPseudo { - cursor_id: pseudo_cursor, - content_reg: reg_sorter_data, - num_fields: sorter_column_count, - }); + program.emit_insn(Insn::OpenPseudo { + cursor_id: pseudo_cursor, + content_reg: reg_sorter_data, + num_fields: sorter_column_count, + }); + + program.emit_insn(Insn::SorterSort { + cursor_id: sort_cursor, + pc_if_empty: sort_loop_end_label, + }); + pseudo_cursor + } else { + program.emit_insn(Insn::Rewind { + cursor_id: sort_cursor, + pc_if_empty: sort_loop_end_label, + }); + sort_cursor + }; - program.emit_insn(Insn::SorterSort { - cursor_id: sort_cursor, - pc_if_empty: sort_loop_end_label, - }); program.preassign_label_to_next_insn(sort_loop_start_label); emit_offset(program, sort_loop_next_label, t_ctx.reg_offset); - program.emit_insn(Insn::SorterData { - cursor_id: sort_cursor, - dest_reg: reg_sorter_data, - pseudo_cursor, - }); + if !use_heap_sort { + program.emit_insn(Insn::SorterData { + cursor_id: sort_cursor, + dest_reg: reg_sorter_data, + pseudo_cursor: cursor_id, + }); + } // We emit the columns in SELECT order, not sorter order (sorter always has the sort keys first). // This is tracked in sort_metadata.remappings. - let cursor_id = pseudo_cursor; let start_reg = t_ctx.reg_result_cols_start.unwrap(); for i in 0..result_columns.len() { let reg = start_reg + i; @@ -171,14 +247,25 @@ pub fn emit_order_by( plan, start_reg, t_ctx.limit_ctx, - Some(sort_loop_end_label), + if !use_heap_sort { + Some(sort_loop_end_label) + } else { + None + }, )?; program.resolve_label(sort_loop_next_label, program.offset()); - program.emit_insn(Insn::SorterNext { - cursor_id: sort_cursor, - pc_if_next: sort_loop_start_label, - }); + if !use_heap_sort { + program.emit_insn(Insn::SorterNext { + cursor_id: sort_cursor, + pc_if_next: sort_loop_start_label, + }); + } else { + program.emit_insn(Insn::Next { + cursor_id: sort_cursor, + pc_if_next: sort_loop_start_label, + }); + } program.preassign_label_to_next_insn(sort_loop_end_label); Ok(()) @@ -237,6 +324,46 @@ pub fn order_by_sorter_insert( )?; } } + + let SortMetadata { + sort_cursor, + reg_sorter_data, + use_heap_sort, + .. + } = sort_metadata; + + let skip_label = if *use_heap_sort { + // skip records which greater than current top-k maintained in a separate BTreeIndex + let insert_label = program.allocate_label(); + let skip_label = program.allocate_label(); + let limit = t_ctx.limit_ctx.as_ref().expect("limit must be set"); + let limit_reg = t_ctx.reg_limit_offset_sum.unwrap_or(limit.reg_limit); + program.emit_insn(Insn::IfPos { + reg: limit_reg, + target_pc: insert_label, + decrement_by: 1, + }); + program.emit_insn(Insn::Last { + cursor_id: *sort_cursor, + pc_if_empty: insert_label, + }); + program.emit_insn(Insn::IdxLE { + cursor_id: *sort_cursor, + start_reg, + num_regs: orderby_sorter_column_count, + target_pc: skip_label, + }); + program.emit_insn(Insn::Delete { + cursor_id: *sort_cursor, + table_name: "".to_string(), + is_part_of_update: false, + }); + program.preassign_label_to_next_insn(insert_label); + Some(skip_label) + } else { + None + }; + let mut cur_reg = start_reg + order_by_len; if sort_metadata.has_sequence { program.emit_insn(Insn::Sequence { @@ -330,19 +457,31 @@ pub fn order_by_sorter_insert( } } - let SortMetadata { - sort_cursor, - reg_sorter_data, - .. - } = sort_metadata; - - sorter_insert( - program, - start_reg, - orderby_sorter_column_count, - *sort_cursor, - *reg_sorter_data, - ); + if *use_heap_sort { + program.emit_insn(Insn::MakeRecord { + start_reg, + count: orderby_sorter_column_count, + dest_reg: *reg_sorter_data, + index_name: None, + affinity_str: None, + }); + program.emit_insn(Insn::IdxInsert { + cursor_id: *sort_cursor, + record_reg: *reg_sorter_data, + unpacked_start: None, + unpacked_count: None, + flags: IdxInsertFlags::new(), + }); + program.preassign_label_to_next_insn(skip_label.unwrap()); + } else { + sorter_insert( + program, + start_reg, + orderby_sorter_column_count, + *sort_cursor, + *reg_sorter_data, + ); + } Ok(()) } diff --git a/core/translate/plan.rs b/core/translate/plan.rs index ec556f3f9..aa506ca76 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -4,7 +4,10 @@ use turso_parser::ast::{self, FrameBound, FrameClause, FrameExclude, FrameMode, use crate::{ function::AggFunc, schema::{BTreeTable, Column, FromClauseSubquery, Index, Schema, Table}, - translate::collate::get_collseq_from_expr, + translate::{ + collate::get_collseq_from_expr, emitter::UpdateRowSource, + optimizer::constraints::SeekRangeConstraint, + }, vdbe::{ builder::{CursorKey, CursorType, ProgramBuilder}, insn::{IdxInsertFlags, Insn}, @@ -440,7 +443,10 @@ pub struct UpdatePlan { // whether the WHERE clause is always false pub contains_constant_false_condition: bool, pub indexes_to_update: Vec>, - // If the table's rowid alias is used, gather all the target rowids into an ephemeral table, and then use that table as the single JoinedTable for the actual UPDATE loop. + // If the UPDATE modifies any column that is present in the key of the btree used to iterate over the table (either the table itself or an index), + // gather all the target rowids into an ephemeral table, and then use that table as the single JoinedTable for the actual UPDATE loop. + // This ensures the keys of the btree used to iterate cannot be changed during the UPDATE loop itself, ensuring all the intended rows actually get + // updated and none are skipped. pub ephemeral_plan: Option, // For ALTER TABLE turso-db emits appropriate DDL statement in the "updates" cell of CDC table // This field is present only for update plan created for ALTER TABLE when CDC mode has "updates" values @@ -583,6 +589,11 @@ pub struct TableReferences { } impl TableReferences { + /// The maximum number of tables that can be joined together in a query. + /// This limit is arbitrary, although we currently use a u128 to represent the [crate::translate::planner::TableMask], + /// which can represent up to 128 tables. + /// Even at 63 tables we currently cannot handle the optimization performantly, hence the arbitrary cap. + pub const MAX_JOINED_TABLES: usize = 63; pub fn new( joined_tables: Vec, outer_query_refs: Vec, @@ -608,6 +619,11 @@ impl TableReferences { self.joined_tables.push(joined_table); } + /// Add a new [OuterQueryReference] to the query plan. + pub fn add_outer_query_reference(&mut self, outer_query_reference: OuterQueryReference) { + self.outer_query_refs.push(outer_query_reference); + } + /// Returns an immutable reference to the [JoinedTable]s in the query plan. pub fn joined_tables(&self) -> &[JoinedTable] { &self.joined_tables @@ -752,33 +768,25 @@ impl TableReferences { } } -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, Default, PartialEq)] #[repr(transparent)] -pub struct ColumnUsedMask(u128); +pub struct ColumnUsedMask(roaring::RoaringBitmap); impl ColumnUsedMask { pub fn set(&mut self, index: usize) { - assert!( - index < 128, - "ColumnUsedMask only supports up to 128 columns" - ); - self.0 |= 1 << index; + self.0.insert(index as u32); } pub fn get(&self, index: usize) -> bool { - assert!( - index < 128, - "ColumnUsedMask only supports up to 128 columns" - ); - self.0 & (1 << index) != 0 + self.0.contains(index as u32) } pub fn contains_all_set_bits_of(&self, other: &Self) -> bool { - self.0 & other.0 == other.0 + other.0.is_subset(&self.0) } pub fn is_empty(&self) -> bool { - self.0 == 0 + self.0.is_empty() } } @@ -905,10 +913,28 @@ impl JoinedTable { Table::BTree(btree) => { let use_covering_index = self.utilizes_covering_index(); let index_is_ephemeral = index.is_some_and(|index| index.ephemeral); - let table_not_required = - OperationMode::SELECT == mode && use_covering_index && !index_is_ephemeral; + let table_not_required = matches!(mode, OperationMode::SELECT) + && use_covering_index + && !index_is_ephemeral; let table_cursor_id = if table_not_required { None + } else if let OperationMode::UPDATE(UpdateRowSource::PrebuiltEphemeralTable { + target_table, + .. + }) = &mode + { + // The cursor for the ephemeral table was already allocated earlier. Let's allocate one for the target table, + // in case it wasn't already allocated when populating the ephemeral table. + Some(program.alloc_cursor_id_keyed_if_not_exists( + CursorKey::table(target_table.internal_id), + match &target_table.table { + Table::BTree(btree) => CursorType::BTreeTable(btree.clone()), + Table::Virtual(virtual_table) => { + CursorType::VirtualTable(virtual_table.clone()) + } + _ => unreachable!("target table must be a btree or virtual table"), + }, + )) } else { // Check if this is a materialized view let cursor_type = @@ -922,6 +948,7 @@ impl JoinedTable { .alloc_cursor_id_keyed(CursorKey::table(self.internal_id), cursor_type), ) }; + let index_cursor_id = index.map(|index| { program.alloc_cursor_id_keyed( CursorKey::index(self.internal_id, index.clone()), @@ -946,9 +973,19 @@ impl JoinedTable { pub fn resolve_cursors( &self, program: &mut ProgramBuilder, + mode: OperationMode, ) -> Result<(Option, Option)> { let index = self.op.index(); - let table_cursor_id = program.resolve_cursor_id_safe(&CursorKey::table(self.internal_id)); + let table_cursor_id = + if let OperationMode::UPDATE(UpdateRowSource::PrebuiltEphemeralTable { + target_table, + .. + }) = &mode + { + program.resolve_cursor_id_safe(&CursorKey::table(target_table.internal_id)) + } else { + program.resolve_cursor_id_safe(&CursorKey::table(self.internal_id)) + }; let index_cursor_id = index.map(|index| { program.resolve_cursor_id(&CursorKey::index(self.internal_id, index.clone())) }); @@ -1004,54 +1041,91 @@ impl JoinedTable { /// A definition of a rowid/index search. /// /// [SeekKey] is the condition that is used to seek to a specific row in a table/index. -/// [TerminationKey] is the condition that is used to terminate the search after a seek. +/// [SeekKey] also used to represent range scan termination condition. #[derive(Debug, Clone)] pub struct SeekDef { - /// The key to use when seeking and when terminating the scan that follows the seek. + /// Common prefix of the key which is shared between start/end fields /// For example, given: /// - CREATE INDEX i ON t (x, y desc) /// - SELECT * FROM t WHERE x = 1 AND y >= 30 /// - /// The key is [(1, ASC), (30, DESC)] - pub key: Vec<(ast::Expr, SortOrder)>, + /// Then, prefix=[(eq=1, ASC)], start=Some((ge, Expr(30))), end=Some((gt, Sentinel)) + pub prefix: Vec, /// The condition to use when seeking. See [SeekKey] for more details. - pub seek: Option, - /// The condition to use when terminating the scan that follows the seek. See [TerminationKey] for more details. - pub termination: Option, + pub start: SeekKey, + /// The condition to use when terminating the scan that follows the seek. See [SeekKey] for more details. + pub end: SeekKey, /// The direction of the scan that follows the seek. pub iter_dir: IterationDirection, } +pub struct SeekDefKeyIterator<'a> { + seek_def: &'a SeekDef, + seek_key: &'a SeekKey, + pos: usize, +} + +impl<'a> Iterator for SeekDefKeyIterator<'a> { + type Item = SeekKeyComponent<&'a ast::Expr>; + + fn next(&mut self) -> Option { + let result = if self.pos < self.seek_def.prefix.len() { + Some(SeekKeyComponent::Expr( + &self.seek_def.prefix[self.pos].eq.as_ref().unwrap().1, + )) + } else if self.pos == self.seek_def.prefix.len() { + match &self.seek_key.last_component { + SeekKeyComponent::Expr(expr) => Some(SeekKeyComponent::Expr(expr)), + SeekKeyComponent::None => None, + } + } else { + None + }; + self.pos += 1; + result + } +} + +impl SeekDef { + /// returns amount of values in the given seek key + /// - so, for SELECT * FROM t WHERE x = 10 AND y = 20 AND y >= 30 there will be 3 values (10, 20, 30) + pub fn size(&self, key: &SeekKey) -> usize { + self.prefix.len() + + match key.last_component { + SeekKeyComponent::Expr(_) => 1, + SeekKeyComponent::None => 0, + } + } + /// iterate over value expressions in the given seek key + pub fn iter<'a>(&'a self, key: &'a SeekKey) -> SeekDefKeyIterator<'a> { + SeekDefKeyIterator { + seek_def: self, + seek_key: key, + pos: 0, + } + } +} + +/// [SeekKeyComponent] enum represents optional last_component of the [SeekKey] +/// +/// This component represented by separate enum instead of Option because before there were third Sentinel value +/// For now - we don't need this and it's enough to just either use some user-provided expression or omit last component of the key completely +/// But as separate enum is almost never a harm - I decided to keep it here. +/// +/// This enum accepts generic argument E in order to use both SeekKeyComponent and SeekKeyComponent<&ast::Expr> +#[derive(Debug, Clone)] +pub enum SeekKeyComponent { + Expr(E), + None, +} + /// A condition to use when seeking. #[derive(Debug, Clone)] pub struct SeekKey { - /// How many columns from [SeekDef::key] are used in seeking. - pub len: usize, - /// Whether to NULL pad the last column of the seek key to match the length of [SeekDef::key]. - /// The reason it is done is that sometimes our full index key is not used in seeking, - /// but we want to find the lowest value that matches the non-null prefix of the key. - /// For example, given: - /// - CREATE INDEX i ON t (x, y) - /// - SELECT * FROM t WHERE x = 1 AND y < 30 - /// - /// We want to seek to the first row where x = 1, and then iterate forwards. - /// In this case, the seek key is GT(1, NULL) since NULL is always LT in index key comparisons. - /// We can't use just GT(1) because in index key comparisons, only the given number of columns are compared, - /// so this means any index keys with (x=1) will compare equal, e.g. (x=1, y=usize::MAX) will compare equal to the seek key (x:1) - pub null_pad: bool, - /// The comparison operator to use when seeking. - pub op: SeekOp, -} + /// Complete key must be constructed from common [SeekDef::prefix] and optional last_component + pub last_component: SeekKeyComponent, -#[derive(Debug, Clone)] -/// A condition to use when terminating the scan that follows a seek. -pub struct TerminationKey { - /// How many columns from [SeekDef::key] are used in terminating the scan that follows the seek. - pub len: usize, - /// Whether to NULL pad the last column of the termination key to match the length of [SeekDef::key]. - /// See [SeekKey::null_pad]. - pub null_pad: bool, - /// The comparison operator to use when terminating the scan that follows the seek. + /// The comparison operator to use when seeking. pub op: SeekOp, } @@ -1224,3 +1298,127 @@ pub struct WindowFunction { /// The expression from which the function was resolved. pub original_expr: Expr, } + +#[cfg(test)] +mod tests { + use super::*; + use rand_chacha::{ + rand_core::{RngCore, SeedableRng}, + ChaCha8Rng, + }; + + #[test] + fn test_column_used_mask_empty() { + let mask = ColumnUsedMask::default(); + assert!(mask.is_empty()); + + let mut mask2 = ColumnUsedMask::default(); + mask2.set(0); + assert!(!mask2.is_empty()); + } + + #[test] + fn test_column_used_mask_set_and_get() { + let mut mask = ColumnUsedMask::default(); + + let max_columns = 10000; + let mut set_indices = Vec::new(); + let mut rng = ChaCha8Rng::seed_from_u64( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + ); + + for i in 0..max_columns { + if rng.next_u32() % 3 == 0 { + set_indices.push(i); + mask.set(i); + } + } + + // Verify set bits are present + for &i in &set_indices { + assert!(mask.get(i), "Expected bit {i} to be set"); + } + + // Verify unset bits are not present + for i in 0..max_columns { + if !set_indices.contains(&i) { + assert!(!mask.get(i), "Expected bit {i} to not be set"); + } + } + } + + #[test] + fn test_column_used_mask_subset_relationship() { + let mut full_mask = ColumnUsedMask::default(); + let mut subset_mask = ColumnUsedMask::default(); + + let max_columns = 5000; + let mut rng = ChaCha8Rng::seed_from_u64( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + ); + + // Create a pattern where subset has fewer bits + for i in 0..max_columns { + if rng.next_u32() % 5 == 0 { + full_mask.set(i); + if i % 2 == 0 { + subset_mask.set(i); + } + } + } + + // full_mask contains all bits of subset_mask + assert!(full_mask.contains_all_set_bits_of(&subset_mask)); + + // subset_mask does not contain all bits of full_mask + assert!(!subset_mask.contains_all_set_bits_of(&full_mask)); + + // A mask contains itself + assert!(full_mask.contains_all_set_bits_of(&full_mask)); + assert!(subset_mask.contains_all_set_bits_of(&subset_mask)); + } + + #[test] + fn test_column_used_mask_empty_subset() { + let mut mask = ColumnUsedMask::default(); + for i in (0..1000).step_by(7) { + mask.set(i); + } + + let empty_mask = ColumnUsedMask::default(); + + // Empty mask is subset of everything + assert!(mask.contains_all_set_bits_of(&empty_mask)); + assert!(empty_mask.contains_all_set_bits_of(&empty_mask)); + } + + #[test] + fn test_column_used_mask_sparse_indices() { + let mut sparse_mask = ColumnUsedMask::default(); + + // Test with very sparse, large indices + let sparse_indices = vec![0, 137, 1042, 5389, 10000, 50000, 100000, 500000, 1000000]; + + for &idx in &sparse_indices { + sparse_mask.set(idx); + } + + for &idx in &sparse_indices { + assert!(sparse_mask.get(idx), "Expected bit {idx} to be set"); + } + + // Check some indices that shouldn't be set + let unset_indices = vec![1, 100, 1000, 5000, 25000, 75000, 250000, 750000]; + for &idx in &unset_indices { + assert!(!sparse_mask.get(idx), "Expected bit {idx} to not be set"); + } + + assert!(!sparse_mask.is_empty()); + } +} diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 589b45f3f..f7d930bd6 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -180,7 +180,26 @@ pub fn resolve_window_and_aggregate_functions( name.as_str() ); } - crate::bail_parse_error!("Invalid aggregate function: {}", name.as_str()); + + // Check if the function supports (*) syntax using centralized logic + match crate::function::Func::resolve_function(name.as_str(), 0) { + Ok(func) => { + if func.supports_star_syntax() { + return Ok(WalkControl::Continue); + } else { + crate::bail_parse_error!( + "wrong number of arguments to function {}()", + name.as_str() + ); + } + } + Err(_) => { + crate::bail_parse_error!( + "wrong number of arguments to function {}()", + name.as_str() + ); + } + } } Err(e) => match e { crate::LimboError::ParseError(e) => { @@ -298,11 +317,26 @@ fn parse_from_clause_table( ) } ast::SelectTable::Select(subselect, maybe_alias) => { + let outer_query_refs_for_subquery = table_references + .outer_query_refs() + .iter() + .cloned() + .chain( + ctes.iter() + .cloned() + .map(|t: JoinedTable| OuterQueryReference { + identifier: t.identifier, + internal_id: t.internal_id, + table: t.table, + col_used_mask: ColumnUsedMask::default(), + }), + ) + .collect::>(); let Plan::Select(subplan) = prepare_select_plan( subselect, resolver, program, - table_references.outer_query_refs(), + &outer_query_refs_for_subquery, QueryDestination::placeholder_for_subquery(), connection, )? @@ -478,6 +512,7 @@ fn parse_table( has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }); drop(view_guard); @@ -951,9 +986,29 @@ fn parse_join( crate::bail_parse_error!("NATURAL JOIN cannot be combined with ON or USING clause"); } + // this is called once for each join, so we only need to check the rightmost table + // against all previous tables for duplicates + let rightmost_table = table_references.joined_tables().last().unwrap(); + let has_duplicate = table_references + .joined_tables() + .iter() + .take(table_references.joined_tables().len() - 1) + .any(|t| t.identifier == rightmost_table.identifier); + + if has_duplicate + && !natural + && constraint + .as_ref() + .is_none_or(|c| !matches!(c, ast::JoinConstraint::Using(_))) + { + // Duplicate table names are only allowed for NATURAL or USING joins + crate::bail_parse_error!( + "table name {} specified more than once - use an alias to disambiguate", + rightmost_table.identifier + ); + } let constraint = if natural { assert!(table_references.joined_tables().len() >= 2); - let rightmost_table = table_references.joined_tables().last().unwrap(); // NATURAL JOIN is first transformed into a USING join with the common columns let mut distinct_names: Vec = vec![]; // TODO: O(n^2) maybe not great for large tables or big multiway joins diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index d8b26143a..8e996bae9 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -95,6 +95,20 @@ fn update_pragma( connection: Arc, mut program: ProgramBuilder, ) -> crate::Result<(ProgramBuilder, TransactionMode)> { + let parse_pragma_enabled = |expr: &ast::Expr| -> bool { + if let Expr::Literal(Literal::Numeric(n)) = expr { + return !matches!(n.as_str(), "0"); + }; + let name_bytes = match expr { + Expr::Literal(Literal::Keyword(name)) => name.as_bytes(), + Expr::Name(name) | Expr::Id(name) => name.as_str().as_bytes(), + _ => "".as_bytes(), + }; + match_ignore_ascii_case!(match name_bytes { + b"ON" | b"TRUE" | b"YES" | b"1" => true, + _ => false, + }) + }; match pragma { PragmaName::ApplicationId => { let data = parse_signed_number(&value)?; @@ -343,39 +357,32 @@ fn update_pragma( } PragmaName::Synchronous => { use crate::SyncMode; - - let mode = match value { - Expr::Name(name) => { - let name_bytes = name.as_str().as_bytes(); - match_ignore_ascii_case!(match name_bytes { - b"OFF" | b"FALSE" | b"NO" | b"0" => SyncMode::Off, - _ => SyncMode::Full, - }) - } - Expr::Literal(Literal::Numeric(n)) => match n.as_str() { - "0" => SyncMode::Off, - _ => SyncMode::Full, - }, - _ => SyncMode::Full, + let mode = match parse_pragma_enabled(&value) { + true => SyncMode::Full, + false => SyncMode::Off, }; - connection.set_sync_mode(mode); Ok((program, TransactionMode::None)) } PragmaName::DataSyncRetry => { - let retry_enabled = match value { - Expr::Name(name) => { - let name_bytes = name.as_str().as_bytes(); - match_ignore_ascii_case!(match name_bytes { - b"ON" | b"TRUE" | b"YES" | b"1" => true, - _ => false, - }) - } - Expr::Literal(Literal::Numeric(n)) => !matches!(n.as_str(), "0"), - _ => false, + let retry_enabled = parse_pragma_enabled(&value); + connection.set_data_sync_retry(retry_enabled); + Ok((program, TransactionMode::None)) + } + PragmaName::MvccCheckpointThreshold => { + let threshold = match parse_signed_number(&value)? { + Value::Integer(size) if size >= -1 => size, + _ => bail_parse_error!( + "mvcc_checkpoint_threshold must be -1, 0, or a positive integer" + ), }; - connection.set_data_sync_retry(retry_enabled); + connection.set_mvcc_checkpoint_threshold(threshold)?; + Ok((program, TransactionMode::None)) + } + PragmaName::ForeignKeys => { + let enabled = parse_pragma_enabled(&value); + connection.set_foreign_keys_enabled(enabled); Ok((program, TransactionMode::None)) } } @@ -687,6 +694,22 @@ fn query_pragma( program.add_pragma_result_column(pragma.to_string()); Ok((program, TransactionMode::None)) } + PragmaName::MvccCheckpointThreshold => { + let threshold = connection.mvcc_checkpoint_threshold()?; + let register = program.alloc_register(); + program.emit_int(threshold, register); + program.emit_result_row(register, 1); + program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::None)) + } + PragmaName::ForeignKeys => { + let enabled = connection.foreign_keys_enabled(); + let register = program.alloc_register(); + program.emit_int(enabled as i64, register); + program.emit_result_row(register, 1); + program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::None)) + } } } diff --git a/core/translate/result_row.rs b/core/translate/result_row.rs index c087a0abf..e29ea3531 100644 --- a/core/translate/result_row.rs +++ b/core/translate/result_row.rs @@ -1,5 +1,3 @@ -use turso_parser::ast::{Expr, Literal, Operator, UnaryOperator}; - use crate::{ vdbe::{ builder::ProgramBuilder, @@ -133,7 +131,9 @@ pub fn emit_result_row_and_limit( key_reg: result_columns_start_reg + (plan.result_columns.len() - 1), // Rowid reg is the last register record_reg, // since we are not doing an Insn::NewRowid or an Insn::NotExists here, we need to seek to ensure the insertion happens in the correct place. - flag: InsertFlags::new().require_seek(), + flag: InsertFlags::new() + .require_seek() + .is_ephemeral_table_insert(), table_name: table.name.clone(), }); } @@ -172,36 +172,3 @@ pub fn emit_offset(program: &mut ProgramBuilder, jump_to: BranchOffset, reg_offs decrement_by: 1, }); } - -#[allow(clippy::borrowed_box)] -pub fn try_fold_expr_to_i64(expr: &Box) -> Option { - match expr.as_ref() { - Expr::Literal(Literal::Numeric(n)) => n.parse::().ok(), - Expr::Literal(Literal::Null) => Some(0), - Expr::Id(name) if !name.quoted() => { - let lowered = name.as_str(); - if lowered == "true" { - Some(1) - } else if lowered == "false" { - Some(0) - } else { - None - } - } - Expr::Unary(UnaryOperator::Negative, inner) => try_fold_expr_to_i64(inner).map(|v| -v), - Expr::Unary(UnaryOperator::Positive, inner) => try_fold_expr_to_i64(inner), - Expr::Binary(left, op, right) => { - let l = try_fold_expr_to_i64(left)?; - let r = try_fold_expr_to_i64(right)?; - match op { - Operator::Add => Some(l.saturating_add(r)), - Operator::Subtract => Some(l.saturating_sub(r)), - Operator::Multiply => Some(l.saturating_mul(r)), - Operator::Divide if r != 0 => Some(l.saturating_div(r)), - _ => None, - } - } - - _ => None, - } -} diff --git a/core/translate/schema.rs b/core/translate/schema.rs index ce85756ff..0442b5542 100644 --- a/core/translate/schema.rs +++ b/core/translate/schema.rs @@ -26,6 +26,55 @@ use crate::{bail_parse_error, Result}; use turso_ext::VTabKind; +fn validate(body: &ast::CreateTableBody, connection: &Connection) -> Result<()> { + if let ast::CreateTableBody::ColumnsAndConstraints { + options, columns, .. + } = &body + { + if options.contains(ast::TableOptions::STRICT) && !connection.experimental_strict_enabled() + { + bail_parse_error!( + "STRICT tables are an experimental feature. Enable them with --experimental-strict flag" + ); + } + for i in 0..columns.len() { + let col_i = &columns[i]; + for constraint in &col_i.constraints { + // don't silently ignore CHECK constraints, throw parse error for now + match constraint.constraint { + ast::ColumnConstraint::Check { .. } => { + bail_parse_error!("CHECK constraints are not supported yet"); + } + ast::ColumnConstraint::Generated { .. } => { + bail_parse_error!("GENERATED columns are not supported yet"); + } + ast::ColumnConstraint::NotNull { + conflict_clause, .. + } + | ast::ColumnConstraint::PrimaryKey { + conflict_clause, .. + } if conflict_clause.is_some() => { + bail_parse_error!( + "ON CONFLICT clauses are not supported yet in column definitions" + ); + } + _ => {} + } + } + for j in &columns[(i + 1)..] { + if col_i + .col_name + .as_str() + .eq_ignore_ascii_case(j.col_name.as_str()) + { + bail_parse_error!("duplicate column name: {}", j.col_name.as_str()); + } + } + } + } + Ok(()) +} + pub fn translate_create_table( tbl_name: ast::QualifiedName, resolver: &Resolver, @@ -39,16 +88,8 @@ pub fn translate_create_table( if temporary { bail_parse_error!("TEMPORARY table not supported yet"); } + validate(&body, connection)?; - // Check for STRICT mode without experimental flag - if let ast::CreateTableBody::ColumnsAndConstraints { options, .. } = &body { - if options.contains(ast::TableOptions::STRICT) && !connection.experimental_strict_enabled() - { - bail_parse_error!( - "STRICT tables are an experimental feature. Enable them with --experimental-strict flag" - ); - } - } let opts = ProgramBuilderOpts { num_cursors: 1, approx_num_insns: 30, @@ -630,7 +671,8 @@ pub fn translate_drop_table( let null_reg = program.alloc_register(); // r1 program.emit_null(null_reg, None); let table_name_and_root_page_register = program.alloc_register(); // r2, this register is special because it's first used to track table name and then moved root page - let table_reg = program.emit_string8_new_reg(tbl_name.name.as_str().to_string()); // r3 + let table_reg = + program.emit_string8_new_reg(normalize_ident(tbl_name.name.as_str()).to_string()); // r3 program.mark_last_insn_constant(); let table_type = program.emit_string8_new_reg("trigger".to_string()); // r4 program.mark_last_insn_constant(); @@ -728,6 +770,7 @@ pub fn translate_drop_table( program.emit_insn(Insn::Delete { cursor_id: sqlite_schema_cursor_id_0, table_name: SQLITE_TABLEID.to_string(), + is_part_of_update: false, }); program.resolve_label(next_label, program.offset()); @@ -812,6 +855,7 @@ pub fn translate_drop_table( }], is_strict: false, unique_sets: vec![], + foreign_keys: vec![], }); // cursor id 2 let ephemeral_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(simple_table_rc)); @@ -927,6 +971,7 @@ pub fn translate_drop_table( program.emit_insn(Insn::Delete { cursor_id: sqlite_schema_cursor_id_1, table_name: SQLITE_TABLEID.to_string(), + is_part_of_update: false, }); program.emit_insn(Insn::Insert { cursor: sqlite_schema_cursor_id_1, @@ -987,6 +1032,7 @@ pub fn translate_drop_table( program.emit_insn(Insn::Delete { cursor_id: seq_cursor_id, table_name: "sqlite_sequence".to_string(), + is_part_of_update: false, }); program.resolve_label(continue_loop_label, program.offset()); diff --git a/core/translate/select.rs b/core/translate/select.rs index 915c64242..59b6ff6cb 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -4,7 +4,7 @@ use super::plan::{ Search, TableReferences, WhereTerm, Window, }; use crate::schema::Table; -use crate::translate::emitter::Resolver; +use crate::translate::emitter::{OperationMode, Resolver}; use crate::translate::expr::{bind_and_rewrite_expr, BindingBehavior, ParamState}; use crate::translate::group_by::compute_group_by_sort_order; use crate::translate::optimizer::optimize_plan; @@ -43,7 +43,7 @@ pub fn translate_select( query_destination, connection, )?; - optimize_plan(&mut select_plan, resolver.schema)?; + optimize_plan(&mut program, &mut select_plan, resolver.schema)?; let num_result_cols; let opts = match &select_plan { Plan::Select(select) => { @@ -505,7 +505,7 @@ fn prepare_one_select_plan( // Return the unoptimized query plan Ok(plan) } - ast::OneSelect::Values(values) => { + ast::OneSelect::Values(mut values) => { if !order_by.is_empty() { crate::bail_parse_error!("ORDER BY clause is not allowed with VALUES clause"); } @@ -522,6 +522,21 @@ fn prepare_one_select_plan( contains_aggregates: false, }); } + + for value_row in values.iter_mut() { + for value in value_row.iter_mut() { + bind_and_rewrite_expr( + value, + None, + None, + connection, + &mut program.param_ctx, + // Allow sqlite quirk of inserting "double-quoted" literals (which our AST maps as identifiers) + BindingBehavior::AllowUnboundIdentifiers, + )?; + } + } + let plan = SelectPlan { join_order: vec![], table_references: TableReferences::new(vec![], vec![]), @@ -674,7 +689,7 @@ pub fn emit_simple_count( .joined_tables() .first() .unwrap() - .resolve_cursors(program)?; + .resolve_cursors(program, OperationMode::SELECT)?; let cursor_id = { match cursors { diff --git a/core/translate/update.rs b/core/translate/update.rs index f89ddedff..1aac4c745 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -1,15 +1,13 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use crate::schema::{BTreeTable, Column, Type, ROWID_SENTINEL}; +use crate::schema::ROWID_SENTINEL; use crate::translate::emitter::Resolver; use crate::translate::expr::{ bind_and_rewrite_expr, walk_expr, BindingBehavior, ParamState, WalkControl, }; -use crate::translate::optimizer::optimize_select_plan; -use crate::translate::plan::{Operation, QueryDestination, Scan, Search, SelectPlan}; +use crate::translate::plan::{Operation, Scan}; use crate::translate::planner::{parse_limit, ROWID_STRS}; -use crate::vdbe::builder::CursorType; use crate::{ bail_parse_error, schema::{Schema, Table}, @@ -22,8 +20,7 @@ use super::emitter::emit_program; use super::expr::process_returning_clause; use super::optimizer::optimize_plan; use super::plan::{ - ColumnUsedMask, IterationDirection, JoinedTable, Plan, ResultSetColumn, TableReferences, - UpdatePlan, + ColumnUsedMask, IterationDirection, JoinedTable, Plan, TableReferences, UpdatePlan, }; use super::planner::parse_where; /* @@ -62,7 +59,7 @@ pub fn translate_update( connection: &Arc, ) -> crate::Result { let mut plan = prepare_update_plan(&mut program, resolver.schema, body, connection, false)?; - optimize_plan(&mut plan, resolver.schema)?; + optimize_plan(&mut program, &mut plan, resolver.schema)?; let opts = ProgramBuilderOpts { num_cursors: 1, approx_num_insns: 20, @@ -89,7 +86,7 @@ pub fn translate_update_for_schema_change( } } - optimize_plan(&mut plan, resolver.schema)?; + optimize_plan(&mut program, &mut plan, resolver.schema)?; let opts = ProgramBuilderOpts { num_cursors: 1, approx_num_insns: 20, @@ -185,7 +182,10 @@ pub fn prepare_update_plan( Table::BTree(btree_table) => Table::BTree(btree_table.clone()), _ => unreachable!(), }, - identifier: table_name.to_string(), + identifier: body.tbl_name.alias.as_ref().map_or_else( + || table_name.to_string(), + |alias| alias.as_str().to_string(), + ), internal_id: program.table_reference_counter.next(), op: build_scan_op(&table, iter_dir), join_info: None, @@ -298,119 +298,16 @@ pub fn prepare_update_plan( // https://github.com/sqlite/sqlite/blob/master/src/update.c#L395 // https://github.com/sqlite/sqlite/blob/master/src/update.c#L670 let columns = table.columns(); - - let rowid_alias_used = set_clauses.iter().fold(false, |accum, (idx, _)| { - accum || (*idx != ROWID_SENTINEL && columns[*idx].is_rowid_alias) - }); - let direct_rowid_update = set_clauses.iter().any(|(idx, _)| *idx == ROWID_SENTINEL); - - let (ephemeral_plan, mut where_clause) = if rowid_alias_used || direct_rowid_update { - let mut where_clause = vec![]; - let internal_id = program.table_reference_counter.next(); - - let joined_tables = vec![JoinedTable { - table: match table.as_ref() { - Table::Virtual(vtab) => Table::Virtual(vtab.clone()), - Table::BTree(btree_table) => Table::BTree(btree_table.clone()), - _ => unreachable!(), - }, - identifier: table_name.to_string(), - internal_id, - op: build_scan_op(&table, iter_dir), - join_info: None, - col_used_mask: ColumnUsedMask::default(), - database_id: 0, - }]; - let mut table_references = TableReferences::new(joined_tables, vec![]); - - // Parse the WHERE clause - parse_where( - body.where_clause.as_deref(), - &mut table_references, - Some(&result_columns), - &mut where_clause, - connection, - &mut program.param_ctx, - )?; - - let table = Arc::new(BTreeTable { - root_page: 0, // Not relevant for ephemeral table definition - name: "ephemeral_scratch".to_string(), - has_rowid: true, - has_autoincrement: false, - primary_key_columns: vec![], - columns: vec![Column { - name: Some("rowid".to_string()), - ty: Type::Integer, - ty_str: "INTEGER".to_string(), - primary_key: true, - is_rowid_alias: false, - notnull: true, - default: None, - unique: false, - collation: None, - hidden: false, - }], - is_strict: false, - unique_sets: vec![], - }); - - let temp_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(table.clone())); - - let mut ephemeral_plan = SelectPlan { - table_references, - result_columns: vec![ResultSetColumn { - expr: Expr::RowId { - database: None, - table: internal_id, - }, - alias: None, - contains_aggregates: false, - }], - where_clause, // original WHERE terms from the UPDATE clause - group_by: None, // N/A - order_by: vec![], // N/A - aggregates: vec![], // N/A - limit: None, // N/A - query_destination: QueryDestination::EphemeralTable { - cursor_id: temp_cursor_id, - table, - }, - join_order: vec![], - offset: None, - contains_constant_false_condition: false, - distinctness: super::plan::Distinctness::NonDistinct, - values: vec![], - window: None, - }; - - optimize_select_plan(&mut ephemeral_plan, schema)?; - let table = ephemeral_plan - .table_references - .joined_tables() - .first() - .unwrap(); - // We do not need to emit an ephemeral plan if we are not going to loop over the table values - if matches!(table.op, Operation::Search(Search::RowidEq { .. })) { - (None, vec![]) - } else { - (Some(ephemeral_plan), vec![]) - } - } else { - (None, vec![]) - }; - - if ephemeral_plan.is_none() { - // Parse the WHERE clause - parse_where( - body.where_clause.as_deref(), - &mut table_references, - Some(&result_columns), - &mut where_clause, - connection, - &mut program.param_ctx, - )?; - }; + let mut where_clause = vec![]; + // Parse the WHERE clause + parse_where( + body.where_clause.as_deref(), + &mut table_references, + Some(&result_columns), + &mut where_clause, + connection, + &mut program.param_ctx, + )?; // Parse the LIMIT/OFFSET clause let (limit, offset) = body.limit.as_mut().map_or(Ok((None, None)), |l| { @@ -481,7 +378,7 @@ pub fn prepare_update_plan( offset, contains_constant_false_condition: false, indexes_to_update, - ephemeral_plan, + ephemeral_plan: None, cdc_update_alter_statement: None, })) } diff --git a/core/translate/upsert.rs b/core/translate/upsert.rs index ffcff23e5..85f304283 100644 --- a/core/translate/upsert.rs +++ b/core/translate/upsert.rs @@ -5,10 +5,14 @@ use std::{collections::HashMap, sync::Arc}; use turso_parser::ast::{self, Upsert}; use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; +use crate::schema::ROWID_SENTINEL; +use crate::translate::emitter::UpdateRowSource; use crate::translate::expr::{walk_expr, WalkControl}; -use crate::translate::insert::format_unique_violation_desc; +use crate::translate::fkeys::{emit_fk_child_update_counters, emit_parent_key_change_checks}; +use crate::translate::insert::{format_unique_violation_desc, InsertEmitCtx}; use crate::translate::planner::ROWID_STRS; use crate::vdbe::insn::CmpInsFlags; +use crate::Connection; use crate::{ bail_parse_error, error::SQLITE_CONSTRAINT_NOTNULL, @@ -28,7 +32,6 @@ use crate::{ vdbe::{ builder::ProgramBuilder, insn::{IdxInsertFlags, InsertFlags, Insn}, - BranchOffset, }, }; @@ -336,34 +339,31 @@ pub fn resolve_upsert_target( pub fn emit_upsert( program: &mut ProgramBuilder, table: &Table, + ctx: &InsertEmitCtx, insertion: &Insertion, - tbl_cursor_id: usize, - conflict_rowid_reg: usize, set_pairs: &mut [(usize, Box)], where_clause: &mut Option>, resolver: &Resolver, - idx_cursors: &[(&String, i64, usize)], returning: &mut [ResultSetColumn], - cdc_cursor_id: Option, - row_done_label: BranchOffset, + connection: &Arc, ) -> crate::Result<()> { // Seek & snapshot CURRENT program.emit_insn(Insn::SeekRowid { - cursor_id: tbl_cursor_id, - src_reg: conflict_rowid_reg, - target_pc: row_done_label, + cursor_id: ctx.cursor_id, + src_reg: ctx.conflict_rowid_reg, + target_pc: ctx.row_done_label, }); - let num_cols = table.columns().len(); + let num_cols = ctx.table.columns.len(); let current_start = program.alloc_registers(num_cols); - for (i, col) in table.columns().iter().enumerate() { + for (i, col) in ctx.table.columns.iter().enumerate() { if col.is_rowid_alias { program.emit_insn(Insn::RowId { - cursor_id: tbl_cursor_id, + cursor_id: ctx.cursor_id, dest: current_start + i, }); } else { program.emit_insn(Insn::Column { - cursor_id: tbl_cursor_id, + cursor_id: ctx.cursor_id, column: i, dest: current_start + i, default: None, @@ -372,7 +372,7 @@ pub fn emit_upsert( } // BEFORE for index maintenance / CDC - let before_start = if cdc_cursor_id.is_some() || !idx_cursors.is_empty() { + let before_start = if ctx.cdc_table.is_some() || !ctx.idx_cursors.is_empty() { let s = program.alloc_registers(num_cols); program.emit_insn(Insn::Copy { src_reg: current_start, @@ -398,7 +398,7 @@ pub fn emit_upsert( pred, table, current_start, - conflict_rowid_reg, + ctx.conflict_rowid_reg, Some(table.get_name()), Some(insertion), true, @@ -407,7 +407,7 @@ pub fn emit_upsert( translate_expr(program, None, pred, pr, resolver)?; program.emit_insn(Insn::IfNot { reg: pr, - target_pc: row_done_label, + target_pc: ctx.row_done_label, jump_if_null: true, }); } @@ -419,7 +419,7 @@ pub fn emit_upsert( expr, table, current_start, - conflict_rowid_reg, + ctx.conflict_rowid_reg, Some(table.get_name()), Some(insertion), true, @@ -464,11 +464,59 @@ pub fn emit_upsert( } } + let (changed_cols, rowid_changed) = collect_changed_cols(table, set_pairs); + let rowid_alias_idx = table.columns().iter().position(|c| c.is_rowid_alias); + let has_direct_rowid_update = set_pairs + .iter() + .any(|(idx, _)| *idx == rowid_alias_idx.unwrap_or(ROWID_SENTINEL)); + let has_user_provided_rowid = if let Some(i) = rowid_alias_idx { + set_pairs.iter().any(|(idx, _)| *idx == i) || has_direct_rowid_update + } else { + has_direct_rowid_update + }; + + let rowid_set_clause_reg = if has_user_provided_rowid { + Some(new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg)) + } else { + None + }; + if let Some(bt) = table.btree() { + if connection.foreign_keys_enabled() { + let rowid_new_reg = new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg); + + // Child-side checks + if resolver.schema.has_child_fks(bt.name.as_str()) { + emit_fk_child_update_counters( + program, + resolver, + &bt, + table.get_name(), + ctx.cursor_id, + new_start, + rowid_new_reg, + &changed_cols, + )?; + } + emit_parent_key_change_checks( + program, + resolver, + &bt, + resolver.schema.get_indices(table.get_name()).filter(|idx| { + upsert_index_is_affected(table, idx, &changed_cols, rowid_changed) + }), + ctx.cursor_id, + ctx.conflict_rowid_reg, + new_start, + new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg), + rowid_set_clause_reg, + set_pairs, + )?; + } + } + // Index rebuild (DELETE old, INSERT new), honoring partial-index WHEREs if let Some(before) = before_start { - let (changed_cols, rowid_changed) = collect_changed_cols(table, set_pairs); - - for (idx_name, _root, idx_cid) in idx_cursors { + for (idx_name, _root, idx_cid) in &ctx.idx_cursors { let idx_meta = resolver .schema .get_index(table.get_name(), idx_name) @@ -484,10 +532,10 @@ pub fn emit_upsert( table, idx_meta, before, - conflict_rowid_reg, + ctx.conflict_rowid_reg, resolver, ); - let new_rowid = new_rowid_reg.unwrap_or(conflict_rowid_reg); + let new_rowid = new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg); let new_pred_reg = eval_partial_pred_for_row_image( program, table, idx_meta, new_start, new_rowid, resolver, ); @@ -514,7 +562,7 @@ pub fn emit_upsert( }); } program.emit_insn(Insn::Copy { - src_reg: conflict_rowid_reg, + src_reg: ctx.conflict_rowid_reg, dst_reg: del + k, extra_amount: 0, }); @@ -645,7 +693,7 @@ pub fn emit_upsert( // If equal to old rowid, skip uniqueness probe program.emit_insn(Insn::Eq { lhs: rnew, - rhs: conflict_rowid_reg, + rhs: ctx.conflict_rowid_reg, target_pc: ok, flags: CmpInsFlags::default(), collation: program.curr_collation(), @@ -653,7 +701,7 @@ pub fn emit_upsert( // If another row already has rnew -> constraint program.emit_insn(Insn::NotExists { - cursor: tbl_cursor_id, + cursor: ctx.cursor_id, rowid_reg: rnew, target_pc: ok, }); @@ -674,11 +722,12 @@ pub fn emit_upsert( // Now replace the row program.emit_insn(Insn::Delete { - cursor_id: tbl_cursor_id, + cursor_id: ctx.cursor_id, table_name: table.get_name().to_string(), + is_part_of_update: true, }); program.emit_insn(Insn::Insert { - cursor: tbl_cursor_id, + cursor: ctx.cursor_id, key_reg: rnew, record_reg: rec, flag: InsertFlags::new().require_seek().update_rowid_change(), @@ -686,8 +735,8 @@ pub fn emit_upsert( }); } else { program.emit_insn(Insn::Insert { - cursor: tbl_cursor_id, - key_reg: conflict_rowid_reg, + cursor: ctx.cursor_id, + key_reg: ctx.conflict_rowid_reg, record_reg: rec, flag: InsertFlags::new(), table_name: table.get_name().to_string(), @@ -695,16 +744,16 @@ pub fn emit_upsert( } // emit CDC instructions - if let Some(cdc_id) = cdc_cursor_id { - let new_rowid = new_rowid_reg.unwrap_or(conflict_rowid_reg); + if let Some((cdc_id, _)) = ctx.cdc_table { + let new_rowid = new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg); if new_rowid_reg.is_some() { // DELETE (before) let before_rec = if program.capture_data_changes_mode().has_before() { Some(emit_cdc_full_record( program, table.columns(), - tbl_cursor_id, - conflict_rowid_reg, + ctx.cursor_id, + ctx.conflict_rowid_reg, )) } else { None @@ -714,7 +763,7 @@ pub fn emit_upsert( resolver, OperationMode::DELETE, cdc_id, - conflict_rowid_reg, + ctx.conflict_rowid_reg, before_rec, None, None, @@ -747,7 +796,7 @@ pub fn emit_upsert( table, new_start, rec, - conflict_rowid_reg, + ctx.conflict_rowid_reg, )) } else { None @@ -756,8 +805,8 @@ pub fn emit_upsert( Some(emit_cdc_full_record( program, table.columns(), - tbl_cursor_id, - conflict_rowid_reg, + ctx.cursor_id, + ctx.conflict_rowid_reg, )) } else { None @@ -765,9 +814,9 @@ pub fn emit_upsert( emit_cdc_insns( program, resolver, - OperationMode::UPDATE, + OperationMode::UPDATE(UpdateRowSource::Normal), cdc_id, - conflict_rowid_reg, + ctx.conflict_rowid_reg, before_rec, after_rec, None, @@ -779,7 +828,7 @@ pub fn emit_upsert( // RETURNING from NEW image + final rowid if !returning.is_empty() { let regs = ReturningValueRegisters { - rowid_register: new_rowid_reg.unwrap_or(conflict_rowid_reg), + rowid_register: new_rowid_reg.unwrap_or(ctx.conflict_rowid_reg), columns_start_register: new_start, num_columns: num_cols, }; @@ -787,7 +836,7 @@ pub fn emit_upsert( } program.emit_insn(Insn::Goto { - target_pc: row_done_label, + target_pc: ctx.row_done_label, }); Ok(()) } diff --git a/core/translate/view.rs b/core/translate/view.rs index 399664ab1..47f0822d7 100644 --- a/core/translate/view.rs +++ b/core/translate/view.rs @@ -80,6 +80,7 @@ pub fn translate_create_materialized_view( has_autoincrement: false, unique_sets: vec![], + foreign_keys: vec![], }); // Allocate a cursor for writing to the view's btree during population @@ -112,6 +113,7 @@ pub fn translate_create_materialized_view( program.emit_insn(Insn::Delete { cursor_id: view_cursor_id, table_name: normalized_view_name.clone(), + is_part_of_update: false, }); program.emit_insn(Insn::Next { cursor_id: view_cursor_id, @@ -408,6 +410,7 @@ pub fn translate_drop_view( program.emit_insn(Insn::Delete { cursor_id: sqlite_schema_cursor_id, table_name: "sqlite_schema".to_string(), + is_part_of_update: false, }); program.resolve_label(skip_delete_label, program.offset()); diff --git a/core/translate/window.rs b/core/translate/window.rs index 91d783ff0..7ab80207d 100644 --- a/core/translate/window.rs +++ b/core/translate/window.rs @@ -505,6 +505,7 @@ pub fn init_window<'a>( is_strict: false, unique_sets: vec![], has_autoincrement: false, + foreign_keys: vec![], }); let cursor_buffer_read = program.alloc_cursor_id(CursorType::BTreeTable(buffer_table.clone())); let cursor_buffer_write = program.alloc_cursor_id(CursorType::BTreeTable(buffer_table.clone())); diff --git a/core/types.rs b/core/types.rs index dc4b2162e..a7259f8b1 100644 --- a/core/types.rs +++ b/core/types.rs @@ -8,7 +8,7 @@ use crate::ext::{ExtValue, ExtValueType}; use crate::numeric::format_float; use crate::pseudo::PseudoCursor; use crate::schema::Index; -use crate::storage::btree::BTreeCursor; +use crate::storage::btree::CursorTrait; use crate::storage::sqlite3_ondisk::{read_integer, read_value, read_varint, write_varint}; use crate::translate::collate::CollationSeq; use crate::translate::plan::IterationDirection; @@ -17,6 +17,7 @@ use crate::vdbe::Register; use crate::vtab::VirtualTableCursor; use crate::{turso_assert, Completion, CompletionError, Result, IO}; use std::fmt::{Debug, Display}; +use std::task::Waker; /// SQLite by default uses 2000 as maximum numbers in a row. /// It controlld by the constant called SQLITE_MAX_COLUMN @@ -68,12 +69,6 @@ impl Display for Text { } } -#[derive(Debug, Clone, PartialEq)] -pub struct TextRef { - pub value: RawSlice, - pub subtype: TextSubtype, -} - impl Text { pub fn new(value: &str) -> Self { Self { @@ -119,24 +114,12 @@ pub trait AnyText: AsRef { fn subtype(&self) -> TextSubtype; } -impl AsRef for TextRef { - fn as_ref(&self) -> &str { - self.as_str() - } -} - impl AnyText for Text { fn subtype(&self) -> TextSubtype { self.subtype } } -impl AnyText for TextRef { - fn subtype(&self) -> TextSubtype { - self.subtype - } -} - impl AnyText for &str { fn subtype(&self) -> TextSubtype { TextSubtype::Text @@ -147,12 +130,6 @@ pub trait AnyBlob { fn as_slice(&self) -> &[u8]; } -impl AnyBlob for RawSlice { - fn as_slice(&self) -> &[u8] { - self.to_slice() - } -} - impl AnyBlob for Vec { fn as_slice(&self) -> &[u8] { self.as_slice() @@ -195,22 +172,6 @@ impl From for String { } } -impl Display for TextRef { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.as_str()) - } -} - -impl TextRef { - pub fn create_from(value: &[u8], subtype: TextSubtype) -> Self { - let value = RawSlice::create_from(value); - Self { value, subtype } - } - pub fn as_str(&self) -> &str { - unsafe { std::str::from_utf8_unchecked(self.value.to_slice()) } - } -} - #[cfg(feature = "serde")] fn float_to_string(float: &f64, serializer: S) -> Result where @@ -249,30 +210,24 @@ pub enum Value { Blob(Vec), } -#[derive(Debug, Clone, PartialEq)] -pub struct RawSlice { - data: *const u8, - len: usize, -} - -#[derive(PartialEq, Clone)] -pub enum RefValue { +#[derive(PartialEq, Clone, Copy)] +pub enum ValueRef<'a> { Null, Integer(i64), Float(f64), - Text(TextRef), - Blob(RawSlice), + Text(&'a [u8], TextSubtype), + Blob(&'a [u8]), } -impl Debug for RefValue { +impl Debug for ValueRef<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - RefValue::Null => write!(f, "Null"), - RefValue::Integer(i) => f.debug_tuple("Integer").field(i).finish(), - RefValue::Float(float) => f.debug_tuple("Float").field(float).finish(), - RefValue::Text(text_ref) => { + ValueRef::Null => write!(f, "Null"), + ValueRef::Integer(i) => f.debug_tuple("Integer").field(i).finish(), + ValueRef::Float(float) => f.debug_tuple("Float").field(float).finish(), + ValueRef::Text(text_ref, _) => { // truncate string to at most 256 chars - let text = text_ref.as_str(); + let text = String::from_utf8_lossy(text_ref); let max_len = text.len().min(256); f.debug_struct("Text") .field("data", &&text[0..max_len]) @@ -280,9 +235,8 @@ impl Debug for RefValue { .field("truncated", &(text.len() > max_len)) .finish() } - RefValue::Blob(raw_slice) => { + ValueRef::Blob(blob) => { // truncate blob_slice to at most 32 bytes - let blob = raw_slice.to_slice(); let max_len = blob.len().min(32); f.debug_struct("Blob") .field("data", &&blob[0..max_len]) @@ -295,6 +249,16 @@ impl Debug for RefValue { } impl Value { + pub fn as_ref<'a>(&'a self) -> ValueRef<'a> { + match self { + Value::Null => ValueRef::Null, + Value::Integer(v) => ValueRef::Integer(*v), + Value::Float(v) => ValueRef::Float(*v), + Value::Text(v) => ValueRef::Text(v.value.as_slice(), v.subtype), + Value::Blob(v) => ValueRef::Blob(v.as_slice()), + } + } + // A helper function that makes building a text Value easier. pub fn build_text(text: impl AsRef) -> Self { Self::Text(Text::new(text.as_ref())) @@ -833,34 +797,36 @@ impl std::ops::DivAssign for Value { } } -impl<'a> TryFrom<&'a RefValue> for i64 { +impl TryFrom> for i64 { type Error = LimboError; - fn try_from(value: &'a RefValue) -> Result { + fn try_from(value: ValueRef<'_>) -> Result { match value { - RefValue::Integer(i) => Ok(*i), + ValueRef::Integer(i) => Ok(i), _ => Err(LimboError::ConversionError("Expected integer value".into())), } } } -impl<'a> TryFrom<&'a RefValue> for String { +impl TryFrom> for String { type Error = LimboError; - fn try_from(value: &'a RefValue) -> Result { + fn try_from(value: ValueRef<'_>) -> Result { match value { - RefValue::Text(s) => Ok(s.as_str().to_string()), + ValueRef::Text(s, _) => Ok(String::from_utf8_lossy(s).to_string()), _ => Err(LimboError::ConversionError("Expected text value".into())), } } } -impl<'a> TryFrom<&'a RefValue> for &'a str { +impl<'a> TryFrom> for &'a str { type Error = LimboError; - fn try_from(value: &'a RefValue) -> Result { + fn try_from(value: ValueRef<'a>) -> Result { match value { - RefValue::Text(s) => Ok(s.as_str()), + ValueRef::Text(s, _) => Ok(str::from_utf8(s).map_err(|_| { + LimboError::ConversionError("Expected a valid UTF8 string".to_string()) + })?), _ => Err(LimboError::ConversionError("Expected text value".into())), } } @@ -987,7 +953,7 @@ impl ImmutableRecord { // TODO: inline the complete record parsing code here. // Its probably more efficient. - pub fn get_values(&self) -> Vec { + pub fn get_values<'a>(&'a self) -> Vec> { let mut cursor = RecordCursor::new(); cursor.get_values(self).unwrap_or_default() } @@ -1007,7 +973,6 @@ impl ImmutableRecord { values: impl IntoIterator + Clone, len: usize, ) -> Self { - let mut ref_values = Vec::with_capacity(len); let mut serials = Vec::with_capacity(len); let mut size_header = 0; let mut size_values = 0; @@ -1044,13 +1009,9 @@ impl ImmutableRecord { // write content for value in values { - let start_offset = writer.pos; match value { - Value::Null => { - ref_values.push(RefValue::Null); - } + Value::Null => {} Value::Integer(i) => { - ref_values.push(RefValue::Integer(*i)); let serial_type = SerialType::from(value); match serial_type.kind() { SerialTypeKind::ConstInt0 | SerialTypeKind::ConstInt1 => {} @@ -1065,27 +1026,12 @@ impl ImmutableRecord { other => panic!("Serial type is not an integer: {other:?}"), } } - Value::Float(f) => { - ref_values.push(RefValue::Float(*f)); - writer.extend_from_slice(&f.to_be_bytes()) - } + Value::Float(f) => writer.extend_from_slice(&f.to_be_bytes()), Value::Text(t) => { writer.extend_from_slice(&t.value); - let end_offset = writer.pos; - let len = end_offset - start_offset; - let ptr = unsafe { writer.buf.as_ptr().add(start_offset) }; - let value = RefValue::Text(TextRef { - value: RawSlice::new(ptr, len), - subtype: t.subtype, - }); - ref_values.push(value); } Value::Blob(b) => { writer.extend_from_slice(b); - let end_offset = writer.pos; - let len = end_offset - start_offset; - let ptr = unsafe { writer.buf.as_ptr().add(start_offset) }; - ref_values.push(RefValue::Blob(RawSlice::new(ptr, len))); } }; } @@ -1132,7 +1078,10 @@ impl ImmutableRecord { // TODO: its probably better to not instantiate the RecordCurosr. Instead do the deserialization // inside the function. - pub fn last_value(&self, record_cursor: &mut RecordCursor) -> Option> { + pub fn last_value<'a>( + &'a self, + record_cursor: &mut RecordCursor, + ) -> Option>> { if self.is_invalidated() { return Some(Err(LimboError::InternalError( "Record is invalidated".into(), @@ -1143,12 +1092,12 @@ impl ImmutableRecord { Some(record_cursor.get_value(self, last_idx)) } - pub fn get_value(&self, idx: usize) -> Result { + pub fn get_value<'a>(&'a self, idx: usize) -> Result> { let mut cursor = RecordCursor::new(); cursor.get_value(self, idx) } - pub fn get_value_opt(&self, idx: usize) -> Option { + pub fn get_value_opt<'a>(&'a self, idx: usize) -> Option> { if self.is_invalidated() { return None; } @@ -1314,23 +1263,27 @@ impl RecordCursor { /// # Special Cases /// /// - Returns `RefValue::Null` for out-of-bounds indices - pub fn deserialize_column(&self, record: &ImmutableRecord, idx: usize) -> Result { + pub fn deserialize_column<'a>( + &self, + record: &'a ImmutableRecord, + idx: usize, + ) -> Result> { if idx >= self.serial_types.len() { - return Ok(RefValue::Null); + return Ok(ValueRef::Null); } let serial_type = self.serial_types[idx]; let serial_type_obj = SerialType::try_from(serial_type)?; match serial_type_obj.kind() { - SerialTypeKind::Null => return Ok(RefValue::Null), - SerialTypeKind::ConstInt0 => return Ok(RefValue::Integer(0)), - SerialTypeKind::ConstInt1 => return Ok(RefValue::Integer(1)), + SerialTypeKind::Null => return Ok(ValueRef::Null), + SerialTypeKind::ConstInt0 => return Ok(ValueRef::Integer(0)), + SerialTypeKind::ConstInt1 => return Ok(ValueRef::Integer(1)), _ => {} // continue } if idx + 1 >= self.offsets.len() { - return Ok(RefValue::Null); + return Ok(ValueRef::Null); } let start = self.offsets[idx]; @@ -1358,7 +1311,11 @@ impl RecordCursor { /// * `Err(LimboError)` - Access failed due to invalid record or parsing error /// #[inline(always)] - pub fn get_value(&mut self, record: &ImmutableRecord, idx: usize) -> Result { + pub fn get_value<'a>( + &mut self, + record: &'a ImmutableRecord, + idx: usize, + ) -> Result> { if record.is_invalidated() { return Err(LimboError::InternalError("Record not initialized".into())); } @@ -1380,11 +1337,11 @@ impl RecordCursor { /// * `Some(Err(LimboError))` - Parsing succeeded but deserialization failed /// * `None` - Record is invalid or index is out of bounds /// - pub fn get_value_opt( + pub fn get_value_opt<'a>( &mut self, - record: &ImmutableRecord, + record: &'a ImmutableRecord, idx: usize, - ) -> Option> { + ) -> Option>> { if record.is_invalidated() { return None; } @@ -1443,7 +1400,7 @@ impl RecordCursor { /// * `Ok(Vec)` - All values in column order /// * `Err(LimboError)` - Parsing or deserialization failed /// - pub fn get_values(&mut self, record: &ImmutableRecord) -> Result> { + pub fn get_values<'a>(&mut self, record: &'a ImmutableRecord) -> Result>> { if record.is_invalidated() { return Ok(Vec::new()); } @@ -1459,62 +1416,62 @@ impl RecordCursor { } } -impl RefValue { +impl<'a> ValueRef<'a> { pub fn to_ffi(&self) -> ExtValue { match self { Self::Null => ExtValue::null(), Self::Integer(i) => ExtValue::from_integer(*i), Self::Float(fl) => ExtValue::from_float(*fl), - Self::Text(text) => ExtValue::from_text( - std::str::from_utf8(text.value.to_slice()) - .unwrap() - .to_string(), - ), - Self::Blob(blob) => ExtValue::from_blob(blob.to_slice().to_vec()), + Self::Text(text, _) => { + ExtValue::from_text(std::str::from_utf8(text).unwrap().to_string()) + } + Self::Blob(blob) => ExtValue::from_blob(blob.to_vec()), + } + } + + pub fn to_blob(&self) -> Option<&'a [u8]> { + match self { + Self::Blob(blob) => Some(*blob), + _ => None, } } pub fn to_owned(&self) -> Value { match self { - RefValue::Null => Value::Null, - RefValue::Integer(i) => Value::Integer(*i), - RefValue::Float(f) => Value::Float(*f), - RefValue::Text(text_ref) => Value::Text(Text { - value: text_ref.value.to_slice().to_vec(), - subtype: text_ref.subtype, + ValueRef::Null => Value::Null, + ValueRef::Integer(i) => Value::Integer(*i), + ValueRef::Float(f) => Value::Float(*f), + ValueRef::Text(text, subtype) => Value::Text(Text { + value: text.to_vec(), + subtype: *subtype, }), - RefValue::Blob(b) => Value::Blob(b.to_slice().to_vec()), - } - } - pub fn to_blob(&self) -> Option<&[u8]> { - match self { - Self::Blob(blob) => Some(blob.to_slice()), - _ => None, + ValueRef::Blob(b) => Value::Blob(b.to_vec()), } } } -impl Display for RefValue { +impl Display for ValueRef<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Null => write!(f, "NULL"), Self::Integer(i) => write!(f, "{i}"), Self::Float(fl) => write!(f, "{fl:?}"), - Self::Text(s) => write!(f, "{}", s.as_str()), - Self::Blob(b) => write!(f, "{}", String::from_utf8_lossy(b.to_slice())), + Self::Text(s, _) => write!(f, "{}", String::from_utf8_lossy(s)), + Self::Blob(b) => write!(f, "{}", String::from_utf8_lossy(b)), } } } -impl Eq for RefValue {} -impl Ord for RefValue { +impl Eq for ValueRef<'_> {} + +impl Ord for ValueRef<'_> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.partial_cmp(other).unwrap() } } #[allow(clippy::non_canonical_partial_ord_impl)] -impl PartialOrd for RefValue { +impl<'a> PartialOrd> for ValueRef<'a> { fn partial_cmp(&self, other: &Self) -> Option { match (self, other) { (Self::Integer(int_left), Self::Integer(int_right)) => int_left.partial_cmp(int_right), @@ -1528,24 +1485,21 @@ impl PartialOrd for RefValue { float_left.partial_cmp(float_right) } // Numeric vs Text/Blob - (Self::Integer(_) | Self::Float(_), Self::Text(_) | Self::Blob(_)) => { + (Self::Integer(_) | Self::Float(_), Self::Text(_, _) | Self::Blob(_)) => { Some(std::cmp::Ordering::Less) } - (Self::Text(_) | Self::Blob(_), Self::Integer(_) | Self::Float(_)) => { + (Self::Text(_, _) | Self::Blob(_), Self::Integer(_) | Self::Float(_)) => { Some(std::cmp::Ordering::Greater) } - (Self::Text(text_left), Self::Text(text_right)) => text_left - .value - .to_slice() - .partial_cmp(text_right.value.to_slice()), - // Text vs Blob - (Self::Text(_), Self::Blob(_)) => Some(std::cmp::Ordering::Less), - (Self::Blob(_), Self::Text(_)) => Some(std::cmp::Ordering::Greater), - - (Self::Blob(blob_left), Self::Blob(blob_right)) => { - blob_left.to_slice().partial_cmp(blob_right.to_slice()) + (Self::Text(text_left, _), Self::Text(text_right, _)) => { + text_left.partial_cmp(text_right) } + // Text vs Blob + (Self::Text(_, _), Self::Blob(_)) => Some(std::cmp::Ordering::Less), + (Self::Blob(_), Self::Text(_, _)) => Some(std::cmp::Ordering::Greater), + + (Self::Blob(blob_left), Self::Blob(blob_right)) => blob_left.partial_cmp(blob_right), (Self::Null, Self::Null) => Some(std::cmp::Ordering::Equal), (Self::Null, _) => Some(std::cmp::Ordering::Less), (_, Self::Null) => Some(std::cmp::Ordering::Greater), @@ -1635,8 +1589,8 @@ impl IndexInfo { } pub fn compare_immutable( - l: &[RefValue], - r: &[RefValue], + l: &[ValueRef], + r: &[ValueRef], column_info: &[KeyInfo], ) -> std::cmp::Ordering { assert_eq!(l.len(), r.len()); @@ -1645,9 +1599,10 @@ pub fn compare_immutable( let column_order = column_info[i].sort_order; let collation = column_info[i].collation; let cmp = match (l, r) { - (RefValue::Text(left), RefValue::Text(right)) => { - collation.compare_strings(left.as_str(), right.as_str()) - } + (ValueRef::Text(left, _), ValueRef::Text(right, _)) => collation.compare_strings( + &String::from_utf8_lossy(left), + &String::from_utf8_lossy(right), + ), _ => l.partial_cmp(r).unwrap(), }; if !cmp.is_eq() { @@ -1671,7 +1626,7 @@ impl RecordCompare { pub fn compare( &self, serialized: &ImmutableRecord, - unpacked: &[RefValue], + unpacked: &[ValueRef], index_info: &IndexInfo, skip: usize, tie_breaker: std::cmp::Ordering, @@ -1690,11 +1645,11 @@ impl RecordCompare { } } -pub fn find_compare(unpacked: &[RefValue], index_info: &IndexInfo) -> RecordCompare { +pub fn find_compare(unpacked: &[ValueRef], index_info: &IndexInfo) -> RecordCompare { if !unpacked.is_empty() && index_info.num_cols <= 13 { match &unpacked[0] { - RefValue::Integer(_) => RecordCompare::Int, - RefValue::Text(_) if index_info.key_info[0].collation == CollationSeq::Binary => { + ValueRef::Integer(_) => RecordCompare::Int, + ValueRef::Text(_, _) if index_info.key_info[0].collation == CollationSeq::Binary => { RecordCompare::String } _ => RecordCompare::Generic, @@ -1760,7 +1715,7 @@ pub fn get_tie_breaker_from_seek_op(seek_op: SeekOp) -> std::cmp::Ordering { /// delegates to `compare_records_generic()` with `skip=1` fn compare_records_int( serialized: &ImmutableRecord, - unpacked: &[RefValue], + unpacked: &[ValueRef], index_info: &IndexInfo, tie_breaker: std::cmp::Ordering, ) -> Result { @@ -1794,7 +1749,7 @@ fn compare_records_int( let data_start = header_size; let lhs_int = read_integer(&payload[data_start..], first_serial_type as u8)?; - let RefValue::Integer(rhs_int) = unpacked[0] else { + let ValueRef::Integer(rhs_int) = unpacked[0] else { return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); }; let comparison = match index_info.key_info[0].sort_order { @@ -1853,7 +1808,7 @@ fn compare_records_int( /// delegates to `compare_records_generic()` with `skip=1` fn compare_records_string( serialized: &ImmutableRecord, - unpacked: &[RefValue], + unpacked: &[ValueRef], index_info: &IndexInfo, tie_breaker: std::cmp::Ordering, ) -> Result { @@ -1884,7 +1839,7 @@ fn compare_records_string( return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); } - let RefValue::Text(rhs_text) = &unpacked[0] else { + let ValueRef::Text(rhs_text, _) = &unpacked[0] else { return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); }; @@ -1896,12 +1851,15 @@ fn compare_records_string( let serial_type = SerialType::try_from(first_serial_type)?; let (lhs_value, _) = read_value(&payload[data_start..], serial_type)?; - let RefValue::Text(lhs_text) = lhs_value else { + let ValueRef::Text(lhs_text, _) = lhs_value else { return compare_records_generic(serialized, unpacked, index_info, 0, tie_breaker); }; let collation = index_info.key_info[0].collation; - let comparison = collation.compare_strings(lhs_text.as_str(), rhs_text.as_str()); + let comparison = collation.compare_strings( + &String::from_utf8_lossy(lhs_text), + &String::from_utf8_lossy(rhs_text), + ); let final_comparison = match index_info.key_info[0].sort_order { SortOrder::Asc => comparison, @@ -1910,7 +1868,7 @@ fn compare_records_string( match final_comparison { std::cmp::Ordering::Equal => { - let len_cmp = lhs_text.value.len.cmp(&rhs_text.value.len); + let len_cmp = lhs_text.len().cmp(&rhs_text.len()); if len_cmp != std::cmp::Ordering::Equal { let adjusted = match index_info.key_info[0].sort_order { SortOrder::Asc => len_cmp, @@ -1962,7 +1920,7 @@ fn compare_records_string( /// `tie_breaker` is returned. pub fn compare_records_generic( serialized: &ImmutableRecord, - unpacked: &[RefValue], + unpacked: &[ValueRef], index_info: &IndexInfo, skip: usize, tie_breaker: std::cmp::Ordering, @@ -2009,9 +1967,9 @@ pub fn compare_records_generic( let rhs_value = &unpacked[field_idx]; let lhs_value = match serial_type.kind() { - SerialTypeKind::ConstInt0 => RefValue::Integer(0), - SerialTypeKind::ConstInt1 => RefValue::Integer(1), - SerialTypeKind::Null => RefValue::Null, + SerialTypeKind::ConstInt0 => ValueRef::Integer(0), + SerialTypeKind::ConstInt1 => ValueRef::Integer(1), + SerialTypeKind::Null => ValueRef::Null, _ => { let (value, field_size) = read_value(&payload[data_pos..], serial_type)?; data_pos += field_size; @@ -2020,15 +1978,18 @@ pub fn compare_records_generic( }; let comparison = match (&lhs_value, rhs_value) { - (RefValue::Text(lhs_text), RefValue::Text(rhs_text)) => index_info.key_info[field_idx] - .collation - .compare_strings(lhs_text.as_str(), rhs_text.as_str()), + (ValueRef::Text(lhs_text, _), ValueRef::Text(rhs_text, _)) => { + index_info.key_info[field_idx].collation.compare_strings( + &String::from_utf8_lossy(lhs_text), + &String::from_utf8_lossy(rhs_text), + ) + } - (RefValue::Integer(lhs_int), RefValue::Float(rhs_float)) => { + (ValueRef::Integer(lhs_int), ValueRef::Float(rhs_float)) => { sqlite_int_float_compare(*lhs_int, *rhs_float) } - (RefValue::Float(lhs_float), RefValue::Integer(rhs_int)) => { + (ValueRef::Float(lhs_float), ValueRef::Integer(rhs_int)) => { sqlite_int_float_compare(*rhs_int, *lhs_float).reverse() } @@ -2309,7 +2270,7 @@ impl Record { } pub enum Cursor { - BTree(Box), + BTree(Box), Pseudo(PseudoCursor), Sorter(Sorter), Virtual(VirtualTableCursor), @@ -2329,8 +2290,8 @@ impl Debug for Cursor { } impl Cursor { - pub fn new_btree(cursor: BTreeCursor) -> Self { - Self::BTree(Box::new(cursor)) + pub fn new_btree(cursor: Box) -> Self { + Self::BTree(cursor) } pub fn new_pseudo(cursor: PseudoCursor) -> Self { @@ -2347,9 +2308,9 @@ impl Cursor { Self::MaterializedView(Box::new(cursor)) } - pub fn as_btree_mut(&mut self) -> &mut BTreeCursor { + pub fn as_btree_mut(&mut self) -> &mut dyn CursorTrait { match self { - Self::BTree(cursor) => cursor, + Self::BTree(cursor) => cursor.as_mut(), _ => panic!("Cursor is not a btree"), } } @@ -2389,7 +2350,6 @@ impl Cursor { #[must_use] pub enum IOCompletions { Single(Completion), - Many(Vec), } impl IOCompletions { @@ -2397,26 +2357,12 @@ impl IOCompletions { pub fn wait(self, io: &I) -> Result<()> { match self { IOCompletions::Single(c) => io.wait_for_completion(c), - IOCompletions::Many(completions) => { - let mut completions = completions.into_iter(); - while let Some(c) = completions.next() { - let res = io.wait_for_completion(c); - if res.is_err() { - for c in completions { - c.abort(); - } - return res; - } - } - Ok(()) - } } } pub fn finished(&self) -> bool { match self { IOCompletions::Single(c) => c.finished(), - IOCompletions::Many(completions) => completions.iter().all(|c| c.finished()), } } @@ -2424,14 +2370,20 @@ impl IOCompletions { pub fn abort(&self) { match self { IOCompletions::Single(c) => c.abort(), - IOCompletions::Many(completions) => completions.iter().for_each(|c| c.abort()), } } pub fn get_error(&self) -> Option { match self { IOCompletions::Single(c) => c.get_error(), - IOCompletions::Many(completions) => completions.iter().find_map(|c| c.get_error()), + } + } + + pub fn set_waker(&self, waker: Option<&Waker>) { + if let Some(waker) = waker { + match self { + IOCompletions::Single(c) => c.set_waker(waker), + } } } } @@ -2553,27 +2505,6 @@ pub enum SeekKey<'a> { IndexKey(&'a ImmutableRecord), } -impl RawSlice { - pub fn create_from(value: &[u8]) -> Self { - if value.is_empty() { - RawSlice::new(std::ptr::null(), 0) - } else { - let ptr = &value[0] as *const u8; - RawSlice::new(ptr, value.len()) - } - } - pub fn new(data: *const u8, len: usize) -> Self { - Self { data, len } - } - pub fn to_slice(&self) -> &[u8] { - if self.data.is_null() { - &[] - } else { - unsafe { std::slice::from_raw_parts(self.data, self.len) } - } - } -} - #[derive(Debug)] pub enum DatabaseChangeType { Delete, @@ -2623,8 +2554,8 @@ mod tests { use crate::translate::collate::CollationSeq; pub fn compare_immutable_for_testing( - l: &[RefValue], - r: &[RefValue], + l: &[ValueRef], + r: &[ValueRef], index_key_info: &[KeyInfo], tie_breaker: std::cmp::Ordering, ) -> std::cmp::Ordering { @@ -2635,9 +2566,10 @@ mod tests { let collation = index_key_info[i].collation; let cmp = match (&l[i], &r[i]) { - (RefValue::Text(left), RefValue::Text(right)) => { - collation.compare_strings(left.as_str(), right.as_str()) - } + (ValueRef::Text(left, _), ValueRef::Text(right, _)) => collation.compare_strings( + &String::from_utf8_lossy(left), + &String::from_utf8_lossy(right), + ), _ => l[i].partial_cmp(&r[i]).unwrap_or(std::cmp::Ordering::Equal), }; @@ -2676,47 +2608,16 @@ mod tests { } } - fn value_to_ref_value(value: &Value) -> RefValue { - match value { - Value::Null => RefValue::Null, - Value::Integer(i) => RefValue::Integer(*i), - Value::Float(f) => RefValue::Float(*f), - Value::Text(text) => RefValue::Text(TextRef { - value: RawSlice::from_slice(&text.value), - subtype: text.subtype, - }), - Value::Blob(blob) => RefValue::Blob(RawSlice::from_slice(blob)), - } - } - - impl TextRef { - fn from_str(s: &str) -> Self { - TextRef { - value: RawSlice::from_slice(s.as_bytes()), - subtype: crate::types::TextSubtype::Text, - } - } - } - - impl RawSlice { - fn from_slice(data: &[u8]) -> Self { - Self { - data: data.as_ptr(), - len: data.len(), - } - } - } - fn assert_compare_matches_full_comparison( serialized_values: Vec, - unpacked_values: Vec, + unpacked_values: Vec, index_info: &IndexInfo, test_name: &str, ) { let serialized = create_record(serialized_values.clone()); - let serialized_ref_values: Vec = - serialized_values.iter().map(value_to_ref_value).collect(); + let serialized_ref_values: Vec = + serialized_values.iter().map(Value::as_ref).collect(); let tie_breaker = std::cmp::Ordering::Equal; @@ -2815,52 +2716,52 @@ mod tests { let test_cases = vec![ ( vec![Value::Integer(42)], - vec![RefValue::Integer(42)], + vec![ValueRef::Integer(42)], "equal_integers", ), ( vec![Value::Integer(10)], - vec![RefValue::Integer(20)], + vec![ValueRef::Integer(20)], "less_than_integers", ), ( vec![Value::Integer(30)], - vec![RefValue::Integer(20)], + vec![ValueRef::Integer(20)], "greater_than_integers", ), ( vec![Value::Integer(0)], - vec![RefValue::Integer(0)], + vec![ValueRef::Integer(0)], "zero_integers", ), ( vec![Value::Integer(-5)], - vec![RefValue::Integer(-5)], + vec![ValueRef::Integer(-5)], "negative_integers", ), ( vec![Value::Integer(i64::MAX)], - vec![RefValue::Integer(i64::MAX)], + vec![ValueRef::Integer(i64::MAX)], "max_integers", ), ( vec![Value::Integer(i64::MIN)], - vec![RefValue::Integer(i64::MIN)], + vec![ValueRef::Integer(i64::MIN)], "min_integers", ), ( vec![Value::Integer(42), Value::Text(Text::new("hello"))], vec![ - RefValue::Integer(42), - RefValue::Text(TextRef::from_str("hello")), + ValueRef::Integer(42), + ValueRef::Text(b"hello", TextSubtype::Text), ], "integer_text_equal", ), ( vec![Value::Integer(42), Value::Text(Text::new("hello"))], vec![ - RefValue::Integer(42), - RefValue::Text(TextRef::from_str("world")), + ValueRef::Integer(42), + ValueRef::Text(b"world", TextSubtype::Text), ], "integer_equal_text_different", ), @@ -2887,43 +2788,43 @@ mod tests { let test_cases = vec![ ( vec![Value::Text(Text::new("hello"))], - vec![RefValue::Text(TextRef::from_str("hello"))], + vec![ValueRef::Text(b"hello", TextSubtype::Text)], "equal_strings", ), ( vec![Value::Text(Text::new("abc"))], - vec![RefValue::Text(TextRef::from_str("def"))], + vec![ValueRef::Text(b"def", TextSubtype::Text)], "less_than_strings", ), ( vec![Value::Text(Text::new("xyz"))], - vec![RefValue::Text(TextRef::from_str("abc"))], + vec![ValueRef::Text(b"abc", TextSubtype::Text)], "greater_than_strings", ), ( vec![Value::Text(Text::new(""))], - vec![RefValue::Text(TextRef::from_str(""))], + vec![ValueRef::Text(b"", TextSubtype::Text)], "empty_strings", ), ( vec![Value::Text(Text::new("a"))], - vec![RefValue::Text(TextRef::from_str("aa"))], + vec![ValueRef::Text(b"aa", TextSubtype::Text)], "prefix_strings", ), // Multi-field with string first ( vec![Value::Text(Text::new("hello")), Value::Integer(42)], vec![ - RefValue::Text(TextRef::from_str("hello")), - RefValue::Integer(42), + ValueRef::Text(b"hello", TextSubtype::Text), + ValueRef::Integer(42), ], "string_integer_equal", ), ( vec![Value::Text(Text::new("hello")), Value::Integer(42)], vec![ - RefValue::Text(TextRef::from_str("hello")), - RefValue::Integer(99), + ValueRef::Text(b"hello", TextSubtype::Text), + ValueRef::Integer(99), ], "string_equal_integer_different", ), @@ -2948,65 +2849,65 @@ mod tests { // NULL vs others ( vec![Value::Null], - vec![RefValue::Integer(42)], + vec![ValueRef::Integer(42)], "null_vs_integer", ), ( vec![Value::Null], - vec![RefValue::Float(64.4)], + vec![ValueRef::Float(64.4)], "null_vs_float", ), ( vec![Value::Null], - vec![RefValue::Text(TextRef::from_str("hello"))], + vec![ValueRef::Text(b"hello", TextSubtype::Text)], "null_vs_text", ), ( vec![Value::Null], - vec![RefValue::Blob(RawSlice::from_slice(b"blob"))], + vec![ValueRef::Blob(b"blob")], "null_vs_blob", ), // Numbers vs Text/Blob ( vec![Value::Integer(42)], - vec![RefValue::Text(TextRef::from_str("hello"))], + vec![ValueRef::Text(b"hello", TextSubtype::Text)], "integer_vs_text", ), ( vec![Value::Float(64.4)], - vec![RefValue::Text(TextRef::from_str("hello"))], + vec![ValueRef::Text(b"hello", TextSubtype::Text)], "float_vs_text", ), ( vec![Value::Integer(42)], - vec![RefValue::Blob(RawSlice::from_slice(b"blob"))], + vec![ValueRef::Blob(b"blob")], "integer_vs_blob", ), ( vec![Value::Float(64.4)], - vec![RefValue::Blob(RawSlice::from_slice(b"blob"))], + vec![ValueRef::Blob(b"blob")], "float_vs_blob", ), // Text vs Blob ( vec![Value::Text(Text::new("hello"))], - vec![RefValue::Blob(RawSlice::from_slice(b"blob"))], + vec![ValueRef::Blob(b"blob")], "text_vs_blob", ), // Integer vs Float (affinity conversion) ( vec![Value::Integer(42)], - vec![RefValue::Float(42.0)], + vec![ValueRef::Float(42.0)], "integer_vs_equal_float", ), ( vec![Value::Integer(42)], - vec![RefValue::Float(42.5)], + vec![ValueRef::Float(42.5)], "integer_vs_different_float", ), ( vec![Value::Float(42.5)], - vec![RefValue::Integer(42)], + vec![ValueRef::Integer(42)], "float_vs_integer", ), ]; @@ -3033,20 +2934,20 @@ mod tests { // DESC order should reverse first field comparison ( vec![Value::Integer(10)], - vec![RefValue::Integer(20)], + vec![ValueRef::Integer(20)], "desc_integer_reversed", ), ( vec![Value::Text(Text::new("abc"))], - vec![RefValue::Text(TextRef::from_str("def"))], + vec![ValueRef::Text(b"def", TextSubtype::Text)], "desc_string_reversed", ), // Mixed sort orders ( vec![Value::Integer(10), Value::Text(Text::new("hello"))], vec![ - RefValue::Integer(20), - RefValue::Text(TextRef::from_str("hello")), + ValueRef::Integer(20), + ValueRef::Text(b"hello", TextSubtype::Text), ], "desc_first_asc_second", ), @@ -3071,38 +2972,38 @@ mod tests { ( vec![Value::Integer(42)], vec![ - RefValue::Integer(42), - RefValue::Text(TextRef::from_str("extra")), + ValueRef::Integer(42), + ValueRef::Text(b"extra", TextSubtype::Text), ], "fewer_serialized_fields", ), ( vec![Value::Integer(42), Value::Text(Text::new("extra"))], - vec![RefValue::Integer(42)], + vec![ValueRef::Integer(42)], "fewer_unpacked_fields", ), (vec![], vec![], "both_empty"), - (vec![], vec![RefValue::Integer(42)], "empty_serialized"), + (vec![], vec![ValueRef::Integer(42)], "empty_serialized"), ( (0..15).map(Value::Integer).collect(), - (0..15).map(RefValue::Integer).collect(), + (0..15).map(ValueRef::Integer).collect(), "large_field_count", ), ( vec![Value::Blob(vec![1, 2, 3])], - vec![RefValue::Blob(RawSlice::from_slice(&[1, 2, 3]))], + vec![ValueRef::Blob(&[1, 2, 3])], "blob_first_field", ), ( vec![Value::Text(Text::new("hello")), Value::Integer(5)], - vec![RefValue::Text(TextRef::from_str("hello"))], + vec![ValueRef::Text(b"hello", TextSubtype::Text)], "equal_text_prefix_but_more_serialized_fields", ), ( vec![Value::Text(Text::new("same")), Value::Integer(5)], vec![ - RefValue::Text(TextRef::from_str("same")), - RefValue::Integer(5), + ValueRef::Text(b"same", TextSubtype::Text), + ValueRef::Integer(5), ], "equal_text_then_equal_int", ), @@ -3132,9 +3033,9 @@ mod tests { Value::Integer(3), ]); let unpacked = vec![ - RefValue::Integer(1), - RefValue::Integer(99), - RefValue::Integer(3), + ValueRef::Integer(1), + ValueRef::Integer(99), + ValueRef::Integer(3), ]; let tie_breaker = std::cmp::Ordering::Equal; @@ -3160,8 +3061,8 @@ mod tests { let index_info_large = create_index_info(15, vec![SortOrder::Asc; 15], collations_large); let int_values = vec![ - RefValue::Integer(42), - RefValue::Text(TextRef::from_str("hello")), + ValueRef::Integer(42), + ValueRef::Text(b"hello", TextSubtype::Text), ]; assert!(matches!( find_compare(&int_values, &index_info_small), @@ -3169,21 +3070,21 @@ mod tests { )); let string_values = vec![ - RefValue::Text(TextRef::from_str("hello")), - RefValue::Integer(42), + ValueRef::Text(b"hello", TextSubtype::Text), + ValueRef::Integer(42), ]; assert!(matches!( find_compare(&string_values, &index_info_small), RecordCompare::String )); - let large_values: Vec = (0..15).map(RefValue::Integer).collect(); + let large_values: Vec = (0..15).map(ValueRef::Integer).collect(); assert!(matches!( find_compare(&large_values, &index_info_large), RecordCompare::Generic )); - let blob_values = vec![RefValue::Blob(RawSlice::from_slice(&[1, 2, 3]))]; + let blob_values = vec![ValueRef::Blob(&[1, 2, 3])]; assert!(matches!( find_compare(&blob_values, &index_info_small), RecordCompare::Generic diff --git a/core/util.rs b/core/util.rs index b35fbdaa6..77062fd7d 100644 --- a/core/util.rs +++ b/core/util.rs @@ -11,6 +11,7 @@ use crate::{ LimboError, OpenFlags, Result, Statement, StepResult, SymbolTable, }; use crate::{Connection, MvStore, IO}; +use std::sync::atomic::AtomicU8; use std::{ collections::HashMap, rc::Rc, @@ -29,12 +30,6 @@ macro_rules! io_yield_one { return Ok(IOResult::IO(IOCompletions::Single($c))); }; } -#[macro_export] -macro_rules! io_yield_many { - ($v:expr) => { - return Ok(IOResult::IO(IOCompletions::Many($v))); - }; -} #[macro_export] macro_rules! eq_ignore_ascii_case { @@ -312,7 +307,7 @@ pub fn module_args_from_sql(sql: &str) -> Result> { pub fn check_literal_equivalency(lhs: &Literal, rhs: &Literal) -> bool { match (lhs, rhs) { (Literal::Numeric(n1), Literal::Numeric(n2)) => cmp_numeric_strings(n1, n2), - (Literal::String(s1), Literal::String(s2)) => check_ident_equivalency(s1, s2), + (Literal::String(s1), Literal::String(s2)) => s1 == s2, (Literal::Blob(b1), Literal::Blob(b2)) => b1 == b2, (Literal::Keyword(k1), Literal::Keyword(k2)) => check_ident_equivalency(k1, k2), (Literal::Null, Literal::Null) => true, @@ -1337,6 +1332,39 @@ pub fn extract_view_columns( Ok(ViewColumnSchema { tables, columns }) } +pub fn rewrite_fk_parent_cols_if_self_ref( + clause: &mut ast::ForeignKeyClause, + table: &str, + from: &str, + to: &str, +) { + if normalize_ident(clause.tbl_name.as_str()) == normalize_ident(table) { + for c in &mut clause.columns { + if normalize_ident(c.col_name.as_str()) == normalize_ident(from) { + c.col_name = ast::Name::exact(to.to_owned()); + } + } + } +} + +/// Update a column-level REFERENCES (col,...) constraint +pub fn rewrite_column_references_if_needed( + col: &mut ast::ColumnDefinition, + table: &str, + from: &str, + to: &str, +) { + for cc in &mut col.constraints { + if let ast::NamedColumnConstraint { + constraint: ast::ColumnConstraint::ForeignKey { clause, .. }, + .. + } = cc + { + rewrite_fk_parent_cols_if_self_ref(clause, table, from, to); + } + } +} + #[cfg(test)] pub mod tests { use super::*; diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 3d1a333ec..3ac364226 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -124,6 +124,10 @@ pub struct ProgramBuilder { current_parent_explain_idx: Option, pub param_ctx: ParamState, pub(crate) reg_result_cols_start: Option, + /// Whether the program needs to use statement subtransactions, + /// i.e. the individual statement may need to be aborted due to a constraint conflict, etc. + /// instead of the entire transaction. + needs_stmt_subtransactions: bool, } #[derive(Debug, Clone)] @@ -211,9 +215,14 @@ impl ProgramBuilder { current_parent_explain_idx: None, param_ctx: ParamState::default(), reg_result_cols_start: None, + needs_stmt_subtransactions: false, } } + pub fn set_needs_stmt_subtransactions(&mut self, needs_stmt_subtransactions: bool) { + self.needs_stmt_subtransactions = needs_stmt_subtransactions; + } + pub fn capture_data_changes_mode(&self) -> &CaptureDataChangesMode { &self.capture_data_changes_mode } @@ -311,6 +320,18 @@ impl ProgramBuilder { self._alloc_cursor_id(Some(key), cursor_type) } + pub fn alloc_cursor_id_keyed_if_not_exists( + &mut self, + key: CursorKey, + cursor_type: CursorType, + ) -> usize { + if let Some(cursor_id) = self.resolve_cursor_id_safe(&key) { + cursor_id + } else { + self._alloc_cursor_id(Some(key), cursor_type) + } + } + pub fn alloc_cursor_id(&mut self, cursor_type: CursorType) -> usize { self._alloc_cursor_id(None, cursor_type) } @@ -341,6 +362,14 @@ impl ProgramBuilder { self.insns.push((insn, self.insns.len())); } + /// Emit an instruction that is guaranteed not to be in any constant span. + /// This ensures the instruction won't be hoisted when emit_constant_insns is called. + #[instrument(skip(self), level = Level::DEBUG)] + pub fn emit_no_constant_insn(&mut self, insn: Insn) { + self.constant_span_end_all(); + self.emit_insn(insn); + } + pub fn close_cursors(&mut self, cursors: &[CursorID]) { for cursor in cursors { self.emit_insn(Insn::Close { cursor_id: *cursor }); @@ -792,6 +821,9 @@ impl ProgramBuilder { Insn::NotFound { target_pc, .. } => { resolve(target_pc, "NotFound"); } + Insn::FkIfZero { target_pc, .. } => { + resolve(target_pc, "FkIfZero"); + } _ => {} } } @@ -1006,6 +1038,11 @@ impl ProgramBuilder { self.resolve_labels(); self.parameters.list.dedup(); + + if !self.table_references.is_empty() && matches!(self.txn_mode, TransactionMode::Write) { + self.needs_stmt_subtransactions = true; + } + Program { max_registers: self.next_free_register, insns: self.insns, @@ -1019,6 +1056,7 @@ impl ProgramBuilder { table_references: self.table_references, sql: sql.to_string(), accesses_db: !matches!(self.txn_mode, TransactionMode::None), + needs_stmt_subtransactions: self.needs_stmt_subtransactions, } } } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 87fbaec0a..da7518def 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -6,21 +6,23 @@ use crate::numeric::{NullableInteger, Numeric}; use crate::schema::Table; use crate::state_machine::StateMachine; use crate::storage::btree::{ - integrity_check, IntegrityCheckError, IntegrityCheckState, PageCategory, + integrity_check, CursorTrait, IntegrityCheckError, IntegrityCheckState, PageCategory, }; use crate::storage::database::DatabaseFile; use crate::storage::page_cache::PageCache; use crate::storage::pager::{AtomicDbState, CreateBTreeFlags, DbState}; -use crate::storage::sqlite3_ondisk::{read_varint, DatabaseHeader, PageSize}; +use crate::storage::sqlite3_ondisk::{read_varint_fast, DatabaseHeader, PageSize}; use crate::translate::collate::CollationSeq; use crate::types::{ compare_immutable, compare_records_generic, Extendable, IOCompletions, ImmutableRecord, SeekResult, Text, }; -use crate::util::normalize_ident; +use crate::util::{ + normalize_ident, rewrite_column_references_if_needed, rewrite_fk_parent_cols_if_self_ref, +}; use crate::vdbe::insn::InsertFlags; -use crate::vdbe::registers_to_ref_values; -use crate::vector::{vector_concat, vector_slice}; +use crate::vdbe::{registers_to_ref_values, EndStatement, TxnCleanup}; +use crate::vector::{vector32_sparse, vector_concat, vector_distance_jaccard, vector_slice}; use crate::{ error::{ LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, @@ -35,7 +37,8 @@ use crate::{ }, translate::emitter::TransactionMode, }; -use crate::{get_cursor, CheckpointMode, MvCursor}; +use crate::{get_cursor, CheckpointMode, Connection, MvCursor}; +use std::any::Any; use std::env::temp_dir; use std::ops::DerefMut; use std::{ @@ -65,15 +68,14 @@ use crate::{ vector::{vector32, vector64, vector_distance_cos, vector_distance_l2, vector_extract}, }; -use crate::{info, turso_assert, OpenFlags, RefValue, Row, TransactionState}; +use crate::{info, turso_assert, OpenFlags, Row, TransactionState, ValueRef}; use super::{ insn::{Cookie, RegisterOrLiteral}, CommitState, }; use parking_lot::RwLock; -use rand::{thread_rng, Rng}; -use turso_parser::ast::{self, Name, SortOrder}; +use turso_parser::ast::{self, ForeignKeyClause, Name, SortOrder}; use turso_parser::parser::Parser; use super::{ @@ -157,7 +159,6 @@ pub enum InsnFunctionStepResult { Done, IO(IOCompletions), Row, - Interrupt, Step, } @@ -407,7 +408,7 @@ pub fn op_checkpoint_inner( let step_result = program .connection .pager - .write() + .load() .wal_checkpoint_start(*checkpoint_mode); match step_result { Ok(IOResult::Done(result)) => { @@ -428,7 +429,7 @@ pub fn op_checkpoint_inner( let step_result = program .connection .pager - .write() + .load() .wal_checkpoint_finish(result.as_mut().unwrap()); match step_result { Ok(IOResult::Done(())) => { @@ -1045,16 +1046,9 @@ pub fn op_open_read( let pager = program.get_pager_from_database_index(db); let (_, cursor_type) = program.cursor_ref.get(*cursor_id).unwrap(); - let mv_cursor = if let Some(tx_id) = program.connection.get_mv_tx_id() { - let mv_store = mv_store.unwrap().clone(); - let mv_cursor = Arc::new(RwLock::new( - MvCursor::new(mv_store, tx_id, *root_page, pager.clone()).unwrap(), - )); - Some(mv_cursor) - } else { + if program.connection.get_mv_tx_id().is_none() { assert!(*root_page >= 0, ""); - None - }; + } let cursors = &mut state.cursors; let num_columns = match cursor_type { CursorType::BTreeTable(table_rc) => table_rc.columns.len(), @@ -1063,16 +1057,33 @@ pub fn op_open_read( _ => unreachable!("This should not have happened"), }; + let maybe_promote_to_mvcc_cursor = + |btree_cursor: Box| -> Result> { + if let Some(tx_id) = program.connection.get_mv_tx_id() { + let mv_store = mv_store.unwrap().clone(); + Ok(Box::new(MvCursor::new( + mv_store, + tx_id, + *root_page, + pager.clone(), + btree_cursor, + )?)) + } else { + Ok(btree_cursor) + } + }; + match cursor_type { CursorType::MaterializedView(_, view_mutex) => { // This is a materialized view with storage // Create btree cursor for reading the persistent data + let btree_cursor = Box::new(BTreeCursor::new_table( - mv_cursor, pager.clone(), *root_page, num_columns, )); + let cursor = maybe_promote_to_mvcc_cursor(btree_cursor)?; // Get the view name and look up or create its transaction state let view_name = view_mutex.lock().unwrap().name().to_string(); @@ -1083,7 +1094,7 @@ pub fn op_open_read( // Create materialized view cursor with this view's transaction state let mv_cursor = crate::incremental::cursor::MaterializedViewCursor::new( - btree_cursor, + cursor, view_mutex.clone(), pager.clone(), tx_state, @@ -1096,20 +1107,25 @@ pub fn op_open_read( } CursorType::BTreeTable(_) => { // Regular table - let cursor = BTreeCursor::new_table(mv_cursor, pager.clone(), *root_page, num_columns); + let btree_cursor = Box::new(BTreeCursor::new_table( + pager.clone(), + *root_page, + num_columns, + )); + let cursor = maybe_promote_to_mvcc_cursor(btree_cursor)?; cursors .get_mut(*cursor_id) .unwrap() .replace(Cursor::new_btree(cursor)); } CursorType::BTreeIndex(index) => { - let cursor = BTreeCursor::new_index( - mv_cursor, + let btree_cursor = Box::new(BTreeCursor::new_index( pager.clone(), *root_page, index.as_ref(), num_columns, - ); + )); + let cursor = maybe_promote_to_mvcc_cursor(btree_cursor)?; cursors .get_mut(*cursor_id) .unwrap() @@ -1470,44 +1486,6 @@ pub fn op_last( Ok(InsnFunctionStepResult::Step) } -/// Fast varint reader optimized for the common cases of 1-byte and 2-byte varints. -/// -/// This function is a performance-optimized version of `read_varint()` that handles -/// the most common varint cases inline before falling back to the full implementation. -/// It follows the same varint encoding as SQLite. -/// -/// # Optimized Cases -/// -/// - **Single-byte case**: Values 0-127 (0x00-0x7F) are returned immediately -/// - **Two-byte case**: Values 128-16383 (0x80-0x3FFF) are handled inline -/// - **Multi-byte case**: Larger values fall back to the full `read_varint()` implementation -/// -/// This function is similar to `sqlite3GetVarint32` -#[inline(always)] -fn read_varint_fast(buf: &[u8]) -> Result<(u64, usize)> { - // Fast path: Single-byte varint - if let Some(&first_byte) = buf.first() { - if first_byte & 0x80 == 0 { - return Ok((first_byte as u64, 1)); - } - } else { - crate::bail_corrupt_error!("Invalid varint"); - } - - // Fast path: Two-byte varint - if let Some(&second_byte) = buf.get(1) { - if second_byte & 0x80 == 0 { - let v = (((buf[0] & 0x7f) as u64) << 7) + (second_byte as u64); - return Ok((v, 2)); - } - } else { - crate::bail_corrupt_error!("Invalid varint"); - } - - //Fallback: Multi-byte varint - read_varint(buf) -} - #[derive(Debug, Clone, Copy)] pub enum OpColumnState { Start, @@ -1631,7 +1609,7 @@ pub fn op_column( break 'ifnull; }; - let mut record_cursor = cursor.record_cursor.borrow_mut(); + let mut record_cursor = cursor.record_cursor_mut(); if record_cursor.offsets.is_empty() { let (header_size, header_len_bytes) = read_varint_fast(payload)?; @@ -1876,11 +1854,7 @@ pub fn op_column( let value = { let cursor = state.get_cursor(*cursor_id); let cursor = cursor.as_pseudo_mut(); - if let Some(record) = cursor.record() { - record.get_value(*column)?.to_owned() - } else { - Value::Null - } + cursor.get_value(*column)? }; state.registers[*dest] = Register::Value(value); } @@ -2137,8 +2111,8 @@ pub fn halt( description: &str, ) -> Result { if err_code > 0 { - // invalidate page cache in case of error - pager.clear_page_cache(); + // Any non-FK constraint violation causes the statement subtransaction to roll back. + state.end_statement(&program.connection, pager, EndStatement::RollbackSavepoint)?; } match err_code { 0 => {} @@ -2166,11 +2140,46 @@ pub fn halt( let auto_commit = program.connection.auto_commit.load(Ordering::SeqCst); tracing::trace!("halt(auto_commit={})", auto_commit); + + // Check for immediate foreign key violations. + // Any immediate violation causes the statement subtransaction to roll back. + if program.connection.foreign_keys_enabled() + && state + .fk_immediate_violations_during_stmt + .load(Ordering::Acquire) + > 0 + { + state.end_statement(&program.connection, pager, EndStatement::RollbackSavepoint)?; + return Err(LimboError::Constraint( + "foreign key constraint failed".to_string(), + )); + } + if auto_commit { + // In autocommit mode, a statement that leaves deferred violations must fail here, + // and it also ends the transaction. + if program.connection.foreign_keys_enabled() { + let deferred_violations = program + .connection + .fk_deferred_violations + .swap(0, Ordering::AcqRel); + if deferred_violations > 0 { + pager.rollback_tx(&program.connection); + program.connection.set_tx_state(TransactionState::None); + program.connection.auto_commit.store(true, Ordering::SeqCst); + return Err(LimboError::Constraint( + "foreign key constraint failed".to_string(), + )); + } + } + state.end_statement(&program.connection, pager, EndStatement::ReleaseSavepoint)?; program .commit_txn(pager.clone(), state, mv_store, false) .map(Into::into) } else { + // Even if deferred violations are present, the statement subtransaction completes successfully when + // it is part of an interactive transaction. + state.end_statement(&program.connection, pager, EndStatement::ReleaseSavepoint)?; Ok(InsnFunctionStepResult::Done) } } @@ -2219,6 +2228,7 @@ pub fn op_halt_if_null( pub enum OpTransactionState { Start, CheckSchemaCookie, + BeginStatement, } pub fn op_transaction( @@ -2258,7 +2268,7 @@ pub fn op_transaction_inner( OpTransactionState::Start => { let conn = program.connection.clone(); let write = matches!(tx_mode, TransactionMode::Write); - if write && conn.db.open_flags.contains(OpenFlags::ReadOnly) { + if write && conn.db.open_flags.get().contains(OpenFlags::ReadOnly) { return Err(LimboError::ReadOnly); } @@ -2328,7 +2338,7 @@ pub fn op_transaction_inner( | TransactionMode::Read | TransactionMode::Concurrent => mv_store.begin_tx(pager.clone())?, TransactionMode::Write => { - return_if_io!(mv_store.begin_exclusive_tx(pager.clone(), None)) + mv_store.begin_exclusive_tx(pager.clone(), None)? } }; *program.connection.mv_tx.write() = Some((tx_id, *tx_mode)); @@ -2343,7 +2353,7 @@ pub fn op_transaction_inner( if matches!(new_transaction_state, TransactionState::Write { .. }) && matches!(actual_tx_mode, TransactionMode::Write) { - return_if_io!(mv_store.begin_exclusive_tx(pager.clone(), Some(tx_id))); + mv_store.begin_exclusive_tx(pager.clone(), Some(tx_id))?; } } } else { @@ -2359,6 +2369,7 @@ pub fn op_transaction_inner( "nested stmt should not begin a new read transaction" ); pager.begin_read_tx()?; + state.auto_txn_cleanup = TxnCleanup::RollbackTxn; } if updated && matches!(new_transaction_state, TransactionState::Write { .. }) { @@ -2372,8 +2383,9 @@ pub fn op_transaction_inner( // That is, if the transaction had not started, end the read transaction so that next time we // start a new one. if matches!(current_state, TransactionState::None) { - pager.end_read_tx()?; + pager.end_read_tx(); conn.set_tx_state(TransactionState::None); + state.auto_txn_cleanup = TxnCleanup::None; } assert_eq!(conn.get_tx_state(), current_state); return Err(LimboError::Busy); @@ -2400,9 +2412,7 @@ pub fn op_transaction_inner( // Can only read header if page 1 has been allocated already // begin_write_tx that happens, but not begin_read_tx OpTransactionState::CheckSchemaCookie => { - let res = with_header(&pager, mv_store, program, |header| { - header.schema_cookie.get() - }); + let res = get_schema_cookie(&pager, mv_store, program); match res { Ok(IOResult::Done(header_schema_cookie)) => { if header_schema_cookie != *schema_cookie { @@ -2422,6 +2432,17 @@ pub fn op_transaction_inner( } } + state.op_transaction_state = OpTransactionState::BeginStatement; + } + OpTransactionState::BeginStatement => { + if program.needs_stmt_subtransactions && mv_store.is_none() { + let write = matches!(tx_mode, TransactionMode::Write); + let res = state.begin_statement(&program.connection, &pager, write)?; + if let IOResult::IO(io) = res { + return Ok(InsnFunctionStepResult::IO(io)); + } + } + state.pc += 1; state.op_transaction_state = OpTransactionState::Start; return Ok(InsnFunctionStepResult::Step); @@ -2440,40 +2461,74 @@ pub fn op_auto_commit( load_insn!( AutoCommit { auto_commit, - rollback, + rollback }, insn ); + let conn = program.connection.clone(); + let fk_on = conn.foreign_keys_enabled(); + let had_autocommit = conn.auto_commit.load(Ordering::SeqCst); // true, not in tx + + // Drive any multi-step commit/rollback that’s already in progress. if matches!(state.commit_state, CommitState::Committing) { - return program + let res = program .commit_txn(pager.clone(), state, mv_store, *rollback) .map(Into::into); + // Only clear after a final, successful non-rollback COMMIT. + if fk_on + && !*rollback + && matches!( + res, + Ok(InsnFunctionStepResult::Step | InsnFunctionStepResult::Done) + ) + { + conn.clear_deferred_foreign_key_violations(); + } + return res; } - if *auto_commit != conn.auto_commit.load(Ordering::SeqCst) { - if *rollback { - // TODO(pere): add rollback I/O logic once we implement rollback journal + // The logic in this opcode can be a bit confusing, so to make things a bit clearer lets be + // very explicit about the currently existing and requested state. + let requested_autocommit = *auto_commit; + let requested_rollback = *rollback; + let changed = requested_autocommit != had_autocommit; + + // what the requested operation is + let is_begin_req = had_autocommit && !requested_autocommit && !requested_rollback; + let is_commit_req = !had_autocommit && requested_autocommit && !requested_rollback; + let is_rollback_req = !had_autocommit && requested_autocommit && requested_rollback; + + if changed { + if requested_rollback { + // ROLLBACK transition if let Some(mv_store) = mv_store { if let Some(tx_id) = conn.get_mv_tx_id() { - mv_store.rollback_tx(tx_id, pager.clone(), &conn)?; + mv_store.rollback_tx(tx_id, pager.clone(), &conn); } } else { - return_if_io!(pager.end_tx(true, &conn)); + pager.rollback_tx(&conn); } conn.set_tx_state(TransactionState::None); conn.auto_commit.store(true, Ordering::SeqCst); } else { - conn.auto_commit.store(*auto_commit, Ordering::SeqCst); + // BEGIN (true->false) or COMMIT (false->true) + if is_commit_req { + // Pre-check deferred FKs; leave tx open and do NOT clear violations + check_deferred_fk_on_commit(&conn)?; + } + conn.auto_commit + .store(requested_autocommit, Ordering::SeqCst); } } else { - let mvcc_tx_active = program.connection.get_mv_tx().is_some(); + // No autocommit flip + let mvcc_tx_active = conn.get_mv_tx().is_some(); if !mvcc_tx_active { - if !*auto_commit { + if !requested_autocommit { return Err(LimboError::TxError( "cannot start a transaction within a transaction".to_string(), )); - } else if *rollback { + } else if requested_rollback { return Err(LimboError::TxError( "cannot rollback - no transaction is active".to_string(), )); @@ -2482,19 +2537,41 @@ pub fn op_auto_commit( "cannot commit - no transaction is active".to_string(), )); } - } else { - let is_begin = !*auto_commit && !*rollback; - if is_begin { - return Err(LimboError::TxError( - "cannot use BEGIN after BEGIN CONCURRENT".to_string(), - )); - } + } else if is_begin_req { + return Err(LimboError::TxError( + "cannot use BEGIN after BEGIN CONCURRENT".to_string(), + )); } } - program - .commit_txn(pager.clone(), state, mv_store, *rollback) - .map(Into::into) + let res = program + .commit_txn(pager.clone(), state, mv_store, requested_rollback) + .map(Into::into); + + // Clear deferred FK counters only after FINAL success of COMMIT/ROLLBACK. + if fk_on + && matches!( + res, + Ok(InsnFunctionStepResult::Step | InsnFunctionStepResult::Done) + ) + && (is_rollback_req || is_commit_req) + { + conn.clear_deferred_foreign_key_violations(); + } + + res +} + +fn check_deferred_fk_on_commit(conn: &Connection) -> Result<()> { + if !conn.foreign_keys_enabled() { + return Ok(()); + } + if conn.get_deferred_foreign_key_violations() > 0 { + return Err(LimboError::Constraint( + "FOREIGN KEY constraint failed".into(), + )); + } + Ok(()) } pub fn op_goto( @@ -2700,11 +2777,11 @@ pub fn op_row_id( let index_cursor = index_cursor.as_btree_mut(); let record = return_if_io!(index_cursor.record()); let record = record.as_ref().unwrap(); - let mut record_cursor_ref = index_cursor.record_cursor.borrow_mut(); + let mut record_cursor_ref = index_cursor.record_cursor_mut(); let record_cursor = record_cursor_ref.deref_mut(); let rowid = record.last_value(record_cursor).unwrap(); match rowid { - Ok(RefValue::Integer(rowid)) => rowid, + Ok(ValueRef::Integer(rowid)) => rowid, _ => unreachable!(), } }; @@ -3217,7 +3294,7 @@ pub fn seek_internal( // this same logic applies for indexes, but the next/prev record is expected to be found in the parent page's // divider cell. turso_assert!( - !cursor.skip_advance.get(), + !cursor.get_skip_advance(), "skip_advance should not be true in the middle of a seek operation" ); let result = match op { @@ -3227,7 +3304,7 @@ pub fn seek_internal( }; match result { IOResult::Done(found) => { - cursor.has_record.set(found); + cursor.set_has_record(found); cursor.invalidate_record(); found } @@ -3341,9 +3418,9 @@ pub fn op_idx_ge( registers_to_ref_values(&state.registers[*start_reg..*start_reg + *num_regs]); let tie_breaker = get_tie_breaker_from_idx_comp_op(insn); let ord = compare_records_generic( - &idx_record, // The serialized record from the index - &values, // The record built from registers - cursor.index_info.as_ref().unwrap(), // Sort order flags + &idx_record, // The serialized record from the index + &values, // The record built from registers + cursor.get_index_info(), // Sort order flags 0, tie_breaker, )?; @@ -3411,7 +3488,7 @@ pub fn op_idx_le( let ord = compare_records_generic( &idx_record, &values, - cursor.index_info.as_ref().unwrap(), + cursor.get_index_info(), 0, tie_breaker, )?; @@ -3462,7 +3539,7 @@ pub fn op_idx_gt( let ord = compare_records_generic( &idx_record, &values, - cursor.index_info.as_ref().unwrap(), + cursor.get_index_info(), 0, tie_breaker, )?; @@ -3514,7 +3591,7 @@ pub fn op_idx_lt( let ord = compare_records_generic( &idx_record, &values, - cursor.index_info.as_ref().unwrap(), + cursor.get_index_info(), 0, tie_breaker, )?; @@ -4274,7 +4351,14 @@ pub fn op_sorter_compare( &record.get_values()[..*num_regs] }; - let cursor = state.get_cursor(*cursor_id); + // Inlined `state.get_cursor` to prevent borrowing conflit with `state.registers` + let cursor = state + .cursors + .get_mut(*cursor_id) + .unwrap_or_else(|| panic!("cursor id {cursor_id} out of bounds")) + .as_mut() + .unwrap_or_else(|| panic!("cursor id {cursor_id} is None")); + let cursor = cursor.as_sorter_mut(); let Some(current_sorter_record) = cursor.record() else { return Err(LimboError::InternalError( @@ -4286,7 +4370,7 @@ pub fn op_sorter_compare( // If the current sorter record has a NULL in any of the significant fields, the comparison is not equal. let is_equal = current_sorter_values .iter() - .all(|v| !matches!(v, RefValue::Null)) + .all(|v| !matches!(v, ValueRef::Null)) && compare_immutable( previous_sorter_values, current_sorter_values, @@ -4728,7 +4812,9 @@ pub fn op_function( ScalarFunc::Typeof => Some(reg_value.exec_typeof()), ScalarFunc::Unicode => Some(reg_value.exec_unicode()), ScalarFunc::Quote => Some(reg_value.exec_quote()), - ScalarFunc::RandomBlob => Some(reg_value.exec_randomblob()), + ScalarFunc::RandomBlob => { + Some(reg_value.exec_randomblob(|dest| pager.io.fill_bytes(dest))) + } ScalarFunc::ZeroBlob => Some(reg_value.exec_zeroblob()), ScalarFunc::Soundex => Some(reg_value.exec_soundex()), _ => unreachable!(), @@ -4753,7 +4839,8 @@ pub fn op_function( state.registers[*dest] = Register::Value(result); } ScalarFunc::Random => { - state.registers[*dest] = Register::Value(Value::exec_random()); + state.registers[*dest] = + Register::Value(Value::exec_random(|| pager.io.generate_random_number())); } ScalarFunc::Trim => { let reg_value = &state.registers[*start_reg]; @@ -4892,7 +4979,7 @@ pub fn op_function( } } } - ScalarFunc::SqliteVersion => { + ScalarFunc::TursoVersion => { if !program.connection.is_db_initialized() { state.registers[*dest] = Register::Value(Value::build_text(info::build::PKG_VERSION)); @@ -4900,10 +4987,14 @@ pub fn op_function( let version_integer = return_if_io!(pager.with_header(|header| header.version_number)).get() as i64; - let version = execute_sqlite_version(version_integer); + let version = execute_turso_version(version_integer); state.registers[*dest] = Register::Value(Value::build_text(version)); } } + ScalarFunc::SqliteVersion => { + let version = execute_sqlite_version(); + state.registers[*dest] = Register::Value(Value::build_text(version)); + } ScalarFunc::SqliteSourceId => { let src_id = format!( "{} {}", @@ -4952,7 +5043,7 @@ pub fn op_function( } #[cfg(feature = "json")] { - use crate::types::{TextRef, TextSubtype}; + use crate::types::TextSubtype; let table = state.registers[*start_reg].get_value(); let Value::Text(table) = table else { @@ -4977,10 +5068,7 @@ pub fn op_function( for column in table.columns() { let name = column.name.as_ref().unwrap(); let name_json = json::convert_ref_dbtype_to_jsonb( - &RefValue::Text(TextRef::create_from( - name.as_str().as_bytes(), - TextSubtype::Text, - )), + ValueRef::Text(name.as_bytes(), TextSubtype::Text), json::Conv::ToString, )?; json.append_jsonb_to_end(name_json.data()); @@ -5048,13 +5136,13 @@ pub fn op_function( json.append_jsonb_to_end(column_name.data()); let val = record_cursor.get_value(&record, i)?; - if let RefValue::Blob(..) = val { + if let ValueRef::Blob(..) = val { return Err(LimboError::InvalidArgument( "bin_record_json_object: formatting of BLOB values stored in binary record is not supported".to_string() )); } let val_json = - json::convert_ref_dbtype_to_jsonb(&val, json::Conv::NotStrict)?; + json::convert_ref_dbtype_to_jsonb(val, json::Conv::NotStrict)?; json.append_jsonb_to_end(val_json.data()); } json.finalize_unsafe(json::jsonb::ElementType::OBJECT)?; @@ -5121,6 +5209,10 @@ pub fn op_function( let result = vector32(&state.registers[*start_reg..*start_reg + arg_count])?; state.registers[*dest] = Register::Value(result); } + VectorFunc::Vector32Sparse => { + let result = vector32_sparse(&state.registers[*start_reg..*start_reg + arg_count])?; + state.registers[*dest] = Register::Value(result); + } VectorFunc::Vector64 => { let result = vector64(&state.registers[*start_reg..*start_reg + arg_count])?; state.registers[*dest] = Register::Value(result); @@ -5134,11 +5226,16 @@ pub fn op_function( vector_distance_cos(&state.registers[*start_reg..*start_reg + arg_count])?; state.registers[*dest] = Register::Value(result); } - VectorFunc::VectorDistanceEuclidean => { + VectorFunc::VectorDistanceL2 => { let result = vector_distance_l2(&state.registers[*start_reg..*start_reg + arg_count])?; state.registers[*dest] = Register::Value(result); } + VectorFunc::VectorDistanceJaccard => { + let result = + vector_distance_jaccard(&state.registers[*start_reg..*start_reg + arg_count])?; + state.registers[*dest] = Register::Value(result); + } VectorFunc::VectorConcat => { let result = vector_concat(&state.registers[*start_reg..*start_reg + arg_count])?; state.registers[*dest] = Register::Value(result); @@ -5378,11 +5475,9 @@ pub fn op_function( .parse_column_definition(true) .unwrap(); - let new_sql = 'sql: { - if table != tbl_name { - break 'sql None; - } + let rename_to = normalize_ident(column_def.col_name.as_str()); + let new_sql = 'sql: { let Value::Text(sql) = sql else { break 'sql None; }; @@ -5436,34 +5531,160 @@ pub fn op_function( temporary, if_not_exists, } => { - if table != normalize_ident(tbl_name.name.as_str()) { - break 'sql None; - } - let ast::CreateTableBody::ColumnsAndConstraints { mut columns, - constraints, + mut constraints, options, } = body else { todo!() }; - let column = columns - .iter_mut() - .find(|column| { - column.col_name.as_str() == original_rename_from.as_str() - }) - .expect("column being renamed should be present"); + let normalized_tbl_name = normalize_ident(tbl_name.name.as_str()); - match alter_func { - AlterTableFunc::AlterColumn => *column = column_def, - AlterTableFunc::RenameColumn => { - column.col_name = column_def.col_name + if normalized_tbl_name == table { + // This is the table being altered - update its column + let column = columns + .iter_mut() + .find(|column| { + column.col_name.as_str() + == original_rename_from.as_str() + }) + .expect("column being renamed should be present"); + + match alter_func { + AlterTableFunc::AlterColumn => *column = column_def.clone(), + AlterTableFunc::RenameColumn => { + column.col_name = column_def.col_name.clone() + } + _ => unreachable!(), } - _ => unreachable!(), - } + // Update table-level constraints (PRIMARY KEY, UNIQUE, FOREIGN KEY) + for constraint in &mut constraints { + match &mut constraint.constraint { + ast::TableConstraint::PrimaryKey { + columns: pk_cols, + .. + } => { + for col in pk_cols { + let (ast::Expr::Name(ref name) + | ast::Expr::Id(ref name)) = *col.expr + else { + return Err(LimboError::ParseError("Unexpected expression in PRIMARY KEY constraint".to_string())); + }; + if normalize_ident(name.as_str()) == rename_from + { + *col.expr = ast::Expr::Name(Name::exact( + column_def.col_name.as_str().to_owned(), + )); + } + } + } + ast::TableConstraint::Unique { + columns: uniq_cols, + .. + } => { + for col in uniq_cols { + let (ast::Expr::Name(ref name) + | ast::Expr::Id(ref name)) = *col.expr + else { + return Err(LimboError::ParseError("Unexpected expression in UNIQUE constraint".to_string())); + }; + if normalize_ident(name.as_str()) == rename_from + { + *col.expr = ast::Expr::Name(Name::exact( + column_def.col_name.as_str().to_owned(), + )); + } + } + } + ast::TableConstraint::ForeignKey { + columns: child_cols, + clause, + .. + } => { + // Update child columns in this table's FK definitions + for child_col in child_cols { + if normalize_ident(child_col.col_name.as_str()) + == rename_from + { + child_col.col_name = Name::exact( + column_def.col_name.as_str().to_owned(), + ); + } + } + rewrite_fk_parent_cols_if_self_ref( + clause, + &normalized_tbl_name, + &rename_from, + column_def.col_name.as_str(), + ); + } + _ => {} + } + + for col in &mut columns { + rewrite_column_references_if_needed( + col, + &normalized_tbl_name, + &rename_from, + column_def.col_name.as_str(), + ); + } + } + } else { + // This is a different table, check if it has FKs referencing the renamed column + let mut fk_updated = false; + + for constraint in &mut constraints { + if let ast::TableConstraint::ForeignKey { + columns: _, + clause: + ForeignKeyClause { + tbl_name, + columns: parent_cols, + .. + }, + .. + } = &mut constraint.constraint + { + // Check if this FK references the table being altered + if normalize_ident(tbl_name.as_str()) == table { + // Update parent column references if they match the renamed column + for parent_col in parent_cols { + if normalize_ident(parent_col.col_name.as_str()) + == rename_from + { + parent_col.col_name = Name::exact( + column_def.col_name.as_str().to_owned(), + ); + fk_updated = true; + } + } + } + } + } + for col in &mut columns { + let before = fk_updated; + let mut local_col = col.clone(); + rewrite_column_references_if_needed( + &mut local_col, + &table, + &rename_from, + column_def.col_name.as_str(), + ); + if local_col != *col { + *col = local_col; + fk_updated = true; + } + } + + // Only return updated SQL if we actually changed something + if !fk_updated { + break 'sql None; + } + } Some( ast::Stmt::CreateTable { tbl_name, @@ -5478,7 +5699,7 @@ pub fn op_function( .to_string(), ) } - _ => todo!(), + _ => None, } }; @@ -5519,8 +5740,9 @@ pub fn op_sequence( }, insn ); - let cursor = state.get_cursor(*cursor_id).as_sorter_mut(); - let seq_num = cursor.next_sequence(); + let cursor_seq = state.cursor_seqs.get_mut(*cursor_id).unwrap(); + let seq_num = *cursor_seq; + *cursor_seq += 1; state.registers[*target_reg] = Register::Value(Value::Integer(seq_num)); state.pc += 1; Ok(InsnFunctionStepResult::Step) @@ -5541,8 +5763,10 @@ pub fn op_sequence_test( }, insn ); - let cursor = state.get_cursor(*cursor_id).as_sorter_mut(); - state.pc = if cursor.seq_beginning() { + let cursor_seq = state.cursor_seqs.get_mut(*cursor_id).unwrap(); + let was_zero = *cursor_seq == 0; + *cursor_seq += 1; + state.pc = if was_zero { target_pc.as_offset_int() } else { state.pc + 1 @@ -5695,31 +5919,52 @@ pub fn op_insert( turso_assert!(!flag.has(InsertFlags::REQUIRE_SEEK), "to capture old record accurately, we must be located at the correct position in the table"); + // Get the key we're going to insert + let insert_key = match &state.registers[*key_reg].get_value() { + Value::Integer(i) => *i, + _ => { + // If key is not an integer, we can't check - assume no old record + state.op_insert_state.old_record = None; + state.op_insert_state.sub_state = if flag.has(InsertFlags::REQUIRE_SEEK) { + OpInsertSubState::Seek + } else { + OpInsertSubState::Insert + }; + continue; + } + }; + let old_record = { let cursor = state.get_cursor(*cursor_id); let cursor = cursor.as_btree_mut(); // Get the current key - for INSERT operations, there may not be a current row let maybe_key = return_if_io!(cursor.rowid()); if let Some(key) = maybe_key { - // Get the current record before deletion and extract values - let maybe_record = return_if_io!(cursor.record()); - if let Some(record) = maybe_record { - let mut values = record - .get_values() - .into_iter() - .map(|v| v.to_owned()) - .collect::>(); + // Only capture as old record if the cursor is at the position we're inserting to + if key == insert_key { + // Get the current record before deletion and extract values + let maybe_record = return_if_io!(cursor.record()); + if let Some(record) = maybe_record { + let mut values = record + .get_values() + .into_iter() + .map(|v| v.to_owned()) + .collect::>(); - // Fix rowid alias columns: replace Null with actual rowid value - if let Some(table) = schema.get_table(table_name) { - for (i, col) in table.columns().iter().enumerate() { - if col.is_rowid_alias && i < values.len() { - values[i] = Value::Integer(key); + // Fix rowid alias columns: replace Null with actual rowid value + if let Some(table) = schema.get_table(table_name) { + for (i, col) in table.columns().iter().enumerate() { + if col.is_rowid_alias && i < values.len() { + values[i] = Value::Integer(key); + } } } + Some((key, values)) + } else { + None } - Some((key, values)) } else { + // Cursor is at wrong position - this is a fresh INSERT, not a replacement None } } else { @@ -5786,7 +6031,10 @@ pub fn op_insert( let cursor = cursor.as_btree_mut(); cursor.root_page() }; - if root_page != 1 && table_name != "sqlite_sequence" { + if root_page != 1 + && table_name != "sqlite_sequence" + && !flag.has(InsertFlags::EPHEMERAL_TABLE_INSERT) + { state.op_insert_state.sub_state = OpInsertSubState::UpdateLastRowid; } else { let schema = program.connection.schema.read(); @@ -5806,7 +6054,6 @@ pub fn op_insert( }; if let Some(rowid) = maybe_rowid { program.connection.update_last_rowid(rowid); - program .n_change .fetch_add(1, std::sync::atomic::Ordering::SeqCst); @@ -5940,7 +6187,8 @@ pub fn op_delete( load_insn!( Delete { cursor_id, - table_name + table_name, + is_part_of_update, }, insn ); @@ -6025,9 +6273,13 @@ pub fn op_delete( } state.op_delete_state.sub_state = OpDeleteSubState::MaybeCaptureRecord; - program - .n_change - .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + if !is_part_of_update { + // DELETEs do not count towards the total changes if they are part of an UPDATE statement, + // i.e. the DELETE and subsequent INSERT of a row are the same "change". + program + .n_change + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } state.pc += 1; Ok(InsnFunctionStepResult::Step) } @@ -6094,7 +6346,7 @@ pub fn op_idx_delete( .map(|i| &state.registers[i]) .collect::>(); return Err(LimboError::Corrupt(format!( - "IdxDelete: no matching index entry found for key {reg_values:?}" + "IdxDelete: no matching index entry found for key {reg_values:?} while seeking" ))); } state.pc += 1; @@ -6115,7 +6367,7 @@ pub fn op_idx_delete( .map(|i| &state.registers[i]) .collect::>(); return Err(LimboError::Corrupt(format!( - "IdxDelete: no matching index entry found for key {reg_values:?}" + "IdxDelete: no matching index entry found for key while verifying: {reg_values:?}" ))); } state.op_idx_delete_state = Some(OpIdxDeleteState::Deleting); @@ -6226,10 +6478,10 @@ pub fn op_idx_insert( // Cursor is pointing at a record; if the index has a rowid, exclude it from the comparison since it's a pointer to the table row; // UNIQUE indexes disallow duplicates like (a=1,b=2,rowid=1) and (a=1,b=2,rowid=2). let existing_key = if cursor.has_rowid() { - let count = cursor.record_cursor.borrow_mut().count(record); - record.get_values()[..count.saturating_sub(1)].to_vec() + let count = cursor.record_cursor_mut().count(record); + &record.get_values()[..count.saturating_sub(1)] } else { - record.get_values().to_vec() + &record.get_values()[..] }; let inserted_key_vals = &record_to_insert.get_values(); if existing_key.len() != inserted_key_vals.len() { @@ -6237,9 +6489,9 @@ pub fn op_idx_insert( } let conflict = compare_immutable( - existing_key.as_slice(), + existing_key, inserted_key_vals, - &cursor.index_info.as_ref().unwrap().key_info, + &cursor.get_index_info().key_info, ) == std::cmp::Ordering::Equal; if conflict { if flags.has(IdxInsertFlags::NO_OP_DUPLICATE) { @@ -6306,17 +6558,17 @@ pub fn op_new_rowid( NewRowid { cursor, rowid_reg, - .. + prev_largest_reg, }, insn ); - if let Some(mv_store) = mv_store { + // With MVCC we can't simply find last rowid and get rowid + 1 as a result. To not have two conflicting rowids concurrently we need to call `get_next_rowid` + // which will make sure we don't collide. let rowid = { let cursor = state.get_cursor(*cursor); - let cursor = cursor.as_btree_mut(); - let mvcc_cursor = cursor.get_mvcc_cursor(); - let mut mvcc_cursor = mvcc_cursor.write(); + let cursor = cursor.as_btree_mut() as &mut dyn Any; + let mvcc_cursor = cursor.downcast_mut::().unwrap(); mvcc_cursor.get_next_rowid() }; state.registers[*rowid_reg] = Register::Value(Value::Integer(rowid)); @@ -6349,6 +6601,11 @@ pub fn op_new_rowid( return_if_io!(cursor.rowid()) }; + if *prev_largest_reg > 0 { + state.registers[*prev_largest_reg] = + Register::Value(Value::Integer(current_max.unwrap_or(0))); + } + match current_max { Some(rowid) if rowid < MAX_ROWID => { // Can use sequential @@ -6378,8 +6635,7 @@ pub fn op_new_rowid( // Generate a random i64 and constrain it to the lower half of the rowid range. // We use the lower half (1 to MAX_ROWID/2) because we're in random mode only // when sequential allocation reached MAX_ROWID, meaning the upper range is full. - let mut rng = thread_rng(); - let mut random_rowid: i64 = rng.gen(); + let mut random_rowid: i64 = pager.io.generate_random_number(); random_rowid &= MAX_ROWID >> 1; // Mask to keep value in range [0, MAX_ROWID/2] random_rowid += 1; // Ensure positive @@ -6441,21 +6697,17 @@ pub fn op_must_be_int( Value::Integer(_) => {} Value::Float(f) => match cast_real_to_integer(*f) { Ok(i) => state.registers[*reg] = Register::Value(Value::Integer(i)), - Err(_) => crate::bail_parse_error!( - "MustBeInt: the value in register cannot be cast to integer" - ), + Err(_) => crate::bail_parse_error!("datatype mismatch"), }, Value::Text(text) => match checked_cast_text_to_numeric(text.as_str()) { Ok(Value::Integer(i)) => state.registers[*reg] = Register::Value(Value::Integer(i)), Ok(Value::Float(f)) => { state.registers[*reg] = Register::Value(Value::Integer(f as i64)) } - _ => crate::bail_parse_error!( - "MustBeInt: the value in register cannot be cast to integer" - ), + _ => crate::bail_parse_error!("datatype mismatch"), }, _ => { - crate::bail_parse_error!("MustBeInt: the value in register cannot be cast to integer"); + crate::bail_parse_error!("datatype mismatch"); } }; state.pc += 1; @@ -6528,7 +6780,7 @@ pub fn op_no_conflict( record .get_values() .iter() - .any(|val| matches!(val, RefValue::Null)) + .any(|val| matches!(val, ValueRef::Null)) } RecordSource::Unpacked { start_reg, @@ -6592,16 +6844,10 @@ pub fn op_not_exists( }, insn ); - let exists = if let Some(mv_store) = mv_store { - let cursor = must_be_btree_cursor!(*cursor, program.cursor_ref, state, "NotExists"); - let cursor = cursor.as_btree_mut(); - let mvcc_cursor = cursor.get_mvcc_cursor(); - false - } else { - let cursor = must_be_btree_cursor!(*cursor, program.cursor_ref, state, "NotExists"); - let cursor = cursor.as_btree_mut(); - return_if_io!(cursor.exists(state.registers[*rowid_reg].get_value())) - }; + let cursor = must_be_btree_cursor!(*cursor, program.cursor_ref, state, "NotExists"); + let cursor = cursor.as_btree_mut(); + let exists = return_if_io!(cursor.exists(state.registers[*rowid_reg].get_value())); + if exists { state.pc += 1; } else { @@ -6709,48 +6955,71 @@ pub fn op_open_write( CursorType::BTreeIndex(index) => Some(index), _ => None, }; - let mv_cursor = if let Some(tx_id) = program.connection.get_mv_tx_id() { - let mv_store = mv_store.unwrap().clone(); - let mv_cursor = Arc::new(RwLock::new( - MvCursor::new(mv_store.clone(), tx_id, root_page, pager.clone()).unwrap(), - )); - Some(mv_cursor) + + // Check if we can reuse the existing cursor + let can_reuse_cursor = if let Some(Some(Cursor::BTree(btree_cursor))) = cursors.get(*cursor_id) + { + // Reuse if the root_page matches (same table/index) + btree_cursor.root_page() == root_page } else { - None + false }; - if let Some(index) = maybe_index { - let conn = program.connection.clone(); - let schema = conn.schema.read(); - let table = schema - .get_table(&index.table_name) - .and_then(|table| table.btree()); - let num_columns = index.columns.len(); - let cursor = BTreeCursor::new_index( - mv_cursor, - pager.clone(), - root_page, - index.as_ref(), - num_columns, - ); - cursors - .get_mut(*cursor_id) - .unwrap() - .replace(Cursor::new_btree(cursor)); - } else { - let num_columns = match cursor_type { - CursorType::BTreeTable(table_rc) => table_rc.columns.len(), - CursorType::MaterializedView(table_rc, _) => table_rc.columns.len(), - _ => unreachable!( - "Expected BTreeTable or MaterializedView. This should not have happened." - ), - }; + if !can_reuse_cursor { + let maybe_promote_to_mvcc_cursor = + |btree_cursor: Box| -> Result> { + if let Some(tx_id) = program.connection.get_mv_tx_id() { + let mv_store = mv_store.unwrap().clone(); + Ok(Box::new(MvCursor::new( + mv_store, + tx_id, + root_page, + pager.clone(), + btree_cursor, + )?)) + } else { + Ok(btree_cursor) + } + }; + if let Some(index) = maybe_index { + let conn = program.connection.clone(); + let schema = conn.schema.read(); + let table = schema + .get_table(&index.table_name) + .and_then(|table| table.btree()); - let cursor = BTreeCursor::new_table(mv_cursor, pager.clone(), root_page, num_columns); - cursors - .get_mut(*cursor_id) - .unwrap() - .replace(Cursor::new_btree(cursor)); + let num_columns = index.columns.len(); + let btree_cursor = Box::new(BTreeCursor::new_index( + pager.clone(), + root_page, + index.as_ref(), + num_columns, + )); + let cursor = maybe_promote_to_mvcc_cursor(btree_cursor)?; + cursors + .get_mut(*cursor_id) + .unwrap() + .replace(Cursor::new_btree(cursor)); + } else { + let num_columns = match cursor_type { + CursorType::BTreeTable(table_rc) => table_rc.columns.len(), + CursorType::MaterializedView(table_rc, _) => table_rc.columns.len(), + _ => unreachable!( + "Expected BTreeTable or MaterializedView. This should not have happened." + ), + }; + + let btree_cursor = Box::new(BTreeCursor::new_table( + pager.clone(), + root_page, + num_columns, + )); + let cursor = maybe_promote_to_mvcc_cursor(btree_cursor)?; + cursors + .get_mut(*cursor_id) + .unwrap() + .replace(Cursor::new_btree(cursor)); + } } state.pc += 1; Ok(InsnFunctionStepResult::Step) @@ -6810,6 +7079,11 @@ pub fn op_create_btree( Ok(InsnFunctionStepResult::Step) } +pub enum OpDestroyState { + CreateCursor, + DestroyBtree(Arc>), +} + pub fn op_destroy( program: &Program, state: &mut ProgramState, @@ -6833,15 +7107,26 @@ pub fn op_destroy( state.pc += 1; return Ok(InsnFunctionStepResult::Step); } - // TODO not sure if should be BTreeCursor::new_table or BTreeCursor::new_index here or neither and just pass an emtpy vec - let mut cursor = BTreeCursor::new(None, pager.clone(), *root, 0); - let former_root_page_result = cursor.btree_destroy()?; - if let IOResult::Done(former_root_page) = former_root_page_result { - state.registers[*former_root_reg] = - Register::Value(Value::Integer(former_root_page.unwrap_or(0) as i64)); + + loop { + match state.op_destroy_state { + OpDestroyState::CreateCursor => { + // Destroy doesn't do anything meaningful with the table/index distinction so we can just use a + // table btree cursor for both. + let cursor = BTreeCursor::new(pager.clone(), *root, 0); + state.op_destroy_state = + OpDestroyState::DestroyBtree(Arc::new(RwLock::new(cursor))); + } + OpDestroyState::DestroyBtree(ref mut cursor) => { + let maybe_former_root_page = return_if_io!(cursor.write().btree_destroy()); + state.registers[*former_root_reg] = + Register::Value(Value::Integer(maybe_former_root_page.unwrap_or(0) as i64)); + state.op_destroy_state = OpDestroyState::CreateCursor; + state.pc += 1; + return Ok(InsnFunctionStepResult::Step); + } + } } - state.pc += 1; - Ok(InsnFunctionStepResult::Step) } pub fn op_reset_sorter( @@ -7392,7 +7677,7 @@ pub enum OpOpenEphemeralState { // clippy complains this variant is too big when compared to the rest of the variants // so it says we need to box it here Rewind { - cursor: Box, + cursor: Box, }, } pub fn op_open_ephemeral( @@ -7416,7 +7701,7 @@ pub fn op_open_ephemeral( let page_size = return_if_io!(with_header(pager, mv_store, program, |header| header.page_size)); let conn = program.connection.clone(); - let io = conn.pager.read().io.clone(); + let io = conn.pager.load().io.clone(); let rand_num = io.generate_random_number(); let db_file; let db_file_io: Arc; @@ -7493,9 +7778,9 @@ pub fn op_open_ephemeral( }; let cursor = if let CursorType::BTreeIndex(index) = cursor_type { - BTreeCursor::new_index(None, pager.clone(), root_page, index, num_columns) + BTreeCursor::new_index(pager.clone(), root_page, index, num_columns) } else { - BTreeCursor::new_table(None, pager.clone(), root_page, num_columns) + BTreeCursor::new_table(pager.clone(), root_page, num_columns) }; state.op_open_ephemeral_state = OpOpenEphemeralState::Rewind { cursor: Box::new(cursor), @@ -7520,13 +7805,13 @@ pub fn op_open_ephemeral( cursors .get_mut(cursor_id) .unwrap() - .replace(Cursor::new_btree(*cursor)); + .replace(Cursor::new_btree(cursor)); } CursorType::BTreeIndex(_) => { cursors .get_mut(cursor_id) .unwrap() - .replace(Cursor::new_btree(*cursor)); + .replace(Cursor::new_btree(cursor)); } CursorType::Pseudo(_) => { panic!("OpenEphemeral on pseudo cursor"); @@ -7572,26 +7857,29 @@ pub fn op_open_dup( // We use the pager from the original cursor instead of the one attached to // the connection because each ephemeral table creates its own pager (and // a separate database file). - let pager = &original_cursor.pager; - - let mv_cursor = if let Some(tx_id) = program.connection.get_mv_tx_id() { - let mv_store = mv_store.unwrap().clone(); - let mv_cursor = Arc::new(RwLock::new(MvCursor::new( - mv_store, - tx_id, - root_page, - pager.clone(), - )?)); - Some(mv_cursor) - } else { - None - }; + let pager = original_cursor.get_pager(); let (_, cursor_type) = program.cursor_ref.get(*original_cursor_id).unwrap(); match cursor_type { CursorType::BTreeTable(table) => { - let cursor = - BTreeCursor::new_table(mv_cursor, pager.clone(), root_page, table.columns.len()); + let cursor = Box::new(BTreeCursor::new_table( + pager.clone(), + root_page, + table.columns.len(), + )); + let cursor: Box = + if let Some(tx_id) = program.connection.get_mv_tx_id() { + let mv_store = mv_store.unwrap().clone(); + Box::new(MvCursor::new( + mv_store, + tx_id, + root_page, + pager.clone(), + cursor, + )?) + } else { + cursor + }; let cursors = &mut state.cursors; cursors .get_mut(*new_cursor_id) @@ -7795,12 +8083,13 @@ pub fn op_integrity_check( ); match &mut state.op_integrity_check_state { OpIntegrityCheckState::Start => { - let freelist_trunk_page = - return_if_io!(with_header(pager, mv_store, program, |header| header - .freelist_trunk_page - .get())); + let (freelist_trunk_page, db_size) = + return_if_io!(with_header(pager, mv_store, program, |header| ( + header.freelist_trunk_page.get(), + header.database_size.get() + ))); let mut errors = Vec::new(); - let mut integrity_check_state = IntegrityCheckState::new(); + let mut integrity_check_state = IntegrityCheckState::new(db_size as usize); let mut current_root_idx = 0; // check freelist pages first, if there are any for database if freelist_trunk_page > 0 { @@ -7843,6 +8132,16 @@ pub fn op_integrity_check( expected_count: integrity_check_state.freelist_count.expected_count, }); } + for page_number in 2..=integrity_check_state.db_size { + if !integrity_check_state + .page_reference + .contains_key(&(page_number as i64)) + { + errors.push(IntegrityCheckError::PageNeverUsed { + page_id: page_number as i64, + }); + } + } let message = if errors.is_empty() { "ok".to_string() } else { @@ -8018,7 +8317,7 @@ pub fn op_drop_column( let schema = conn.schema.read(); for (view_name, view) in schema.views.iter() { let view_select_sql = format!("SELECT * FROM {view_name}"); - conn.prepare(view_select_sql.as_str()).map_err(|e| { + let _ = conn.prepare(view_select_sql.as_str()).map_err(|e| { LimboError::ParseError(format!( "cannot drop column \"{}\": referenced in VIEW {view_name}: {}", column_name, view.sql, @@ -8097,43 +8396,94 @@ pub fn op_alter_column( .clone() }; let new_column = crate::schema::Column::from(definition); + let new_name = definition.col_name.as_str().to_owned(); conn.with_schema_mut(|schema| { - let table = schema + let table_arc = schema .tables .get_mut(&normalized_table_name) - .expect("table being renamed should be in schema"); + .expect("table being ALTERed should be in schema"); + let table = Arc::make_mut(table_arc); - let table = Arc::make_mut(table); - - let Table::BTree(btree) = table else { - panic!("only btree tables can be renamed"); + let Table::BTree(ref mut btree_arc) = table else { + panic!("only btree tables can be altered"); }; - - let btree = Arc::make_mut(btree); - - let column = btree + let btree = Arc::make_mut(btree_arc); + let col = btree .columns .get_mut(*column_index) - .expect("renamed column should be in schema"); + .expect("column being ALTERed should be in schema"); - if let Some(indexes) = schema.indexes.get_mut(&normalized_table_name) { - for index in indexes { - let index = Arc::make_mut(index); - for index_column in &mut index.columns { - if index_column.name - == *column.name.as_ref().expect("btree column should be named") - { - index_column.name = definition.col_name.as_str().to_owned(); + // Update indexes on THIS table that name the old column (you already had this) + if let Some(idxs) = schema.indexes.get_mut(&normalized_table_name) { + for idx in idxs { + let idx = Arc::make_mut(idx); + for ic in &mut idx.columns { + if ic.name.eq_ignore_ascii_case( + col.name.as_ref().expect("btree column should be named"), + ) { + ic.name = new_name.clone(); + } + } + } + } + if *rename { + col.name = Some(new_name.clone()); + } else { + *col = new_column.clone(); + } + + // Keep primary_key_columns consistent (names may change on rename) + for (pk_name, _ord) in &mut btree.primary_key_columns { + if pk_name.eq_ignore_ascii_case(&old_column_name) { + *pk_name = new_name.clone(); + } + } + + // Maintain rowid-alias bit after change/rename (INTEGER PRIMARY KEY) + if !*rename { + // recompute alias from `new_column` + btree.columns[*column_index].is_rowid_alias = new_column.is_rowid_alias; + } + + // Update this table’s OWN foreign keys + for fk_arc in &mut btree.foreign_keys { + let fk = Arc::make_mut(fk_arc); + // child side: rename child column if it matches + for cc in &mut fk.child_columns { + if cc.eq_ignore_ascii_case(&old_column_name) { + *cc = new_name.clone(); + } + } + // parent side: if self-referencing, rename parent column too + if normalize_ident(&fk.parent_table) == normalized_table_name { + for pc in &mut fk.parent_columns { + if pc.eq_ignore_ascii_case(&old_column_name) { + *pc = new_name.clone(); } } } } - if *rename { - column.name = new_column.name; - } else { - *column = new_column; + // fix OTHER tables that reference this table as parent + for (tname, t_arc) in schema.tables.iter_mut() { + if normalize_ident(tname) == normalized_table_name { + continue; + } + if let Table::BTree(ref mut child_btree_arc) = Arc::make_mut(t_arc) { + let child_btree = Arc::make_mut(child_btree_arc); + for fk_arc in &mut child_btree.foreign_keys { + if normalize_ident(&fk_arc.parent_table) != normalized_table_name { + continue; + } + let fk = Arc::make_mut(fk_arc); + for pc in &mut fk.parent_columns { + if pc.eq_ignore_ascii_case(&old_column_name) { + *pc = new_name.clone(); + } + } + } + } } }); @@ -8149,7 +8499,7 @@ pub fn op_alter_column( for (view_name, view) in schema.views.iter() { let view_select_sql = format!("SELECT * FROM {view_name}"); // FIXME: this should rewrite the view to reference the new column name - conn.prepare(view_select_sql.as_str()).map_err(|e| { + let _ = conn.prepare(view_select_sql.as_str()).map_err(|e| { LimboError::ParseError(format!( "cannot rename column \"{}\": referenced in VIEW {view_name}: {}", old_column_name, view.sql, @@ -8234,6 +8584,76 @@ fn handle_text_sum(acc: &mut Value, sum_state: &mut SumAggState, parsed_number: } } +pub fn op_fk_counter( + program: &Program, + state: &mut ProgramState, + insn: &Insn, + pager: &Arc, + mv_store: Option<&Arc>, +) -> Result { + load_insn!( + FkCounter { + increment_value, + deferred, + }, + insn + ); + if !*deferred { + state + .fk_immediate_violations_during_stmt + .fetch_add(*increment_value, Ordering::AcqRel); + } else { + // Transaction-level counter: add/subtract for deferred FKs. + program + .connection + .fk_deferred_violations + .fetch_add(*increment_value, Ordering::AcqRel); + } + + state.pc += 1; + Ok(InsnFunctionStepResult::Step) +} + +pub fn op_fk_if_zero( + program: &Program, + state: &mut ProgramState, + insn: &Insn, + _pager: &Arc, + _mv_store: Option<&Arc>, +) -> Result { + load_insn!( + FkIfZero { + deferred, + target_pc, + }, + insn + ); + let fk_enabled = program.connection.foreign_keys_enabled(); + + // Jump if any: + // Foreign keys are disabled globally + // p1 is true AND deferred constraint counter is zero + // p1 is false AND deferred constraint counter is non-zero + if !fk_enabled { + state.pc = target_pc.as_offset_int(); + return Ok(InsnFunctionStepResult::Step); + } + let v = if *deferred { + program.connection.get_deferred_foreign_key_violations() + } else { + state + .fk_immediate_violations_during_stmt + .load(Ordering::Acquire) + }; + + state.pc = if v == 0 { + target_pc.as_offset_int() + } else { + state.pc + 1 + }; + Ok(InsnFunctionStepResult::Step) +} + mod cmath { extern "C" { pub fn exp(x: f64) -> f64; @@ -8440,14 +8860,17 @@ impl Value { }) } - pub fn exec_random() -> Self { - let mut buf = [0u8; 8]; - getrandom::getrandom(&mut buf).unwrap(); - let random_number = i64::from_ne_bytes(buf); - Value::Integer(random_number) + pub fn exec_random(generate_random_number: F) -> Self + where + F: Fn() -> i64, + { + Value::Integer(generate_random_number()) } - pub fn exec_randomblob(&self) -> Value { + pub fn exec_randomblob(&self, fill_bytes: F) -> Value + where + F: Fn(&mut [u8]), + { let length = match self { Value::Integer(i) => *i, Value::Float(f) => *f as i64, @@ -8457,7 +8880,7 @@ impl Value { .max(1) as usize; let mut blob: Vec = vec![0; length]; - getrandom::getrandom(&mut blob).expect("Failed to generate random blob"); + fill_bytes(&mut blob); Value::Blob(blob) } @@ -9318,7 +9741,12 @@ fn try_float_to_integer_affinity(value: &mut Value, fl: f64) -> bool { false } -fn execute_sqlite_version(version_integer: i64) -> String { +// Compat for applications that test for SQLite. +fn execute_sqlite_version() -> String { + "3.50.4".to_string() +} + +fn execute_turso_version(version_integer: i64) -> String { let major = version_integer / 1_000_000; let minor = (version_integer % 1_000_000) / 1_000; let release = version_integer % 1_000; @@ -9807,8 +10235,25 @@ where } } +fn get_schema_cookie( + pager: &Arc, + mv_store: Option<&Arc>, + program: &Program, +) -> Result> { + if let Some(mv_store) = mv_store { + let tx_id = program.connection.get_mv_tx_id(); + mv_store + .with_header(|header| header.schema_cookie.get(), tx_id.as_ref()) + .map(IOResult::Done) + } else { + pager.get_schema_cookie() + } +} + #[cfg(test)] mod tests { + use rand::{Rng, RngCore}; + use super::*; use crate::types::Value; @@ -10202,7 +10647,7 @@ mod tests { use crate::vdbe::{Bitfield, Register}; - use super::{exec_char, execute_sqlite_version}; + use super::{exec_char, execute_turso_version}; use std::collections::HashMap; #[test] @@ -10569,7 +11014,7 @@ mod tests { #[test] fn test_random() { - match Value::exec_random() { + match Value::exec_random(|| rand::rng().random()) { Value::Integer(value) => { // Check that the value is within the range of i64 assert!( @@ -10632,7 +11077,9 @@ mod tests { ]; for test_case in &test_cases { - let result = test_case.input.exec_randomblob(); + let result = test_case.input.exec_randomblob(|dest| { + rand::rng().fill_bytes(dest); + }); match result { Value::Blob(blob) => { assert_eq!(blob.len(), test_case.expected_len); @@ -10990,7 +11437,7 @@ mod tests { fn test_execute_sqlite_version() { let version_integer = 3046001; let expected = "3.46.1"; - assert_eq!(execute_sqlite_version(version_integer), expected); + assert_eq!(execute_turso_version(version_integer), expected); } #[test] diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index f480a8a4c..4f09e2ea0 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -6,7 +6,11 @@ use super::{Insn, InsnReference, Program, Value}; use crate::function::{Func, ScalarFunc}; pub const EXPLAIN_COLUMNS: [&str; 8] = ["addr", "opcode", "p1", "p2", "p3", "p4", "p5", "comment"]; +pub const EXPLAIN_COLUMNS_TYPE: [&str; 8] = [ + "INTEGER", "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "TEXT", +]; pub const EXPLAIN_QUERY_PLAN_COLUMNS: [&str; 4] = ["id", "parent", "notused", "detail"]; +pub const EXPLAIN_QUERY_PLAN_COLUMNS_TYPE: [&str; 4] = ["INTEGER", "INTEGER", "INTEGER", "TEXT"]; pub fn insn_to_row( program: &Program, @@ -1137,7 +1141,7 @@ pub fn insn_to_row( flag.0 as u16, format!("intkey=r[{key_reg}] data=r[{record_reg}]"), ), - Insn::Delete { cursor_id, table_name } => ( + Insn::Delete { cursor_id, table_name, .. } => ( "Delete", *cursor_id as i32, 0, @@ -1800,7 +1804,25 @@ pub fn insn_to_row( 0, String::new(), ), - } + Insn::FkCounter{increment_value, deferred } => ( + "FkCounter", + *increment_value as i32, + *deferred as i32, + 0, + Value::build_text(""), + 0, + String::new(), + ), + Insn::FkIfZero{target_pc, deferred } => ( + "FkIfZero", + target_pc.as_debug_int(), + *deferred as i32, + 0, + Value::build_text(""), + 0, + String::new(), + ), + } } pub fn insn_to_row_with_comment( diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 67e1b784d..c949d804f 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -112,6 +112,7 @@ pub struct InsertFlags(pub u8); impl InsertFlags { pub const UPDATE_ROWID_CHANGE: u8 = 0x01; // Flag indicating this is part of an UPDATE statement where the row's rowid is changed pub const REQUIRE_SEEK: u8 = 0x02; // Flag indicating that a seek is required to insert the row + pub const EPHEMERAL_TABLE_INSERT: u8 = 0x04; // Flag indicating that this is an insert into an ephemeral table pub fn new() -> Self { InsertFlags(0) @@ -130,6 +131,11 @@ impl InsertFlags { self.0 |= InsertFlags::UPDATE_ROWID_CHANGE; self } + + pub fn is_ephemeral_table_insert(mut self) -> Self { + self.0 |= InsertFlags::EPHEMERAL_TABLE_INSERT; + self + } } #[derive(Clone, Copy, Debug)] @@ -791,6 +797,8 @@ pub enum Insn { Delete { cursor_id: CursorID, table_name: String, + /// Whether the DELETE is part of an UPDATE statement. If so, it doesn't count towards the change counter. + is_part_of_update: bool, }, /// If P5 is not zero, then raise an SQLITE_CORRUPT_INDEX error if no matching index entry @@ -853,10 +861,12 @@ pub enum Insn { db: usize, }, + /// Make a copy of register src..src+extra_amount into dst..dst+extra_amount. Copy { src_reg: usize, dst_reg: usize, - extra_amount: usize, // 0 extra_amount means we include src_reg, dst_reg..=dst_reg+amount = src_reg..=src_reg+amount + /// 0 extra_amount means we include src_reg, dst_reg..=dst_reg+amount = src_reg..=src_reg+amount + extra_amount: usize, }, /// Allocate a new b-tree. @@ -1169,6 +1179,20 @@ pub enum Insn { p2: Option, // P2: address of parent explain instruction detail: String, // P4: detail text }, + // Increment a "constraint counter" by P2 (P2 may be negative or positive). + // If P1 is non-zero, the database constraint counter is incremented (deferred foreign key constraints). + // Otherwise, if P1 is zero, the statement counter is incremented (immediate foreign key constraints). + FkCounter { + increment_value: isize, + deferred: bool, + }, + // This opcode tests if a foreign key constraint-counter is currently zero. If so, jump to instruction P2. Otherwise, fall through to the next instruction. + // If P1 is non-zero, then the jump is taken if the database constraint-counter is zero (the one that counts deferred constraint violations). + // If P1 is zero, the jump is taken if the statement constraint-counter is zero (immediate foreign key constraint violations). + FkIfZero { + deferred: bool, + target_pc: BranchOffset, + }, } const fn get_insn_virtual_table() -> [InsnFunction; InsnVariants::COUNT] { @@ -1335,6 +1359,8 @@ impl InsnVariants { InsnVariants::MemMax => execute::op_mem_max, InsnVariants::Sequence => execute::op_sequence, InsnVariants::SequenceTest => execute::op_sequence_test, + InsnVariants::FkCounter => execute::op_fk_counter, + InsnVariants::FkIfZero => execute::op_fk_if_zero, } } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index fa1a88df8..8eec50ba3 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -29,19 +29,20 @@ use crate::{ error::LimboError, function::{AggFunc, FuncCtx}, mvcc::{database::CommitStateMachine, LocalClock}, + return_if_io, state_machine::StateMachine, - storage::sqlite3_ondisk::SmallVec, + storage::{pager::PagerCommitResult, sqlite3_ondisk::SmallVec}, translate::{collate::CollationSeq, plan::TableReferences}, - types::{IOCompletions, IOResult, RawSlice, TextRef}, + types::{IOCompletions, IOResult}, vdbe::{ execute::{ - OpCheckpointState, OpColumnState, OpDeleteState, OpDeleteSubState, OpIdxInsertState, - OpInsertState, OpInsertSubState, OpNewRowidState, OpNoConflictState, OpRowIdState, - OpSeekState, OpTransactionState, + OpCheckpointState, OpColumnState, OpDeleteState, OpDeleteSubState, OpDestroyState, + OpIdxInsertState, OpInsertState, OpInsertSubState, OpNewRowidState, OpNoConflictState, + OpRowIdState, OpSeekState, OpTransactionState, }, metrics::StatementMetrics, }, - IOExt, RefValue, + ValueRef, }; use crate::{ @@ -66,9 +67,10 @@ use std::{ collections::HashMap, num::NonZero, sync::{ - atomic::{AtomicI64, Ordering}, + atomic::{AtomicI64, AtomicIsize, Ordering}, Arc, }, + task::Waker, }; use tracing::{instrument, Level}; @@ -178,18 +180,6 @@ pub enum StepResult { Busy, } -/// If there is I/O, the instruction is restarted. -/// Evaluate a Result>, if IO return Ok(StepResult::IO). -#[macro_export] -macro_rules! return_step_if_io { - ($expr:expr) => { - match $expr? { - IOResult::Ok(v) => v, - IOResult::IO => return Ok(StepResult::IO), - } - }; -} - struct RegexCache { like: HashMap, glob: HashMap, @@ -265,11 +255,23 @@ pub struct Row { count: usize, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for Row {} +unsafe impl Sync for Row {} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TxnCleanup { + None, + RollbackTxn, +} + /// The program state describes the environment in which the program executes. pub struct ProgramState { pub io_completions: Option, pub pc: InsnReference, cursors: Vec>, + cursor_seqs: Vec, registers: Vec, pub(crate) result_row: Option, last_compare: Option, @@ -284,6 +286,7 @@ pub struct ProgramState { #[cfg(feature = "json")] json_cache: JsonCacheCell, op_delete_state: OpDeleteState, + op_destroy_state: OpDestroyState, op_idx_delete_state: Option, op_integrity_check_state: OpIntegrityCheckState, /// Metrics collected during statement execution @@ -302,16 +305,34 @@ pub struct ProgramState { op_checkpoint_state: OpCheckpointState, /// State machine for committing view deltas with I/O handling view_delta_state: ViewDeltaCommitState, + /// Marker which tells about auto transaction cleanup necessary for that connection in case of reset + /// This is used when statement in auto-commit mode reseted after previous uncomplete execution - in which case we may need to rollback transaction started on previous attempt + /// Note, that MVCC transactions are always explicit - so they do not update auto_txn_cleanup marker + pub(crate) auto_txn_cleanup: TxnCleanup, + /// Number of deferred foreign key violations when the statement started. + /// When a statement subtransaction rolls back, the connection's deferred foreign key violations counter + /// is reset to this value. + fk_deferred_violations_when_stmt_started: AtomicIsize, + /// Number of immediate foreign key violations that occurred during the active statement. If nonzero, + /// the statement subtransactionwill roll back. + fk_immediate_violations_during_stmt: AtomicIsize, } +// SAFETY: This needs to be audited for thread safety. +// See: https://github.com/tursodatabase/turso/issues/1552 +unsafe impl Send for ProgramState {} +unsafe impl Sync for ProgramState {} + impl ProgramState { pub fn new(max_registers: usize, max_cursors: usize) -> Self { let cursors: Vec> = (0..max_cursors).map(|_| None).collect(); + let cursor_seqs = vec![0i64; max_cursors]; let registers = vec![Register::Value(Value::Null); max_registers]; Self { io_completions: None, pc: 0, cursors, + cursor_seqs, registers, result_row: None, last_compare: None, @@ -328,6 +349,7 @@ impl ProgramState { sub_state: OpDeleteSubState::MaybeCaptureRecord, deleted_record: None, }, + op_destroy_state: OpDestroyState::CreateCursor, op_idx_delete_state: None, op_integrity_check_state: OpIntegrityCheckState::Start, metrics: StatementMetrics::new(), @@ -346,6 +368,9 @@ impl ProgramState { op_transaction_state: OpTransactionState::Start, op_checkpoint_state: OpCheckpointState::StartCheckpoint, view_delta_state: ViewDeltaCommitState::NotStarted, + auto_txn_cleanup: TxnCleanup::None, + fk_deferred_violations_when_stmt_started: AtomicIsize::new(0), + fk_immediate_violations_during_stmt: AtomicIsize::new(0), } } @@ -390,6 +415,7 @@ impl ProgramState { if let Some(max_cursors) = max_cursors { self.cursors.resize_with(max_cursors, || None); + self.cursor_seqs.resize(max_cursors, 0); } if let Some(max_resgisters) = max_registers { self.registers @@ -428,6 +454,11 @@ impl ProgramState { self.op_column_state = OpColumnState::Start; self.op_row_id_state = OpRowIdState::Start; self.view_delta_state = ViewDeltaCommitState::NotStarted; + self.auto_txn_cleanup = TxnCleanup::None; + self.fk_immediate_violations_during_stmt + .store(0, Ordering::SeqCst); + self.fk_deferred_violations_when_stmt_started + .store(0, Ordering::SeqCst); } pub fn get_cursor(&mut self, cursor_id: CursorID) -> &mut Cursor { @@ -437,6 +468,63 @@ impl ProgramState { .as_mut() .unwrap_or_else(|| panic!("cursor id {cursor_id} is None")) } + + /// Begin a statement subtransaction. + pub fn begin_statement( + &mut self, + connection: &Connection, + pager: &Arc, + write: bool, + ) -> Result> { + // Store the deferred foreign key violations counter at the start of the statement. + // This is used to ensure that if an interactive transaction had deferred FK violations and a statement subtransaction rolls back, + // the deferred FK violations are not lost. + self.fk_deferred_violations_when_stmt_started.store( + connection.fk_deferred_violations.load(Ordering::Acquire), + Ordering::SeqCst, + ); + // Reset the immediate foreign key violations counter to 0. If this is nonzero when the statement completes, the statement subtransaction will roll back. + self.fk_immediate_violations_during_stmt + .store(0, Ordering::SeqCst); + if write { + let db_size = return_if_io!(pager.with_header(|header| header.database_size.get())); + pager.begin_statement(db_size)?; + } + Ok(IOResult::Done(())) + } + + /// End a statement subtransaction. + pub fn end_statement( + &mut self, + connection: &Connection, + pager: &Arc, + end_statement: EndStatement, + ) -> Result<()> { + match end_statement { + EndStatement::ReleaseSavepoint => pager.release_savepoint(), + EndStatement::RollbackSavepoint => { + pager.rollback_to_newest_savepoint()?; + // Reset the deferred foreign key violations counter to the value it had at the start of the statement. + // This is used to ensure that if an interactive transaction had deferred FK violations, they are not lost. + connection.fk_deferred_violations.store( + self.fk_deferred_violations_when_stmt_started + .load(Ordering::Acquire), + Ordering::SeqCst, + ); + Ok(()) + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Action to take at the end of a statement subtransaction. +pub enum EndStatement { + /// Release (commit) the savepoint -- effectively removing the savepoint as it is no longer needed for undo purposes. + ReleaseSavepoint, + /// Rollback (abort) to the newest savepoint: read pages from the subjournal and restore them to the page cache. + /// This is used to undo the changes made by the statement. + RollbackSavepoint, } impl Register { @@ -501,6 +589,10 @@ pub struct Program { /// Used to determine whether we need to check for schema changes when /// starting a transaction. pub accesses_db: bool, + /// In SQLite, whether statement subtransactions will be used for executing a program (`usesStmtJournal`) + /// is determined by the parser flags "mayAbort" and "isMultiWrite". Essentially this means that the individual + /// statement may need to be aborted due to a constraint conflict, etc. instead of the entire transaction. + pub needs_stmt_subtransactions: bool, } impl Program { @@ -514,9 +606,10 @@ impl Program { mv_store: Option<&Arc>, pager: Arc, query_mode: QueryMode, + waker: Option<&Waker>, ) -> Result { match query_mode { - QueryMode::Normal => self.normal_step(state, mv_store, pager), + QueryMode::Normal => self.normal_step(state, mv_store, pager, waker), QueryMode::Explain => self.explain_step(state, mv_store, pager), QueryMode::ExplainQueryPlan => self.explain_query_plan_step(state, mv_store, pager), } @@ -533,7 +626,7 @@ impl Program { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.get_tx_state(); if let TransactionState::Write { .. } = state { - pager.io.block(|| pager.end_tx(true, &self.connection))?; + pager.rollback_tx(&self.connection); } return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -588,7 +681,7 @@ impl Program { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.get_tx_state(); if let TransactionState::Write { .. } = state { - pager.io.block(|| pager.end_tx(true, &self.connection))?; + pager.rollback_tx(&self.connection); } return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -629,6 +722,7 @@ impl Program { state: &mut ProgramState, mv_store: Option<&Arc>, pager: Arc, + waker: Option<&Waker>, ) -> Result { let enable_tracing = tracing::enabled!(tracing::Level::TRACE); loop { @@ -636,20 +730,22 @@ impl Program { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.get_tx_state(); if let TransactionState::Write { .. } = state { - pager.io.block(|| pager.end_tx(true, &self.connection))?; + pager.rollback_tx(&self.connection); } return Err(LimboError::InternalError("Connection closed".to_string())); } if state.is_interrupted() { + self.abort(mv_store, &pager, None, &mut state.auto_txn_cleanup); return Ok(StepResult::Interrupt); } if let Some(io) = &state.io_completions { if !io.finished() { + io.set_waker(waker); return Ok(StepResult::IO); } if let Some(err) = io.get_error() { let err = err.into(); - handle_program_error(&pager, &self.connection, &err, mv_store)?; + self.abort(mv_store, &pager, Some(&err), &mut state.auto_txn_cleanup); return Err(err); } state.io_completions = None; @@ -672,10 +768,12 @@ impl Program { Ok(InsnFunctionStepResult::Done) => { // Instruction completed execution state.metrics.insn_executed = state.metrics.insn_executed.saturating_add(1); + state.auto_txn_cleanup = TxnCleanup::None; return Ok(StepResult::Done); } Ok(InsnFunctionStepResult::IO(io)) => { // Instruction not complete - waiting for I/O, will resume at same PC + io.set_waker(waker); state.io_completions = Some(io); return Ok(StepResult::IO); } @@ -684,16 +782,12 @@ impl Program { state.metrics.insn_executed = state.metrics.insn_executed.saturating_add(1); return Ok(StepResult::Row); } - Ok(InsnFunctionStepResult::Interrupt) => { - // Instruction interrupted - may resume at same PC - return Ok(StepResult::Interrupt); - } Err(LimboError::Busy) => { // Instruction blocked - will retry at same PC return Ok(StepResult::Busy); } Err(err) => { - handle_program_error(&pager, &self.connection, &err, mv_store)?; + self.abort(mv_store, &pager, Some(&err), &mut state.auto_txn_cleanup); return Err(err); } } @@ -818,7 +912,6 @@ impl Program { // Reset state for next use program_state.view_delta_state = ViewDeltaCommitState::NotStarted; - if self.connection.get_tx_state() == TransactionState::None { // No need to do any work here if not in tx. Current MVCC logic doesn't work with this assumption, // hence the mv_store.is_none() check. @@ -888,7 +981,7 @@ impl Program { ), TransactionState::Read => { connection.set_tx_state(TransactionState::None); - pager.end_read_tx()?; + pager.end_read_tx(); Ok(IOResult::Done(())) } TransactionState::None => Ok(IOResult::Done(())), @@ -914,7 +1007,12 @@ impl Program { connection: &Connection, rollback: bool, ) -> Result> { - let cacheflush_status = pager.end_tx(rollback, connection)?; + let cacheflush_status = if !rollback { + pager.commit_tx(connection)? + } else { + pager.rollback_tx(connection); + IOResult::Done(PagerCommitResult::Rollback) + }; match cacheflush_status { IOResult::Done(_) => { if self.change_cnt_on { @@ -941,6 +1039,47 @@ impl Program { ) -> Result> { commit_state.step(mv_store) } + + /// Aborts the program due to various conditions (explicit error, interrupt or reset of unfinished statement) by rolling back the transaction + /// This method is no-op if program was already finished (either aborted or executed to completion) + pub fn abort( + &self, + mv_store: Option<&Arc>, + pager: &Arc, + err: Option<&LimboError>, + cleanup: &mut TxnCleanup, + ) { + // Errors from nested statements are handled by the parent statement. + if !self.connection.is_nested_stmt.load(Ordering::SeqCst) { + match err { + // Transaction errors, e.g. trying to start a nested transaction, do not cause a rollback. + Some(LimboError::TxError(_)) => {} + // Table locked errors, e.g. trying to checkpoint in an interactive transaction, do not cause a rollback. + Some(LimboError::TableLocked) => {} + // Busy errors do not cause a rollback. + Some(LimboError::Busy) => {} + // Constraint errors do not cause a rollback of the transaction by default; + // Instead individual statement subtransactions will roll back and these are handled in op_auto_commit + // and op_halt. + Some(LimboError::Constraint(_)) => {} + _ => { + if *cleanup != TxnCleanup::None || err.is_some() { + if let Some(mv_store) = mv_store { + if let Some(tx_id) = self.connection.get_mv_tx_id() { + self.connection.auto_commit.store(true, Ordering::SeqCst); + mv_store.rollback_tx(tx_id, pager.clone(), &self.connection); + } + } else { + pager.rollback_tx(&self.connection); + self.connection.auto_commit.store(true, Ordering::SeqCst); + } + self.connection.set_tx_state(TransactionState::None); + } + } + } + } + *cleanup = TxnCleanup::None; + } } fn make_record(registers: &[Register], start_reg: &usize, count: &usize) -> ImmutableRecord { @@ -948,22 +1087,10 @@ fn make_record(registers: &[Register], start_reg: &usize, count: &usize) -> Immu ImmutableRecord::from_registers(regs, regs.len()) } -pub fn registers_to_ref_values(registers: &[Register]) -> Vec { +pub fn registers_to_ref_values<'a>(registers: &'a [Register]) -> Vec> { registers .iter() - .map(|reg| { - let value = reg.get_value(); - match value { - Value::Null => RefValue::Null, - Value::Integer(i) => RefValue::Integer(*i), - Value::Float(f) => RefValue::Float(*f), - Value::Text(t) => RefValue::Text(TextRef { - value: RawSlice::new(t.value.as_ptr(), t.value.len()), - subtype: t.subtype, - }), - Value::Blob(b) => RefValue::Blob(RawSlice::new(b.as_ptr(), b.len())), - } - }) + .map(|reg| reg.get_value().as_ref()) .collect() } @@ -1063,42 +1190,3 @@ impl Row { self.count } } - -/// Handle a program error by rolling back the transaction -pub fn handle_program_error( - pager: &Arc, - connection: &Connection, - err: &LimboError, - mv_store: Option<&Arc>, -) -> Result<()> { - if connection.is_nested_stmt.load(Ordering::SeqCst) { - // Errors from nested statements are handled by the parent statement. - return Ok(()); - } - match err { - // Transaction errors, e.g. trying to start a nested transaction, do not cause a rollback. - LimboError::TxError(_) => {} - // Table locked errors, e.g. trying to checkpoint in an interactive transaction, do not cause a rollback. - LimboError::TableLocked => {} - // Busy errors do not cause a rollback. - LimboError::Busy => {} - _ => { - if let Some(mv_store) = mv_store { - if let Some(tx_id) = connection.get_mv_tx_id() { - connection.set_tx_state(TransactionState::None); - connection.auto_commit.store(true, Ordering::SeqCst); - mv_store.rollback_tx(tx_id, pager.clone(), connection)?; - } - } else { - pager - .io - .block(|| pager.end_tx(true, connection)) - .inspect_err(|e| { - tracing::error!("end_tx failed: {e}"); - })?; - } - connection.set_tx_state(TransactionState::None); - } - } - Ok(()) -} diff --git a/core/vdbe/sorter.rs b/core/vdbe/sorter.rs index f2ef80c69..81abde222 100644 --- a/core/vdbe/sorter.rs +++ b/core/vdbe/sorter.rs @@ -1,6 +1,5 @@ use turso_parser::ast::SortOrder; -use std::cell::RefCell; use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd, Reverse}; use std::collections::BinaryHeap; use std::rc::Rc; @@ -8,17 +7,16 @@ use std::sync::{atomic, Arc, RwLock}; use tempfile; use crate::types::IOCompletions; -use crate::util::IOExt; use crate::{ error::LimboError, - io::{Buffer, Completion, File, OpenFlags, IO}, + io::{Buffer, Completion, CompletionGroup, File, OpenFlags, IO}, storage::sqlite3_ondisk::{read_varint, varint_len, write_varint}, translate::collate::CollationSeq, turso_assert, - types::{IOResult, ImmutableRecord, KeyInfo, RecordCursor, RefValue}, + types::{IOResult, ImmutableRecord, KeyInfo, RecordCursor, ValueRef}, Result, }; -use crate::{io_yield_many, io_yield_one, return_if_io, CompletionError}; +use crate::{io_yield_one, return_if_io, CompletionError}; #[derive(Debug, Clone, Copy)] enum SortState { @@ -88,7 +86,7 @@ pub struct Sorter { insert_state: InsertState, /// State machine for [Sorter::init_chunk_heap] init_chunk_heap_state: InitChunkHeapState, - seq_count: i64, + pending_completions: Vec, } impl Sorter { @@ -126,7 +124,7 @@ impl Sorter { sort_state: SortState::Start, insert_state: InsertState::Start, init_chunk_heap_state: InitChunkHeapState::Start, - seq_count: 0, + pending_completions: Vec::new(), } } @@ -138,21 +136,6 @@ impl Sorter { self.current.is_some() } - /// Get current sequence count and increment it - pub fn next_sequence(&mut self) -> i64 { - let current = self.seq_count; - self.seq_count += 1; - current - } - - /// Test if at beginning of sequence (count == 0) and increment - /// Returns true if this was the first call (seq_count was 0) - pub fn seq_beginning(&mut self) -> bool { - let was_zero = self.seq_count == 0; - self.seq_count += 1; - was_zero - } - // We do the sorting here since this is what is called by the SorterSort instruction pub fn sort(&mut self) -> Result> { loop { @@ -204,7 +187,7 @@ impl Sorter { }; match record { Some(record) => { - if let Some(error) = record.deserialization_error.replace(None) { + if let Some(error) = record.deserialization_error { // If there was a key deserialization error during the comparison, return the error. return Err(error); } @@ -227,20 +210,13 @@ impl Sorter { self.insert_state = InsertState::Insert; if self.current_buffer_size + payload_size > self.max_buffer_size { if let Some(c) = self.flush()? { - io_yield_one!(c); + if !c.succeeded() { + io_yield_one!(c); + } } } } InsertState::Insert => { - turso_assert!( - !self.chunks.iter().any(|chunk| { - matches!( - *chunk.io_state.read().unwrap(), - SortedChunkIOState::WaitingForWrite - ) - }), - "chunks should have written" - ); self.records.push(SortableImmutableRecord::new( record.clone(), self.key_len, @@ -259,20 +235,21 @@ impl Sorter { fn init_chunk_heap(&mut self) -> Result> { match self.init_chunk_heap_state { InitChunkHeapState::Start => { - let mut completions: Vec = Vec::with_capacity(self.chunks.len()); + let mut group = CompletionGroup::new(|_| {}); for chunk in self.chunks.iter_mut() { match chunk.read() { Err(e) => { tracing::error!("Failed to read chunk: {e}"); - self.io.cancel(&completions)?; + group.cancel(); self.io.drain()?; return Err(e); } - Ok(c) => completions.push(c), + Ok(c) => group.add(&c), }; } self.init_chunk_heap_state = InitChunkHeapState::PushChunk; - io_yield_many!(completions); + let completion = group.build(); + io_yield_one!(completion); } InitChunkHeapState::PushChunk => { // Make sure all chunks read at least one record into their buffer. @@ -285,51 +262,61 @@ impl Sorter { ); self.chunk_heap.reserve(self.chunks.len()); // TODO: blocking will be unnecessary here with IO completions - let io = self.io.clone(); + let mut group = CompletionGroup::new(|_| {}); for chunk_idx in 0..self.chunks.len() { - io.block(|| self.push_to_chunk_heap(chunk_idx))?; + if let Some(c) = self.push_to_chunk_heap(chunk_idx)? { + group.add(&c); + }; } self.init_chunk_heap_state = InitChunkHeapState::Start; - Ok(IOResult::Done(())) + let completion = group.build(); + if completion.finished() { + Ok(IOResult::Done(())) + } else { + io_yield_one!(completion); + } } } } fn next_from_chunk_heap(&mut self) -> Result>> { + if !self.pending_completions.is_empty() { + let mut group = CompletionGroup::new(|_| {}); + for c in self.pending_completions.drain(..) { + group.add(&c); + } + return Ok(IOResult::IO(IOCompletions::Single(group.build()))); + } // Make sure all chunks read at least one record into their buffer. - turso_assert!( - !self.chunks.iter().any(|chunk| matches!( - *chunk.io_state.read().unwrap(), - SortedChunkIOState::WaitingForRead - )), - "chunks should have been read" - ); - if let Some((next_record, next_chunk_idx)) = self.chunk_heap.pop() { // TODO: blocking will be unnecessary here with IO completions - let io = self.io.clone(); - io.block(|| self.push_to_chunk_heap(next_chunk_idx))?; + if let Some(c) = self.push_to_chunk_heap(next_chunk_idx)? { + self.pending_completions.push(c); + } Ok(IOResult::Done(Some(next_record.0))) } else { Ok(IOResult::Done(None)) } } - fn push_to_chunk_heap(&mut self, chunk_idx: usize) -> Result> { + fn push_to_chunk_heap(&mut self, chunk_idx: usize) -> Result> { let chunk = &mut self.chunks[chunk_idx]; - if let Some(record) = return_if_io!(chunk.next()) { - self.chunk_heap.push(( - Reverse(SortableImmutableRecord::new( - record, - self.key_len, - self.index_key_info.clone(), - )?), - chunk_idx, - )); + match chunk.next()? { + ChunkNextResult::Done(Some(record)) => { + self.chunk_heap.push(( + Reverse(SortableImmutableRecord::new( + record, + self.key_len, + self.index_key_info.clone(), + )?), + chunk_idx, + )); + Ok(None) + } + ChunkNextResult::Done(None) => Ok(None), + ChunkNextResult::IO(io) => Ok(Some(io)), } - - Ok(IOResult::Done(())) } fn flush(&mut self) -> Result> { @@ -413,6 +400,11 @@ struct SortedChunk { next_state: NextState, } +enum ChunkNextResult { + Done(Option), + IO(Completion), +} + impl SortedChunk { fn new(file: Arc, start_offset: usize, buffer_size: usize) -> Self { Self { @@ -436,13 +428,13 @@ impl SortedChunk { self.buffer_len.store(len, atomic::Ordering::SeqCst); } - fn next(&mut self) -> Result>> { + fn next(&mut self) -> Result { loop { match self.next_state { NextState::Start => { let mut buffer_len = self.buffer_len(); if self.records.is_empty() && buffer_len == 0 { - return Ok(IOResult::Done(None)); + return Ok(ChunkNextResult::Done(None)); } if self.records.is_empty() { @@ -506,13 +498,15 @@ impl SortedChunk { *self.io_state.write().unwrap() = SortedChunkIOState::ReadEOF; } else { let c = self.read()?; - io_yield_one!(c); + if !c.succeeded() { + return Ok(ChunkNextResult::IO(c)); + } } } } NextState::Finish => { self.next_state = NextState::Start; - return Ok(IOResult::Done(self.records.pop())); + return Ok(ChunkNextResult::Done(self.records.pop())); } } } @@ -614,11 +608,14 @@ impl SortedChunk { struct SortableImmutableRecord { record: ImmutableRecord, cursor: RecordCursor, - key_values: RefCell>, + /// SAFETY: borrows from self + /// These are precomputed on record construction so that they can be reused during + /// sorting comparisons. + key_values: Vec>, key_len: usize, index_key_info: Rc>, /// The key deserialization error, if any. - deserialization_error: RefCell>, + deserialization_error: Option, } impl SortableImmutableRecord { @@ -633,40 +630,40 @@ impl SortableImmutableRecord { index_key_info.len() >= cursor.serial_types.len(), "index_key_info.len() < cursor.serial_types.len()" ); + + // Pre-compute all key values upfront + let mut key_values = Vec::with_capacity(key_len); + let mut deserialization_error = None; + + for i in 0..key_len { + match cursor.deserialize_column(&record, i) { + Ok(value) => { + // SAFETY: We're storing the value with 'static lifetime but it's actually bound to the record + // This is safe because the record lives as long as this struct + let value: ValueRef<'static> = unsafe { std::mem::transmute(value) }; + key_values.push(value); + } + Err(err) => { + deserialization_error = Some(err); + break; + } + } + } + Ok(Self { record, cursor, - key_values: RefCell::new(Vec::with_capacity(key_len)), + key_values, index_key_info, - deserialization_error: RefCell::new(None), + deserialization_error, key_len, }) } - - /// Attempts to deserialize the key value at the given index. - /// If the key value has already been deserialized, this does nothing. - /// The deserialized key value is stored in the `key_values` field. - /// In case of an error, the error is stored in the `deserialization_error` field. - fn try_deserialize_key(&self, idx: usize) { - let mut key_values = self.key_values.borrow_mut(); - if idx < key_values.len() { - // The key value with this index has already been deserialized. - return; - } - match self.cursor.deserialize_column(&self.record, idx) { - Ok(value) => key_values.push(value), - Err(error) => { - self.deserialization_error.replace(Some(error)); - } - } - } } impl Ord for SortableImmutableRecord { fn cmp(&self, other: &Self) -> Ordering { - if self.deserialization_error.borrow().is_some() - || other.deserialization_error.borrow().is_some() - { + if self.deserialization_error.is_some() || other.deserialization_error.is_some() { // If one of the records has a deserialization error, circumvent the comparison and return early. return Ordering::Equal; } @@ -674,34 +671,21 @@ impl Ord for SortableImmutableRecord { self.cursor.serial_types.len(), other.cursor.serial_types.len() ); - let this_key_values_len = self.key_values.borrow().len(); - let other_key_values_len = other.key_values.borrow().len(); for i in 0..self.key_len { - // Lazily deserialize the key values if they haven't been deserialized already. - if i >= this_key_values_len { - self.try_deserialize_key(i); - if self.deserialization_error.borrow().is_some() { - return Ordering::Equal; - } - } - if i >= other_key_values_len { - other.try_deserialize_key(i); - if other.deserialization_error.borrow().is_some() { - return Ordering::Equal; - } - } + let this_key_value = self.key_values[i]; + let other_key_value = other.key_values[i]; - let this_key_value = &self.key_values.borrow()[i]; - let other_key_value = &other.key_values.borrow()[i]; let column_order = self.index_key_info[i].sort_order; let collation = self.index_key_info[i].collation; let cmp = match (this_key_value, other_key_value) { - (RefValue::Text(left), RefValue::Text(right)) => { - collation.compare_strings(left.as_str(), right.as_str()) - } - _ => this_key_value.partial_cmp(other_key_value).unwrap(), + (ValueRef::Text(left, _), ValueRef::Text(right, _)) => collation.compare_strings( + // SAFETY: these were checked to be valid UTF-8 on construction. + unsafe { std::str::from_utf8_unchecked(left) }, + unsafe { std::str::from_utf8_unchecked(right) }, + ), + _ => this_key_value.partial_cmp(&other_key_value).unwrap(), }; if !cmp.is_eq() { return match column_order { @@ -742,7 +726,7 @@ enum SortedChunkIOState { mod tests { use super::*; use crate::translate::collate::CollationSeq; - use crate::types::{ImmutableRecord, RefValue, Value, ValueType}; + use crate::types::{ImmutableRecord, Value, ValueRef, ValueType}; use crate::util::IOExt; use crate::PlatformIO; use rand_chacha::{ @@ -806,7 +790,7 @@ mod tests { for i in 0..num_records { assert!(sorter.has_more()); let record = sorter.record().unwrap(); - assert_eq!(record.get_values()[0], RefValue::Integer(i)); + assert_eq!(record.get_values()[0], ValueRef::Integer(i)); // Check that the record remained unchanged after sorting. assert_eq!(record, &initial_records[(num_records - i - 1) as usize]); diff --git a/core/vector/distance.rs b/core/vector/distance.rs deleted file mode 100644 index e61c24c70..000000000 --- a/core/vector/distance.rs +++ /dev/null @@ -1,25 +0,0 @@ -use super::vector_types::Vector; -use crate::Result; - -pub(crate) mod euclidean; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[non_exhaustive] -pub enum DistanceType { - /// Euclidean distance. This is a very common distance metric that - /// accounts for both magnitude and direction when determining the distance - /// between vectors. Euclidean distance has a range of [0, ∞). - Euclidean, - - // TODO(asukamilet): Refactor the current `vector_types.rs` to integrate - #[allow(dead_code)] - /// Cosine distance. This is a measure of similarity between two vectors - Cosine, -} - -pub trait DistanceCalculator { - #[allow(unused)] - fn distance_type() -> DistanceType; - - fn calculate(v1: &Vector, v2: &Vector) -> Result; -} diff --git a/core/vector/distance/euclidean.rs b/core/vector/distance/euclidean.rs deleted file mode 100644 index f8e4f048d..000000000 --- a/core/vector/distance/euclidean.rs +++ /dev/null @@ -1,72 +0,0 @@ -use super::{DistanceCalculator, DistanceType}; -use crate::vector::vector_types::{Vector, VectorType}; -use crate::Result; - -#[derive(Debug, Clone)] -pub struct Euclidean; - -impl DistanceCalculator for Euclidean { - fn distance_type() -> DistanceType { - DistanceType::Euclidean - } - - fn calculate(v1: &Vector, v2: &Vector) -> Result { - match v1.vector_type { - VectorType::Float32 => Ok(euclidean_distance_f32(v1.as_f32_slice(), v2.as_f32_slice())), - VectorType::Float64 => Ok(euclidean_distance_f64(v1.as_f64_slice(), v2.as_f64_slice())), - } - } -} - -fn euclidean_distance_f32(v1: &[f32], v2: &[f32]) -> f64 { - let sum = v1 - .iter() - .zip(v2.iter()) - .map(|(a, b)| (a - b).powi(2)) - .sum::() as f64; - sum.sqrt() -} - -fn euclidean_distance_f64(v1: &[f64], v2: &[f64]) -> f64 { - let sum = v1 - .iter() - .zip(v2.iter()) - .map(|(a, b)| (a - b).powi(2)) - .sum::(); - sum.sqrt() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_euclidean_distance_f32() { - let vectors = [ - (0..8).map(|x| x as f32).collect::>(), - (1..9).map(|x| x as f32).collect::>(), - (2..10).map(|x| x as f32).collect::>(), - (3..11).map(|x| x as f32).collect::>(), - ]; - let query = (2..10).map(|x| x as f32).collect::>(); - - let expected: Vec = vec![ - 32.0_f64.sqrt(), - 8.0_f64.sqrt(), - 0.0_f64.sqrt(), - 8.0_f64.sqrt(), - ]; - let results = vectors - .iter() - .map(|v| euclidean_distance_f32(&query, v)) - .collect::>(); - assert_eq!(results, expected); - } - - #[test] - fn test_odd_len() { - let v = (0..5).map(|x| x as f32).collect::>(); - let query = (2..7).map(|x| x as f32).collect::>(); - assert_eq!(euclidean_distance_f32(&v, &query), 20.0_f64.sqrt()); - } -} diff --git a/core/vector/mod.rs b/core/vector/mod.rs index 2fc960849..14cb2a462 100644 --- a/core/vector/mod.rs +++ b/core/vector/mod.rs @@ -1,28 +1,53 @@ use crate::types::Value; +use crate::types::ValueType; use crate::vdbe::Register; -use crate::vector::distance::{euclidean::Euclidean, DistanceCalculator}; use crate::LimboError; use crate::Result; -pub mod distance; +pub mod operations; pub mod vector_types; use vector_types::*; +pub fn parse_vector(value: &Register, type_hint: Option) -> Result { + match value.get_value().value_type() { + ValueType::Text => operations::text::vector_from_text( + type_hint.unwrap_or(VectorType::Float32Dense), + value.get_value().to_text().expect("value must be text"), + ), + ValueType::Blob => { + let Some(blob) = value.get_value().to_blob() else { + return Err(LimboError::ConversionError( + "Invalid vector value".to_string(), + )); + }; + Vector::from_slice(blob) + } + _ => Err(LimboError::ConversionError( + "Invalid vector type".to_string(), + )), + } +} + pub fn vector32(args: &[Register]) -> Result { if args.len() != 1 { return Err(LimboError::ConversionError( "vector32 requires exactly one argument".to_string(), )); } - let x = parse_vector(&args[0], Some(VectorType::Float32))?; - // Extract the Vec from Value - if let Value::Blob(data) = vector_serialize_f32(x) { - Ok(Value::Blob(data)) - } else { - Err(LimboError::ConversionError( - "Failed to serialize vector".to_string(), - )) + let vector = parse_vector(&args[0], Some(VectorType::Float32Dense))?; + let vector = operations::convert::vector_convert(vector, VectorType::Float32Dense)?; + Ok(operations::serialize::vector_serialize(vector)) +} + +pub fn vector32_sparse(args: &[Register]) -> Result { + if args.len() != 1 { + return Err(LimboError::ConversionError( + "vector32_sparse requires exactly one argument".to_string(), + )); } + let vector = parse_vector(&args[0], Some(VectorType::Float32Sparse))?; + let vector = operations::convert::vector_convert(vector, VectorType::Float32Sparse)?; + Ok(operations::serialize::vector_serialize(vector)) } pub fn vector64(args: &[Register]) -> Result { @@ -31,15 +56,9 @@ pub fn vector64(args: &[Register]) -> Result { "vector64 requires exactly one argument".to_string(), )); } - let x = parse_vector(&args[0], Some(VectorType::Float64))?; - // Extract the Vec from Value - if let Value::Blob(data) = vector_serialize_f64(x) { - Ok(Value::Blob(data)) - } else { - Err(LimboError::ConversionError( - "Failed to serialize vector".to_string(), - )) - } + let vector = parse_vector(&args[0], Some(VectorType::Float64Dense))?; + let vector = operations::convert::vector_convert(vector, VectorType::Float64Dense)?; + Ok(operations::serialize::vector_serialize(vector)) } pub fn vector_extract(args: &[Register]) -> Result { @@ -62,9 +81,8 @@ pub fn vector_extract(args: &[Register]) -> Result { return Ok(Value::build_text("[]")); } - let vector_type = vector_type(blob)?; - let vector = vector_deserialize(vector_type, blob)?; - Ok(Value::build_text(vector_to_text(&vector))) + let vector = Vector::from_vec(blob.to_vec())?; + Ok(Value::build_text(operations::text::vector_to_text(&vector))) } pub fn vector_distance_cos(args: &[Register]) -> Result { @@ -76,7 +94,7 @@ pub fn vector_distance_cos(args: &[Register]) -> Result { let x = parse_vector(&args[0], None)?; let y = parse_vector(&args[1], None)?; - let dist = do_vector_distance_cos(&x, &y)?; + let dist = operations::distance_cos::vector_distance_cos(&x, &y)?; Ok(Value::Float(dist)) } @@ -89,19 +107,20 @@ pub fn vector_distance_l2(args: &[Register]) -> Result { let x = parse_vector(&args[0], None)?; let y = parse_vector(&args[1], None)?; - // Validate that both vectors have the same dimensions and type - if x.dims != y.dims { + let dist = operations::distance_l2::vector_distance_l2(&x, &y)?; + Ok(Value::Float(dist)) +} + +pub fn vector_distance_jaccard(args: &[Register]) -> Result { + if args.len() != 2 { return Err(LimboError::ConversionError( - "Vectors must have the same dimensions".to_string(), - )); - } - if x.vector_type != y.vector_type { - return Err(LimboError::ConversionError( - "Vectors must be of the same type".to_string(), + "distance_jaccard requires exactly two arguments".to_string(), )); } - let dist = Euclidean::calculate(&x, &y)?; + let x = parse_vector(&args[0], None)?; + let y = parse_vector(&args[1], None)?; + let dist = operations::jaccard::vector_distance_jaccard(&x, &y)?; Ok(Value::Float(dist)) } @@ -114,18 +133,8 @@ pub fn vector_concat(args: &[Register]) -> Result { let x = parse_vector(&args[0], None)?; let y = parse_vector(&args[1], None)?; - - if x.vector_type != y.vector_type { - return Err(LimboError::InvalidArgument( - "Vectors must be of the same type".into(), - )); - } - - let vector = vector_types::vector_concat(&x, &y)?; - match vector.vector_type { - VectorType::Float32 => Ok(vector_serialize_f32(vector)), - VectorType::Float64 => Ok(vector_serialize_f64(vector)), - } + let vector = operations::concat::vector_concat(&x, &y)?; + Ok(operations::serialize::vector_serialize(vector)) } pub fn vector_slice(args: &[Register]) -> Result { @@ -153,10 +162,8 @@ pub fn vector_slice(args: &[Register]) -> Result { )); } - let result = vector_types::vector_slice(&vector, start_index as usize, end_index as usize)?; + let result = + operations::slice::vector_slice(&vector, start_index as usize, end_index as usize)?; - Ok(match result.vector_type { - VectorType::Float32 => vector_serialize_f32(result), - VectorType::Float64 => vector_serialize_f64(result), - }) + Ok(operations::serialize::vector_serialize(result)) } diff --git a/core/vector/operations/concat.rs b/core/vector/operations/concat.rs new file mode 100644 index 000000000..8823568ff --- /dev/null +++ b/core/vector/operations/concat.rs @@ -0,0 +1,119 @@ +use crate::{ + vector::vector_types::{Vector, VectorType}, + LimboError, Result, +}; + +pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result> { + if v1.vector_type != v2.vector_type { + return Err(LimboError::ConversionError( + "Mismatched vector types".into(), + )); + } + + let data = match v1.vector_type { + VectorType::Float32Dense | VectorType::Float64Dense => { + let mut data = Vec::with_capacity(v1.bin_len() + v2.bin_len()); + data.extend_from_slice(v1.bin_data()); + data.extend_from_slice(v2.bin_data()); + data + } + VectorType::Float32Sparse => { + let mut data = Vec::with_capacity(v1.bin_len() + v2.bin_len()); + data.extend_from_slice(&v1.bin_data()[..v1.bin_len() / 2]); + data.extend_from_slice(&v2.bin_data()[..v2.bin_len() / 2]); + data.extend_from_slice(&v1.bin_data()[v1.bin_len() / 2..]); + data.extend_from_slice(&v2.bin_data()[v2.bin_len() / 2..]); + data + } + }; + + Ok(Vector { + vector_type: v1.vector_type, + dims: v1.dims + v2.dims, + owned: Some(data), + refer: None, + }) +} + +#[cfg(test)] +mod tests { + use crate::vector::{ + operations::concat::vector_concat, + vector_types::{Vector, VectorType}, + }; + + fn float32_vec_from(slice: &[f32]) -> Vector<'static> { + let mut data = Vec::new(); + for &v in slice { + data.extend_from_slice(&v.to_le_bytes()); + } + + Vector { + vector_type: VectorType::Float32Dense, + dims: slice.len(), + owned: Some(data), + refer: None, + } + } + + fn f32_slice_from_vector(vector: &Vector) -> Vec { + vector.as_f32_slice().to_vec() + } + + #[test] + fn test_vector_concat_normal_case() { + let v1 = float32_vec_from(&[1.0, 2.0, 3.0]); + let v2 = float32_vec_from(&[4.0, 5.0, 6.0]); + + let result = vector_concat(&v1, &v2).unwrap(); + + assert_eq!(result.dims, 6); + assert_eq!(result.vector_type, VectorType::Float32Dense); + assert_eq!( + f32_slice_from_vector(&result), + vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + ); + } + + #[test] + fn test_vector_concat_empty_left() { + let v1 = float32_vec_from(&[]); + let v2 = float32_vec_from(&[4.0, 5.0]); + + let result = vector_concat(&v1, &v2).unwrap(); + + assert_eq!(result.dims, 2); + assert_eq!(f32_slice_from_vector(&result), vec![4.0, 5.0]); + } + + #[test] + fn test_vector_concat_empty_right() { + let v1 = float32_vec_from(&[1.0, 2.0]); + let v2 = float32_vec_from(&[]); + + let result = vector_concat(&v1, &v2).unwrap(); + + assert_eq!(result.dims, 2); + assert_eq!(f32_slice_from_vector(&result), vec![1.0, 2.0]); + } + + #[test] + fn test_vector_concat_both_empty() { + let v1 = float32_vec_from(&[]); + let v2 = float32_vec_from(&[]); + let result = vector_concat(&v1, &v2).unwrap(); + assert_eq!(result.dims, 0); + assert_eq!(f32_slice_from_vector(&result), Vec::::new()); + } + + #[test] + fn test_vector_concat_different_lengths() { + let v1 = float32_vec_from(&[1.0]); + let v2 = float32_vec_from(&[2.0, 3.0, 4.0]); + + let result = vector_concat(&v1, &v2).unwrap(); + + assert_eq!(result.dims, 4); + assert_eq!(f32_slice_from_vector(&result), vec![1.0, 2.0, 3.0, 4.0]); + } +} diff --git a/core/vector/operations/convert.rs b/core/vector/operations/convert.rs new file mode 100644 index 000000000..1db0ab99e --- /dev/null +++ b/core/vector/operations/convert.rs @@ -0,0 +1,123 @@ +use crate::{ + vector::vector_types::{Vector, VectorType}, + Result, +}; + +pub fn vector_convert(v: Vector, target_type: VectorType) -> Result { + match (v.vector_type, target_type) { + (VectorType::Float32Dense, VectorType::Float32Dense) + | (VectorType::Float64Dense, VectorType::Float64Dense) + | (VectorType::Float32Sparse, VectorType::Float32Sparse) => Ok(v), + (VectorType::Float32Dense, VectorType::Float64Dense) => Ok(Vector::from_f64( + v.as_f32_slice().iter().map(|&x| x as f64).collect(), + )), + (VectorType::Float64Dense, VectorType::Float32Dense) => Ok(Vector::from_f32( + v.as_f64_slice().iter().map(|&x| x as f32).collect(), + )), + (VectorType::Float32Dense, VectorType::Float32Sparse) => { + let (mut idx, mut values) = (Vec::new(), Vec::new()); + for (i, &value) in v.as_f32_slice().iter().enumerate() { + if value == 0.0 { + continue; + } + idx.push(i as u32); + values.push(value); + } + Ok(Vector::from_f32_sparse(v.dims, values, idx)) + } + (VectorType::Float64Dense, VectorType::Float32Sparse) => { + let (mut idx, mut values) = (Vec::new(), Vec::new()); + for (i, &value) in v.as_f64_slice().iter().enumerate() { + if value == 0.0 { + continue; + } + idx.push(i as u32); + values.push(value as f32); + } + Ok(Vector::from_f32_sparse(v.dims, values, idx)) + } + (VectorType::Float32Sparse, VectorType::Float32Dense) => { + let sparse = v.as_f32_sparse(); + let mut data = vec![0f32; v.dims]; + for (&i, &value) in sparse.idx.iter().zip(sparse.values.iter()) { + data[i as usize] = value; + } + Ok(Vector::from_f32(data)) + } + (VectorType::Float32Sparse, VectorType::Float64Dense) => { + let sparse = v.as_f32_sparse(); + let mut data = vec![0f64; v.dims]; + for (&i, &value) in sparse.idx.iter().zip(sparse.values.iter()) { + data[i as usize] = value as f64; + } + Ok(Vector::from_f64(data)) + } + } +} + +#[cfg(test)] +mod tests { + use crate::vector::{ + operations::convert::vector_convert, + vector_types::{Vector, VectorType}, + }; + + fn concat(data: &[[u8; N]]) -> Vec { + data.iter().flatten().cloned().collect::>() + } + + fn assert_vectors(v1: &Vector, v2: &Vector) { + assert_eq!(v1.vector_type, v2.vector_type); + assert_eq!(v1.dims, v2.dims); + assert_eq!(v1.bin_data(), v2.bin_data()); + } + + #[test] + pub fn test_vector_convert() { + let vf32 = Vector { + vector_type: VectorType::Float32Dense, + dims: 3, + owned: Some(concat(&[ + 1.0f32.to_le_bytes(), + 0.0f32.to_le_bytes(), + 2.0f32.to_le_bytes(), + ])), + refer: None, + }; + let vf64 = Vector { + vector_type: VectorType::Float64Dense, + dims: 3, + owned: Some(concat(&[ + 1.0f64.to_le_bytes(), + 0.0f64.to_le_bytes(), + 2.0f64.to_le_bytes(), + ])), + refer: None, + }; + let vf32_sparse = Vector { + vector_type: VectorType::Float32Sparse, + dims: 3, + owned: Some(concat(&[ + 1.0f32.to_le_bytes(), + 2.0f32.to_le_bytes(), + 0u32.to_le_bytes(), + 2u32.to_le_bytes(), + ])), + refer: None, + }; + + let vectors = [vf32, vf64, vf32_sparse]; + for v1 in &vectors { + for v2 in &vectors { + println!("{:?} -> {:?}", v1.vector_type, v2.vector_type); + let v_copy = Vector { + vector_type: v1.vector_type, + dims: v1.dims, + owned: v1.owned.clone(), + refer: None, + }; + assert_vectors(&vector_convert(v_copy, v2.vector_type).unwrap(), v2); + } + } + } +} diff --git a/core/vector/operations/distance_cos.rs b/core/vector/operations/distance_cos.rs new file mode 100644 index 000000000..e7a56a923 --- /dev/null +++ b/core/vector/operations/distance_cos.rs @@ -0,0 +1,217 @@ +use crate::{ + vector::vector_types::{Vector, VectorSparse, VectorType}, + LimboError, Result, +}; +use simsimd::SpatialSimilarity; + +pub fn vector_distance_cos(v1: &Vector, v2: &Vector) -> Result { + if v1.dims != v2.dims { + return Err(LimboError::ConversionError( + "Vectors must have the same dimensions".to_string(), + )); + } + if v1.vector_type != v2.vector_type { + return Err(LimboError::ConversionError( + "Vectors must be of the same type".to_string(), + )); + } + match v1.vector_type { + #[cfg(not(target_family = "wasm"))] + VectorType::Float32Dense => Ok(vector_f32_distance_cos_simsimd( + v1.as_f32_slice(), + v2.as_f32_slice(), + )), + #[cfg(target_family = "wasm")] + VectorType::Float32Dense => Ok(vector_f32_distance_cos_rust( + v1.as_f32_slice(), + v2.as_f32_slice(), + )), + #[cfg(not(target_family = "wasm"))] + VectorType::Float64Dense => Ok(vector_f64_distance_cos_simsimd( + v1.as_f64_slice(), + v2.as_f64_slice(), + )), + #[cfg(target_family = "wasm")] + VectorType::Float64Dense => Ok(vector_f64_distance_cos_rust( + v1.as_f64_slice(), + v2.as_f64_slice(), + )), + VectorType::Float32Sparse => Ok(vector_f32_sparse_distance_cos( + v1.as_f32_sparse(), + v2.as_f32_sparse(), + )), + } +} + +#[allow(dead_code)] +fn vector_f32_distance_cos_simsimd(v1: &[f32], v2: &[f32]) -> f64 { + f32::cosine(v1, v2).unwrap_or(f64::NAN) +} + +// SimSIMD do not support WASM for now, so we have alternative implementation: https://github.com/ashvardanian/SimSIMD/issues/189 +#[allow(dead_code)] +fn vector_f32_distance_cos_rust(v1: &[f32], v2: &[f32]) -> f64 { + let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0); + for (a, b) in v1.iter().zip(v2.iter()) { + dot += a * b; + norm1 += a * a; + norm2 += b * b; + } + if norm1 == 0.0 || norm2 == 0.0 { + return 0.0; + } + (1.0 - dot / (norm1 * norm2).sqrt()) as f64 +} + +#[allow(dead_code)] +fn vector_f64_distance_cos_simsimd(v1: &[f64], v2: &[f64]) -> f64 { + f64::cosine(v1, v2).unwrap_or(f64::NAN) +} + +// SimSIMD do not support WASM for now, so we have alternative implementation: https://github.com/ashvardanian/SimSIMD/issues/189 +#[allow(dead_code)] +fn vector_f64_distance_cos_rust(v1: &[f64], v2: &[f64]) -> f64 { + let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0); + for (a, b) in v1.iter().zip(v2.iter()) { + dot += a * b; + norm1 += a * a; + norm2 += b * b; + } + if norm1 == 0.0 || norm2 == 0.0 { + return 0.0; + } + 1.0 - dot / (norm1 * norm2).sqrt() +} + +fn vector_f32_sparse_distance_cos(v1: VectorSparse, v2: VectorSparse) -> f64 { + let mut v1_pos = 0; + let mut v2_pos = 0; + let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0); + while v1_pos < v1.idx.len() && v2_pos < v2.idx.len() { + let e1 = v1.values[v1_pos]; + let e2 = v2.values[v2_pos]; + if v1.idx[v1_pos] == v2.idx[v2_pos] { + dot += e1 * e2; + norm1 += e1 * e1; + norm2 += e2 * e2; + v1_pos += 1; + v2_pos += 1; + } else if v1.idx[v1_pos] < v2.idx[v2_pos] { + norm1 += e1 * e1; + v1_pos += 1; + } else { + norm2 += e2 * e2; + v2_pos += 1; + } + } + + while v1_pos < v1.idx.len() { + norm1 += v1.values[v1_pos] * v1.values[v1_pos]; + v1_pos += 1; + } + while v2_pos < v2.idx.len() { + norm2 += v2.values[v2_pos] * v2.values[v2_pos]; + v2_pos += 1; + } + + // Check for zero norms + if norm1 == 0.0f32 || norm2 == 0.0f32 { + return f64::NAN; + } + + (1.0f32 - (dot / (norm1 * norm2).sqrt())) as f64 +} + +#[cfg(test)] +mod tests { + use crate::vector::{ + operations::convert::vector_convert, vector_types::tests::ArbitraryVector, + }; + + use super::*; + use quickcheck_macros::quickcheck; + + #[test] + fn test_vector_distance_cos_f32() { + assert_eq!(vector_f32_distance_cos_simsimd(&[], &[]), 0.0); + assert_eq!( + vector_f32_distance_cos_simsimd(&[1.0, 2.0], &[0.0, 0.0]), + 1.0 + ); + assert!(vector_f32_distance_cos_simsimd(&[1.0, 2.0], &[1.0, 2.0]).abs() < 1e-6); + assert!((vector_f32_distance_cos_simsimd(&[1.0, 2.0], &[-1.0, -2.0]) - 2.0).abs() < 1e-6); + assert!((vector_f32_distance_cos_simsimd(&[1.0, 2.0], &[-2.0, 1.0]) - 1.0).abs() < 1e-6); + } + + #[test] + fn test_vector_distance_cos_f64() { + assert_eq!(vector_f64_distance_cos_simsimd(&[], &[]), 0.0); + assert_eq!( + vector_f64_distance_cos_simsimd(&[1.0, 2.0], &[0.0, 0.0]), + 1.0 + ); + assert!(vector_f64_distance_cos_simsimd(&[1.0, 2.0], &[1.0, 2.0]).abs() < 1e-6); + assert!((vector_f64_distance_cos_simsimd(&[1.0, 2.0], &[-1.0, -2.0]) - 2.0).abs() < 1e-6); + assert!((vector_f64_distance_cos_simsimd(&[1.0, 2.0], &[-2.0, 1.0]) - 1.0).abs() < 1e-6); + } + + #[test] + fn test_vector_distance_cos_f32_sparse() { + assert!( + (vector_f32_sparse_distance_cos( + VectorSparse { + idx: &[0, 1], + values: &[1.0, 2.0] + }, + VectorSparse { + idx: &[1, 2], + values: &[1.0, 3.0] + }, + ) - vector_f32_distance_cos_simsimd(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0])) + .abs() + < 1e-7 + ); + } + + #[quickcheck] + fn prop_vector_distance_cos_dense_vs_sparse( + v1: ArbitraryVector<100>, + v2: ArbitraryVector<100>, + ) -> bool { + let v1 = vector_convert(v1.into(), VectorType::Float32Dense).unwrap(); + let v2 = vector_convert(v2.into(), VectorType::Float32Dense).unwrap(); + let d1 = vector_distance_cos(&v1, &v2).unwrap(); + + let sparse1 = vector_convert(v1, VectorType::Float32Sparse).unwrap(); + let sparse2 = vector_convert(v2, VectorType::Float32Sparse).unwrap(); + let d2 = vector_f32_sparse_distance_cos(sparse1.as_f32_sparse(), sparse2.as_f32_sparse()); + + (d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-6 + } + + #[quickcheck] + fn prop_vector_distance_cos_rust_vs_simsimd_f32( + v1: ArbitraryVector<100>, + v2: ArbitraryVector<100>, + ) -> bool { + let v1 = vector_convert(v1.into(), VectorType::Float32Dense).unwrap(); + let v2 = vector_convert(v2.into(), VectorType::Float32Dense).unwrap(); + let d1 = vector_f32_distance_cos_rust(v1.as_f32_slice(), v2.as_f32_slice()); + let d2 = vector_f32_distance_cos_simsimd(v1.as_f32_slice(), v2.as_f32_slice()); + println!("d1 vs d2: {d1} vs {d2}"); + (d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-4 + } + + #[quickcheck] + fn prop_vector_distance_cos_rust_vs_simsimd_f64( + v1: ArbitraryVector<100>, + v2: ArbitraryVector<100>, + ) -> bool { + let v1 = vector_convert(v1.into(), VectorType::Float64Dense).unwrap(); + let v2 = vector_convert(v2.into(), VectorType::Float64Dense).unwrap(); + let d1 = vector_f64_distance_cos_rust(v1.as_f64_slice(), v2.as_f64_slice()); + let d2 = vector_f64_distance_cos_simsimd(v1.as_f64_slice(), v2.as_f64_slice()); + println!("d1 vs d2: {d1} vs {d2}"); + (d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-6 + } +} diff --git a/core/vector/operations/distance_l2.rs b/core/vector/operations/distance_l2.rs new file mode 100644 index 000000000..84f21db14 --- /dev/null +++ b/core/vector/operations/distance_l2.rs @@ -0,0 +1,239 @@ +use crate::{ + vector::vector_types::{Vector, VectorSparse, VectorType}, + LimboError, Result, +}; +use simsimd::SpatialSimilarity; + +pub fn vector_distance_l2(v1: &Vector, v2: &Vector) -> Result { + if v1.dims != v2.dims { + return Err(LimboError::ConversionError( + "Vectors must have the same dimensions".to_string(), + )); + } + if v1.vector_type != v2.vector_type { + return Err(LimboError::ConversionError( + "Vectors must be of the same type".to_string(), + )); + } + match v1.vector_type { + #[cfg(not(target_family = "wasm"))] + VectorType::Float32Dense => Ok(vector_f32_distance_l2_simsimd( + v1.as_f32_slice(), + v2.as_f32_slice(), + )), + #[cfg(target_family = "wasm")] + VectorType::Float32Dense => Ok(vector_f32_distance_l2_rust( + v1.as_f32_slice(), + v2.as_f32_slice(), + )), + #[cfg(not(target_family = "wasm"))] + VectorType::Float64Dense => Ok(vector_f64_distance_l2_simsimd( + v1.as_f64_slice(), + v2.as_f64_slice(), + )), + #[cfg(target_family = "wasm")] + VectorType::Float64Dense => Ok(vector_f64_distance_l2_rust( + v1.as_f64_slice(), + v2.as_f64_slice(), + )), + VectorType::Float32Sparse => Ok(vector_f32_sparse_distance_l2( + v1.as_f32_sparse(), + v2.as_f32_sparse(), + )), + } +} + +#[allow(dead_code)] +fn vector_f32_distance_l2_simsimd(v1: &[f32], v2: &[f32]) -> f64 { + f32::euclidean(v1, v2).unwrap_or(f64::NAN) +} + +// SimSIMD do not support WASM for now, so we have alternative implementation: https://github.com/ashvardanian/SimSIMD/issues/189 +#[allow(dead_code)] +fn vector_f32_distance_l2_rust(v1: &[f32], v2: &[f32]) -> f64 { + let sum = v1 + .iter() + .zip(v2.iter()) + .map(|(a, b)| (a - b).powi(2)) + .sum::() as f64; + sum.sqrt() +} + +#[allow(dead_code)] +fn vector_f64_distance_l2_simsimd(v1: &[f64], v2: &[f64]) -> f64 { + f64::euclidean(v1, v2).unwrap_or(f64::NAN) +} + +// SimSIMD do not support WASM for now, so we have alternative implementation: https://github.com/ashvardanian/SimSIMD/issues/189 +#[allow(dead_code)] +fn vector_f64_distance_l2_rust(v1: &[f64], v2: &[f64]) -> f64 { + let sum = v1 + .iter() + .zip(v2.iter()) + .map(|(a, b)| (a - b).powi(2)) + .sum::(); + sum.sqrt() +} + +fn vector_f32_sparse_distance_l2(v1: VectorSparse, v2: VectorSparse) -> f64 { + let mut v1_pos = 0; + let mut v2_pos = 0; + let mut sum = 0.0; + while v1_pos < v1.idx.len() && v2_pos < v2.idx.len() { + if v1.idx[v1_pos] == v2.idx[v2_pos] { + sum += (v1.values[v1_pos] - v2.values[v2_pos]).powi(2); + v1_pos += 1; + v2_pos += 1; + } else if v1.idx[v1_pos] < v2.idx[v2_pos] { + sum += v1.values[v1_pos].powi(2); + v1_pos += 1; + } else { + sum += v2.values[v2_pos].powi(2); + v2_pos += 1; + } + } + while v1_pos < v1.idx.len() { + sum += v1.values[v1_pos].powi(2); + v1_pos += 1; + } + while v2_pos < v2.idx.len() { + sum += v2.values[v2_pos].powi(2); + v2_pos += 1; + } + (sum as f64).sqrt() +} + +#[cfg(test)] +mod tests { + use quickcheck_macros::quickcheck; + + use crate::vector::{ + operations::convert::vector_convert, vector_types::tests::ArbitraryVector, + }; + + use super::*; + + #[test] + fn test_vector_distance_l2_f32_another() { + let vectors = [ + (0..8).map(|x| x as f32).collect::>(), + (1..9).map(|x| x as f32).collect::>(), + (2..10).map(|x| x as f32).collect::>(), + (3..11).map(|x| x as f32).collect::>(), + ]; + let query = (2..10).map(|x| x as f32).collect::>(); + + let expected: Vec = vec![ + 32.0_f64.sqrt(), + 8.0_f64.sqrt(), + 0.0_f64.sqrt(), + 8.0_f64.sqrt(), + ]; + let results = vectors + .iter() + .map(|v| vector_f32_distance_l2_rust(&query, v)) + .collect::>(); + assert_eq!(results, expected); + } + + #[test] + fn test_vector_distance_l2_odd_len() { + let v = (0..5).map(|x| x as f32).collect::>(); + let query = (2..7).map(|x| x as f32).collect::>(); + assert_eq!(vector_f32_distance_l2_rust(&v, &query), 20.0_f64.sqrt()); + } + + #[test] + fn test_vector_distance_l2_f32() { + assert_eq!(vector_f32_distance_l2_rust(&[], &[]), 0.0); + assert_eq!( + vector_f32_distance_l2_rust(&[1.0, 2.0], &[0.0, 0.0]), + (1f64 + 2f64 * 2f64).sqrt() + ); + assert_eq!(vector_f32_distance_l2_rust(&[1.0, 2.0], &[1.0, 2.0]), 0.0); + assert_eq!( + vector_f32_distance_l2_rust(&[1.0, 2.0], &[-1.0, -2.0]), + (2f64 * 2f64 + 4f64 * 4f64).sqrt() + ); + assert_eq!( + vector_f32_distance_l2_rust(&[1.0, 2.0], &[-2.0, 1.0]), + (3f64 * 3f64 + 1f64 * 1f64).sqrt() + ); + } + + #[test] + fn test_vector_distance_l2_f64() { + assert_eq!(vector_f64_distance_l2_rust(&[], &[]), 0.0); + assert_eq!( + vector_f64_distance_l2_rust(&[1.0, 2.0], &[0.0, 0.0]), + (1f64 + 2f64 * 2f64).sqrt() + ); + assert_eq!(vector_f64_distance_l2_rust(&[1.0, 2.0], &[1.0, 2.0]), 0.0); + assert_eq!( + vector_f64_distance_l2_rust(&[1.0, 2.0], &[-1.0, -2.0]), + (2f64 * 2f64 + 4f64 * 4f64).sqrt() + ); + assert_eq!( + vector_f64_distance_l2_rust(&[1.0, 2.0], &[-2.0, 1.0]), + (3f64 * 3f64 + 1f64 * 1f64).sqrt() + ); + } + + #[test] + fn test_vector_distance_l2_f32_sparse() { + assert!( + (vector_f32_sparse_distance_l2( + VectorSparse { + idx: &[0, 1], + values: &[1.0, 2.0] + }, + VectorSparse { + idx: &[1, 2], + values: &[1.0, 3.0] + }, + ) - vector_f32_distance_l2_rust(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0])) + .abs() + < 1e-7 + ); + } + + #[quickcheck] + fn prop_vector_distance_l2_dense_vs_sparse( + v1: ArbitraryVector<100>, + v2: ArbitraryVector<100>, + ) -> bool { + let v1 = vector_convert(v1.into(), VectorType::Float32Dense).unwrap(); + let v2 = vector_convert(v2.into(), VectorType::Float32Dense).unwrap(); + let d1 = vector_distance_l2(&v1, &v2).unwrap(); + + let sparse1 = vector_convert(v1, VectorType::Float32Sparse).unwrap(); + let sparse2 = vector_convert(v2, VectorType::Float32Sparse).unwrap(); + let d2 = vector_f32_sparse_distance_l2(sparse1.as_f32_sparse(), sparse2.as_f32_sparse()); + + (d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-6 + } + + #[quickcheck] + fn prop_vector_distance_l2_rust_vs_simsimd_f32( + v1: ArbitraryVector<100>, + v2: ArbitraryVector<100>, + ) -> bool { + let v1 = vector_convert(v1.into(), VectorType::Float32Dense).unwrap(); + let v2 = vector_convert(v2.into(), VectorType::Float32Dense).unwrap(); + let d1 = vector_f32_distance_l2_rust(v1.as_f32_slice(), v2.as_f32_slice()); + let d2 = vector_f32_distance_l2_simsimd(v1.as_f32_slice(), v2.as_f32_slice()); + (d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-4 + } + + #[quickcheck] + fn prop_vector_distance_l2_rust_vs_simsimd_f64( + v1: ArbitraryVector<100>, + v2: ArbitraryVector<100>, + ) -> bool { + let v1 = vector_convert(v1.into(), VectorType::Float64Dense).unwrap(); + let v2 = vector_convert(v2.into(), VectorType::Float64Dense).unwrap(); + let d1 = vector_f64_distance_l2_rust(v1.as_f64_slice(), v2.as_f64_slice()); + let d2 = vector_f64_distance_l2_simsimd(v1.as_f64_slice(), v2.as_f64_slice()); + (d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-6 + } +} diff --git a/core/vector/operations/jaccard.rs b/core/vector/operations/jaccard.rs new file mode 100644 index 000000000..f7f73c2a0 --- /dev/null +++ b/core/vector/operations/jaccard.rs @@ -0,0 +1,162 @@ +use crate::{ + vector::vector_types::{Vector, VectorSparse, VectorType}, + LimboError, Result, +}; + +pub fn vector_distance_jaccard(v1: &Vector, v2: &Vector) -> Result { + if v1.dims != v2.dims { + return Err(LimboError::ConversionError( + "Vectors must have the same dimensions".to_string(), + )); + } + if v1.vector_type != v2.vector_type { + return Err(LimboError::ConversionError( + "Vectors must be of the same type".to_string(), + )); + } + match v1.vector_type { + VectorType::Float32Dense => Ok(vector_f32_distance_jaccard( + v1.as_f32_slice(), + v2.as_f32_slice(), + )), + VectorType::Float64Dense => Ok(vector_f64_distance_jaccard( + v1.as_f64_slice(), + v2.as_f64_slice(), + )), + VectorType::Float32Sparse => Ok(vector_f32_sparse_distance_jaccard( + v1.as_f32_sparse(), + v2.as_f32_sparse(), + )), + } +} + +fn vector_f32_distance_jaccard(v1: &[f32], v2: &[f32]) -> f64 { + let (mut min_sum, mut max_sum) = (0.0, 0.0); + for (&a, &b) in v1.iter().zip(v2.iter()) { + min_sum += a.min(b); + max_sum += a.max(b); + } + if max_sum == 0.0 { + return f64::NAN; + } + 1. - (min_sum / max_sum) as f64 +} + +fn vector_f64_distance_jaccard(v1: &[f64], v2: &[f64]) -> f64 { + let (mut min_sum, mut max_sum) = (0.0, 0.0); + for (&a, &b) in v1.iter().zip(v2.iter()) { + min_sum += a.min(b); + max_sum += a.max(b); + } + if max_sum == 0.0 { + return f64::NAN; + } + 1. - min_sum / max_sum +} + +fn vector_f32_sparse_distance_jaccard(v1: VectorSparse, v2: VectorSparse) -> f64 { + let mut v1_pos = 0; + let mut v2_pos = 0; + let (mut min_sum, mut max_sum) = (0.0, 0.0); + while v1_pos < v1.idx.len() && v2_pos < v2.idx.len() { + if v1.idx[v1_pos] == v2.idx[v2_pos] { + min_sum += v1.values[v1_pos].min(v2.values[v2_pos]); + max_sum += v1.values[v1_pos].max(v2.values[v2_pos]); + v1_pos += 1; + v2_pos += 1; + } else if v1.idx[v1_pos] < v2.idx[v2_pos] { + min_sum += v1.values[v1_pos].min(0.); + max_sum += v1.values[v1_pos].max(0.); + v1_pos += 1; + } else { + min_sum += v2.values[v2_pos].min(0.); + max_sum += v2.values[v2_pos].max(0.); + v2_pos += 1; + } + } + while v1_pos < v1.idx.len() { + min_sum += v1.values[v1_pos].min(0.); + max_sum += v1.values[v1_pos].max(0.); + v1_pos += 1; + } + while v2_pos < v2.idx.len() { + min_sum += v2.values[v2_pos].min(0.); + max_sum += v2.values[v2_pos].max(0.); + v2_pos += 1; + } + if max_sum == 0.0 { + return f64::NAN; + } + 1. - (min_sum / max_sum) as f64 +} + +#[cfg(test)] +mod tests { + use quickcheck_macros::quickcheck; + + use crate::vector::{ + operations::convert::vector_convert, vector_types::tests::ArbitraryVector, + }; + + use super::*; + + #[test] + fn test_vector_distance_jaccard_f32() { + assert!(vector_f32_distance_jaccard(&[0.0, 0.0, 0.0], &[0.0, 0.0, 0.0]).is_nan()); + assert_eq!(vector_f32_distance_jaccard(&[1.0, 2.0], &[0.0, 0.0]), 1.0); + assert_eq!(vector_f32_distance_jaccard(&[1.0, 2.0], &[1.0, 2.0]), 0.0); + assert_eq!( + vector_f32_distance_jaccard(&[1.0, 2.0], &[2.0, 1.0]), + 1. - (1.0 + 1.0) / (2.0 + 2.0) + ); + } + + #[test] + fn test_vector_distance_jaccard_f64() { + assert!(vector_f64_distance_jaccard(&[], &[]).is_nan()); + assert!(vector_f64_distance_jaccard(&[0.0, 0.0, 0.0], &[0.0, 0.0, 0.0]).is_nan()); + assert_eq!(vector_f64_distance_jaccard(&[1.0, 2.0], &[0.0, 0.0]), 1.0); + assert_eq!(vector_f64_distance_jaccard(&[1.0, 2.0], &[1.0, 2.0]), 0.0); + assert_eq!( + vector_f64_distance_jaccard(&[1.0, 2.0], &[2.0, 1.0]), + 1. - (1.0 + 1.0) / (2.0 + 2.0) + ); + } + + #[test] + fn test_vector_distance_jaccard_f32_sparse() { + assert!( + (vector_f32_sparse_distance_jaccard( + VectorSparse { + idx: &[0, 1], + values: &[1.0, 2.0] + }, + VectorSparse { + idx: &[1, 2], + values: &[1.0, 3.0] + }, + ) - vector_f32_distance_jaccard(&[1.0, 2.0, 0.0], &[0.0, 1.0, 3.0])) + .abs() + < 1e-7 + ); + } + + #[quickcheck] + fn prop_vector_distance_jaccard_dense_vs_sparse( + v1: ArbitraryVector<100>, + v2: ArbitraryVector<100>, + ) -> bool { + let v1 = vector_convert(v1.into(), VectorType::Float32Dense).unwrap(); + let v2 = vector_convert(v2.into(), VectorType::Float32Dense).unwrap(); + let d1 = vector_distance_jaccard(&v1, &v2).unwrap(); + println!("v1: {:?}, v2: {:?}", v1.as_f32_slice(), v2.as_f32_slice()); + + let sparse1 = vector_convert(v1, VectorType::Float32Sparse).unwrap(); + let sparse2 = vector_convert(v2, VectorType::Float32Sparse).unwrap(); + let d2 = + vector_f32_sparse_distance_jaccard(sparse1.as_f32_sparse(), sparse2.as_f32_sparse()); + + println!("d1: {}, d2: {}, delta: {}", d1, d2, (d1 - d2).abs()); + (d1.is_nan() && d2.is_nan()) || (d1 - d2).abs() < 1e-6 + } +} diff --git a/core/vector/operations/mod.rs b/core/vector/operations/mod.rs new file mode 100644 index 000000000..9b1a20ada --- /dev/null +++ b/core/vector/operations/mod.rs @@ -0,0 +1,8 @@ +pub mod concat; +pub mod convert; +pub mod distance_cos; +pub mod distance_l2; +pub mod jaccard; +pub mod serialize; +pub mod slice; +pub mod text; diff --git a/core/vector/operations/serialize.rs b/core/vector/operations/serialize.rs new file mode 100644 index 000000000..8f8c3af9d --- /dev/null +++ b/core/vector/operations/serialize.rs @@ -0,0 +1,22 @@ +use crate::{ + vector::vector_types::{Vector, VectorType}, + Value, +}; + +pub fn vector_serialize(x: Vector) -> Value { + match x.vector_type { + VectorType::Float32Dense => Value::from_blob(x.bin_eject()), + VectorType::Float64Dense => { + let mut data = x.bin_eject(); + data.push(2); + Value::from_blob(data) + } + VectorType::Float32Sparse => { + let dims = x.dims; + let mut data = x.bin_eject(); + data.extend_from_slice(&(dims as u32).to_le_bytes()); + data.push(9); + Value::from_blob(data) + } + } +} diff --git a/core/vector/operations/slice.rs b/core/vector/operations/slice.rs new file mode 100644 index 000000000..a1f6b99d8 --- /dev/null +++ b/core/vector/operations/slice.rs @@ -0,0 +1,140 @@ +use crate::{ + vector::vector_types::{Vector, VectorType}, + LimboError, Result, +}; + +pub fn vector_slice(vector: &Vector, start: usize, end: usize) -> Result> { + if start > end { + return Err(LimboError::InvalidArgument( + "start index must not be greater than end index".into(), + )); + } + if end > vector.dims || end < start { + return Err(LimboError::ConversionError( + "vector_slice range out of bounds".into(), + )); + } + match vector.vector_type { + VectorType::Float32Dense => Ok(Vector { + vector_type: vector.vector_type, + dims: end - start, + owned: Some(vector.bin_data()[start * 4..end * 4].to_vec()), + refer: None, + }), + VectorType::Float64Dense => Ok(Vector { + vector_type: vector.vector_type, + dims: end - start, + owned: Some(vector.bin_data()[start * 8..end * 8].to_vec()), + refer: None, + }), + VectorType::Float32Sparse => { + let mut values = Vec::new(); + let mut idx = Vec::new(); + let sparse = vector.as_f32_sparse(); + for (&i, &value) in sparse.idx.iter().zip(sparse.values.iter()) { + let i = i as usize; + if i < start || i >= end { + continue; + } + values.extend_from_slice(&value.to_le_bytes()); + idx.extend_from_slice(&i.to_le_bytes()); + } + values.extend_from_slice(&idx); + Ok(Vector { + vector_type: vector.vector_type, + dims: end - start, + owned: Some(values), + refer: None, + }) + } + } +} + +#[cfg(test)] +mod tests { + use crate::vector::{ + operations::slice::vector_slice, + vector_types::{Vector, VectorType}, + }; + + fn float32_vec_from(slice: &[f32]) -> Vector { + let mut data = Vec::new(); + for &v in slice { + data.extend_from_slice(&v.to_le_bytes()); + } + + Vector { + vector_type: VectorType::Float32Dense, + dims: slice.len(), + owned: Some(data), + refer: None, + } + } + + fn f32_slice_from_vector(vector: &Vector) -> Vec { + vector.as_f32_slice().to_vec() + } + + #[test] + fn test_vector_slice_normal_case() { + let input_vec = float32_vec_from(&[1.0, 2.0, 3.0, 4.0, 5.0]); + let result = vector_slice(&input_vec, 1, 4).unwrap(); + + assert_eq!(result.dims, 3); + assert_eq!(f32_slice_from_vector(&result), vec![2.0, 3.0, 4.0]); + } + + #[test] + fn test_vector_slice_full_range() { + let input_vec = float32_vec_from(&[10.0, 20.0, 30.0]); + let result = vector_slice(&input_vec, 0, 3).unwrap(); + + assert_eq!(result.dims, 3); + assert_eq!(f32_slice_from_vector(&result), vec![10.0, 20.0, 30.0]); + } + + #[test] + fn test_vector_slice_single_element() { + let input_vec = float32_vec_from(&[4.40, 2.71]); + let result = vector_slice(&input_vec, 1, 2).unwrap(); + + assert_eq!(result.dims, 1); + assert_eq!(f32_slice_from_vector(&result), vec![2.71]); + } + + #[test] + fn test_vector_slice_empty_list() { + let input_vec = float32_vec_from(&[1.0, 2.0]); + let result = vector_slice(&input_vec, 2, 2).unwrap(); + + assert_eq!(result.dims, 0); + } + + #[test] + fn test_vector_slice_zero_length() { + let input_vec = float32_vec_from(&[1.0, 2.0, 3.0]); + let err = vector_slice(&input_vec, 2, 1); + assert!(err.is_err(), "Expected error on zero-length range"); + } + + #[test] + fn test_vector_slice_out_of_bounds() { + let input_vec = float32_vec_from(&[1.0, 2.0]); + let err = vector_slice(&input_vec, 0, 5); + assert!(err.is_err()); + } + + #[test] + fn test_vector_slice_start_out_of_bounds() { + let input_vec = float32_vec_from(&[1.0, 2.0]); + let err = vector_slice(&input_vec, 5, 5); + assert!(err.is_err()); + } + + #[test] + fn test_vector_slice_end_out_of_bounds() { + let input_vec = float32_vec_from(&[1.0, 2.0]); + let err = vector_slice(&input_vec, 1, 3); + assert!(err.is_err()); + } +} diff --git a/core/vector/operations/text.rs b/core/vector/operations/text.rs new file mode 100644 index 000000000..810bd8bc0 --- /dev/null +++ b/core/vector/operations/text.rs @@ -0,0 +1,144 @@ +use crate::{ + vector::vector_types::{Vector, VectorType}, + LimboError, Result, +}; + +pub fn vector_to_text(vector: &Vector) -> String { + match vector.vector_type { + VectorType::Float32Dense => format_text(vector.as_f32_slice().iter()), + VectorType::Float64Dense => format_text(vector.as_f64_slice().iter()), + VectorType::Float32Sparse => { + let mut dense = vec![0.0f32; vector.dims]; + let sparse = vector.as_f32_sparse(); + tracing::info!("{:?}", sparse); + for (&idx, &value) in sparse.idx.iter().zip(sparse.values.iter()) { + dense[idx as usize] = value; + } + format_text(dense.iter()) + } + } +} + +fn format_text(values: impl Iterator) -> String { + let mut text = String::new(); + text.push('['); + let mut first = true; + for value in values { + if !first { + text.push(','); + } + first = false; + text.push_str(&value.to_string()); + } + text.push(']'); + text +} + +/// Parse a vector in text representation into a Vector. +/// +/// The format of a vector in text representation looks as follows: +/// +/// ```console +/// [1.0, 2.0, 3.0] +/// ``` +pub fn vector_from_text(vector_type: VectorType, text: &str) -> Result { + let text = text.trim(); + let mut chars = text.chars(); + if chars.next() != Some('[') || chars.last() != Some(']') { + return Err(LimboError::ConversionError( + "Invalid vector value".to_string(), + )); + } + let text = &text[1..text.len() - 1]; + if text.trim().is_empty() { + return Ok(match vector_type { + VectorType::Float32Dense | VectorType::Float64Dense | VectorType::Float32Sparse => { + Vector { + vector_type, + dims: 0, + owned: Some(Vec::new()), + refer: None, + } + } + }); + } + let tokens = text.split(',').map(|x| x.trim()); + match vector_type { + VectorType::Float32Dense => vector32_from_text(tokens), + VectorType::Float64Dense => vector64_from_text(tokens), + VectorType::Float32Sparse => vector32_sparse_from_text(tokens), + } +} + +fn vector32_from_text<'a>(tokens: impl Iterator) -> Result> { + let mut data = Vec::new(); + for token in tokens { + let value = token + .parse::() + .map_err(|_| LimboError::ConversionError("Invalid vector value".to_string()))?; + if !value.is_finite() { + return Err(LimboError::ConversionError( + "Invalid vector value".to_string(), + )); + } + data.extend_from_slice(&value.to_le_bytes()); + } + Ok(Vector { + vector_type: VectorType::Float32Dense, + dims: data.len() / 4, + owned: Some(data), + refer: None, + }) +} + +fn vector64_from_text<'a>(tokens: impl Iterator) -> Result> { + let mut data = Vec::new(); + for token in tokens { + let value = token + .parse::() + .map_err(|_| LimboError::ConversionError("Invalid vector value".to_string()))?; + if !value.is_finite() { + return Err(LimboError::ConversionError( + "Invalid vector value".to_string(), + )); + } + data.extend_from_slice(&value.to_le_bytes()); + } + Ok(Vector { + vector_type: VectorType::Float64Dense, + dims: data.len() / 8, + owned: Some(data), + refer: None, + }) +} + +fn vector32_sparse_from_text<'a>(tokens: impl Iterator) -> Result> { + let mut idx = Vec::new(); + let mut values = Vec::new(); + let mut dims = 0u32; + for token in tokens { + let value = token + .parse::() + .map_err(|_| LimboError::ConversionError("Invalid vector value".to_string()))?; + if !value.is_finite() { + return Err(LimboError::ConversionError( + "Invalid vector value".to_string(), + )); + } + + dims += 1; + if value == 0.0 { + continue; + } + idx.extend_from_slice(&(dims - 1).to_le_bytes()); + values.extend_from_slice(&value.to_le_bytes()); + } + + values.extend_from_slice(&idx); + Ok(Vector { + vector_type: VectorType::Float32Sparse, + dims: dims as usize, + owned: Some(values), + refer: None, + }) +} diff --git a/core/vector/vector_types.rs b/core/vector/vector_types.rs index 20d268dc0..8b70fdbda 100644 --- a/core/vector/vector_types.rs +++ b/core/vector/vector_types.rs @@ -1,30 +1,205 @@ -use crate::types::{Value, ValueType}; -use crate::vdbe::Register; use crate::{LimboError, Result}; #[derive(Debug, Clone, PartialEq, Copy)] pub enum VectorType { - Float32, - Float64, -} - -impl VectorType { - pub fn size_to_dims(&self, size: usize) -> usize { - match self { - VectorType::Float32 => size / 4, - VectorType::Float64 => size / 8, - } - } + Float32Dense, + Float64Dense, + Float32Sparse, } #[derive(Debug)] -pub struct Vector { +pub struct Vector<'a> { pub vector_type: VectorType, pub dims: usize, - pub data: Vec, + pub owned: Option>, + pub refer: Option<&'a [u8]>, } -impl Vector { +#[derive(Debug)] +pub struct VectorSparse<'a, T: std::fmt::Debug> { + pub idx: &'a [u32], + pub values: &'a [T], +} + +impl<'a> Vector<'a> { + pub fn vector_type(blob: &[u8]) -> Result<(VectorType, usize)> { + // Even-sized blobs are always float32. + if blob.len() % 2 == 0 { + return Ok((VectorType::Float32Dense, blob.len())); + } + // Odd-sized blobs have type byte at the end + let vector_type = blob[blob.len() - 1]; + /* + vector types used by LibSQL: + (see https://github.com/tursodatabase/libsql/blob/a55bf61192bdb89e97568de593c4af5b70d24bde/libsql-sqlite3/src/vectorInt.h#L52) + #define VECTOR_TYPE_FLOAT32 1 + #define VECTOR_TYPE_FLOAT64 2 + #define VECTOR_TYPE_FLOAT1BIT 3 + #define VECTOR_TYPE_FLOAT8 4 + #define VECTOR_TYPE_FLOAT16 5 + #define VECTOR_TYPE_FLOATB16 6 + */ + match vector_type { + 1 => Ok((VectorType::Float32Dense, blob.len() - 1)), + 2 => Ok((VectorType::Float64Dense, blob.len() - 1)), + 3..=6 => Err(LimboError::ConversionError( + "unsupported vector type from LibSQL".to_string(), + )), + 9 => Ok((VectorType::Float32Sparse, blob.len() - 1)), + _ => Err(LimboError::ConversionError(format!( + "unknown vector type: {vector_type}" + ))), + } + } + pub fn from_f32(mut values_f32: Vec) -> Self { + let dims = values_f32.len(); + let values = unsafe { + Vec::from_raw_parts( + values_f32.as_mut_ptr() as *mut u8, + values_f32.len() * 4, + values_f32.capacity() * 4, + ) + }; + std::mem::forget(values_f32); + Self { + vector_type: VectorType::Float32Dense, + dims, + owned: Some(values), + refer: None, + } + } + pub fn from_f64(mut values_f64: Vec) -> Self { + let dims = values_f64.len(); + let values = unsafe { + Vec::from_raw_parts( + values_f64.as_mut_ptr() as *mut u8, + values_f64.len() * 8, + values_f64.capacity() * 8, + ) + }; + std::mem::forget(values_f64); + Self { + vector_type: VectorType::Float64Dense, + dims, + owned: Some(values), + refer: None, + } + } + pub fn from_f32_sparse(dims: usize, mut values_f32: Vec, mut idx_u32: Vec) -> Self { + let mut values = unsafe { + Vec::from_raw_parts( + values_f32.as_mut_ptr() as *mut u8, + values_f32.len() * 4, + values_f32.capacity() * 4, + ) + }; + std::mem::forget(values_f32); + + let idx = unsafe { + Vec::from_raw_parts( + idx_u32.as_mut_ptr() as *mut u8, + idx_u32.len() * 4, + idx_u32.capacity() * 4, + ) + }; + std::mem::forget(idx_u32); + + values.extend_from_slice(&idx); + Self { + vector_type: VectorType::Float32Sparse, + dims, + owned: Some(values), + refer: None, + } + } + pub fn from_vec(mut blob: Vec) -> Result { + let (vector_type, len) = Self::vector_type(&blob)?; + blob.truncate(len); + Self::from_data(vector_type, Some(blob), None) + } + pub fn from_slice(blob: &'a [u8]) -> Result { + let (vector_type, len) = Self::vector_type(blob)?; + Self::from_data(vector_type, None, Some(&blob[..len])) + } + pub fn from_data( + vector_type: VectorType, + owned: Option>, + refer: Option<&'a [u8]>, + ) -> Result { + let owned_slice = owned.as_deref(); + let refer_slice = refer.as_ref().map(|&x| x); + let data = owned_slice.unwrap_or_else(|| refer_slice.unwrap()); + match vector_type { + VectorType::Float32Dense => { + if data.len() % 4 != 0 { + return Err(LimboError::InvalidArgument(format!( + "f32 dense vector unexpected data length: {}", + data.len(), + ))); + } + Ok(Vector { + vector_type, + dims: data.len() / 4, + owned, + refer, + }) + } + VectorType::Float64Dense => { + if data.len() % 8 != 0 { + return Err(LimboError::InvalidArgument(format!( + "f64 dense vector unexpected data length: {}", + data.len(), + ))); + } + Ok(Vector { + vector_type, + dims: data.len() / 8, + owned, + refer, + }) + } + VectorType::Float32Sparse => { + if data.is_empty() || data.len() % 4 != 0 || (data.len() - 4) % 8 != 0 { + return Err(LimboError::InvalidArgument(format!( + "f32 sparse vector unexpected data length: {}", + data.len(), + ))); + } + let original_len = data.len(); + let dims_bytes = &data[original_len - 4..]; + let dims = u32::from_le_bytes(dims_bytes.try_into().unwrap()) as usize; + let owned = owned.map(|mut x| { + x.truncate(original_len - 4); + x + }); + let refer = refer.map(|x| &x[0..original_len - 4]); + let vector = Vector { + vector_type, + dims, + owned, + refer, + }; + Ok(vector) + } + } + } + + pub fn bin_len(&self) -> usize { + let owned = self.owned.as_ref().map(|x| x.len()); + let refer = self.refer.as_ref().map(|x| x.len()); + owned.unwrap_or_else(|| refer.unwrap()) + } + + pub fn bin_data(&'a self) -> &'a [u8] { + let owned = self.owned.as_deref(); + let refer = self.refer.as_ref().map(|&x| x); + owned.unwrap_or_else(|| refer.unwrap()) + } + + pub fn bin_eject(self) -> Vec { + self.owned.unwrap_or_else(|| self.refer.unwrap().to_vec()) + } + /// # Safety /// /// This method is used to reinterpret the underlying `Vec` data @@ -32,17 +207,18 @@ impl Vector { /// - The buffer is correctly aligned for `f32` /// - The length of the buffer is exactly `dims * size_of::()` pub fn as_f32_slice(&self) -> &[f32] { + debug_assert!(self.vector_type == VectorType::Float32Dense); if self.dims == 0 { return &[]; } assert_eq!( - self.data.len(), + self.bin_len(), self.dims * std::mem::size_of::(), "data length must equal dims * size_of::()" ); - let ptr = self.data.as_ptr(); + let ptr = self.bin_data().as_ptr(); let align = std::mem::align_of::(); assert_eq!( ptr.align_offset(align), @@ -60,17 +236,18 @@ impl Vector { /// - The buffer is correctly aligned for `f64` /// - The length of the buffer is exactly `dims * size_of::()` pub fn as_f64_slice(&self) -> &[f64] { + debug_assert!(self.vector_type == VectorType::Float64Dense); if self.dims == 0 { return &[]; } assert_eq!( - self.data.len(), + self.bin_len(), self.dims * std::mem::size_of::(), "data length must equal dims * size_of::()" ); - let ptr = self.data.as_ptr(); + let ptr = self.bin_data().as_ptr(); let align = std::mem::align_of::(); assert_eq!( ptr.align_offset(align), @@ -78,360 +255,37 @@ impl Vector { "data pointer must be aligned to {align} bytes for f64 access" ); - unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const f64, self.dims) } - } -} - -/// Parse a vector in text representation into a Vector. -/// -/// The format of a vector in text representation looks as follows: -/// -/// ```console -/// [1.0, 2.0, 3.0] -/// ``` -pub fn parse_string_vector(vector_type: VectorType, value: &Value) -> Result { - let Some(text) = value.to_text() else { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - }; - let text = text.trim(); - let mut chars = text.chars(); - if chars.next() != Some('[') || chars.last() != Some(']') { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - let mut data: Vec = Vec::new(); - let text = &text[1..text.len() - 1]; - if text.trim().is_empty() { - return Ok(Vector { - vector_type, - dims: 0, - data, - }); - } - let xs = text.split(','); - for x in xs { - let x = x.trim(); - if x.is_empty() { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - match vector_type { - VectorType::Float32 => { - let x = x - .parse::() - .map_err(|_| LimboError::ConversionError("Invalid vector value".to_string()))?; - if !x.is_finite() { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - data.extend_from_slice(&x.to_le_bytes()); - } - VectorType::Float64 => { - let x = x - .parse::() - .map_err(|_| LimboError::ConversionError("Invalid vector value".to_string()))?; - if !x.is_finite() { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - data.extend_from_slice(&x.to_le_bytes()); - } - }; - } - let dims = vector_type.size_to_dims(data.len()); - Ok(Vector { - vector_type, - dims, - data, - }) -} - -pub fn parse_vector(value: &Register, vec_ty: Option) -> Result { - match value.get_value().value_type() { - ValueType::Text => { - parse_string_vector(vec_ty.unwrap_or(VectorType::Float32), value.get_value()) - } - ValueType::Blob => { - let Some(blob) = value.get_value().to_blob() else { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - }; - let vector_type = vector_type(blob)?; - if let Some(vec_ty) = vec_ty { - if vec_ty != vector_type { - return Err(LimboError::ConversionError( - "Invalid vector type".to_string(), - )); - } - } - vector_deserialize(vector_type, blob) - } - _ => Err(LimboError::ConversionError( - "Invalid vector type".to_string(), - )), - } -} - -pub fn vector_to_text(vector: &Vector) -> String { - let mut text = String::new(); - text.push('['); - match vector.vector_type { - VectorType::Float32 => { - let data = vector.as_f32_slice(); - for (i, value) in data.iter().enumerate().take(vector.dims) { - text.push_str(&value.to_string()); - if i < vector.dims - 1 { - text.push(','); - } - } - } - VectorType::Float64 => { - let data = vector.as_f64_slice(); - for (i, value) in data.iter().enumerate().take(vector.dims) { - text.push_str(&value.to_string()); - if i < vector.dims - 1 { - text.push(','); - } - } - } - } - text.push(']'); - text -} - -pub fn vector_deserialize(vector_type: VectorType, blob: &[u8]) -> Result { - match vector_type { - VectorType::Float32 => vector_deserialize_f32(blob), - VectorType::Float64 => vector_deserialize_f64(blob), - } -} - -pub fn vector_serialize_f64(x: Vector) -> Value { - let mut blob = Vec::with_capacity(x.dims * 8 + 1); - blob.extend_from_slice(&x.data); - blob.push(2); - Value::from_blob(blob) -} - -pub fn vector_deserialize_f64(blob: &[u8]) -> Result { - Ok(Vector { - vector_type: VectorType::Float64, - dims: (blob.len() - 1) / 8, - data: blob[..blob.len() - 1].to_vec(), - }) -} - -pub fn vector_serialize_f32(x: Vector) -> Value { - Value::from_blob(x.data) -} - -pub fn vector_deserialize_f32(blob: &[u8]) -> Result { - Ok(Vector { - vector_type: VectorType::Float32, - dims: blob.len() / 4, - data: blob.to_vec(), - }) -} - -pub fn do_vector_distance_cos(v1: &Vector, v2: &Vector) -> Result { - match v1.vector_type { - VectorType::Float32 => vector_f32_distance_cos(v1, v2), - VectorType::Float64 => vector_f64_distance_cos(v1, v2), - } -} - -pub fn vector_f32_distance_cos(v1: &Vector, v2: &Vector) -> Result { - if v1.dims != v2.dims { - return Err(LimboError::ConversionError( - "Invalid vector dimensions".to_string(), - )); - } - if v1.vector_type != v2.vector_type { - return Err(LimboError::ConversionError( - "Invalid vector type".to_string(), - )); - } - let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0); - let v1_data = v1.as_f32_slice(); - let v2_data = v2.as_f32_slice(); - - // Check for non-finite values - if v1_data.iter().any(|x| !x.is_finite()) || v2_data.iter().any(|x| !x.is_finite()) { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); + unsafe { std::slice::from_raw_parts(ptr as *const f64, self.dims) } } - for i in 0..v1.dims { - let e1 = v1_data[i]; - let e2 = v2_data[i]; - dot += e1 * e2; - norm1 += e1 * e1; - norm2 += e2 * e2; + pub fn as_f32_sparse(&self) -> VectorSparse<'_, f32> { + debug_assert!(self.vector_type == VectorType::Float32Sparse); + let ptr = self.bin_data().as_ptr(); + let align = std::mem::align_of::(); + assert_eq!( + ptr.align_offset(align), + 0, + "data pointer must be aligned to {align} bytes for f32 access" + ); + let length = self.bin_data().len() / 4 / 2; + let values = unsafe { std::slice::from_raw_parts(ptr as *const f32, length) }; + let idx = unsafe { std::slice::from_raw_parts((ptr as *const u32).add(length), length) }; + debug_assert!(idx.is_sorted()); + VectorSparse { idx, values } } - - // Check for zero norms to avoid division by zero - if norm1 == 0.0 || norm2 == 0.0 { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - - Ok(1.0 - (dot / (norm1 * norm2).sqrt()) as f64) -} - -pub fn vector_f64_distance_cos(v1: &Vector, v2: &Vector) -> Result { - if v1.dims != v2.dims { - return Err(LimboError::ConversionError( - "Invalid vector dimensions".to_string(), - )); - } - if v1.vector_type != v2.vector_type { - return Err(LimboError::ConversionError( - "Invalid vector type".to_string(), - )); - } - let (mut dot, mut norm1, mut norm2) = (0.0, 0.0, 0.0); - let v1_data = v1.as_f64_slice(); - let v2_data = v2.as_f64_slice(); - - // Check for non-finite values - if v1_data.iter().any(|x| !x.is_finite()) || v2_data.iter().any(|x| !x.is_finite()) { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - - for i in 0..v1.dims { - let e1 = v1_data[i]; - let e2 = v2_data[i]; - dot += e1 * e2; - norm1 += e1 * e1; - norm2 += e2 * e2; - } - - // Check for zero norms - if norm1 == 0.0 || norm2 == 0.0 { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - - Ok(1.0 - (dot / (norm1 * norm2).sqrt())) -} - -pub fn vector_type(blob: &[u8]) -> Result { - // Even-sized blobs are always float32. - if blob.len() % 2 == 0 { - return Ok(VectorType::Float32); - } - // Odd-sized blobs have type byte at the end - let (data_blob, type_byte) = blob.split_at(blob.len() - 1); - let vector_type = type_byte[0]; - match vector_type { - 1 => { - if data_blob.len() % 4 != 0 { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - Ok(VectorType::Float32) - } - 2 => { - if data_blob.len() % 8 != 0 { - return Err(LimboError::ConversionError( - "Invalid vector value".to_string(), - )); - } - Ok(VectorType::Float64) - } - _ => Err(LimboError::ConversionError( - "Invalid vector type".to_string(), - )), - } -} - -pub fn vector_concat(v1: &Vector, v2: &Vector) -> Result { - if v1.vector_type != v2.vector_type { - return Err(LimboError::ConversionError( - "Mismatched vector types".into(), - )); - } - - let mut data = Vec::with_capacity(v1.data.len() + v2.data.len()); - data.extend_from_slice(&v1.data); - data.extend_from_slice(&v2.data); - - Ok(Vector { - vector_type: v1.vector_type, - dims: v1.dims + v2.dims, - data, - }) -} - -pub fn vector_slice(vector: &Vector, start_idx: usize, end_idx: usize) -> Result { - fn extract_bytes( - slice: &[T], - start: usize, - end: usize, - to_bytes: impl Fn(&T) -> [u8; N], - ) -> Result> { - if start > end { - return Err(LimboError::InvalidArgument( - "start index must not be greater than end index".into(), - )); - } - if end > slice.len() || end < start { - return Err(LimboError::ConversionError( - "vector_slice range out of bounds".into(), - )); - } - - let mut buf = Vec::with_capacity((end - start) * N); - for item in &slice[start..end] { - buf.extend_from_slice(&to_bytes(item)); - } - Ok(buf) - } - - let (vector_type, data) = match vector.vector_type { - VectorType::Float32 => ( - VectorType::Float32, - extract_bytes::(vector.as_f32_slice(), start_idx, end_idx, |v| { - v.to_le_bytes() - })?, - ), - VectorType::Float64 => ( - VectorType::Float64, - extract_bytes::(vector.as_f64_slice(), start_idx, end_idx, |v| { - v.to_le_bytes() - })?, - ), - }; - - Ok(Vector { - vector_type, - dims: end_idx - start_idx, - data, - }) } #[cfg(test)] -mod tests { +pub(crate) mod tests { + use crate::vector::operations; + use super::*; use quickcheck::{Arbitrary, Gen}; use quickcheck_macros::quickcheck; // Helper to generate arbitrary vectors of specific type and dimensions #[derive(Debug, Clone)] - struct ArbitraryVector { + pub struct ArbitraryVector { vector_type: VectorType, data: Vec, } @@ -442,6 +296,10 @@ mod tests { (0..DIMS) .map(|_| { loop { + // generate zeroes with some probability since we have support for sparse vectors + if bool::arbitrary(g) { + return 0.0; + } let f = f32::arbitrary(g); // f32::arbitrary() can generate "problem values" like NaN, infinity, and very small values // Skip these values @@ -458,6 +316,10 @@ mod tests { (0..DIMS) .map(|_| { loop { + // generate zeroes with some probability since we have support for sparse vectors + if bool::arbitrary(g) { + return 0.0; + } let f = f64::arbitrary(g); // f64::arbitrary() can generate "problem values" like NaN, infinity, and very small values // Skip these values @@ -472,12 +334,13 @@ mod tests { } /// Convert an ArbitraryVector to a Vector. - impl From> for Vector { + impl From> for Vector<'static> { fn from(v: ArbitraryVector) -> Self { Vector { vector_type: v.vector_type, dims: DIMS, - data: v.data, + owned: Some(v.data), + refer: None, } } } @@ -486,20 +349,21 @@ mod tests { impl Arbitrary for ArbitraryVector { fn arbitrary(g: &mut Gen) -> Self { let vector_type = if bool::arbitrary(g) { - VectorType::Float32 + VectorType::Float32Dense } else { - VectorType::Float64 + VectorType::Float64Dense }; let data = match vector_type { - VectorType::Float32 => { + VectorType::Float32Dense => { let floats = Self::generate_f32_vector(g); floats.iter().flat_map(|f| f.to_le_bytes()).collect() } - VectorType::Float64 => { + VectorType::Float64Dense => { let floats = Self::generate_f64_vector(g); floats.iter().flat_map(|f| f.to_le_bytes()).collect() } + _ => unreachable!(), }; ArbitraryVector { vector_type, data } @@ -534,14 +398,10 @@ mod tests { /// Test if the vector type identification is correct for a given vector. fn test_vector_type(v: Vector) -> bool { let vtype = v.vector_type; - let value = match &vtype { - VectorType::Float32 => vector_serialize_f32(v), - VectorType::Float64 => vector_serialize_f64(v), - }; - - let blob = value.to_blob().unwrap(); - match vector_type(blob) { - Ok(detected_type) => detected_type == vtype, + let value = operations::serialize::vector_serialize(v); + let blob = value.to_blob().unwrap().to_vec(); + match Vector::vector_type(&blob) { + Ok((detected_type, _)) => detected_type == vtype, Err(_) => false, } } @@ -576,52 +436,20 @@ mod tests { /// - The data length is correct (4 bytes per float for f32, 8 bytes per float for f64) fn test_slice_conversion(v: Vector) -> bool { match v.vector_type { - VectorType::Float32 => { + VectorType::Float32Dense => { let slice = v.as_f32_slice(); // Check if the slice length matches the dimensions and the data length is correct (4 bytes per float) - slice.len() == DIMS && (slice.len() * 4 == v.data.len()) + slice.len() == DIMS && (slice.len() * 4 == v.bin_len()) } - VectorType::Float64 => { + VectorType::Float64Dense => { let slice = v.as_f64_slice(); // Check if the slice length matches the dimensions and the data length is correct (8 bytes per float) - slice.len() == DIMS && (slice.len() * 8 == v.data.len()) + slice.len() == DIMS && (slice.len() * 8 == v.bin_len()) } + _ => unreachable!(), } } - // Test size_to_dims calculation with different dimensions - #[quickcheck] - fn prop_size_to_dims_calculation_2d(v: ArbitraryVector<2>) -> bool { - test_size_to_dims::<2>(v.into()) - } - - #[quickcheck] - fn prop_size_to_dims_calculation_3d(v: ArbitraryVector<3>) -> bool { - test_size_to_dims::<3>(v.into()) - } - - #[quickcheck] - fn prop_size_to_dims_calculation_4d(v: ArbitraryVector<4>) -> bool { - test_size_to_dims::<4>(v.into()) - } - - #[quickcheck] - fn prop_size_to_dims_calculation_100d(v: ArbitraryVector<100>) -> bool { - test_size_to_dims::<100>(v.into()) - } - - #[quickcheck] - fn prop_size_to_dims_calculation_1536d(v: ArbitraryVector<1536>) -> bool { - test_size_to_dims::<1536>(v.into()) - } - - /// Test if the size_to_dims calculation is correct for a given vector. - fn test_size_to_dims(v: Vector) -> bool { - let size = v.data.len(); - let calculated_dims = v.vector_type.size_to_dims(size); - calculated_dims == DIMS - } - #[quickcheck] fn prop_vector_distance_safety_2d(v1: ArbitraryVector<2>, v2: ArbitraryVector<2>) -> bool { test_vector_distance::<2>(&v1.into(), &v2.into()) @@ -658,171 +486,56 @@ mod tests { /// - Assumes vectors are well-formed (same type and dimension) /// - The distance must be between 0 and 2 fn test_vector_distance(v1: &Vector, v2: &Vector) -> bool { - match do_vector_distance_cos(v1, v2) { - Ok(distance) => (0.0..=2.0).contains(&distance), + match operations::distance_cos::vector_distance_cos(v1, v2) { + Ok(distance) => distance.is_nan() || (0.0 - 1e-6..=2.0 + 1e-6).contains(&distance), Err(_) => true, } } #[test] - fn parse_string_vector_zero_length() { - let value = Value::from_text("[]"); - let vector = parse_string_vector(VectorType::Float32, &value).unwrap(); - assert_eq!(vector.dims, 0); - assert_eq!(vector.vector_type, VectorType::Float32); - } - - #[test] - fn test_parse_string_vector_valid_whitespace() { - let value = Value::from_text(" [ 1.0 , 2.0 , 3.0 ] "); - let vector = parse_string_vector(VectorType::Float32, &value).unwrap(); - assert_eq!(vector.dims, 3); - assert_eq!(vector.vector_type, VectorType::Float32); - } - - #[test] - fn test_parse_string_vector_valid() { - let value = Value::from_text("[1.0, 2.0, 3.0]"); - let vector = parse_string_vector(VectorType::Float32, &value).unwrap(); - assert_eq!(vector.dims, 3); - assert_eq!(vector.vector_type, VectorType::Float32); - } - - fn float32_vec_from(slice: &[f32]) -> Vector { - let mut data = Vec::new(); - for &v in slice { - data.extend_from_slice(&v.to_le_bytes()); - } - - Vector { - vector_type: VectorType::Float32, - dims: slice.len(), - data, - } - } - - fn f32_slice_from_vector(vector: &Vector) -> Vec { - vector.as_f32_slice().to_vec() - } - - #[test] - fn test_vector_concat_normal_case() { - let v1 = float32_vec_from(&[1.0, 2.0, 3.0]); - let v2 = float32_vec_from(&[4.0, 5.0, 6.0]); - - let result = vector_concat(&v1, &v2).unwrap(); - - assert_eq!(result.dims, 6); - assert_eq!(result.vector_type, VectorType::Float32); - assert_eq!( - f32_slice_from_vector(&result), - vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + fn test_vector_some_cosine_dist() { + let a = Vector { + vector_type: VectorType::Float32Dense, + dims: 2, + owned: Some(vec![0, 0, 0, 0, 52, 208, 106, 63]), + refer: None, + }; + let b = Vector { + vector_type: VectorType::Float32Dense, + dims: 2, + owned: Some(vec![0, 0, 0, 0, 58, 100, 45, 192]), + refer: None, + }; + assert!( + (operations::distance_cos::vector_distance_cos(&a, &b).unwrap() - 2.0).abs() <= 1e-6 ); } #[test] - fn test_vector_concat_empty_left() { - let v1 = float32_vec_from(&[]); - let v2 = float32_vec_from(&[4.0, 5.0]); - - let result = vector_concat(&v1, &v2).unwrap(); - - assert_eq!(result.dims, 2); - assert_eq!(f32_slice_from_vector(&result), vec![4.0, 5.0]); + fn parse_string_vector_zero_length() { + let vector = operations::text::vector_from_text(VectorType::Float32Dense, "[]").unwrap(); + assert_eq!(vector.dims, 0); + assert_eq!(vector.vector_type, VectorType::Float32Dense); } #[test] - fn test_vector_concat_empty_right() { - let v1 = float32_vec_from(&[1.0, 2.0]); - let v2 = float32_vec_from(&[]); - - let result = vector_concat(&v1, &v2).unwrap(); - - assert_eq!(result.dims, 2); - assert_eq!(f32_slice_from_vector(&result), vec![1.0, 2.0]); + fn test_parse_string_vector_valid_whitespace() { + let vector = operations::text::vector_from_text( + VectorType::Float32Dense, + " [ 1.0 , 2.0 , 3.0 ] ", + ) + .unwrap(); + assert_eq!(vector.dims, 3); + assert_eq!(vector.vector_type, VectorType::Float32Dense); } #[test] - fn test_vector_concat_both_empty() { - let v1 = float32_vec_from(&[]); - let v2 = float32_vec_from(&[]); - let result = vector_concat(&v1, &v2).unwrap(); - assert_eq!(result.dims, 0); - assert_eq!(f32_slice_from_vector(&result), Vec::::new()); - } - - #[test] - fn test_vector_concat_different_lengths() { - let v1 = float32_vec_from(&[1.0]); - let v2 = float32_vec_from(&[2.0, 3.0, 4.0]); - - let result = vector_concat(&v1, &v2).unwrap(); - - assert_eq!(result.dims, 4); - assert_eq!(f32_slice_from_vector(&result), vec![1.0, 2.0, 3.0, 4.0]); - } - - #[test] - fn test_vector_slice_normal_case() { - let input_vec = float32_vec_from(&[1.0, 2.0, 3.0, 4.0, 5.0]); - let result = vector_slice(&input_vec, 1, 4).unwrap(); - - assert_eq!(result.dims, 3); - assert_eq!(f32_slice_from_vector(&result), vec![2.0, 3.0, 4.0]); - } - - #[test] - fn test_vector_slice_full_range() { - let input_vec = float32_vec_from(&[10.0, 20.0, 30.0]); - let result = vector_slice(&input_vec, 0, 3).unwrap(); - - assert_eq!(result.dims, 3); - assert_eq!(f32_slice_from_vector(&result), vec![10.0, 20.0, 30.0]); - } - - #[test] - fn test_vector_slice_single_element() { - let input_vec = float32_vec_from(&[4.40, 2.71]); - let result = vector_slice(&input_vec, 1, 2).unwrap(); - - assert_eq!(result.dims, 1); - assert_eq!(f32_slice_from_vector(&result), vec![2.71]); - } - - #[test] - fn test_vector_slice_empty_list() { - let input_vec = float32_vec_from(&[1.0, 2.0]); - let result = vector_slice(&input_vec, 2, 2).unwrap(); - - assert_eq!(result.dims, 0); - } - - #[test] - fn test_vector_slice_zero_length() { - let input_vec = float32_vec_from(&[1.0, 2.0, 3.0]); - let err = vector_slice(&input_vec, 2, 1); - assert!(err.is_err(), "Expected error on zero-length range"); - } - - #[test] - fn test_vector_slice_out_of_bounds() { - let input_vec = float32_vec_from(&[1.0, 2.0]); - let err = vector_slice(&input_vec, 0, 5); - assert!(err.is_err()); - } - - #[test] - fn test_vector_slice_start_out_of_bounds() { - let input_vec = float32_vec_from(&[1.0, 2.0]); - let err = vector_slice(&input_vec, 5, 5); - assert!(err.is_err()); - } - - #[test] - fn test_vector_slice_end_out_of_bounds() { - let input_vec = float32_vec_from(&[1.0, 2.0]); - let err = vector_slice(&input_vec, 1, 3); - assert!(err.is_err()); + fn test_parse_string_vector_valid() { + let vector = + operations::text::vector_from_text(VectorType::Float32Dense, "[1.0, 2.0, 3.0]") + .unwrap(); + assert_eq!(vector.dims, 3); + assert_eq!(vector.vector_type, VectorType::Float32Dense); } #[quickcheck] @@ -853,11 +566,10 @@ mod tests { /// Test that a vector can be converted to text and back without loss of precision fn test_vector_text_roundtrip(v: Vector) -> bool { // Convert to text - let text = vector_to_text(&v); + let text = operations::text::vector_to_text(&v); // Parse back from text - let value = Value::from_text(&text); - let parsed = parse_string_vector(v.vector_type, &value); + let parsed = operations::text::vector_from_text(v.vector_type, &text); match parsed { Ok(parsed_vector) => { @@ -867,16 +579,17 @@ mod tests { } match v.vector_type { - VectorType::Float32 => { + VectorType::Float32Dense => { let original = v.as_f32_slice(); let parsed = parsed_vector.as_f32_slice(); original.iter().zip(parsed.iter()).all(|(a, b)| a == b) } - VectorType::Float64 => { + VectorType::Float64Dense => { let original = v.as_f64_slice(); let parsed = parsed_vector.as_f64_slice(); original.iter().zip(parsed.iter()).all(|(a, b)| a == b) } + _ => unreachable!(), } } Err(_) => false, diff --git a/dist-workspace.toml b/dist-workspace.toml index 71bd31d3f..daa295ace 100644 --- a/dist-workspace.toml +++ b/dist-workspace.toml @@ -10,7 +10,7 @@ ci = "github" # The installers to generate for each app installers = ["shell", "powershell"] # Target platforms to build apps for (Rust target-triple syntax) -targets = ["aarch64-apple-darwin", "x86_64-apple-darwin", "x86_64-unknown-linux-gnu", "x86_64-pc-windows-msvc"] +targets = ["aarch64-apple-darwin", "aarch64-unknown-linux-gnu", "x86_64-apple-darwin", "x86_64-unknown-linux-gnu", "x86_64-pc-windows-msvc"] # Which actions to run on pull requests pr-run-mode = "plan" # Path that installers should place binaries in diff --git a/docs/manual.md b/docs/manual.md index c0df2efe8..638ebad69 100644 --- a/docs/manual.md +++ b/docs/manual.md @@ -40,6 +40,7 @@ Welcome to Turso database manual! - [WAL manipulation](#wal-manipulation) - [`libsql_wal_frame_count`](#libsql_wal_frame_count) - [Encryption](#encryption) + - [Vector search](#vector-search) - [CDC](#cdc-early-preview) - [Appendix A: Turso Internals](#appendix-a-turso-internals) - [Frontend](#frontend) @@ -396,7 +397,7 @@ ROLLBACK [ TRANSACTION ] SELECT expression [ FROM table-or-subquery ] [ WHERE condition ] - [ GROU BY expression ] + [ GROUP BY expression ] ``` **Example:** @@ -609,6 +610,167 @@ $ cargo run -- --experimental-encryption \ ``` +## Vector search + +Turso supports vector search for building workloads such as semantic search, recommendation systems, and similarity matching. Vector embeddings can be stored and queried using specialized functions for distance calculations. + +### Vector types + +Turso supports both **dense** and **sparse** vector representations: + +#### Dense vectors + +Dense vectors store a value for every dimension. Turso provides two precision levels: + +* **Float32 dense vectors** (`vector32`): 32-bit floating-point values, suitable for most machine learning embeddings (e.g., OpenAI embeddings, sentence transformers). Uses 4 bytes per dimension. +* **Float64 dense vectors** (`vector64`): 64-bit floating-point values for applications requiring higher precision. Uses 8 bytes per dimension. + +Dense vectors are ideal for embeddings from neural networks where most dimensions contain non-zero values. + +#### Sparse vectors + +Sparse vectors only store non-zero values and their indices, making them memory-efficient for high-dimensional data with many zero values: + +* **Float32 sparse vectors** (`vector32_sparse`): Stores only non-zero 32-bit float values along with their dimension indices. + +Sparse vectors are ideal for TF-IDF representations, bag-of-words models, and other scenarios where most dimensions are zero. + +### Vector functions + +#### Creating and converting vectors + +**`vector32(value)`** + +Converts a text or blob value into a 32-bit dense vector. + +```sql +SELECT vector32('[1.0, 2.0, 3.0]'); +``` + +**`vector32_sparse(value)`** + +Converts a text or blob value into a 32-bit sparse vector. + +```sql +SELECT vector32_sparse('[0.0, 1.5, 0.0, 2.3, 0.0]'); +``` + +**`vector64(value)`** + +Converts a text or blob value into a 64-bit dense vector. + +```sql +SELECT vector64('[1.0, 2.0, 3.0]'); +``` + +**`vector_extract(blob)`** + +Extracts and displays a vector blob as human-readable text. + +```sql +SELECT vector_extract(embedding) FROM documents; +``` + +#### Distance functions + +Turso provides three distance metrics for measuring vector similarity: + +**`vector_distance_cos(v1, v2)`** + +Computes the cosine distance between two vectors. Returns a value between 0 (identical direction) and 2 (opposite direction). Cosine distance is computed as `1 - cosine_similarity`. + +Cosine distance is ideal for: +- Text embeddings where magnitude is less important than direction +- Comparing document similarity + +```sql +SELECT name, vector_distance_cos(embedding, vector32('[0.1, 0.5, 0.3]')) AS distance +FROM documents +ORDER BY distance +LIMIT 10; +``` + +**`vector_distance_l2(v1, v2)`** + +Computes the Euclidean (L2) distance between two vectors. Returns the straight-line distance in n-dimensional space. + +L2 distance is ideal for: +- Image embeddings where absolute differences matter +- Spatial data and geometric problems +- When embeddings are not normalized + +```sql +SELECT name, vector_distance_l2(embedding, vector32('[0.1, 0.5, 0.3]')) AS distance +FROM documents +ORDER BY distance +LIMIT 10; +``` + +**`vector_distance_jaccard(v1, v2)`** + +Computes the weighted Jaccard distance between two vectors, measuring dissimilarity based on the ratio of minimum to maximum values across dimensions. Note that this is different from the ordinary Jaccard distance, which is defined only for binary vectors. + +Weighted Jaccard distance is ideal for: +- Sparse vectors with many zero values +- Set-like comparisons +- TF-IDF and bag-of-words representations + +```sql +SELECT name, vector_distance_jaccard(sparse_embedding, vector32_sparse('[0.0, 1.0, 0.0, 2.0]')) AS distance +FROM documents +ORDER BY distance +LIMIT 10; +``` + +#### Utility functions + +**`vector_concat(v1, v2)`** + +Concatenates two vectors into a single vector. The resulting vector has dimensions equal to the sum of both input vectors. + +```sql +SELECT vector_concat(vector32('[1.0, 2.0]'), vector32('[3.0, 4.0]')); +-- Results in a 4-dimensional vector: [1.0, 2.0, 3.0, 4.0] +``` + +**`vector_slice(vector, start_index, end_index)`** + +Extracts a slice of a vector from `start_index` to `end_index` (exclusive). + +```sql +SELECT vector_slice(vector32('[1.0, 2.0, 3.0, 4.0, 5.0]'), 1, 4); +-- Results in: [2.0, 3.0, 4.0] +``` + +### Example: Semantic search + +Here's a complete example of building a semantic search system: + +```sql +-- Create a table for documents with embeddings +CREATE TABLE documents ( + id INTEGER PRIMARY KEY, + name TEXT, + content TEXT, + embedding BLOB +); + +-- Insert documents with precomputed embeddings +INSERT INTO documents (name, content, embedding) VALUES + ('Doc 1', 'Machine learning basics', vector32('[0.2, 0.5, 0.1, 0.8]')), + ('Doc 2', 'Database fundamentals', vector32('[0.1, 0.3, 0.9, 0.2]')), + ('Doc 3', 'Neural networks guide', vector32('[0.3, 0.6, 0.2, 0.7]')); + +-- Find documents similar to a query embedding +SELECT + name, + content, + vector_distance_cos(embedding, vector32('[0.25, 0.55, 0.15, 0.75]')) AS similarity +FROM documents +ORDER BY similarity +LIMIT 5; +``` + ## CDC (Early Preview) Turso supports [Change Data Capture](https://en.wikipedia.org/wiki/Change_data_capture), a powerful pattern for tracking and recording changes to your database in real-time. Instead of periodically scanning tables to find what changed, CDC automatically logs every insert, update, and delete as it happens per connection. diff --git a/macros/src/atomic_enum.rs b/macros/src/atomic_enum.rs new file mode 100644 index 000000000..14dd97ac0 --- /dev/null +++ b/macros/src/atomic_enum.rs @@ -0,0 +1,290 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, Data, DeriveInput, Fields, Type}; + +pub(crate) fn derive_atomic_enum_inner(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = &input.ident; + let atomic_name = syn::Ident::new(&format!("Atomic{name}"), name.span()); + + let variants = match &input.data { + Data::Enum(data) => &data.variants, + _ => { + return syn::Error::new_spanned(input, "AtomicEnum can only be derived for enums") + .to_compile_error() + .into(); + } + }; + + // get info about variants to determine how we have to encode them + let mut has_bool_field = false; + let mut has_u8_field = false; + let mut max_discriminant = 0u8; + + for (idx, variant) in variants.iter().enumerate() { + max_discriminant = idx as u8; + match &variant.fields { + Fields::Unit => {} + Fields::Named(fields) if fields.named.len() == 1 => { + let field = &fields.named[0]; + if is_bool_type(&field.ty) { + has_bool_field = true; + } else if is_u8_or_i8_type(&field.ty) { + has_u8_field = true; + } else { + return syn::Error::new_spanned( + field, + "AtomicEnum only supports bool, u8, or i8 fields", + ) + .to_compile_error() + .into(); + } + } + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + let field = &fields.unnamed[0]; + if is_bool_type(&field.ty) { + has_bool_field = true; + } else if is_u8_or_i8_type(&field.ty) { + has_u8_field = true; + } else { + return syn::Error::new_spanned( + field, + "AtomicEnum only supports bool, u8, or i8 fields", + ) + .to_compile_error() + .into(); + } + } + _ => { + return syn::Error::new_spanned( + variant, + "AtomicEnum only supports unit variants or variants with a single field", + ) + .to_compile_error() + .into(); + } + } + } + + let (storage_type, atomic_type) = if has_u8_field || (has_bool_field && max_discriminant > 127) + { + // Need u16: 8 bits for discriminant, 8 bits for data + (quote! { u16 }, quote! { ::std::sync::atomic::AtomicU16 }) + } else { + // Can use u8: 7 bits for discriminant, 1 bit for bool (if any) + (quote! { u8 }, quote! { ::std::sync::atomic::AtomicU8 }) + }; + + let use_u16 = has_u8_field || (has_bool_field && max_discriminant > 127); + + let to_storage = variants.iter().enumerate().map(|(idx, variant)| { + let var_name = &variant.ident; + let disc = idx as u8; // The discriminant here is just the variant's index + + match &variant.fields { + // Simple unit variant, just store the discriminant + Fields::Unit => { + if use_u16 { + quote! { #name::#var_name => #disc as u16 } + } else { + quote! { #name::#var_name => #disc } + } + } + Fields::Named(fields) => { + // Named field variant like `Write { schema_did_change: bool }` + let field = &fields.named[0]; + let field_name = &field.ident; + + if is_bool_type(&field.ty) { + if use_u16 { + // Pack as: [discriminant_byte | bool_as_byte] + // Example: Write {true} with disc=3 becomes: b100000011 + quote! { + #name::#var_name { ref #field_name } => { + (#disc as u16) | ((*#field_name as u16) << 8) + } + } + } else { + // Same as above but with u8, so only 1 bit for bool + // Example: Write{true} with disc=3 becomes: b10000011 + quote! { + #name::#var_name { ref #field_name } => { + #disc | ((*#field_name as u8) << 7) + } + } + } + } else { + // u8/i8 field always uses u16 to have enough bits + // Pack as: [discriminant_byte | value_byte] + quote! { + #name::#var_name { ref #field_name } => { + (#disc as u16) | ((*#field_name as u16) << 8) + } + } + } + } + Fields::Unnamed(_) => { + // same strategy as above, but for tuple variants like `Write(bool)` + if is_bool_type(&variant.fields.iter().next().unwrap().ty) { + if use_u16 { + quote! { + #name::#var_name(ref val) => { + (#disc as u16) | ((*val as u16) << 8) + } + } + } else { + quote! { + #name::#var_name(ref val) => { + #disc | ((*val as u8) << 7) + } + } + } + } else { + quote! { + #name::#var_name(ref val) => { + (#disc as u16) | ((*val as u16) << 8) + } + } + } + } + } + }); + + // Generate the match arms for decoding the storage representation back to enum + let from_storage = variants.iter().enumerate().map(|(idx, variant)| { + let var_name = &variant.ident; + let disc = idx as u8; + + match &variant.fields { + Fields::Unit => quote! { #disc => #name::#var_name }, + Fields::Named(fields) => { + let field = &fields.named[0]; + let field_name = &field.ident; + + if is_bool_type(&field.ty) { + if use_u16 { + // Extract bool from high byte: check if non-zero + quote! { + #disc => #name::#var_name { + #field_name: (val >> 8) != 0 + } + } + } else { + // check single bool value at bit 7 + quote! { + #disc => #name::#var_name { + #field_name: (val & 0x80) != 0 + } + } + } + } else { + quote! { + #disc => #name::#var_name { + // Extract u8/i8 from high byte and cast to appropriate type + #field_name: (val >> 8) as _ + } + } + } + } + Fields::Unnamed(_) => { + if is_bool_type(&variant.fields.iter().next().unwrap().ty) { + if use_u16 { + quote! { #disc => #name::#var_name((val >> 8) != 0) } + } else { + quote! { #disc => #name::#var_name((val & 0x80) != 0) } + } + } else { + quote! { #disc => #name::#var_name((val >> 8) as _) } + } + } + } + }); + + let discriminant_mask = if use_u16 { + quote! { 0xFF } + } else { + quote! { 0x7F } + }; + let to_storage_arms_copy = to_storage.clone(); + + let expanded = quote! { + #[derive(Debug)] + /// Atomic wrapper for #name + pub struct #atomic_name(#atomic_type); + + impl #atomic_name { + /// Encode enum into storage representation + /// Discriminant in lower bits, field data in upper bits + #[inline] + fn to_storage(val: &#name) -> #storage_type { + match val { + #(#to_storage_arms_copy),* + } + } + + /// Decode storage representation into enum + /// Panics on invalid discriminant + #[inline] + fn from_storage(val: #storage_type) -> #name { + let discriminant = (val & #discriminant_mask) as u8; + match discriminant { + #(#from_storage,)* + _ => panic!(concat!("Invalid ", stringify!(#name), " discriminant: {}"), discriminant), + } + } + + /// Create new atomic enum with initial value + #[inline] + pub const fn new(val: #name) -> Self { + // Can't call to_storage in const context, so inline it + let storage = match val { + #(#to_storage),* + }; + Self(#atomic_type::new(storage)) + } + + #[inline] + /// Load and convert the current value to expected enum + pub fn get(&self) -> #name { + Self::from_storage(self.0.load(::std::sync::atomic::Ordering::SeqCst)) + } + + #[inline] + /// Convert and store new value + pub fn set(&self, val: #name) { + self.0.store(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst) + } + + #[inline] + /// Store new value and return previous value + pub fn swap(&self, val: #name) -> #name { + let prev = self.0.swap(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst); + Self::from_storage(prev) + } + } + + impl From<#name> for #atomic_name { + fn from(val: #name) -> Self { + Self::new(val) + } + } + }; + + TokenStream::from(expanded) +} + +fn is_bool_type(ty: &Type) -> bool { + if let Type::Path(path) = ty { + path.path.is_ident("bool") + } else { + false + } +} + +fn is_u8_or_i8_type(ty: &Type) -> bool { + if let Type::Path(path) = ty { + path.path.is_ident("u8") || path.path.is_ident("i8") + } else { + false + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8df01da22..9db89bad2 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,5 +1,6 @@ mod ext; extern crate proc_macro; +mod atomic_enum; use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree}; use std::collections::HashMap; @@ -464,3 +465,30 @@ pub fn derive_vfs_module(input: TokenStream) -> TokenStream { pub fn match_ignore_ascii_case(input: TokenStream) -> TokenStream { ext::match_ignore_ascci_case(input) } + +/// Derive macro for creating atomic wrappers for enums +/// +/// Supports: +/// - Unit variants +/// - Variants with single bool/u8/i8 fields +/// - Named or unnamed fields +/// +/// Algorithm: +/// - Uses u8 representation, splitting bits for variant discriminant and field data +/// - For bool fields: high bit for bool, lower 7 bits for discriminant +/// - For u8/i8 fields: uses u16 internally (8 bits discriminant, 8 bits data) +/// +/// Example: +/// ```ignore +/// #[derive(AtomicEnum)] +/// enum TransactionState { +/// Write { schema_did_change: bool }, +/// Read, +/// PendingUpgrade, +/// None, +/// } +/// ``` +#[proc_macro_derive(AtomicEnum)] +pub fn derive_atomic_enum(input: TokenStream) -> TokenStream { + atomic_enum::derive_atomic_enum_inner(input) +} diff --git a/parser/Cargo.toml b/parser/Cargo.toml index 6f9720bc8..a140f4e44 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -13,6 +13,7 @@ name = "turso_parser" [features] default = [] serde = ["dep:serde", "bitflags/serde"] +simulator = [] [dependencies] bitflags = { workspace = true } diff --git a/parser/src/ast.rs b/parser/src/ast.rs index ced93b722..5f4449afb 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -542,6 +542,17 @@ impl Expr { pub fn raise(resolve_type: ResolveType, expr: Option) -> Expr { Expr::Raise(resolve_type, expr.map(Box::new)) } + + pub fn can_be_null(&self) -> bool { + // todo: better handling columns. Check sqlite3ExprCanBeNull + match self { + Expr::Literal(literal) => !matches!( + literal, + Literal::Numeric(_) | Literal::String(_) | Literal::Blob(_) + ), + _ => true, + } + } } /// SQL literal @@ -1121,6 +1132,11 @@ pub struct NamedColumnConstraint { // https://sqlite.org/syntax/column-constraint.html #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "simulator", derive(strum::EnumDiscriminants))] +#[cfg_attr( + feature = "simulator", + strum_discriminants(derive(strum::VariantArray)) +)] pub enum ColumnConstraint { /// `PRIMARY KEY` PrimaryKey { @@ -1416,6 +1432,8 @@ pub enum PragmaName { Encoding, /// Current free page count. FreelistCount, + /// Enable or disable foreign key constraint enforcement + ForeignKeys, /// Run integrity check on the database file IntegrityCheck, /// `journal_mode` pragma @@ -1449,6 +1467,8 @@ pub enum PragmaName { UserVersion, /// trigger a checkpoint to run on database(s) if WAL is enabled WalCheckpoint, + /// Sets or queries the threshold (in bytes) at which MVCC triggers an automatic checkpoint. + MvccCheckpointThreshold, } /// `CREATE TRIGGER` time diff --git a/parser/src/error.rs b/parser/src/error.rs index 63d4aef78..26353ffb6 100644 --- a/parser/src/error.rs +++ b/parser/src/error.rs @@ -6,45 +6,101 @@ use crate::token::TokenType; #[diagnostic()] pub enum Error { /// Lexer error - #[error("unrecognized token at {0:?}")] - UnrecognizedToken(#[label("here")] miette::SourceSpan), + #[error("unrecognized token '{token_text}' at offset {offset}")] + UnrecognizedToken { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, /// Missing quote or double-quote or backtick - #[error("non-terminated literal at {0:?}")] - UnterminatedLiteral(#[label("here")] miette::SourceSpan), + #[error("non-terminated literal '{token_text}' at offset {offset}")] + UnterminatedLiteral { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, /// Missing `]` - #[error("non-terminated bracket at {0:?}")] - UnterminatedBracket(#[label("here")] miette::SourceSpan), + #[error("non-terminated bracket '{token_text}' at offset {offset}")] + UnterminatedBracket { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, + /// Missing `*/` + #[error("non-terminated block comment '{token_text}' at offset {offset}")] + UnterminatedBlockComment { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, /// Invalid parameter name - #[error("bad variable name at {0:?}")] - BadVariableName(#[label("here")] miette::SourceSpan), + #[error("bad variable name '{token_text}' at offset {offset}")] + BadVariableName { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, /// Invalid number format - #[error("bad number at {0:?}")] - BadNumber(#[label("here")] miette::SourceSpan), + #[error("bad number '{token_text}' at offset {offset}")] + BadNumber { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, // Bad fractional part of a number - #[error("bad fractional part at {0:?}")] - BadFractionalPart(#[label("here")] miette::SourceSpan), + #[error("bad fractional part '{token_text}' at offset {offset}")] + BadFractionalPart { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, // Bad exponent part of a number - #[error("bad exponent part at {0:?}")] - BadExponentPart(#[label("here")] miette::SourceSpan), + #[error("bad exponent part '{token_text}' at offset {offset}")] + BadExponentPart { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, /// Invalid or missing sign after `!` - #[error("expected = sign at {0:?}")] - ExpectedEqualsSign(#[label("here")] miette::SourceSpan), + #[error("expected = sign '{token_text}' at offset {offset}")] + ExpectedEqualsSign { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, /// Hexadecimal integer literals follow the C-language notation of "0x" or "0X" followed by hexadecimal digits. - #[error("malformed hex integer at {0:?}")] - MalformedHexInteger(#[label("here")] miette::SourceSpan), + #[error("malformed hex integer '{token_text}' at offset {offset}")] + MalformedHexInteger { + #[label("here")] + span: miette::SourceSpan, + token_text: String, + offset: usize, + }, // parse errors // Unexpected end of file #[error("unexpected end of file")] ParseUnexpectedEOF, // Unexpected token - #[error("unexpected token at {parsed_offset:?}")] - #[diagnostic(help("expected {expected:?} but found {got:?}"))] + #[error("unexpected token '{token_text}' at offset {offset}")] + #[diagnostic(help("expected {expected_display} but found '{token_text}'"))] ParseUnexpectedToken { #[label("here")] parsed_offset: miette::SourceSpan, got: TokenType, expected: &'static [TokenType], + token_text: String, + offset: usize, + expected_display: String, }, // Custom error message #[error("{0}")] diff --git a/parser/src/lexer.rs b/parser/src/lexer.rs index 0876e4103..33429917a 100644 --- a/parser/src/lexer.rs +++ b/parser/src/lexer.rs @@ -297,14 +297,27 @@ impl<'a> Lexer<'a> { if start == self.offset { // before the underscore, there was no digit - return Err(Error::BadNumber((start, self.offset - start).into())); + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + return Err(Error::BadNumber { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } match self.peek() { Some(b) if b.is_ascii_digit() => continue, // Continue if next is a digit _ => { // after the underscore, there is no digit - return Err(Error::BadNumber((start, self.offset - start).into())); + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]) + .to_string(); + return Err(Error::BadNumber { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } } } @@ -321,7 +334,13 @@ impl<'a> Lexer<'a> { Some(b'_') => { if start == self.offset { // before the underscore, there was no digit - return Err(Error::BadNumber((start, self.offset - start).into())); + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + return Err(Error::BadNumber { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } self.eat_and_assert(|b| b == b'_'); @@ -329,7 +348,14 @@ impl<'a> Lexer<'a> { Some(b) if b.is_ascii_hexdigit() => continue, // Continue if next is a digit _ => { // after the underscore, there is no digit - return Err(Error::BadNumber((start, self.offset - start).into())); + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]) + .to_string(); + return Err(Error::BadNumber { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } } } @@ -514,9 +540,13 @@ impl<'a> Lexer<'a> { self.eat_and_assert(|b| b == b'='); } _ => { - return Err(Error::ExpectedEqualsSign( - (start, self.offset - start).into(), - )) + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + return Err(Error::ExpectedEqualsSign { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } } @@ -567,9 +597,13 @@ impl<'a> Lexer<'a> { } } None => { - return Err(Error::UnterminatedLiteral( - (start, self.offset - start).into(), - )) + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + return Err(Error::UnterminatedLiteral { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } _ => unreachable!(), }; @@ -598,9 +632,15 @@ impl<'a> Lexer<'a> { token_type: Some(TokenType::TK_FLOAT), }) } - Some(b) if is_identifier_start(b) => Err(Error::BadFractionalPart( - (start, self.offset - start).into(), - )), + Some(b) if is_identifier_start(b) => { + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + Err(Error::BadFractionalPart { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }) + } _ => Ok(Token { value: &self.input[start..self.offset], token_type: Some(TokenType::TK_FLOAT), @@ -627,11 +667,21 @@ impl<'a> Lexer<'a> { let start_num = self.offset; self.eat_while_number_digit()?; if start_num == self.offset { - return Err(Error::BadExponentPart((start, self.offset - start).into())); + let token_text = String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + return Err(Error::BadExponentPart { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } if self.peek().is_some() && is_identifier_start(self.peek().unwrap()) { - return Err(Error::BadExponentPart((start, self.offset - start).into())); + let token_text = String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + return Err(Error::BadExponentPart { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } Ok(Token { @@ -654,13 +704,23 @@ impl<'a> Lexer<'a> { self.eat_while_number_hexdigit()?; if start_hex == self.offset { - return Err(Error::MalformedHexInteger( - (start, self.offset - start).into(), - )); + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + return Err(Error::MalformedHexInteger { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } if self.peek().is_some() && is_identifier_start(self.peek().unwrap()) { - return Err(Error::BadNumber((start, self.offset - start).into())); + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + return Err(Error::BadNumber { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } return Ok(Token { @@ -689,7 +749,13 @@ impl<'a> Lexer<'a> { }) } Some(b) if is_identifier_start(b) => { - Err(Error::BadNumber((start, self.offset - start).into())) + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + Err(Error::BadNumber { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }) } _ => Ok(Token { value: &self.input[start..self.offset], @@ -710,9 +776,15 @@ impl<'a> Lexer<'a> { token_type: Some(TokenType::TK_ID), }) } - None => Err(Error::UnterminatedBracket( - (start, self.offset - start).into(), - )), + None => { + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + Err(Error::UnterminatedBracket { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }) + } _ => unreachable!(), // We should not reach here } } @@ -737,7 +809,13 @@ impl<'a> Lexer<'a> { // empty variable name if start_id == self.offset { - return Err(Error::BadVariableName((start, self.offset - start).into())); + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + return Err(Error::BadVariableName { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } Ok(Token { @@ -767,9 +845,14 @@ impl<'a> Lexer<'a> { self.eat_and_assert(|b| b == b'\''); if (end_hex - start_hex) % 2 != 0 { - return Err(Error::UnrecognizedToken( - (start, self.offset - start).into(), - )); + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]) + .to_string(); + return Err(Error::UnrecognizedToken { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }); } Ok(Token { @@ -777,9 +860,15 @@ impl<'a> Lexer<'a> { token_type: Some(TokenType::TK_BLOB), }) } - _ => Err(Error::UnterminatedLiteral( - (start, self.offset - start).into(), - )), + _ => { + let token_text = + String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + Err(Error::UnterminatedLiteral { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }) + } } } _ => { @@ -796,9 +885,12 @@ impl<'a> Lexer<'a> { fn eat_unrecognized(&mut self) -> Result> { let start = self.offset; self.eat_while(|b| b.is_some() && !b.unwrap().is_ascii_whitespace()); - Err(Error::UnrecognizedToken( - (start, self.offset - start).into(), - )) + let token_text = String::from_utf8_lossy(&self.input[start..self.offset]).to_string(); + Err(Error::UnrecognizedToken { + span: (start, self.offset - start).into(), + token_text, + offset: start, + }) } } diff --git a/parser/src/parser.rs b/parser/src/parser.rs index b4188df44..02a57fadc 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -28,12 +28,17 @@ macro_rules! peek_expect { match (TK_ID, tt.fallback_id_if_ok()) { $(($x, TK_ID) => token,)* _ => { + let token_text = String::from_utf8_lossy(token.value).to_string(); + let offset = $parser.offset(); return Err(Error::ParseUnexpectedToken { parsed_offset: ($parser.offset(), token_len).into(), expected: &[ $($x,)* ], got: tt, + token_text: token_text.clone(), + offset, + expected_display: crate::token::TokenType::format_expected_tokens(&[$($x,)*]), }) } } @@ -242,10 +247,17 @@ impl<'a> Parser<'a> { Some(token) => { if !found_semi { let tt = token.token_type.unwrap(); + let token_text = String::from_utf8_lossy(token.value).to_string(); + let offset = self.offset(); return Err(Error::ParseUnexpectedToken { - parsed_offset: (self.offset(), 1).into(), + parsed_offset: (offset, 1).into(), expected: &[TK_SEMI], got: tt, + token_text: token_text.clone(), + offset, + expected_display: crate::token::TokenType::format_expected_tokens(&[ + TK_SEMI, + ]), }); } @@ -1495,10 +1507,18 @@ impl<'a> Parser<'a> { Some(self.parse_nm()?) } else if tok.token_type == Some(TK_LP) { if can_be_lit_str { + let token = self.peek_no_eof()?; + let token_text = String::from_utf8_lossy(token.value).to_string(); + let offset = self.offset(); return Err(Error::ParseUnexpectedToken { parsed_offset: (self.offset() - name.len(), name.len()).into(), got: TK_STRING, expected: &[TK_ID, TK_INDEXED, TK_JOIN_KW], + token_text: token_text.clone(), + offset, + expected_display: crate::token::TokenType::format_expected_tokens( + &[TK_ID, TK_INDEXED, TK_JOIN_KW], + ), }); } // can not be literal string in function name @@ -1723,11 +1743,23 @@ impl<'a> Parser<'a> { _ => { let exprs = self.parse_expr_list()?; eat_expect!(self, TK_RP); - Box::new(Expr::InList { - lhs: result, - not, - rhs: exprs, - }) + // Expressions in the form: + // lhs IN () + // lhs NOT IN () + // can be simplified to constants 0 (false) and 1 (true), respectively. + // + // todo: should check if lhs has a function. If so, this optimization cannot + // be done. + if exprs.is_empty() { + let name = if not { "1" } else { "0" }; + Box::new(Expr::Literal(Literal::Numeric(name.into()))) + } else { + Box::new(Expr::InList { + lhs: result, + rhs: exprs, + not, + }) + } } } } @@ -3108,9 +3140,17 @@ impl<'a> Parser<'a> { TK_NULL | TK_BLOB | TK_STRING | TK_FLOAT | TK_INTEGER | TK_CTIME_KW => { Ok(ColumnConstraint::Default(self.parse_term()?)) } - _ => Ok(ColumnConstraint::Default(Box::new(Expr::Id( - self.parse_nm()?, - )))), + _ => { + let name = self.parse_nm()?; + let expr = if name.as_str().eq_ignore_ascii_case("true") { + Expr::Literal(Literal::Numeric("1".into())) + } else if name.as_str().eq_ignore_ascii_case("false") { + Expr::Literal(Literal::Numeric("0".into())) + } else { + Expr::Id(name) + }; + Ok(ColumnConstraint::Default(Box::new(expr))) + } } } diff --git a/parser/src/token.rs b/parser/src/token.rs index ed8f416c5..0f0719741 100644 --- a/parser/src/token.rs +++ b/parser/src/token.rs @@ -548,4 +548,47 @@ impl TokenType { _ => self, } } + + /// Get user-friendly display name for error messages + pub fn user_friendly_name(&self) -> &'static str { + match self.as_str() { + Some(s) => s, + None => match self { + TokenType::TK_ID => "identifier", + TokenType::TK_STRING => "string", + TokenType::TK_INTEGER => "integer", + TokenType::TK_FLOAT => "float", + TokenType::TK_BLOB => "blob", + TokenType::TK_VARIABLE => "variable", + TokenType::TK_ILLEGAL => "illegal token", + TokenType::TK_EOF => "end of file", + TokenType::TK_LIKE_KW => "LIKE", + TokenType::TK_JOIN_KW => "JOIN", + TokenType::TK_CTIME_KW => "datetime function", + TokenType::TK_ISNOT => "IS NOT", + TokenType::TK_ISNULL => "ISNULL", + TokenType::TK_NOTNULL => "NOTNULL", + TokenType::TK_PTR => "->", + _ => "unknown token", + }, + } + } + + /// Format multiple tokens for error messages + pub fn format_expected_tokens(tokens: &[TokenType]) -> String { + if tokens.is_empty() { + return "nothing".to_string(); + } + if tokens.len() == 1 { + return tokens[0].user_friendly_name().to_string(); + } + + let names: Vec<&str> = tokens.iter().map(|t| t.user_friendly_name()).collect(); + if names.len() == 2 { + format!("{} or {}", names[0], names[1]) + } else { + let (last, rest) = names.split_last().unwrap(); + format!("{}, or {}", rest.join(", "), last) + } + } } diff --git a/perf/throughput/rusqlite/scripts/bench.sh b/perf/throughput/rusqlite/scripts/bench.sh index a4efcc6e5..7d9f13207 100755 --- a/perf/throughput/rusqlite/scripts/bench.sh +++ b/perf/throughput/rusqlite/scripts/bench.sh @@ -6,6 +6,7 @@ echo "system,threads,batch_size,compute,throughput" for threads in 1 2 4 8; do for compute in 0 100 500 1000; do + rm -f write_throughput_test.db* ../../../target/release/write-throughput-sqlite --threads ${threads} --batch-size 100 --compute ${compute} -i 1000 done done diff --git a/perf/throughput/rusqlite/src/main.rs b/perf/throughput/rusqlite/src/main.rs index c926f2a44..e709e8d87 100644 --- a/perf/throughput/rusqlite/src/main.rs +++ b/perf/throughput/rusqlite/src/main.rs @@ -94,6 +94,7 @@ fn setup_database(db_path: &str) -> Result { conn.pragma_update(None, "journal_mode", "WAL")?; conn.pragma_update(None, "synchronous", "FULL")?; + conn.pragma_update(None, "fullfsync", "true")?; conn.execute( "CREATE TABLE IF NOT EXISTS test_table ( @@ -114,15 +115,18 @@ fn worker_thread( start_barrier: Arc, compute_usec: u64, ) -> Result { - let conn = Connection::open(&db_path)?; - - conn.busy_timeout(std::time::Duration::from_secs(30))?; - start_barrier.wait(); let mut total_inserts = 0; for iteration in 0..iterations { + let conn = Connection::open(&db_path)?; + + conn.pragma_update(None, "synchronous", "FULL")?; + conn.pragma_update(None, "fullfsync", "true")?; + + conn.busy_timeout(std::time::Duration::from_secs(30))?; + let mut stmt = conn.prepare("INSERT INTO test_table (id, data) VALUES (?, ?)")?; conn.execute("BEGIN", [])?; diff --git a/perf/throughput/turso/Cargo.toml b/perf/throughput/turso/Cargo.toml index d275cc223..8f958add1 100644 --- a/perf/throughput/turso/Cargo.toml +++ b/perf/throughput/turso/Cargo.toml @@ -7,9 +7,13 @@ edition = "2021" name = "write-throughput" path = "src/main.rs" +[features] +console = ["dep:console-subscriber" ,"tokio/tracing"] + [dependencies] turso = { workspace = true } clap = { workspace = true, features = ["derive"] } tokio = { workspace = true, default-features = true, features = ["full"] } futures = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } +console-subscriber = { workspace = true, optional = true } diff --git a/perf/throughput/turso/scripts/bench.sh b/perf/throughput/turso/scripts/bench.sh index 6d6f75d4b..27ad51e0b 100755 --- a/perf/throughput/turso/scripts/bench.sh +++ b/perf/throughput/turso/scripts/bench.sh @@ -6,6 +6,7 @@ echo "system,threads,batch_size,compute,throughput" for threads in 1 2 4 8; do for compute in 0 100 500 1000; do + rm -f write_throughput_test.db* ../../../target/release/write-throughput --threads ${threads} --batch-size 100 --compute ${compute} -i 1000 --mode concurrent done done diff --git a/perf/throughput/turso/src/main.rs b/perf/throughput/turso/src/main.rs index dbe2318f1..6f7ec4a0e 100644 --- a/perf/throughput/turso/src/main.rs +++ b/perf/throughput/turso/src/main.rs @@ -2,7 +2,9 @@ use clap::{Parser, ValueEnum}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Barrier}; use std::time::{Duration, Instant}; -use tracing_subscriber::EnvFilter; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Layer}; use turso::{Builder, Database, Result}; #[derive(Debug, Clone, Copy, ValueEnum)] @@ -53,11 +55,18 @@ struct Args { } fn main() -> Result<()> { - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) + #[cfg(feature = "console")] + let console_layer = console_subscriber::spawn(); + let fmt_layer = tracing_subscriber::fmt::layer() .with_ansi(false) .with_thread_ids(true) - .init(); + .with_filter(EnvFilter::from_default_env()); + let registry = tracing_subscriber::registry(); + #[cfg(feature = "console")] + let registry = registry.with(console_layer); + let registry = registry.with(fmt_layer); + + registry.init(); let args = Args::parse(); let rt = tokio::runtime::Builder::new_multi_thread() diff --git a/scripts/antithesis/launch.sh b/scripts/antithesis/launch.sh index 2f95c3975..83ec5caf6 100755 --- a/scripts/antithesis/launch.sh +++ b/scripts/antithesis/launch.sh @@ -3,7 +3,7 @@ curl --fail -u "$ANTITHESIS_USER:$ANTITHESIS_PASSWD" \ -X POST https://$ANTITHESIS_TENANT.antithesis.com/api/v1/launch/limbo \ -d "{\"params\": { \"antithesis.description\":\"basic_test on main\", - \"custom.duration\":\"8\", + \"custom.duration\":\"4\", \"antithesis.config_image\":\"$ANTITHESIS_DOCKER_REPO/limbo-config:antithesis-latest\", \"antithesis.images\":\"$ANTITHESIS_DOCKER_REPO/limbo-workload:antithesis-latest\", \"antithesis.report.recipients\":\"$ANTITHESIS_EMAIL\" diff --git a/simulator/Cargo.toml b/simulator/Cargo.toml index 09401f2cc..7fd5dbeff 100644 --- a/simulator/Cargo.toml +++ b/simulator/Cargo.toml @@ -46,3 +46,4 @@ either = "1.15.0" similar = { workspace = true } similar-asserts = { workspace = true } bitmaps = { workspace = true } +bitflags.workspace = true diff --git a/simulator/README.md b/simulator/README.md index 3de0afb99..b3c93e00a 100644 --- a/simulator/README.md +++ b/simulator/README.md @@ -118,6 +118,17 @@ For development purposes, you can run `make sim-schema` to generate a JsonSchema } ``` +## Run simulator using the Miri interpreter + +Miri is a deterministic Rust interpreter designed to identify undefined behavior. To run the simulator under Miri, use +```bash +MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-disable-stacked-borrows" RUST_LOG=limbo_sim=debug cargo +nightly miri run --bin limbo_sim -- --disable-integrity-check +```` +Notes: +- `-Zmiri-disable-isolation` is needed for host access (like opening a file) +- `-Zmiri-disable-stacked-borrows` this alias checking is experimental, so disabled for now +- `--disable-integrity-check` is needed since we can't run sqlite via the FFI in Miri + ## Resources - [(reading) TigerBeetle Deterministic Simulation Testing](https://docs.tigerbeetle.com/about/vopr/) diff --git a/simulator/generation/mod.rs b/simulator/generation/mod.rs index 80a2d0cff..6206c9a62 100644 --- a/simulator/generation/mod.rs +++ b/simulator/generation/mod.rs @@ -1,3 +1,6 @@ +use rand::distr::weighted::WeightedIndex; +use sql_generation::generation::GenerationContext; + use crate::runner::env::ShadowTablesMut; pub mod plan; @@ -17,3 +20,15 @@ pub(crate) trait Shadow { type Result; fn shadow(&self, tables: &mut ShadowTablesMut<'_>) -> Self::Result; } + +pub(super) trait WeightedDistribution { + type Item; + type GenItem; + fn items(&self) -> &[Self::Item]; + fn weights(&self) -> &WeightedIndex; + fn sample( + &self, + rng: &mut R, + context: &C, + ) -> Self::GenItem; +} diff --git a/simulator/generation/plan.rs b/simulator/generation/plan.rs index 3c07d73e3..6c7e66384 100644 --- a/simulator/generation/plan.rs +++ b/simulator/generation/plan.rs @@ -11,12 +11,11 @@ use indexmap::IndexSet; use serde::{Deserialize, Serialize}; use sql_generation::{ - generation::{Arbitrary, ArbitraryFrom, GenerationContext, frequency, query::SelectFree}, + generation::{Arbitrary, ArbitraryFrom, GenerationContext, frequency}, model::{ query::{ - Create, CreateIndex, Delete, Drop, Insert, Select, + Create, transaction::{Begin, Commit}, - update::Update, }, table::SimValue, }, @@ -26,7 +25,11 @@ use turso_core::{Connection, Result, StepResult}; use crate::{ SimulatorEnv, - generation::Shadow, + generation::{ + Shadow, WeightedDistribution, + property::PropertyDistribution, + query::{QueryDistribution, possible_queries}, + }, model::Query, runner::env::{ShadowTablesMut, SimConnection, SimulationType}, }; @@ -55,14 +58,17 @@ impl InteractionPlan { pub fn new_with(plan: Vec, mvcc: bool) -> Self { let len = plan .iter() - .filter(|interaction| !interaction.is_transaction()) + .filter(|interaction| !interaction.ignore()) .count(); Self { plan, mvcc, len } } #[inline] - pub fn plan(&self) -> &[Interactions] { - &self.plan + fn new_len(&self) -> usize { + self.plan + .iter() + .filter(|interaction| !interaction.ignore()) + .count() } /// Length of interactions that are not transaction statements @@ -72,12 +78,59 @@ impl InteractionPlan { } pub fn push(&mut self, interactions: Interactions) { - if !interactions.is_transaction() { + if !interactions.ignore() { self.len += 1; } self.plan.push(interactions); } + pub fn remove(&mut self, index: usize) -> Interactions { + let interactions = self.plan.remove(index); + if !interactions.ignore() { + self.len -= 1; + } + interactions + } + + pub fn truncate(&mut self, len: usize) { + self.plan.truncate(len); + self.len = self.new_len(); + } + + pub fn retain_mut(&mut self, mut f: F) + where + F: FnMut(&mut Interactions) -> bool, + { + let f = |t: &mut Interactions| { + let ignore = t.ignore(); + let retain = f(t); + // removed an interaction that was not previously ignored + if !retain && !ignore { + self.len -= 1; + } + retain + }; + self.plan.retain_mut(f); + } + + #[expect(dead_code)] + pub fn retain(&mut self, mut f: F) + where + F: FnMut(&Interactions) -> bool, + { + let f = |t: &Interactions| { + let ignore = t.ignore(); + let retain = f(t); + // removed an interaction that was not previously ignored + if !retain && !ignore { + self.len -= 1; + } + retain + }; + self.plan.retain(f); + self.len = self.new_len(); + } + /// Compute via diff computes a a plan from a given `.plan` file without the need to parse /// sql. This is possible because there are two versions of the plan file, one that is human /// readable and one that is serialized as JSON. Under watch mode, the users will be able to @@ -167,18 +220,7 @@ impl InteractionPlan { } pub(crate) fn stats(&self) -> InteractionStats { - let mut stats = InteractionStats { - select_count: 0, - insert_count: 0, - delete_count: 0, - update_count: 0, - create_count: 0, - create_index_count: 0, - drop_count: 0, - begin_count: 0, - commit_count: 0, - rollback_count: 0, - }; + let mut stats = InteractionStats::default(); fn query_stat(q: &Query, stats: &mut InteractionStats) { match q { @@ -192,11 +234,19 @@ impl InteractionPlan { Query::Begin(_) => stats.begin_count += 1, Query::Commit(_) => stats.commit_count += 1, Query::Rollback(_) => stats.rollback_count += 1, + Query::AlterTable(_) => stats.alter_table_count += 1, + Query::DropIndex(_) => stats.drop_index_count += 1, + Query::Placeholder => {} } } for interactions in &self.plan { match &interactions.interactions { InteractionsType::Property(property) => { + if matches!(property, Property::AllTableHaveExpectedContent { .. }) { + // Skip Property::AllTableHaveExpectedContent when counting stats + // this allows us to generate more relevant interactions as we count less Select's to the Stats + continue; + } for interaction in &property.interactions(interactions.connection_index) { if let InteractionType::Query(query) = &interaction.interaction { query_stat(query, &mut stats); @@ -235,6 +285,28 @@ impl InteractionPlan { env: &mut SimulatorEnv, ) -> Option> { let num_interactions = env.opts.max_interactions as usize; + // If last interaction needs to check all db tables, generate the Property to do so + if let Some(i) = self.plan.last() + && i.check_tables() + { + let check_all_tables = Interactions::new( + i.connection_index, + InteractionsType::Property(Property::AllTableHaveExpectedContent { + tables: env + .connection_context(i.connection_index) + .tables() + .iter() + .map(|t| t.name.clone()) + .collect(), + }), + ); + + let out_interactions = check_all_tables.interactions(); + + self.push(check_all_tables); + return Some(out_interactions); + } + if self.len() < num_interactions { let conn_index = env.choose_conn(rng); let interactions = if self.mvcc && !env.conn_in_transaction(conn_index) { @@ -289,16 +361,7 @@ impl InteractionPlan { self.push(interactions); Some(out_interactions) } else { - // after we generated all interactions if some connection is still in a transaction, commit - (0..env.connections.len()) - .find(|idx| env.conn_in_transaction(*idx)) - .map(|conn_index| { - let query = Query::Commit(Commit); - let interaction = Interactions::new(conn_index, InteractionsType::Query(query)); - let out_interactions = interaction.interactions(); - self.push(interaction); - out_interactions - }) + None } } @@ -310,6 +373,7 @@ impl InteractionPlan { let iter = interactions.into_iter(); PlanGenerator { plan: self, + peek: None, iter, rng, } @@ -379,28 +443,146 @@ impl InteractionPlanIterator for &mut T { pub struct PlanGenerator<'a, R: rand::Rng> { plan: &'a mut InteractionPlan, + peek: Option, iter: as IntoIterator>::IntoIter, rng: &'a mut R, } +impl<'a, R: rand::Rng> PlanGenerator<'a, R> { + fn next_interaction(&mut self, env: &mut SimulatorEnv) -> Option { + self.iter + .next() + .or_else(|| { + // Iterator ended, try to create a new iterator + // This will not be an infinte sequence because generate_next_interaction will eventually + // stop generating + let mut iter = self + .plan + .generate_next_interaction(self.rng, env) + .map_or(Vec::new().into_iter(), |interactions| { + interactions.into_iter() + }); + let next = iter.next(); + self.iter = iter; + + next + }) + .map(|interaction| { + // Certain properties can generate intermediate queries + // we need to generate them here and substitute + if let InteractionType::Query(Query::Placeholder) = &interaction.interaction { + let stats = self.plan.stats(); + + let conn_ctx = env.connection_context(interaction.connection_index); + + let remaining_ = remaining( + env.opts.max_interactions, + &env.profile.query, + &stats, + env.profile.experimental_mvcc, + &conn_ctx, + ); + + let InteractionsType::Property(property) = + &mut self.plan.last_mut().unwrap().interactions + else { + unreachable!("only properties have extensional queries"); + }; + + let queries = possible_queries(conn_ctx.tables()); + let query_distr = QueryDistribution::new(queries, &remaining_); + + let query_gen = property.get_extensional_query_gen_function(); + + let mut count = 0; + let new_query = loop { + if count > 1_000_000 { + panic!("possible infinite loop in query generation"); + } + if let Some(new_query) = + (query_gen)(self.rng, &conn_ctx, &query_distr, property) + { + let queries = property.get_extensional_queries().unwrap(); + let query = queries + .iter_mut() + .find(|query| matches!(query, Query::Placeholder)) + .expect("Placeholder should be present in extensional queries"); + *query = new_query.clone(); + break new_query; + } + count += 1; + }; + Interaction::new( + interaction.connection_index, + InteractionType::Query(new_query), + ) + } else { + interaction + } + }) + } + + fn peek(&mut self, env: &mut SimulatorEnv) -> Option<&Interaction> { + if self.peek.is_none() { + self.peek = self.next_interaction(env); + } + self.peek.as_ref() + } +} + impl<'a, R: rand::Rng> InteractionPlanIterator for PlanGenerator<'a, R> { /// try to generate the next [Interactions] and store it fn next(&mut self, env: &mut SimulatorEnv) -> Option { - self.iter.next().or_else(|| { - // Iterator ended, try to create a new iterator - // This will not be an infinte sequence because generate_next_interaction will eventually - // stop generating - let mut iter = self - .plan - .generate_next_interaction(self.rng, env) - .map_or(Vec::new().into_iter(), |interactions| { - interactions.into_iter() - }); - let next = iter.next(); - self.iter = iter; + let mvcc = self.plan.mvcc; + match self.peek(env) { + Some(peek_interaction) => { + if mvcc && peek_interaction.is_ddl() { + // try to commit a transaction as we cannot execute DDL statements in concurrent mode - next - }) + let commit_connection = (0..env.connections.len()) + .find(|idx| env.conn_in_transaction(*idx)) + .map(|conn_index| { + let query = Query::Commit(Commit); + let interaction = Interactions::new( + conn_index, + InteractionsType::Query(query.clone()), + ); + + // Connections are queued for commit on `generate_next_interaction` if Interactions::Query or Interactions::Property produce a DDL statement. + // This means that the only way we will reach here, is if the DDL statement was created later in the extensional query of a Property + let queries = self + .plan + .last_mut() + .unwrap() + .get_extensional_queries() + .unwrap(); + queries.insert(0, query.clone()); + + self.plan.push(interaction); + + Interaction::new(conn_index, InteractionType::Query(query)) + }); + if commit_connection.is_some() { + return commit_connection; + } + } + + self.peek.take() + } + None => { + // after we generated all interactions if some connection is still in a transaction, commit + (0..env.connections.len()) + .find(|idx| env.conn_in_transaction(*idx)) + .map(|conn_index| { + let query = Query::Commit(Commit); + let interaction = + Interactions::new(conn_index, InteractionsType::Query(query)); + self.plan.push(interaction); + + Interaction::new(conn_index, InteractionType::Query(Query::Commit(Commit))) + }) + } + } } } @@ -448,6 +630,25 @@ impl Interactions { InteractionsType::Query(..) | InteractionsType::Fault(..) => None, } } + + /// Whether the interaction needs to check the database tables + pub fn check_tables(&self) -> bool { + match &self.interactions { + InteractionsType::Property(property) => property.check_tables(), + InteractionsType::Query(..) | InteractionsType::Fault(..) => false, + } + } + + /// Interactions that are not counted/ignored in the InteractionPlan. + /// Used in InteractionPlan to not count certain interactions to its length, as they are just auxiliary. This allows more + /// meaningful interactions to be generation + fn ignore(&self) -> bool { + self.is_transaction() + || matches!( + self.interactions, + InteractionsType::Property(Property::AllTableHaveExpectedContent { .. }) + ) + } } impl Deref for Interactions { @@ -481,14 +682,6 @@ impl InteractionsType { } impl Interactions { - pub(crate) fn name(&self) -> Option<&str> { - match &self.interactions { - InteractionsType::Property(property) => Some(property.name()), - InteractionsType::Query(_) => None, - InteractionsType::Fault(_) => None, - } - } - pub(crate) fn interactions(&self) -> Vec { match &self.interactions { InteractionsType::Property(property) => property.interactions(self.connection_index), @@ -564,36 +757,27 @@ impl Display for InteractionPlan { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Default)] pub(crate) struct InteractionStats { - pub(crate) select_count: u32, - pub(crate) insert_count: u32, - pub(crate) delete_count: u32, - pub(crate) update_count: u32, - pub(crate) create_count: u32, - pub(crate) create_index_count: u32, - pub(crate) drop_count: u32, - pub(crate) begin_count: u32, - pub(crate) commit_count: u32, - pub(crate) rollback_count: u32, -} - -impl InteractionStats { - pub fn total_writes(&self) -> u32 { - self.insert_count - + self.delete_count - + self.update_count - + self.create_count - + self.create_index_count - + self.drop_count - } + pub select_count: u32, + pub insert_count: u32, + pub delete_count: u32, + pub update_count: u32, + pub create_count: u32, + pub create_index_count: u32, + pub drop_count: u32, + pub begin_count: u32, + pub commit_count: u32, + pub rollback_count: u32, + pub alter_table_count: u32, + pub drop_index_count: u32, } impl Display for InteractionStats { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Read: {}, Write: {}, Delete: {}, Update: {}, Create: {}, CreateIndex: {}, Drop: {}, Begin: {}, Commit: {}, Rollback: {}", + "Read: {}, Insert: {}, Delete: {}, Update: {}, Create: {}, CreateIndex: {}, Drop: {}, Begin: {}, Commit: {}, Rollback: {}, Alter Table: {}, Drop Index: {}", self.select_count, self.insert_count, self.delete_count, @@ -604,16 +788,14 @@ impl Display for InteractionStats { self.begin_count, self.commit_count, self.rollback_count, + self.alter_table_count, + self.drop_index_count, ) } } type AssertionFunc = dyn Fn(&Vec, &mut SimulatorEnv) -> Result>; -enum AssertionAST { - Pick(), -} - #[derive(Clone)] pub struct Assertion { pub func: Rc, @@ -763,6 +945,11 @@ impl InteractionType { pub(crate) fn execute_query(&self, conn: &mut Arc) -> ResultSet { if let Self::Query(query) = self { + assert!( + !matches!(query, Query::Placeholder), + "simulation cannot have a placeholder Query for execution" + ); + let query_str = query.to_string(); let rows = conn.query(&query_str); if rows.is_err() { @@ -1064,118 +1251,22 @@ fn reopen_database(env: &mut SimulatorEnv) { }; } -fn random_create(rng: &mut R, env: &SimulatorEnv, conn_index: usize) -> Interactions { - let conn_ctx = env.connection_context(conn_index); - let mut create = Create::arbitrary(rng, &conn_ctx); - while conn_ctx - .tables() - .iter() - .any(|t| t.name == create.table.name) - { - create = Create::arbitrary(rng, &conn_ctx); - } - Interactions::new(conn_index, InteractionsType::Query(Query::Create(create))) -} - -fn random_read(rng: &mut R, env: &SimulatorEnv, conn_index: usize) -> Interactions { - Interactions::new( - conn_index, - InteractionsType::Query(Query::Select(Select::arbitrary( - rng, - &env.connection_context(conn_index), - ))), - ) -} - -fn random_expr(rng: &mut R, env: &SimulatorEnv, conn_index: usize) -> Interactions { - Interactions::new( - conn_index, - InteractionsType::Query(Query::Select( - SelectFree::arbitrary(rng, &env.connection_context(conn_index)).0, - )), - ) -} - -fn random_write(rng: &mut R, env: &SimulatorEnv, conn_index: usize) -> Interactions { - Interactions::new( - conn_index, - InteractionsType::Query(Query::Insert(Insert::arbitrary( - rng, - &env.connection_context(conn_index), - ))), - ) -} - -fn random_delete(rng: &mut R, env: &SimulatorEnv, conn_index: usize) -> Interactions { - Interactions::new( - conn_index, - InteractionsType::Query(Query::Delete(Delete::arbitrary( - rng, - &env.connection_context(conn_index), - ))), - ) -} - -fn random_update(rng: &mut R, env: &SimulatorEnv, conn_index: usize) -> Interactions { - Interactions::new( - conn_index, - InteractionsType::Query(Query::Update(Update::arbitrary( - rng, - &env.connection_context(conn_index), - ))), - ) -} - -fn random_drop(rng: &mut R, env: &SimulatorEnv, conn_index: usize) -> Interactions { - Interactions::new( - conn_index, - InteractionsType::Query(Query::Drop(Drop::arbitrary( - rng, - &env.connection_context(conn_index), - ))), - ) -} - -fn random_create_index( +fn random_fault( rng: &mut R, env: &SimulatorEnv, conn_index: usize, -) -> Option { - let conn_ctx = env.connection_context(conn_index); - if conn_ctx.tables().is_empty() { - return None; - } - let mut create_index = CreateIndex::arbitrary(rng, &conn_ctx); - while conn_ctx - .tables() - .iter() - .find(|t| t.name == create_index.table_name) - .expect("table should exist") - .indexes - .iter() - .any(|i| i == &create_index.index_name) - { - create_index = CreateIndex::arbitrary(rng, &conn_ctx); - } - - Some(Interactions::new( - conn_index, - InteractionsType::Query(Query::CreateIndex(create_index)), - )) -} - -fn random_fault(rng: &mut R, env: &SimulatorEnv) -> Interactions { +) -> Interactions { let faults = if env.opts.disable_reopen_database { vec![Fault::Disconnect] } else { vec![Fault::Disconnect, Fault::ReopenDatabase] }; let fault = faults[rng.random_range(0..faults.len())]; - Interactions::new(env.choose_conn(rng), InteractionsType::Fault(fault)) + Interactions::new(conn_index, InteractionsType::Fault(fault)) } impl ArbitraryFrom<(&SimulatorEnv, InteractionStats, usize)> for Interactions { - fn arbitrary_from( + fn arbitrary_from( rng: &mut R, conn_ctx: &C, (env, stats, conn_index): (&SimulatorEnv, InteractionStats, usize), @@ -1185,72 +1276,51 @@ impl ArbitraryFrom<(&SimulatorEnv, InteractionStats, usize)> for Interactions { &env.profile.query, &stats, env.profile.experimental_mvcc, + conn_ctx, ); - frequency( - vec![ - ( - u32::min(remaining_.select, remaining_.insert) + remaining_.create, - Box::new(|rng: &mut R| { - Interactions::new( - conn_index, - InteractionsType::Property(Property::arbitrary_from( - rng, - conn_ctx, - (env, &stats), - )), - ) - }), - ), - ( - remaining_.select, - Box::new(|rng: &mut R| random_read(rng, env, conn_index)), - ), - ( - remaining_.select / 3, - Box::new(|rng: &mut R| random_expr(rng, env, conn_index)), - ), - ( - remaining_.insert, - Box::new(|rng: &mut R| random_write(rng, env, conn_index)), - ), - ( - remaining_.create, - Box::new(|rng: &mut R| random_create(rng, env, conn_index)), - ), - ( - remaining_.create_index, - Box::new(|rng: &mut R| { - if let Some(interaction) = random_create_index(rng, env, conn_index) { - interaction - } else { - // if no tables exist, we can't create an index, so fallback to creating a table - random_create(rng, env, conn_index) - } - }), - ), - ( - remaining_.delete, - Box::new(|rng: &mut R| random_delete(rng, env, conn_index)), - ), - ( - remaining_.update, - Box::new(|rng: &mut R| random_update(rng, env, conn_index)), - ), - ( - // remaining_.drop, - 0, - Box::new(|rng: &mut R| random_drop(rng, env, conn_index)), - ), - ( - remaining_ - .select - .min(remaining_.insert) - .min(remaining_.create) - .max(1), - Box::new(|rng: &mut R| random_fault(rng, env)), - ), - ], - rng, - ) + + let queries = possible_queries(conn_ctx.tables()); + let query_distr = QueryDistribution::new(queries, &remaining_); + + #[expect(clippy::type_complexity)] + let mut choices: Vec<(u32, Box Interactions>)> = vec![ + ( + query_distr.weights().total_weight(), + Box::new(|rng: &mut R| { + Interactions::new( + conn_index, + InteractionsType::Query(Query::arbitrary_from(rng, conn_ctx, &query_distr)), + ) + }), + ), + ( + remaining_ + .select + .min(remaining_.insert) + .min(remaining_.create) + .max(1), + Box::new(|rng: &mut R| random_fault(rng, env, conn_index)), + ), + ]; + + if let Ok(property_distr) = + PropertyDistribution::new(env, &remaining_, &query_distr, conn_ctx) + { + choices.push(( + property_distr.weights().total_weight(), + Box::new(move |rng: &mut R| { + Interactions::new( + conn_index, + InteractionsType::Property(Property::arbitrary_from( + rng, + conn_ctx, + &property_distr, + )), + ) + }), + )); + }; + + frequency(choices, rng) } } diff --git a/simulator/generation/property.rs b/simulator/generation/property.rs index a113dc939..f2530dfeb 100644 --- a/simulator/generation/property.rs +++ b/simulator/generation/property.rs @@ -1,9 +1,18 @@ +//! FIXME: With the current API and generation logic in plan.rs, +//! for Properties that have intermediary queries we need to CLONE the current Context tables +//! to properly generate queries, as we need to shadow after each query generated to make sure we are generating +//! queries that are valid. This is specially valid with DROP and ALTER TABLE in the mix, because with outdated context +//! we can generate queries that reference tables that do not exist. This is not a correctness issue, but more of +//! an optimization issue that is good to point out for the future + +use rand::distr::{Distribution, weighted::WeightedIndex}; use serde::{Deserialize, Serialize}; use sql_generation::{ - generation::{Arbitrary, ArbitraryFrom, GenerationContext, frequency, pick, pick_index}, + generation::{Arbitrary, ArbitraryFrom, GenerationContext, pick, pick_index}, model::{ query::{ Create, Delete, Drop, Insert, Select, + alter_table::{AlterTable, AlterTableType}, predicate::Predicate, select::{CompoundOperator, CompoundSelect, ResultColumn, SelectBody, SelectInner}, transaction::{Begin, Commit, Rollback}, @@ -12,13 +21,16 @@ use sql_generation::{ table::SimValue, }, }; +use strum::IntoEnumIterator; use turso_core::{LimboError, types}; use turso_parser::ast::{self, Distinctness}; use crate::{ common::print_diff, - generation::{Shadow as _, plan::InteractionType}, - model::Query, + generation::{ + Shadow as _, WeightedDistribution, plan::InteractionType, query::QueryDistribution, + }, + model::{Query, QueryCapabilities, QueryDiscriminants}, profiles::query::QueryProfile, runner::env::SimulatorEnv, }; @@ -27,7 +39,8 @@ use super::plan::{Assertion, Interaction, InteractionStats, ResultSet}; /// Properties are representations of executable specifications /// about the database behavior. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, strum::EnumDiscriminants)] +#[strum_discriminants(derive(strum::EnumIter))] pub enum Property { /// Insert-Select is a property in which the inserted row /// must be in the resulting rows of a select query that has a @@ -79,6 +92,17 @@ pub enum Property { TableHasExpectedContent { table: String, }, + /// AllTablesHaveExpectedContent is a property in which the table + /// must have the expected content, i.e. all the insertions and + /// updates and deletions should have been persisted in the way + /// we think they were. + /// The execution of the property is as follows + /// SELECT * FROM + /// ASSERT + /// for each table in the simulator model + AllTableHaveExpectedContent { + tables: Vec, + }, /// Double Create Failure is a property in which creating /// the same table twice leads to an error. /// The execution of the property is as follows @@ -185,11 +209,9 @@ pub enum Property { /// FsyncNoWait { query: Query, - tables: Vec, }, FaultyQuery { query: Query, - tables: Vec, }, /// Property used to subsititute a property with its queries only Queries { @@ -203,12 +225,16 @@ pub struct InteractiveQueryInfo { end_with_commit: bool, } +type PropertyQueryGenFunc<'a, R, G> = + fn(&mut R, &G, &QueryDistribution, &Property) -> Option; + impl Property { pub(crate) fn name(&self) -> &str { match self { Property::InsertValuesSelect { .. } => "Insert-Values-Select", Property::ReadYourUpdatesBack { .. } => "Read-Your-Updates-Back", Property::TableHasExpectedContent { .. } => "Table-Has-Expected-Content", + Property::AllTableHaveExpectedContent { .. } => "All-Tables-Have-Expected-Content", Property::DoubleCreateFailure { .. } => "Double-Create-Failure", Property::SelectLimit { .. } => "Select-Limit", Property::DeleteSelect { .. } => "Delete-Select", @@ -222,6 +248,14 @@ impl Property { } } + /// Property Does some sort of fault injection + pub fn check_tables(&self) -> bool { + matches!( + self, + Property::FsyncNoWait { .. } | Property::FaultyQuery { .. } + ) + } + pub fn get_extensional_queries(&mut self) -> Option<&mut Vec> { match self { Property::InsertValuesSelect { queries, .. } @@ -235,7 +269,209 @@ impl Property { | Property::WhereTrueFalseNull { .. } | Property::UNIONAllPreservesCardinality { .. } | Property::ReadYourUpdatesBack { .. } - | Property::TableHasExpectedContent { .. } => None, + | Property::TableHasExpectedContent { .. } + | Property::AllTableHaveExpectedContent { .. } => None, + } + } + + pub(super) fn get_extensional_query_gen_function(&self) -> PropertyQueryGenFunc + where + R: rand::Rng + ?Sized, + G: GenerationContext, + { + match self { + Property::InsertValuesSelect { .. } => { + // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) + // - [x] The inserted row will not be deleted. + // - [x] The inserted row will not be updated. + // - [x] The table `t` will not be renamed, dropped, or altered. + |rng: &mut R, ctx: &G, query_distr: &QueryDistribution, property: &Property| { + let Property::InsertValuesSelect { + insert, row_index, .. + } = property + else { + unreachable!(); + }; + let query = Query::arbitrary_from(rng, ctx, query_distr); + let table_name = insert.table(); + let table = ctx + .tables() + .iter() + .find(|table| table.name == table_name) + .unwrap(); + + let rows = insert.rows(); + let row = &rows[*row_index]; + + match &query { + Query::Delete(Delete { + table: t, + predicate, + }) if t == &table.name && predicate.test(row, table) => { + // The inserted row will not be deleted. + None + } + Query::Create(Create { table: t }) if t.name == table.name => { + // There will be no errors in the middle interactions. + // - Creating the same table is an error + None + } + Query::Update(Update { + table: t, + set_values: _, + predicate, + }) if t == &table.name && predicate.test(row, table) => { + // The inserted row will not be updated. + None + } + Query::Drop(Drop { table: t }) if *t == table.name => { + // Cannot drop the table we are inserting + None + } + Query::AlterTable(AlterTable { table_name: t, .. }) if *t == table.name => { + // Cannot alter the table we are inserting + None + } + _ => Some(query), + } + } + } + Property::DoubleCreateFailure { .. } => { + // The interactions in the middle has the following constraints; + // - [x] There will be no errors in the middle interactions.(best effort) + // - [x] Table `t` will not be renamed or dropped. + |rng: &mut R, ctx: &G, query_distr: &QueryDistribution, property: &Property| { + let Property::DoubleCreateFailure { create, .. } = property else { + unreachable!() + }; + + let table_name = create.table.name.clone(); + let table = ctx + .tables() + .iter() + .find(|table| table.name == table_name) + .unwrap(); + + let query = Query::arbitrary_from(rng, ctx, query_distr); + match &query { + Query::Create(Create { table: t }) if t.name == table.name => { + // There will be no errors in the middle interactions. + // - Creating the same table is an error + None + } + Query::Drop(Drop { table: t }) if *t == table.name => { + // Cannot Drop the created table + None + } + Query::AlterTable(AlterTable { table_name: t, .. }) if *t == table.name => { + // Cannot alter the table we created + None + } + _ => Some(query), + } + } + } + Property::DeleteSelect { .. } => { + // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) + // - [x] A row that holds for the predicate will not be inserted. + // - [x] The table `t` will not be renamed, dropped, or altered. + + |rng, ctx, query_distr, property| { + let Property::DeleteSelect { + table: table_name, + predicate, + .. + } = property + else { + unreachable!() + }; + + let table_name = table_name.clone(); + let table = ctx + .tables() + .iter() + .find(|table| table.name == table_name) + .unwrap(); + let query = Query::arbitrary_from(rng, ctx, query_distr); + match &query { + Query::Insert(Insert::Values { table: t, values }) + if *t == table_name + && values.iter().any(|v| predicate.test(v, table)) => + { + // A row that holds for the predicate will not be inserted. + None + } + Query::Insert(Insert::Select { + table: t, + select: _, + }) if t == &table.name => { + // A row that holds for the predicate will not be inserted. + None + } + Query::Update(Update { table: t, .. }) if t == &table.name => { + // A row that holds for the predicate will not be updated. + None + } + Query::Create(Create { table: t }) if t.name == table.name => { + // There will be no errors in the middle interactions. + // - Creating the same table is an error + None + } + Query::Drop(Drop { table: t }) if *t == table.name => { + // Cannot Drop the same table + None + } + Query::AlterTable(AlterTable { table_name: t, .. }) if *t == table.name => { + // Cannot alter the same table + None + } + _ => Some(query), + } + } + } + Property::DropSelect { .. } => { + // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) + // - [x] The table `t` will not be created, no table will be renamed to `t`. + |rng, ctx, query_distr, property: &Property| { + let Property::DropSelect { + table: table_name, .. + } = property + else { + unreachable!() + }; + + let query = Query::arbitrary_from(rng, ctx, query_distr); + match &query { + Query::Create(Create { table: t }) if t.name == *table_name => { + // - The table `t` will not be created + None + } + Query::AlterTable(AlterTable { + table_name: t, + alter_table_type: AlterTableType::RenameTo { new_name }, + }) if t == table_name || new_name == table_name => { + // no table will be renamed to `t` + None + } + _ => Some(query), + } + } + } + Property::Queries { .. } => { + unreachable!("No extensional querie generation for `Property::Queries`") + } + Property::FsyncNoWait { .. } | Property::FaultyQuery { .. } => { + unreachable!("No extensional queries") + } + Property::SelectLimit { .. } + | Property::SelectSelectOptimizer { .. } + | Property::WhereTrueFalseNull { .. } + | Property::UNIONAllPreservesCardinality { .. } + | Property::ReadYourUpdatesBack { .. } + | Property::TableHasExpectedContent { .. } + | Property::AllTableHaveExpectedContent { .. } => { + unreachable!("No extensional queries") + } } } @@ -244,6 +480,9 @@ impl Property { /// and `interaction` cannot be serialized directly. pub(crate) fn interactions(&self, connection_index: usize) -> Vec { match self { + Property::AllTableHaveExpectedContent { tables } => { + assert_all_table_values(tables, connection_index).collect() + } Property::TableHasExpectedContent { table } => { let table = table.to_string(); let table_name = table.clone(); @@ -688,17 +927,17 @@ impl Property { Ok(success) => Ok(Err(format!( "expected table creation to fail but it succeeded: {success:?}" ))), - Err(e) => { - if e.to_string() - .contains(&format!("Table {table_name} does not exist")) + Err(e) => match e { + e if e + .to_string() + .contains(&format!("no such table: {table_name}")) => { Ok(Ok(())) - } else { - Ok(Err(format!( - "expected table does not exist error, got: {e}" - ))) } - } + _ => Ok(Err(format!( + "expected table does not exist error, got: {e}" + ))), + }, } }, )); @@ -719,7 +958,7 @@ impl Property { .into_iter() .map(|q| Interaction::new(connection_index, InteractionType::Query(q))), ); - interactions.push(Interaction::new(connection_index, select)); + interactions.push(Interaction::new_ignore_error(connection_index, select)); interactions.push(Interaction::new(connection_index, assertion)); interactions @@ -813,18 +1052,13 @@ impl Property { Interaction::new(connection_index, assertion), ] } - Property::FsyncNoWait { query, tables } => { - let checks = assert_all_table_values(tables, connection_index); - Vec::from_iter( - std::iter::once(Interaction::new( - connection_index, - InteractionType::FsyncQuery(query.clone()), - )) - .chain(checks), - ) + Property::FsyncNoWait { query } => { + vec![Interaction::new( + connection_index, + InteractionType::FsyncQuery(query.clone()), + )] } - Property::FaultyQuery { query, tables } => { - let checks = assert_all_table_values(tables, connection_index); + Property::FaultyQuery { query } => { let query_clone = query.clone(); // A fault may not occur as we first signal we want a fault injected, // then when IO is called the fault triggers. It may happen that a fault is injected @@ -851,13 +1085,13 @@ impl Property { } }, ); - let first = [ + [ InteractionType::FaultyQuery(query.clone()), InteractionType::Assertion(assert), ] .into_iter() - .map(|i| Interaction::new(connection_index, i)); - Vec::from_iter(first.chain(checks)) + .map(|i| Interaction::new(connection_index, i)) + .collect() } Property::WhereTrueFalseNull { select, predicate } => { let assumption = InteractionType::Assumption(Assertion::new( @@ -1137,29 +1371,26 @@ fn assert_all_table_values( } #[derive(Debug)] -pub(crate) struct Remaining { - pub(crate) select: u32, - pub(crate) insert: u32, - pub(crate) create: u32, - pub(crate) create_index: u32, - pub(crate) delete: u32, - pub(crate) update: u32, - pub(crate) drop: u32, +pub(super) struct Remaining { + pub select: u32, + pub insert: u32, + pub create: u32, + pub create_index: u32, + pub delete: u32, + pub update: u32, + pub drop: u32, + pub alter_table: u32, + pub drop_index: u32, } -pub(crate) fn remaining( +pub(super) fn remaining( max_interactions: u32, opts: &QueryProfile, stats: &InteractionStats, mvcc: bool, + context: &impl GenerationContext, ) -> Remaining { - let total_weight = opts.select_weight - + opts.create_table_weight - + opts.create_index_weight - + opts.insert_weight - + opts.update_weight - + opts.delete_weight - + opts.drop_table_weight; + let total_weight = opts.total_weight(); let total_select = (max_interactions * opts.select_weight) / total_weight; let total_insert = (max_interactions * opts.insert_weight) / total_weight; @@ -1168,6 +1399,8 @@ pub(crate) fn remaining( let total_delete = (max_interactions * opts.delete_weight) / total_weight; let total_update = (max_interactions * opts.update_weight) / total_weight; let total_drop = (max_interactions * opts.drop_table_weight) / total_weight; + let total_alter_table = (max_interactions * opts.alter_table_weight) / total_weight; + let total_drop_index = (max_interactions * opts.drop_index) / total_weight; let remaining_select = total_select .checked_sub(stats.select_count) @@ -1189,9 +1422,27 @@ pub(crate) fn remaining( .unwrap_or_default(); let remaining_drop = total_drop.checked_sub(stats.drop_count).unwrap_or_default(); + let remaining_alter_table = total_alter_table + .checked_sub(stats.alter_table_count) + .unwrap_or_default(); + + let mut remaining_drop_index = total_drop_index + .checked_sub(stats.alter_table_count) + .unwrap_or_default(); + if mvcc { // TODO: index not supported yet for mvcc remaining_create_index = 0; + remaining_drop_index = 0; + } + + // if there are no indexes do not allow creation of drop_index + if !context + .tables() + .iter() + .any(|table| !table.indexes.is_empty()) + { + remaining_drop_index = 0; } Remaining { @@ -1202,15 +1453,18 @@ pub(crate) fn remaining( delete: remaining_delete, drop: remaining_drop, update: remaining_update, + alter_table: remaining_alter_table, + drop_index: remaining_drop_index, } } -fn property_insert_values_select( +fn property_insert_values_select( rng: &mut R, - remaining: &Remaining, + _query_distr: &QueryDistribution, ctx: &impl GenerationContext, mvcc: bool, ) -> Property { + assert!(!ctx.tables().is_empty()); // Get a random table let table = pick(ctx.tables(), rng); // Generate rows to insert @@ -1223,10 +1477,10 @@ fn property_insert_values_select( let row = rows[row_index].clone(); // Insert the rows - let insert_query = Insert::Values { + let insert_query = Query::Insert(Insert::Values { table: table.name.clone(), values: rows, - }; + }); // Choose if we want queries to be executed in an interactive transaction let interactive = if !mvcc && rng.random_bool(0.5) { @@ -1237,12 +1491,11 @@ fn property_insert_values_select( } else { None }; - // Create random queries respecting the constraints - let mut queries = Vec::new(); - // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) - // - [x] The inserted row will not be deleted. - // - [x] The inserted row will not be updated. - // - [ ] The table `t` will not be renamed, dropped, or altered. (todo: add this constraint once ALTER or DROP is implemented) + + let amount = rng.random_range(0..3); + + let mut queries = Vec::with_capacity(amount + 2); + if let Some(ref interactive) = interactive { queries.push(Query::Begin(if interactive.start_with_immediate { Begin::Immediate @@ -1250,39 +1503,9 @@ fn property_insert_values_select( Begin::Deferred })); } - for _ in 0..rng.random_range(0..3) { - let query = Query::arbitrary_from(rng, ctx, remaining); - match &query { - Query::Delete(Delete { - table: t, - predicate, - }) => { - // The inserted row will not be deleted. - if t == &table.name && predicate.test(&row, table) { - continue; - } - } - Query::Create(Create { table: t }) => { - // There will be no errors in the middle interactions. - // - Creating the same table is an error - if t.name == table.name { - continue; - } - } - Query::Update(Update { - table: t, - set_values: _, - predicate, - }) => { - // The inserted row will not be updated. - if t == &table.name && predicate.test(&row, table) { - continue; - } - } - _ => (), - } - queries.push(query); - } + + queries.extend(std::iter::repeat_n(Query::Placeholder, amount)); + if let Some(ref interactive) = interactive { queries.push(if interactive.end_with_commit { Query::Commit(Commit) @@ -1298,7 +1521,7 @@ fn property_insert_values_select( ); Property::InsertValuesSelect { - insert: insert_query, + insert: insert_query.unwrap_insert(), row_index, queries, select: select_query, @@ -1306,9 +1529,11 @@ fn property_insert_values_select( } } -fn property_read_your_updates_back( +fn property_read_your_updates_back( rng: &mut R, + _query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { // e.g. UPDATE t SET a=1, b=2 WHERE c=1; let update = Update::arbitrary(rng, ctx); @@ -1328,10 +1553,13 @@ fn property_read_your_updates_back( Property::ReadYourUpdatesBack { update, select } } -fn property_table_has_expected_content( +fn property_table_has_expected_content( rng: &mut R, + _query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { + assert!(!ctx.tables().is_empty()); // Get a random table let table = pick(ctx.tables(), rng); Property::TableHasExpectedContent { @@ -1339,7 +1567,24 @@ fn property_table_has_expected_content( } } -fn property_select_limit(rng: &mut R, ctx: &impl GenerationContext) -> Property { +fn property_all_tables_have_expected_content( + _rng: &mut R, + _query_distr: &QueryDistribution, + ctx: &impl GenerationContext, + _mvcc: bool, +) -> Property { + Property::AllTableHaveExpectedContent { + tables: ctx.tables().iter().map(|t| t.name.clone()).collect(), + } +} + +fn property_select_limit( + rng: &mut R, + _query_distr: &QueryDistribution, + ctx: &impl GenerationContext, + _mvcc: bool, +) -> Property { + assert!(!ctx.tables().is_empty()); // Get a random table let table = pick(ctx.tables(), rng); // Select the table @@ -1353,31 +1598,18 @@ fn property_select_limit(rng: &mut R, ctx: &impl GenerationContext Property::SelectLimit { select } } -fn property_double_create_failure( +fn property_double_create_failure( rng: &mut R, - remaining: &Remaining, + _query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { // Create the table let create_query = Create::arbitrary(rng, ctx); - let table = &create_query.table; - // Create random queries respecting the constraints - let mut queries = Vec::new(); - // The interactions in the middle has the following constraints; - // - [x] There will be no errors in the middle interactions.(best effort) - // - [ ] Table `t` will not be renamed or dropped.(todo: add this constraint once ALTER or DROP is implemented) - for _ in 0..rng.random_range(0..3) { - let query = Query::arbitrary_from(rng, ctx, remaining); - if let Query::Create(Create { table: t }) = &query { - // There will be no errors in the middle interactions. - // - Creating the same table is an error - if t.name == table.name { - continue; - } - } - queries.push(query); - } + let amount = rng.random_range(0..3); + + let queries = vec![Query::Placeholder; amount]; Property::DoubleCreateFailure { create: create_query, @@ -1385,56 +1617,21 @@ fn property_double_create_failure( } } -fn property_delete_select( +fn property_delete_select( rng: &mut R, - remaining: &Remaining, + _query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { + assert!(!ctx.tables().is_empty()); // Get a random table let table = pick(ctx.tables(), rng); // Generate a random predicate let predicate = Predicate::arbitrary_from(rng, ctx, table); - // Create random queries respecting the constraints - let mut queries = Vec::new(); - // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) - // - [x] A row that holds for the predicate will not be inserted. - // - [ ] The table `t` will not be renamed, dropped, or altered. (todo: add this constraint once ALTER or DROP is implemented) - for _ in 0..rng.random_range(0..3) { - let query = Query::arbitrary_from(rng, ctx, remaining); - match &query { - Query::Insert(Insert::Values { table: t, values }) => { - // A row that holds for the predicate will not be inserted. - if t == &table.name && values.iter().any(|v| predicate.test(v, table)) { - continue; - } - } - Query::Insert(Insert::Select { - table: t, - select: _, - }) => { - // A row that holds for the predicate will not be inserted. - if t == &table.name { - continue; - } - } - Query::Update(Update { table: t, .. }) => { - // A row that holds for the predicate will not be updated. - if t == &table.name { - continue; - } - } - Query::Create(Create { table: t }) => { - // There will be no errors in the middle interactions. - // - Creating the same table is an error - if t.name == table.name { - continue; - } - } - _ => (), - } - queries.push(query); - } + let amount = rng.random_range(0..3); + + let queries = vec![Query::Placeholder; amount]; Property::DeleteSelect { table: table.name.clone(), @@ -1443,28 +1640,19 @@ fn property_delete_select( } } -fn property_drop_select( +fn property_drop_select( rng: &mut R, - remaining: &Remaining, + _query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { + assert!(!ctx.tables().is_empty()); // Get a random table let table = pick(ctx.tables(), rng); - // Create random queries respecting the constraints - let mut queries = Vec::new(); - // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) - // - [-] The table `t` will not be created, no table will be renamed to `t`. (todo: update this constraint once ALTER is implemented) - for _ in 0..rng.random_range(0..3) { - let query = Query::arbitrary_from(rng, ctx, remaining); - if let Query::Create(Create { table: t }) = &query { - // - The table `t` will not be created - if t.name == table.name { - continue; - } - } - queries.push(query); - } + let amount = rng.random_range(0..3); + + let queries = vec![Query::Placeholder; amount]; let select = Select::simple( table.name.clone(), @@ -1478,10 +1666,13 @@ fn property_drop_select( } } -fn property_select_select_optimizer( +fn property_select_select_optimizer( rng: &mut R, + _query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { + assert!(!ctx.tables().is_empty()); // Get a random table let table = pick(ctx.tables(), rng); // Generate a random predicate @@ -1499,10 +1690,13 @@ fn property_select_select_optimizer( } } -fn property_where_true_false_null( +fn property_where_true_false_null( rng: &mut R, + _query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { + assert!(!ctx.tables().is_empty()); // Get a random table let table = pick(ctx.tables(), rng); // Generate a random predicate @@ -1518,10 +1712,13 @@ fn property_where_true_false_null( } } -fn property_union_all_preserves_cardinality( +fn property_union_all_preserves_cardinality( rng: &mut R, + _query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { + assert!(!ctx.tables().is_empty()); // Get a random table let table = pick(ctx.tables(), rng); // Generate a random predicate @@ -1543,133 +1740,150 @@ fn property_union_all_preserves_cardinality( } } -fn property_fsync_no_wait( +fn property_fsync_no_wait( rng: &mut R, - remaining: &Remaining, + query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { Property::FsyncNoWait { - query: Query::arbitrary_from(rng, ctx, remaining), - tables: ctx.tables().iter().map(|t| t.name.clone()).collect(), + query: Query::arbitrary_from(rng, ctx, query_distr), } } -fn property_faulty_query( +fn property_faulty_query( rng: &mut R, - remaining: &Remaining, + query_distr: &QueryDistribution, ctx: &impl GenerationContext, + _mvcc: bool, ) -> Property { Property::FaultyQuery { - query: Query::arbitrary_from(rng, ctx, remaining), - tables: ctx.tables().iter().map(|t| t.name.clone()).collect(), + query: Query::arbitrary_from(rng, ctx, query_distr), } } -impl ArbitraryFrom<(&SimulatorEnv, &InteractionStats)> for Property { - fn arbitrary_from( - rng: &mut R, - conn_ctx: &C, - (env, stats): (&SimulatorEnv, &InteractionStats), - ) -> Self { - let opts = conn_ctx.opts(); - let remaining_ = remaining( - env.opts.max_interactions, - &env.profile.query, - stats, - env.profile.experimental_mvcc, - ); +type PropertyGenFunc = fn(&mut R, &QueryDistribution, &G, bool) -> Property; - #[allow(clippy::type_complexity)] - let choices: Vec<(_, Box Property>)> = vec![ - ( - if !env.opts.disable_insert_values_select { - u32::min(remaining_.select, remaining_.insert).max(1) +impl PropertyDiscriminants { + fn gen_function(&self) -> PropertyGenFunc + where + R: rand::Rng + ?Sized, + G: GenerationContext, + { + match self { + PropertyDiscriminants::InsertValuesSelect => property_insert_values_select, + PropertyDiscriminants::ReadYourUpdatesBack => property_read_your_updates_back, + PropertyDiscriminants::TableHasExpectedContent => property_table_has_expected_content, + PropertyDiscriminants::AllTableHaveExpectedContent => { + property_all_tables_have_expected_content + } + PropertyDiscriminants::DoubleCreateFailure => property_double_create_failure, + PropertyDiscriminants::SelectLimit => property_select_limit, + PropertyDiscriminants::DeleteSelect => property_delete_select, + PropertyDiscriminants::DropSelect => property_drop_select, + PropertyDiscriminants::SelectSelectOptimizer => property_select_select_optimizer, + PropertyDiscriminants::WhereTrueFalseNull => property_where_true_false_null, + PropertyDiscriminants::UNIONAllPreservesCardinality => { + property_union_all_preserves_cardinality + } + PropertyDiscriminants::FsyncNoWait => property_fsync_no_wait, + PropertyDiscriminants::FaultyQuery => property_faulty_query, + PropertyDiscriminants::Queries => { + unreachable!("should not try to generate queries property") + } + } + } + + fn weight( + &self, + env: &SimulatorEnv, + remaining: &Remaining, + ctx: &impl GenerationContext, + ) -> u32 { + let opts = ctx.opts(); + match self { + PropertyDiscriminants::InsertValuesSelect => { + if !env.opts.disable_insert_values_select && !ctx.tables().is_empty() { + u32::min(remaining.select, remaining.insert).max(1) } else { 0 - }, - Box::new(|rng: &mut R| { - property_insert_values_select( - rng, - &remaining_, - conn_ctx, - env.profile.experimental_mvcc, - ) - }), - ), - ( - remaining_.select.max(1), - Box::new(|rng: &mut R| property_table_has_expected_content(rng, conn_ctx)), - ), - ( - u32::min(remaining_.select, remaining_.insert).max(1), - Box::new(|rng: &mut R| property_read_your_updates_back(rng, conn_ctx)), - ), - ( + } + } + PropertyDiscriminants::ReadYourUpdatesBack => { + u32::min(remaining.select, remaining.insert).max(1) + } + PropertyDiscriminants::TableHasExpectedContent => { + if !ctx.tables().is_empty() { + remaining.select.max(1) + } else { + 0 + } + } + // AllTableHaveExpectedContent should only be generated by Properties that inject faults + PropertyDiscriminants::AllTableHaveExpectedContent => 0, + PropertyDiscriminants::DoubleCreateFailure => { if !env.opts.disable_double_create_failure { - remaining_.create / 2 + remaining.create / 2 } else { 0 - }, - Box::new(|rng: &mut R| property_double_create_failure(rng, &remaining_, conn_ctx)), - ), - ( - if !env.opts.disable_select_limit { - remaining_.select + } + } + PropertyDiscriminants::SelectLimit => { + if !env.opts.disable_select_limit && !ctx.tables().is_empty() { + remaining.select } else { 0 - }, - Box::new(|rng: &mut R| property_select_limit(rng, conn_ctx)), - ), - ( - if !env.opts.disable_delete_select { - u32::min(remaining_.select, remaining_.insert).min(remaining_.delete) + } + } + PropertyDiscriminants::DeleteSelect => { + if !env.opts.disable_delete_select && !ctx.tables().is_empty() { + u32::min(remaining.select, remaining.insert).min(remaining.delete) } else { 0 - }, - Box::new(|rng: &mut R| property_delete_select(rng, &remaining_, conn_ctx)), - ), - ( - if !env.opts.disable_drop_select { - // remaining_.drop - 0 + } + } + PropertyDiscriminants::DropSelect => { + if !env.opts.disable_drop_select && !ctx.tables().is_empty() { + remaining.drop } else { 0 - }, - Box::new(|rng: &mut R| property_drop_select(rng, &remaining_, conn_ctx)), - ), - ( - if !env.opts.disable_select_optimizer { - remaining_.select / 2 + } + } + PropertyDiscriminants::SelectSelectOptimizer => { + if !env.opts.disable_select_optimizer && !ctx.tables().is_empty() { + remaining.select / 2 } else { 0 - }, - Box::new(|rng: &mut R| property_select_select_optimizer(rng, conn_ctx)), - ), - ( - if opts.indexes && !env.opts.disable_where_true_false_null { - remaining_.select / 2 + } + } + PropertyDiscriminants::WhereTrueFalseNull => { + if opts.indexes + && !env.opts.disable_where_true_false_null + && !ctx.tables().is_empty() + { + remaining.select / 2 } else { 0 - }, - Box::new(|rng: &mut R| property_where_true_false_null(rng, conn_ctx)), - ), - ( - if opts.indexes && !env.opts.disable_union_all_preserves_cardinality { - remaining_.select / 3 + } + } + PropertyDiscriminants::UNIONAllPreservesCardinality => { + if opts.indexes + && !env.opts.disable_union_all_preserves_cardinality + && !ctx.tables().is_empty() + { + remaining.select / 3 } else { 0 - }, - Box::new(|rng: &mut R| property_union_all_preserves_cardinality(rng, conn_ctx)), - ), - ( + } + } + PropertyDiscriminants::FsyncNoWait => { if env.profile.io.enable && !env.opts.disable_fsync_no_wait { 50 // Freestyle number } else { 0 - }, - Box::new(|rng: &mut R| property_fsync_no_wait(rng, &remaining_, conn_ctx)), - ), - ( + } + } + PropertyDiscriminants::FaultyQuery => { if env.profile.io.enable && env.profile.io.fault.enable && !env.opts.disable_faulty_query @@ -1677,12 +1891,115 @@ impl ArbitraryFrom<(&SimulatorEnv, &InteractionStats)> for Property { 20 } else { 0 - }, - Box::new(|rng: &mut R| property_faulty_query(rng, &remaining_, conn_ctx)), - ), - ]; + } + } + PropertyDiscriminants::Queries => { + unreachable!("queries property should not be generated") + } + } + } - frequency(choices, rng) + fn can_generate(queries: &[QueryDiscriminants]) -> Vec { + let queries_capabilities = QueryCapabilities::from_list_queries(queries); + + PropertyDiscriminants::iter() + .filter(|property| { + !matches!(property, PropertyDiscriminants::Queries) + && queries_capabilities.contains(property.requirements()) + }) + .collect() + } + + pub const fn requirements(&self) -> QueryCapabilities { + match self { + PropertyDiscriminants::InsertValuesSelect => { + QueryCapabilities::SELECT.union(QueryCapabilities::INSERT) + } + PropertyDiscriminants::ReadYourUpdatesBack => { + QueryCapabilities::SELECT.union(QueryCapabilities::UPDATE) + } + PropertyDiscriminants::TableHasExpectedContent => QueryCapabilities::SELECT, + PropertyDiscriminants::AllTableHaveExpectedContent => QueryCapabilities::SELECT, + PropertyDiscriminants::DoubleCreateFailure => QueryCapabilities::CREATE, + PropertyDiscriminants::SelectLimit => QueryCapabilities::SELECT, + PropertyDiscriminants::DeleteSelect => { + QueryCapabilities::SELECT.union(QueryCapabilities::DELETE) + } + PropertyDiscriminants::DropSelect => { + QueryCapabilities::SELECT.union(QueryCapabilities::DROP) + } + PropertyDiscriminants::SelectSelectOptimizer => QueryCapabilities::SELECT, + PropertyDiscriminants::WhereTrueFalseNull => QueryCapabilities::SELECT, + PropertyDiscriminants::UNIONAllPreservesCardinality => QueryCapabilities::SELECT, + PropertyDiscriminants::FsyncNoWait => QueryCapabilities::all(), + PropertyDiscriminants::FaultyQuery => QueryCapabilities::all(), + PropertyDiscriminants::Queries => panic!("queries property should not be generated"), + } + } +} + +pub(super) struct PropertyDistribution<'a> { + properties: Vec, + weights: WeightedIndex, + query_distr: &'a QueryDistribution, + mvcc: bool, +} + +impl<'a> PropertyDistribution<'a> { + pub fn new( + env: &SimulatorEnv, + remaining: &Remaining, + query_distr: &'a QueryDistribution, + ctx: &impl GenerationContext, + ) -> Result { + let properties = PropertyDiscriminants::can_generate(query_distr.items()); + let weights = WeightedIndex::new( + properties + .iter() + .map(|property| property.weight(env, remaining, ctx)), + )?; + + Ok(Self { + properties, + weights, + query_distr, + mvcc: env.profile.experimental_mvcc, + }) + } +} + +impl<'a> WeightedDistribution for PropertyDistribution<'a> { + type Item = PropertyDiscriminants; + + type GenItem = Property; + + fn items(&self) -> &[Self::Item] { + &self.properties + } + + fn weights(&self) -> &WeightedIndex { + &self.weights + } + + fn sample( + &self, + rng: &mut R, + conn_ctx: &C, + ) -> Self::GenItem { + let properties = &self.properties; + let idx = self.weights.sample(rng); + let property_fn = properties[idx].gen_function(); + (property_fn)(rng, self.query_distr, conn_ctx, self.mvcc) + } +} + +impl<'a> ArbitraryFrom<&PropertyDistribution<'a>> for Property { + fn arbitrary_from( + rng: &mut R, + conn_ctx: &C, + property_distr: &PropertyDistribution<'a>, + ) -> Self { + property_distr.sample(rng, conn_ctx) } } diff --git a/simulator/generation/query.rs b/simulator/generation/query.rs index 72541c4d7..7445bd744 100644 --- a/simulator/generation/query.rs +++ b/simulator/generation/query.rs @@ -1,42 +1,221 @@ -use crate::model::Query; -use rand::Rng; +use crate::{ + generation::WeightedDistribution, + model::{Query, QueryDiscriminants}, +}; +use rand::{ + Rng, + distr::{Distribution, weighted::WeightedIndex}, +}; use sql_generation::{ - generation::{Arbitrary, ArbitraryFrom, GenerationContext, frequency}, - model::query::{Create, Delete, Insert, Select, update::Update}, + generation::{Arbitrary, ArbitraryFrom, GenerationContext, query::SelectFree}, + model::{ + query::{ + Create, CreateIndex, Delete, DropIndex, Insert, Select, alter_table::AlterTable, + update::Update, + }, + table::Table, + }, }; use super::property::Remaining; -impl ArbitraryFrom<&Remaining> for Query { - fn arbitrary_from( - rng: &mut R, - context: &C, - remaining: &Remaining, - ) -> Self { - frequency( - vec![ - ( - remaining.create, - Box::new(|rng| Self::Create(Create::arbitrary(rng, context))), - ), - ( - remaining.select, - Box::new(|rng| Self::Select(Select::arbitrary(rng, context))), - ), - ( - remaining.insert, - Box::new(|rng| Self::Insert(Insert::arbitrary(rng, context))), - ), - ( - remaining.update, - Box::new(|rng| Self::Update(Update::arbitrary(rng, context))), - ), - ( - remaining.insert.min(remaining.delete), - Box::new(|rng| Self::Delete(Delete::arbitrary(rng, context))), - ), - ], - rng, - ) +fn random_create(rng: &mut R, conn_ctx: &impl GenerationContext) -> Query { + let mut create = Create::arbitrary(rng, conn_ctx); + while conn_ctx + .tables() + .iter() + .any(|t| t.name == create.table.name) + { + create = Create::arbitrary(rng, conn_ctx); + } + Query::Create(create) +} + +fn random_select(rng: &mut R, conn_ctx: &impl GenerationContext) -> Query { + if !conn_ctx.tables().is_empty() && rng.random_bool(0.7) { + Query::Select(Select::arbitrary(rng, conn_ctx)) + } else { + // Random expression + Query::Select(SelectFree::arbitrary(rng, conn_ctx).0) + } +} + +fn random_insert(rng: &mut R, conn_ctx: &impl GenerationContext) -> Query { + assert!(!conn_ctx.tables().is_empty()); + Query::Insert(Insert::arbitrary(rng, conn_ctx)) +} + +fn random_delete(rng: &mut R, conn_ctx: &impl GenerationContext) -> Query { + assert!(!conn_ctx.tables().is_empty()); + Query::Delete(Delete::arbitrary(rng, conn_ctx)) +} + +fn random_update(rng: &mut R, conn_ctx: &impl GenerationContext) -> Query { + assert!(!conn_ctx.tables().is_empty()); + Query::Update(Update::arbitrary(rng, conn_ctx)) +} + +fn random_drop(rng: &mut R, conn_ctx: &impl GenerationContext) -> Query { + assert!(!conn_ctx.tables().is_empty()); + Query::Drop(sql_generation::model::query::Drop::arbitrary(rng, conn_ctx)) +} + +fn random_create_index( + rng: &mut R, + conn_ctx: &impl GenerationContext, +) -> Query { + assert!(!conn_ctx.tables().is_empty()); + + let mut create_index = CreateIndex::arbitrary(rng, conn_ctx); + while conn_ctx + .tables() + .iter() + .find(|t| t.name == create_index.table_name) + .expect("table should exist") + .indexes + .iter() + .any(|i| i.index_name == create_index.index_name) + { + create_index = CreateIndex::arbitrary(rng, conn_ctx); + } + + Query::CreateIndex(create_index) +} + +fn random_alter_table( + rng: &mut R, + conn_ctx: &impl GenerationContext, +) -> Query { + assert!(!conn_ctx.tables().is_empty()); + Query::AlterTable(AlterTable::arbitrary(rng, conn_ctx)) +} + +fn random_drop_index( + rng: &mut R, + conn_ctx: &impl GenerationContext, +) -> Query { + assert!( + conn_ctx + .tables() + .iter() + .any(|table| !table.indexes.is_empty()) + ); + Query::DropIndex(DropIndex::arbitrary(rng, conn_ctx)) +} + +/// Possible queries that can be generated given the table state +/// +/// Does not take into account transactional statements +pub const fn possible_queries(tables: &[Table]) -> &'static [QueryDiscriminants] { + if tables.is_empty() { + &[QueryDiscriminants::Select, QueryDiscriminants::Create] + } else { + QueryDiscriminants::ALL_NO_TRANSACTION + } +} + +type QueryGenFunc = fn(&mut R, &G) -> Query; + +impl QueryDiscriminants { + fn gen_function(&self) -> QueryGenFunc + where + R: rand::Rng + ?Sized, + G: GenerationContext, + { + match self { + QueryDiscriminants::Create => random_create, + QueryDiscriminants::Select => random_select, + QueryDiscriminants::Insert => random_insert, + QueryDiscriminants::Delete => random_delete, + QueryDiscriminants::Update => random_update, + QueryDiscriminants::Drop => random_drop, + QueryDiscriminants::CreateIndex => random_create_index, + QueryDiscriminants::AlterTable => random_alter_table, + QueryDiscriminants::DropIndex => random_drop_index, + QueryDiscriminants::Begin + | QueryDiscriminants::Commit + | QueryDiscriminants::Rollback => { + unreachable!("transactional queries should not be generated") + } + QueryDiscriminants::Placeholder => { + unreachable!("Query Placeholders should not be generated") + } + } + } + + fn weight(&self, remaining: &Remaining) -> u32 { + match self { + QueryDiscriminants::Create => remaining.create, + // remaining.select / 3 is for the random_expr generation + // have a max of 1 so that we always generate at least a non zero weight for `QueryDistribution` + QueryDiscriminants::Select => (remaining.select + remaining.select / 3).max(1), + QueryDiscriminants::Insert => remaining.insert, + QueryDiscriminants::Delete => remaining.delete, + QueryDiscriminants::Update => remaining.update, + QueryDiscriminants::Drop => remaining.drop, + QueryDiscriminants::CreateIndex => remaining.create_index, + QueryDiscriminants::AlterTable => remaining.alter_table, + QueryDiscriminants::DropIndex => remaining.drop_index, + QueryDiscriminants::Begin + | QueryDiscriminants::Commit + | QueryDiscriminants::Rollback => { + unreachable!("transactional queries should not be generated") + } + QueryDiscriminants::Placeholder => { + unreachable!("Query Placeholders should not be generated") + } + } + } +} + +#[derive(Debug)] +pub(super) struct QueryDistribution { + queries: &'static [QueryDiscriminants], + weights: WeightedIndex, +} + +impl QueryDistribution { + pub fn new(queries: &'static [QueryDiscriminants], remaining: &Remaining) -> Self { + let query_weights = + WeightedIndex::new(queries.iter().map(|query| query.weight(remaining))).unwrap(); + Self { + queries, + weights: query_weights, + } + } +} + +impl WeightedDistribution for QueryDistribution { + type Item = QueryDiscriminants; + type GenItem = Query; + + fn items(&self) -> &[Self::Item] { + self.queries + } + + fn weights(&self) -> &WeightedIndex { + &self.weights + } + + fn sample( + &self, + rng: &mut R, + ctx: &C, + ) -> Self::GenItem { + let weights = &self.weights; + + let idx = weights.sample(rng); + let query_fn = self.queries[idx].gen_function(); + (query_fn)(rng, ctx) + } +} + +impl ArbitraryFrom<&QueryDistribution> for Query { + fn arbitrary_from( + rng: &mut R, + context: &C, + query_distr: &QueryDistribution, + ) -> Self { + query_distr.sample(rng, context) } } diff --git a/simulator/main.rs b/simulator/main.rs index 15376a13b..a23157dda 100644 --- a/simulator/main.rs +++ b/simulator/main.rs @@ -1,4 +1,4 @@ -#![allow(clippy::arc_with_non_send_sync, dead_code)] +#![allow(clippy::arc_with_non_send_sync)] use anyhow::anyhow; use clap::Parser; use generation::plan::{InteractionPlan, InteractionPlanState}; @@ -421,6 +421,7 @@ enum SandboxedResult { error: String, last_execution: Execution, }, + #[expect(dead_code)] FoundBug { error: String, history: ExecutionHistory, @@ -610,13 +611,19 @@ fn run_simulation_default( tracing::info!("Simulation completed"); + env.io.persist_files().unwrap(); + if result.error.is_none() { - let ic = integrity_check(&env.get_db_path()); - if let Err(err) = ic { - tracing::error!("integrity check failed: {}", err); - result.error = Some(turso_core::LimboError::InternalError(err.to_string())); + if env.opts.disable_integrity_check { + tracing::info!("skipping integrity check (disabled by configuration)"); } else { - tracing::info!("integrity check passed"); + let ic = integrity_check(&env.get_db_path()); + if let Err(err) = ic { + tracing::error!("integrity check failed: {}", err); + result.error = Some(turso_core::LimboError::InternalError(err.to_string())); + } else { + tracing::info!("integrity check passed"); + } } } @@ -683,6 +690,7 @@ const BANNER: &str = r#" "#; fn integrity_check(db_path: &Path) -> anyhow::Result<()> { + assert!(db_path.exists()); let conn = rusqlite::Connection::open(db_path)?; let mut stmt = conn.prepare("SELECT * FROM pragma_integrity_check;")?; let mut rows = stmt.query(())?; diff --git a/simulator/model/mod.rs b/simulator/model/mod.rs index 551c08b1d..cdebd76fe 100644 --- a/simulator/model/mod.rs +++ b/simulator/model/mod.rs @@ -1,24 +1,26 @@ use std::fmt::Display; use anyhow::Context; +use bitflags::bitflags; use indexmap::IndexSet; use itertools::Itertools; use serde::{Deserialize, Serialize}; use sql_generation::model::{ query::{ - Create, CreateIndex, Delete, Drop, Insert, Select, + Create, CreateIndex, Delete, Drop, DropIndex, Insert, Select, + alter_table::{AlterTable, AlterTableType}, select::{CompoundOperator, FromClause, ResultColumn, SelectInner}, transaction::{Begin, Commit, Rollback}, update::Update, }, - table::{JoinTable, JoinType, SimValue, Table, TableContext}, + table::{Index, JoinTable, JoinType, SimValue, Table, TableContext}, }; use turso_parser::ast::Distinctness; use crate::{generation::Shadow, runner::env::ShadowTablesMut}; // This type represents the potential queries on the database. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, strum::EnumDiscriminants)] pub enum Query { Create(Create), Select(Select), @@ -27,12 +29,38 @@ pub enum Query { Update(Update), Drop(Drop), CreateIndex(CreateIndex), + AlterTable(AlterTable), + DropIndex(DropIndex), Begin(Begin), Commit(Commit), Rollback(Rollback), + /// Placeholder query that still needs to be generated + Placeholder, } impl Query { + pub fn as_create(&self) -> &Create { + match self { + Self::Create(create) => create, + _ => unreachable!(), + } + } + + pub fn unwrap_create(self) -> Create { + match self { + Self::Create(create) => create, + _ => unreachable!(), + } + } + + #[inline] + pub fn unwrap_insert(self) -> Insert { + match self { + Self::Insert(insert) => insert, + _ => unreachable!(), + } + } + pub fn dependencies(&self) -> IndexSet { match self { Query::Select(select) => select.dependencies(), @@ -41,11 +69,20 @@ impl Query { | Query::Insert(Insert::Values { table, .. }) | Query::Delete(Delete { table, .. }) | Query::Update(Update { table, .. }) - | Query::Drop(Drop { table, .. }) => IndexSet::from_iter([table.clone()]), - Query::CreateIndex(CreateIndex { table_name, .. }) => { - IndexSet::from_iter([table_name.clone()]) - } + | Query::Drop(Drop { table, .. }) + | Query::CreateIndex(CreateIndex { + index: Index { + table_name: table, .. + }, + }) + | Query::AlterTable(AlterTable { + table_name: table, .. + }) + | Query::DropIndex(DropIndex { + table_name: table, .. + }) => IndexSet::from_iter([table.clone()]), Query::Begin(_) | Query::Commit(_) | Query::Rollback(_) => IndexSet::new(), + Query::Placeholder => IndexSet::new(), } } pub fn uses(&self) -> Vec { @@ -56,9 +93,20 @@ impl Query { | Query::Insert(Insert::Values { table, .. }) | Query::Delete(Delete { table, .. }) | Query::Update(Update { table, .. }) - | Query::Drop(Drop { table, .. }) => vec![table.clone()], - Query::CreateIndex(CreateIndex { table_name, .. }) => vec![table_name.clone()], + | Query::Drop(Drop { table, .. }) + | Query::CreateIndex(CreateIndex { + index: Index { + table_name: table, .. + }, + }) + | Query::AlterTable(AlterTable { + table_name: table, .. + }) + | Query::DropIndex(DropIndex { + table_name: table, .. + }) => vec![table.clone()], Query::Begin(..) | Query::Commit(..) | Query::Rollback(..) => vec![], + Query::Placeholder => vec![], } } @@ -74,7 +122,11 @@ impl Query { pub fn is_ddl(&self) -> bool { matches!( self, - Self::Create(..) | Self::CreateIndex(..) | Self::Drop(..) + Self::Create(..) + | Self::CreateIndex(..) + | Self::Drop(..) + | Self::AlterTable(..) + | Self::DropIndex(..) ) } } @@ -89,9 +141,12 @@ impl Display for Query { Self::Update(update) => write!(f, "{update}"), Self::Drop(drop) => write!(f, "{drop}"), Self::CreateIndex(create_index) => write!(f, "{create_index}"), + Self::AlterTable(alter_table) => write!(f, "{alter_table}"), + Self::DropIndex(drop_index) => write!(f, "{drop_index}"), Self::Begin(begin) => write!(f, "{begin}"), Self::Commit(commit) => write!(f, "{commit}"), Self::Rollback(rollback) => write!(f, "{rollback}"), + Self::Placeholder => Ok(()), } } } @@ -108,13 +163,83 @@ impl Shadow for Query { Query::Update(update) => update.shadow(env), Query::Drop(drop) => drop.shadow(env), Query::CreateIndex(create_index) => Ok(create_index.shadow(env)), + Query::AlterTable(alter_table) => alter_table.shadow(env), + Query::DropIndex(drop_index) => drop_index.shadow(env), Query::Begin(begin) => Ok(begin.shadow(env)), Query::Commit(commit) => Ok(commit.shadow(env)), Query::Rollback(rollback) => Ok(rollback.shadow(env)), + Query::Placeholder => Ok(vec![]), } } } +bitflags! { + pub struct QueryCapabilities: u32 { + const CREATE = 1 << 0; + const SELECT = 1 << 1; + const INSERT = 1 << 2; + const DELETE = 1 << 3; + const UPDATE = 1 << 4; + const DROP = 1 << 5; + const CREATE_INDEX = 1 << 6; + const ALTER_TABLE = 1 << 7; + const DROP_INDEX = 1 << 8; + } +} + +impl QueryCapabilities { + // TODO: can be const fn in the future + pub fn from_list_queries(queries: &[QueryDiscriminants]) -> Self { + queries + .iter() + .fold(Self::empty(), |accum, q| accum.union(q.into())) + } +} + +impl From<&QueryDiscriminants> for QueryCapabilities { + fn from(value: &QueryDiscriminants) -> Self { + (*value).into() + } +} + +impl From for QueryCapabilities { + fn from(value: QueryDiscriminants) -> Self { + match value { + QueryDiscriminants::Create => Self::CREATE, + QueryDiscriminants::Select => Self::SELECT, + QueryDiscriminants::Insert => Self::INSERT, + QueryDiscriminants::Delete => Self::DELETE, + QueryDiscriminants::Update => Self::UPDATE, + QueryDiscriminants::Drop => Self::DROP, + QueryDiscriminants::CreateIndex => Self::CREATE_INDEX, + QueryDiscriminants::AlterTable => Self::ALTER_TABLE, + QueryDiscriminants::DropIndex => Self::DROP_INDEX, + QueryDiscriminants::Begin + | QueryDiscriminants::Commit + | QueryDiscriminants::Rollback => { + unreachable!("QueryCapabilities do not apply to transaction queries") + } + QueryDiscriminants::Placeholder => { + unreachable!("QueryCapabilities do not apply to query Placeholder") + } + } + } +} + +impl QueryDiscriminants { + pub const ALL_NO_TRANSACTION: &[QueryDiscriminants] = &[ + QueryDiscriminants::Select, + QueryDiscriminants::Create, + QueryDiscriminants::Insert, + QueryDiscriminants::Update, + QueryDiscriminants::Delete, + QueryDiscriminants::Drop, + QueryDiscriminants::CreateIndex, + QueryDiscriminants::AlterTable, + QueryDiscriminants::DropIndex, + ]; +} + impl Shadow for Create { type Result = anyhow::Result>>; @@ -138,7 +263,7 @@ impl Shadow for CreateIndex { .find(|t| t.name == self.table_name) .unwrap() .indexes - .push(self.index_name.clone()); + .push(self.index.clone()); vec![] } } @@ -169,6 +294,7 @@ impl Shadow for Drop { type Result = anyhow::Result>>; fn shadow(&self, tables: &mut ShadowTablesMut) -> Self::Result { + tracing::info!("dropping {:?}", self); if !tables.iter().any(|t| t.name == self.table) { // If the table does not exist, we return an error return Err(anyhow::anyhow!( @@ -430,3 +556,76 @@ impl Shadow for Update { Ok(vec![]) } } + +impl Shadow for AlterTable { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut ShadowTablesMut<'_>) -> Self::Result { + let table = tables + .iter_mut() + .find(|t| t.name == self.table_name) + .ok_or_else(|| anyhow::anyhow!("Table {} does not exist", self.table_name))?; + + match &self.alter_table_type { + AlterTableType::RenameTo { new_name } => { + table.name = new_name.clone(); + } + AlterTableType::AddColumn { column } => { + table.columns.push(column.clone()); + table.rows.iter_mut().for_each(|row| { + row.push(SimValue(turso_core::Value::Null)); + }); + } + AlterTableType::AlterColumn { old, new } => { + let col = table.columns.iter_mut().find(|c| c.name == *old).unwrap(); + *col = new.clone(); + table.indexes.iter_mut().for_each(|index| { + index.columns.iter_mut().for_each(|(col_name, _)| { + if col_name == old { + *col_name = new.name.clone(); + } + }); + }); + } + AlterTableType::RenameColumn { old, new } => { + let col = table.columns.iter_mut().find(|c| c.name == *old).unwrap(); + col.name = new.clone(); + table.indexes.iter_mut().for_each(|index| { + index.columns.iter_mut().for_each(|(col_name, _)| { + if col_name == old { + *col_name = new.clone(); + } + }); + }); + } + AlterTableType::DropColumn { column_name } => { + let col_idx = table + .columns + .iter() + .position(|c| c.name == *column_name) + .unwrap(); + table.columns.remove(col_idx); + table.rows.iter_mut().for_each(|row| { + row.remove(col_idx); + }); + } + }; + Ok(vec![]) + } +} + +impl Shadow for DropIndex { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut ShadowTablesMut<'_>) -> Self::Result { + let table = tables + .iter_mut() + .find(|t| t.name == self.table_name) + .ok_or_else(|| anyhow::anyhow!("Table {} does not exist", self.table_name))?; + + table + .indexes + .retain(|index| index.index_name != self.index_name); + Ok(vec![]) + } +} diff --git a/simulator/profiles/mod.rs b/simulator/profiles/mod.rs index 8c8d1f670..e4ea1dc06 100644 --- a/simulator/profiles/mod.rs +++ b/simulator/profiles/mod.rs @@ -93,11 +93,6 @@ impl Profile { }, ..Default::default() }, - query: QueryProfile { - create_table_weight: 0, - create_index_weight: 4, - ..Default::default() - }, ..Default::default() }; diff --git a/simulator/profiles/query.rs b/simulator/profiles/query.rs index a58c983e0..95bcf146a 100644 --- a/simulator/profiles/query.rs +++ b/simulator/profiles/query.rs @@ -22,6 +22,10 @@ pub struct QueryProfile { pub delete_weight: u32, #[garde(skip)] pub drop_table_weight: u32, + #[garde(skip)] + pub alter_table_weight: u32, + #[garde(skip)] + pub drop_index: u32, } impl Default for QueryProfile { @@ -35,10 +39,26 @@ impl Default for QueryProfile { update_weight: 20, delete_weight: 20, drop_table_weight: 2, + alter_table_weight: 2, + drop_index: 2, } } } +impl QueryProfile { + /// Attention: edit this function when another weight is added + pub fn total_weight(&self) -> u32 { + self.select_weight + + self.create_table_weight + + self.create_index_weight + + self.insert_weight + + self.update_weight + + self.delete_weight + + self.drop_table_weight + + self.alter_table_weight + } +} + #[derive(Debug, Clone, strum::VariantArray)] pub enum QueryTypes { CreateTable, diff --git a/simulator/run-miri.sh b/simulator/run-miri.sh new file mode 100755 index 000000000..e065f3c68 --- /dev/null +++ b/simulator/run-miri.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +ARGS=("$@") + +# Intercept the seed if it's passed +while [[ $# -gt 0 ]]; do + case $1 in + -s=*|--seed=*) + seed="${1#*=}" + shift + ;; + -s|--seed) + seed="$2" + shift 2 + ;; + *) + shift + ;; + esac +done +# Otherwise make one up +if [ -z "$seed" ]; then + # Dump 8 bytes of /dev/random as decimal u64 + seed=$(od -An -N8 -tu8 /dev/random | tr -d ' ') + ARGS+=("--seed" "${seed}") + echo "Generated seed for Miri and simulator: ${seed}" +else + echo "Intercepted simulator seed to pass to Miri: ${seed}" +fi + +MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-disable-stacked-borrows -Zmiri-seed=${seed}" cargo +nightly miri run --bin limbo_sim -- "${ARGS[@]}" diff --git a/simulator/runner/bugbase.rs b/simulator/runner/bugbase.rs index 179c292f1..89ad25e71 100644 --- a/simulator/runner/bugbase.rs +++ b/simulator/runner/bugbase.rs @@ -1,8 +1,9 @@ use std::{ collections::HashMap, - io::{self, Write}, - path::PathBuf, - process::Command, + env::current_dir, + fs::File, + io::{self, Read, Write}, + path::{Path, PathBuf}, time::SystemTime, }; @@ -49,6 +50,7 @@ pub(crate) struct BugRun { } impl Bug { + #[expect(dead_code)] /// Check if the bug is loaded. pub(crate) fn is_loaded(&self) -> bool { match self { @@ -130,6 +132,7 @@ impl BugBase { Err(anyhow!("failed to create bug base")) } + #[expect(dead_code)] /// Load the bug base from one of the potential paths. pub(crate) fn interactive_load() -> anyhow::Result { let potential_paths = vec![ @@ -291,22 +294,23 @@ impl BugBase { None => anyhow::bail!("No bugs found for seed {}", seed), Some(Bug::Unloaded { .. }) => { let plan = - std::fs::read_to_string(self.path.join(seed.to_string()).join("test.json")) + std::fs::read_to_string(self.path.join(seed.to_string()).join("plan.json")) .with_context(|| { format!( "should be able to read plan file at {}", - self.path.join(seed.to_string()).join("test.json").display() + self.path.join(seed.to_string()).join("plan.json").display() ) })?; let plan: InteractionPlan = serde_json::from_str(&plan) .with_context(|| "should be able to deserialize plan")?; - let shrunk_plan: Option = std::fs::read_to_string( - self.path.join(seed.to_string()).join("shrunk_test.json"), - ) - .with_context(|| "should be able to read shrunk plan file") - .and_then(|shrunk| serde_json::from_str(&shrunk).map_err(|e| anyhow!("{}", e))) - .ok(); + let shrunk_plan: Option = + std::fs::read_to_string(self.path.join(seed.to_string()).join("shrunk.json")) + .with_context(|| "should be able to read shrunk plan file") + .and_then(|shrunk| { + serde_json::from_str(&shrunk).map_err(|e| anyhow!("{}", e)) + }) + .ok(); let shrunk_plan: Option = shrunk_plan.and_then(|shrunk_plan| serde_json::from_str(&shrunk_plan).ok()); @@ -338,6 +342,7 @@ impl BugBase { } } + #[expect(dead_code)] pub(crate) fn mark_successful_run( &mut self, seed: u64, @@ -434,6 +439,7 @@ impl BugBase { } impl BugBase { + #[expect(dead_code)] /// Get the path to the bug base directory. pub(crate) fn path(&self) -> &PathBuf { &self.path @@ -448,28 +454,56 @@ impl BugBase { impl BugBase { pub(crate) fn get_current_commit_hash() -> anyhow::Result { - let output = Command::new("git") - .args(["rev-parse", "HEAD"]) - .output() - .with_context(|| "should be able to get the commit hash")?; - let commit_hash = String::from_utf8(output.stdout) - .with_context(|| "commit hash should be valid utf8")? - .trim() - .to_string(); - Ok(commit_hash) + let git_dir = find_git_dir(current_dir()?).with_context(|| "should be a git repo")?; + let hash = + resolve_head(&git_dir).with_context(|| "should be able to get the commit hash")?; + Ok(hash) } pub(crate) fn get_limbo_project_dir() -> anyhow::Result { - Ok(PathBuf::from( - String::from_utf8( - Command::new("git") - .args(["rev-parse", "--show-toplevel"]) - .output() - .with_context(|| "should be able to get the git path")? - .stdout, - ) - .with_context(|| "commit hash should be valid utf8")? - .trim(), - )) + let git_dir = find_git_dir(current_dir()?).with_context(|| "should be a git repo")?; + let workdir = git_dir + .parent() + .with_context(|| "work tree should be parent of .git")?; + Ok(workdir.to_path_buf()) } } + +fn find_git_dir(start_path: impl AsRef) -> Option { + let mut current = start_path.as_ref().to_path_buf(); + loop { + let git_path = current.join(".git"); + if git_path.is_dir() { + return Some(git_path); + } else if git_path.is_file() { + // Handle git worktrees - .git is a file containing "gitdir: " + if let Ok(contents) = read_to_string(&git_path) { + if let Some(gitdir) = contents.strip_prefix("gitdir: ") { + return Some(PathBuf::from(gitdir)); + } + } + } + if !current.pop() { + return None; + } + } +} + +fn resolve_head(git_dir: impl AsRef) -> anyhow::Result { + // HACK ignores stuff like packed-refs + let head_path = git_dir.as_ref().join("HEAD"); + let head_contents = read_to_string(&head_path)?; + if let Some(ref_path) = head_contents.strip_prefix("ref: ") { + let ref_file = git_dir.as_ref().join(ref_path); + read_to_string(&ref_file) + } else { + Ok(head_contents) + } +} + +fn read_to_string(path: impl AsRef) -> anyhow::Result { + let mut file = File::open(path)?; + let mut contents = String::new(); + file.read_to_string(&mut contents)?; + Ok(contents.trim().to_string()) +} diff --git a/simulator/runner/cli.rs b/simulator/runner/cli.rs index 6dec4b46d..97062dd2d 100644 --- a/simulator/runner/cli.rs +++ b/simulator/runner/cli.rs @@ -30,7 +30,7 @@ pub struct SimulatorCLI { short = 'n', long, help = "change the maximum size of the randomly generated sequence of interactions", - default_value_t = 5000, + default_value_t = normal_or_miri(5000, 50), value_parser = clap::value_parser!(u32).range(1..) )] pub maximum_tests: u32, @@ -38,7 +38,7 @@ pub struct SimulatorCLI { short = 'k', long, help = "change the minimum size of the randomly generated sequence of interactions", - default_value_t = 1000, + default_value_t = normal_or_miri(1000, 10), value_parser = clap::value_parser!(u32).range(1..) )] pub minimum_tests: u32, @@ -147,6 +147,12 @@ pub struct SimulatorCLI { default_value_t = false )] pub keep_files: bool, + #[clap( + long, + help = "Disable the SQLite integrity check at the end of a simulation", + default_value_t = normal_or_miri(false, true) + )] + pub disable_integrity_check: bool, #[clap( long, help = "Use memory IO for complex simulations", @@ -274,3 +280,7 @@ impl ValueParserFactory for ProfileType { ProfileTypeParser } } + +const fn normal_or_miri(normal_val: T, miri_val: T) -> T { + if cfg!(miri) { miri_val } else { normal_val } +} diff --git a/simulator/runner/env.rs b/simulator/runner/env.rs index 52b57052a..23267519b 100644 --- a/simulator/runner/env.rs +++ b/simulator/runner/env.rs @@ -300,7 +300,6 @@ impl SimulatorEnv { seed, ticks: rng .random_range(cli_opts.minimum_tests as usize..=cli_opts.maximum_tests as usize), - max_tables: rng.random_range(0..128), disable_select_optimizer: cli_opts.disable_select_optimizer, disable_insert_values_select: cli_opts.disable_insert_values_select, disable_double_create_failure: cli_opts.disable_double_create_failure, @@ -316,6 +315,7 @@ impl SimulatorEnv { max_interactions: rng.random_range(cli_opts.minimum_tests..=cli_opts.maximum_tests), max_time_simulation: cli_opts.maximum_time, disable_reopen_database: cli_opts.disable_reopen_database, + disable_integrity_check: cli_opts.disable_integrity_check, }; // Remove existing database file if it exists @@ -352,6 +352,9 @@ impl SimulatorEnv { profile.io.enable = false; // Disable limits due to differences in return order from turso and rusqlite opts.disable_select_limit = true; + + // There is no `ALTER COLUMN` in SQLite + profile.query.gen_opts.query.alter_table.alter_column = false; } profile.validate().unwrap(); @@ -516,14 +519,6 @@ impl SimulatorEnv { } } -pub trait ConnectionTrait -where - Self: std::marker::Sized + Clone, -{ - fn is_connected(&self) -> bool; - fn disconnect(&mut self); -} - pub(crate) enum SimConnection { LimboConnection(Arc), SQLiteConnection(rusqlite::Connection), @@ -572,7 +567,6 @@ impl Display for SimConnection { pub(crate) struct SimulatorOpts { pub(crate) seed: u64, pub(crate) ticks: usize, - pub(crate) max_tables: usize, pub(crate) disable_select_optimizer: bool, pub(crate) disable_insert_values_select: bool, @@ -585,6 +579,7 @@ pub(crate) struct SimulatorOpts { pub(crate) disable_fsync_no_wait: bool, pub(crate) disable_faulty_query: bool, pub(crate) disable_reopen_database: bool, + pub(crate) disable_integrity_check: bool, pub(crate) max_interactions: u32, pub(crate) page_size: usize, diff --git a/simulator/runner/execution.rs b/simulator/runner/execution.rs index e877a972f..1e30de520 100644 --- a/simulator/runner/execution.rs +++ b/simulator/runner/execution.rs @@ -46,6 +46,7 @@ impl ExecutionHistory { } pub struct ExecutionResult { + #[expect(dead_code)] pub history: ExecutionHistory, pub error: Option, } @@ -282,16 +283,13 @@ fn limbo_integrity_check(conn: &Arc) -> Result<()> { Ok(()) } +#[instrument(skip(env, interaction, stack), fields(conn_index = interaction.connection_index, interaction = %interaction))] fn execute_interaction_rusqlite( env: &mut SimulatorEnv, interaction: &Interaction, stack: &mut Vec, ) -> turso_core::Result { - tracing::trace!( - "execute_interaction_rusqlite(connection_index={}, interaction={})", - interaction.connection_index, - interaction - ); + tracing::info!(""); let SimConnection::SQLiteConnection(conn) = &mut env.connections[interaction.connection_index] else { unreachable!() @@ -343,14 +341,25 @@ fn execute_query_rusqlite( connection: &rusqlite::Connection, query: &Query, ) -> rusqlite::Result>> { + // https://sqlite.org/forum/forumpost/9fe5d047f0 + // Due to a bug in sqlite, we need to execute this query to clear the internal stmt cache so that schema changes become visible always to other connections + connection.query_one("SELECT * FROM pragma_user_version()", (), |_| Ok(()))?; match query { Query::Select(select) => { let mut stmt = connection.prepare(select.to_string().as_str())?; - let columns = stmt.column_count(); let rows = stmt.query_map([], |row| { let mut values = vec![]; - for i in 0..columns { - let value = row.get_unwrap(i); + for i in 0.. { + let value = match row.get(i) { + Ok(value) => value, + Err(err) => match err { + rusqlite::Error::InvalidColumnIndex(_) => break, + _ => { + tracing::error!(?err); + panic!("{err}") + } + }, + }; let value = match value { rusqlite::types::Value::Null => Value::Null, rusqlite::types::Value::Integer(i) => Value::Integer(i), @@ -368,6 +377,9 @@ fn execute_query_rusqlite( } Ok(result) } + Query::Placeholder => { + unreachable!("simulation cannot have a placeholder Query for execution") + } _ => { connection.execute(query.to_string().as_str(), ())?; Ok(vec![]) diff --git a/simulator/runner/io.rs b/simulator/runner/io.rs index 8eccd470a..baf4d9e98 100644 --- a/simulator/runner/io.rs +++ b/simulator/runner/io.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use rand::{RngCore, SeedableRng}; +use rand::{Rng, RngCore, SeedableRng}; use rand_chacha::ChaCha8Rng; use turso_core::{Clock, IO, Instant, OpenFlags, PlatformIO, Result}; @@ -79,6 +79,11 @@ impl SimIO for SimulatorIO { fn close_files(&self) { self.files.borrow_mut().clear() } + + fn persist_files(&self) -> anyhow::Result<()> { + // Files are persisted automatically + Ok(()) + } } impl Clock for SimulatorIO { @@ -131,6 +136,10 @@ impl IO for SimulatorIO { } fn generate_random_number(&self) -> i64 { - self.rng.borrow_mut().next_u64() as i64 + self.rng.borrow_mut().random() + } + + fn fill_bytes(&self, dest: &mut [u8]) { + self.rng.borrow_mut().fill_bytes(dest); } } diff --git a/simulator/runner/memory/io.rs b/simulator/runner/memory/io.rs index 007398a10..fc406e7c1 100644 --- a/simulator/runner/memory/io.rs +++ b/simulator/runner/memory/io.rs @@ -1,9 +1,9 @@ -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use std::sync::Arc; use indexmap::IndexMap; use parking_lot::Mutex; -use rand::{RngCore, SeedableRng}; +use rand::{Rng, RngCore, SeedableRng}; use rand_chacha::ChaCha8Rng; use turso_core::{Clock, Completion, IO, Instant, OpenFlags, Result}; @@ -121,7 +121,7 @@ pub struct MemorySimIO { timeouts: CallbackQueue, pub files: RefCell>>, pub rng: RefCell, - pub nr_run_once_faults: Cell, + #[expect(dead_code)] pub page_size: usize, seed: u64, latency_probability: u8, @@ -141,13 +141,11 @@ impl MemorySimIO { ) -> Self { let files = RefCell::new(IndexMap::new()); let rng = RefCell::new(ChaCha8Rng::seed_from_u64(seed)); - let nr_run_once_faults = Cell::new(0); Self { callbacks: Arc::new(Mutex::new(Vec::new())), timeouts: Arc::new(Mutex::new(Vec::new())), files, rng, - nr_run_once_faults, page_size, seed, latency_probability, @@ -192,6 +190,17 @@ impl SimIO for MemorySimIO { file.closed.set(true); } } + + fn persist_files(&self) -> anyhow::Result<()> { + let files = self.files.borrow(); + for (file_path, file) in files.iter() { + if file_path.ends_with(".db") || file_path.ends_with("wal") || file_path.ends_with("lg") + { + std::fs::write(file_path, &*file.buffer.borrow())?; + } + } + Ok(()) + } } impl Clock for MemorySimIO { @@ -260,7 +269,11 @@ impl IO for MemorySimIO { } fn generate_random_number(&self) -> i64 { - self.rng.borrow_mut().next_u64() as i64 + self.rng.borrow_mut().random() + } + + fn fill_bytes(&self, dest: &mut [u8]) { + self.rng.borrow_mut().fill_bytes(dest); } fn remove_file(&self, path: &str) -> Result<()> { diff --git a/simulator/runner/mod.rs b/simulator/runner/mod.rs index 7afbaa720..ed898100a 100644 --- a/simulator/runner/mod.rs +++ b/simulator/runner/mod.rs @@ -5,7 +5,7 @@ pub mod differential; pub mod doublecheck; pub mod env; pub mod execution; -#[allow(dead_code)] +#[expect(dead_code)] pub mod file; pub mod io; pub mod memory; @@ -20,4 +20,6 @@ pub trait SimIO: turso_core::IO { fn syncing(&self) -> bool; fn close_files(&self); + + fn persist_files(&self) -> anyhow::Result<()>; } diff --git a/simulator/shrink/plan.rs b/simulator/shrink/plan.rs index 93f2f1702..9f80f78cc 100644 --- a/simulator/shrink/plan.rs +++ b/simulator/shrink/plan.rs @@ -101,12 +101,6 @@ impl InteractionPlan { // Remove all properties that do not use the failing tables self.retain_mut(|interactions| { let retain = if idx == failing_interaction_index { - if let InteractionsType::Property( - Property::FsyncNoWait { tables, .. } | Property::FaultyQuery { tables, .. }, - ) = &mut interactions.interactions - { - tables.retain(|table| depending_tables.contains(table)); - } true } else { let mut has_table = interactions @@ -126,16 +120,14 @@ impl InteractionPlan { | Property::DeleteSelect { queries, .. } | Property::DropSelect { queries, .. } | Property::Queries { queries } => { + // Remove placeholder queries + queries.retain(|query| !matches!(query, Query::Placeholder)); extensional_queries.append(queries); } - Property::FsyncNoWait { tables, query } - | Property::FaultyQuery { tables, query } => { - if !query.uses().iter().any(|t| depending_tables.contains(t)) { - tables.clear(); - } else { - tables.retain(|table| depending_tables.contains(table)); - } + Property::AllTableHaveExpectedContent { tables } => { + tables.retain(|table| depending_tables.contains(table)); } + Property::FsyncNoWait { .. } | Property::FaultyQuery { .. } => {} Property::SelectLimit { .. } | Property::SelectSelectOptimizer { .. } | Property::WhereTrueFalseNull { .. } @@ -350,7 +342,8 @@ impl InteractionPlan { | Property::FaultyQuery { .. } | Property::FsyncNoWait { .. } | Property::ReadYourUpdatesBack { .. } - | Property::TableHasExpectedContent { .. } => {} + | Property::TableHasExpectedContent { .. } + | Property::AllTableHaveExpectedContent { .. } => {} } } } diff --git a/sql_generation/Cargo.toml b/sql_generation/Cargo.toml index d42668237..8d1084f24 100644 --- a/sql_generation/Cargo.toml +++ b/sql_generation/Cargo.toml @@ -13,7 +13,7 @@ path = "lib.rs" hex = { workspace = true } serde = { workspace = true, features = ["derive"] } turso_core = { workspace = true, features = ["simulator"] } -turso_parser = { workspace = true, features = ["serde"] } +turso_parser = { workspace = true, features = ["serde", "simulator"] } rand = { workspace = true } anarchist-readable-name-generator-lib = "0.2.0" itertools = { workspace = true } @@ -22,6 +22,7 @@ tracing = { workspace = true } schemars = { workspace = true } garde = { workspace = true, features = ["derive", "serde"] } indexmap = { workspace = true } +strum = { workspace = true } [dev-dependencies] rand_chacha = { workspace = true } diff --git a/sql_generation/generation/expr.rs b/sql_generation/generation/expr.rs index 25bdc0d97..e60baf78f 100644 --- a/sql_generation/generation/expr.rs +++ b/sql_generation/generation/expr.rs @@ -14,7 +14,7 @@ impl Arbitrary for Box where T: Arbitrary, { - fn arbitrary(rng: &mut R, context: &C) -> Self { + fn arbitrary(rng: &mut R, context: &C) -> Self { Box::from(T::arbitrary(rng, context)) } } @@ -23,7 +23,7 @@ impl ArbitrarySized for Box where T: ArbitrarySized, { - fn arbitrary_sized( + fn arbitrary_sized( rng: &mut R, context: &C, size: usize, @@ -36,7 +36,7 @@ impl ArbitrarySizedFrom for Box where T: ArbitrarySizedFrom, { - fn arbitrary_sized_from( + fn arbitrary_sized_from( rng: &mut R, context: &C, t: A, @@ -50,7 +50,7 @@ impl Arbitrary for Option where T: Arbitrary, { - fn arbitrary(rng: &mut R, context: &C) -> Self { + fn arbitrary(rng: &mut R, context: &C) -> Self { rng.random_bool(0.5).then_some(T::arbitrary(rng, context)) } } @@ -59,7 +59,7 @@ impl ArbitrarySizedFrom for Option where T: ArbitrarySizedFrom, { - fn arbitrary_sized_from( + fn arbitrary_sized_from( rng: &mut R, context: &C, t: A, @@ -74,7 +74,11 @@ impl ArbitraryFrom for Vec where T: ArbitraryFrom, { - fn arbitrary_from(rng: &mut R, context: &C, t: A) -> Self { + fn arbitrary_from( + rng: &mut R, + context: &C, + t: A, + ) -> Self { let size = rng.random_range(0..5); (0..size) .map(|_| T::arbitrary_from(rng, context, t)) @@ -84,7 +88,7 @@ where // Freestyling generation impl ArbitrarySized for Expr { - fn arbitrary_sized( + fn arbitrary_sized( rng: &mut R, context: &C, size: usize, @@ -188,7 +192,7 @@ impl ArbitrarySized for Expr { } impl Arbitrary for Operator { - fn arbitrary(rng: &mut R, _context: &C) -> Self { + fn arbitrary(rng: &mut R, _context: &C) -> Self { let choices = [ Operator::Add, Operator::And, @@ -219,7 +223,7 @@ impl Arbitrary for Operator { } impl Arbitrary for Type { - fn arbitrary(rng: &mut R, _context: &C) -> Self { + fn arbitrary(rng: &mut R, _context: &C) -> Self { let name = pick(&["INT", "INTEGER", "REAL", "TEXT", "BLOB", "ANY"], rng).to_string(); Self { name, @@ -229,7 +233,7 @@ impl Arbitrary for Type { } impl Arbitrary for QualifiedName { - fn arbitrary(rng: &mut R, context: &C) -> Self { + fn arbitrary(rng: &mut R, context: &C) -> Self { // TODO: for now just generate table name let table_idx = pick_index(context.tables().len(), rng); let table = &context.tables()[table_idx]; @@ -243,7 +247,7 @@ impl Arbitrary for QualifiedName { } impl Arbitrary for LikeOperator { - fn arbitrary(rng: &mut R, _t: &C) -> Self { + fn arbitrary(rng: &mut R, _t: &C) -> Self { let choice = rng.random_range(0..4); match choice { 0 => LikeOperator::Glob, @@ -257,7 +261,7 @@ impl Arbitrary for LikeOperator { // Current implementation does not take into account the columns affinity nor if table is Strict impl Arbitrary for ast::Literal { - fn arbitrary(rng: &mut R, _t: &C) -> Self { + fn arbitrary(rng: &mut R, _t: &C) -> Self { loop { let choice = rng.random_range(0..5); let lit = match choice { @@ -284,7 +288,7 @@ impl Arbitrary for ast::Literal { // Creates a litreal value impl ArbitraryFrom<&Vec<&SimValue>> for ast::Expr { - fn arbitrary_from( + fn arbitrary_from( rng: &mut R, _context: &C, values: &Vec<&SimValue>, @@ -299,7 +303,7 @@ impl ArbitraryFrom<&Vec<&SimValue>> for ast::Expr { } impl Arbitrary for UnaryOperator { - fn arbitrary(rng: &mut R, _t: &C) -> Self { + fn arbitrary(rng: &mut R, _t: &C) -> Self { let choice = rng.random_range(0..4); match choice { 0 => Self::BitwiseNot, diff --git a/sql_generation/generation/mod.rs b/sql_generation/generation/mod.rs index 25f353673..e67dc482b 100644 --- a/sql_generation/generation/mod.rs +++ b/sql_generation/generation/mod.rs @@ -8,6 +8,7 @@ pub mod opts; pub mod predicate; pub mod query; pub mod table; +pub mod value; pub use opts::*; @@ -19,7 +20,7 @@ type Choice<'a, R, T> = (usize, Box Option + 'a>); /// the possible values of the type, with a bias towards smaller values for /// practicality. pub trait Arbitrary { - fn arbitrary(rng: &mut R, context: &C) -> Self; + fn arbitrary(rng: &mut R, context: &C) -> Self; } /// ArbitrarySized trait for generating random values of a specific size @@ -29,8 +30,11 @@ pub trait Arbitrary { /// must fit in the given size. This is useful for generating values that are /// constrained by a specific size, such as integers or strings. pub trait ArbitrarySized { - fn arbitrary_sized(rng: &mut R, context: &C, size: usize) - -> Self; + fn arbitrary_sized( + rng: &mut R, + context: &C, + size: usize, + ) -> Self; } /// ArbitraryFrom trait for generating random values from a given value @@ -39,7 +43,11 @@ pub trait ArbitrarySized { /// such as generating an integer within an interval, or a value that fits in a table, /// or a predicate satisfying a given table row. pub trait ArbitraryFrom { - fn arbitrary_from(rng: &mut R, context: &C, t: T) -> Self; + fn arbitrary_from( + rng: &mut R, + context: &C, + t: T, + ) -> Self; } /// ArbitrarySizedFrom trait for generating random values from a given value @@ -51,7 +59,7 @@ pub trait ArbitraryFrom { /// This is useful for generating values that are constrained by a specific size, /// such as integers or strings, while still being dependent on the given value. pub trait ArbitrarySizedFrom { - fn arbitrary_sized_from( + fn arbitrary_sized_from( rng: &mut R, context: &C, t: T, @@ -61,7 +69,7 @@ pub trait ArbitrarySizedFrom { /// ArbitraryFromMaybe trait for fallibally generating random values from a given value pub trait ArbitraryFromMaybe { - fn arbitrary_from_maybe( + fn arbitrary_from_maybe( rng: &mut R, context: &C, t: T, @@ -77,7 +85,11 @@ pub trait ArbitraryFromMaybe { /// the operations we require for the implementation. // todo: switch to a simpler type signature that can accommodate all integer and float types, which // should be enough for our purposes. -pub fn frequency( +pub fn frequency< + T, + R: Rng + ?Sized, + N: Sum + PartialOrd + Copy + Default + SampleUniform + SubAssign, +>( choices: Vec<(N, ArbitraryFromFunc)>, rng: &mut R, ) -> T { @@ -95,7 +107,7 @@ pub fn frequency(choices: Vec>, rng: &mut R) -> T { +pub fn one_of(choices: Vec>, rng: &mut R) -> T { let index = rng.random_range(0..choices.len()); choices[index](rng) } @@ -103,7 +115,7 @@ pub fn one_of(choices: Vec>, rng: &mut R) -> /// backtrack is a helper function for composing different "failable" generators. /// The function takes a list of functions that return an Option, along with number of retries /// to make before giving up. -pub fn backtrack(mut choices: Vec>, rng: &mut R) -> Option { +pub fn backtrack(mut choices: Vec>, rng: &mut R) -> Option { loop { // If there are no more choices left, we give up let choices_ = choices @@ -129,20 +141,20 @@ pub fn backtrack(mut choices: Vec>, rng: &mut R) -> Opti } /// pick is a helper function for uniformly picking a random element from a slice -pub fn pick<'a, T, R: Rng>(choices: &'a [T], rng: &mut R) -> &'a T { +pub fn pick<'a, T, R: Rng + ?Sized>(choices: &'a [T], rng: &mut R) -> &'a T { let index = rng.random_range(0..choices.len()); &choices[index] } /// pick_index is typically used for picking an index from a slice to later refer to the element /// at that index. -pub fn pick_index(choices: usize, rng: &mut R) -> usize { +pub fn pick_index(choices: usize, rng: &mut R) -> usize { rng.random_range(0..choices) } /// pick_n_unique is a helper function for uniformly picking N unique elements from a range. /// The elements themselves are usize, typically representing indices. -pub fn pick_n_unique( +pub fn pick_n_unique( range: std::ops::Range, n: usize, rng: &mut R, @@ -155,7 +167,7 @@ pub fn pick_n_unique( /// gen_random_text uses `anarchist_readable_name_generator_lib` to generate random /// readable names for tables, columns, text values etc. -pub fn gen_random_text(rng: &mut T) -> String { +pub fn gen_random_text(rng: &mut R) -> String { let big_text = rng.random_ratio(1, 1000); if big_text { // let max_size: u64 = 2 * 1024 * 1024 * 1024; @@ -172,10 +184,10 @@ pub fn gen_random_text(rng: &mut T) -> String { } } -pub fn pick_unique<'a, T: PartialEq>( +pub fn pick_unique<'a, T: PartialEq, R: Rng + ?Sized>( items: &'a [T], count: usize, - rng: &mut impl rand::Rng, + rng: &mut R, ) -> impl Iterator { let mut picked: Vec<&T> = Vec::new(); while picked.len() < count { diff --git a/sql_generation/generation/opts.rs b/sql_generation/generation/opts.rs index 190033748..fcc818bbe 100644 --- a/sql_generation/generation/opts.rs +++ b/sql_generation/generation/opts.rs @@ -93,6 +93,8 @@ pub struct QueryOpts { pub from_clause: FromClauseOpts, #[garde(dive)] pub insert: InsertOpts, + #[garde(dive)] + pub alter_table: AlterTableOpts, } #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)] @@ -198,6 +200,22 @@ impl Default for InsertOpts { } } +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)] +#[serde(deny_unknown_fields)] +pub struct AlterTableOpts { + #[garde(skip)] + pub alter_column: bool, +} + +#[expect(clippy::derivable_impls)] +impl Default for AlterTableOpts { + fn default() -> Self { + Self { + alter_column: Default::default(), + } + } +} + fn range_struct_min( min: T, ) -> impl FnOnce(&Range, &()) -> garde::Result { @@ -217,7 +235,7 @@ fn range_struct_min( } } -#[allow(dead_code)] +#[expect(dead_code)] fn range_struct_max( max: T, ) -> impl FnOnce(&Range, &()) -> garde::Result { diff --git a/sql_generation/generation/predicate/binary.rs b/sql_generation/generation/predicate/binary.rs index 9867a561a..25d813c76 100644 --- a/sql_generation/generation/predicate/binary.rs +++ b/sql_generation/generation/predicate/binary.rs @@ -6,7 +6,7 @@ use crate::{ generation::{ backtrack, one_of, pick, predicate::{CompoundPredicate, SimplePredicate}, - table::{GTValue, LTValue, LikeValue}, + value::{GTValue, LTValue, LikeValue}, ArbitraryFrom, ArbitraryFromMaybe as _, GenerationContext, }, model::{ @@ -16,46 +16,8 @@ use crate::{ }; impl Predicate { - /// Generate an [ast::Expr::Binary] [Predicate] from a column and [SimValue] - pub fn from_column_binary( - rng: &mut R, - context: &C, - column_name: &str, - value: &SimValue, - ) -> Predicate { - let expr = one_of( - vec![ - Box::new(|_| { - Expr::Binary( - Box::new(Expr::Id(ast::Name::exact(column_name.to_string()))), - ast::Operator::Equals, - Box::new(Expr::Literal(value.into())), - ) - }), - Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, context, value).0; - Expr::Binary( - Box::new(Expr::Id(ast::Name::exact(column_name.to_string()))), - ast::Operator::Greater, - Box::new(Expr::Literal(gt_value.into())), - ) - }), - Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, context, value).0; - Expr::Binary( - Box::new(Expr::Id(ast::Name::exact(column_name.to_string()))), - ast::Operator::Less, - Box::new(Expr::Literal(lt_value.into())), - ) - }), - ], - rng, - ); - Predicate(expr) - } - /// Produces a true [ast::Expr::Binary] [Predicate] that is true for the provided row in the given table - pub fn true_binary( + pub fn true_binary( rng: &mut R, context: &C, t: &Table, @@ -117,7 +79,8 @@ impl Predicate { ( 1, Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, context, value).0; + let lt_value = + LTValue::arbitrary_from(rng, context, (value, column.column_type)).0; Some(Expr::Binary( Box::new(ast::Expr::Qualified( ast::Name::from_string(&table_name), @@ -131,7 +94,8 @@ impl Predicate { ( 1, Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, context, value).0; + let gt_value = + GTValue::arbitrary_from(rng, context, (value, column.column_type)).0; Some(Expr::Binary( Box::new(ast::Expr::Qualified( ast::Name::from_string(&table_name), @@ -168,7 +132,7 @@ impl Predicate { } /// Produces an [ast::Expr::Binary] [Predicate] that is false for the provided row in the given table - pub fn false_binary( + pub fn false_binary( rng: &mut R, context: &C, t: &Table, @@ -223,7 +187,8 @@ impl Predicate { ) }), Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, context, value).0; + let gt_value = + GTValue::arbitrary_from(rng, context, (value, column.column_type)).0; Expr::Binary( Box::new(ast::Expr::Qualified( ast::Name::from_string(&table_name), @@ -234,7 +199,8 @@ impl Predicate { ) }), Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, context, value).0; + let lt_value = + LTValue::arbitrary_from(rng, context, (value, column.column_type)).0; Expr::Binary( Box::new(ast::Expr::Qualified( ast::Name::from_string(&table_name), @@ -253,7 +219,7 @@ impl Predicate { impl SimplePredicate { /// Generates a true [ast::Expr::Binary] [SimplePredicate] from a [TableContext] for a row in the table - pub fn true_binary( + pub fn true_binary( rng: &mut R, context: &C, table: &T, @@ -283,7 +249,12 @@ impl SimplePredicate { ) }), Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, context, column_value).0; + let lt_value = LTValue::arbitrary_from( + rng, + context, + (column_value, column.column.column_type), + ) + .0; Expr::Binary( Box::new(Expr::Qualified( ast::Name::from_string(table_name), @@ -294,7 +265,12 @@ impl SimplePredicate { ) }), Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, context, column_value).0; + let gt_value = GTValue::arbitrary_from( + rng, + context, + (column_value, column.column.column_type), + ) + .0; Expr::Binary( Box::new(Expr::Qualified( ast::Name::from_string(table_name), @@ -311,7 +287,7 @@ impl SimplePredicate { } /// Generates a false [ast::Expr::Binary] [SimplePredicate] from a [TableContext] for a row in the table - pub fn false_binary( + pub fn false_binary( rng: &mut R, context: &C, table: &T, @@ -341,7 +317,12 @@ impl SimplePredicate { ) }), Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, context, column_value).0; + let gt_value = GTValue::arbitrary_from( + rng, + context, + (column_value, column.column.column_type), + ) + .0; Expr::Binary( Box::new(ast::Expr::Qualified( ast::Name::from_string(table_name), @@ -352,7 +333,12 @@ impl SimplePredicate { ) }), Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, context, column_value).0; + let lt_value = LTValue::arbitrary_from( + rng, + context, + (column_value, column.column.column_type), + ) + .0; Expr::Binary( Box::new(ast::Expr::Qualified( ast::Name::from_string(table_name), @@ -373,7 +359,7 @@ impl CompoundPredicate { /// Decide if you want to create an AND or an OR /// /// Creates a Compound Predicate that is TRUE or FALSE for at least a single row - pub fn from_table_binary( + pub fn from_table_binary( rng: &mut R, context: &C, table: &T, diff --git a/sql_generation/generation/predicate/mod.rs b/sql_generation/generation/predicate/mod.rs index 78fa30ae4..75546848a 100644 --- a/sql_generation/generation/predicate/mod.rs +++ b/sql_generation/generation/predicate/mod.rs @@ -21,7 +21,7 @@ struct CompoundPredicate(Predicate); struct SimplePredicate(Predicate); impl, T: TableContext> ArbitraryFrom<(&T, A, bool)> for SimplePredicate { - fn arbitrary_from( + fn arbitrary_from( rng: &mut R, context: &C, (table, row, predicate_value): (&T, A, bool), @@ -46,7 +46,7 @@ impl, T: TableContext> ArbitraryFrom<(&T, A, bool)> for Sim } impl ArbitraryFrom<(&T, bool)> for CompoundPredicate { - fn arbitrary_from( + fn arbitrary_from( rng: &mut R, context: &C, (table, predicate_value): (&T, bool), @@ -56,14 +56,18 @@ impl ArbitraryFrom<(&T, bool)> for CompoundPredicate { } impl ArbitraryFrom<&T> for Predicate { - fn arbitrary_from(rng: &mut R, context: &C, table: &T) -> Self { + fn arbitrary_from( + rng: &mut R, + context: &C, + table: &T, + ) -> Self { let predicate_value = rng.random_bool(0.5); Predicate::arbitrary_from(rng, context, (table, predicate_value)).parens() } } impl ArbitraryFrom<(&T, bool)> for Predicate { - fn arbitrary_from( + fn arbitrary_from( rng: &mut R, context: &C, (table, predicate_value): (&T, bool), @@ -72,18 +76,8 @@ impl ArbitraryFrom<(&T, bool)> for Predicate { } } -impl ArbitraryFrom<(&str, &SimValue)> for Predicate { - fn arbitrary_from( - rng: &mut R, - context: &C, - (column_name, value): (&str, &SimValue), - ) -> Self { - Predicate::from_column_binary(rng, context, column_name, value) - } -} - impl ArbitraryFrom<(&Table, &Vec)> for Predicate { - fn arbitrary_from( + fn arbitrary_from( rng: &mut R, context: &C, (t, row): (&Table, &Vec), diff --git a/sql_generation/generation/predicate/unary.rs b/sql_generation/generation/predicate/unary.rs index 1cc0e0d24..31dfc2a7a 100644 --- a/sql_generation/generation/predicate/unary.rs +++ b/sql_generation/generation/predicate/unary.rs @@ -17,7 +17,7 @@ use crate::{ pub struct TrueValue(pub SimValue); impl ArbitraryFromMaybe<&SimValue> for TrueValue { - fn arbitrary_from_maybe( + fn arbitrary_from_maybe( _rng: &mut R, _context: &C, value: &SimValue, @@ -31,7 +31,7 @@ impl ArbitraryFromMaybe<&SimValue> for TrueValue { } impl ArbitraryFromMaybe<&Vec<&SimValue>> for TrueValue { - fn arbitrary_from_maybe( + fn arbitrary_from_maybe( rng: &mut R, context: &C, values: &Vec<&SimValue>, @@ -51,7 +51,7 @@ impl ArbitraryFromMaybe<&Vec<&SimValue>> for TrueValue { pub struct FalseValue(pub SimValue); impl ArbitraryFromMaybe<&SimValue> for FalseValue { - fn arbitrary_from_maybe( + fn arbitrary_from_maybe( _rng: &mut R, _context: &C, value: &SimValue, @@ -65,7 +65,7 @@ impl ArbitraryFromMaybe<&SimValue> for FalseValue { } impl ArbitraryFromMaybe<&Vec<&SimValue>> for FalseValue { - fn arbitrary_from_maybe( + fn arbitrary_from_maybe( rng: &mut R, context: &C, values: &Vec<&SimValue>, @@ -86,7 +86,7 @@ impl ArbitraryFromMaybe<&Vec<&SimValue>> for FalseValue { pub struct BitNotValue(pub SimValue); impl ArbitraryFromMaybe<(&SimValue, bool)> for BitNotValue { - fn arbitrary_from_maybe( + fn arbitrary_from_maybe( _rng: &mut R, _context: &C, (value, predicate): (&SimValue, bool), @@ -101,7 +101,7 @@ impl ArbitraryFromMaybe<(&SimValue, bool)> for BitNotValue { } impl ArbitraryFromMaybe<(&Vec<&SimValue>, bool)> for BitNotValue { - fn arbitrary_from_maybe( + fn arbitrary_from_maybe( rng: &mut R, context: &C, (values, predicate): (&Vec<&SimValue>, bool), @@ -121,7 +121,7 @@ impl ArbitraryFromMaybe<(&Vec<&SimValue>, bool)> for BitNotValue { // TODO: have some more complex generation with columns names here as well impl SimplePredicate { /// Generates a true [ast::Expr::Unary] [SimplePredicate] from a [TableContext] for some values in the table - pub fn true_unary( + pub fn true_unary( rng: &mut R, context: &C, _table: &T, @@ -187,7 +187,7 @@ impl SimplePredicate { } /// Generates a false [ast::Expr::Unary] [SimplePredicate] from a [TableContext] for a row in the table - pub fn false_unary( + pub fn false_unary( rng: &mut R, context: &C, _table: &T, diff --git a/sql_generation/generation/query.rs b/sql_generation/generation/query.rs index a0e0e47b0..c4be7f2d8 100644 --- a/sql_generation/generation/query.rs +++ b/sql_generation/generation/query.rs @@ -1,24 +1,28 @@ use crate::generation::{ - gen_random_text, pick_n_unique, pick_unique, Arbitrary, ArbitraryFrom, ArbitrarySized, - GenerationContext, + gen_random_text, pick_index, pick_n_unique, pick_unique, Arbitrary, ArbitraryFrom, + ArbitrarySized, GenerationContext, }; +use crate::model::query::alter_table::{AlterTable, AlterTableType, AlterTableTypeDiscriminants}; use crate::model::query::predicate::Predicate; use crate::model::query::select::{ CompoundOperator, CompoundSelect, Distinctness, FromClause, OrderBy, ResultColumn, SelectBody, SelectInner, }; use crate::model::query::update::Update; -use crate::model::query::{Create, CreateIndex, Delete, Drop, Insert, Select}; -use crate::model::table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}; +use crate::model::query::{Create, CreateIndex, Delete, Drop, DropIndex, Insert, Select}; +use crate::model::table::{ + Column, Index, JoinTable, JoinType, JoinedTable, Name, SimValue, Table, TableContext, +}; use indexmap::IndexSet; use itertools::Itertools; +use rand::seq::IndexedRandom; use rand::Rng; use turso_parser::ast::{Expr, SortOrder}; use super::{backtrack, pick}; impl Arbitrary for Create { - fn arbitrary(rng: &mut R, context: &C) -> Self { + fn arbitrary(rng: &mut R, context: &C) -> Self { Create { table: Table::arbitrary(rng, context), } @@ -26,7 +30,7 @@ impl Arbitrary for Create { } impl Arbitrary for FromClause { - fn arbitrary(rng: &mut R, context: &C) -> Self { + fn arbitrary(rng: &mut R, context: &C) -> Self { let opts = &context.opts().query.from_clause; let weights = opts.as_weighted_index(); let num_joins = opts.joins[rng.sample(weights)].num_joins; @@ -85,7 +89,7 @@ impl Arbitrary for FromClause { } impl Arbitrary for SelectInner { - fn arbitrary(rng: &mut R, env: &C) -> Self { + fn arbitrary(rng: &mut R, env: &C) -> Self { let from = FromClause::arbitrary(rng, env); let tables = env.tables().clone(); let join_table = from.into_join_table(&tables); @@ -144,7 +148,7 @@ impl Arbitrary for SelectInner { } impl ArbitrarySized for SelectInner { - fn arbitrary_sized( + fn arbitrary_sized( rng: &mut R, env: &C, num_result_columns: usize, @@ -179,7 +183,7 @@ impl ArbitrarySized for SelectInner { } impl Arbitrary for Distinctness { - fn arbitrary(rng: &mut R, _context: &C) -> Self { + fn arbitrary(rng: &mut R, _context: &C) -> Self { match rng.random_range(0..=5) { 0..4 => Distinctness::All, _ => Distinctness::Distinct, @@ -188,7 +192,7 @@ impl Arbitrary for Distinctness { } impl Arbitrary for CompoundOperator { - fn arbitrary(rng: &mut R, _context: &C) -> Self { + fn arbitrary(rng: &mut R, _context: &C) -> Self { match rng.random_range(0..=1) { 0 => CompoundOperator::Union, 1 => CompoundOperator::UnionAll, @@ -203,7 +207,7 @@ impl Arbitrary for CompoundOperator { pub struct SelectFree(pub Select); impl Arbitrary for SelectFree { - fn arbitrary(rng: &mut R, env: &C) -> Self { + fn arbitrary(rng: &mut R, env: &C) -> Self { let expr = Predicate(Expr::arbitrary_sized(rng, env, 8)); let select = Select::expr(expr); Self(select) @@ -211,7 +215,7 @@ impl Arbitrary for SelectFree { } impl Arbitrary for Select { - fn arbitrary(rng: &mut R, env: &C) -> Self { + fn arbitrary(rng: &mut R, env: &C) -> Self { // Generate a number of selects based on the query size // If experimental indexes are enabled, we can have selects with compounds // Otherwise, we just have a single select with no compounds @@ -259,7 +263,7 @@ impl Arbitrary for Select { } impl Arbitrary for Insert { - fn arbitrary(rng: &mut R, env: &C) -> Self { + fn arbitrary(rng: &mut R, env: &C) -> Self { let opts = &env.opts().query.insert; let gen_values = |rng: &mut R| { let table = pick(env.tables(), rng); @@ -300,7 +304,7 @@ impl Arbitrary for Insert { } impl Arbitrary for Delete { - fn arbitrary(rng: &mut R, env: &C) -> Self { + fn arbitrary(rng: &mut R, env: &C) -> Self { let table = pick(env.tables(), rng); Self { table: table.name.clone(), @@ -310,7 +314,7 @@ impl Arbitrary for Delete { } impl Arbitrary for Drop { - fn arbitrary(rng: &mut R, env: &C) -> Self { + fn arbitrary(rng: &mut R, env: &C) -> Self { let table = pick(env.tables(), rng); Self { table: table.name.clone(), @@ -319,7 +323,7 @@ impl Arbitrary for Drop { } impl Arbitrary for CreateIndex { - fn arbitrary(rng: &mut R, env: &C) -> Self { + fn arbitrary(rng: &mut R, env: &C) -> Self { assert!( !env.tables().is_empty(), "Cannot create an index when no tables exist in the environment." @@ -358,15 +362,17 @@ impl Arbitrary for CreateIndex { ); CreateIndex { - index_name, - table_name: table.name.clone(), - columns, + index: Index { + index_name, + table_name: table.name.clone(), + columns, + }, } } } impl Arbitrary for Update { - fn arbitrary(rng: &mut R, env: &C) -> Self { + fn arbitrary(rng: &mut R, env: &C) -> Self { let table = pick(env.tables(), rng); let num_cols = rng.random_range(1..=table.columns.len()); let columns = pick_unique(&table.columns, num_cols, rng); @@ -385,3 +391,166 @@ impl Arbitrary for Update { } } } + +const ALTER_TABLE_ALL: &[AlterTableTypeDiscriminants] = &[ + AlterTableTypeDiscriminants::RenameTo, + AlterTableTypeDiscriminants::AddColumn, + AlterTableTypeDiscriminants::AlterColumn, + AlterTableTypeDiscriminants::RenameColumn, + AlterTableTypeDiscriminants::DropColumn, +]; +const ALTER_TABLE_NO_DROP: &[AlterTableTypeDiscriminants] = &[ + AlterTableTypeDiscriminants::RenameTo, + AlterTableTypeDiscriminants::AddColumn, + AlterTableTypeDiscriminants::AlterColumn, + AlterTableTypeDiscriminants::RenameColumn, +]; +const ALTER_TABLE_NO_ALTER_COL: &[AlterTableTypeDiscriminants] = &[ + AlterTableTypeDiscriminants::RenameTo, + AlterTableTypeDiscriminants::AddColumn, + AlterTableTypeDiscriminants::RenameColumn, + AlterTableTypeDiscriminants::DropColumn, +]; +const ALTER_TABLE_NO_ALTER_COL_NO_DROP: &[AlterTableTypeDiscriminants] = &[ + AlterTableTypeDiscriminants::RenameTo, + AlterTableTypeDiscriminants::AddColumn, + AlterTableTypeDiscriminants::RenameColumn, +]; + +// TODO: Unfortunately this diff strategy allocates a couple of IndexSet's +// in the future maybe change this to be more efficient. This is currently acceptable because this function +// is only called for `DropColumn` +fn get_column_diff(table: &Table) -> IndexSet<&str> { + // Columns that are referenced in INDEXES cannot be dropped + let column_cannot_drop = table + .indexes + .iter() + .flat_map(|index| index.columns.iter().map(|(col_name, _)| col_name.as_str())) + .collect::>(); + if column_cannot_drop.len() == table.columns.len() { + // Optimization: all columns are present in indexes so we do not need to but the table column set + return IndexSet::new(); + } + + let column_set: IndexSet<_, std::hash::RandomState> = + IndexSet::from_iter(table.columns.iter().map(|col| col.name.as_str())); + + let diff = column_set + .difference(&column_cannot_drop) + .copied() + .collect::>(); + diff +} + +impl ArbitraryFrom<(&Table, &[AlterTableTypeDiscriminants])> for AlterTableType { + fn arbitrary_from( + rng: &mut R, + context: &C, + (table, choices): (&Table, &[AlterTableTypeDiscriminants]), + ) -> Self { + match choices.choose(rng).unwrap() { + AlterTableTypeDiscriminants::RenameTo => AlterTableType::RenameTo { + new_name: Name::arbitrary(rng, context).0, + }, + AlterTableTypeDiscriminants::AddColumn => AlterTableType::AddColumn { + column: Column::arbitrary(rng, context), + }, + AlterTableTypeDiscriminants::AlterColumn => { + let col_diff = get_column_diff(table); + + if col_diff.is_empty() { + // Generate a DropColumn if we can drop a column + return AlterTableType::arbitrary_from( + rng, + context, + ( + table, + if choices.contains(&AlterTableTypeDiscriminants::DropColumn) { + ALTER_TABLE_NO_ALTER_COL + } else { + ALTER_TABLE_NO_ALTER_COL_NO_DROP + }, + ), + ); + } + + let col_idx = pick_index(col_diff.len(), rng); + let col_name = col_diff.get_index(col_idx).unwrap(); + + AlterTableType::AlterColumn { + old: col_name.to_string(), + new: Column::arbitrary(rng, context), + } + } + AlterTableTypeDiscriminants::RenameColumn => AlterTableType::RenameColumn { + old: pick(&table.columns, rng).name.clone(), + new: Name::arbitrary(rng, context).0, + }, + AlterTableTypeDiscriminants::DropColumn => { + let col_diff = get_column_diff(table); + + if col_diff.is_empty() { + // Generate a DropColumn if we can drop a column + return AlterTableType::arbitrary_from( + rng, + context, + ( + table, + if context.opts().query.alter_table.alter_column { + ALTER_TABLE_NO_DROP + } else { + ALTER_TABLE_NO_ALTER_COL_NO_DROP + }, + ), + ); + } + + let col_idx = pick_index(col_diff.len(), rng); + let col_name = col_diff.get_index(col_idx).unwrap(); + + AlterTableType::DropColumn { + column_name: col_name.to_string(), + } + } + } + } +} + +impl Arbitrary for AlterTable { + fn arbitrary(rng: &mut R, context: &C) -> Self { + let table = pick(context.tables(), rng); + let choices = match ( + table.columns.len() > 1, + context.opts().query.alter_table.alter_column, + ) { + (true, true) => ALTER_TABLE_ALL, + (true, false) => ALTER_TABLE_NO_ALTER_COL, + (false, true) | (false, false) => ALTER_TABLE_NO_ALTER_COL_NO_DROP, + }; + + let alter_table_type = AlterTableType::arbitrary_from(rng, context, (table, choices)); + Self { + table_name: table.name.clone(), + alter_table_type, + } + } +} + +impl Arbitrary for DropIndex { + fn arbitrary(rng: &mut R, context: &C) -> Self { + let tables_with_indexes = context + .tables() + .iter() + .filter(|table| !table.indexes.is_empty()) + .collect::>(); + + // Cannot DROP INDEX if there is no index to drop + assert!(!tables_with_indexes.is_empty()); + let table = tables_with_indexes.choose(rng).unwrap(); + let index = table.indexes.choose(rng).unwrap(); + Self { + index_name: index.index_name.clone(), + table_name: table.name.clone(), + } + } +} diff --git a/sql_generation/generation/table.rs b/sql_generation/generation/table.rs index ce0ff97f4..e8b722180 100644 --- a/sql_generation/generation/table.rs +++ b/sql_generation/generation/table.rs @@ -1,36 +1,48 @@ +use std::sync::atomic::{AtomicU64, Ordering}; + use indexmap::IndexSet; use rand::Rng; -use turso_core::Value; -use crate::generation::{ - gen_random_text, pick, readable_name_custom, Arbitrary, ArbitraryFrom, GenerationContext, -}; -use crate::model::table::{Column, ColumnType, Name, SimValue, Table}; +use crate::generation::{pick, readable_name_custom, Arbitrary, GenerationContext}; +use crate::model::table::{Column, ColumnType, Name, Table}; -use super::ArbitraryFromMaybe; +static COUNTER: AtomicU64 = AtomicU64::new(0); impl Arbitrary for Name { - fn arbitrary(rng: &mut R, _c: &C) -> Self { - let name = readable_name_custom("_", rng); - Name(name.replace("-", "_")) + fn arbitrary(rng: &mut R, _c: &C) -> Self { + let base = readable_name_custom("_", rng).replace("-", "_"); + let id = COUNTER.fetch_add(1, Ordering::Relaxed); + Name(format!("{base}_{id}")) } } -impl Arbitrary for Table { - fn arbitrary(rng: &mut R, context: &C) -> Self { +impl Table { + /// Generate a table with some predefined columns + pub fn arbitrary_with_columns( + rng: &mut R, + context: &C, + name: String, + predefined_columns: Vec, + ) -> Self { let opts = context.opts().table.clone(); - let name = Name::arbitrary(rng, context).0; let large_table = opts.large_table.enable && rng.random_bool(opts.large_table.large_table_prob); - let column_size = if large_table { + let target_column_size = if large_table { rng.random_range(opts.large_table.column_range) } else { rng.random_range(opts.column_range) } as usize; - let mut column_set = IndexSet::with_capacity(column_size); + + // Start with predefined columns + let mut column_set = IndexSet::with_capacity(target_column_size); + for col in predefined_columns { + column_set.insert(col); + } + + // Generate additional columns to reach target size for col in std::iter::repeat_with(|| Column::arbitrary(rng, context)) { column_set.insert(col); - if column_set.len() == column_size { + if column_set.len() >= target_column_size { break; } } @@ -44,244 +56,27 @@ impl Arbitrary for Table { } } +impl Arbitrary for Table { + fn arbitrary(rng: &mut R, context: &C) -> Self { + let name = Name::arbitrary(rng, context).0; + Table::arbitrary_with_columns(rng, context, name, vec![]) + } +} + impl Arbitrary for Column { - fn arbitrary(rng: &mut R, context: &C) -> Self { + fn arbitrary(rng: &mut R, context: &C) -> Self { let name = Name::arbitrary(rng, context).0; let column_type = ColumnType::arbitrary(rng, context); Self { name, column_type, - primary: false, - unique: false, + constraints: vec![], // TODO: later implement arbitrary here for ColumnConstraint } } } impl Arbitrary for ColumnType { - fn arbitrary(rng: &mut R, _context: &C) -> Self { + fn arbitrary(rng: &mut R, _context: &C) -> Self { pick(&[Self::Integer, Self::Float, Self::Text, Self::Blob], rng).to_owned() } } - -impl ArbitraryFrom<&Table> for Vec { - fn arbitrary_from( - rng: &mut R, - context: &C, - table: &Table, - ) -> Self { - let mut row = Vec::new(); - for column in table.columns.iter() { - let value = SimValue::arbitrary_from(rng, context, &column.column_type); - row.push(value); - } - row - } -} - -impl ArbitraryFrom<&Vec<&SimValue>> for SimValue { - fn arbitrary_from( - rng: &mut R, - _context: &C, - values: &Vec<&Self>, - ) -> Self { - if values.is_empty() { - return Self(Value::Null); - } - - pick(values, rng).to_owned().clone() - } -} - -impl ArbitraryFrom<&ColumnType> for SimValue { - fn arbitrary_from( - rng: &mut R, - _context: &C, - column_type: &ColumnType, - ) -> Self { - let value = match column_type { - ColumnType::Integer => Value::Integer(rng.random_range(i64::MIN..i64::MAX)), - ColumnType::Float => Value::Float(rng.random_range(-1e10..1e10)), - ColumnType::Text => Value::build_text(gen_random_text(rng)), - ColumnType::Blob => Value::Blob(gen_random_text(rng).as_bytes().to_vec()), - }; - SimValue(value) - } -} - -pub struct LTValue(pub SimValue); - -impl ArbitraryFrom<&Vec<&SimValue>> for LTValue { - fn arbitrary_from( - rng: &mut R, - context: &C, - values: &Vec<&SimValue>, - ) -> Self { - if values.is_empty() { - return Self(SimValue(Value::Null)); - } - - // Get value less than all values - let value = Value::exec_min(values.iter().map(|value| &value.0)); - Self::arbitrary_from(rng, context, &SimValue(value)) - } -} - -impl ArbitraryFrom<&SimValue> for LTValue { - fn arbitrary_from( - rng: &mut R, - _context: &C, - value: &SimValue, - ) -> Self { - let new_value = match &value.0 { - Value::Integer(i) => Value::Integer(rng.random_range(i64::MIN..*i - 1)), - Value::Float(f) => Value::Float(f - rng.random_range(0.0..1e10)), - value @ Value::Text(..) => { - // Either shorten the string, or make at least one character smaller and mutate the rest - let mut t = value.to_string(); - if rng.random_bool(0.01) { - t.pop(); - Value::build_text(t) - } else { - let mut t = t.chars().map(|c| c as u32).collect::>(); - let index = rng.random_range(0..t.len()); - t[index] -= 1; - // Mutate the rest of the string - for val in t.iter_mut().skip(index + 1) { - *val = rng.random_range('a' as u32..='z' as u32); - } - let t = t - .into_iter() - .map(|c| char::from_u32(c).unwrap_or('z')) - .collect::(); - Value::build_text(t) - } - } - Value::Blob(b) => { - // Either shorten the blob, or make at least one byte smaller and mutate the rest - let mut b = b.clone(); - if rng.random_bool(0.01) { - b.pop(); - Value::Blob(b) - } else { - let index = rng.random_range(0..b.len()); - b[index] -= 1; - // Mutate the rest of the blob - for val in b.iter_mut().skip(index + 1) { - *val = rng.random_range(0..=255); - } - Value::Blob(b) - } - } - _ => unreachable!(), - }; - Self(SimValue(new_value)) - } -} - -pub struct GTValue(pub SimValue); - -impl ArbitraryFrom<&Vec<&SimValue>> for GTValue { - fn arbitrary_from( - rng: &mut R, - context: &C, - values: &Vec<&SimValue>, - ) -> Self { - if values.is_empty() { - return Self(SimValue(Value::Null)); - } - // Get value greater than all values - let value = Value::exec_max(values.iter().map(|value| &value.0)); - - Self::arbitrary_from(rng, context, &SimValue(value)) - } -} - -impl ArbitraryFrom<&SimValue> for GTValue { - fn arbitrary_from( - rng: &mut R, - _context: &C, - value: &SimValue, - ) -> Self { - let new_value = match &value.0 { - Value::Integer(i) => Value::Integer(rng.random_range(*i..i64::MAX)), - Value::Float(f) => Value::Float(rng.random_range(*f..1e10)), - value @ Value::Text(..) => { - // Either lengthen the string, or make at least one character smaller and mutate the rest - let mut t = value.to_string(); - if rng.random_bool(0.01) { - t.push(rng.random_range(0..=255) as u8 as char); - Value::build_text(t) - } else { - let mut t = t.chars().map(|c| c as u32).collect::>(); - let index = rng.random_range(0..t.len()); - t[index] += 1; - // Mutate the rest of the string - for val in t.iter_mut().skip(index + 1) { - *val = rng.random_range('a' as u32..='z' as u32); - } - let t = t - .into_iter() - .map(|c| char::from_u32(c).unwrap_or('a')) - .collect::(); - Value::build_text(t) - } - } - Value::Blob(b) => { - // Either lengthen the blob, or make at least one byte smaller and mutate the rest - let mut b = b.clone(); - if rng.random_bool(0.01) { - b.push(rng.random_range(0..=255)); - Value::Blob(b) - } else { - let index = rng.random_range(0..b.len()); - b[index] += 1; - // Mutate the rest of the blob - for val in b.iter_mut().skip(index + 1) { - *val = rng.random_range(0..=255); - } - Value::Blob(b) - } - } - _ => unreachable!(), - }; - Self(SimValue(new_value)) - } -} - -pub struct LikeValue(pub SimValue); - -impl ArbitraryFromMaybe<&SimValue> for LikeValue { - fn arbitrary_from_maybe( - rng: &mut R, - _context: &C, - value: &SimValue, - ) -> Option { - match &value.0 { - value @ Value::Text(..) => { - let t = value.to_string(); - let mut t = t.chars().collect::>(); - // Remove a number of characters, either insert `_` for each character removed, or - // insert one `%` for the whole substring - let mut i = 0; - while i < t.len() { - if rng.random_bool(0.1) { - t[i] = '_'; - } else if rng.random_bool(0.05) { - t[i] = '%'; - // skip a list of characters - for _ in 0..rng.random_range(0..=3.min(t.len() - i - 1)) { - t.remove(i + 1); - } - } - i += 1; - } - let index = rng.random_range(0..t.len()); - t.insert(index, '%'); - Some(Self(SimValue(Value::build_text( - t.into_iter().collect::(), - )))) - } - _ => None, - } - } -} diff --git a/sql_generation/generation/value/cmp.rs b/sql_generation/generation/value/cmp.rs new file mode 100644 index 000000000..400cfc14c --- /dev/null +++ b/sql_generation/generation/value/cmp.rs @@ -0,0 +1,183 @@ +use turso_core::Value; + +use crate::{ + generation::{ArbitraryFrom, GenerationContext}, + model::table::{ColumnType, SimValue}, +}; + +pub struct LTValue(pub SimValue); + +impl ArbitraryFrom<(&SimValue, ColumnType)> for LTValue { + fn arbitrary_from( + rng: &mut R, + _context: &C, + (value, _col_type): (&SimValue, ColumnType), + ) -> Self { + let new_value = match &value.0 { + Value::Integer(i) => Value::Integer(rng.random_range(i64::MIN..*i - 1)), + Value::Float(f) => Value::Float(f - rng.random_range(0.0..1e10)), + value @ Value::Text(..) => { + // Either shorten the string, or make at least one character smaller and mutate the rest + let mut t = value.to_string(); + if rng.random_bool(0.01) { + t.pop(); + Value::build_text(t) + } else { + Value::build_text(mutate_string(&t, rng, MutationType::Decrement)) + } + } + Value::Blob(b) => { + // Either shorten the blob, or make at least one byte smaller and mutate the rest + let mut b = b.clone(); + if rng.random_bool(0.01) { + b.pop(); + Value::Blob(b) + } else { + let index = rng.random_range(0..b.len()); + b[index] -= 1; + // Mutate the rest of the blob + for val in b.iter_mut().skip(index + 1) { + *val = rng.random_range(0..=255); + } + Value::Blob(b) + } + } + // A value with storage class NULL is considered less than any other value (including another value with storage class NULL) + Value::Null => Value::Null, + }; + Self(SimValue(new_value)) + } +} + +pub struct GTValue(pub SimValue); + +impl ArbitraryFrom<(&SimValue, ColumnType)> for GTValue { + fn arbitrary_from( + rng: &mut R, + context: &C, + (value, col_type): (&SimValue, ColumnType), + ) -> Self { + let new_value = match &value.0 { + Value::Integer(i) => Value::Integer(rng.random_range(*i..i64::MAX)), + Value::Float(f) => Value::Float(rng.random_range(*f..1e10)), + value @ Value::Text(..) => { + // Either lengthen the string, or make at least one character smaller and mutate the rest + let mut t = value.to_string(); + if rng.random_bool(0.01) { + if rng.random_bool(0.5) { + t.push(rng.random_range(UPPERCASE_A..=UPPERCASE_Z) as u8 as char); + } else { + t.push(rng.random_range(LOWERCASE_A..=LOWERCASE_Z) as u8 as char); + } + Value::build_text(t) + } else { + Value::build_text(mutate_string(&t, rng, MutationType::Increment)) + } + } + Value::Blob(b) => { + // Either lengthen the blob, or make at least one byte smaller and mutate the rest + let mut b = b.clone(); + if rng.random_bool(0.01) { + b.push(rng.random_range(0..=255)); + Value::Blob(b) + } else { + let index = rng.random_range(0..b.len()); + b[index] += 1; + // Mutate the rest of the blob + for val in b.iter_mut().skip(index + 1) { + *val = rng.random_range(0..=255); + } + Value::Blob(b) + } + } + Value::Null => { + // Any value is greater than NULL, except NULL + SimValue::arbitrary_from(rng, context, col_type).0 + } + }; + Self(SimValue(new_value)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum MutationType { + Decrement, + Increment, +} + +const UPPERCASE_A: u32 = 'A' as u32; +const UPPERCASE_Z: u32 = 'Z' as u32; +const LOWERCASE_A: u32 = 'a' as u32; +const LOWERCASE_Z: u32 = 'z' as u32; + +fn mutate_string( + t: &str, + rng: &mut R, + mutation_type: MutationType, +) -> String { + let mut chars = t.chars().map(|c| c as u32).collect::>(); + let mut index; + let mut max_loops = 100; + loop { + index = rng.random_range(0..chars.len()); + if chars[index] > UPPERCASE_A && chars[index] < UPPERCASE_Z + || chars[index] > LOWERCASE_A && chars[index] < LOWERCASE_Z + { + break; + } + max_loops -= 1; + if max_loops == 0 { + panic!("Failed to find a printable character to decrement"); + } + } + + if mutation_type == MutationType::Decrement { + chars[index] -= 1; + } else { + chars[index] += 1; + } + + // Mutate the rest of the string with printable ASCII characters + for val in chars.iter_mut().skip(index + 1) { + if rng.random_bool(0.5) { + *val = rng.random_range(UPPERCASE_A..=UPPERCASE_Z); + } else { + *val = rng.random_range(LOWERCASE_A..=LOWERCASE_Z); + } + } + + chars + .into_iter() + .map(|c| char::from_u32(c).unwrap()) + .collect::() +} + +#[cfg(test)] +mod tests { + use anarchist_readable_name_generator_lib::readable_name; + + use super::*; + + #[test] + fn test_mutate_string_fuzz() { + let mut rng = rand::rng(); + for _ in 0..1000 { + let mut t = readable_name(); + while !t.is_ascii() { + t = readable_name(); + } + let t2 = mutate_string(&t, &mut rng, MutationType::Decrement); + assert!(t2.is_ascii(), "{}", t); + assert!(t2 < t); + } + for _ in 0..1000 { + let mut t = readable_name(); + while !t.is_ascii() { + t = readable_name(); + } + let t2 = mutate_string(&t, &mut rng, MutationType::Increment); + assert!(t2.is_ascii(), "{}", t); + assert!(t2 > t); + } + } +} diff --git a/sql_generation/generation/value/mod.rs b/sql_generation/generation/value/mod.rs new file mode 100644 index 000000000..5062c9e57 --- /dev/null +++ b/sql_generation/generation/value/mod.rs @@ -0,0 +1,68 @@ +use rand::Rng; +use turso_core::Value; + +use crate::{ + generation::{gen_random_text, pick, ArbitraryFrom, GenerationContext}, + model::table::{ColumnType, SimValue, Table}, +}; + +mod cmp; +mod pattern; + +pub use cmp::{GTValue, LTValue}; +pub use pattern::LikeValue; + +impl ArbitraryFrom<&Table> for Vec { + fn arbitrary_from( + rng: &mut R, + context: &C, + table: &Table, + ) -> Self { + let mut row = Vec::new(); + for column in table.columns.iter() { + let value = SimValue::arbitrary_from(rng, context, &column.column_type); + row.push(value); + } + row + } +} + +impl ArbitraryFrom<&Vec<&SimValue>> for SimValue { + fn arbitrary_from( + rng: &mut R, + _context: &C, + values: &Vec<&Self>, + ) -> Self { + if values.is_empty() { + return Self(Value::Null); + } + + pick(values, rng).to_owned().clone() + } +} + +impl ArbitraryFrom<&ColumnType> for SimValue { + fn arbitrary_from( + rng: &mut R, + _context: &C, + column_type: &ColumnType, + ) -> Self { + let value = match column_type { + ColumnType::Integer => Value::Integer(rng.random_range(i64::MIN..i64::MAX)), + ColumnType::Float => Value::Float(rng.random_range(-1e10..1e10)), + ColumnType::Text => Value::build_text(gen_random_text(rng)), + ColumnType::Blob => Value::Blob(gen_random_text(rng).into_bytes()), + }; + SimValue(value) + } +} + +impl ArbitraryFrom for SimValue { + fn arbitrary_from( + rng: &mut R, + context: &C, + column_type: ColumnType, + ) -> Self { + SimValue::arbitrary_from(rng, context, &column_type) + } +} diff --git a/sql_generation/generation/value/pattern.rs b/sql_generation/generation/value/pattern.rs new file mode 100644 index 000000000..3bf0d7a9f --- /dev/null +++ b/sql_generation/generation/value/pattern.rs @@ -0,0 +1,44 @@ +use turso_core::Value; + +use crate::{ + generation::{ArbitraryFromMaybe, GenerationContext}, + model::table::SimValue, +}; + +pub struct LikeValue(pub SimValue); + +impl ArbitraryFromMaybe<&SimValue> for LikeValue { + fn arbitrary_from_maybe( + rng: &mut R, + _context: &C, + value: &SimValue, + ) -> Option { + match &value.0 { + value @ Value::Text(..) => { + let t = value.to_string(); + let mut t = t.chars().collect::>(); + // Remove a number of characters, either insert `_` for each character removed, or + // insert one `%` for the whole substring + let mut i = 0; + while i < t.len() { + if rng.random_bool(0.1) { + t[i] = '_'; + } else if rng.random_bool(0.05) { + t[i] = '%'; + // skip a list of characters + for _ in 0..rng.random_range(0..=3.min(t.len() - i - 1)) { + t.remove(i + 1); + } + } + i += 1; + } + let index = rng.random_range(0..t.len()); + t.insert(index, '%'); + Some(Self(SimValue(Value::build_text( + t.into_iter().collect::(), + )))) + } + _ => None, + } + } +} diff --git a/sql_generation/model/query/alter_table.rs b/sql_generation/model/query/alter_table.rs new file mode 100644 index 000000000..684198b35 --- /dev/null +++ b/sql_generation/model/query/alter_table.rs @@ -0,0 +1,54 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use crate::model::table::Column; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct AlterTable { + pub table_name: String, + pub alter_table_type: AlterTableType, +} + +// TODO: in the future maybe use parser AST's when we test almost the entire SQL spectrum +// so we can repeat less code +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, strum::EnumDiscriminants)] +pub enum AlterTableType { + /// `RENAME TO`: new table name + RenameTo { new_name: String }, + /// `ADD COLUMN` + AddColumn { column: Column }, + /// `ALTER COLUMN` + AlterColumn { old: String, new: Column }, + /// `RENAME COLUMN` + RenameColumn { + /// old name + old: String, + /// new name + new: String, + }, + /// `DROP COLUMN` + DropColumn { column_name: String }, +} + +impl Display for AlterTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ALTER TABLE {} {}", + self.table_name, self.alter_table_type + ) + } +} + +impl Display for AlterTableType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AlterTableType::RenameTo { new_name } => write!(f, "RENAME TO {new_name}"), + AlterTableType::AddColumn { column } => write!(f, "ADD COLUMN {column}"), + AlterTableType::AlterColumn { old, new } => write!(f, "ALTER COLUMN {old} TO {new}"), + AlterTableType::RenameColumn { old, new } => write!(f, "RENAME COLUMN {old} TO {new}"), + AlterTableType::DropColumn { column_name } => write!(f, "DROP COLUMN {column_name}"), + } + } +} diff --git a/sql_generation/model/query/create.rs b/sql_generation/model/query/create.rs index 607d5fe8d..ee028e879 100644 --- a/sql_generation/model/query/create.rs +++ b/sql_generation/model/query/create.rs @@ -1,5 +1,6 @@ use std::fmt::Display; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use crate::model::table::Table; @@ -13,13 +14,13 @@ impl Display for Create { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "CREATE TABLE {} (", self.table.name)?; - for (i, column) in self.table.columns.iter().enumerate() { - if i != 0 { - write!(f, ",")?; - } - write!(f, "{} {}", column.name, column.column_type)?; - } + let cols = self + .table + .columns + .iter() + .map(|column| column.to_string()) + .join(", "); - write!(f, ")") + write!(f, "{cols})") } } diff --git a/sql_generation/model/query/create_index.rs b/sql_generation/model/query/create_index.rs index db9d15a04..55548114e 100644 --- a/sql_generation/model/query/create_index.rs +++ b/sql_generation/model/query/create_index.rs @@ -1,11 +1,26 @@ +use std::ops::{Deref, DerefMut}; + use serde::{Deserialize, Serialize}; -use turso_parser::ast::SortOrder; + +use crate::model::table::Index; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct CreateIndex { - pub index_name: String, - pub table_name: String, - pub columns: Vec<(String, SortOrder)>, + pub index: Index, +} + +impl Deref for CreateIndex { + type Target = Index; + + fn deref(&self) -> &Self::Target { + &self.index + } +} + +impl DerefMut for CreateIndex { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.index + } } impl std::fmt::Display for CreateIndex { @@ -13,9 +28,10 @@ impl std::fmt::Display for CreateIndex { write!( f, "CREATE INDEX {} ON {} ({})", - self.index_name, - self.table_name, - self.columns + self.index.index_name, + self.index.table_name, + self.index + .columns .iter() .map(|(name, order)| format!("{name} {order}")) .collect::>() diff --git a/sql_generation/model/query/drop_index.rs b/sql_generation/model/query/drop_index.rs index 18cadb12d..670636efb 100644 --- a/sql_generation/model/query/drop_index.rs +++ b/sql_generation/model/query/drop_index.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct DropIndex { pub index_name: String, + pub table_name: String, } impl std::fmt::Display for DropIndex { diff --git a/sql_generation/model/query/insert.rs b/sql_generation/model/query/insert.rs index d69921388..4e5994f14 100644 --- a/sql_generation/model/query/insert.rs +++ b/sql_generation/model/query/insert.rs @@ -24,6 +24,13 @@ impl Insert { Insert::Values { table, .. } | Insert::Select { table, .. } => table, } } + + pub fn rows(&self) -> &[Vec] { + match self { + Insert::Values { values, .. } => values, + Insert::Select { .. } => unreachable!(), + } + } } impl Display for Insert { diff --git a/sql_generation/model/query/mod.rs b/sql_generation/model/query/mod.rs index 98ec2bdfd..9876ffe54 100644 --- a/sql_generation/model/query/mod.rs +++ b/sql_generation/model/query/mod.rs @@ -6,6 +6,7 @@ pub use drop_index::DropIndex; pub use insert::Insert; pub use select::Select; +pub mod alter_table; pub mod create; pub mod create_index; pub mod delete; diff --git a/sql_generation/model/query/select.rs b/sql_generation/model/query/select.rs index 8add21bfb..3ba71c06f 100644 --- a/sql_generation/model/query/select.rs +++ b/sql_generation/model/query/select.rs @@ -243,28 +243,24 @@ impl FromClause { match join.join_type { JoinType::Inner => { // Implement inner join logic - let join_rows = joined_table - .rows - .iter() - .filter(|row| join.on.test(row, joined_table)) - .cloned() - .collect::>(); // take a cartesian product of the rows let all_row_pairs = join_table .rows .clone() .into_iter() - .cartesian_product(join_rows.iter()); + .cartesian_product(joined_table.rows.iter()); + let mut new_rows = Vec::new(); for (row1, row2) in all_row_pairs { let row = row1.iter().chain(row2.iter()).cloned().collect::>(); let is_in = join.on.test(&row, &join_table); if is_in { - join_table.rows.push(row); + new_rows.push(row); } } + join_table.rows = new_rows; } _ => todo!(), } diff --git a/sql_generation/model/table.rs b/sql_generation/model/table.rs index 87057b42b..dce2fdddf 100644 --- a/sql_generation/model/table.rs +++ b/sql_generation/model/table.rs @@ -1,8 +1,9 @@ use std::{fmt::Display, hash::Hash, ops::Deref}; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use turso_core::{numeric::Numeric, types}; -use turso_parser::ast; +use turso_parser::ast::{self, ColumnConstraint, SortOrder}; use crate::model::query::predicate::Predicate; @@ -45,7 +46,7 @@ pub struct Table { pub name: String, pub columns: Vec, pub rows: Vec>, - pub indexes: Vec, + pub indexes: Vec, } impl Table { @@ -63,8 +64,7 @@ impl Table { pub struct Column { pub name: String, pub column_type: ColumnType, - pub primary: bool, - pub unique: bool, + pub constraints: Vec, } // Uniquely defined by name in this case @@ -82,7 +82,23 @@ impl PartialEq for Column { impl Eq for Column {} -#[derive(Debug, Clone, Serialize, Deserialize)] +impl Display for Column { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let constraints = self + .constraints + .iter() + .map(|constraint| constraint.to_string()) + .join(" "); + let mut col_string = format!("{} {}", self.name, self.column_type); + if !constraints.is_empty() { + col_string.push(' '); + col_string.push_str(&constraints); + } + write!(f, "{col_string}") + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum ColumnType { Integer, Float, @@ -101,6 +117,13 @@ impl Display for ColumnType { } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Index { + pub table_name: String, + pub index_name: String, + pub columns: Vec<(String, SortOrder)>, +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct JoinedTable { /// table name @@ -163,19 +186,34 @@ impl Display for SimValue { impl SimValue { pub const FALSE: Self = SimValue(types::Value::Integer(0)); pub const TRUE: Self = SimValue(types::Value::Integer(1)); + pub const NULL: Self = SimValue(types::Value::Null); pub fn as_bool(&self) -> bool { Numeric::from(&self.0).try_into_bool().unwrap_or_default() } + #[inline] + fn is_null(&self) -> bool { + matches!(self.0, types::Value::Null) + } + + // The result of any binary operator is either a numeric value or NULL, except for the || concatenation operator, and the -> and ->> extract operators which can return values of any type. + // All operators generally evaluate to NULL when any operand is NULL, with specific exceptions as stated below. This is in accordance with the SQL92 standard. + // When paired with NULL: + // AND evaluates to 0 (false) when the other operand is false; and + // OR evaluates to 1 (true) when the other operand is true. + // The IS and IS NOT operators work like = and != except when one or both of the operands are NULL. In this case, if both operands are NULL, then the IS operator evaluates to 1 (true) and the IS NOT operator evaluates to 0 (false). If one operand is NULL and the other is not, then the IS operator evaluates to 0 (false) and the IS NOT operator is 1 (true). It is not possible for an IS or IS NOT expression to evaluate to NULL. + // The IS NOT DISTINCT FROM operator is an alternative spelling for the IS operator. Likewise, the IS DISTINCT FROM operator means the same thing as IS NOT. Standard SQL does not support the compact IS and IS NOT notation. Those compact forms are an SQLite extension. You must use the less readable IS NOT DISTINCT FROM and IS DISTINCT FROM operators in most other SQL database engines. + // TODO: support more predicates /// Returns a Result of a Binary Operation /// /// TODO: forget collations for now /// TODO: have the [ast::Operator::Equals], [ast::Operator::NotEquals], [ast::Operator::Greater], /// [ast::Operator::GreaterEquals], [ast::Operator::Less], [ast::Operator::LessEquals] function to be extracted - /// into its functions in turso_core so that it can be used here + /// into its functions in turso_core so that it can be used here. For now we just do the `not_null` check to avoid refactoring code in core pub fn binary_compare(&self, other: &Self, operator: ast::Operator) -> SimValue { + let not_null = !self.is_null() && !other.is_null(); match operator { ast::Operator::Add => self.0.exec_add(&other.0).into(), ast::Operator::And => self.0.exec_and(&other.0).into(), @@ -185,10 +223,10 @@ impl SimValue { ast::Operator::BitwiseOr => self.0.exec_bit_or(&other.0).into(), ast::Operator::BitwiseNot => todo!(), // TODO: Do not see any function usage of this operator in Core ast::Operator::Concat => self.0.exec_concat(&other.0).into(), - ast::Operator::Equals => (self == other).into(), + ast::Operator::Equals => not_null.then(|| self == other).into(), ast::Operator::Divide => self.0.exec_divide(&other.0).into(), - ast::Operator::Greater => (self > other).into(), - ast::Operator::GreaterEquals => (self >= other).into(), + ast::Operator::Greater => not_null.then(|| self > other).into(), + ast::Operator::GreaterEquals => not_null.then(|| self >= other).into(), // TODO: Test these implementations ast::Operator::Is => match (&self.0, &other.0) { (types::Value::Null, types::Value::Null) => true.into(), @@ -200,11 +238,11 @@ impl SimValue { .binary_compare(other, ast::Operator::Is) .unary_exec(ast::UnaryOperator::Not), ast::Operator::LeftShift => self.0.exec_shift_left(&other.0).into(), - ast::Operator::Less => (self < other).into(), - ast::Operator::LessEquals => (self <= other).into(), + ast::Operator::Less => not_null.then(|| self < other).into(), + ast::Operator::LessEquals => not_null.then(|| self <= other).into(), ast::Operator::Modulus => self.0.exec_remainder(&other.0).into(), ast::Operator::Multiply => self.0.exec_multiply(&other.0).into(), - ast::Operator::NotEquals => (self != other).into(), + ast::Operator::NotEquals => not_null.then(|| self != other).into(), ast::Operator::Or => self.0.exec_or(&other.0).into(), ast::Operator::RightShift => self.0.exec_shift_right(&other.0).into(), ast::Operator::Subtract => self.0.exec_subtract(&other.0).into(), @@ -349,7 +387,18 @@ impl From<&SimValue> for ast::Literal { } } +impl From> for SimValue { + #[inline] + fn from(value: Option) -> Self { + if value.is_none() { + return SimValue::NULL; + } + SimValue::from(value.unwrap()) + } +} + impl From for SimValue { + #[inline] fn from(value: bool) -> Self { if value { SimValue::TRUE diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 037e5015b..2cfa52bbd 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -94,15 +94,25 @@ pub struct sqlite3_stmt { *mut ffi::c_void, )>, pub(crate) next: *mut sqlite3_stmt, + pub(crate) text_cache: Vec>, } impl sqlite3_stmt { pub fn new(db: *mut sqlite3, stmt: turso_core::Statement) -> Self { + let n_cols = stmt.num_columns(); Self { db, stmt, destructors: Vec::new(), next: std::ptr::null_mut(), + text_cache: vec![vec![]; n_cols], + } + } + #[inline] + fn clear_text_cache(&mut self) { + // Drop per-column buffers for the previous row + for r in &mut self.text_cache { + r.clear(); } } } @@ -323,7 +333,7 @@ pub unsafe extern "C" fn sqlite3_finalize(stmt: *mut sqlite3_stmt) -> ffi::c_int destructor_fn(ptr); } } - + stmt_ref.clear_text_cache(); let _ = Box::from_raw(stmt); SQLITE_OK } @@ -340,9 +350,15 @@ pub unsafe extern "C" fn sqlite3_step(stmt: *mut sqlite3_stmt) -> ffi::c_int { stmt.stmt.run_once().unwrap(); continue; } - turso_core::StepResult::Done => return SQLITE_DONE, + turso_core::StepResult::Done => { + stmt.clear_text_cache(); + return SQLITE_DONE; + } turso_core::StepResult::Interrupt => return SQLITE_INTERRUPT, - turso_core::StepResult::Row => return SQLITE_ROW, + turso_core::StepResult::Row => { + stmt.clear_text_cache(); + return SQLITE_ROW; + } turso_core::StepResult::Busy => return SQLITE_BUSY, } } else { @@ -364,31 +380,279 @@ type exec_callback = Option< pub unsafe extern "C" fn sqlite3_exec( db: *mut sqlite3, sql: *const ffi::c_char, - _callback: exec_callback, - _context: *mut ffi::c_void, - _err: *mut *mut ffi::c_char, + callback: exec_callback, + context: *mut ffi::c_void, + err: *mut *mut ffi::c_char, ) -> ffi::c_int { if db.is_null() || sql.is_null() { return SQLITE_MISUSE; } - let db: &mut sqlite3 = &mut *db; - let db = db.inner.lock().unwrap(); - let sql = CStr::from_ptr(sql); - let sql = match sql.to_str() { + + let db_ref: &mut sqlite3 = &mut *db; + let sql_cstr = CStr::from_ptr(sql); + let sql_str = match sql_cstr.to_str() { Ok(s) => s, Err(_) => return SQLITE_MISUSE, }; - trace!("sqlite3_exec(sql={})", sql); - match db.conn.execute(sql) { - Ok(_) => SQLITE_OK, - Err(_) => SQLITE_ERROR, + trace!("sqlite3_exec(sql={})", sql_str); + if !err.is_null() { + *err = std::ptr::null_mut(); } + let statements = split_sql_statements(sql_str); + for stmt_sql in statements { + let trimmed = stmt_sql.trim(); + if trimmed.is_empty() { + continue; + } + + let is_dql = is_query_statement(trimmed); + if !is_dql { + // For DML/DDL, use normal execute path + let db_inner = db_ref.inner.lock().unwrap(); + match db_inner.conn.execute(trimmed) { + Ok(_) => continue, + Err(e) => { + if !err.is_null() { + let err_msg = format!("SQL error: {e:?}"); + *err = CString::new(err_msg).unwrap().into_raw(); + } + return SQLITE_ERROR; + } + } + } else if callback.is_none() { + // DQL without callback provided, still execute but discard any result rows + let mut stmt_ptr: *mut sqlite3_stmt = std::ptr::null_mut(); + let rc = sqlite3_prepare_v2( + db, + CString::new(trimmed).unwrap().as_ptr(), + -1, + &mut stmt_ptr, + std::ptr::null_mut(), + ); + if rc != SQLITE_OK { + if !err.is_null() { + let err_msg = format!("Prepare failed: {rc}"); + *err = CString::new(err_msg).unwrap().into_raw(); + } + return rc; + } + loop { + let step_rc = sqlite3_step(stmt_ptr); + match step_rc { + SQLITE_ROW => continue, + SQLITE_DONE => break, + _ => { + sqlite3_finalize(stmt_ptr); + if !err.is_null() { + let err_msg = format!("Step failed: {step_rc}"); + *err = CString::new(err_msg).unwrap().into_raw(); + } + return step_rc; + } + } + } + sqlite3_finalize(stmt_ptr); + } else { + // DQL with callback + let rc = execute_query_with_callback(db, trimmed, callback, context, err); + if rc != SQLITE_OK { + return rc; + } + } + } + SQLITE_OK +} + +/// Detect if a SQL statement is DQL +fn is_query_statement(sql: &str) -> bool { + let trimmed = sql.trim_start(); + if trimmed.is_empty() { + return false; + } + let bytes = trimmed.as_bytes(); + + let starts_with_ignore_case = |keyword: &[u8]| -> bool { + if bytes.len() < keyword.len() { + return false; + } + // Check keyword matches + if !bytes[..keyword.len()].eq_ignore_ascii_case(keyword) { + return false; + } + // Ensure keyword is followed by whitespace or EOF + bytes.len() == keyword.len() || bytes[keyword.len()].is_ascii_whitespace() + }; + + // Check DQL keywords + if starts_with_ignore_case(b"SELECT") + || starts_with_ignore_case(b"VALUES") + || starts_with_ignore_case(b"WITH") + || starts_with_ignore_case(b"PRAGMA") + || starts_with_ignore_case(b"EXPLAIN") + { + return true; + } + + // Look for RETURNING as a whole word, that's not part of another identifier + let mut i = 0; + while i < bytes.len() { + if i + 9 <= bytes.len() && bytes[i..i + 9].eq_ignore_ascii_case(b"RETURNING") { + // Check it's a word boundary before and after + let is_word_start = + i == 0 || !bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_'; + let is_word_end = i + 9 == bytes.len() + || !bytes[i + 9].is_ascii_alphanumeric() && bytes[i + 9] != b'_'; + if is_word_start && is_word_end { + return true; + } + } + i += 1; + } + false +} + +/// Execute a query statement with callback for each row +/// Only called when we know callback is Some +unsafe fn execute_query_with_callback( + db: *mut sqlite3, + sql: &str, + callback: exec_callback, + context: *mut ffi::c_void, + err: *mut *mut ffi::c_char, +) -> ffi::c_int { + let sql_cstring = match CString::new(sql) { + Ok(s) => s, + Err(_) => return SQLITE_MISUSE, + }; + + let mut stmt_ptr: *mut sqlite3_stmt = std::ptr::null_mut(); + let rc = sqlite3_prepare_v2( + db, + sql_cstring.as_ptr(), + -1, + &mut stmt_ptr, + std::ptr::null_mut(), + ); + + if rc != SQLITE_OK { + if !err.is_null() { + let err_msg = format!("Prepare failed: {rc}"); + *err = CString::new(err_msg).unwrap().into_raw(); + } + return rc; + } + + let stmt_ref = &*stmt_ptr; + let n_cols = stmt_ref.stmt.num_columns() as ffi::c_int; + let mut column_names: Vec = Vec::with_capacity(n_cols as usize); + + for i in 0..n_cols { + let name = stmt_ref.stmt.get_column_name(i as usize); + column_names.push(CString::new(name.as_bytes()).unwrap()); + } + + loop { + let step_rc = sqlite3_step(stmt_ptr); + + match step_rc { + SQLITE_ROW => { + // Safety: checked earlier + let callback = callback.unwrap(); + + let mut values: Vec = Vec::with_capacity(n_cols as usize); + let mut value_ptrs: Vec<*mut ffi::c_char> = Vec::with_capacity(n_cols as usize); + let mut col_ptrs: Vec<*mut ffi::c_char> = Vec::with_capacity(n_cols as usize); + + for i in 0..n_cols { + let val = stmt_ref.stmt.row().unwrap().get_value(i as usize); + values.push(CString::new(val.to_string().as_bytes()).unwrap()); + } + + for value in &values { + value_ptrs.push(value.as_ptr() as *mut ffi::c_char); + } + for name in &column_names { + col_ptrs.push(name.as_ptr() as *mut ffi::c_char); + } + + let cb_rc = callback( + context, + n_cols, + value_ptrs.as_mut_ptr(), + col_ptrs.as_mut_ptr(), + ); + + if cb_rc != 0 { + sqlite3_finalize(stmt_ptr); + return SQLITE_ABORT; + } + } + SQLITE_DONE => { + break; + } + _ => { + sqlite3_finalize(stmt_ptr); + if !err.is_null() { + let err_msg = format!("Step failed: {step_rc}"); + *err = CString::new(err_msg).unwrap().into_raw(); + } + return step_rc; + } + } + } + + sqlite3_finalize(stmt_ptr) +} + +/// Split SQL string into individual statements +/// Handles quoted strings properly and skips comments +fn split_sql_statements(sql: &str) -> Vec<&str> { + let mut statements = Vec::new(); + let mut current_start = 0; + let mut in_single_quote = false; + let mut in_double_quote = false; + let bytes = sql.as_bytes(); + let mut i = 0; + + while i < bytes.len() { + match bytes[i] { + // Check for escaped quotes first + b'\'' if !in_double_quote => { + if i + 1 < bytes.len() && bytes[i + 1] == b'\'' { + i += 2; + continue; + } + in_single_quote = !in_single_quote; + } + b'"' if !in_single_quote => { + if i + 1 < bytes.len() && bytes[i + 1] == b'"' { + i += 2; + continue; + } + in_double_quote = !in_double_quote; + } + b';' if !in_single_quote && !in_double_quote => { + // we found the statement boundary + statements.push(&sql[current_start..i]); + current_start = i + 1; + } + _ => {} + } + i += 1; + } + + if current_start < sql.len() { + statements.push(&sql[current_start..]); + } + + statements } #[no_mangle] pub unsafe extern "C" fn sqlite3_reset(stmt: *mut sqlite3_stmt) -> ffi::c_int { let stmt = &mut *stmt; stmt.stmt.reset(); + stmt.clear_text_cache(); SQLITE_OK } @@ -1048,14 +1312,30 @@ pub unsafe extern "C" fn sqlite3_column_text( stmt: *mut sqlite3_stmt, idx: ffi::c_int, ) -> *const ffi::c_uchar { + if stmt.is_null() || idx < 0 { + return std::ptr::null(); + } let stmt = &mut *stmt; let row = stmt.stmt.row(); let row = match row.as_ref() { Some(row) => row, None => return std::ptr::null(), }; - match row.get::<&Value>(idx as usize) { - Ok(turso_core::Value::Text(text)) => text.as_str().as_ptr(), + let i = idx as usize; + if i >= stmt.text_cache.len() { + return std::ptr::null(); + } + if !stmt.text_cache[i].is_empty() { + // we have already cached this value + return stmt.text_cache[i].as_ptr() as *const ffi::c_uchar; + } + match row.get::<&Value>(i) { + Ok(turso_core::Value::Text(text)) => { + let buf = &mut stmt.text_cache[i]; + buf.extend(text.as_str().as_bytes()); + buf.push(0); + buf.as_ptr() as *const ffi::c_uchar + } _ => std::ptr::null(), } } diff --git a/sqlite3/tests/compat/mod.rs b/sqlite3/tests/compat/mod.rs index d04b933e8..a1f016ddc 100644 --- a/sqlite3/tests/compat/mod.rs +++ b/sqlite3/tests/compat/mod.rs @@ -20,6 +20,21 @@ extern "C" { fn sqlite3_close(db: *mut sqlite3) -> i32; fn sqlite3_open(filename: *const libc::c_char, db: *mut *mut sqlite3) -> i32; fn sqlite3_db_filename(db: *mut sqlite3, db_name: *const libc::c_char) -> *const libc::c_char; + fn sqlite3_exec( + db: *mut sqlite3, + sql: *const libc::c_char, + callback: Option< + unsafe extern "C" fn( + arg1: *mut libc::c_void, + arg2: libc::c_int, + arg3: *mut *mut libc::c_char, + arg4: *mut *mut libc::c_char, + ) -> libc::c_int, + >, + arg: *mut libc::c_void, + errmsg: *mut *mut libc::c_char, + ) -> i32; + fn sqlite3_free(ptr: *mut libc::c_void); fn sqlite3_prepare_v2( db: *mut sqlite3, sql: *const libc::c_char, @@ -106,6 +121,7 @@ const SQLITE_CHECKPOINT_RESTART: i32 = 2; const SQLITE_CHECKPOINT_TRUNCATE: i32 = 3; const SQLITE_INTEGER: i32 = 1; const SQLITE_FLOAT: i32 = 2; +const SQLITE_ABORT: i32 = 4; const SQLITE_TEXT: i32 = 3; const SQLITE3_TEXT: i32 = 3; const SQLITE_BLOB: i32 = 4; @@ -412,6 +428,42 @@ mod tests { } } + #[test] + #[cfg(not(target_os = "windows"))] + fn column_text_is_nul_terminated_and_bytes_match() { + unsafe { + let mut db = std::ptr::null_mut(); + assert_eq!( + sqlite3_open(c"../testing/testing.db".as_ptr(), &mut db), + SQLITE_OK + ); + let mut stmt = std::ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT first_name FROM users ORDER BY rowid ASC LIMIT 1;".as_ptr(), + -1, + &mut stmt, + std::ptr::null_mut() + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + let p = sqlite3_column_text(stmt, 0); + assert!(!p.is_null()); + let bytes = sqlite3_column_bytes(stmt, 0) as usize; + // NUL at [bytes], and no extra counted + let slice = std::slice::from_raw_parts(p, bytes + 1); + assert_eq!(slice[bytes], 0); + assert_eq!(libc::strlen(p), bytes); + + let s = std::ffi::CStr::from_ptr(p).to_str().unwrap(); + assert_eq!(s, "Jamie"); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + #[test] fn test_sqlite3_bind_text() { unsafe { @@ -726,6 +778,744 @@ mod tests { } } + #[test] + fn test_exec_multi_statement_dml() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Multiple DML statements in one exec call + let rc = sqlite3_exec( + db, + c"CREATE TABLE bind_text(x TEXT);\ + INSERT INTO bind_text(x) VALUES('TEXT1');\ + INSERT INTO bind_text(x) VALUES('TEXT2');" + .as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + // Verify the data was inserted + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT COUNT(*) FROM bind_text".as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + assert_eq!(sqlite3_column_int(stmt, 0), 2); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_multi_statement_with_semicolons_in_strings() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Semicolons inside strings should not split statements + let rc = sqlite3_exec( + db, + c"CREATE TABLE test_semicolon(x TEXT);\ + INSERT INTO test_semicolon(x) VALUES('value;with;semicolons');\ + INSERT INTO test_semicolon(x) VALUES(\"another;value\");" + .as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + // Verify the values contain semicolons + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT x FROM test_semicolon ORDER BY rowid".as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ), + SQLITE_OK + ); + + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + let val1 = std::ffi::CStr::from_ptr(sqlite3_column_text(stmt, 0)) + .to_str() + .unwrap(); + assert_eq!(val1, "value;with;semicolons"); + + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + let val2 = std::ffi::CStr::from_ptr(sqlite3_column_text(stmt, 0)) + .to_str() + .unwrap(); + assert_eq!(val2, "another;value"); + + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_multi_statement_with_escaped_quotes() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Test escaped quotes + let rc = sqlite3_exec( + db, + c"CREATE TABLE test_quotes(x TEXT);\ + INSERT INTO test_quotes(x) VALUES('it''s working');\ + INSERT INTO test_quotes(x) VALUES(\"quote\"\"test\"\"\");" + .as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT x FROM test_quotes ORDER BY rowid".as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ), + SQLITE_OK + ); + + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + let val1 = std::ffi::CStr::from_ptr(sqlite3_column_text(stmt, 0)) + .to_str() + .unwrap(); + assert_eq!(val1, "it's working"); + + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + let val2 = std::ffi::CStr::from_ptr(sqlite3_column_text(stmt, 0)) + .to_str() + .unwrap(); + assert_eq!(val2, "quote\"test\""); + + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_with_select_callback() { + unsafe { + // Callback that collects results + unsafe extern "C" fn exec_callback( + context: *mut std::ffi::c_void, + n_cols: std::ffi::c_int, + values: *mut *mut std::ffi::c_char, + _cols: *mut *mut std::ffi::c_char, + ) -> std::ffi::c_int { + let results = &mut *(context as *mut Vec>); + let mut row = Vec::new(); + + for i in 0..n_cols as isize { + let value_ptr = *values.offset(i); + let value = if value_ptr.is_null() { + String::from("NULL") + } else { + std::ffi::CStr::from_ptr(value_ptr) + .to_str() + .unwrap() + .to_owned() + }; + row.push(value); + } + results.push(row); + 0 // Continue + } + + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Setup data + let rc = sqlite3_exec( + db, + c"CREATE TABLE test_select(id INTEGER, name TEXT);\ + INSERT INTO test_select VALUES(1, 'Alice');\ + INSERT INTO test_select VALUES(2, 'Bob');" + .as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + // Execute SELECT with callback + let mut results: Vec> = Vec::new(); + let rc = sqlite3_exec( + db, + c"SELECT id, name FROM test_select ORDER BY id".as_ptr(), + Some(exec_callback), + &mut results as *mut _ as *mut std::ffi::c_void, + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + assert_eq!(results.len(), 2); + assert_eq!(results[0], vec!["1", "Alice"]); + assert_eq!(results[1], vec!["2", "Bob"]); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_multi_statement_mixed_dml_select() { + unsafe { + // Callback that counts invocations + unsafe extern "C" fn count_callback( + context: *mut std::ffi::c_void, + _n_cols: std::ffi::c_int, + _values: *mut *mut std::ffi::c_char, + _cols: *mut *mut std::ffi::c_char, + ) -> std::ffi::c_int { + let count = &mut *(context as *mut i32); + *count += 1; + 0 + } + + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + let mut callback_count = 0; + + // Mix of DDL/DML/DQL + let rc = sqlite3_exec( + db, + c"CREATE TABLE mixed(x INTEGER);\ + INSERT INTO mixed VALUES(1);\ + INSERT INTO mixed VALUES(2);\ + SELECT x FROM mixed;\ + INSERT INTO mixed VALUES(3);\ + SELECT COUNT(*) FROM mixed;" + .as_ptr(), + Some(count_callback), + &mut callback_count as *mut _ as *mut std::ffi::c_void, + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + // Callback should be called 3 times total: + // 2 times for first SELECT (2 rows) + // 1 time for second SELECT (1 row with COUNT) + assert_eq!(callback_count, 3); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_callback_abort() { + unsafe { + // Callback that aborts after first row + unsafe extern "C" fn abort_callback( + context: *mut std::ffi::c_void, + _n_cols: std::ffi::c_int, + _values: *mut *mut std::ffi::c_char, + _cols: *mut *mut std::ffi::c_char, + ) -> std::ffi::c_int { + let count = &mut *(context as *mut i32); + *count += 1; + if *count >= 1 { + return 1; // Abort + } + 0 + } + + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + sqlite3_exec( + db, + c"CREATE TABLE test(x INTEGER);\ + INSERT INTO test VALUES(1),(2),(3);" + .as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + + let mut count = 0; + let rc = sqlite3_exec( + db, + c"SELECT x FROM test".as_ptr(), + Some(abort_callback), + &mut count as *mut _ as *mut std::ffi::c_void, + ptr::null_mut(), + ); + + assert_eq!(rc, SQLITE_ABORT); + assert_eq!(count, 1); // Only processed one row before aborting + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_error_stops_execution() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + let mut err_msg = ptr::null_mut(); + + // Second statement has error, third should not execute + let rc = sqlite3_exec( + db, + c"CREATE TABLE test(x INTEGER);\ + INSERT INTO nonexistent VALUES(1);\ + CREATE TABLE should_not_exist(y INTEGER);" + .as_ptr(), + None, + ptr::null_mut(), + &mut err_msg, + ); + + assert_eq!(rc, SQLITE_ERROR); + + // Verify third statement didn't execute + let mut stmt = ptr::null_mut(); + let check_rc = sqlite3_prepare_v2( + db, + c"SELECT name FROM sqlite_master WHERE type='table' AND name='should_not_exist'" + .as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ); + assert_eq!(check_rc, SQLITE_OK); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); // No rows = table doesn't exist + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + if !err_msg.is_null() { + sqlite3_free(err_msg as *mut std::ffi::c_void); + } + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_empty_statements() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Multiple semicolons and whitespace should be handled gracefully + let rc = sqlite3_exec( + db, + c"CREATE TABLE test(x INTEGER);;;\n\n;\t;INSERT INTO test VALUES(1);;;".as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + // Verify both statements executed + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT x FROM test".as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + assert_eq!(sqlite3_column_int(stmt, 0), 1); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + #[test] + fn test_exec_with_comments() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // SQL comments shouldn't affect statement splitting + let rc = sqlite3_exec( + db, + c"-- This is a comment\n\ + CREATE TABLE test(x INTEGER); -- inline comment\n\ + INSERT INTO test VALUES(1); -- semicolon in comment ;\n\ + INSERT INTO test VALUES(2) -- end with comment" + .as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + // Verify both inserts worked + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT COUNT(*) FROM test".as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + assert_eq!(sqlite3_column_int(stmt, 0), 2); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_nested_quotes() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Mix of quote types and nesting + let rc = sqlite3_exec( + db, + c"CREATE TABLE test(x TEXT);\ + INSERT INTO test VALUES('single \"double\" inside');\ + INSERT INTO test VALUES(\"double 'single' inside\");\ + INSERT INTO test VALUES('mix;\"quote\";types');" + .as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + // Verify values + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT x FROM test ORDER BY rowid".as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ), + SQLITE_OK + ); + + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + let val1 = std::ffi::CStr::from_ptr(sqlite3_column_text(stmt, 0)) + .to_str() + .unwrap(); + assert_eq!(val1, "single \"double\" inside"); + + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + let val2 = std::ffi::CStr::from_ptr(sqlite3_column_text(stmt, 0)) + .to_str() + .unwrap(); + assert_eq!(val2, "double 'single' inside"); + + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + let val3 = std::ffi::CStr::from_ptr(sqlite3_column_text(stmt, 0)) + .to_str() + .unwrap(); + assert_eq!(val3, "mix;\"quote\";types"); + + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_transaction_rollback() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Test transaction rollback in multi-statement + let rc = sqlite3_exec( + db, + c"CREATE TABLE test(x INTEGER);\ + BEGIN TRANSACTION;\ + INSERT INTO test VALUES(1);\ + INSERT INTO test VALUES(2);\ + ROLLBACK;" + .as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + // Table should exist but be empty due to rollback + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT COUNT(*) FROM test".as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + assert_eq!(sqlite3_column_int(stmt, 0), 0); // No rows due to rollback + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_with_pragma() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Callback to capture pragma results + unsafe extern "C" fn pragma_callback( + context: *mut std::ffi::c_void, + _n_cols: std::ffi::c_int, + _values: *mut *mut std::ffi::c_char, + _cols: *mut *mut std::ffi::c_char, + ) -> std::ffi::c_int { + let count = &mut *(context as *mut i32); + *count += 1; + 0 + } + + let mut callback_count = 0; + + // PRAGMA should be treated as DQL when it returns results + let rc = sqlite3_exec( + db, + c"CREATE TABLE test(x INTEGER);\ + PRAGMA table_info(test);" + .as_ptr(), + Some(pragma_callback), + &mut callback_count as *mut _ as *mut std::ffi::c_void, + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + assert!(callback_count > 0); // PRAGMA should return at least one row + + // PRAGMA without callback should discard row + let mut err_msg = ptr::null_mut(); + let rc = sqlite3_exec( + db, + c"PRAGMA table_info(test)".as_ptr(), + None, + ptr::null_mut(), + &mut err_msg, + ); + assert_eq!(rc, SQLITE_OK); + if !err_msg.is_null() { + sqlite3_free(err_msg as *mut std::ffi::c_void); + } + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_with_cte() { + unsafe { + // Callback that collects results + unsafe extern "C" fn exec_callback( + context: *mut std::ffi::c_void, + n_cols: std::ffi::c_int, + values: *mut *mut std::ffi::c_char, + _cols: *mut *mut std::ffi::c_char, + ) -> std::ffi::c_int { + let results = &mut *(context as *mut Vec>); + let mut row = Vec::new(); + for i in 0..n_cols as isize { + let value_ptr = *values.offset(i); + let value = if value_ptr.is_null() { + String::from("NULL") + } else { + std::ffi::CStr::from_ptr(value_ptr) + .to_str() + .unwrap() + .to_owned() + }; + row.push(value); + } + results.push(row); + 0 + } + + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // CTE should be recognized as DQL + let mut results: Vec> = Vec::new(); + let rc = sqlite3_exec( + db, + c"CREATE TABLE test(x INTEGER);\ + INSERT INTO test VALUES(1),(2),(3);\ + WITH cte AS (SELECT x FROM test WHERE x > 1) SELECT * FROM cte;" + .as_ptr(), + Some(exec_callback), + &mut results as *mut _ as *mut std::ffi::c_void, + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + assert_eq!(results.len(), 2); // Should get 2 and 3 + assert_eq!(results[0], vec!["2"]); + assert_eq!(results[1], vec!["3"]); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_exec_with_returning_clause() { + unsafe { + // Callback for RETURNING results + unsafe extern "C" fn exec_callback( + context: *mut std::ffi::c_void, + n_cols: std::ffi::c_int, + values: *mut *mut std::ffi::c_char, + _cols: *mut *mut std::ffi::c_char, + ) -> std::ffi::c_int { + let results = &mut *(context as *mut Vec>); + let mut row = Vec::new(); + for i in 0..n_cols as isize { + let value_ptr = *values.offset(i); + let value = if value_ptr.is_null() { + String::from("NULL") + } else { + std::ffi::CStr::from_ptr(value_ptr) + .to_str() + .unwrap() + .to_owned() + }; + row.push(value); + } + results.push(row); + 0 + } + + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + let mut results: Vec> = Vec::new(); + + // INSERT...RETURNING with callback should capture the returned values + let rc = sqlite3_exec( + db, + c"CREATE TABLE test(id INTEGER PRIMARY KEY, x INTEGER);\ + INSERT INTO test(x) VALUES(42) RETURNING id, x;" + .as_ptr(), + Some(exec_callback), + &mut results as *mut _ as *mut std::ffi::c_void, + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + assert_eq!(results.len(), 1); + assert_eq!(results[0][1], "42"); // x value + + // Add another row for testing + sqlite3_exec( + db, + c"INSERT INTO test(x) VALUES(99)".as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + + // should still delete the row but discard the RETURNING results + let rc = sqlite3_exec( + db, + c"UPDATE test SET id = 3, x = 41 WHERE x=42 RETURNING id".as_ptr(), + None, + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(rc, SQLITE_OK); + + // Verify the row was actually updated + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT COUNT(*) FROM test WHERE x=42".as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + assert_eq!(sqlite3_column_int(stmt, 0), 0); // Should be 0 rows with x=42 + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + // Verify + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT COUNT(*) FROM test".as_ptr(), + -1, + &mut stmt, + ptr::null_mut(), + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + assert_eq!(sqlite3_column_int(stmt, 0), 2); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + #[cfg(not(feature = "sqlite3"))] mod libsql_ext { diff --git a/stress/Cargo.toml b/stress/Cargo.toml index 2e51f003c..7221e80ca 100644 --- a/stress/Cargo.toml +++ b/stress/Cargo.toml @@ -16,7 +16,7 @@ path = "main.rs" [features] default = ["experimental_indexes"] -antithesis = ["turso/antithesis"] +antithesis = ["turso/antithesis", "antithesis_sdk/full"] experimental_indexes = [] [dependencies] @@ -30,3 +30,4 @@ tracing = { workspace = true } tracing-appender = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } turso = { workspace = true } +rusqlite = { workspace = true } diff --git a/stress/main.rs b/stress/main.rs index 5091bd05b..879002547 100644 --- a/stress/main.rs +++ b/stress/main.rs @@ -450,6 +450,31 @@ pub fn init_tracing() -> Result { Ok(guard) } +fn integrity_check( + db_path: &std::path::Path, +) -> Result<(), Box> { + assert!(db_path.exists()); + let conn = rusqlite::Connection::open(db_path)?; + let mut stmt = conn.prepare("SELECT * FROM pragma_integrity_check;")?; + let mut rows = stmt.query(())?; + let mut result: Vec = Vec::new(); + + while let Some(row) = rows.next()? { + result.push(row.get(0)?); + } + if result.is_empty() { + return Err( + "simulation failed: integrity_check should return `ok` or a list of problems".into(), + ); + } + if !result[0].eq_ignore_ascii_case("ok") { + // Build a list of problems + result.iter_mut().for_each(|row| *row = format!("- {row}")); + return Err(format!("simulation failed: {}", result.join("\n")).into()); + } + Ok(()) +} + #[tokio::main] async fn main() -> Result<(), Box> { let _g = init_tracing()?; @@ -493,6 +518,8 @@ async fn main() -> Result<(), Box> { let plan = plan.clone(); let conn = db.lock().await.connect()?; + conn.busy_timeout(std::time::Duration::from_millis(opts.busy_timeout))?; + conn.execute("PRAGMA data_sync_retry = 1", ()).await?; // Apply each DDL statement individually @@ -525,6 +552,8 @@ async fn main() -> Result<(), Box> { let handle = tokio::spawn(async move { let mut conn = db.lock().await.connect()?; + conn.busy_timeout(std::time::Duration::from_millis(opts.busy_timeout))?; + conn.execute("PRAGMA data_sync_retry = 1", ()).await?; println!("\rExecuting queries..."); @@ -542,6 +571,7 @@ async fn main() -> Result<(), Box> { } *db_guard = builder.build().await?; conn = db_guard.connect()?; + conn.busy_timeout(std::time::Duration::from_millis(opts.busy_timeout))?; } else if gen_bool(0.0) { // disabled // Reconnect to the database @@ -550,6 +580,7 @@ async fn main() -> Result<(), Box> { } let db_guard = db.lock().await; conn = db_guard.connect()?; + conn.busy_timeout(std::time::Duration::from_millis(opts.busy_timeout))?; } let sql = &plan.queries_per_thread[thread][query_index]; if !opts.silent { @@ -598,6 +629,8 @@ async fn main() -> Result<(), Box> { } } } + // In case this thread is running an exclusive transaction, commit it so that it doesn't block other threads. + let _ = conn.execute("COMMIT", ()).await; Ok::<_, Box>(()) }); handles.push(handle); @@ -608,5 +641,12 @@ async fn main() -> Result<(), Box> { } println!("Done. SQL statements written to {}", opts.log_file); println!("Database file: {db_file}"); + + #[cfg(not(miri))] + { + println!("Running SQLite Integrity check"); + integrity_check(std::path::Path::new(&db_file))?; + } + Ok(()) } diff --git a/stress/opts.rs b/stress/opts.rs index fd53d7635..aac93f67d 100644 --- a/stress/opts.rs +++ b/stress/opts.rs @@ -21,7 +21,7 @@ pub struct Opts { short = 'i', long, help = "the number of iterations", - default_value_t = 100000 + default_value_t = normal_or_miri(100_000, 1000) )] pub nr_iterations: usize, @@ -66,4 +66,20 @@ pub struct Opts { /// Number of tables to use #[clap(long, help = "Select number of tables to create")] pub tables: Option, + + /// Busy timeout in milliseconds + #[clap( + long, + help = "Set busy timeout in milliseconds", + default_value_t = 5000 + )] + pub busy_timeout: u64, +} + +const fn normal_or_miri(normal_val: T, miri_val: T) -> T { + if cfg!(miri) { + miri_val + } else { + normal_val + } } diff --git a/stress/run-miri.sh b/stress/run-miri.sh new file mode 100755 index 000000000..2de0069e5 --- /dev/null +++ b/stress/run-miri.sh @@ -0,0 +1,4 @@ +#!/bin/bash + + +MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-disable-stacked-borrows" cargo +nightly miri run -p turso_stress -- "$@" diff --git a/testing/all.test b/testing/all.test index 4d578e31d..602174abf 100755 --- a/testing/all.test +++ b/testing/all.test @@ -47,3 +47,4 @@ source $testdir/vtab.test source $testdir/upsert.test source $testdir/window.test source $testdir/partial_idx.test +source $testdir/foreign_keys.test diff --git a/testing/alter_column.test b/testing/alter_column.test index 3672497ab..1b4da6dd0 100755 --- a/testing/alter_column.test +++ b/testing/alter_column.test @@ -22,3 +22,219 @@ do_execsql_test_in_memory_any_error fail-alter-column-unique { CREATE TABLE t (a); ALTER TABLE t ALTER COLUMN a TO a UNIQUE; } + +do_execsql_test_on_specific_db {:memory:} alter-table-rename-pk-column { + CREATE TABLE customers (cust_id INTEGER PRIMARY KEY, cust_name TEXT); + INSERT INTO customers VALUES (1, 'Alice'), (2, 'Bob'); + + ALTER TABLE customers RENAME COLUMN cust_id TO customer_id; + + SELECT sql FROM sqlite_schema WHERE name = 'customers'; + SELECT customer_id, cust_name FROM customers ORDER BY customer_id; +} { + "CREATE TABLE customers (customer_id INTEGER PRIMARY KEY, cust_name TEXT)" + "1|Alice" + "2|Bob" +} + +do_execsql_test_on_specific_db {:memory:} alter-table-rename-composite-pk { + CREATE TABLE products (category TEXT, prod_code TEXT, name TEXT, PRIMARY KEY (category, prod_code)); + INSERT INTO products VALUES ('Electronics', 'E001', 'Laptop'); + + ALTER TABLE products RENAME COLUMN prod_code TO product_code; + + SELECT sql FROM sqlite_schema WHERE name = 'products'; + SELECT category, product_code, name FROM products; +} { + "CREATE TABLE products (category TEXT, product_code TEXT, name TEXT, PRIMARY KEY (category, product_code))" + "Electronics|E001|Laptop" +} + +# Foreign key child column rename +do_execsql_test_on_specific_db {:memory:} alter-table-rename-fk-child { + CREATE TABLE parent (id INTEGER PRIMARY KEY); + CREATE TABLE child (cid INTEGER PRIMARY KEY, pid INTEGER, FOREIGN KEY (pid) REFERENCES parent(id)); + INSERT INTO parent VALUES (1); + INSERT INTO child VALUES (1, 1); + + ALTER TABLE child RENAME COLUMN pid TO parent_id; + + SELECT sql FROM sqlite_schema WHERE name = 'child'; +} { + "CREATE TABLE child (cid INTEGER PRIMARY KEY, parent_id INTEGER, FOREIGN KEY (parent_id) REFERENCES parent (id))" +} + +# Foreign key parent column rename - critical test +do_execsql_test_on_specific_db {:memory:} alter-table-rename-fk-parent { + CREATE TABLE orders (order_id INTEGER PRIMARY KEY, date TEXT); + CREATE TABLE items (item_id INTEGER PRIMARY KEY, oid INTEGER, FOREIGN KEY (oid) REFERENCES orders(order_id)); + + ALTER TABLE orders RENAME COLUMN order_id TO ord_id; + + SELECT sql FROM sqlite_schema WHERE name = 'orders'; + SELECT sql FROM sqlite_schema WHERE name = 'items'; +} { + "CREATE TABLE orders (ord_id INTEGER PRIMARY KEY, date TEXT)" + "CREATE TABLE items (item_id INTEGER PRIMARY KEY, oid INTEGER, FOREIGN KEY (oid) REFERENCES orders (ord_id))" +} + +# Composite foreign key parent rename +do_execsql_test_on_specific_db {:memory:} alter-table-rename-composite-fk-parent { + CREATE TABLE products (cat TEXT, code TEXT, PRIMARY KEY (cat, code)); + CREATE TABLE inventory (id INTEGER PRIMARY KEY, cat TEXT, code TEXT, FOREIGN KEY (cat, code) REFERENCES products(cat, code)); + + ALTER TABLE products RENAME COLUMN code TO sku; + + SELECT sql FROM sqlite_schema WHERE name = 'products'; + SELECT sql FROM sqlite_schema WHERE name = 'inventory'; +} { + "CREATE TABLE products (cat TEXT, sku TEXT, PRIMARY KEY (cat, sku))" + "CREATE TABLE inventory (id INTEGER PRIMARY KEY, cat TEXT, code TEXT, FOREIGN KEY (cat, code) REFERENCES products (cat, sku))" +} + +# Multiple foreign keys to same parent +do_execsql_test_on_specific_db {:memory:} alter-table-rename-multiple-fks { + CREATE TABLE users (uid INTEGER PRIMARY KEY); + CREATE TABLE messages (mid INTEGER PRIMARY KEY, sender INTEGER, receiver INTEGER, + FOREIGN KEY (sender) REFERENCES users(uid), + FOREIGN KEY (receiver) REFERENCES users(uid)); + + ALTER TABLE users RENAME COLUMN uid TO user_id; + + SELECT sql FROM sqlite_schema WHERE name = 'messages'; +} { + "CREATE TABLE messages (mid INTEGER PRIMARY KEY, sender INTEGER, receiver INTEGER, FOREIGN KEY (sender) REFERENCES users (user_id), FOREIGN KEY (receiver) REFERENCES users (user_id))" +} + +# Self-referencing foreign key +do_execsql_test_on_specific_db {:memory:} alter-table-rename-self-ref-fk { + CREATE TABLE employees (emp_id INTEGER PRIMARY KEY, manager_id INTEGER, + FOREIGN KEY (manager_id) REFERENCES employees(emp_id)); + + ALTER TABLE employees RENAME COLUMN emp_id TO employee_id; + + SELECT sql FROM sqlite_schema WHERE name = 'employees'; +} { + "CREATE TABLE employees (employee_id INTEGER PRIMARY KEY, manager_id INTEGER, FOREIGN KEY (manager_id) REFERENCES employees (employee_id))" +} + +# Chain of FK renames - parent is both PK and referenced +do_execsql_test_on_specific_db {:memory:} alter-table-rename-fk-chain { + CREATE TABLE t1 (a INTEGER PRIMARY KEY); + CREATE TABLE t2 (b INTEGER PRIMARY KEY, a_ref INTEGER, FOREIGN KEY (a_ref) REFERENCES t1(a)); + CREATE TABLE t3 (c INTEGER PRIMARY KEY, b_ref INTEGER, FOREIGN KEY (b_ref) REFERENCES t2(b)); + + ALTER TABLE t1 RENAME COLUMN a TO a_new; + ALTER TABLE t2 RENAME COLUMN b TO b_new; + + SELECT sql FROM sqlite_schema WHERE name = 't2'; + SELECT sql FROM sqlite_schema WHERE name = 't3'; +} { + "CREATE TABLE t2 (b_new INTEGER PRIMARY KEY, a_ref INTEGER, FOREIGN KEY (a_ref) REFERENCES t1 (a_new))" + "CREATE TABLE t3 (c INTEGER PRIMARY KEY, b_ref INTEGER, FOREIGN KEY (b_ref) REFERENCES t2 (b_new))" +} + +# FK with ON DELETE/UPDATE actions +do_execsql_test_on_specific_db {:memory:} alter-table-rename-fk-actions { + CREATE TABLE parent (pid INTEGER PRIMARY KEY); + CREATE TABLE child (cid INTEGER PRIMARY KEY, pid INTEGER, + FOREIGN KEY (pid) REFERENCES parent(pid) ON DELETE CASCADE ON UPDATE RESTRICT); + + ALTER TABLE parent RENAME COLUMN pid TO parent_id; + + SELECT sql FROM sqlite_schema WHERE name = 'child'; +} { + "CREATE TABLE child (cid INTEGER PRIMARY KEY, pid INTEGER, FOREIGN KEY (pid) REFERENCES parent (parent_id) ON DELETE CASCADE ON UPDATE RESTRICT)" +} + +# FK with DEFERRABLE +do_execsql_test_on_specific_db {:memory:} alter-table-rename-fk-deferrable { + CREATE TABLE parent (id INTEGER PRIMARY KEY); + CREATE TABLE child (id INTEGER PRIMARY KEY, pid INTEGER, + FOREIGN KEY (pid) REFERENCES parent(id) DEFERRABLE INITIALLY DEFERRED); + + ALTER TABLE parent RENAME COLUMN id TO parent_id; + + SELECT sql FROM sqlite_schema WHERE name = 'child'; +} { + "CREATE TABLE child (id INTEGER PRIMARY KEY, pid INTEGER, FOREIGN KEY (pid) REFERENCES parent (parent_id) DEFERRABLE INITIALLY DEFERRED)" +} + +# Rename with quoted identifiers in FK +do_execsql_test_on_specific_db {:memory:} alter-table-rename-fk-quoted { + CREATE TABLE "parent table" ("parent id" INTEGER PRIMARY KEY); + CREATE TABLE child (id INTEGER PRIMARY KEY, pid INTEGER, + FOREIGN KEY (pid) REFERENCES "parent table"("parent id")); + + ALTER TABLE "parent table" RENAME COLUMN "parent id" TO "new id"; + + SELECT sql FROM sqlite_schema WHERE name = 'child'; +} { + "CREATE TABLE child (id INTEGER PRIMARY KEY, pid INTEGER, FOREIGN KEY (pid) REFERENCES \"parent table\" (\"new id\"))" +} + +# Verify FK constraint still works after rename +do_execsql_test_on_specific_db {:memory:} alter-table-fk-constraint-after-rename { + PRAGMA foreign_keys = ON; + CREATE TABLE parent (id INTEGER PRIMARY KEY); + CREATE TABLE child (id INTEGER PRIMARY KEY, pid INTEGER, FOREIGN KEY (pid) REFERENCES parent(id)); + INSERT INTO parent VALUES (1); + INSERT INTO child VALUES (1, 1); + + ALTER TABLE parent RENAME COLUMN id TO parent_id; + + -- This should work + INSERT INTO child VALUES (2, 1); + SELECT COUNT(*) FROM child; +} { + "2" +} + +# FK constraint violation after rename should still fail +do_execsql_test_in_memory_any_error alter-table-fk-violation-after-rename { + PRAGMA foreign_keys = ON; + CREATE TABLE parent (id INTEGER PRIMARY KEY); + CREATE TABLE child (id INTEGER PRIMARY KEY, pid INTEGER, FOREIGN KEY (pid) REFERENCES parent(id)); + INSERT INTO parent VALUES (1); + + ALTER TABLE parent RENAME COLUMN id TO parent_id; + + -- This should fail with FK violation + INSERT INTO child VALUES (1, 999); +} + +# Complex scenario with multiple table constraints +do_execsql_test_on_specific_db {:memory:} alter-table-rename-complex-constraints { + CREATE TABLE t ( + a INTEGER, + b TEXT, + c REAL, + PRIMARY KEY (a, b), + UNIQUE (b, c), + FOREIGN KEY (a) REFERENCES t(a) + ); + + ALTER TABLE t RENAME COLUMN a TO x; + ALTER TABLE t RENAME COLUMN b TO y; + + SELECT sql FROM sqlite_schema WHERE name = 't'; +} { + "CREATE TABLE t (x INTEGER, y TEXT, c REAL, PRIMARY KEY (x, y), UNIQUE (y, c), FOREIGN KEY (x) REFERENCES t (x))" +} + +# Rename column that appears in both PK and FK +do_execsql_test_on_specific_db {:memory:} alter-table-rename-pk-and-fk { + CREATE TABLE parent (id INTEGER PRIMARY KEY); + CREATE TABLE child ( + id INTEGER PRIMARY KEY, + parent_ref INTEGER, + FOREIGN KEY (id) REFERENCES parent(id), + FOREIGN KEY (parent_ref) REFERENCES parent(id) + ); + + ALTER TABLE parent RENAME COLUMN id TO pid; + + SELECT sql FROM sqlite_schema WHERE name = 'child'; +} { + "CREATE TABLE child (id INTEGER PRIMARY KEY, parent_ref INTEGER, FOREIGN KEY (id) REFERENCES parent (pid), FOREIGN KEY (parent_ref) REFERENCES parent (pid))" +} diff --git a/testing/attach.test b/testing/attach.test index 0d799a927..58d71eea1 100755 --- a/testing/attach.test +++ b/testing/attach.test @@ -73,3 +73,11 @@ do_execsql_test_error query-after-detach { DETACH DATABASE small; select * from small.sqlite_schema; } {(.*no such.*)} + +# regression test for https://github.com/tursodatabase/turso/issues/3540 +do_execsql_test_on_specifc_db {:memory:} attach-from-memory-db { + CREATE TABLE t(a); + INSERT INTO t SELECT value from generate_series(1,10); + ATTACH DATABASE 'testing/testing.db' as a; + SELECT * from a.products, t LIMIT 1; +} {1|hat|79.0|1} diff --git a/testing/autoincr.test b/testing/autoincr.test index cb28e8c5d..9b69b36a2 100755 --- a/testing/autoincr.test +++ b/testing/autoincr.test @@ -174,4 +174,17 @@ do_execsql_test_on_specific_db {:memory:} autoinc-conflict-on-nothing { INSERT INTO t (k) VALUES ('a') ON CONFLICT DO NOTHING; INSERT INTO t (k) VALUES ('b'); SELECT * FROM t ORDER BY id; -} {1|a 2|a 4|b} \ No newline at end of file +} {1|a 2|a 4|b} + +# https://github.com/tursodatabase/turso/issues/3664 +do_execsql_test_on_specific_db {:memory:} autoinc-skips-manually-updated-pk { + CREATE TABLE t(a INTEGER PRIMARY KEY AUTOINCREMENT); + INSERT INTO t DEFAULT VALUES; + select * from sqlite_sequence; + UPDATE t SET a = a + 1; + SELECT * FROM sqlite_sequence; + INSERT INTO t DEFAULT VALUES; + SELECT * FROM sqlite_sequence; +} {t|1 +t|1 +t|3} diff --git a/testing/changes.test b/testing/changes.test index ee03a2168..74d7833b2 100644 --- a/testing/changes.test +++ b/testing/changes.test @@ -32,3 +32,74 @@ do_execsql_test_on_specific_db {:memory:} changes-doesnt-track-indexes { UPDATE users SET name = 'young' where age < 40; select changes(); } {6} + +# https://github.com/tursodatabase/turso/issues/3688 +do_execsql_test_on_specific_db {:memory:} changes-1.69 { + create table t(id integer primary key, value text); + insert into t values (1, 'a'); + select changes(); + update t set id = id+10 where id = 1; + select changes(); +} {1 +1} + +do_execsql_test_on_specific_db {:memory:} changes-on-delete { + create table temp (t1 integer, primary key (t1)); + insert into temp values (1), (2), (3), (4), (5); + delete from temp where t1 > 2; + select changes(); +} {3} + +do_execsql_test_on_specific_db {:memory:} changes-on-update { + create table temp (t1 integer, t2 text, primary key (t1)); + insert into temp values (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'); + update temp set t2 = 'updated' where t1 <= 3; + select changes(); +} {3} + +do_execsql_test_on_specific_db {:memory:} changes-on-update-rowid { + create table temp (t1 integer primary key, t2 text); + insert into temp values (1, 'a'), (2, 'b'), (3, 'c'); + update temp set t1 = t1 + 10 where t1 = 2; + select changes(); +} {1} + + +do_execsql_test_on_specific_db {:memory:} changes-resets-after-select { + create table temp (t1 integer, primary key (t1)); + insert into temp values (1), (2), (3); + select * from temp; + select changes(); +} {1 +2 +3 +3} + +do_execsql_test_on_specific_db {:memory:} changes-on-delete-no-match { + create table temp (t1 integer, primary key (t1)); + insert into temp values (1), (2), (3); + delete from temp where t1 > 100; + select changes(); +} {0} + +do_execsql_test_on_specific_db {:memory:} changes-on-update-no-match { + create table temp (t1 integer, t2 text, primary key (t1)); + insert into temp values (1, 'a'), (2, 'b'); + update temp set t2 = 'updated' where t1 > 100; + select changes(); +} {0} + +do_execsql_test_on_specific_db {:memory:} changes-on-delete-all { + create table temp (t1 integer, primary key (t1)); + insert into temp values (1), (2), (3), (4), (5), (6); + delete from temp; + select changes(); +} {6} + +do_execsql_test_on_specific_db {:memory:} changes-mixed-operations { + create table temp (t1 integer, t2 text, primary key (t1)); + insert into temp values (1, 'a'), (2, 'b'), (3, 'c'); + update temp set t2 = 'updated' where t1 <= 2; + delete from temp where t1 = 1; + select changes(); +} {1} diff --git a/testing/cli_tests/cli_test_cases.py b/testing/cli_tests/cli_test_cases.py index fa07285d4..10ad90765 100755 --- a/testing/cli_tests/cli_test_cases.py +++ b/testing/cli_tests/cli_test_cases.py @@ -157,6 +157,7 @@ def test_output_file(): # Clean up os.remove(output_file) + shell.quit() def test_multi_line_single_line_comments_succession(): @@ -367,6 +368,16 @@ def test_parse_error(): lambda res: "Parse error: " in res, "Try to LIMIT using an identifier should trigger a Parse error", ) + turso.quit() + + +def test_tables_with_attached_db(): + shell = TestTursoShell() + shell.execute_dot(".open :memory:") + shell.execute_dot("CREATE TABLE orders(a);") + shell.execute_dot("ATTACH DATABASE 'testing/testing.db' AS attached;") + shell.run_test("tables-with-attached-database", ".tables", "orders attached.products attached.users") + shell.quit() def main(): @@ -393,6 +404,7 @@ def main(): test_copy_db_file() test_copy_memory_db_to_file() test_parse_error() + test_tables_with_attached_db() console.info("All tests have passed") diff --git a/testing/cli_tests/extensions.py b/testing/cli_tests/extensions.py index 5f90014a9..ad6e99687 100755 --- a/testing/cli_tests/extensions.py +++ b/testing/cli_tests/extensions.py @@ -792,12 +792,12 @@ def test_csv(): ) turso.run_test_fn( "create virtual table t1 using csv(data='1'\\'2');", - lambda res: "unrecognized token at" in res, + lambda res: "unrecognized token " in res, "Create CSV table with malformed escape sequence", ) turso.run_test_fn( "create virtual table t1 using csv(data=\"12');", - lambda res: "non-terminated literal at" in res, + lambda res: "non-terminated literal " in res, "Create CSV table with unterminated quoted string", ) diff --git a/testing/create_table.test b/testing/create_table.test index 7eb7bea7d..e0182828c 100755 --- a/testing/create_table.test +++ b/testing/create_table.test @@ -53,3 +53,65 @@ do_execsql_test_on_specific_db {:memory:} col-named-rowid { update t set rowid = 1; -- should allow regular update and not throw unique constraint select count(*) from t where rowid = 1; } {3} + +# https://github.com/tursodatabase/turso/issues/3637 +do_execsql_test_in_memory_any_error create_table_duplicate_column_names { + CREATE TABLE t(a, a); +} + +do_execsql_test_in_memory_any_error create_table_duplicate_column_names_case_insensitive { + CREATE TABLE t(A, a); +} + +do_execsql_test_in_memory_any_error create_table_duplicate_column_names_quoted { + CREATE TABLE t("a", a); +} + +# https://github.com/tursodatabase/turso/issues/3675 +do_execsql_test_in_memory_any_error create_table_view_collision-1 { + CREATE VIEW v_same AS SELECT 1; + CREATE TABLE v_same(x INT); +} + +do_execsql_test_in_memory_any_error create_view_table_collision-1 { + CREATE TABLE t_same(x INT); + CREATE VIEW t_same AS SELECT 1; +} + +do_execsql_test_in_memory_any_error create_index_view_collision-1 { + CREATE VIEW i_same AS SELECT 1; + CREATE TABLE t1(x); + CREATE INDEX i_same ON t1(x); +} + +do_execsql_test_in_memory_any_error create_index_table_collision-1 { + CREATE TABLE i_same(x INT); + CREATE TABLE t2(y); + CREATE INDEX i_same ON t2(y); +} + +do_execsql_test_in_memory_any_error create_table_index_collision-1 { + CREATE TABLE t3(z); + CREATE INDEX ix_same ON t3(z); + CREATE TABLE ix_same(x INT); +} + +do_execsql_test_in_memory_any_error create_view_index_collision-1 { + CREATE TABLE t4(w); + CREATE INDEX ix_same ON t4(w); + CREATE VIEW ix_same AS SELECT 1; +} + +# https://github.com/tursodatabase/turso/issues/3796 +do_execsql_test_on_specific_db {:memory:} col-default-true { + create table t(id integer primary key, a default true); + insert into t (id) values (1); + SELECT a from t; +} {1} + +# https://github.com/tursodatabase/turso/issues/3796 +do_execsql_test_on_specific_db {:memory:} col-default-false { + create table t(id integer primary key, a default false); + insert into t (id) values (1); + SELECT a from t; +} {0} diff --git a/testing/drop_table.test b/testing/drop_table.test index 9365b70d4..d7d02256c 100755 --- a/testing/drop_table.test +++ b/testing/drop_table.test @@ -12,6 +12,15 @@ do_execsql_test_on_specific_db {:memory:} drop-table-basic-1 { SELECT count(*) FROM sqlite_schema WHERE type='table' AND name='t1'; } {0} +# The table should be dropped irrespective of the case of the table name. +do_execsql_test_on_specific_db {:memory:} drop-table-case-insensitive { + CREATE TABLE test (x INTEGER PRIMARY KEY); + INSERT INTO test VALUES (1); + INSERT INTO test VALUES (2); + DROP TABLE TeSt; + SELECT count(*) FROM sqlite_schema WHERE type='table' AND name='test'; +} {0} + # Test DROP TABLE IF EXISTS on existing table do_execsql_test_on_specific_db {:memory:} drop-table-if-exists-1 { CREATE TABLE t2 (x INTEGER PRIMARY KEY); diff --git a/testing/foreign_keys.test b/testing/foreign_keys.test new file mode 100644 index 000000000..360ccbc8f --- /dev/null +++ b/testing/foreign_keys.test @@ -0,0 +1,1209 @@ +#!/usr/bin/env tclsh + +set testdir [file dirname $argv0] +source $testdir/tester.tcl +source $testdir/sqlite3/tester.tcl + +do_execsql_test_on_specific_db {:memory:} fk-basic-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t (id INTEGER PRIMARY KEY, a TEXT); + CREATE TABLE t2 (id INTEGER PRIMARY KEY, tid REFERENCES t(id)); + INSERT INTO t VALUES (1,'x'),(2,'y'); + INSERT INTO t2 VALUES (10,1),(11,NULL); -- NULL child ok + SELECT id,tid FROM t2 ORDER BY id; +} {10|1 +11|} + +do_execsql_test_in_memory_any_error fk-insert-child-missing-parent { + PRAGMA foreign_keys=ON; + CREATE TABLE t (id INTEGER PRIMARY KEY, a TEXT); + CREATE TABLE t2 (id INTEGER PRIMARY KEY, tid REFERENCES t(id)); + INSERT INTO t2 VALUES (20,99); +} + +do_execsql_test_in_memory_any_error fk-update-child-to-missing-parent { + PRAGMA foreign_keys=ON; + CREATE TABLE t (id INTEGER PRIMARY KEY, a TEXT); + CREATE TABLE t2 (id INTEGER PRIMARY KEY, tid REFERENCES t(id)); + INSERT INTO t VALUES (1,'x'); + INSERT INTO t2 VALUES (10,1); + UPDATE t2 SET tid = 42 WHERE id = 10; -- now missing +} + +do_execsql_test_on_specific_db {:memory:} fk-update-child-to-null-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t (id INTEGER PRIMARY KEY); + CREATE TABLE t2 (id INTEGER PRIMARY KEY, tid REFERENCES t(id)); + INSERT INTO t VALUES (1); + INSERT INTO t2 VALUES (7,1); + UPDATE t2 SET tid = NULL WHERE id = 7; + SELECT id, tid FROM t2; +} {7|} + +do_execsql_test_in_memory_any_error fk-delete-parent-blocked { + PRAGMA foreign_keys=ON; + CREATE TABLE t (id INTEGER PRIMARY KEY, a TEXT); + CREATE TABLE t2 (id INTEGER PRIMARY KEY, tid REFERENCES t(id)); + INSERT INTO t VALUES (1,'x'),(2,'y'); + INSERT INTO t2 VALUES (10,2); + DELETE FROM t WHERE id=2; +} + +do_execsql_test_on_specific_db {:memory:} fk-delete-parent-ok-when-no-child { + PRAGMA foreign_keys=ON; + CREATE TABLE t (id INTEGER PRIMARY KEY, a TEXT); + CREATE TABLE t2 (id INTEGER PRIMARY KEY, tid REFERENCES t(id)); + INSERT INTO t VALUES (1,'x'),(2,'y'); + INSERT INTO t2 VALUES (10,1); + DELETE FROM t WHERE id=2; + SELECT id FROM t ORDER BY id; +} {1} + + +do_execsql_test_on_specific_db {:memory:} fk-composite-pk-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE p( + a INT NOT NULL, + b INT NOT NULL, + PRIMARY KEY(a,b) + ); + CREATE TABLE c( + id INT PRIMARY KEY, + x INT, y INT, + FOREIGN KEY(x,y) REFERENCES p(a,b) + ); + INSERT INTO p VALUES (1,1),(1,2); + INSERT INTO c VALUES (10,1,1),(11,1,2),(12,NULL,2); -- NULL in child allowed + SELECT id,x,y FROM c ORDER BY id; +} {10|1|1 +11|1|2 +12||2} + +do_execsql_test_in_memory_any_error fk-composite-pk-missing { + PRAGMA foreign_keys=ON; + CREATE TABLE p( + a INT NOT NULL, + b INT NOT NULL, + PRIMARY KEY(a,b) + ); + CREATE TABLE c( + id INT PRIMARY KEY, + x INT, y INT, + FOREIGN KEY(x,y) REFERENCES p(a,b) + ); + INSERT INTO p VALUES (1,1); + INSERT INTO c VALUES (20,1,2); -- (1,2) missing +} + +do_execsql_test_in_memory_any_error fk-composite-update-child-missing { + PRAGMA foreign_keys=ON; + CREATE TABLE p(a INT NOT NULL, b INT NOT NULL, PRIMARY KEY(a,b)); + CREATE TABLE c(id INT PRIMARY KEY, x INT, y INT, + FOREIGN KEY(x,y) REFERENCES p(a,b)); + INSERT INTO p VALUES (1,1),(2,2); + INSERT INTO c VALUES (5,1,1); + UPDATE c SET x=2,y=3 WHERE id=5; +} + +do_execsql_test_on_specific_db {:memory:} fk-composite-unique-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(u TEXT, v TEXT, pad INT, UNIQUE(u,v)); + CREATE TABLE child(id INT PRIMARY KEY, cu TEXT, cv TEXT, + FOREIGN KEY(cu,cv) REFERENCES parent(u,v)); + INSERT INTO parent VALUES ('A','B',0),('A','C',0); + INSERT INTO child VALUES (1,'A','B'); + SELECT id, cu, cv FROM child ORDER BY id; +} {1|A|B} + +do_execsql_test_in_memory_any_error fk-composite-unique-missing { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(u TEXT, v TEXT, pad INT, UNIQUE(u,v)); + CREATE TABLE child(id INT PRIMARY KEY, cu TEXT, cv TEXT, + FOREIGN KEY(cu,cv) REFERENCES parent(u,v)); + INSERT INTO parent VALUES ('A','B',0); + INSERT INTO child VALUES (2,'A','X'); -- no ('A','X') in parent +} + +do_execsql_test_in_memory_any_error fk-rowid-alias-parent { + PRAGMA foreign_keys=ON; + CREATE TABLE t(id INTEGER PRIMARY KEY, a TEXT); + CREATE TABLE c(cid INTEGER PRIMARY KEY, rid REFERENCES t(rowid)); + INSERT INTO t VALUES (100,'x'); + INSERT INTO c VALUES (1, 100); +} + +do_execsql_test_in_memory_any_error fk-rowid-alias-parent-missing { + PRAGMA foreign_keys=ON; + CREATE TABLE t(id INTEGER PRIMARY KEY, a TEXT); + CREATE TABLE c(cid INTEGER PRIMARY KEY, rid REFERENCES t(rowid)); + INSERT INTO c VALUES (1, 9999); +} + +do_execsql_test_on_specific_db {:memory:} fk-update-child-noop-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid REFERENCES p(id)); + INSERT INTO p VALUES (1); + INSERT INTO c VALUES (10,1); + UPDATE c SET id = id WHERE id = 10; -- no FK column touched + SELECT id, pid FROM c; +} {10|1} + +do_execsql_test_in_memory_any_error fk-delete-parent-composite-scan { + PRAGMA foreign_keys=ON; + CREATE TABLE p(a INT NOT NULL, b INT NOT NULL, PRIMARY KEY(a,b)); + CREATE TABLE c(id INT PRIMARY KEY, x INT, y INT, + FOREIGN KEY(x,y) REFERENCES p(a,b)); + INSERT INTO p VALUES (1,2),(2,3); + INSERT INTO c VALUES (7,2,3); + DELETE FROM p WHERE a=2 AND b=3; +} + +do_execsql_test_on_specific_db {:memory:} fk-update-child-to-existing-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t(id INTEGER PRIMARY KEY); + CREATE TABLE t2(id INTEGER PRIMARY KEY, tid REFERENCES t(id)); + INSERT INTO t VALUES (1),(2); + INSERT INTO t2 VALUES (9,1); + UPDATE t2 SET tid = 2 WHERE id = 9; + SELECT id, tid FROM t2; +} {9|2} + +do_execsql_test_on_specific_db {:memory:} fk-composite-pk-delete-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE p(a INT NOT NULL, b INT NOT NULL, PRIMARY KEY(a,b)); + CREATE TABLE c(id INT PRIMARY KEY, x INT, y INT, + FOREIGN KEY(x,y) REFERENCES p(a,b)); + INSERT INTO p VALUES (1,2),(2,3); + INSERT INTO c VALUES (7,2,3); + -- Deleting a non-referenced parent tuple is OK + DELETE FROM p WHERE a=1 AND b=2; + SELECT a,b FROM p ORDER BY a,b; +} {2|3} + +do_execsql_test_in_memory_any_error fk-composite-pk-delete-violate { + PRAGMA foreign_keys=ON; + CREATE TABLE p(a INT NOT NULL, b INT NOT NULL, PRIMARY KEY(a,b)); + CREATE TABLE c(id INT PRIMARY KEY, x INT, y INT, + FOREIGN KEY(x,y) REFERENCES p(a,b)); + INSERT INTO p VALUES (2,3); + INSERT INTO c VALUES (7,2,3); + -- Deleting the referenced tuple should fail + DELETE FROM p WHERE a=2 AND b=3; +} + +# Parent columns omitted: should default to parent's declared PRIMARY KEY (composite) +do_execsql_test_on_specific_db {:memory:} fk-default-parent-pk-composite-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE p( + a INT NOT NULL, + b INT NOT NULL, + PRIMARY KEY(a,b) + ); + -- Parent columns omitted in REFERENCES p + CREATE TABLE c( + id INT PRIMARY KEY, + x INT, y INT, + FOREIGN KEY(x,y) REFERENCES p + ); + INSERT INTO p VALUES (1,1), (1,2); + INSERT INTO c VALUES (10,1,1), (11,1,2), (12,NULL,2); -- NULL in child allowed + SELECT id,x,y FROM c ORDER BY id; +} {10|1|1 +11|1|2 +12||2} + +do_execsql_test_in_memory_any_error fk-default-parent-pk-composite-missing { + PRAGMA foreign_keys=ON; + CREATE TABLE p(a INT NOT NULL, b INT NOT NULL, PRIMARY KEY(a,b)); + CREATE TABLE c(id INT PRIMARY KEY, x INT, y INT, + FOREIGN KEY(x,y) REFERENCES p); -- omit parent cols + INSERT INTO p VALUES (1,1); + INSERT INTO c VALUES (20,1,2); -- (1,2) missing in parent +} + +# Parent has no explicitly declared PK, so we throw parse error when referencing bare table +do_execsql_test_in_memory_any_error fk-default-parent-rowid-no-parent-pk { + PRAGMA foreign_keys=ON; + CREATE TABLE p_no_pk(v TEXT); + CREATE TABLE c_rowid(id INT PRIMARY KEY, + r REFERENCES p_no_pk); + INSERT INTO p_no_pk(v) VALUES ('a'), ('b'); + INSERT INTO c_rowid VALUES (1, 1); +} + +do_execsql_test_on_specific_db {:memory:} fk-parent-omit-cols-parent-has-pk { + PRAGMA foreign_keys=ON; + CREATE TABLE p_pk(id INTEGER PRIMARY KEY, v TEXT); + CREATE TABLE c_ok(id INT PRIMARY KEY, r REFERENCES p_pk); -- binds to p_pk(id) + INSERT INTO p_pk VALUES (1,'a'),(2,'b'); + INSERT INTO c_ok VALUES (10,1); + INSERT INTO c_ok VALUES (11,2); + SELECT id, r FROM c_ok ORDER BY id; +} {10|1 11|2} + + +# Self-reference (same table) with INTEGER PRIMARY KEY: single-row insert should pass +do_execsql_test_on_specific_db {:memory:} fk-self-ipk-single-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + id INTEGER PRIMARY KEY, + rid REFERENCES t(id) -- child->parent in same table + ); + INSERT INTO t(id,rid) VALUES(5,5); -- self-reference, single-row + SELECT id, rid FROM t; +} {5|5} + +# Self-reference with mismatched value: should fail immediately (no counter semantics used) +do_execsql_test_in_memory_any_error fk-self-ipk-single-mismatch { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + id INTEGER PRIMARY KEY, + rid REFERENCES t(id) + ); + INSERT INTO t(id,rid) VALUES(5,4); -- rid!=id -> FK violation +} + +# Self-reference on composite PRIMARY KEY: single-row insert should pass +do_execsql_test_on_specific_db {:memory:} fk-self-composite-single-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + a INT NOT NULL, + b INT NOT NULL, + x INT, + y INT, + PRIMARY KEY(a,b), + FOREIGN KEY(x,y) REFERENCES t(a,b) + ); + INSERT INTO t(a,b,x,y) VALUES(1,2,1,2); -- self-reference matches PK + SELECT a,b,x,y FROM t; +} {1|2|1|2} + +# Rowid parent path: text '10' must be coerced to integer (MustBeInt) and succeed +do_execsql_test_on_specific_db {:memory:} fk-rowid-mustbeint-coercion-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(cid INTEGER PRIMARY KEY, pid REFERENCES p(id)); + INSERT INTO p(id) VALUES(10); + INSERT INTO c VALUES(1, '10'); -- text -> int via MustBeInt; should match + SELECT pid FROM c; +} {10} + +# Rowid parent path: non-numeric text cannot be coerced -> violation +do_execsql_test_in_memory_any_error fk-rowid-mustbeint-coercion-fail { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(cid INTEGER PRIMARY KEY, pid REFERENCES p(id)); + INSERT INTO p(id) VALUES(10); + INSERT INTO c VALUES(2, 'abc'); -- MustBeInt fails to match any parent row +} + +# Parent match via UNIQUE index (non-rowid), success path +do_execsql_test_on_specific_db {:memory:} fk-parent-unique-index-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(u TEXT, v TEXT, pad INT, UNIQUE(u,v)); + CREATE TABLE child(id INT PRIMARY KEY, cu TEXT, cv TEXT, + FOREIGN KEY(cu,cv) REFERENCES parent(u,v)); + INSERT INTO parent VALUES ('A','B',0),('A','C',0); + INSERT INTO child VALUES (1,'A','B'); + SELECT id, cu, cv FROM child ORDER BY id; +} {1|A|B} + +# Parent UNIQUE index path: missing key -> immediate violation +do_execsql_test_in_memory_any_error fk-parent-unique-index-missing { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(u TEXT, v TEXT, pad INT, UNIQUE(u,v)); + CREATE TABLE child(id INT PRIMARY KEY, cu TEXT, cv TEXT, + FOREIGN KEY(cu,cv) REFERENCES parent(u,v)); + INSERT INTO parent VALUES ('A','B',0); + INSERT INTO child VALUES (2,'A','X'); -- no ('A','X') in parent +} + +# NULL in child short-circuits FK check +do_execsql_test_on_specific_db {:memory:} fk-child-null-shortcircuit { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid REFERENCES p(id)); + INSERT INTO c VALUES (1, NULL); -- NULL child is allowed + SELECT id, pid FROM c; +} {1|} + +do_execsql_test_on_specific_db {:memory:} fk-self-unique-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + u TEXT, + v TEXT, + cu TEXT, + cv TEXT, + UNIQUE(u,v), + FOREIGN KEY(cu,cv) REFERENCES t(u,v) + ); + -- Single row insert where child points to its own (u,v): allowed + INSERT INTO t(u,v,cu,cv) VALUES('A','B','A','B'); + SELECT u, v, cu, cv FROM t; +} {A|B|A|B} + +do_execsql_test_in_memory_any_error fk-self-unique-mismatch { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + u TEXT, + v TEXT, + cu TEXT, + cv TEXT, + UNIQUE(u,v), + FOREIGN KEY(cu,cv) REFERENCES t(u,v) + ); + -- Child points to a different (u,v) that doesn't exist: must fail + INSERT INTO t(u,v,cu,cv) VALUES('A','B','A','X'); +} + +do_execsql_test_on_specific_db {:memory:} fk-self-unique-reference-existing-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + u TEXT, + v TEXT, + cu TEXT, + cv TEXT, + UNIQUE(u,v), + FOREIGN KEY(cu,cv) REFERENCES t(u,v) + ); + -- Insert a parent row first + INSERT INTO t(u,v,cu,cv) VALUES('P','Q',NULL,NULL); + -- Now insert a row whose FK references the existing ('P','Q'): OK + INSERT INTO t(u,v,cu,cv) VALUES('X','Y','P','Q'); + SELECT u, v, cu, cv FROM t ORDER BY u, v, cu, cv; +} {P|Q|| X|Y|P|Q} + +do_execsql_test_on_specific_db {:memory:} fk-self-unique-multirow-no-fastpath { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + u TEXT, + v TEXT, + cu TEXT, + cv TEXT, + UNIQUE(u,v), + FOREIGN KEY(cu,cv) REFERENCES t(u,v) + ); + INSERT INTO t(u,v,cu,cv) VALUES + ('C','D','C','D'), + ('E','F','E','F'); +} {} + +do_execsql_test_in_memory_any_error fk-self-multirow-one-bad { + PRAGMA foreign_keys=ON; + CREATE TABLE t(id INTEGER PRIMARY KEY, rid INTEGER, + FOREIGN KEY(rid) REFERENCES t(id)); + INSERT INTO t(id,rid) VALUES (1,1),(3,99); -- 99 has no parent -> error +} + +# doesnt fail because tx is un-committed +do_execsql_test_on_specific_db {:memory:} fk-deferred-commit-doesnt-fail-early { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1, 99); -- shouldnt fail because we are mid-tx +} {} + +# it should fail here because we actuall COMMIT +do_execsql_test_in_memory_any_error fk-deferred-commit-fails { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1, 99); + COMMIT; +} + + +# If we fix it before COMMIT, COMMIT succeeds +do_execsql_test_on_specific_db {:memory:} fk-deferred-fix-before-commit-succeeds { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO c VALUES(1, 99); -- temporary violation + INSERT INTO p VALUES(99); -- fix parent + COMMIT; + SELECT * FROM p ORDER BY 1; +} {99} + +# ROLLBACK clears deferred state; a new tx can still fail if violation persists +do_execsql_test_on_specific_db {:memory:} fk-deferred-rollback-clears { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO c VALUES(1, 123); + ROLLBACK; + + -- Now start over and *fix* it, COMMIT should pass. + BEGIN; + INSERT INTO p VALUES(123); + INSERT INTO c VALUES(1, 123); + COMMIT; + SELECT * FROM c ORDER BY 1; +} {1|123} + + +do_execsql_test_on_specific_db {:memory:} fk-deferred-insert-parent-fixes-before-commit { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO c VALUES(1, 50); -- violation + INSERT INTO p VALUES(50); -- resolve + COMMIT; + SELECT * FROM c ORDER BY 1; +} {1|50} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-update-fixes-child-before-commit { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO c VALUES(1, 50); -- violation + INSERT INTO p VALUES(32); + UPDATE c SET pid=32 WHERE id=1; -- resolve child + COMMIT; + SELECT * FROM c ORDER BY 1; +} {1|32} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-delete-fixes-child-before-commit { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO c VALUES(1, 50); -- violation + INSERT INTO p VALUES(32); + DELETE FROM c WHERE id=1; -- resolve by deleting child + COMMIT; + SELECT * FROM c ORDER BY 1; +} {} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-update-fixes-parent-before-commit { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO c VALUES(1, 50); -- violation + INSERT INTO p VALUES(32); + UPDATE p SET id=50 WHERE id=32; -- resolve via parent + COMMIT; + SELECT * FROM c ORDER BY 1; +} {1|50} + +# Self-referential: row referencing itself should succeed +do_execsql_test_on_specific_db {:memory:} fk-deferred-self-ref-succeeds { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + id INTEGER PRIMARY KEY, + pid INT REFERENCES t(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO t VALUES(1, 1); -- self-match + COMMIT; + SELECT * FROM t ORDER BY 1; +} {1|1} + +# Two-step self-ref: insert invalid, then create parent before COMMIT +do_execsql_test_on_specific_db {:memory:} fk-deferred-self-ref-late-parent { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + id INTEGER PRIMARY KEY, + pid INT REFERENCES t(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO t VALUES(2, 3); -- currently invalid + INSERT INTO t VALUES(3, 3); -- now parent exists + COMMIT; + SELECT * FROM t ORDER BY 1; +} {2|3 +3|3} + + +# counter must not be neutralized by later good statements +do_execsql_test_in_memory_any_error fk-deferred-neutralize.1 { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(id INTEGER PRIMARY KEY); + CREATE TABLE parent_comp(a INT NOT NULL, b INT NOT NULL, PRIMARY KEY(a,b)); + CREATE TABLE child_deferred(id INTEGER PRIMARY KEY, pid INT, + FOREIGN KEY(pid) REFERENCES parent(id)); + + CREATE TABLE child_comp_deferred(id INTEGER PRIMARY KEY, ca INT, cb INT, + FOREIGN KEY(ca,cb) REFERENCES parent_comp(a,b)); + INSERT INTO parent_comp VALUES (4,-1); + BEGIN; + INSERT INTO child_deferred VALUES (1, 999); + INSERT INTO child_comp_deferred VALUES (2, 4, -1); + COMMIT; +} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-upsert-late-parent { +PRAGMA foreign_keys=ON; + +CREATE TABLE p(id INTEGER PRIMARY KEY); +CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED +); + +BEGIN; + INSERT INTO c VALUES(1, 50); -- deferred violation + INSERT INTO p VALUES(32); -- parent exists, but pid still 50 + INSERT INTO c(id,pid) VALUES(1,32) + ON CONFLICT(id) DO UPDATE SET pid=excluded.pid; -- resolve child via UPSERT +COMMIT; +-- Expect: row is (1,32) and no violations remain +SELECT * FROM c ORDER BY id; +} {1|32} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-upsert-late-child { +PRAGMA foreign_keys=ON; + +CREATE TABLE p( + id INTEGER PRIMARY KEY, + u INT UNIQUE +); +CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED +); +BEGIN; + INSERT INTO c VALUES(1, 50); -- deferred violation (no parent 50) + INSERT INTO p VALUES(32, 7); -- parent row with u=7 + -- Trigger DO UPDATE via conflict on p.u, then change the PK id to 50, + -- which satisfies the child reference. + INSERT INTO p(id,u) VALUES(999,7) + ON CONFLICT(u) DO UPDATE SET id=50; +COMMIT; +-- Expect: parent is now (50,7), child (1,50), no violations remain +SELECT p.id, c.id FROM p join c on c.pid = p.id; +} {50|1} + +do_execsql_test_in_memory_any_error fk-deferred-insert-commit-fails { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INTEGER REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO c VALUES(1, 99); -- no parent -> deferred violation + COMMIT; -- must fail +} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-insert-parent-fix-before-commit { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INTEGER REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO c VALUES(1, 99); -- violation + INSERT INTO p VALUES(99); -- fix by inserting parent + COMMIT; + SELECT id, pid FROM c ORDER BY id; +} {1|99} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-insert-multi-children-one-parent-fix { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1, 50); + INSERT INTO c VALUES(2, 50); -- two violations pointing to same parent + INSERT INTO p VALUES(50); -- one parent fixes both + COMMIT; + SELECT id, pid FROM c ORDER BY id; +} {1|50 2|50} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-insert-then-delete-child-fix { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1, 77); -- violation + DELETE FROM c WHERE id=1; -- resolve by removing the child + COMMIT; + SELECT count(*) FROM c; +} {0} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-insert-self-ref-succeeds { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + id INTEGER PRIMARY KEY, + pid INT REFERENCES t(id) DEFERRABLE INITIALLY DEFERRED + ); + BEGIN; + INSERT INTO t VALUES(1, 1); -- self-reference, legal at COMMIT + COMMIT; + SELECT id, pid FROM t; +} {1|1} + +do_execsql_test_in_memory_any_error fk-deferred-update-child-breaks-commit-fails { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(10); + INSERT INTO c VALUES(1, 10); -- valid + BEGIN; + UPDATE c SET pid=99 WHERE id=1; -- create violation + COMMIT; -- must fail +} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-update-child-fix-before-commit { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(10); + INSERT INTO c VALUES(1, 10); + BEGIN; + UPDATE c SET pid=99 WHERE id=1; -- violation + UPDATE c SET pid=10 WHERE id=1; -- fix child back + COMMIT; + SELECT id, pid FROM c; +} {1|10} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-update-child-fix-by-inserting-parent { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(10); + INSERT INTO c VALUES(1, 10); + BEGIN; + UPDATE c SET pid=50 WHERE id=1; -- violation + INSERT INTO p VALUES(50); -- fix by adding parent + COMMIT; + SELECT id, pid FROM c; +} {1|50} + +do_execsql_test_in_memory_any_error fk-deferred-update-parent-breaks-commit-fails { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(32); + INSERT INTO c VALUES(1, 32); -- valid + BEGIN; + UPDATE p SET id=50 WHERE id=32; -- break child reference + COMMIT; -- must fail (no fix) +} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-update-parent-fix-by-updating-child { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(32); + INSERT INTO c VALUES(1, 32); + BEGIN; + UPDATE p SET id=50 WHERE id=32; -- break + UPDATE c SET pid=50 WHERE id=1; -- fix child to new parent key + COMMIT; + SELECT id, pid FROM c; +} {1|50} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-update-parent-fix-by-reverting-parent { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(32); + INSERT INTO c VALUES(1, 32); + BEGIN; + UPDATE p SET id=50 WHERE id=32; -- break + UPDATE p SET id=32 WHERE id=50; -- revert (fix) + COMMIT; + SELECT id, pid FROM c; +} {1|32} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-update-self-ref-id-change-and-fix { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + id INTEGER PRIMARY KEY, + pid INT REFERENCES t(id) DEFERRABLE INITIALLY DEFERRED + ); + INSERT INTO t VALUES(1,1); + BEGIN; + UPDATE t SET id=2 WHERE id=1; -- break self-ref + UPDATE t SET pid=2 WHERE id=2; -- fix to new self + COMMIT; + SELECT id, pid FROM t; +} {2|2} + +do_execsql_test_in_memory_any_error fk-deferred-delete-parent-commit-fails { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(10); + INSERT INTO c VALUES(1, 10); -- valid + BEGIN; + DELETE FROM p WHERE id=10; -- break reference + COMMIT; -- must fail +} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-delete-parent-then-delete-child-fix { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(10); + INSERT INTO c VALUES(1, 10); + BEGIN; + DELETE FROM p WHERE id=10; -- break + DELETE FROM c WHERE id=1; -- fix by removing child + COMMIT; + SELECT count(*) FROM p, c; -- both empty +} {0} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-delete-parent-then-reinsert-parent-fix { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(10); + INSERT INTO c VALUES(1, 10); + BEGIN; + DELETE FROM p WHERE id=10; -- break + INSERT INTO p VALUES(10); -- fix by re-creating parent + COMMIT; + SELECT id, pid FROM c; +} {1|10} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-delete-self-ref-row-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t( + id INTEGER PRIMARY KEY, + pid INT REFERENCES t(id) DEFERRABLE INITIALLY DEFERRED + ); + INSERT INTO t VALUES(1,1); -- valid + BEGIN; + DELETE FROM t WHERE id=1; -- removes both child+parent (same row) + COMMIT; -- should succeed + SELECT count(*) FROM t; +} {0} + +do_execsql_test_on_specific_db {:memory:} fk-deferred-delete-parent-then-update-child-to-null-fix { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c( + id INTEGER PRIMARY KEY, + pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED + ); + INSERT INTO p VALUES(5); + INSERT INTO c VALUES(1,5); + BEGIN; + DELETE FROM p WHERE id=5; -- break + UPDATE c SET pid=NULL WHERE id=1; -- fix (NULL never violates) + COMMIT; + SELECT id, pid FROM c; +} {1|} + +# AUTOCOMMIT: deferred FK still fails at end-of-statement +do_execsql_test_in_memory_any_error fk-deferred-autocommit-insert-missing-parent { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(id INTEGER PRIMARY KEY); + CREATE TABLE child(id INTEGER PRIMARY KEY, pid INT REFERENCES parent(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO child VALUES(1, 3); -- no BEGIN; should fail at statement end +} + +# AUTOCOMMIT: self-referential insert is OK (parent is same row) +do_execsql_test_on_specific_db {:memory:} fk-deferred-autocommit-selfref-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t(id INTEGER PRIMARY KEY, pid INT REFERENCES t(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO t VALUES(1,1); + SELECT * FROM t; +} {1|1} + +# AUTOCOMMIT: deleting a parent that has a child → fails at statement end +do_execsql_test_in_memory_any_error fk-deferred-autocommit-delete-parent-fails { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(1); + INSERT INTO c VALUES(10,1); + DELETE FROM p WHERE id=1; -- no BEGIN; should fail at statement end +} + +# TX: delete a referenced parent then reinsert before COMMIT -> OK +do_execsql_test_on_specific_db {:memory:} fk-deferred-tx-delete-parent-then-reinsert-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(5); + INSERT INTO c VALUES(1,5); + BEGIN; + DELETE FROM p WHERE id=5; -- violation (deferred) + INSERT INTO p VALUES(5); -- fix in same tx + COMMIT; + SELECT count(*) FROM p WHERE id=5; +} {1} + +# TX: multiple violating children, later insert parent, COMMIT -> OK +do_execsql_test_on_specific_db {:memory:} fk-deferred-tx-multi-children-fixed-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1,99); + INSERT INTO c VALUES(2,99); + INSERT INTO p VALUES(99); + COMMIT; + SELECT id,pid FROM c ORDER BY id; +} {1|99 2|99} + +# one of several children left unfixed -> COMMIT fails +do_execsql_test_in_memory_any_error fk-deferred-tx-multi-children-one-left-fails { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1,42); + INSERT INTO c VALUES(2,42); + INSERT INTO p VALUES(42); + UPDATE c SET pid=777 WHERE id=2; -- reintroduce a bad reference + COMMIT; -- should fail +} + +# composite PK parent, fix via parent UPDATE before COMMIT -> OK +do_execsql_test_on_specific_db {:memory:} fk-deferred-composite-parent-update-fix { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(a INT NOT NULL, b INT NOT NULL, PRIMARY KEY(a,b)); + CREATE TABLE child(id INT PRIMARY KEY, ca INT, cb INT, + FOREIGN KEY(ca,cb) REFERENCES parent(a,b) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO parent VALUES(1,1); + BEGIN; + INSERT INTO child VALUES(10, 7, 7); -- violation + UPDATE parent SET a=7, b=7 WHERE a=1 AND b=1; -- fix composite PK + COMMIT; + SELECT id, ca, cb FROM child; +} {10|7|7} + +# TX: NULL in child FK -> never a violation +do_execsql_test_on_specific_db {:memory:} fk-deferred-null-fk-never-violates { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1, NULL); -- always OK + COMMIT; + SELECT id, pid FROM c; +} {1|} + +# TX: child UPDATE to NULL resolves before COMMIT +do_execsql_test_on_specific_db {:memory:} fk-deferred-update-child-null-resolves { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1, 500); -- violation + UPDATE c SET pid=NULL WHERE id=1; -- resolves + COMMIT; + SELECT * FROM c; +} {1|} + +# TX: delete violating child resolves before COMMIT +do_execsql_test_on_specific_db {:memory:} fk-deferred-delete-child-resolves { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1, 777); -- violation + DELETE FROM c WHERE id=1; -- resolves + COMMIT; + SELECT count(*) FROM c; +} {0} + +# TX: update parent PK to match child before COMMIT -> OK +do_execsql_test_on_specific_db {:memory:} fk-deferred-update-parent-pk-resolves { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO p VALUES(10); + BEGIN; + INSERT INTO c VALUES(1, 20); -- violation + UPDATE p SET id=20 WHERE id=10; -- resolve via parent + COMMIT; + SELECT * FROM c; +} {1|20} + +# Two-table cycle; both inserted before COMMIT -> OK +do_execsql_test_on_specific_db {:memory:} fk-deferred-cycle-two-tables-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE a(id INT PRIMARY KEY, b_id INT, FOREIGN KEY(b_id) REFERENCES b(id) DEFERRABLE INITIALLY DEFERRED); + CREATE TABLE b(id INT PRIMARY KEY, a_id INT, FOREIGN KEY(a_id) REFERENCES a(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO a VALUES(1, 1); -- refers to b(1) (not yet present) + INSERT INTO b VALUES(1, 1); -- refers to a(1) + COMMIT; + SELECT count(b.id), count(a.id) FROM a, b; +} {1|1} + +# Delete a row that self-references (child==parent) within a tx -> OK +do_execsql_test_on_specific_db {:memory:} fk-deferred-selfref-delete-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE t(id INTEGER PRIMARY KEY, pid INT REFERENCES t(id) DEFERRABLE INITIALLY DEFERRED); + INSERT INTO t VALUES(1,1); + BEGIN; + DELETE FROM t WHERE id=1; + COMMIT; + SELECT count(*) FROM t; +} {0} + + +do_execsql_test_on_specific_db {:memory:} fk-parentcomp-donothing-noconflict-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE parent (id INTEGER PRIMARY KEY, a INT, b INT); + CREATE TABLE child_deferred ( + id INTEGER PRIMARY KEY, pid INT, x INT, + FOREIGN KEY(pid) REFERENCES parent(id) DEFERRABLE INITIALLY DEFERRED + ); + CREATE TABLE parent_comp (a INT NOT NULL, b INT NOT NULL, c INT, PRIMARY KEY(a,b)); + CREATE TABLE child_comp_deferred ( + id INTEGER PRIMARY KEY, ca INT, cb INT, z INT, + FOREIGN KEY (ca,cb) REFERENCES parent_comp(a,b) DEFERRABLE INITIALLY DEFERRED + ); + + -- No conflict on (a,b); should insert 1 row, no FK noise + INSERT INTO parent_comp VALUES (-1,-1,9) ON CONFLICT DO NOTHING; + SELECT a,b,c FROM parent_comp ORDER BY a,b; +} {-1|-1|9} + +do_execsql_test_on_specific_db {:memory:} fk-parentcomp-donothing-conflict-noop { + PRAGMA foreign_keys=ON; + CREATE TABLE parent_comp (a INT NOT NULL, b INT NOT NULL, c INT, PRIMARY KEY(a,b)); + CREATE TABLE child_comp_deferred ( + id INTEGER PRIMARY KEY, ca INT, cb INT, z INT, + FOREIGN KEY (ca,cb) REFERENCES parent_comp(a,b) DEFERRABLE INITIALLY DEFERRED + ); + + INSERT INTO parent_comp VALUES (10,20,1); + -- Conflicts with existing (10,20); must do nothing (no triggers, no FK scans that mutate counters) + INSERT INTO parent_comp VALUES (10,20,999) ON CONFLICT DO NOTHING; + SELECT a,b,c FROM parent_comp; +} {10|20|1} + +do_execsql_test_on_specific_db {:memory:} fk-parentcomp-donothing-unrelated-immediate-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE parent (id INTEGER PRIMARY KEY); + CREATE TABLE child_immediate ( + id INTEGER PRIMARY KEY, pid INT, + FOREIGN KEY(pid) REFERENCES parent(id) -- IMMEDIATE + ); + CREATE TABLE parent_comp (a INT NOT NULL, b INT NOT NULL, c INT, PRIMARY KEY(a,b)); + CREATE TABLE child_comp_deferred ( + id INTEGER PRIMARY KEY, ca INT, cb INT, z INT, + FOREIGN KEY(ca,cb) REFERENCES parent_comp(a,b) DEFERRABLE INITIALLY DEFERRED + ); + + INSERT INTO parent_comp VALUES (-1,-1,9) ON CONFLICT DO NOTHING; + SELECT a,b,c FROM parent_comp; +} {-1|-1|9} + +do_execsql_test_on_specific_db {:memory:} fk-parentcomp-deferred-fix-inside-tx-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE parent_comp (a INT NOT NULL, b INT NOT NULL, c INT, PRIMARY KEY(a,b)); + CREATE TABLE child_comp_deferred ( + id INTEGER PRIMARY KEY, ca INT, cb INT, + FOREIGN KEY(ca,cb) REFERENCES parent_comp(a,b) DEFERRABLE INITIALLY DEFERRED + ); + + BEGIN; + INSERT INTO child_comp_deferred VALUES (1, -5, -6); -- violation + INSERT INTO parent_comp VALUES (-5, -6, 9); -- fix via parent insert + COMMIT; + SELECT id,ca,cb FROM child_comp_deferred; +} {1|-5|-6} + +do_execsql_test_on_specific_db {:memory:} fk-parentcomp-autocommit-unrelated-children-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE parent_comp (a INT NOT NULL, b INT NOT NULL, c INT, PRIMARY KEY(a,b)); + CREATE TABLE child_comp_deferred ( + id INTEGER PRIMARY KEY, ca INT, cb INT, + FOREIGN KEY(ca,cb) REFERENCES parent_comp(a,b) DEFERRABLE INITIALLY DEFERRED + ); + + INSERT INTO parent_comp VALUES (1,1,0); + INSERT INTO child_comp_deferred VALUES (10,1,1); -- valid + INSERT INTO parent_comp VALUES (2,2,0) ON CONFLICT DO NOTHING; -- unrelated insert; must not raise + SELECT a,b,c FROM parent_comp ORDER BY a,b; +} {1|1|0 +2|2|0} + +# ROLLBACK must clear any deferred state; next statement must not trip. +do_execsql_test_on_specific_db {:memory:} fk-rollback-clears-then-donothing-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + CREATE TABLE parent_comp(a INT NOT NULL, b INT NOT NULL, c INT, PRIMARY KEY(a,b)); + + BEGIN; + INSERT INTO c VALUES(1, 456); -- create deferred violation + ROLLBACK; -- must clear counters + + INSERT INTO parent_comp VALUES(-2,-2,0) ON CONFLICT DO NOTHING; + SELECT a,b,c FROM parent_comp; +} {-2|-2|0} + +# DO NOTHING conflict path must touch no FK maintenance at all. +do_execsql_test_on_specific_db {:memory:} fk-parentcomp-donothing-conflict-stays-quiet { + PRAGMA foreign_keys=ON; + CREATE TABLE parent_comp(a INT NOT NULL, b INT NOT NULL, c INT, PRIMARY KEY(a,b)); + CREATE TABLE child_comp_deferred( + id INTEGER PRIMARY KEY, ca INT, cb INT, + FOREIGN KEY(ca,cb) REFERENCES parent_comp(a,b) DEFERRABLE INITIALLY DEFERRED + ); + + INSERT INTO parent_comp VALUES(10,20,1); + -- This conflicts with (10,20) and must be a no-op; if counters move here, it’s a bug. + INSERT INTO parent_comp VALUES(10,20,999) ON CONFLICT DO NOTHING; + + -- Prove DB is sane afterwards (no stray FK error) + INSERT INTO parent_comp VALUES(11,22,3) ON CONFLICT DO NOTHING; + SELECT a,b FROM parent_comp ORDER BY a,b; +} {10|20 +11|22} + +# Two-statement fix inside an explicit transaction (separate statements). +#Insert child (violation), then insert parent in a new statement; commit must pass. +do_execsql_test_on_specific_db {:memory:} fk-deferred-two-stmt-fix-inside-tx-ok { + PRAGMA foreign_keys=ON; + CREATE TABLE p(id INTEGER PRIMARY KEY); + CREATE TABLE c(id INTEGER PRIMARY KEY, pid INT REFERENCES p(id) DEFERRABLE INITIALLY DEFERRED); + BEGIN; + INSERT INTO c VALUES(1, 777); -- violation recorded in tx + INSERT INTO p VALUES(777); -- next statement fixes it + COMMIT; + SELECT * FROM c; +} {1|777} + +do_execsql_test_on_specific_db {:memory:} fk-delete-composite-bounds { + PRAGMA foreign_keys=ON; + CREATE TABLE p(a INT NOT NULL, b INT NOT NULL, v INT, PRIMARY KEY(a,b)); + CREATE TABLE c(id INTEGER PRIMARY KEY, x INT, y INT, w INT, + FOREIGN KEY(x,y) REFERENCES p(a,b)); + + INSERT INTO p VALUES (5,1,0),(5,2,0),(5,4,0); + INSERT INTO c VALUES (1,5,4,0); -- child references (5,4) + + -- This should be a no-op (no row (5,3)), and MUST NOT error. + DELETE FROM p WHERE a=5 AND b=3; +} {} + +# Single column unique index on parent, FK referenced by child +do_execsql_test_in_memory_any_error fk-update-parent-unique-single-col { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(id UNIQUE); + CREATE TABLE child(pid REFERENCES parent(id)); + INSERT INTO parent VALUES(1); + INSERT INTO child VALUES(1); + UPDATE parent SET id = 2 WHERE id = 1; +} + +# Single column with explicit CREATE UNIQUE INDEX +do_execsql_test_in_memory_any_error fk-update-parent-explicit-unique-single-col { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(id); + CREATE UNIQUE INDEX parent_id_idx ON parent(id); + CREATE TABLE child(pid REFERENCES parent(id)); + INSERT INTO parent VALUES(1); + INSERT INTO child VALUES(1); + UPDATE parent SET id = 2 WHERE id = 1; +} + +# Multi-column unique index on parent, FK referenced by multi-column FK in child +do_execsql_test_in_memory_any_error fk-update-parent-unique-multi-col { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(a, b, UNIQUE(a, b)); + CREATE TABLE child(ca, cb, FOREIGN KEY(ca, cb) REFERENCES parent(a, b)); + INSERT INTO parent VALUES(1, 2); + INSERT INTO child VALUES(1, 2); + UPDATE parent SET a = 3 WHERE a = 1 AND b = 2; +} + +# Multi-column unique index on parent, FK referenced by multi-column FK in child +do_execsql_test_in_memory_any_error fk-update-parent-unique-multi-col-2 { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(a, b, UNIQUE(a, b)); + CREATE TABLE child(ca, cb, FOREIGN KEY(ca, cb) REFERENCES parent(a, b)); + INSERT INTO parent VALUES(1, 2); + INSERT INTO child VALUES(1, 2); + UPDATE parent SET b = 3 WHERE a = 1 AND b = 2; +} + +# Multi-column index defined explicitly as CREATE UNIQUE INDEX +do_execsql_test_in_memory_any_error fk-update-parent-explicit-unique-multi-col { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(a, b); + CREATE UNIQUE INDEX parent_ab_idx ON parent(a, b); + CREATE TABLE child(ca, cb, FOREIGN KEY(ca, cb) REFERENCES parent(a, b)); + INSERT INTO parent VALUES(1, 2); + INSERT INTO child VALUES(1, 2); + UPDATE parent SET a = 3 WHERE a = 1 AND b = 2; +} + +# Multi-column index defined explicitly as CREATE UNIQUE INDEX +do_execsql_test_in_memory_any_error fk-update-parent-explicit-unique-multi-col-2 { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(a, b); + CREATE UNIQUE INDEX parent_ab_idx ON parent(a, b); + CREATE TABLE child(ca, cb, FOREIGN KEY(ca, cb) REFERENCES parent(a, b)); + INSERT INTO parent VALUES(1, 2); + INSERT INTO child VALUES(1, 2); + UPDATE parent SET b = 3 WHERE a = 1 AND b = 2; +} + +# Single column INTEGER PRIMARY KEY +do_execsql_test_in_memory_any_error fk-update-parent-int-pk { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(id INTEGER PRIMARY KEY); + CREATE TABLE child(pid REFERENCES parent(id)); + INSERT INTO parent VALUES(1); + INSERT INTO child VALUES(1); + UPDATE parent SET id = 2 WHERE id = 1; +} + +# Single column TEXT PRIMARY KEY +do_execsql_test_in_memory_any_error fk-update-parent-text-pk { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(id PRIMARY KEY); + CREATE TABLE child(pid REFERENCES parent(id)); + INSERT INTO parent VALUES('key1'); + INSERT INTO child VALUES('key1'); + UPDATE parent SET id = 'key2' WHERE id = 'key1'; +} + +# Multi-column PRIMARY KEY +do_execsql_test_in_memory_any_error fk-update-parent-multi-col-pk { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(a, b, PRIMARY KEY(a, b)); + CREATE TABLE child(ca, cb, FOREIGN KEY(ca, cb) REFERENCES parent(a, b)); + INSERT INTO parent VALUES(1, 2); + INSERT INTO child VALUES(1, 2); + UPDATE parent SET a = 3 WHERE a = 1 AND b = 2; +} + +# Multi-column PRIMARY KEY +do_execsql_test_in_memory_any_error fk-update-parent-multi-col-pk-2 { + PRAGMA foreign_keys=ON; + CREATE TABLE parent(a, b, PRIMARY KEY(a, b)); + CREATE TABLE child(ca, cb, FOREIGN KEY(ca, cb) REFERENCES parent(a, b)); + INSERT INTO parent VALUES(1, 2); + INSERT INTO child VALUES(1, 2); + UPDATE parent SET b = 3 WHERE a = 1 AND b = 2; +} diff --git a/testing/insert.test b/testing/insert.test index 4419bcafb..75b8630b3 100755 --- a/testing/insert.test +++ b/testing/insert.test @@ -529,6 +529,13 @@ do_execsql_test_on_specific_db {:memory:} null-value-insert-null-type-column { SELECT * FROM test; } {1|} +# https://github.com/tursodatabase/turso/issues/1710 +do_execsql_test_in_memory_error_content uniq_constraint { + CREATE TABLE test (id INTEGER unique); + insert into test values (1); + insert into test values (1); +} {UNIQUE constraint failed: test.id (19)} + do_execsql_test_in_memory_error_content insert-explicit-rowid-conflict { create table t (x); insert into t(rowid, x) values (1, 1); @@ -678,4 +685,93 @@ do_execsql_test_on_specific_db {:memory:} insert-rowid-select-rowid-success { INSERT INTO t(a) SELECT rowid FROM t; SELECT * FROM t; } {2 -1} \ No newline at end of file +1} + + +# Due to a bug in SQLite, this check is needed to maintain backwards compatibility with rowid alias +# SQLite docs: https://sqlite.org/lang_createtable.html#rowids_and_the_integer_primary_key +# Issue: https://github.com/tursodatabase/turso/issues/3665 +do_execsql_test_on_specific_db {:memory:} insert-rowid-backwards-compability { + CREATE TABLE t(a INTEGER PRIMARY KEY DESC); + INSERT INTO t(a) VALUES (123); + SELECT rowid, * FROM t; +} {1|123} + +do_execsql_test_on_specific_db {:memory:} insert-rowid-backwards-compability-2 { + CREATE TABLE t(a INTEGER, PRIMARY KEY (a DESC)); + INSERT INTO t(a) VALUES (123); + SELECT rowid, * FROM t; +} {123|123} + + +do_execsql_test_on_specific_db {:memory:} ignore-pk-conflict { + CREATE TABLE t(a INTEGER PRIMARY KEY); + INSERT INTO t VALUES (1),(2),(3); + INSERT OR IGNORE INTO t VALUES (2); + SELECT a FROM t ORDER BY a; +} {1 +2 +3} + +do_execsql_test_on_specific_db {:memory:} ignore-unique-conflict { + CREATE TABLE t(a INTEGER, b TEXT UNIQUE); + INSERT INTO t VALUES (1,'x'),(2,'y'); + INSERT OR IGNORE INTO t VALUES (3,'y'); + SELECT a,b FROM t ORDER BY a; +} {1|x +2|y} + +do_execsql_test_on_specific_db {:memory:} ignore-multi-unique-conflict { + CREATE TABLE t(a UNIQUE, b UNIQUE, c); + INSERT INTO t VALUES (1,10,100),(2,20,200); + INSERT OR IGNORE INTO t VALUES (1,30,300); -- conflicts on a + INSERT OR IGNORE INTO t VALUES (3,20,300); -- conflicts on b + INSERT OR IGNORE INTO t VALUES (1,20,300); -- conflicts on both + SELECT a,b,c FROM t ORDER BY a; +} {1|10|100 +2|20|200} + +do_execsql_test_on_specific_db {:memory:} ignore-some-conflicts-multirow { + CREATE TABLE t(a INTEGER UNIQUE); + INSERT INTO t VALUES (2),(4); + INSERT OR IGNORE INTO t VALUES (1),(2),(3),(4),(5); + SELECT a FROM t ORDER BY a; +} {1 +2 +3 +4 +5} + +do_execsql_test_on_specific_db {:memory:} ignore-from-select { + CREATE TABLE src(x); + INSERT INTO src VALUES (1),(2),(2),(3); + CREATE TABLE dst(a INTEGER UNIQUE); + INSERT INTO dst VALUES (2); + INSERT OR IGNORE INTO dst SELECT x FROM src; + SELECT a FROM dst ORDER BY a; +} {1 +2 +3} + +do_execsql_test_on_specific_db {:memory:} ignore-null-in-unique { + CREATE TABLE t(a INTEGER UNIQUE); + INSERT INTO t VALUES (1),(NULL),(NULL); + INSERT OR IGNORE INTO t VALUES (1),(NULL); + SELECT COUNT(*) FROM t WHERE a IS NULL; +} {3} + +do_execsql_test_on_specific_db {:memory:} ignore-preserves-rowid { + CREATE TABLE t(data TEXT UNIQUE); + INSERT INTO t VALUES ('x'),('y'),('z'); + SELECT rowid, data FROM t WHERE data='y'; + INSERT OR IGNORE INTO t VALUES ('y'); + SELECT rowid, data FROM t WHERE data='y'; +} {2|y +2|y} + +do_execsql_test_on_specific_db {:memory:} ignore-intra-statement-dups { + CREATE TABLE t(a INTEGER PRIMARY KEY, b TEXT); + INSERT OR IGNORE INTO t VALUES (5,'first'),(6,'x'),(5,'second'),(5,'third'); + SELECT a,b FROM t ORDER BY a; +} {5|first +6|x} diff --git a/testing/join.test b/testing/join.test index c14b96ab4..6c6aa7314 100755 --- a/testing/join.test +++ b/testing/join.test @@ -384,3 +384,10 @@ do_execsql_test_on_specific_db {:memory:} left-join-using-null { select a, b from t left join s using (a, b); } {1| 2|} + +# Regression test for: https://github.com/tursodatabase/turso/issues/3656 +do_execsql_test_on_specific_db {:memory:} redundant-join-condition { + create table t(x); + insert into t values ('lol'); + select t1.x from t t1 join t t2 on t1.x=t2.x where t1.x=t2.x; +} {lol} \ No newline at end of file diff --git a/testing/materialized_views.test b/testing/materialized_views.test index dd2652d7d..2827f30d4 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -3,6 +3,35 @@ set testdir [file dirname $argv0] source $testdir/tester.tcl +# Test that INSERT with reused primary keys maintains correct aggregate counts +# When a row is deleted and a new row is inserted with the same primary key +# but different group value, all groups should maintain correct counts +do_execsql_test_on_specific_db {:memory:} matview-insert-reused-key-maintains-all-groups { + CREATE TABLE t(id INTEGER PRIMARY KEY, val TEXT); + INSERT INTO t VALUES (1, 'A'), (2, 'B'); + + CREATE MATERIALIZED VIEW v AS + SELECT val, COUNT(*) as cnt + FROM t + GROUP BY val; + + -- Initial state: A=1, B=1 + SELECT * FROM v ORDER BY val; + + -- Delete id=1 (which has 'A') + DELETE FROM t WHERE id = 1; + SELECT * FROM v ORDER BY val; + + -- Insert id=1 with different value 'C' + -- This should NOT affect group 'B' + INSERT INTO t VALUES (1, 'C'); + SELECT * FROM v ORDER BY val; +} {A|1 +B|1 +B|1 +B|1 +C|1} + do_execsql_test_on_specific_db {:memory:} matview-basic-filter-population { CREATE TABLE products(id INTEGER, name TEXT, price INTEGER, category TEXT); INSERT INTO products VALUES @@ -37,9 +66,9 @@ do_execsql_test_on_specific_db {:memory:} matview-aggregation-population { GROUP BY day; SELECT * FROM daily_totals ORDER BY day; -} {1|7|2 -2|2|2 -3|4|2} +} {1|7.0|2 +2|2.0|2 +3|4.0|2} do_execsql_test_on_specific_db {:memory:} matview-filter-with-groupby { CREATE TABLE t(a INTEGER, b INTEGER); @@ -52,9 +81,9 @@ do_execsql_test_on_specific_db {:memory:} matview-filter-with-groupby { GROUP BY b; SELECT * FROM v ORDER BY yourb; -} {3|3|1 -6|6|1 -7|7|1} +} {3|3.0|1 +6|6.0|1 +7|7.0|1} do_execsql_test_on_specific_db {:memory:} matview-insert-maintenance { CREATE TABLE t(a INTEGER, b INTEGER); @@ -72,12 +101,12 @@ do_execsql_test_on_specific_db {:memory:} matview-insert-maintenance { INSERT INTO t VALUES (1,1), (2,2); SELECT * FROM v ORDER BY b; -} {3|3|1 -6|6|1 -3|7|2 -6|11|2 -3|7|2 -6|11|2} +} {3|3.0|1 +6|6.0|1 +3|7.0|2 +6|11.0|2 +3|7.0|2 +6|11.0|2} do_execsql_test_on_specific_db {:memory:} matview-delete-maintenance { CREATE TABLE items(id INTEGER, category TEXT, amount INTEGER); @@ -100,11 +129,11 @@ do_execsql_test_on_specific_db {:memory:} matview-delete-maintenance { DELETE FROM items WHERE category = 'B'; SELECT * FROM category_sums ORDER BY category; -} {A|90|3 -B|60|2 -A|60|2 -B|60|2 -A|60|2} +} {A|90.0|3 +B|60.0|2 +A|60.0|2 +B|60.0|2 +A|60.0|2} do_execsql_test_on_specific_db {:memory:} matview-update-maintenance { CREATE TABLE records(id INTEGER, value INTEGER, status INTEGER); @@ -126,12 +155,12 @@ do_execsql_test_on_specific_db {:memory:} matview-update-maintenance { UPDATE records SET status = 2 WHERE id = 3; SELECT * FROM status_totals ORDER BY status; -} {1|400|2 -2|600|2 -1|450|2 -2|600|2 -1|150|1 -2|900|3} +} {1|400.0|2 +2|600.0|2 +1|450.0|2 +2|600.0|2 +1|150.0|1 +2|900.0|3} do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-basic { CREATE TABLE t(a INTEGER PRIMARY KEY, b INTEGER); @@ -214,12 +243,12 @@ do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-with-aggre DELETE FROM t WHERE a = 3; SELECT * FROM v ORDER BY b; -} {10|500|1 -20|700|2 -10|600|2 -20|700|2 -10|600|2 -20|400|1} +} {10|500.0|1 +20|700.0|2 +10|600.0|2 +20|700.0|2 +10|600.0|2 +20|400.0|1} do_execsql_test_on_specific_db {:memory:} matview-complex-filter-aggregation { CREATE TABLE transactions( @@ -253,17 +282,17 @@ do_execsql_test_on_specific_db {:memory:} matview-complex-filter-aggregation { DELETE FROM transactions WHERE id = 3; SELECT * FROM account_deposits ORDER BY account; -} {100|70|2 -200|100|1 -300|60|1 -100|95|3 -200|100|1 -300|60|1 -100|125|3 -200|100|1 -300|60|1 -100|125|3 -300|60|1} +} {100|70.0|2 +200|100.0|1 +300|60.0|1 +100|95.0|3 +200|100.0|1 +300|60.0|1 +100|125.0|3 +200|100.0|1 +300|60.0|1 +100|125.0|3 +300|60.0|1} do_execsql_test_on_specific_db {:memory:} matview-sum-count-only { CREATE TABLE data(id INTEGER, value INTEGER, category INTEGER); @@ -288,12 +317,12 @@ do_execsql_test_on_specific_db {:memory:} matview-sum-count-only { UPDATE data SET value = 35 WHERE id = 3; SELECT * FROM category_stats ORDER BY category; -} {1|80|3 -2|70|2 -1|85|4 -2|70|2 -1|85|4 -2|75|2} +} {1|80.0|3 +2|70.0|2 +1|85.0|4 +2|70.0|2 +1|85.0|4 +2|75.0|2} do_execsql_test_on_specific_db {:memory:} matview-empty-table-population { CREATE TABLE t(a INTEGER, b INTEGER); @@ -308,8 +337,8 @@ do_execsql_test_on_specific_db {:memory:} matview-empty-table-population { INSERT INTO t VALUES (1, 3), (2, 7), (3, 9); SELECT * FROM v ORDER BY b; } {0 -7|2|1 -9|3|1} +7|2.0|1 +9|3.0|1} do_execsql_test_on_specific_db {:memory:} matview-all-rows-filtered { CREATE TABLE t(a INTEGER, b INTEGER); @@ -357,17 +386,17 @@ do_execsql_test_on_specific_db {:memory:} matview-mixed-operations-sequence { INSERT INTO orders VALUES (4, 300, 150); SELECT * FROM customer_totals ORDER BY customer_id; -} {100|50|1 -200|75|1 -100|75|2 -200|75|1 -100|75|2 -200|100|1 -100|25|1 -200|100|1 -100|25|1 -200|100|1 -300|150|1} +} {100|50.0|1 +200|75.0|1 +100|75.0|2 +200|75.0|1 +100|75.0|2 +200|100.0|1 +100|25.0|1 +200|100.0|1 +100|25.0|1 +200|100.0|1 +300|150.0|1} do_execsql_test_on_specific_db {:memory:} matview-projections { CREATE TABLE t(a,b); @@ -473,13 +502,13 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-aggregation { ROLLBACK; SELECT * FROM product_totals ORDER BY product_id; -} {1|300|2 -2|400|2 -1|350|3 -2|400|2 -3|300|1 -1|300|2 -2|400|2} +} {1|300.0|2 +2|400.0|2 +1|350.0|3 +2|400.0|2 +3|300.0|1 +1|300.0|2 +2|400.0|2} do_execsql_test_on_specific_db {:memory:} matview-rollback-mixed-operations { CREATE TABLE orders(id INTEGER PRIMARY KEY, customer INTEGER, amount INTEGER); @@ -500,12 +529,12 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-mixed-operations { ROLLBACK; SELECT * FROM customer_totals ORDER BY customer; -} {100|75|2 -200|75|1 -100|150|2 -200|150|1 -100|75|2 -200|75|1} +} {100|75.0|2 +200|75.0|1 +100|150.0|2 +200|150.0|1 +100|75.0|2 +200|75.0|1} do_execsql_test_on_specific_db {:memory:} matview-rollback-filtered-aggregation { CREATE TABLE transactions(id INTEGER, account INTEGER, amount INTEGER, type TEXT); @@ -531,11 +560,11 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-filtered-aggregation ROLLBACK; SELECT * FROM deposits ORDER BY account; -} {100|50|1 -200|100|1 -100|135|2 -100|50|1 -200|100|1} +} {100|50.0|1 +200|100.0|1 +100|135.0|2 +100|50.0|1 +200|100.0|1} do_execsql_test_on_specific_db {:memory:} matview-rollback-empty-view { CREATE TABLE t(a INTEGER, b INTEGER); @@ -590,8 +619,8 @@ do_execsql_test_on_specific_db {:memory:} matview-join-with-aggregation { GROUP BY u.name; SELECT * FROM user_totals ORDER BY name; -} {Alice|250 -Bob|250} +} {Alice|250.0 +Bob|250.0} do_execsql_test_on_specific_db {:memory:} matview-three-way-join { CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, city TEXT); @@ -632,9 +661,9 @@ do_execsql_test_on_specific_db {:memory:} matview-three-way-join-with-aggregatio GROUP BY c.name, p.name; SELECT * FROM sales_totals ORDER BY customer_name, product_name; -} {Alice|Gadget|3|60 -Alice|Widget|9|90 -Bob|Widget|2|20} +} {Alice|Gadget|3.0|60.0 +Alice|Widget|9.0|90.0 +Bob|Widget|2.0|20.0} do_execsql_test_on_specific_db {:memory:} matview-join-incremental-insert { CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); @@ -835,9 +864,9 @@ do_execsql_test_on_specific_db {:memory:} matview-aggregation-before-join { GROUP BY c.id, c.name, c.tier; SELECT * FROM customer_order_summary ORDER BY total_quantity DESC; -} {Bob|Silver|2|7 -Alice|Gold|3|6 -Charlie|Bronze|1|1} +} {Bob|Silver|2|7.0 +Alice|Gold|3|6.0 +Charlie|Bronze|1|1.0} # Test 4: Join with aggregation AFTER the join do_execsql_test_on_specific_db {:memory:} matview-aggregation-after-join { @@ -865,8 +894,8 @@ do_execsql_test_on_specific_db {:memory:} matview-aggregation-after-join { GROUP BY st.region; SELECT * FROM regional_sales ORDER BY total_revenue DESC; -} {North|38|3150 -South|18|1500} +} {North|38.0|3150.0 +South|18.0|1500.0} # Test 5: Modifying both tables in same transaction do_execsql_test_on_specific_db {:memory:} matview-join-both-tables-modified { @@ -1194,8 +1223,8 @@ do_execsql_test_on_specific_db {:memory:} matview-union-with-aggregation { FROM q2_sales; SELECT * FROM half_year_summary ORDER BY quarter; -} {Q1|68|16450 -Q2|105|21750} +} {Q1|68.0|16450.0 +Q2|105.0|21750.0} do_execsql_test_on_specific_db {:memory:} matview-union-with-join { CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, type TEXT); @@ -1625,8 +1654,8 @@ do_execsql_test_on_specific_db {:memory:} matview-groupby-scalar-function { GROUP BY substr(orderdate, 1, 4); SELECT * FROM yearly_totals ORDER BY 1; -} {2020|250 -2021|200} +} {2020|250.0 +2021|200.0} do_execsql_test_on_specific_db {:memory:} matview-groupby-alias { CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER); @@ -1640,8 +1669,8 @@ do_execsql_test_on_specific_db {:memory:} matview-groupby-alias { GROUP BY year; SELECT * FROM yearly_totals ORDER BY year; -} {2020|250 -2021|200} +} {2020|250.0 +2021|200.0} do_execsql_test_on_specific_db {:memory:} matview-groupby-position { CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER, nation TEXT); @@ -1655,8 +1684,8 @@ do_execsql_test_on_specific_db {:memory:} matview-groupby-position { GROUP BY 1, 2; SELECT * FROM national_yearly ORDER BY nation, year; -} {UK|2021|200 -USA|2020|250} +} {UK|2021|200.0 +USA|2020|250.0} do_execsql_test_on_specific_db {:memory:} matview-groupby-scalar-incremental { CREATE TABLE orders(id INTEGER, orderdate TEXT, amount INTEGER); @@ -1672,10 +1701,10 @@ do_execsql_test_on_specific_db {:memory:} matview-groupby-scalar-incremental { SELECT * FROM yearly_totals; INSERT INTO orders VALUES (3, '2021-03-20', 200); SELECT * FROM yearly_totals ORDER BY year; -} {2020|100 -2020|250 -2020|250 -2021|200} +} {2020|100.0 +2020|250.0 +2020|250.0 +2021|200.0} do_execsql_test_on_specific_db {:memory:} matview-groupby-join-position { CREATE TABLE t(a INTEGER); @@ -1691,3 +1720,777 @@ do_execsql_test_on_specific_db {:memory:} matview-groupby-join-position { SELECT * FROM tujoingroup ORDER BY a; } {1|2 2|1} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-basic { + CREATE TABLE items(id INTEGER, category TEXT, value INTEGER); + INSERT INTO items VALUES + (1, 'A', 100), + (2, 'B', 200), + (3, 'A', 100), -- duplicate of row 1 + (4, 'C', 300), + (5, 'B', 200), -- duplicate of row 2 + (6, 'A', 100); -- another duplicate of row 1 + + CREATE MATERIALIZED VIEW distinct_items AS + SELECT DISTINCT category, value FROM items; + + SELECT category, value FROM distinct_items ORDER BY category, value; +} {A|100 +B|200 +C|300} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-single-column { + CREATE TABLE numbers(n INTEGER); + INSERT INTO numbers VALUES (1), (2), (1), (3), (2), (1), (4), (3); + + CREATE MATERIALIZED VIEW distinct_numbers AS + SELECT DISTINCT n FROM numbers; + + SELECT n FROM distinct_numbers ORDER BY n; +} {1 +2 +3 +4} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-incremental-insert { + CREATE TABLE data(x INTEGER, y TEXT); + CREATE MATERIALIZED VIEW distinct_data AS + SELECT DISTINCT x, y FROM data; + + -- Initial data + INSERT INTO data VALUES (1, 'alpha'), (2, 'beta'), (1, 'alpha'); + SELECT x, y FROM distinct_data ORDER BY x, y; +} {1|alpha +2|beta} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-incremental-insert-new { + CREATE TABLE data(x INTEGER, y TEXT); + CREATE MATERIALIZED VIEW distinct_data AS + SELECT DISTINCT x, y FROM data; + + -- Initial data + INSERT INTO data VALUES (1, 'alpha'), (2, 'beta'), (1, 'alpha'); + + -- Add new distinct values + INSERT INTO data VALUES (3, 'gamma'), (4, 'delta'); + SELECT x, y FROM distinct_data ORDER BY x, y; +} {1|alpha +2|beta +3|gamma +4|delta} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-incremental-insert-dups { + CREATE TABLE data(x INTEGER, y TEXT); + CREATE MATERIALIZED VIEW distinct_data AS + SELECT DISTINCT x, y FROM data; + + -- Initial data with some new values + INSERT INTO data VALUES + (1, 'alpha'), (2, 'beta'), (1, 'alpha'), + (3, 'gamma'), (4, 'delta'); + + -- Add duplicates of existing values + INSERT INTO data VALUES (1, 'alpha'), (2, 'beta'), (3, 'gamma'); + -- Should be same as before the duplicate insert + SELECT x, y FROM distinct_data ORDER BY x, y; +} {1|alpha +2|beta +3|gamma +4|delta} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-incremental-delete { + CREATE TABLE records(id INTEGER PRIMARY KEY, category TEXT, status INTEGER); + INSERT INTO records VALUES + (1, 'X', 1), + (2, 'Y', 2), + (3, 'X', 1), -- duplicate values + (4, 'Z', 3), + (5, 'Y', 2); -- duplicate values + + CREATE MATERIALIZED VIEW distinct_records AS + SELECT DISTINCT category, status FROM records; + + SELECT category, status FROM distinct_records ORDER BY category, status; +} {X|1 +Y|2 +Z|3} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-incremental-delete-partial { + CREATE TABLE records(id INTEGER PRIMARY KEY, category TEXT, status INTEGER); + INSERT INTO records VALUES + (1, 'X', 1), + (2, 'Y', 2), + (3, 'X', 1), -- duplicate values + (4, 'Z', 3), + (5, 'Y', 2); -- duplicate values + + CREATE MATERIALIZED VIEW distinct_records AS + SELECT DISTINCT category, status FROM records; + + -- Delete one instance of duplicate + DELETE FROM records WHERE id = 3; + -- X|1 should still exist (one instance remains) + SELECT category, status FROM distinct_records ORDER BY category, status; +} {X|1 +Y|2 +Z|3} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-incremental-delete-all { + CREATE TABLE records(id INTEGER PRIMARY KEY, category TEXT, status INTEGER); + INSERT INTO records VALUES + (1, 'X', 1), + (2, 'Y', 2), + (4, 'Z', 3), + (5, 'Y', 2); -- duplicate values + + CREATE MATERIALIZED VIEW distinct_records AS + SELECT DISTINCT category, status FROM records; + + -- Delete all instances of X|1 + DELETE FROM records WHERE category = 'X' AND status = 1; + -- Now X|1 should be gone + SELECT category, status FROM distinct_records ORDER BY category, status; +} {Y|2 +Z|3} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-incremental-reappear { + CREATE TABLE records(id INTEGER PRIMARY KEY, category TEXT, status INTEGER); + INSERT INTO records VALUES + (2, 'Y', 2), + (4, 'Z', 3), + (5, 'Y', 2); -- duplicate values + + CREATE MATERIALIZED VIEW distinct_records AS + SELECT DISTINCT category, status FROM records; + + -- Re-add a previously deleted value + INSERT INTO records VALUES (6, 'X', 1); + -- X|1 should appear + SELECT category, status FROM distinct_records ORDER BY category, status; +} {X|1 +Y|2 +Z|3} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-null-handling { + CREATE TABLE nullable(a INTEGER, b TEXT); + INSERT INTO nullable VALUES + (1, 'one'), + (2, NULL), + (NULL, 'three'), + (1, 'one'), -- duplicate + (2, NULL), -- duplicate with NULL + (NULL, 'three'), -- duplicate with NULL + (NULL, NULL); + + CREATE MATERIALIZED VIEW distinct_nullable AS + SELECT DISTINCT a, b FROM nullable; + + -- NULLs should be handled as distinct values + SELECT a, b FROM distinct_nullable ORDER BY a, b; +} {| +|three +1|one +2|} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-empty-table { + CREATE TABLE empty_source(x INTEGER, y TEXT); + CREATE MATERIALIZED VIEW distinct_empty AS + SELECT DISTINCT x, y FROM empty_source; + + -- Should be empty + SELECT COUNT(*) FROM distinct_empty; +} {0} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-empty-then-insert { + CREATE TABLE empty_source(x INTEGER, y TEXT); + CREATE MATERIALIZED VIEW distinct_empty AS + SELECT DISTINCT x, y FROM empty_source; + + -- Insert into previously empty table + INSERT INTO empty_source VALUES (1, 'first'), (1, 'first'), (2, 'second'); + SELECT x, y FROM distinct_empty ORDER BY x, y; +} {1|first +2|second} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-multi-column-types { + CREATE TABLE mixed_types(i INTEGER, t TEXT, r REAL, b BLOB); + INSERT INTO mixed_types VALUES + (1, 'text1', 1.5, x'0102'), + (2, 'text2', 2.5, x'0304'), + (1, 'text1', 1.5, x'0102'), -- exact duplicate + (3, 'text3', 3.5, x'0506'), + (2, 'text2', 2.5, x'0304'); -- another duplicate + + CREATE MATERIALIZED VIEW distinct_mixed AS + SELECT DISTINCT i, t FROM mixed_types; + + SELECT i, t FROM distinct_mixed ORDER BY i, t; +} {1|text1 +2|text2 +3|text3} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-update-simulation { + CREATE TABLE updatable(id INTEGER PRIMARY KEY, val TEXT); + INSERT INTO updatable VALUES (1, 'old'), (2, 'old'), (3, 'new'); + + CREATE MATERIALIZED VIEW distinct_vals AS + SELECT DISTINCT val FROM updatable; + + SELECT val FROM distinct_vals ORDER BY val; +} {new +old} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-update-simulation-change { + CREATE TABLE updatable(id INTEGER PRIMARY KEY, val TEXT); + INSERT INTO updatable VALUES (1, 'old'), (2, 'old'), (3, 'new'); + + CREATE MATERIALIZED VIEW distinct_vals AS + SELECT DISTINCT val FROM updatable; + + -- Simulate update by delete + insert + DELETE FROM updatable WHERE id = 1; + INSERT INTO updatable VALUES (1, 'new'); + + -- Now we have two 'new' and one 'old', but distinct shows each once + SELECT val FROM distinct_vals ORDER BY val; +} {new +old} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-update-simulation-all-same { + CREATE TABLE updatable(id INTEGER PRIMARY KEY, val TEXT); + INSERT INTO updatable VALUES (1, 'new'), (2, 'old'), (3, 'new'); + + CREATE MATERIALIZED VIEW distinct_vals AS + SELECT DISTINCT val FROM updatable; + + -- Change the 'old' to 'new' + DELETE FROM updatable WHERE id = 2; + INSERT INTO updatable VALUES (2, 'new'); + + -- Now all three rows have 'new', old should disappear + SELECT val FROM distinct_vals ORDER BY val; +} {new} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-large-duplicates { + CREATE TABLE many_dups(x INTEGER); + + -- Insert many duplicates + INSERT INTO many_dups VALUES (1), (1), (1), (1), (1); + INSERT INTO many_dups VALUES (2), (2), (2), (2), (2); + INSERT INTO many_dups VALUES (3), (3), (3), (3), (3); + + CREATE MATERIALIZED VIEW distinct_many AS + SELECT DISTINCT x FROM many_dups; + + SELECT x FROM distinct_many ORDER BY x; +} {1 +2 +3} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-large-duplicates-remove { + CREATE TABLE many_dups(x INTEGER); + + -- Insert many duplicates + INSERT INTO many_dups VALUES (1), (1), (1), (1), (1); + INSERT INTO many_dups VALUES (2), (2), (2), (2), (2); + INSERT INTO many_dups VALUES (3), (3), (3), (3), (3); + + CREATE MATERIALIZED VIEW distinct_many AS + SELECT DISTINCT x FROM many_dups; + + -- Remove some instances of value 2 (rowids 7,8,9,10 keeping rowid 6) + DELETE FROM many_dups WHERE rowid IN (7, 8, 9, 10); + + -- Should still have all three values + SELECT x FROM distinct_many ORDER BY x; +} {1 +2 +3} + +do_execsql_test_on_specific_db {:memory:} matview-distinct-large-duplicates-remove-all { + CREATE TABLE many_dups(x INTEGER); + + -- Insert many duplicates but only one instance of 2 + INSERT INTO many_dups VALUES (1), (1), (1), (1), (1); + INSERT INTO many_dups VALUES (2); + INSERT INTO many_dups VALUES (3), (3), (3), (3), (3); + + CREATE MATERIALIZED VIEW distinct_many AS + SELECT DISTINCT x FROM many_dups; + + -- Remove ALL instances of value 2 + DELETE FROM many_dups WHERE x = 2; + + -- Now 2 should be gone + SELECT x FROM distinct_many ORDER BY x; +} {1 +3} + +# COUNT(DISTINCT) tests for materialized views + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-basic { + CREATE TABLE sales(region TEXT, product TEXT, amount INTEGER); + INSERT INTO sales VALUES + ('North', 'A', 100), + ('North', 'A', 100), -- Duplicate + ('North', 'B', 200), + ('South', 'A', 150), + ('South', 'A', 150); -- Duplicate + + CREATE MATERIALIZED VIEW sales_summary AS + SELECT region, COUNT(DISTINCT product) as unique_products + FROM sales GROUP BY region; + + SELECT * FROM sales_summary ORDER BY region; +} {North|2 +South|1} + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-nulls { + -- COUNT(DISTINCT) should ignore NULL values per SQL standard + CREATE TABLE data(grp INTEGER, val INTEGER); + INSERT INTO data VALUES + (1, 10), + (1, 20), + (1, NULL), + (1, NULL), -- Multiple NULLs + (2, 30), + (2, NULL); + + CREATE MATERIALIZED VIEW v AS + SELECT grp, COUNT(DISTINCT val) as cnt FROM data GROUP BY grp; + + SELECT * FROM v ORDER BY grp; + + -- Add more NULLs (should not affect count) + INSERT INTO data VALUES (1, NULL), (2, NULL); + SELECT * FROM v ORDER BY grp; + + -- Add a non-NULL value + INSERT INTO data VALUES (2, 40); + SELECT * FROM v ORDER BY grp; +} {1|2 +2|1 +1|2 +2|1 +1|2 +2|2} + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-empty-groups { + CREATE TABLE items(category TEXT, item TEXT); + INSERT INTO items VALUES + ('A', 'x'), + ('A', 'y'), + ('B', 'z'); + + CREATE MATERIALIZED VIEW category_counts AS + SELECT category, COUNT(DISTINCT item) as unique_items + FROM items GROUP BY category; + + SELECT * FROM category_counts ORDER BY category; + + -- Delete all items from category B + DELETE FROM items WHERE category = 'B'; + SELECT * FROM category_counts ORDER BY category; + + -- Re-add items to B + INSERT INTO items VALUES ('B', 'w'), ('B', 'w'); -- Same value twice + SELECT * FROM category_counts ORDER BY category; +} {A|2 +B|1 +A|2 +A|2 +B|1} + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-updates { + CREATE TABLE records(id INTEGER PRIMARY KEY, grp TEXT, val INTEGER); + INSERT INTO records VALUES + (1, 'X', 100), + (2, 'X', 200), + (3, 'Y', 100), + (4, 'Y', 100); -- Duplicate + + CREATE MATERIALIZED VIEW grp_summary AS + SELECT grp, COUNT(DISTINCT val) as distinct_vals + FROM records GROUP BY grp; + + SELECT * FROM grp_summary ORDER BY grp; + + -- Update that changes group membership + UPDATE records SET grp = 'Y' WHERE id = 1; + SELECT * FROM grp_summary ORDER BY grp; + + -- Update that changes value within group + UPDATE records SET val = 300 WHERE id = 3; + SELECT * FROM grp_summary ORDER BY grp; +} {X|2 +Y|1 +X|1 +Y|1 +X|1 +Y|2} + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-large-scale { + CREATE TABLE events(user_id INTEGER, event_type INTEGER); + + -- Insert many events with varying duplication + INSERT INTO events VALUES + (1, 1), (1, 1), (1, 1), (1, 2), (1, 3), -- User 1: 3 distinct + (2, 1), (2, 2), (2, 2), (2, 2), -- User 2: 2 distinct + (3, 4), (3, 4), (3, 4), (3, 4), (3, 4); -- User 3: 1 distinct + + CREATE MATERIALIZED VIEW user_stats AS + SELECT user_id, COUNT(DISTINCT event_type) as unique_events + FROM events GROUP BY user_id; + + SELECT * FROM user_stats ORDER BY user_id; + + -- Mass deletion + DELETE FROM events WHERE event_type = 2; + SELECT * FROM user_stats ORDER BY user_id; + + -- Mass insertion with duplicates + INSERT INTO events VALUES + (1, 5), (1, 5), (1, 6), + (2, 5), (2, 6), (2, 7); + SELECT * FROM user_stats ORDER BY user_id; +} {1|3 +2|2 +3|1 +1|2 +2|1 +3|1 +1|4 +2|4 +3|1} + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-group-by-empty-start { + CREATE TABLE measurements(device TEXT, reading INTEGER); + + CREATE MATERIALIZED VIEW device_summary AS + SELECT device, COUNT(DISTINCT reading) as unique_readings + FROM measurements GROUP BY device; + + -- Start with empty table (no groups yet) + SELECT COUNT(*) FROM device_summary; + + -- Add first group + INSERT INTO measurements VALUES ('D1', 100), ('D1', 100); + SELECT * FROM device_summary; + + -- Add second group with distinct values + INSERT INTO measurements VALUES ('D2', 200), ('D2', 300); + SELECT * FROM device_summary ORDER BY device; + + -- Remove all data + DELETE FROM measurements; + SELECT COUNT(*) FROM device_summary; +} {0 +D1|1 +D1|1 +D2|2 +0} + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-single-row-groups { + CREATE TABLE singles(k TEXT PRIMARY KEY, v INTEGER); + INSERT INTO singles VALUES + ('a', 1), + ('b', 2), + ('c', 3); + + CREATE MATERIALIZED VIEW v AS + SELECT k, COUNT(DISTINCT v) as cnt FROM singles GROUP BY k; + + SELECT * FROM v ORDER BY k; + + -- Each group has exactly one row, so COUNT(DISTINCT v) = 1 + UPDATE singles SET v = 999 WHERE k = 'b'; + SELECT * FROM v ORDER BY k; + + DELETE FROM singles WHERE k = 'c'; + SELECT * FROM v ORDER BY k; +} {a|1 +b|1 +c|1 +a|1 +b|1 +c|1 +a|1 +b|1} + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-transactions { + CREATE TABLE txn_data(grp TEXT, val INTEGER); + INSERT INTO txn_data VALUES ('A', 1), ('A', 2); + + CREATE MATERIALIZED VIEW txn_view AS + SELECT grp, COUNT(DISTINCT val) as cnt FROM txn_data GROUP BY grp; + + SELECT * FROM txn_view; + + -- Transaction that adds duplicates (should not change count) + BEGIN; + INSERT INTO txn_data VALUES ('A', 1), ('A', 2); + SELECT * FROM txn_view; + COMMIT; + + SELECT * FROM txn_view; + + -- Transaction that adds new distinct value then rolls back + BEGIN; + INSERT INTO txn_data VALUES ('A', 3); + SELECT * FROM txn_view; + ROLLBACK; + + SELECT * FROM txn_view; +} {A|2 +A|2 +A|2 +A|3 +A|2} + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-text-values { + CREATE TABLE strings(category INTEGER, str TEXT); + INSERT INTO strings VALUES + (1, 'hello'), + (1, 'world'), + (1, 'hello'), -- Duplicate + (2, 'foo'), + (2, 'bar'), + (2, 'bar'); -- Duplicate + + CREATE MATERIALIZED VIEW str_counts AS + SELECT category, COUNT(DISTINCT str) as unique_strings + FROM strings GROUP BY category; + + SELECT * FROM str_counts ORDER BY category; + + -- Case sensitivity test + INSERT INTO strings VALUES (1, 'HELLO'), (2, 'FOO'); + SELECT * FROM str_counts ORDER BY category; + + -- Empty strings + INSERT INTO strings VALUES (1, ''), (1, ''), (2, ''); + SELECT * FROM str_counts ORDER BY category; +} {1|2 +2|2 +1|3 +2|3 +1|4 +2|4} + +do_execsql_test_on_specific_db {:memory:} matview-sum-distinct { + CREATE TABLE sales(region TEXT, amount INTEGER); + INSERT INTO sales VALUES + ('North', 100), + ('North', 200), + ('North', 100), -- Duplicate + ('North', NULL), + ('South', 300), + ('South', 300), -- Duplicate + ('South', 400); + + CREATE MATERIALIZED VIEW sales_summary AS + SELECT region, SUM(DISTINCT amount) as total_distinct + FROM sales GROUP BY region; + + SELECT * FROM sales_summary ORDER BY region; + + -- Add a duplicate value + INSERT INTO sales VALUES ('North', 200); + SELECT * FROM sales_summary ORDER BY region; + + -- Add a new distinct value + INSERT INTO sales VALUES ('South', 500); + SELECT * FROM sales_summary ORDER BY region; +} {North|300.0 +South|700.0 +North|300.0 +South|700.0 +North|300.0 +South|1200.0} + +do_execsql_test_on_specific_db {:memory:} matview-avg-distinct { + CREATE TABLE grades(student TEXT, score INTEGER); + INSERT INTO grades VALUES + ('Alice', 90), + ('Alice', 85), + ('Alice', 90), -- Duplicate + ('Alice', NULL), + ('Bob', 75), + ('Bob', 80), + ('Bob', 75); -- Duplicate + + CREATE MATERIALIZED VIEW avg_grades AS + SELECT student, AVG(DISTINCT score) as avg_score + FROM grades GROUP BY student; + + SELECT * FROM avg_grades ORDER BY student; + + -- Add duplicate scores + INSERT INTO grades VALUES ('Alice', 85), ('Bob', 80); + SELECT * FROM avg_grades ORDER BY student; + + -- Add new distinct score + INSERT INTO grades VALUES ('Alice', 95); + SELECT * FROM avg_grades ORDER BY student; +} {Alice|87.5 +Bob|77.5 +Alice|87.5 +Bob|77.5 +Alice|90.0 +Bob|77.5} + +do_execsql_test_on_specific_db {:memory:} matview-min-distinct { + CREATE TABLE metrics(category TEXT, value INTEGER); + INSERT INTO metrics VALUES + ('A', 10), + ('A', 20), + ('A', 10), -- Duplicate + ('A', 30), + ('A', NULL), + ('B', 5), + ('B', 15), + ('B', 5); -- Duplicate + + CREATE MATERIALIZED VIEW metric_min AS + SELECT category, + MIN(DISTINCT value) as min_val + FROM metrics GROUP BY category; + + SELECT * FROM metric_min ORDER BY category; + + -- Add values that don't change min + INSERT INTO metrics VALUES ('A', 15), ('B', 10); + SELECT * FROM metric_min ORDER BY category; + + -- Add values that change min + INSERT INTO metrics VALUES ('A', 5), ('B', 3); + SELECT * FROM metric_min ORDER BY category; +} {A|10 +B|5 +A|10 +B|5 +A|5 +B|3} + +do_execsql_test_on_specific_db {:memory:} matview-max-distinct { + CREATE TABLE metrics2(category TEXT, value INTEGER); + INSERT INTO metrics2 VALUES + ('A', 10), + ('A', 20), + ('A', 10), -- Duplicate + ('A', 30), + ('A', NULL), + ('B', 5), + ('B', 15), + ('B', 5); -- Duplicate + + CREATE MATERIALIZED VIEW metric_max AS + SELECT category, + MAX(DISTINCT value) as max_val + FROM metrics2 GROUP BY category; + + SELECT * FROM metric_max ORDER BY category; + + -- Add values that don't change max + INSERT INTO metrics2 VALUES ('A', 15), ('B', 10); + SELECT * FROM metric_max ORDER BY category; + + -- Add values that change max + INSERT INTO metrics2 VALUES ('A', 40), ('B', 20); + SELECT * FROM metric_max ORDER BY category; +} {A|30 +B|15 +A|30 +B|15 +A|40 +B|20} + +do_execsql_test_on_specific_db {:memory:} matview-multiple-distinct-aggregates-with-groupby { + CREATE TABLE data(grp TEXT, x INTEGER, y INTEGER, z INTEGER); + INSERT INTO data VALUES + ('A', 1, 10, 100), + ('A', 2, 20, 200), + ('A', 1, 10, 300), -- x,y duplicates + ('A', 3, 30, 100), -- z duplicate + ('A', NULL, 40, 400), + ('B', 4, 50, 500), + ('B', 5, 50, 600), -- y duplicate + ('B', 4, 60, 700), -- x duplicate + ('B', 6, NULL, 500), -- z duplicate + ('B', NULL, 70, NULL); + + CREATE MATERIALIZED VIEW multi_distinct AS + SELECT grp, + COUNT(DISTINCT x) as cnt_x, + SUM(DISTINCT y) as sum_y, + AVG(DISTINCT z) as avg_z + FROM data GROUP BY grp; + + SELECT * FROM multi_distinct ORDER BY grp; + + -- Add more data with duplicates + INSERT INTO data VALUES + ('A', 1, 20, 200), -- Existing values + ('B', 7, 80, 800); -- New values + + SELECT * FROM multi_distinct ORDER BY grp; +} {A|3|100.0|250.0 +B|3|180.0|600.0 +A|3|100.0|250.0 +B|4|260.0|650.0} + +do_execsql_test_on_specific_db {:memory:} matview-multiple-distinct-aggregates-no-groupby { + CREATE TABLE data2(x INTEGER, y INTEGER, z INTEGER); + INSERT INTO data2 VALUES + (1, 10, 100), + (2, 20, 200), + (1, 10, 300), -- x,y duplicates + (3, 30, 100), -- z duplicate + (NULL, 40, 400), + (4, 50, 500), + (5, 50, 600), -- y duplicate + (4, 60, 700), -- x duplicate + (6, NULL, 500), -- z duplicate + (NULL, 70, NULL); + + CREATE MATERIALIZED VIEW multi_distinct_global AS + SELECT COUNT(DISTINCT x) as cnt_x, + SUM(DISTINCT y) as sum_y, + AVG(DISTINCT z) as avg_z + FROM data2; + + SELECT * FROM multi_distinct_global; + + -- Add more data + INSERT INTO data2 VALUES + (1, 20, 200), -- Existing values + (7, 80, 800); -- New values + + SELECT * FROM multi_distinct_global; +} {6|280.0|400.0 +7|360.0|450.0} + +do_execsql_test_on_specific_db {:memory:} matview-count-distinct-global-aggregate { + CREATE TABLE all_data(val INTEGER); + INSERT INTO all_data VALUES (1), (2), (1), (3), (2); + + CREATE MATERIALIZED VIEW summary AS + SELECT COUNT(DISTINCT val) as total_distinct FROM all_data; + + SELECT * FROM summary; + + -- Add duplicates + INSERT INTO all_data VALUES (1), (2), (3); + SELECT * FROM summary; + + -- Add new distinct values + INSERT INTO all_data VALUES (4), (5); + SELECT * FROM summary; + + -- Delete all of one value + DELETE FROM all_data WHERE val = 3; + SELECT * FROM summary; +} {3 +3 +5 +4} diff --git a/testing/offset.test b/testing/offset.test old mode 100644 new mode 100755 index 720fbebac..935070de7 --- a/testing/offset.test +++ b/testing/offset.test @@ -64,3 +64,52 @@ do_execsql_test_on_specific_db {:memory:} select-ungrouped-aggregate-with-offset INSERT INTO t VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10); SELECT COUNT(a) FROM t LIMIT 1 OFFSET 1; } {} + + +do_execsql_test_on_specific_db {:memory:} offset-expr-can-be-cast-losslessly-1 { + SELECT 1 LIMIT 3 OFFSET 1.1 + 2.9; +} {} + +do_execsql_test_on_specific_db {:memory:} offset-expr-can-be-cast-losslessly-2 { + CREATE TABLE T(a); + INSERT INTO T VALUES (1),(2),(3),(4); + SELECT * FROM T LIMIT 1+'2' OFFSET 1.6/2 + 3.6/3 + 4*0.25; +} {4} + +# Strings are cast to float. Final result is integer losslessly +do_execsql_test_on_specific_db {:memory:} offset-expr-can-be-cast-losslessly-3 { + CREATE TABLE T(a); + INSERT INTO T VALUES (1),(2),(3),(4); + SELECT * FROM T LIMIT 3 OFFSET '0.8' + '1.2' + '4'*'0.25'; +} {4} + +# Strings are cast to 0. Expression still valid. +do_execsql_test_on_specific_db {:memory:} offset-expr-int-and-string { + SELECT 1 LIMIT 3 OFFSET 3/3 + 'test' + 4*'test are best'; +} {} + +do_execsql_test_in_memory_error_content offset-expr-cannot-be-cast-losslessly-1 { + SELECT 1 LIMIT 3 OFFSET 1.1; +} {"datatype mismatch"} + +do_execsql_test_in_memory_error_content offset-expr-cannot-be-cast-losslessly-2 { + SELECT 1 LIMIT 3 OFFSET 1.1 + 2.2 + 1.9/8; +} {"datatype mismatch"} + +# Return error as float in expression cannot be cast losslessly +do_execsql_test_in_memory_error_content offset-expr-cannot-be-cast-losslessly-3 { + SELECT 1 LIMIT 3 OFFSET 1.1 + 'a'; +} {"datatype mismatch"} + +do_execsql_test_in_memory_error_content offset-expr-invalid-data-type-1 { + SELECT 1 LIMIT 3 OFFSET 'a'; +} {"datatype mismatch"} + +do_execsql_test_in_memory_error_content offset-expr-invalid-data-type-2 { + SELECT 1 LIMIT 3 OFFSET NULL; +} {"datatype mismatch"} + +# Expression below evaluates to NULL (string → 0) +do_execsql_test_in_memory_error_content offset-expr-invalid-data-type-3 { + SELECT 1 LIMIT 3 OFFSET 1/'iwillbezero ;-; '; +} {"datatype mismatch"} diff --git a/testing/orderby.test b/testing/orderby.test index e173d946b..3864cf695 100755 --- a/testing/orderby.test +++ b/testing/orderby.test @@ -239,4 +239,41 @@ do_execsql_test_on_specific_db {:memory:} orderby_alias_precedence { INSERT INTO t VALUES (1,200),(2,100); SELECT x AS y, y AS x FROM t ORDER BY x; } {2|100 -1|200} \ No newline at end of file +1|200} + +# Check that ORDER BY with heap-sort properly handle multiple rows with same order key + result values +do_execsql_test_on_specific_db {:memory:} orderby_same_rows { + CREATE TABLE t(x,y,z); + INSERT INTO t VALUES (1,2,3),(1,2,6),(1,2,9),(1,2,10),(1,3,-1),(1,3,-2); + SELECT x, y FROM t ORDER BY x, y LIMIT 10; +} {1|2 +1|2 +1|2 +1|2 +1|3 +1|3} + +# https://github.com/tursodatabase/turso/issues/3684 +do_execsql_test_on_specific_db {:memory:} orderby_alias_shadows_column { + CREATE TABLE t(a, b); + INSERT INTO t VALUES (1, 1), (2, 2), (3, 3); + SELECT a, -b AS a FROM t ORDER BY a; +} {3|-3 +2|-2 +1|-1} + + do_execsql_test_in_memory_any_error order_by_ambiguous_column { + CREATE TABLE a(id INT, value INT); +INSERT INTO a VALUES (1, 10), (2, 20); + +CREATE TABLE b(id INT, value INT); +INSERT INTO b VALUES (1, 100), (2, 200); + +SELECT + a.id, + b.value +FROM + a JOIN b ON a.id = b.id +ORDER BY +value; + } diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index 963ccc5be..30ba59b83 100755 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -1023,7 +1023,58 @@ do_execsql_test sum-8 { } {1.2} +# https://github.com/tursodatabase/turso/issues/3689 +do_execsql_test iif-3-args-true { + select iif(1 < 2, 'yes', 'no'); +} {yes} + +do_execsql_test iif-3-args-false { + select iif(1 > 2, 'yes', 'no'); +} {no} + +do_execsql_test iif-2-args-true { + select iif(1 < 2, 'yes'); +} {yes} + +do_execsql_test iif-2-args-false-is-null { + select iif(1 > 2, 'yes'); +} {} + +do_execsql_test iif-multi-args-finds-first-true { + select iif(0, 'a', 1, 'b', 2, 'c', 'default'); +} {b} + +do_execsql_test iif-multi-args-falls-to-else { + select iif(0, 'a', 0, 'b', 0, 'c', 'default'); +} {default} + +do_execsql_test if-alias-3-args-true { + select if(1 < 2, 'yes', 'no'); +} {yes} + +do_execsql_test if-alias-3-args-false { + select if(1 > 2, 'yes', 'no'); +} {no} + +do_execsql_test if-alias-2-args-true { + select if(1 < 2, 'ok'); +} {ok} + +do_execsql_test if-alias-multi-args-finds-first-true { + select if(0, 'a', 1, 'b', 'c'); +} {b} + +do_execsql_test if-alias-multi-args-falls-to-else { + select if(0, 'a', 0, 'b', 'c'); +} {c} + +do_execsql_test if-alias-multi-args-no-else-is-null { + select if(0, 'a', 0, 'b'); +} {} + + # TODO: sqlite seems not enable soundex() by default unless build it with SQLITE_SOUNDEX enabled. # do_execsql_test soundex-text { # select soundex('Pfister'), soundex('husobee'), soundex('Tymczak'), soundex('Ashcraft'), soundex('Robert'), soundex('Rupert'), soundex('Rubin'), soundex('Kant'), soundex('Knuth'), soundex('x'), soundex(''); # } {P236|H210|T522|A261|R163|R163|R150|K530|K530|X000|0000} + diff --git a/testing/select.test b/testing/select.test index 5b35d3eda..c42e38e42 100755 --- a/testing/select.test +++ b/testing/select.test @@ -720,6 +720,176 @@ do_execsql_test_on_specific_db {:memory:} select-no-match-in-leaf-page { 2 2} +do_execsql_test_on_specific_db {:memory:} select-range-search-count-asc-index { + CREATE TABLE t (a, b); + CREATE INDEX t_idx ON t(a, b); + insert into t values (1, 1); + insert into t values (1, 2); + insert into t values (1, 3); + insert into t values (1, 4); + insert into t values (1, 5); + insert into t values (1, 6); + insert into t values (2, 1); + insert into t values (2, 2); + insert into t values (2, 3); + insert into t values (2, 4); + insert into t values (2, 5); + insert into t values (2, 6); + select count(*) from t where a = 1 AND b >= 2 ORDER BY a ASC, b ASC; + select count(*) from t where a = 1 AND b > 2 ORDER BY a ASC, b ASC; + select count(*) from t where a = 1 AND b <= 4 ORDER BY a ASC, b ASC; + select count(*) from t where a = 1 AND b < 4 ORDER BY a ASC, b ASC; + select count(*) from t where a = 1 AND b >= 2 AND b <= 4 ORDER BY a ASC, b ASC; + select count(*) from t where a = 1 AND b > 2 AND b <= 4 ORDER BY a ASC, b ASC; + select count(*) from t where a = 1 AND b >= 2 AND b < 4 ORDER BY a ASC, b ASC; + select count(*) from t where a = 1 AND b > 2 AND b < 4 ORDER BY a ASC, b ASC; + + select count(*) from t where a = 1 AND b >= 2 ORDER BY a DESC, b DESC; + select count(*) from t where a = 1 AND b > 2 ORDER BY a DESC, b DESC; + select count(*) from t where a = 1 AND b <= 4 ORDER BY a DESC, b DESC; + select count(*) from t where a = 1 AND b < 4 ORDER BY a DESC, b DESC; + select count(*) from t where a = 1 AND b >= 2 AND b <= 4 ORDER BY a DESC, b DESC; + select count(*) from t where a = 1 AND b > 2 AND b <= 4 ORDER BY a DESC, b DESC; + select count(*) from t where a = 1 AND b >= 2 AND b < 4 ORDER BY a DESC, b DESC; + select count(*) from t where a = 1 AND b > 2 AND b < 4 ORDER BY a DESC, b DESC; +} {5 +4 +4 +3 +3 +2 +2 +1 +5 +4 +4 +3 +3 +2 +2 +1} + +do_execsql_test_on_specific_db {:memory:} select-range-search-count-desc-index { + CREATE TABLE t (a, b); + CREATE INDEX t_idx ON t(a, b DESC); + insert into t values (1, 1); + insert into t values (1, 2); + insert into t values (1, 3); + insert into t values (1, 4); + insert into t values (1, 5); + insert into t values (1, 6); + insert into t values (2, 1); + insert into t values (2, 2); + insert into t values (2, 3); + insert into t values (2, 4); + insert into t values (2, 5); + insert into t values (2, 6); + select count(*) from t where a = 1 AND b >= 2 ORDER BY a ASC, b DESC; + select count(*) from t where a = 1 AND b > 2 ORDER BY a ASC, b DESC; + select count(*) from t where a = 1 AND b <= 4 ORDER BY a ASC, b DESC; + select count(*) from t where a = 1 AND b < 4 ORDER BY a ASC, b DESC; + select count(*) from t where a = 1 AND b >= 2 AND b <= 4 ORDER BY a ASC, b DESC; + select count(*) from t where a = 1 AND b > 2 AND b <= 4 ORDER BY a ASC, b DESC; + select count(*) from t where a = 1 AND b >= 2 AND b < 4 ORDER BY a ASC, b DESC; + select count(*) from t where a = 1 AND b > 2 AND b < 4 ORDER BY a ASC, b DESC; + + select count(*) from t where a = 1 AND b >= 2 ORDER BY a DESC, b ASC; + select count(*) from t where a = 1 AND b > 2 ORDER BY a DESC, b ASC; + select count(*) from t where a = 1 AND b <= 4 ORDER BY a DESC, b ASC; + select count(*) from t where a = 1 AND b < 4 ORDER BY a DESC, b ASC; + select count(*) from t where a = 1 AND b >= 2 AND b <= 4 ORDER BY a DESC, b ASC; + select count(*) from t where a = 1 AND b > 2 AND b <= 4 ORDER BY a DESC, b ASC; + select count(*) from t where a = 1 AND b >= 2 AND b < 4 ORDER BY a DESC, b ASC; + select count(*) from t where a = 1 AND b > 2 AND b < 4 ORDER BY a DESC, b ASC; +} {5 +4 +4 +3 +3 +2 +2 +1 +5 +4 +4 +3 +3 +2 +2 +1} + +do_execsql_test_on_specific_db {:memory:} select-range-search-scan-asc-index { + CREATE TABLE t (a, b); + CREATE INDEX t_idx ON t(a, b); + insert into t values (1, 1); + insert into t values (1, 2); + insert into t values (1, 3); + insert into t values (1, 4); + insert into t values (1, 5); + insert into t values (1, 6); + insert into t values (2, 1); + insert into t values (2, 2); + insert into t values (2, 3); + insert into t values (2, 4); + insert into t values (2, 5); + insert into t values (2, 6); + select * from t where a = 1 AND b > 1 AND b < 6 ORDER BY a ASC, b ASC; + select * from t where a = 2 AND b > 1 AND b < 6 ORDER BY a DESC, b DESC; + select * from t where a = 1 AND b > 1 AND b < 6 ORDER BY a DESC, b ASC; + select * from t where a = 2 AND b > 1 AND b < 6 ORDER BY a ASC, b DESC; +} {1|2 +1|3 +1|4 +1|5 +2|5 +2|4 +2|3 +2|2 +1|2 +1|3 +1|4 +1|5 +2|5 +2|4 +2|3 +2|2} + +do_execsql_test_on_specific_db {:memory:} select-range-search-scan-desc-index { + CREATE TABLE t (a, b); + CREATE INDEX t_idx ON t(a, b DESC); + insert into t values (1, 1); + insert into t values (1, 2); + insert into t values (1, 3); + insert into t values (1, 4); + insert into t values (1, 5); + insert into t values (1, 6); + insert into t values (2, 1); + insert into t values (2, 2); + insert into t values (2, 3); + insert into t values (2, 4); + insert into t values (2, 5); + insert into t values (2, 6); + select * from t where a = 1 AND b > 1 AND b < 6 ORDER BY a ASC, b ASC; + select * from t where a = 2 AND b > 1 AND b < 6 ORDER BY a DESC, b DESC; + select * from t where a = 1 AND b > 1 AND b < 6 ORDER BY a DESC, b ASC; + select * from t where a = 2 AND b > 1 AND b < 6 ORDER BY a ASC, b DESC; +} {1|2 +1|3 +1|4 +1|5 +2|5 +2|4 +2|3 +2|2 +1|2 +1|3 +1|4 +1|5 +2|5 +2|4 +2|3 +2|2} + # Regression tests for double-quoted strings in SELECT statements do_execsql_test_skip_lines_on_specific_db 1 {:memory:} select-double-quotes-values { .dbconfig dqs_dml on @@ -743,6 +913,19 @@ do_execsql_test_on_specific_db {:memory:} select-in-simple { } {1 0} +do_execsql_test_on_specific_db {:memory:} select-in-with-nulls { + SELECT 4 IN (1, 4, null); + SELECT 4 NOT IN (1, 4, null); +} {1 +0} + +# All should be null +do_execsql_test_on_specific_db {:memory:} select-in-with-nulls-2 { +SELECT 1 IN (2, 3, null); +SELECT 1 NOT IN (2, 3, null); +SELECT null in (null); +} {\n\n} + do_execsql_test_on_specific_db {:memory:} select-in-complex { CREATE TABLE test_table (id INTEGER, category TEXT, value INTEGER); INSERT INTO test_table VALUES (1, 'A', 10), (2, 'B', 20), (3, 'A', 30), (4, 'C', 40); @@ -768,6 +951,63 @@ foreach {testname limit ans} { "SELECT id FROM users ORDER BY id LIMIT $limit" $ans } +do_execsql_test_on_specific_db {:memory:} limit-expr-can-be-cast-losslessly-1 { + SELECT 1 LIMIT 1.1 + 2.9; +} {1} + +do_execsql_test_on_specific_db {:memory:} limit-expr-can-be-cast-losslessly-2 { + CREATE TABLE T(a); + INSERT INTO T VALUES (1),(1),(1),(1); + SELECT * FROM T LIMIT 1.6/2 + 3.6/3 + 4*0.25; +} {1 +1 +1} + +# Numeric strings are cast to float. The final evaluation of the expression returns an int losslessly +do_execsql_test_on_specific_db {:memory:} limit-expr-can-be-cast-losslessly-3 { + CREATE TABLE T(a); + INSERT INTO T VALUES (1),(1),(1),(1); + SELECT * FROM T LIMIT '0.8' + '1.2' + 4*0.25; +} {1 +1 +1} + +# Invalid strings are cast to 0. So expression is valid +do_execsql_test_on_specific_db {:memory:} limit-expr-int-and-string { + SELECT 1 LIMIT 3/3 + 'test' + 4*'test are best'; +} {1} + +do_execsql_test_in_memory_error_content limit-expr-cannot-be-cast-losslessly-1 { + SELECT 1 LIMIT 1.1; +} {"datatype mismatch"} + +do_execsql_test_in_memory_error_content limit-expr-cannot-be-cast-losslessly-2 { + SELECT 1 LIMIT 1.1 + 2.2 + 1.9/8; +} {"datatype mismatch"} + +# Return error as float in the expression cannot be cast losslessly +do_execsql_test_in_memory_error_content limit-expr-cannot-be-cast-losslessly-3 { + SELECT 1 LIMIT 1.1 +'a'; +} {"datatype mismatch"} + +do_execsql_test_in_memory_error_content limit-expr-invalid-data-type-1 { + SELECT 1 LIMIT 'a'; +} {"datatype mismatch"} + +do_execsql_test_in_memory_error_content limit-expr-invalid-data-type-2 { + SELECT 1 LIMIT NULL; +} {"datatype mismatch"} + +# The expression below evaluates to NULL as string is cast to 0 +do_execsql_test_in_memory_error_content limit-expr-invalid-data-type-3 { + SELECT 1 LIMIT 1/'iwillbezero ;-; ' ; +} {"datatype mismatch"} + +# Expression is evaluated as NULL +do_execsql_test_in_memory_error_content limit-expr-invalid-data-type-4 { + SELECT 1 LIMIT 4+NULL; +} {"datatype mismatch"} + do_execsql_test_on_specific_db {:memory:} rowid-references { CREATE TABLE test_table (id INTEGER); INSERT INTO test_table VALUES (5),(5); @@ -812,3 +1052,45 @@ do_execsql_test_on_specific_db {:memory:} null-in-search { 2|2 2|2} +do_execsql_test_in_memory_any_error limit-column-reference-error { + CREATE TABLE t(a); + SELECT * FROM t LIMIT (t.a); +} + +do_execsql_test select-binary-collation { + SELECT 'a' = 'A'; + SELECT 'a' = 'a'; +} {0 1} + +# https://github.com/tursodatabase/turso/issues/3667 regression test +do_execsql_test_in_memory_error_content rowid-select-from-clause-subquery { + CREATE TABLE t(a); + SELECT rowid FROM (SELECT * FROM t); +} {"no such column: rowid"} + +do_execsql_test_on_specific_db {:memory:} rowid-select-from-clause-subquery-explicit-works { + CREATE TABLE t(a); + INSERT INTO t values ('abc'); + SELECT rowid,a FROM (SELECT rowid,a FROM t); +} {1|abc} + +# https://github.com/tursodatabase/turso/issues/3505 regression test +do_execsql_test_in_memory_any_error ambiguous-self-join { + CREATE TABLE T(a); + INSERT INTO t VALUES (1), (2), (3); + SELECT * fROM t JOIN t; +} + +do_execsql_test_on_specific_db {:memory:} unambiguous-self-join { + CREATE TABLE T(a); + INSERT INTO t VALUES (1), (2), (3); + SELECT * fROM t as ta JOIN t order by ta.a; +} {1|1 +1|2 +1|3 +2|1 +2|2 +2|3 +3|1 +3|2 +3|3} diff --git a/testing/subquery.test b/testing/subquery.test index 98ecec001..3a909df34 100644 --- a/testing/subquery.test +++ b/testing/subquery.test @@ -433,3 +433,17 @@ do_execsql_test subquery-count-all { where u.id < 100 ); } {1089} + +do_execsql_test_on_specific_db {:memory:} subquery-cte-available-in-arbitrary-depth { + with cte as (select 1 as one) + select onehundredandeleven+1 as onehundredandtwelve + from ( + with cte2 as (select 10 as ten) + select onehundredandone+ten as onehundredandeleven + from ( + with cte3 as (select 100 as hundred) + select one+hundred as onehundredandone + from cte join cte3 + ) join cte2 + ); +} {112} \ No newline at end of file diff --git a/testing/update.test b/testing/update.test index d55d98d0e..964dfba29 100755 --- a/testing/update.test +++ b/testing/update.test @@ -386,3 +386,115 @@ do_execsql_test_on_specific_db {:memory:} can-update-rowid-directly { UPDATE test SET rowid = 5; SELECT rowid, name from test; } {5|test} + +# https://github.com/tursodatabase/turso/issues/3678 +do_execsql_test_on_specific_db {:memory:} update-alias-visibility-in-where-clause { + create table t(a); + insert into t values (0); + insert into t values (5); + update t as tt set a = 1 where tt.a = 0; + select * from t; +} {1 +5} + +# Basic UPDATE tests with indexes +do_execsql_test_on_specific_db {:memory:} update-non-indexed-column { + CREATE TABLE t (a INTEGER, b INTEGER); + CREATE INDEX idx_a ON t(a); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); + UPDATE t SET b = 100 WHERE a = 2; + SELECT * FROM t ORDER BY a; +} {1|10 +2|100 +3|30} + +do_execsql_test_on_specific_db {:memory:} update-indexed-column { + CREATE TABLE t (a INTEGER, b INTEGER); + CREATE INDEX idx_a ON t(a); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); + UPDATE t SET a = 5 WHERE a = 2; + SELECT * FROM t ORDER BY a; +} {1|10 +3|30 +5|20} + +do_execsql_test_on_specific_db {:memory:} update-both-indexed-and-non-indexed { + CREATE TABLE t (a INTEGER, b INTEGER, c INTEGER); + CREATE INDEX idx_a ON t(a); + INSERT INTO t VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300); + UPDATE t SET a = 5, b = 50, c = 500 WHERE a = 2; + SELECT * FROM t ORDER BY a; +} {1|10|100 +3|30|300 +5|50|500} + +do_execsql_test_on_specific_db {:memory:} update-multiple-indexes { + CREATE TABLE t (a INTEGER, b INTEGER, c INTEGER); + CREATE INDEX idx_a ON t(a); + CREATE INDEX idx_b ON t(b); + INSERT INTO t VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300); + UPDATE t SET a = 5, b = 50 WHERE c = 200; + SELECT * FROM t ORDER BY a; +} {1|10|100 +3|30|300 +5|50|200} + +do_execsql_test_on_specific_db {:memory:} update-all-rows-with-index { + CREATE TABLE t (a INTEGER, b INTEGER); + CREATE INDEX idx_a ON t(a); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); + UPDATE t SET a = a + 10; + SELECT * FROM t ORDER BY a; +} {11|10 +12|20 +13|30} + +# Range update tests +do_execsql_test_on_specific_db {:memory:} update-range-non-indexed { + CREATE TABLE t (a INTEGER, b INTEGER); + CREATE INDEX idx_a ON t(a); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30), (4, 40), (5, 50); + UPDATE t SET b = 999 WHERE a >= 2 AND a <= 4; + SELECT * FROM t ORDER BY a; +} {1|10 +2|999 +3|999 +4|999 +5|50} + +do_execsql_test_on_specific_db {:memory:} update-range-indexed-column { + CREATE TABLE t (a INTEGER, b INTEGER); + CREATE INDEX idx_a ON t(a); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30), (4, 40), (5, 50); + UPDATE t SET a = a + 100 WHERE a >= 2 AND a < 4; + SELECT * FROM t ORDER BY a; +} {1|10 +4|40 +5|50 +102|20 +103|30} + +do_execsql_test_on_specific_db {:memory:} update-range-both-columns { + CREATE TABLE t (a INTEGER, b INTEGER, c INTEGER); + CREATE INDEX idx_a ON t(a); + INSERT INTO t VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300), (4, 40, 400), (5, 50, 500); + UPDATE t SET a = a * 10, b = b * 2 WHERE a > 1 AND a < 5; + SELECT * FROM t ORDER BY a; +} {1|10|100 +5|50|500 +20|40|200 +30|60|300 +40|80|400} + +do_execsql_test_on_specific_db {:memory:} update-range-multiple-indexes { + CREATE TABLE t (a INTEGER, b INTEGER, c INTEGER); + CREATE INDEX idx_a ON t(a); + CREATE INDEX idx_b ON t(b); + INSERT INTO t VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300), (4, 40, 400); + UPDATE t SET a = a + 10, b = b + 100 WHERE a BETWEEN 2 AND 3; + SELECT * FROM t ORDER BY a; +} {1|10|100 +4|40|400 +12|120|200 +13|130|300} + diff --git a/testing/values.test b/testing/values.test index 8ebc33978..b0cc2464a 100755 --- a/testing/values.test +++ b/testing/values.test @@ -48,3 +48,16 @@ do_execsql_test_skip_lines_on_specific_db 1 {:memory:} values-double-quotes-subq .dbconfig dqs_dml on SELECT * FROM (VALUES ("subquery_string")); } {subquery_string} + +# regression test for: https://github.com/tursodatabase/turso/issues/2158 +do_execsql_test_on_specific_db {:memory:} values-between { + CREATE TABLE t0 (c0); + INSERT INTO t0 VALUES ((0 BETWEEN 0 AND 0)), (0); + SELECT * FROM t0; +} {1 +0} + +do_execsql_test_in_memory_any_error values-illegal-column-ref { + CREATE TABLE t0 (c0); + INSERT INTO t0 VALUES (c0); +} \ No newline at end of file diff --git a/testing/vector.test b/testing/vector.test index 7cae2ca5e..12aaef940 100755 --- a/testing/vector.test +++ b/testing/vector.test @@ -12,3 +12,45 @@ do_execsql_test vector-functions-valid { {[1,2,3]} {[-1000000000000000000]} } + +do_execsql_test_on_specific_db {:memory:} vector-insert { + CREATE TABLE IF NOT EXISTS vector_test ( + id INTEGER PRIMARY KEY, + format TEXT NOT NULL, + vec_data F32_BLOB(3) -- 3-dimensional vector + ); + INSERT INTO vector_test (id, format, vec_data) + VALUES (2, 'Bracketed_comma_separated', vector('[4.000000,5.000000,6.000000]')); + SELECT id, format, vector_extract(vec_data) from vector_test; +} {2|Bracketed_comma_separated|[4,5,6]} + +do_execsql_test_on_specific_db {:memory:} vector-insert { + CREATE TABLE IF NOT EXISTS vector_test ( + id INTEGER PRIMARY KEY, + format TEXT NOT NULL, + vec_data F32_BLOB(3) -- 3-dimensional vector + ); + INSERT INTO vector_test (id, format, vec_data) + VALUES (2, 'Bracketed_comma_separated', vector('[4.000000,5.000000,6.000000]')); + SELECT id, format, vector_extract(vec_data) from vector_test; +} {2|Bracketed_comma_separated|[4,5,6]} + +do_execsql_test_in_memory_error vector-insert-no-quotes { + CREATE TABLE IF NOT EXISTS vector_test ( + id INTEGER PRIMARY KEY, + format TEXT NOT NULL, + vec_data F32_BLOB(3) -- 3-dimensional vector + ); + INSERT INTO vector_test (id, format, vec_data) + VALUES (2, 'Bracketed_comma_separated', vector([4.000000,5.000000,6.000000])); +} { × Parse error: no such column: [4.000000,5.000000,6.000000]} + +do_execsql_test_in_memory_error_content vector-insert-double-quotes { + CREATE TABLE IF NOT EXISTS vector_test ( + id INTEGER PRIMARY KEY, + format TEXT NOT NULL, + vec_data F32_BLOB(3) -- 3-dimensional vector + ); + INSERT INTO vector_test (id, format, vec_data) + VALUES (2, 'Bracketed_comma_separated', vector("[4.000000,5.000000,6.000000]")); +} {no such column: [4.000000,5.000000,6.000000]} \ No newline at end of file diff --git a/testing/views.test b/testing/views.test index e7ca6a938..f5829abaa 100755 --- a/testing/views.test +++ b/testing/views.test @@ -230,13 +230,13 @@ do_execsql_test_on_specific_db {:memory:} view-with-having { } {C|380 A|250} -do_execsql_test_error view-self-circle-detection { +do_execsql_test_in_memory_error_content view-self-circle-detection { CREATE VIEW v AS SELECT * FROM v; SELECT * FROM v; } {view v is circularly defined} -do_execsql_test_error view-mutual-circle-detection { +do_execsql_test_in_memory_error_content view-mutual-circle-detection { CREATE VIEW v AS SELECT * FROM vv; CREATE VIEW vv AS SELECT * FROM v; SELECT * FROM v; diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 837635850..9b7ee7f5d 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -14,6 +14,10 @@ path = "lib.rs" name = "integration_tests" path = "integration/mod.rs" +[[test]] +name = "fuzz_tests" +path = "fuzz/mod.rs" + [dependencies] anyhow.workspace = true env_logger = { workspace = true } @@ -29,11 +33,13 @@ rand = { workspace = true } zerocopy = "0.8.26" ctor = "0.5.0" twox-hash = "2.1.1" +sql_generation = { path = "../sql_generation" } +turso_parser = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +tracing = { workspace = true } [dev-dependencies] test-log = { version = "0.2.17", features = ["trace"] } -tracing-subscriber = { workspace = true, features = ["env-filter"] } -tracing = { workspace = true } [features] default = ["test_helper"] diff --git a/tests/integration/fuzz/grammar_generator.rs b/tests/fuzz/grammar_generator.rs similarity index 99% rename from tests/integration/fuzz/grammar_generator.rs rename to tests/fuzz/grammar_generator.rs index aa046f065..5c8f701d8 100644 --- a/tests/integration/fuzz/grammar_generator.rs +++ b/tests/fuzz/grammar_generator.rs @@ -89,6 +89,12 @@ struct GrammarGeneratorInner { symbols: HashMap, } +impl Default for GrammarGenerator { + fn default() -> Self { + Self::new() + } +} + impl GrammarGenerator { pub fn new() -> Self { GrammarGenerator(Rc::new(RefCell::new(GrammarGeneratorInner { diff --git a/tests/integration/fuzz/mod.rs b/tests/fuzz/mod.rs similarity index 64% rename from tests/integration/fuzz/mod.rs rename to tests/fuzz/mod.rs index 1f40c8ae0..9f06f20fc 100644 --- a/tests/integration/fuzz/mod.rs +++ b/tests/fuzz/mod.rs @@ -1,24 +1,23 @@ pub mod grammar_generator; +pub mod rowid_alias; #[cfg(test)] -mod tests { - use rand::seq::{IndexedRandom, SliceRandom}; - use std::collections::HashSet; - use turso_core::DatabaseOpts; - - use rand::{Rng, SeedableRng}; +mod fuzz_tests { + use rand::seq::{IndexedRandom, IteratorRandom, SliceRandom}; + use rand::Rng; use rand_chacha::ChaCha8Rng; use rusqlite::{params, types::Value}; + use std::{collections::HashSet, io::Write}; + use turso_core::DatabaseOpts; - use crate::{ - common::{ - do_flush, limbo_exec_rows, limbo_exec_rows_fallible, limbo_stmt_get_column_names, - maybe_setup_tracing, rng_from_time, rng_from_time_or_env, sqlite_exec_rows, - TempDatabase, - }, - fuzz::grammar_generator::{const_str, rand_int, rand_str, GrammarGenerator}, + use core_tester::common::{ + do_flush, limbo_exec_rows, limbo_exec_rows_fallible, limbo_stmt_get_column_names, + maybe_setup_tracing, rng_from_time_or_env, rusqlite_integrity_check, sqlite_exec_rows, + TempDatabase, }; + use super::grammar_generator::{const_str, rand_int, rand_str, GrammarGenerator}; + use super::grammar_generator::SymbolHandle; /// [See this issue for more info](https://github.com/tursodatabase/turso/issues/1763) @@ -222,12 +221,7 @@ mod tests { /// A test for verifying that index seek+scan works correctly for compound keys /// on indexes with various column orderings. pub fn index_scan_compound_key_fuzz() { - let (mut rng, seed) = if std::env::var("SEED").is_ok() { - let seed = std::env::var("SEED").unwrap().parse::().unwrap(); - (ChaCha8Rng::seed_from_u64(seed), seed) - } else { - rng_from_time() - }; + let (mut rng, seed) = rng_from_time_or_env(); let table_defs: [&str; 8] = [ "CREATE TABLE t (x, y, z, nonindexed_col, PRIMARY KEY (x, y, z))", "CREATE TABLE t (x, y, z, nonindexed_col, PRIMARY KEY (x desc, y, z))", @@ -382,11 +376,28 @@ mod tests { // Use a small limit to make the test complete faster let limit = 5; - // Generate WHERE clause string + /// Generate a comparison string (e.g. x > 10 AND x < 20) or just x > 10. + fn generate_comparison( + operator: &str, + col_name: &str, + col_val: i32, + rng: &mut ChaCha8Rng, + ) -> String { + if operator != "=" && rng.random_range(0..3) == 1 { + let val2 = rng.random_range(0..=3000); + let op2 = COMPARISONS[rng.random_range(0..COMPARISONS.len())]; + format!("{col_name} {operator} {col_val} AND {col_name} {op2} {val2}") + } else { + format!("{col_name} {operator} {col_val}") + } + } + + // Generate WHERE clause string. + // Sometimes add another inequality to the WHERE clause (e.g. x > 10 AND x < 20) to exercise range queries. let where_clause_components = vec![ - comp1.map(|x| format!("x {} {}", x, col_val_first.unwrap())), - comp2.map(|x| format!("y {} {}", x, col_val_second.unwrap())), - comp3.map(|x| format!("z {} {}", x, col_val_third.unwrap())), + comp1.map(|x| generate_comparison(x, "x", col_val_first.unwrap(), &mut rng)), + comp2.map(|x| generate_comparison(x, "y", col_val_second.unwrap(), &mut rng)), + comp3.map(|x| generate_comparison(x, "z", col_val_third.unwrap(), &mut rng)), ] .into_iter() .flatten() @@ -500,12 +511,7 @@ mod tests { #[test] pub fn collation_fuzz() { let _ = env_logger::try_init(); - let (mut rng, seed) = if std::env::var("SEED").is_ok() { - let seed = std::env::var("SEED").unwrap().parse::().unwrap(); - (ChaCha8Rng::seed_from_u64(seed), seed) - } else { - rng_from_time() - }; + let (mut rng, seed) = rng_from_time_or_env(); println!("collation_fuzz seed: {seed}"); // Build six table variants that assign BINARY/NOCASE/RTRIM across (a,b,c) @@ -598,15 +604,10 @@ mod tests { // Fuzz WHERE clauses with and without explicit COLLATE on a/b/c let columns = ["a", "b", "c"]; let collates = [None, Some("BINARY"), Some("NOCASE"), Some("RTRIM")]; - let (mut rng, seed) = if std::env::var("SEED").is_ok() { - let seed = std::env::var("SEED").unwrap().parse::().unwrap(); - (ChaCha8Rng::seed_from_u64(seed), seed) - } else { - rng_from_time() - }; + let (mut rng, seed) = rng_from_time_or_env(); println!("collation_fuzz seed: {seed}"); - const ITERS: usize = 3000; + const ITERS: usize = 1000; for iter in 0..ITERS { if iter % (ITERS / 100).max(1) == 0 { println!("collation_fuzz: iteration {}/{}", iter + 1, ITERS); @@ -650,13 +651,1207 @@ mod tests { } } + #[test] + #[allow(unused_assignments)] + pub fn fk_deferred_constraints_fuzz() { + let _ = env_logger::try_init(); + let (mut rng, seed) = rng_from_time_or_env(); + println!("fk_deferred_constraints_fuzz seed: {seed}"); + + const OUTER_ITERS: usize = 10; + const INNER_ITERS: usize = 100; + + for outer in 0..OUTER_ITERS { + println!("fk_deferred_constraints_fuzz {}/{}", outer + 1, OUTER_ITERS); + + let limbo_db = TempDatabase::new_empty(true); + let sqlite_db = TempDatabase::new_empty(true); + let limbo = limbo_db.connect_limbo(); + let sqlite = rusqlite::Connection::open(sqlite_db.path.clone()).unwrap(); + + let mut stmts: Vec = Vec::new(); + let mut log_and_exec = |sql: &str| { + stmts.push(sql.to_string()); + sql.to_string() + }; + // Enable FKs + let s = log_and_exec("PRAGMA foreign_keys=ON"); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + let get_constraint_type = |rng: &mut ChaCha8Rng| match rng.random_range(0..3) { + 0 => "INTEGER PRIMARY KEY", + 1 => "UNIQUE", + 2 => "PRIMARY KEY", + _ => unreachable!(), + }; + + // Mix of immediate and deferred FK constraints + let s = log_and_exec(&format!( + "CREATE TABLE parent(id {}, a INT, b INT)", + get_constraint_type(&mut rng) + )); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + // Child with DEFERRABLE INITIALLY DEFERRED FK + let s = log_and_exec(&format!( + "CREATE TABLE child_deferred(id {}, pid INT, x INT, \ + FOREIGN KEY(pid) REFERENCES parent(id) DEFERRABLE INITIALLY DEFERRED)", + get_constraint_type(&mut rng) + )); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + // Child with immediate FK (default) + let s = log_and_exec(&format!( + "CREATE TABLE child_immediate(id {}, pid INT, y INT, \ + FOREIGN KEY(pid) REFERENCES parent(id))", + get_constraint_type(&mut rng) + )); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + let composite_constraint = match rng.random_range(0..2) { + 0 => "PRIMARY KEY", + 1 => "UNIQUE", + _ => unreachable!(), + }; + // Composite key parent for deferred testing + let s = log_and_exec( + &format!("CREATE TABLE parent_comp(a INT NOT NULL, b INT NOT NULL, c INT, {composite_constraint}(a,b))"), + ); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + // Child with composite deferred FK + let s = log_and_exec( + "CREATE TABLE child_comp_deferred(id INTEGER PRIMARY KEY, ca INT, cb INT, z INT, \ + FOREIGN KEY(ca,cb) REFERENCES parent_comp(a,b) DEFERRABLE INITIALLY DEFERRED)", + ); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + // Seed initial data + let mut parent_ids = std::collections::HashSet::new(); + for _ in 0..rng.random_range(10..=25) { + let id = rng.random_range(1..=50) as i64; + if parent_ids.insert(id) { + let a = rng.random_range(-5..=25); + let b = rng.random_range(-5..=25); + let stmt = log_and_exec(&format!("INSERT INTO parent VALUES ({id}, {a}, {b})")); + limbo_exec_rows(&limbo_db, &limbo, &stmt); + sqlite.execute(&stmt, params![]).unwrap(); + } + } + + // Seed composite parent + let mut comp_pairs = std::collections::HashSet::new(); + for _ in 0..rng.random_range(3..=10) { + let a = rng.random_range(-3..=6) as i64; + let b = rng.random_range(-3..=6) as i64; + if comp_pairs.insert((a, b)) { + let c = rng.random_range(0..=20); + let stmt = + log_and_exec(&format!("INSERT INTO parent_comp VALUES ({a}, {b}, {c})")); + limbo_exec_rows(&limbo_db, &limbo, &stmt); + sqlite.execute(&stmt, params![]).unwrap(); + } + } + + // Transaction-based mutations with mix of deferred and immediate operations + let mut in_tx = false; + for tx_num in 0..INNER_ITERS { + // Decide if we're in a transaction + let start_a_transaction = rng.random_bool(0.7); + + if start_a_transaction && !in_tx { + in_tx = true; + let s = log_and_exec("BEGIN"); + let sres = sqlite.execute(&s, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &s); + match (&sres, &lres) { + (Ok(_), Ok(_)) | (Err(_), Err(_)) => {} + _ => { + eprintln!("BEGIN mismatch"); + eprintln!("sqlite result: {sres:?}"); + eprintln!("limbo result: {lres:?}"); + let file = std::fs::File::create("fk_deferred.sql").unwrap(); + for stmt in stmts.iter() { + writeln!(&file, "{stmt};").unwrap(); + } + eprintln!("Wrote `tests/fk_deferred.sql` for debugging"); + eprintln!("turso path: {}", limbo_db.path.display()); + eprintln!("sqlite path: {}", sqlite_db.path.display()); + panic!("BEGIN mismatch"); + } + } + } + + let op = rng.random_range(0..12); + let stmt = match op { + // Insert into child_deferred (can violate temporarily in transaction) + 0 => { + let id = rng.random_range(1000..=2000); + let pid = if rng.random_bool(0.6) { + *parent_ids.iter().choose(&mut rng).unwrap_or(&1) + } else { + // Non-existent parent - OK if deferred and fixed before commit + rng.random_range(200..=300) as i64 + }; + let x = rng.random_range(-10..=10); + format!("INSERT INTO child_deferred VALUES ({id}, {pid}, {x})") + } + // Insert into child_immediate (must satisfy FK immediately) + 1 => { + let id = rng.random_range(3000..=4000); + let pid = if rng.random_bool(0.8) { + *parent_ids.iter().choose(&mut rng).unwrap_or(&1) + } else { + rng.random_range(200..=300) as i64 + }; + let y = rng.random_range(-10..=10); + format!("INSERT INTO child_immediate VALUES ({id}, {pid}, {y})") + } + // Insert parent (may fix deferred violations) + 2 => { + let id = rng.random_range(1..=300); + let a = rng.random_range(-5..=25); + let b = rng.random_range(-5..=25); + parent_ids.insert(id as i64); + format!("INSERT INTO parent VALUES ({id}, {a}, {b})") + } + // Delete parent (may cause violations) + 3 => { + let id = if rng.random_bool(0.5) { + *parent_ids.iter().choose(&mut rng).unwrap_or(&1) + } else { + rng.random_range(1..=300) as i64 + }; + format!("DELETE FROM parent WHERE id={id}") + } + // Update parent PK + 4 => { + let old = rng.random_range(1..=300); + let new = rng.random_range(1..=350); + format!("UPDATE parent SET id={new} WHERE id={old}") + } + // Update child_deferred FK + 5 => { + let id = rng.random_range(1000..=2000); + let pid = if rng.random_bool(0.5) { + *parent_ids.iter().choose(&mut rng).unwrap_or(&1) + } else { + rng.random_range(200..=400) as i64 + }; + format!("UPDATE child_deferred SET pid={pid} WHERE id={id}") + } + // Insert into composite deferred child + 6 => { + let id = rng.random_range(5000..=6000); + let (ca, cb) = if rng.random_bool(0.6) { + *comp_pairs.iter().choose(&mut rng).unwrap_or(&(1, 1)) + } else { + // Non-existent composite parent + ( + rng.random_range(-5..=8) as i64, + rng.random_range(-5..=8) as i64, + ) + }; + let z = rng.random_range(0..=10); + format!( + "INSERT INTO child_comp_deferred VALUES ({id}, {ca}, {cb}, {z}) ON CONFLICT DO NOTHING" + ) + } + // Insert composite parent + 7 => { + let a = rng.random_range(-5..=8) as i64; + let b = rng.random_range(-5..=8) as i64; + let c = rng.random_range(0..=20); + comp_pairs.insert((a, b)); + format!("INSERT INTO parent_comp VALUES ({a}, {b}, {c})") + } + // UPSERT with deferred child + 8 => { + let id = rng.random_range(1000..=2000); + let pid = if rng.random_bool(0.5) { + *parent_ids.iter().choose(&mut rng).unwrap_or(&1) + } else { + rng.random_range(200..=400) as i64 + }; + let x = rng.random_range(-10..=10); + format!( + "INSERT INTO child_deferred VALUES ({id}, {pid}, {x}) + ON CONFLICT(id) DO UPDATE SET pid=excluded.pid, x=excluded.x" + ) + } + // Delete from child_deferred + 9 => { + let id = rng.random_range(1000..=2000); + format!("DELETE FROM child_deferred WHERE id={id}") + } + // Self-referential deferred insert (create temp violation then fix) + 10 if start_a_transaction => { + let id = rng.random_range(400..=500); + let pid = id + 1; // References non-existent yet + format!("INSERT INTO child_deferred VALUES ({id}, {pid}, 0)") + } + _ => { + // Default: simple parent insert + let id = rng.random_range(1..=300); + format!("INSERT INTO parent VALUES ({id}, 0, 0)") + } + }; + + let stmt = log_and_exec(&stmt); + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + + if !start_a_transaction && !in_tx { + match (sres, lres) { + (Ok(_), Ok(_)) | (Err(_), Err(_)) => {} + (s, l) => { + eprintln!("Non-tx mismatch: sqlite={s:?}, limbo={l:?}"); + eprintln!("Statement: {stmt}"); + eprintln!("Seed: {seed}, outer: {outer}, tx: {tx_num}, in_tx={in_tx}"); + let mut file = std::fs::File::create("fk_deferred.sql").unwrap(); + for stmt in stmts.iter() { + writeln!(file, "{stmt};").expect("write to file"); + } + eprintln!("turso path: {}", limbo_db.path.display()); + eprintln!("sqlite path: {}", sqlite_db.path.display()); + panic!("Non-transactional operation mismatch, file written to 'tests/fk_deferred.sql'"); + } + } + } + + // Randomly COMMIT or ROLLBACK some of the time + if in_tx && rng.random_bool(0.4) { + let commit = rng.random_bool(0.7); + let s = log_and_exec("COMMIT"); + + let sres = sqlite.execute(&s, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &s); + + match (sres, lres) { + (Ok(_), Ok(_)) => {} + (Err(_), Err(_)) => { + // Both failed - OK, deferred constraint violation at commit + if commit && in_tx { + in_tx = false; + let s = if commit { + log_and_exec("ROLLBACK") + } else { + log_and_exec("SELECT 1") // noop if we already rolled back + }; + + let sres = sqlite.execute(&s, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &s); + match (sres, lres) { + (Ok(_), Ok(_)) => {} + (s, l) => { + eprintln!("Post-failed-commit cleanup mismatch: sqlite={s:?}, limbo={l:?}"); + let mut file = + std::fs::File::create("fk_deferred.sql").unwrap(); + for stmt in stmts.iter() { + writeln!(file, "{stmt};").expect("write to file"); + } + eprintln!("turso path: {}", limbo_db.path.display()); + eprintln!("sqlite path: {}", sqlite_db.path.display()); + panic!("Post-failed-commit cleanup mismatch, file written to 'tests/fk_deferred.sql'"); + } + } + } + } + (s, l) => { + eprintln!("\n=== COMMIT/ROLLBACK mismatch ==="); + eprintln!("Operation: {s:?}"); + eprintln!("sqlite={s:?}, limbo={l:?}"); + eprintln!("Seed: {seed}, outer: {outer}, tx: {tx_num}, in_tx={in_tx}"); + eprintln!("--- Replay statements ({}) ---", stmts.len()); + let mut file = std::fs::File::create("fk_deferred.sql").unwrap(); + for stmt in stmts.iter() { + writeln!(file, "{stmt};").expect("write to file"); + } + eprintln!("Turso path: {}", limbo_db.path.display()); + eprintln!("Sqlite path: {}", sqlite_db.path.display()); + panic!( + "outcome mismatch, .sql file written to `tests/fk_deferred.sql`" + ); + } + } + in_tx = false; + } + } + // Print all statements + if std::env::var("VERBOSE").is_ok() { + println!("{}", stmts.join("\n")); + println!("--------- ITERATION COMPLETED ---------"); + } + } + } + + #[test] + pub fn fk_single_pk_mutation_fuzz() { + let _ = env_logger::try_init(); + let (mut rng, seed) = rng_from_time_or_env(); + println!("fk_single_pk_mutation_fuzz seed: {seed}"); + + const OUTER_ITERS: usize = 20; + const INNER_ITERS: usize = 100; + + for outer in 0..OUTER_ITERS { + println!("fk_single_pk_mutation_fuzz {}/{}", outer + 1, OUTER_ITERS); + + let limbo_db = TempDatabase::new_empty(true); + let sqlite_db = TempDatabase::new_empty(true); + let limbo = limbo_db.connect_limbo(); + let sqlite = rusqlite::Connection::open(sqlite_db.path.clone()).unwrap(); + + // Statement log for this iteration + let mut stmts: Vec = Vec::new(); + let mut log_and_exec = |sql: &str| { + stmts.push(sql.to_string()); + sql.to_string() + }; + + // Enable FKs in both engines + let s = log_and_exec("PRAGMA foreign_keys=ON"); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + let s = log_and_exec("CREATE TABLE p(id INTEGER PRIMARY KEY, a INT, b INT)"); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + let s = log_and_exec( + "CREATE TABLE c(id INTEGER PRIMARY KEY, x INT, y INT, FOREIGN KEY(x) REFERENCES p(id))", + ); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + // Seed parent + let n_par = rng.random_range(5..=40); + let mut used_ids = std::collections::HashSet::new(); + for _ in 0..n_par { + let mut id; + loop { + id = rng.random_range(1..=200) as i64; + if used_ids.insert(id) { + break; + } + } + let a = rng.random_range(-5..=25); + let b = rng.random_range(-5..=25); + let stmt = log_and_exec(&format!("INSERT INTO p VALUES ({id}, {a}, {b})")); + let l_res = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + let s_res = sqlite.execute(&stmt, params![]); + match (l_res, s_res) { + (Ok(_), Ok(_)) | (Err(_), Err(_)) => {} + _ => { + panic!("Seeding parent insert mismatch"); + } + } + } + + // Seed child + let n_child = rng.random_range(5..=80); + for i in 0..n_child { + let id = 1000 + i as i64; + let x = if rng.random_bool(0.8) { + *used_ids.iter().choose(&mut rng).unwrap() + } else { + rng.random_range(1..=220) as i64 + }; + let y = rng.random_range(-10..=10); + let stmt = log_and_exec(&format!("INSERT INTO c VALUES ({id}, {x}, {y})")); + match ( + sqlite.execute(&stmt, params![]), + limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt), + ) { + (Ok(_), Ok(_)) => {} + (Err(_), Err(_)) => {} + (x, y) => { + eprintln!("\n=== FK fuzz failure (seeding mismatch) ==="); + eprintln!("seed: {seed}, outer: {}", outer + 1); + eprintln!("sqlite: {x:?}, limbo: {y:?}"); + eprintln!("last stmt: {stmt}"); + eprintln!("--- replay statements ({}) ---", stmts.len()); + for (i, s) in stmts.iter().enumerate() { + eprintln!("{:04}: {};", i + 1, s); + } + panic!("Seeding child insert mismatch"); + } + } + } + + // Mutations + for _ in 0..INNER_ITERS { + let action = rng.random_range(0..8); + let stmt = match action { + // Parent INSERT + 0 => { + let mut id; + let mut tries = 0; + loop { + id = rng.random_range(1..=250) as i64; + if !used_ids.contains(&id) || tries > 10 { + break; + } + tries += 1; + } + let a = rng.random_range(-5..=25); + let b = rng.random_range(-5..=25); + format!("INSERT INTO p VALUES({id}, {a}, {b})") + } + // Parent UPDATE + 1 => { + if rng.random_bool(0.5) { + let old = rng.random_range(1..=250); + let new_id = rng.random_range(1..=260); + format!("UPDATE p SET id={new_id} WHERE id={old}") + } else { + let a = rng.random_range(-5..=25); + let b = rng.random_range(-5..=25); + let tgt = rng.random_range(1..=260); + format!("UPDATE p SET a={a}, b={b} WHERE id={tgt}") + } + } + // Parent DELETE + 2 => { + let del_id = rng.random_range(1..=260); + format!("DELETE FROM p WHERE id={del_id}") + } + // Child INSERT + 3 => { + let id = rng.random_range(1000..=2000); + let x = if rng.random_bool(0.7) { + if let Some(p) = used_ids.iter().choose(&mut rng) { + *p + } else { + rng.random_range(1..=260) as i64 + } + } else { + rng.random_range(1..=260) as i64 + }; + let y = rng.random_range(-10..=10); + format!("INSERT INTO c VALUES({id}, {x}, {y})") + } + // Child UPDATE + 4 => { + let pick = rng.random_range(1000..=2000); + if rng.random_bool(0.6) { + let new_x = if rng.random_bool(0.7) { + if let Some(p) = used_ids.iter().choose(&mut rng) { + *p + } else { + rng.random_range(1..=260) as i64 + } + } else { + rng.random_range(1..=260) as i64 + }; + format!("UPDATE c SET x={new_x} WHERE id={pick}") + } else { + let new_y = rng.random_range(-10..=10); + format!("UPDATE c SET y={new_y} WHERE id={pick}") + } + } + 5 => { + // UPSERT parent + let pick = rng.random_range(1..=250); + if rng.random_bool(0.5) { + let a = rng.random_range(-5..=25); + let b = rng.random_range(-5..=25); + format!( + "INSERT INTO p VALUES({pick}, {a}, {b}) ON CONFLICT(id) DO UPDATE SET a=excluded.a, b=excluded.b" + ) + } else { + let a = rng.random_range(-5..=25); + let b = rng.random_range(-5..=25); + format!( + "INSERT INTO p VALUES({pick}, {a}, {b}) \ + ON CONFLICT(id) DO NOTHING" + ) + } + } + 6 => { + // UPSERT child + let pick = rng.random_range(1000..=2000); + if rng.random_bool(0.5) { + let x = if rng.random_bool(0.7) { + if let Some(p) = used_ids.iter().choose(&mut rng) { + *p + } else { + rng.random_range(1..=260) as i64 + } + } else { + rng.random_range(1..=260) as i64 + }; + format!( + "INSERT INTO c VALUES({pick}, {x}, 0) ON CONFLICT(id) DO UPDATE SET x=excluded.x" + ) + } else { + let x = if rng.random_bool(0.7) { + if let Some(p) = used_ids.iter().choose(&mut rng) { + *p + } else { + rng.random_range(1..=260) as i64 + } + } else { + rng.random_range(1..=260) as i64 + }; + format!( + "INSERT INTO c VALUES({pick}, {x}, 0) ON CONFLICT(id) DO NOTHING" + ) + } + } + // Child DELETE + _ => { + let pick = rng.random_range(1000..=2000); + format!("DELETE FROM c WHERE id={pick}") + } + }; + + let stmt = log_and_exec(&stmt); + + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + + match (sres, lres) { + (Ok(_), Ok(_)) => { + if stmt.starts_with("INSERT INTO p VALUES(") { + if let Some(tok) = stmt.split_whitespace().nth(4) { + if let Some(idtok) = tok.split(['(', ',']).nth(1) { + if let Ok(idnum) = idtok.parse::() { + used_ids.insert(idnum); + } + } + } + } + let sp = sqlite_exec_rows(&sqlite, "SELECT id,a,b FROM p ORDER BY id"); + let sc = sqlite_exec_rows(&sqlite, "SELECT id,x,y FROM c ORDER BY id"); + let lp = + limbo_exec_rows(&limbo_db, &limbo, "SELECT id,a,b FROM p ORDER BY id"); + let lc = + limbo_exec_rows(&limbo_db, &limbo, "SELECT id,x,y FROM c ORDER BY id"); + + if sp != lp || sc != lc { + eprintln!("\n=== FK fuzz failure (state mismatch) ==="); + eprintln!("seed: {seed}, outer: {}", outer + 1); + eprintln!("last stmt: {stmt}"); + eprintln!("sqlite p: {sp:?}\nsqlite c: {sc:?}"); + eprintln!("limbo p: {lp:?}\nlimbo c: {lc:?}"); + eprintln!("--- replay statements ({}) ---", stmts.len()); + for (i, s) in stmts.iter().enumerate() { + eprintln!("{:04}: {};", i + 1, s); + } + panic!("State mismatch"); + } + } + (Err(_), Err(_)) => { /* parity OK */ } + (ok_sqlite, ok_limbo) => { + eprintln!("\n=== FK fuzz failure (outcome mismatch) ==="); + eprintln!("seed: {seed}, outer: {}", outer + 1); + eprintln!("sqlite: {ok_sqlite:?}, limbo: {ok_limbo:?}"); + eprintln!("last stmt: {stmt}"); + // dump final states to help decide who is right + let sp = sqlite_exec_rows(&sqlite, "SELECT id,a,b FROM p ORDER BY id"); + let sc = sqlite_exec_rows(&sqlite, "SELECT id,x,y FROM c ORDER BY id"); + let lp = + limbo_exec_rows(&limbo_db, &limbo, "SELECT id,a,b FROM p ORDER BY id"); + let lc = + limbo_exec_rows(&limbo_db, &limbo, "SELECT id,x,y FROM c ORDER BY id"); + eprintln!("sqlite p: {sp:?}\nsqlite c: {sc:?}"); + eprintln!("turso p: {lp:?}\nturso c: {lc:?}"); + eprintln!( + "--- writing ({}) statements to fk_fuzz_statements.sql ---", + stmts.len() + ); + let mut file = std::fs::File::create("fk_fuzz_statements.sql").unwrap(); + for s in stmts.iter() { + let _ = file.write_fmt(format_args!("{s};\n")); + } + file.flush().unwrap(); + panic!("DML outcome mismatch, statements written to tests/fk_fuzz_statements.sql"); + } + } + } + } + } + + #[test] + pub fn fk_edgecases_fuzzing() { + let _ = env_logger::try_init(); + let (mut rng, seed) = rng_from_time_or_env(); + println!("fk_edgecases_minifuzz seed: {seed}"); + + const OUTER_ITERS: usize = 20; + const INNER_ITERS: usize = 100; + + fn assert_parity( + seed: u64, + stmts: &[String], + sqlite_res: rusqlite::Result, + limbo_res: Result>, turso_core::LimboError>, + last_stmt: &str, + tag: &str, + ) { + match (sqlite_res.is_ok(), limbo_res.is_ok()) { + (true, true) | (false, false) => (), + _ => { + eprintln!("\n=== {tag} mismatch ==="); + eprintln!("seed: {seed}"); + eprintln!("sqlite: {sqlite_res:?}, limbo: {limbo_res:?}"); + eprintln!("stmt: {last_stmt}"); + eprintln!("--- replay statements ({}) ---", stmts.len()); + for (i, s) in stmts.iter().enumerate() { + eprintln!("{:04}: {};", i + 1, s); + } + panic!("{tag}: engines disagree"); + } + } + } + + // parent rowid, child textified integers -> MustBeInt coercion path + for outer in 0..OUTER_ITERS { + let limbo_db = TempDatabase::new_empty(true); + let sqlite_db = TempDatabase::new_empty(true); + let limbo = limbo_db.connect_limbo(); + let sqlite = rusqlite::Connection::open(sqlite_db.path.clone()).unwrap(); + + let mut stmts: Vec = Vec::new(); + let log = |s: &str, stmts: &mut Vec| { + stmts.push(s.to_string()); + s.to_string() + }; + + for s in [ + "PRAGMA foreign_keys=ON", + "CREATE TABLE p(id INTEGER PRIMARY KEY, a INT)", + "CREATE TABLE c(id INTEGER PRIMARY KEY, x INT, FOREIGN KEY(x) REFERENCES p(id))", + ] { + let s = log(s, &mut stmts); + let _ = limbo_exec_rows_fallible(&limbo_db, &limbo, &s); + let _ = sqlite.execute(&s, params![]); + } + + // Seed a few parents + for _ in 0..rng.random_range(2..=5) { + let id = rng.random_range(1..=15); + let a = rng.random_range(-5..=5); + let s = log(&format!("INSERT INTO p VALUES({id},{a})"), &mut stmts); + let _ = limbo_exec_rows_fallible(&limbo_db, &limbo, &s); + let _ = sqlite.execute(&s, params![]); + } + + // try random child inserts with weird text-ints + for i in 0..INNER_ITERS { + let id = 1000 + i as i64; + let raw = if rng.random_bool(0.7) { + 1 + rng.random_range(0..=15) + } else { + rng.random_range(100..=200) as i64 + }; + + // Randomly decorate the integer as text with spacing/zeros/plus + let pad_left_zeros = rng.random_range(0..=2); + let spaces_left = rng.random_range(0..=2); + let spaces_right = rng.random_range(0..=2); + let plus = if rng.random_bool(0.3) { "+" } else { "" }; + let txt_num = format!( + "{plus}{:0width$}", + raw, + width = (1 + pad_left_zeros) as usize + ); + let txt = format!( + "'{}{}{}'", + " ".repeat(spaces_left), + txt_num, + " ".repeat(spaces_right) + ); + + let stmt = log(&format!("INSERT INTO c VALUES({id}, {txt})"), &mut stmts); + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + assert_parity(seed, &stmts, sres, lres, &stmt, "A: rowid-coercion"); + } + println!("A {}/{} ok", outer + 1, OUTER_ITERS); + } + + // slf-referential rowid FK + for outer in 0..OUTER_ITERS { + let limbo_db = TempDatabase::new_empty(true); + let sqlite_db = TempDatabase::new_empty(true); + let limbo = limbo_db.connect_limbo(); + let sqlite = rusqlite::Connection::open(sqlite_db.path.clone()).unwrap(); + + let mut stmts: Vec = Vec::new(); + let log = |s: &str, stmts: &mut Vec| { + stmts.push(s.to_string()); + s.to_string() + }; + + for s in [ + "PRAGMA foreign_keys=ON", + "CREATE TABLE t(id INTEGER PRIMARY KEY, rid REFERENCES t(id))", + ] { + let s = log(s, &mut stmts); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + } + + // Self-match should succeed for many ids + for _ in 0..INNER_ITERS { + let id = rng.random_range(1..=500); + let stmt = log( + &format!("INSERT INTO t(id,rid) VALUES({id},{id})"), + &mut stmts, + ); + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + assert_parity(seed, &stmts, sres, lres, &stmt, "B1: self-row ok"); + } + + // Mismatch (rid != id) should fail (unless the referenced id already exists). + for _ in 0..rng.random_range(1..=10) { + let id = rng.random_range(1..=20); + let s = log( + &format!("INSERT INTO t(id,rid) VALUES({id},{id})"), + &mut stmts, + ); + let s_res = sqlite.execute(&s, params![]); + let turso_rs = limbo_exec_rows_fallible(&limbo_db, &limbo, &s); + match (s_res.is_ok(), turso_rs.is_ok()) { + (true, true) | (false, false) => {} + _ => panic!("Seeding self-ref failed differently"), + } + } + + for _ in 0..INNER_ITERS { + let id = rng.random_range(600..=900); + let ref_ = rng.random_range(1..=25); + let stmt = log( + &format!("INSERT INTO t(id,rid) VALUES({id},{ref_})"), + &mut stmts, + ); + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + assert_parity(seed, &stmts, sres, lres, &stmt, "B2: self-row mismatch"); + } + println!("B {}/{} ok", outer + 1, OUTER_ITERS); + } + + // self-referential UNIQUE(u,v) parent (fast-path for composite) + for outer in 0..OUTER_ITERS { + let limbo_db = TempDatabase::new_empty(true); + let sqlite_db = TempDatabase::new_empty(true); + let limbo = limbo_db.connect_limbo(); + let sqlite = rusqlite::Connection::open(sqlite_db.path.clone()).unwrap(); + + let mut stmts: Vec = Vec::new(); + let log = |s: &str, stmts: &mut Vec| { + stmts.push(s.to_string()); + s.to_string() + }; + + let s = log("PRAGMA foreign_keys=ON", &mut stmts); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + // Variant the schema a bit: TEXT/TEXT, NUMERIC/TEXT, etc. + let decls = [ + ("TEXT", "TEXT"), + ("TEXT", "NUMERIC"), + ("NUMERIC", "TEXT"), + ("TEXT", "BLOB"), + ]; + let (tu, tv) = decls[rng.random_range(0..decls.len())]; + + let s = log( + &format!( + "CREATE TABLE sr(u {tu}, v {tv}, cu {tu}, cv {tv}, UNIQUE(u,v), \ + FOREIGN KEY(cu,cv) REFERENCES sr(u,v))" + ), + &mut stmts, + ); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + // Self-matching composite rows should succeed + for _ in 0..INNER_ITERS { + // Random small tokens, possibly padded + let u = format!("U{}", rng.random_range(0..50)); + let v = format!("V{}", rng.random_range(0..50)); + let mut cu = u.clone(); + let mut cv = v.clone(); + + // occasionally wrap child refs as blobs/text to stress coercion on parent index + if rng.random_bool(0.2) { + // child cv as hex blob of ascii v + let hex: String = v.bytes().map(|b| format!("{b:02X}")).collect(); + cv = format!("x'{hex}'"); + } else { + cu = format!("'{cu}'"); + cv = format!("'{cv}'"); + } + + let stmt = log( + &format!("INSERT INTO sr(u,v,cu,cv) VALUES('{u}','{v}',{cu},{cv})"), + &mut stmts, + ); + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + assert_parity(seed, &stmts, sres, lres, &stmt, "C1: self-UNIQUE ok"); + } + + // Non-self-match likely fails unless earlier rows happen to satisfy (u,v) + for _ in 0..INNER_ITERS { + let u = format!("U{}", rng.random_range(60..100)); + let v = format!("V{}", rng.random_range(60..100)); + let cu = format!("'U{}'", rng.random_range(0..40)); + let cv = format!("'{}{}'", "V", rng.random_range(0..40)); + let stmt = log( + &format!("INSERT INTO sr(u,v,cu,cv) VALUES('{u}','{v}',{cu},{cv})"), + &mut stmts, + ); + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + assert_parity(seed, &stmts, sres, lres, &stmt, "C2: self-UNIQUE mismatch"); + } + println!("C {}/{} ok", outer + 1, OUTER_ITERS); + } + + // parent TEXT UNIQUE(u,v), child types differ; rely on parent-index affinities + for outer in 0..OUTER_ITERS { + let limbo_db = TempDatabase::new_empty(true); + let sqlite_db = TempDatabase::new_empty(true); + let limbo = limbo_db.connect_limbo(); + let sqlite = rusqlite::Connection::open(sqlite_db.path.clone()).unwrap(); + + let mut stmts: Vec = Vec::new(); + let log = |s: &str, stmts: &mut Vec| { + stmts.push(s.to_string()); + s.to_string() + }; + + for s in [ + "PRAGMA foreign_keys=ON", + "CREATE TABLE parent(u TEXT, v TEXT, UNIQUE(u,v))", + "CREATE TABLE child(id INTEGER PRIMARY KEY, cu INT, cv BLOB, \ + FOREIGN KEY(cu,cv) REFERENCES parent(u,v))", + ] { + let s = log(s, &mut stmts); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + } + + for _ in 0..rng.random_range(3..=8) { + let u_raw = rng.random_range(0..=9); + let v_raw = rng.random_range(0..=9); + let u = if rng.random_bool(0.4) { + format!("+{u_raw}") + } else { + format!("{u_raw}") + }; + let v = if rng.random_bool(0.5) { + format!("{v_raw:02}",) + } else { + format!("{v_raw}") + }; + let s = log( + &format!("INSERT INTO parent VALUES('{u}','{v}')"), + &mut stmts, + ); + let l_res = limbo_exec_rows_fallible(&limbo_db, &limbo, &s); + let s_res = sqlite.execute(&s, params![]); + match (s_res, l_res) { + (Ok(_), Ok(_)) | (Err(_), Err(_)) => {} + (x, y) => { + panic!("Parent seeding mismatch: sqlite {x:?}, limbo {y:?}"); + } + } + } + + for i in 0..INNER_ITERS { + let id = i as i64 + 1; + let u_txt = if rng.random_bool(0.7) { + format!("+{}", rng.random_range(0..=9)) + } else { + format!("{}", rng.random_range(0..=9)) + }; + let v_txt = if rng.random_bool(0.5) { + format!("{:02}", rng.random_range(0..=9)) + } else { + format!("{}", rng.random_range(0..=9)) + }; + + // produce child literals that *look different* but should match under TEXT affinity + // cu uses integer-ish form of u; cv uses blob of ASCII v or quoted v randomly. + let cu = if let Ok(u_int) = u_txt.trim().trim_start_matches('+').parse::() { + if rng.random_bool(0.5) { + format!("{u_int}",) + } else { + format!("'{u_txt}'") + } + } else { + format!("'{u_txt}'") + }; + let cv = if rng.random_bool(0.6) { + let hex: String = v_txt + .as_bytes() + .iter() + .map(|b| format!("{b:02X}")) + .collect(); + format!("x'{hex}'") + } else { + format!("'{v_txt}'") + }; + + let stmt = log( + &format!("INSERT INTO child VALUES({id}, {cu}, {cv})"), + &mut stmts, + ); + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + assert_parity(seed, &stmts, sres, lres, &stmt, "D1: parent-index affinity"); + } + + for i in 0..(INNER_ITERS / 3) { + let id = 10_000 + i as i64; + let cu = rng.random_range(0..=9); + let miss = rng.random_range(10..=19); + let stmt = log( + &format!("INSERT INTO child VALUES({id}, {cu}, x'{miss:02X}')"), + &mut stmts, + ); + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + assert_parity(seed, &stmts, sres, lres, &stmt, "D2: parent-index negative"); + } + println!("D {}/{} ok", outer + 1, OUTER_ITERS); + } + + println!("fk_edgecases_minifuzz complete (seed {seed})"); + } + + #[test] + pub fn fk_composite_pk_mutation_fuzz() { + let _ = env_logger::try_init(); + let (mut rng, seed) = rng_from_time_or_env(); + println!("fk_composite_pk_mutation_fuzz seed: {seed}"); + + const OUTER_ITERS: usize = 10; + const INNER_ITERS: usize = 100; + + for outer in 0..OUTER_ITERS { + println!( + "fk_composite_pk_mutation_fuzz {}/{}", + outer + 1, + OUTER_ITERS + ); + + let limbo_db = TempDatabase::new_empty(true); + let sqlite_db = TempDatabase::new_empty(true); + let limbo = limbo_db.connect_limbo(); + let sqlite = rusqlite::Connection::open(sqlite_db.path.clone()).unwrap(); + + let mut stmts: Vec = Vec::new(); + let mut log_and_exec = |sql: &str| { + stmts.push(sql.to_string()); + sql.to_string() + }; + + // Enable FKs in both engines + let _ = log_and_exec("PRAGMA foreign_keys=ON"); + limbo_exec_rows(&limbo_db, &limbo, "PRAGMA foreign_keys=ON"); + sqlite.execute("PRAGMA foreign_keys=ON", params![]).unwrap(); + + // Parent PK is composite (a,b). Child references (x,y) -> (a,b). + let s = log_and_exec( + "CREATE TABLE p(a INT NOT NULL, b INT NOT NULL, v INT, PRIMARY KEY(a,b))", + ); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + let s = log_and_exec( + "CREATE TABLE c(id INTEGER PRIMARY KEY, x INT, y INT, w INT, \ + FOREIGN KEY(x,y) REFERENCES p(a,b))", + ); + limbo_exec_rows(&limbo_db, &limbo, &s); + sqlite.execute(&s, params![]).unwrap(); + + // Seed parent: small grid of (a,b) + let mut pairs: Vec<(i64, i64)> = Vec::new(); + for _ in 0..rng.random_range(5..=25) { + let a = rng.random_range(-3..=6); + let b = rng.random_range(-3..=6); + if !pairs.contains(&(a, b)) { + pairs.push((a, b)); + let v = rng.random_range(0..=20); + let stmt = log_and_exec(&format!("INSERT INTO p VALUES({a},{b},{v})")); + limbo_exec_rows(&limbo_db, &limbo, &stmt); + sqlite.execute(&stmt, params![]).unwrap(); + } + } + + // Seed child rows, 70% chance to reference existing (a,b) + for i in 0..rng.random_range(5..=60) { + let id = 5000 + i as i64; + let (x, y) = if rng.random_bool(0.7) { + *pairs.choose(&mut rng).unwrap_or(&(0, 0)) + } else { + (rng.random_range(-4..=7), rng.random_range(-4..=7)) + }; + let w = rng.random_range(-10..=10); + let stmt = log_and_exec(&format!("INSERT INTO c VALUES({id}, {x}, {y}, {w})")); + let _ = sqlite.execute(&stmt, params![]); + let _ = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + } + + for _ in 0..INNER_ITERS { + let op = rng.random_range(0..7); + let stmt = log_and_exec(&match op { + // INSERT parent + 0 => { + let a = rng.random_range(-4..=8); + let b = rng.random_range(-4..=8); + let v = rng.random_range(0..=20); + format!("INSERT INTO p VALUES({a},{b},{v})") + } + // UPDATE parent composite key (a,b) + 1 => { + let a_old = rng.random_range(-4..=8); + let b_old = rng.random_range(-4..=8); + let a_new = rng.random_range(-4..=8); + let b_new = rng.random_range(-4..=8); + format!("UPDATE p SET a={a_new}, b={b_new} WHERE a={a_old} AND b={b_old}") + } + // DELETE parent + 2 => { + let a = rng.random_range(-4..=8); + let b = rng.random_range(-4..=8); + format!("DELETE FROM p WHERE a={a} AND b={b}") + } + // INSERT child + 3 => { + let id = rng.random_range(5000..=7000); + let (x, y) = if rng.random_bool(0.7) { + *pairs.choose(&mut rng).unwrap_or(&(0, 0)) + } else { + (rng.random_range(-4..=8), rng.random_range(-4..=8)) + }; + let w = rng.random_range(-10..=10); + format!("INSERT INTO c VALUES({id},{x},{y},{w})") + } + // UPDATE child FK columns (x,y) + 4 => { + let id = rng.random_range(5000..=7000); + let (x, y) = if rng.random_bool(0.7) { + *pairs.choose(&mut rng).unwrap_or(&(0, 0)) + } else { + (rng.random_range(-4..=8), rng.random_range(-4..=8)) + }; + format!("UPDATE c SET x={x}, y={y} WHERE id={id}") + } + 5 => { + // UPSERT parent + if rng.random_bool(0.5) { + let a = rng.random_range(-4..=8); + let b = rng.random_range(-4..=8); + let v = rng.random_range(0..=20); + format!( + "INSERT INTO p VALUES({a},{b},{v}) ON CONFLICT(a,b) DO UPDATE SET v=excluded.v" + ) + } else { + let a = rng.random_range(-4..=8); + let b = rng.random_range(-4..=8); + format!( + "INSERT INTO p VALUES({a},{b},{}) ON CONFLICT(a,b) DO NOTHING", + rng.random_range(0..=20) + ) + } + } + 6 => { + // UPSERT child + let id = rng.random_range(5000..=7000); + let (x, y) = if rng.random_bool(0.7) { + *pairs.choose(&mut rng).unwrap_or(&(0, 0)) + } else { + (rng.random_range(-4..=8), rng.random_range(-4..=8)) + }; + format!( + "INSERT INTO c VALUES({id},{x},{y},{}) ON CONFLICT(id) DO UPDATE SET x=excluded.x, y=excluded.y", + rng.random_range(-10..=10) + ) + } + // DELETE child + _ => { + let id = rng.random_range(5000..=7000); + format!("DELETE FROM c WHERE id={id}") + } + }); + + let sres = sqlite.execute(&stmt, params![]); + let lres = limbo_exec_rows_fallible(&limbo_db, &limbo, &stmt); + + match (sres, lres) { + (Ok(_), Ok(_)) => { + // Compare canonical states + let sp = sqlite_exec_rows(&sqlite, "SELECT a,b,v FROM p ORDER BY a,b,v"); + let sc = sqlite_exec_rows(&sqlite, "SELECT id,x,y,w FROM c ORDER BY id"); + let lp = limbo_exec_rows( + &limbo_db, + &limbo, + "SELECT a,b,v FROM p ORDER BY a,b,v", + ); + let lc = limbo_exec_rows( + &limbo_db, + &limbo, + "SELECT id,x,y,w FROM c ORDER BY id", + ); + assert_eq!(sp, lp, "seed {seed}, stmt {stmt}"); + assert_eq!(sc, lc, "seed {seed}, stmt {stmt}"); + } + (Err(_), Err(_)) => { /* both errored -> parity OK */ } + (ok_s, ok_l) => { + eprintln!( + "Mismatch sqlite={ok_s:?}, limbo={ok_l:?}, stmt={stmt}, seed={seed}" + ); + let sp = sqlite_exec_rows(&sqlite, "SELECT a,b,v FROM p ORDER BY a,b,v"); + let sc = sqlite_exec_rows(&sqlite, "SELECT id,x,y,w FROM c ORDER BY id"); + let lp = limbo_exec_rows( + &limbo_db, + &limbo, + "SELECT a,b,v FROM p ORDER BY a,b,v", + ); + let lc = limbo_exec_rows( + &limbo_db, + &limbo, + "SELECT id,x,y,w FROM c ORDER BY id", + ); + eprintln!( + "sqlite p={sp:?}\nsqlite c={sc:?}\nlimbo p={lp:?}\nlimbo c={lc:?}" + ); + let mut file = + std::fs::File::create("fk_composite_fuzz_statements.sql").unwrap(); + for s in stmts.iter() { + let _ = writeln!(&file, "{s};"); + } + file.flush().unwrap(); + panic!("DML outcome mismatch, sql file written to tests/fk_composite_fuzz_statements.sql"); + } + } + } + } + } + #[test] /// Create a table with a random number of columns and indexes, and then randomly update or delete rows from the table. /// Verify that the results are the same for SQLite and Turso. pub fn table_index_mutation_fuzz() { let _ = env_logger::try_init(); - let (mut rng, seed) = rng_from_time(); - println!("index_scan_single_key_mutation_fuzz seed: {seed}"); + let (mut rng, seed) = rng_from_time_or_env(); + println!("table_index_mutation_fuzz seed: {seed}"); const OUTER_ITERATIONS: usize = 100; for i in 0..OUTER_ITERATIONS { @@ -678,9 +1873,33 @@ mod tests { let table_def = format!("CREATE TABLE t ({table_def})"); let num_indexes = rng.random_range(0..=num_cols); - let indexes = (0..num_indexes) - .map(|i| format!("CREATE INDEX idx_{i} ON t(c{i})")) - .collect::>(); + let mut indexes = Vec::new(); + for i in 0..num_indexes { + // Decide if this should be a single-column or multi-column index + let is_multi_column = rng.random_bool(0.5) && num_cols > 1; + + if is_multi_column { + // Create a multi-column index with 2-3 columns + let num_index_cols = rng.random_range(2..=3.min(num_cols)); + let mut index_cols = Vec::new(); + let mut available_cols: Vec = (0..num_cols).collect(); + + for _ in 0..num_index_cols { + let idx = rng.random_range(0..available_cols.len()); + let col = available_cols.remove(idx); + index_cols.push(format!("c{col}")); + } + + indexes.push(format!( + "CREATE INDEX idx_{i} ON t({})", + index_cols.join(", ") + )); + } else { + // Single-column index + let col = rng.random_range(0..num_cols); + indexes.push(format!("CREATE INDEX idx_{i} ON t(c{col})")); + } + } // Create tables and indexes in both databases let limbo_conn = limbo_db.connect_limbo(); @@ -764,8 +1983,22 @@ mod tests { }; let query = if do_update { - let new_y = rng.random_range(0..1000); - format!("UPDATE t SET c{affected_col} = {new_y} {where_clause}") + let num_updates = rng.random_range(1..=num_cols); + let mut values = Vec::new(); + for _ in 0..num_updates { + let new_y = if rng.random_bool(0.5) { + // Update to a constant value + rng.random_range(0..1000).to_string() + } else { + let source_col = rng.random_range(0..num_cols); + // Update to a value that is a function of the another column + let operator = *["+", "-"].choose(&mut rng).unwrap(); + let amount = rng.random_range(0..1000); + format!("c{source_col} {operator} {amount}") + }; + values.push(format!("c{affected_col} = {new_y}")); + } + format!("UPDATE t SET {} {where_clause}", values.join(", ")) } else { format!("DELETE FROM t {where_clause}") }; @@ -803,6 +2036,19 @@ mod tests { "Different results after mutation! limbo: {limbo_rows:?}, sqlite: {sqlite_rows:?}, seed: {seed}, query: {query}", ); + // Run integrity check on limbo db using rusqlite + if let Err(e) = rusqlite_integrity_check(&limbo_db.path) { + println!("{table_def};"); + for t in indexes.iter() { + println!("{t};"); + } + for t in dml_statements.iter() { + println!("{t};"); + } + println!("{query};"); + panic!("seed: {seed}, error: {e}"); + } + if sqlite_rows.is_empty() { break; } @@ -825,12 +2071,7 @@ mod tests { const OUTER_ITERS: usize = 5; const INNER_ITERS: usize = 500; - let (mut rng, seed) = if std::env::var("SEED").is_ok() { - let seed = std::env::var("SEED").unwrap().parse::().unwrap(); - (ChaCha8Rng::seed_from_u64(seed), seed) - } else { - rng_from_time() - }; + let (mut rng, seed) = rng_from_time_or_env(); println!("partial_index_mutation_and_upsert_fuzz seed: {seed}"); // we want to hit unique constraints fairly often so limit the insert values const K_POOL: [&str; 35] = [ @@ -986,6 +2227,22 @@ mod tests { } for _ in 0..INNER_ITERS { + // Randomly inject transaction statements -- we don't care if they are legal, + // we just care that tursodb/sqlite behave the same way. + if rng.random_bool(0.15) { + let tx_stmt = match rng.random_range(0..4) { + 0 => "BEGIN", + 1 => "BEGIN IMMEDIATE", + 2 => "COMMIT", + 3 => "ROLLBACK", + _ => unreachable!(), + }; + println!("{tx_stmt};"); + let sqlite_res = sqlite.execute(tx_stmt, rusqlite::params![]); + let limbo_res = limbo_exec_rows_fallible(&limbo_db, &limbo_conn, tx_stmt); + // Both should succeed or both should fail + assert!(sqlite_res.is_ok() == limbo_res.is_ok()); + } let action = rng.random_range(0..4); // 0: INSERT, 1: UPDATE, 2: DELETE, 3: UPSERT (catch-all) let stmt = match action { // INSERT @@ -1161,7 +2418,7 @@ mod tests { #[test] pub fn compound_select_fuzz() { let _ = env_logger::try_init(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("compound_select_fuzz seed: {seed}"); // Constants for fuzzing parameters @@ -1297,7 +2554,7 @@ mod tests { #[test] pub fn ddl_compatibility_fuzz() { let _ = env_logger::try_init(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); const ITERATIONS: usize = 1000; for i in 0..ITERATIONS { let db = TempDatabase::new_empty(true); @@ -1466,7 +2723,7 @@ mod tests { let limbo_conn = db.connect_limbo(); let sqlite_conn = rusqlite::Connection::open_in_memory().unwrap(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); for _ in 0..1024 { let query = g.generate(&mut rng, sql, 50); @@ -1585,7 +2842,7 @@ mod tests { let limbo_conn = db.connect_limbo(); let sqlite_conn = rusqlite::Connection::open_in_memory().unwrap(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); for _ in 0..1024 { let query = g.generate(&mut rng, sql, 50); @@ -1745,7 +3002,7 @@ mod tests { let limbo_conn = db.connect_limbo(); let sqlite_conn = rusqlite::Connection::open_in_memory().unwrap(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); for _ in 0..1024 { let query = g.generate(&mut rng, sql, 50); @@ -2114,7 +3371,7 @@ mod tests { let limbo_conn = db.connect_limbo(); let sqlite_conn = rusqlite::Connection::open_in_memory().unwrap(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); for _ in 0..1024 { let query = g.generate(&mut rng, sql, 50); @@ -2163,7 +3420,7 @@ mod tests { let _ = env_logger::try_init(); let datatypes = ["INTEGER", "TEXT", "REAL", "BLOB"]; - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); for _ in 0..1000 { @@ -2218,7 +3475,7 @@ mod tests { pub fn affinity_fuzz() { let _ = env_logger::try_init(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("affinity_fuzz seed: {seed}"); for iteration in 0..500 { @@ -2319,7 +3576,7 @@ mod tests { pub fn sum_agg_fuzz_floats() { let _ = env_logger::try_init(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); for _ in 0..100 { @@ -2365,7 +3622,7 @@ mod tests { pub fn sum_agg_fuzz() { let _ = env_logger::try_init(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); for _ in 0..100 { @@ -2409,7 +3666,7 @@ mod tests { fn concat_ws_fuzz() { let _ = env_logger::try_init(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); for _ in 0..100 { @@ -2455,7 +3712,7 @@ mod tests { pub fn total_agg_fuzz() { let _ = env_logger::try_init(); - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); for _ in 0..100 { @@ -2531,7 +3788,7 @@ mod tests { ); } - let (mut rng, seed) = rng_from_time(); + let (mut rng, seed) = rng_from_time_or_env(); log::info!("seed: {seed}"); let mut i = 0; @@ -2663,7 +3920,7 @@ mod tests { } #[test] - pub fn fuzz_long_create_table_drop_table_alter_table() { + pub fn fuzz_long_create_table_drop_table_alter_table_normal() { _fuzz_long_create_table_drop_table_alter_table(false); } @@ -2695,6 +3952,8 @@ mod tests { let mut undroppable_cols = HashSet::new(); + let mut stmts = vec![]; + for iteration in 0..2000 { println!("iteration: {iteration} (seed: {seed})"); let operation = rng.random_range(0..100); // 0: create, 1: drop, 2: alter, 3: alter rename @@ -2749,8 +4008,8 @@ mod tests { format!("CREATE TABLE {table_name} ({})", columns.join(", ")); // Execute the create table statement + stmts.push(create_sql.clone()); limbo_exec_rows(&db, &limbo_conn, &create_sql); - let column_names = columns .iter() .map(|c| c.split_whitespace().next().unwrap().to_string()) @@ -2765,6 +4024,7 @@ mod tests { .collect::>() .join(", ") ); + stmts.push(insert_sql.clone()); limbo_exec_rows(&db, &limbo_conn, &insert_sql); // Successfully created table, update our tracking @@ -2779,6 +4039,7 @@ mod tests { let table_to_drop = &table_names[rng.random_range(0..table_names.len())]; let drop_sql = format!("DROP TABLE {table_to_drop}"); + stmts.push(drop_sql.clone()); limbo_exec_rows(&db, &limbo_conn, &drop_sql); // Successfully dropped table, update our tracking @@ -2799,6 +4060,7 @@ mod tests { table_to_alter, &new_col_name, col_type ); + stmts.push(alter_sql.clone()); limbo_exec_rows(&db, &limbo_conn, &alter_sql); // Successfully added column, update our tracking @@ -2830,6 +4092,7 @@ mod tests { let alter_sql = format!( "ALTER TABLE {table_to_alter} DROP COLUMN {col_to_drop}" ); + stmts.push(alter_sql.clone()); limbo_exec_rows(&db, &limbo_conn, &alter_sql); // Successfully dropped column, update our tracking @@ -2866,6 +4129,14 @@ mod tests { "seed: {seed}, mvcc: {mvcc}, table: {table_name}" ); } + if !mvcc { + if let Err(e) = rusqlite_integrity_check(&db.path) { + for stmt in stmts.iter() { + println!("{stmt};"); + } + panic!("seed: {seed}, mvcc: {mvcc}, error: {e}"); + } + } } // Final verification - the test passes if we didn't crash @@ -2878,7 +4149,7 @@ mod tests { #[test] #[cfg(feature = "test_helper")] pub fn fuzz_pending_byte_database() -> anyhow::Result<()> { - use crate::common::rusqlite_integrity_check; + use core_tester::common::rusqlite_integrity_check; maybe_setup_tracing(); let (mut rng, seed) = rng_from_time_or_env(); diff --git a/tests/fuzz/rowid_alias.rs b/tests/fuzz/rowid_alias.rs new file mode 100644 index 000000000..28c08a074 --- /dev/null +++ b/tests/fuzz/rowid_alias.rs @@ -0,0 +1,211 @@ +use core_tester::common::{limbo_exec_rows, TempDatabase}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sql_generation::{ + generation::{Arbitrary, GenerationContext, Opts}, + model::{ + query::{Create, Insert, Select}, + table::{Column, ColumnType, Table}, + }, +}; +use turso_parser::ast::ColumnConstraint; + +fn rng_from_time_or_env() -> (ChaCha8Rng, u64) { + let seed = if let Ok(seed_str) = std::env::var("FUZZ_SEED") { + seed_str.parse::().expect("Invalid FUZZ_SEED value") + } else { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + }; + let rng = ChaCha8Rng::seed_from_u64(seed); + (rng, seed) +} + +// Our test context that implements GenerationContext +#[derive(Debug, Clone)] +struct FuzzTestContext { + opts: Opts, + tables: Vec
, +} + +impl FuzzTestContext { + fn new() -> Self { + Self { + opts: Opts::default(), + tables: Vec::new(), + } + } + + fn add_table(&mut self, table: Table) { + self.tables.push(table); + } +} + +impl GenerationContext for FuzzTestContext { + fn tables(&self) -> &Vec
{ + &self.tables + } + + fn opts(&self) -> &Opts { + &self.opts + } +} + +// Convert a table's CREATE statement to use INTEGER PRIMARY KEY (rowid alias) +fn convert_to_rowid_alias(create_sql: &str) -> String { + // Since we always generate INTEGER PRIMARY KEY, just return as-is + create_sql.to_string() +} + +// Convert a table's CREATE statement to NOT use rowid alias +fn convert_to_no_rowid_alias(create_sql: &str) -> String { + // Replace INTEGER PRIMARY KEY with INT PRIMARY KEY to disable rowid alias + create_sql.replace("INTEGER PRIMARY KEY", "INT PRIMARY KEY") +} + +#[test] +#[ignore] +pub fn rowid_alias_differential_fuzz() { + let (mut rng, seed) = rng_from_time_or_env(); + tracing::info!("rowid_alias_differential_fuzz seed: {}", seed); + + // Number of queries to test + let num_queries = if let Ok(num) = std::env::var("FUZZ_NUM_QUERIES") { + num.parse::().unwrap_or(1000) + } else { + 1000 + }; + + // Create two Limbo databases with indexes enabled + let db_with_alias = TempDatabase::new_empty(true); + let db_without_alias = TempDatabase::new_empty(true); + + // Connect to both databases + let conn_with_alias = db_with_alias.connect_limbo(); + let conn_without_alias = db_without_alias.connect_limbo(); + + // Create our test context + let mut context = FuzzTestContext::new(); + + let mut successful_queries = 0; + let mut skipped_queries = 0; + + for iteration in 0..num_queries { + // Decide whether to create a new table, insert data, or generate a query + let action = + if context.tables.is_empty() || (context.tables.len() < 5 && rng.random_bool(0.1)) { + 0 // Create a new table + } else if rng.random_bool(0.3) { + 1 // Insert data + } else { + 2 // Generate a SELECT query + }; + + match action { + 0 => { + // Generate a new table with an integer primary key + let primary_key = Column { + name: "id".to_string(), + column_type: ColumnType::Integer, + constraints: vec![ColumnConstraint::PrimaryKey { + order: None, + conflict_clause: None, + auto_increment: false, + }], + }; + let table_name = format!("table_{}", context.tables.len()); + let table = Table::arbitrary_with_columns( + &mut rng, + &context, + table_name, + vec![primary_key], + ); + let create = Create { + table: table.clone(), + }; + + // Create table with rowid alias in first database + let create_with_alias = convert_to_rowid_alias(&create.to_string()); + let _ = limbo_exec_rows(&db_with_alias, &conn_with_alias, &create_with_alias); + + // Create table without rowid alias in second database + let create_without_alias = convert_to_no_rowid_alias(&create.to_string()); + let _ = limbo_exec_rows( + &db_without_alias, + &conn_without_alias, + &create_without_alias, + ); + + // Add table to context for future query generation + context.add_table(table); + + skipped_queries += 1; + continue; + } + 1 => { + // Generate and execute an INSERT statement + let insert = Insert::arbitrary(&mut rng, &context); + let insert_str = insert.to_string(); + + // Execute the insert in both databases + let _ = limbo_exec_rows(&db_with_alias, &conn_with_alias, &insert_str); + let _ = limbo_exec_rows(&db_without_alias, &conn_without_alias, &insert_str); + + // Update the table's rows in the context so predicate generation knows about the data + if let Insert::Values { + table: table_name, + values, + } = &insert + { + for table in &mut context.tables { + if table.name == *table_name { + table.rows.extend(values.clone()); + break; + } + } + } + + skipped_queries += 1; + continue; + } + _ => { + // Continue to generate SELECT query below + } + } + + let select = Select::arbitrary(&mut rng, &context); + let query_str = select.to_string(); + + tracing::debug!("Comparing query {}: {}", iteration, query_str); + + let with_alias_results = limbo_exec_rows(&db_with_alias, &conn_with_alias, &query_str); + let without_alias_results = + limbo_exec_rows(&db_without_alias, &conn_without_alias, &query_str); + + let mut sorted_with_alias = with_alias_results; + let mut sorted_without_alias = without_alias_results; + + // Sort results to handle different row ordering + sorted_with_alias.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}"))); + sorted_without_alias.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}"))); + + assert_eq!( + sorted_with_alias, sorted_without_alias, + "Query produced different results with and without rowid alias!\n\ + Query: {query_str}\n\ + With rowid alias: {sorted_with_alias:?}\n\ + Without rowid alias: {sorted_without_alias:?}\n\ + Seed: {seed}" + ); + + successful_queries += 1; + } + + tracing::info!( + "Rowid alias differential fuzz test completed: {} queries tested successfully, {} queries skipped", + successful_queries, + skipped_queries + ); +} diff --git a/tests/integration/common.rs b/tests/integration/common.rs index 175a98a53..65e171d1a 100644 --- a/tests/integration/common.rs +++ b/tests/integration/common.rs @@ -163,7 +163,7 @@ impl TempDatabase { } } -pub(crate) fn do_flush(conn: &Arc, tmp_db: &TempDatabase) -> anyhow::Result<()> { +pub fn do_flush(conn: &Arc, tmp_db: &TempDatabase) -> anyhow::Result<()> { let completions = conn.cacheflush()?; for c in completions { tmp_db.io.wait_for_completion(c)?; @@ -171,7 +171,7 @@ pub(crate) fn do_flush(conn: &Arc, tmp_db: &TempDatabase) -> anyhow: Ok(()) } -pub(crate) fn compare_string(a: impl AsRef, b: impl AsRef) { +pub fn compare_string(a: impl AsRef, b: impl AsRef) { let a = a.as_ref(); let b = b.as_ref(); @@ -204,7 +204,7 @@ pub fn maybe_setup_tracing() { .try_init(); } -pub(crate) fn sqlite_exec_rows( +pub fn sqlite_exec_rows( conn: &rusqlite::Connection, query: &str, ) -> Vec> { @@ -227,7 +227,7 @@ pub(crate) fn sqlite_exec_rows( results } -pub(crate) fn limbo_exec_rows( +pub fn limbo_exec_rows( _db: &TempDatabase, conn: &Arc, query: &str, @@ -266,7 +266,8 @@ pub(crate) fn limbo_exec_rows( rows } -pub(crate) fn limbo_stmt_get_column_names( +#[allow(dead_code)] +pub fn limbo_stmt_get_column_names( _db: &TempDatabase, conn: &Arc, query: &str, @@ -280,7 +281,7 @@ pub(crate) fn limbo_stmt_get_column_names( names } -pub(crate) fn limbo_exec_rows_fallible( +pub fn limbo_exec_rows_fallible( _db: &TempDatabase, conn: &Arc, query: &str, @@ -319,7 +320,7 @@ pub(crate) fn limbo_exec_rows_fallible( Ok(rows) } -pub(crate) fn limbo_exec_rows_error( +pub fn limbo_exec_rows_error( _db: &TempDatabase, conn: &Arc, query: &str, @@ -338,7 +339,7 @@ pub(crate) fn limbo_exec_rows_error( } } -pub(crate) fn rng_from_time() -> (ChaCha8Rng, u64) { +pub fn rng_from_time() -> (ChaCha8Rng, u64) { let seed = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() @@ -401,6 +402,7 @@ pub fn run_query_core( Ok(()) } +#[allow(dead_code)] pub fn rusqlite_integrity_check(db_path: &Path) -> anyhow::Result<()> { let conn = rusqlite::Connection::open(db_path)?; let mut stmt = conn.prepare("SELECT * FROM pragma_integrity_check;")?; diff --git a/tests/integration/functions/test_cdc.rs b/tests/integration/functions/test_cdc.rs index d631c1f93..28606eca3 100644 --- a/tests/integration/functions/test_cdc.rs +++ b/tests/integration/functions/test_cdc.rs @@ -1107,9 +1107,7 @@ fn test_cdc_schema_changes_alter_table() { Value::Text("t".to_string()), Value::Text("t".to_string()), Value::Integer(4), - Value::Text( - "CREATE TABLE t (x PRIMARY KEY, y PRIMARY KEY, z UNIQUE)".to_string() - ) + Value::Text("CREATE TABLE t (x, y, z UNIQUE, PRIMARY KEY (x, y))".to_string()) ])), Value::Blob(record([ Value::Integer(0), @@ -1135,9 +1133,7 @@ fn test_cdc_schema_changes_alter_table() { Value::Text("t".to_string()), Value::Text("t".to_string()), Value::Integer(4), - Value::Text( - "CREATE TABLE t (x PRIMARY KEY, y PRIMARY KEY, z UNIQUE)".to_string() - ) + Value::Text("CREATE TABLE t (x, y, z UNIQUE, PRIMARY KEY (x, y))".to_string()) ])), Value::Blob(record([ Value::Text("table".to_string()), @@ -1145,7 +1141,7 @@ fn test_cdc_schema_changes_alter_table() { Value::Text("t".to_string()), Value::Integer(4), Value::Text( - "CREATE TABLE t (x PRIMARY KEY, y PRIMARY KEY, z UNIQUE, t)".to_string() + "CREATE TABLE t (x, y, z UNIQUE, t, PRIMARY KEY (x, y))".to_string() ) ])), Value::Blob(record([ diff --git a/tests/integration/mod.rs b/tests/integration/mod.rs index 1149f224a..e369e68b7 100644 --- a/tests/integration/mod.rs +++ b/tests/integration/mod.rs @@ -1,6 +1,5 @@ mod common; mod functions; -mod fuzz; mod fuzz_transaction; mod pragma; mod query_processing; diff --git a/tests/integration/query_processing/test_btree.rs b/tests/integration/query_processing/test_btree.rs index e55c512f0..c24d025d6 100644 --- a/tests/integration/query_processing/test_btree.rs +++ b/tests/integration/query_processing/test_btree.rs @@ -130,7 +130,7 @@ pub fn write_varint(buf: &mut [u8], value: u64) -> usize { return 9; } - let mut encoded: [u8; 10] = [0; 10]; + let mut encoded: [u8; 9] = [0; 9]; let mut bytes = value; let mut n = 0; while bytes != 0 { diff --git a/tests/integration/query_processing/test_multi_thread.rs b/tests/integration/query_processing/test_multi_thread.rs index 6054c7a32..4f09a0f94 100644 --- a/tests/integration/query_processing/test_multi_thread.rs +++ b/tests/integration/query_processing/test_multi_thread.rs @@ -242,8 +242,10 @@ fn test_schema_reprepare_write() { } fn advance(stmt: &mut Statement) -> anyhow::Result<()> { - stmt.step()?; - stmt.run_once()?; + tracing::info!("Advancing statement: {:?}", stmt.get_sql()); + while matches!(stmt.step()?, StepResult::IO) { + stmt.run_once()?; + } Ok(()) } @@ -268,9 +270,9 @@ fn test_interleaved_transactions() -> anyhow::Result<()> { tmp_db.connect_limbo(), ]; - let mut statement2 = conn[2].prepare("BEGIN")?; let mut statement0 = conn[0].prepare("BEGIN")?; let mut statement1 = conn[1].prepare("BEGIN")?; + let mut statement2 = conn[2].prepare("BEGIN")?; advance(&mut statement2)?; diff --git a/tests/integration/query_processing/test_read_path.rs b/tests/integration/query_processing/test_read_path.rs index a0b72cd60..4b08df3b1 100644 --- a/tests/integration/query_processing/test_read_path.rs +++ b/tests/integration/query_processing/test_read_path.rs @@ -1,5 +1,5 @@ -use crate::common::TempDatabase; -use turso_core::{StepResult, Value}; +use crate::common::{limbo_exec_rows, TempDatabase}; +use turso_core::{LimboError, StepResult, Value}; #[test] fn test_statement_reset_bind() -> anyhow::Result<()> { @@ -876,3 +876,118 @@ fn test_upsert_parameters_order() -> anyhow::Result<()> { ); Ok(()) } + +#[test] +fn test_multiple_connections_visibility() -> anyhow::Result<()> { + let tmp_db = TempDatabase::new_with_rusqlite( + "CREATE TABLE test (k INTEGER PRIMARY KEY, v INTEGER);", + false, + ); + let conn1 = tmp_db.connect_limbo(); + let conn2 = tmp_db.connect_limbo(); + conn1.execute("BEGIN")?; + conn1.execute("INSERT INTO test VALUES (1, 2), (3, 4)")?; + let mut stmt = conn2.prepare("SELECT COUNT(*) FROM test").unwrap(); + let _ = stmt.step().unwrap(); + // intentionally drop not-fully-consumed statement in order to check that on Drop statement will execute reset with proper cleanup + drop(stmt); + conn1.execute("COMMIT")?; + + let rows = limbo_exec_rows(&tmp_db, &conn2, "SELECT COUNT(*) FROM test"); + assert_eq!(rows, vec![vec![rusqlite::types::Value::Integer(2)]]); + Ok(()) +} + +#[test] +/// Test that we can only join up to 63 tables, and trying to join more should fail with an error instead of panicing. +fn test_max_joined_tables_limit() { + let tmp_db = TempDatabase::new("test_max_joined_tables_limit", false); + let conn = tmp_db.connect_limbo(); + + // Create 64 tables + for i in 0..64 { + conn.execute(format!("CREATE TABLE t{i} (id INTEGER)")) + .unwrap(); + } + + // Try to join 64 tables - should fail + let mut sql = String::from("SELECT * FROM t0"); + for i in 1..64 { + sql.push_str(&format!(" JOIN t{i} ON t{i}.id = t0.id")); + } + + let Err(LimboError::ParseError(result)) = conn.prepare(&sql) else { + panic!("Expected an error but got no error"); + }; + assert!(result.contains("Only up to 63 tables can be joined")); +} + +#[test] +/// Test that we can create and select from a table with 1000 columns. +fn test_many_columns() { + let mut create_sql = String::from("CREATE TABLE test ("); + for i in 0..1000 { + if i > 0 { + create_sql.push_str(", "); + } + create_sql.push_str(&format!("col{i} INTEGER")); + } + create_sql.push(')'); + + let tmp_db = TempDatabase::new("test_many_columns", false); + let conn = tmp_db.connect_limbo(); + conn.execute(&create_sql).unwrap(); + + // Insert a row with values 0-999 + let mut insert_sql = String::from("INSERT INTO test VALUES ("); + for i in 0..1000 { + if i > 0 { + insert_sql.push_str(", "); + } + insert_sql.push_str(&i.to_string()); + } + insert_sql.push(')'); + conn.execute(&insert_sql).unwrap(); + + // Select every 100th column + let mut select_sql = String::from("SELECT "); + let mut first = true; + for i in (0..1000).step_by(100) { + if !first { + select_sql.push_str(", "); + } + select_sql.push_str(&format!("col{i}")); + first = false; + } + select_sql.push_str(" FROM test"); + + let mut rows = Vec::new(); + let mut stmt = conn.prepare(&select_sql).unwrap(); + loop { + match stmt.step().unwrap() { + StepResult::Row => { + let row = stmt.row().unwrap(); + rows.push(row.get_values().cloned().collect::>()); + } + StepResult::IO => stmt.run_once().unwrap(), + _ => break, + } + } + + // Verify we got values 0,100,200,...,900 + assert_eq!( + rows, + vec![vec![ + turso_core::Value::Integer(0), + turso_core::Value::Integer(100), + turso_core::Value::Integer(200), + turso_core::Value::Integer(300), + turso_core::Value::Integer(400), + turso_core::Value::Integer(500), + turso_core::Value::Integer(600), + turso_core::Value::Integer(700), + turso_core::Value::Integer(800), + turso_core::Value::Integer(900), + ]] + ); +} diff --git a/tests/integration/query_processing/test_transactions.rs b/tests/integration/query_processing/test_transactions.rs index 53ab7f00e..81b104788 100644 --- a/tests/integration/query_processing/test_transactions.rs +++ b/tests/integration/query_processing/test_transactions.rs @@ -95,6 +95,7 @@ fn test_deferred_transaction_no_restart() { .execute("INSERT INTO test (id, value) VALUES (2, 'second')") .unwrap(); conn2.execute("COMMIT").unwrap(); + drop(stmt); let mut stmt = conn1.query("SELECT COUNT(*) FROM test").unwrap().unwrap(); if let StepResult::Row = stmt.step().unwrap() { @@ -175,6 +176,79 @@ fn test_transaction_visibility() { } } +#[test] +/// A constraint error does not rollback the transaction, it rolls back the statement. +fn test_constraint_error_aborts_only_stmt_not_entire_transaction() { + let tmp_db = TempDatabase::new( + "test_constraint_error_aborts_only_stmt_not_entire_transaction.db", + true, + ); + let conn = tmp_db.connect_limbo(); + + // Create table succeeds + conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY)") + .unwrap(); + + // Begin succeeds + conn.execute("BEGIN").unwrap(); + + // First insert succeeds + conn.execute("INSERT INTO t VALUES (1),(2)").unwrap(); + + // Second insert fails due to UNIQUE constraint + let result = conn.execute("INSERT INTO t VALUES (2),(3)"); + assert!(matches!(result, Err(LimboError::Constraint(_)))); + + // Third insert is valid again + conn.execute("INSERT INTO t VALUES (4)").unwrap(); + + // Commit succeeds + conn.execute("COMMIT").unwrap(); + + // Make sure table has 3 rows (a=1, a=2, a=4) + let stmt = conn.query("SELECT a FROM t").unwrap().unwrap(); + let rows = helper_read_all_rows(stmt); + assert_eq!( + rows, + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(4)] + ] + ); +} + +#[test] +/// Regression test for https://github.com/tursodatabase/turso/issues/3784 where dirty pages +/// were flushed to WAL _before_ deferred FK violations were checked. This resulted in the +/// violations being persisted to the database, even though the transaction was aborted. +/// This test ensures that dirty pages are not flushed to WAL until after deferred violations are checked. +fn test_deferred_fk_violation_rollback_in_autocommit() { + let tmp_db = TempDatabase::new("test_deferred_fk_violation_rollback.db", true); + let conn = tmp_db.connect_limbo(); + + // Enable foreign keys + conn.execute("PRAGMA foreign_keys = ON").unwrap(); + + // Create parent and child tables with deferred FK constraint + conn.execute("CREATE TABLE parent(a PRIMARY KEY)").unwrap(); + conn.execute("CREATE TABLE child(a, b, FOREIGN KEY(b) REFERENCES parent(a) DEFERRABLE INITIALLY DEFERRED)") + .unwrap(); + + // This insert should fail because parent(1) doesn't exist + // and the deferred FK violation should be caught at statement end in autocommit mode + let result = conn.execute("INSERT INTO child VALUES(1,1)"); + assert!(matches!(result, Err(LimboError::Constraint(_)))); + + // Do a truncating checkpoint + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)").unwrap(); + + // Verify that the child table is empty (the insert was rolled back) + let stmt = conn.query("SELECT COUNT(*) FROM child").unwrap().unwrap(); + let row = helper_read_single_row(stmt); + assert_eq!(row, vec![Value::Integer(0)]); +} + #[test] fn test_mvcc_transactions_autocommit() { let tmp_db = TempDatabase::new_with_opts( diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index 85f666ee4..b9384cdc3 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -1,10 +1,11 @@ -use crate::common::{self, limbo_exec_rows, maybe_setup_tracing}; +use crate::common::{self, limbo_exec_rows, maybe_setup_tracing, rusqlite_integrity_check}; use crate::common::{compare_string, do_flush, TempDatabase}; use log::debug; use std::io::{Read, Seek, Write}; use std::sync::Arc; use turso_core::{ - CheckpointMode, Connection, Database, LimboError, Row, Statement, StepResult, Value, + CheckpointMode, Connection, Database, DatabaseOpts, LimboError, Row, Statement, StepResult, + Value, }; const WAL_HEADER_SIZE: usize = 32; @@ -508,6 +509,129 @@ fn test_update_regression() -> anyhow::Result<()> { Ok(()) } +#[test] +/// Test that a large insert statement containing a UNIQUE constraint violation +/// is properly rolled back so that the database size is also shrunk to the size +/// before that statement is executed. +fn test_rollback_on_unique_constraint_violation() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new_with_opts( + "big_statement_rollback.db", + DatabaseOpts::new().with_indexes(true), + ); + let conn = tmp_db.connect_limbo(); + + conn.execute("CREATE TABLE t(x UNIQUE)")?; + + conn.execute("BEGIN")?; + conn.execute("INSERT INTO t VALUES (10000)")?; + + // This should fail due to unique constraint violation + let result = conn.execute("INSERT INTO t SELECT value FROM generate_series(1,10000)"); + assert!(result.is_err(), "Expected unique constraint violation"); + + conn.execute("COMMIT")?; + + // Should have exactly 1 row (the first insert) + common::run_query_on_row(&tmp_db, &conn, "SELECT count(*) FROM t", |row| { + let count = row.get::(0).unwrap(); + assert_eq!(count, 1, "Expected 1 row after rollback"); + })?; + + // Check page count + common::run_query_on_row(&tmp_db, &conn, "PRAGMA page_count", |row| { + let page_count = row.get::(0).unwrap(); + assert_eq!(page_count, 3, "Expected 3 pages"); + })?; + + // Checkpoint the WAL + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")?; + + // Integrity check with rusqlite + rusqlite_integrity_check(tmp_db.path.as_path())?; + + // Size on disk should be 3 * 4096 + let db_size = std::fs::metadata(&tmp_db.path).unwrap().len(); + assert_eq!(db_size, 3 * 4096); + + Ok(()) +} + +#[test] +/// Test that a large delete statement containing a foreign key constraint violation +/// is properly rolled back. +fn test_rollback_on_foreign_key_constraint_violation() -> anyhow::Result<()> { + let _ = env_logger::try_init(); + let tmp_db = TempDatabase::new_with_opts( + "big_delete_rollback.db", + DatabaseOpts::new().with_indexes(true), + ); + let conn = tmp_db.connect_limbo(); + + // Enable foreign keys + conn.execute("PRAGMA foreign_keys = ON")?; + + // Create parent and child tables + conn.execute("CREATE TABLE parent(id INTEGER PRIMARY KEY)")?; + conn.execute( + "CREATE TABLE child(id INTEGER PRIMARY KEY, parent_id INTEGER REFERENCES parent(id))", + )?; + + // Insert 10000 parent rows + conn.execute("INSERT INTO parent SELECT value FROM generate_series(1,10000)")?; + + // Insert a child row that references the 10000th parent row + conn.execute("INSERT INTO child VALUES (1, 10000)")?; + + conn.execute("BEGIN")?; + + // Delete first parent row (should succeed) + conn.execute("DELETE FROM parent WHERE id = 1")?; + + // This should fail due to foreign key constraint violation (trying to delete parent row 10000 which has a child) + let result = conn.execute("DELETE FROM parent WHERE id >= 2"); + assert!(result.is_err(), "Expected foreign key constraint violation"); + + conn.execute("COMMIT")?; + + // Should have 9999 parent rows (10000 - 1 that was successfully deleted) + common::run_query_on_row(&tmp_db, &conn, "SELECT count(*) FROM parent", |row| { + let count = row.get::(0).unwrap(); + assert_eq!(count, 9999, "Expected 9999 parent rows after rollback"); + })?; + + // Verify rows 2-10000 are intact + common::run_query_on_row( + &tmp_db, + &conn, + "SELECT min(id), max(id) FROM parent", + |row| { + let min_id = row.get::(0).unwrap(); + let max_id = row.get::(1).unwrap(); + assert_eq!(min_id, 2, "Expected min id to be 2"); + assert_eq!(max_id, 10000, "Expected max id to be 10000"); + }, + )?; + + // Child row should still exist + common::run_query_on_row(&tmp_db, &conn, "SELECT count(*) FROM child", |row| { + let count = row.get::(0).unwrap(); + assert_eq!(count, 1, "Expected 1 child row"); + })?; + + // Checkpoint the WAL + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")?; + + // Integrity check with rusqlite + rusqlite_integrity_check(tmp_db.path.as_path())?; + + // Size on disk should be 21 * 4096 + let db_size = std::fs::metadata(&tmp_db.path).unwrap().len(); + assert_eq!(db_size, 21 * 4096); + + Ok(()) +} + #[test] fn test_multiple_statements() -> anyhow::Result<()> { let _ = env_logger::try_init(); diff --git a/tests/lib.rs b/tests/lib.rs index 8b1378917..26482d663 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1 +1,3 @@ - +// Shared test utilities +#[path = "integration/common.rs"] +pub mod common; diff --git a/whopper/io.rs b/whopper/io.rs index 5b9da7b3e..9f8c8a872 100644 --- a/whopper/io.rs +++ b/whopper/io.rs @@ -142,6 +142,11 @@ impl IO for SimulatorIO { let mut rng = self.rng.lock().unwrap(); rng.next_u64() as i64 } + + fn fill_bytes(&self, dest: &mut [u8]) { + let mut rng = self.rng.lock().unwrap(); + rng.fill_bytes(dest); + } } const MAX_FILE_SIZE: usize = 1 << 33; // 8 GiB diff --git a/whopper/main.rs b/whopper/main.rs index 086b4687e..0a18edf2c 100644 --- a/whopper/main.rs +++ b/whopper/main.rs @@ -4,11 +4,13 @@ use rand::{Rng, RngCore, SeedableRng}; use rand_chacha::ChaCha8Rng; use sql_generation::{ generation::{Arbitrary, GenerationContext, Opts}, - model::query::{ - create::Create, create_index::CreateIndex, delete::Delete, drop_index::DropIndex, - insert::Insert, select::Select, update::Update, + model::{ + query::{ + create::Create, create_index::CreateIndex, delete::Delete, drop_index::DropIndex, + insert::Insert, select::Select, update::Update, + }, + table::{Column, ColumnType, Index, Table}, }, - model::table::{Column, ColumnType, Table}, }; use std::cell::RefCell; use std::collections::HashMap; @@ -18,7 +20,7 @@ use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitEx use turso_core::{ CipherMode, Connection, Database, DatabaseOpts, EncryptionOpts, IO, OpenFlags, Statement, }; -use turso_parser::ast::SortOrder; +use turso_parser::ast::{ColumnConstraint, SortOrder}; mod io; use crate::io::FILE_SIZE_SOFT_LIMIT; @@ -66,7 +68,7 @@ struct SimulatorFiber { struct SimulatorContext { fibers: Vec, tables: Vec
, - indexes: Vec, + indexes: Vec<(String, String)>, opts: Opts, stats: Stats, disable_indexes: bool, @@ -208,7 +210,10 @@ fn main() -> anyhow::Result<()> { let mut context = SimulatorContext { fibers, tables, - indexes: indexes.iter().map(|idx| idx.index_name.clone()).collect(), + indexes: indexes + .iter() + .map(|idx| (idx.table_name.clone(), idx.index_name.clone())) + .collect(), opts: Opts::default(), stats: Stats::default(), disable_indexes: args.disable_indexes, @@ -306,9 +311,11 @@ fn create_initial_indexes(rng: &mut ChaCha8Rng, tables: &[Table]) -> Vec Vec { let num_columns = rng.random_range(2..=8); let mut columns = Vec::new(); + // TODO: there is no proper unique generation yet in whopper, so disable primary keys for now + + // let primary = ColumnConstraint::PrimaryKey { + // order: None, + // conflict_clause: None, + // auto_increment: false, + // }; // Always add an id column as primary key columns.push(Column { name: "id".to_string(), column_type: ColumnType::Integer, - primary: true, - unique: false, + constraints: vec![], }); // Add random columns @@ -348,11 +361,19 @@ fn create_initial_schema(rng: &mut ChaCha8Rng) -> Vec { _ => ColumnType::Float, }; + // FIXME: before sql_generation did not incorporate ColumnConstraint into the sql string + // now it does and it the simulation here fails `whopper` with UNIQUE CONSTRAINT ERROR + // 20% chance of unique + let constraints = if rng.random_bool(0.0) { + vec![ColumnConstraint::Unique(None)] + } else { + Vec::new() + }; + columns.push(Column { name: format!("col_{j}"), column_type: col_type, - primary: false, - unique: rng.random_bool(0.2), // 20% chance of unique + constraints, }); } @@ -365,7 +386,6 @@ fn create_initial_schema(rng: &mut ChaCha8Rng) -> Vec { schema.push(Create { table }); } - schema } @@ -550,7 +570,10 @@ fn perform_work( let sql = create_index.to_string(); if let Ok(stmt) = context.fibers[fiber_idx].connection.prepare(&sql) { context.fibers[fiber_idx].statement.replace(Some(stmt)); - context.indexes.push(create_index.index_name.clone()); + context.indexes.push(( + create_index.index.table_name.clone(), + create_index.index_name.clone(), + )); } trace!("{} CREATE INDEX: {}", fiber_idx, sql); } @@ -559,8 +582,11 @@ fn perform_work( // DROP INDEX (2%) if !context.disable_indexes && !context.indexes.is_empty() { let index_idx = rng.random_range(0..context.indexes.len()); - let index_name = context.indexes.remove(index_idx); - let drop_index = DropIndex { index_name }; + let (table_name, index_name) = context.indexes.remove(index_idx); + let drop_index = DropIndex { + table_name, + index_name, + }; let sql = drop_index.to_string(); if let Ok(stmt) = context.fibers[fiber_idx].connection.prepare(&sql) { context.fibers[fiber_idx].statement.replace(Some(stmt));