merge main

This commit is contained in:
TcMits
2025-09-02 18:25:20 +07:00
94 changed files with 3786 additions and 1684 deletions

View File

@@ -64,8 +64,12 @@ jobs:
- uses: useblacksmith/rust-cache@v3
with:
prefix-key: "v1-rust" # can be updated if we need to reset caches due to non-trivial change in the dependencies (for example, custom env var were set for single workspace project)
- name: Install the project
- name: Simulator default
run: ./scripts/run-sim --maximum-tests 1000 --min-tick 10 --max-tick 50 loop -n 10 -s
- name: Simulator InsertHeavy
run: ./scripts/run-sim --maximum-tests 1000 --min-tick 10 --max-tick 50 --profile write_heavy loop -n 10 -s
- name: Simulator Faultless
run: ./scripts/run-sim --maximum-tests 1000 --min-tick 10 --max-tick 50 --profile faultless loop -n 10 -s
test-limbo:
runs-on: blacksmith-4vcpu-ubuntu-2404

235
Cargo.lock generated
View File

@@ -426,6 +426,15 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "castaway"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
dependencies = [
"rustversion",
]
[[package]]
name = "cc"
version = "1.2.17"
@@ -601,6 +610,21 @@ dependencies = [
"unicode-width 0.2.0",
]
[[package]]
name = "compact_str"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32"
dependencies = [
"castaway",
"cfg-if",
"itoa",
"rustversion",
"ryu",
"serde",
"static_assertions",
]
[[package]]
name = "concurrent-queue"
version = "2.5.0"
@@ -643,7 +667,7 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "core_tester"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"anyhow",
"assert_cmd",
@@ -1402,6 +1426,29 @@ dependencies = [
"slab",
]
[[package]]
name = "garde"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a989bd2fd12136080f7825ff410d9239ce84a2a639487fc9d924ee42e2fb84f"
dependencies = [
"compact_str",
"garde_derive",
"serde",
"smallvec",
]
[[package]]
name = "garde_derive"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f7f0545bbbba0a37d4d445890fa5759814e0716f02417b39f6fab292193df68"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "genawaiter"
version = "0.99.1"
@@ -1953,6 +2000,17 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "json5"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96b0db21af676c1ce64250b5f40f3ce2cf27e4e47cb91ed91eb6fe9350b430c1"
dependencies = [
"pest",
"pest_derive",
"serde",
]
[[package]]
name = "julian_day_converter"
version = "0.4.5"
@@ -2068,7 +2126,7 @@ dependencies = [
[[package]]
name = "limbo_completion"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"mimalloc",
"turso_ext",
@@ -2076,7 +2134,7 @@ dependencies = [
[[package]]
name = "limbo_crypto"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"blake3",
"data-encoding",
@@ -2089,7 +2147,7 @@ dependencies = [
[[package]]
name = "limbo_csv"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"csv",
"mimalloc",
@@ -2099,7 +2157,7 @@ dependencies = [
[[package]]
name = "limbo_ipaddr"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"ipnetwork",
"mimalloc",
@@ -2108,7 +2166,7 @@ dependencies = [
[[package]]
name = "limbo_percentile"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"mimalloc",
"turso_ext",
@@ -2116,7 +2174,7 @@ dependencies = [
[[package]]
name = "limbo_regexp"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"mimalloc",
"regex",
@@ -2125,15 +2183,17 @@ dependencies = [
[[package]]
name = "limbo_sim"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"anyhow",
"chrono",
"clap",
"dirs 6.0.0",
"env_logger 0.10.2",
"garde",
"hex",
"itertools 0.14.0",
"json5",
"log",
"notify",
"rand 0.9.2",
@@ -2141,9 +2201,11 @@ dependencies = [
"regex",
"regex-syntax 0.8.5",
"rusqlite",
"schemars 1.0.4",
"serde",
"serde_json",
"sql_generation",
"strum",
"tracing",
"tracing-subscriber",
"turso_core",
@@ -2152,7 +2214,7 @@ dependencies = [
[[package]]
name = "limbo_sqlite_test_ext"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"cc",
]
@@ -2641,6 +2703,50 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pest"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1db05f56d34358a8b1066f67cbb203ee3e7ed2ba674a6263a1d5ec6db2204323"
dependencies = [
"memchr",
"thiserror 2.0.12",
"ucd-trie",
]
[[package]]
name = "pest_derive"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb056d9e8ea77922845ec74a1c4e8fb17e7c218cc4fc11a15c5d25e189aa40bc"
dependencies = [
"pest",
"pest_generator",
]
[[package]]
name = "pest_generator"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87e404e638f781eb3202dc82db6760c8ae8a1eeef7fb3fa8264b2ef280504966"
dependencies = [
"pest",
"pest_meta",
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "pest_meta"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edd1101f170f5903fde0914f899bb503d9ff5271d7ba76bbb70bea63690cc0d5"
dependencies = [
"pest",
"sha2",
]
[[package]]
name = "pin-project-lite"
version = "0.2.16"
@@ -2863,7 +2969,7 @@ dependencies = [
[[package]]
name = "py-turso"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"anyhow",
"pyo3",
@@ -3109,6 +3215,26 @@ dependencies = [
"thiserror 2.0.12",
]
[[package]]
name = "ref-cast"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf"
dependencies = [
"ref-cast-impl",
]
[[package]]
name = "ref-cast-impl"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "regex"
version = "1.11.1"
@@ -3355,7 +3481,20 @@ checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615"
dependencies = [
"dyn-clone",
"indexmap 1.9.3",
"schemars_derive",
"schemars_derive 0.8.22",
"serde",
"serde_json",
]
[[package]]
name = "schemars"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0"
dependencies = [
"dyn-clone",
"ref-cast",
"schemars_derive 1.0.4",
"serde",
"serde_json",
]
@@ -3372,6 +3511,18 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "schemars_derive"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33d020396d1d138dc19f1165df7545479dcd58d93810dc5d646a16e55abefa80"
dependencies = [
"proc-macro2",
"quote",
"serde_derive_internals",
"syn 2.0.100",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@@ -3436,6 +3587,17 @@ dependencies = [
"serde",
]
[[package]]
name = "sha2"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
@@ -3474,6 +3636,9 @@ name = "smallvec"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
dependencies = [
"serde",
]
[[package]]
name = "socket2"
@@ -3499,14 +3664,16 @@ checksum = "d372029cb5195f9ab4e4b9aef550787dce78b124fcaee8d82519925defcd6f0d"
[[package]]
name = "sql_generation"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"anarchist-readable-name-generator-lib 0.2.0",
"anyhow",
"garde",
"hex",
"itertools 0.14.0",
"rand 0.9.2",
"rand_chacha 0.9.0",
"schemars 1.0.4",
"serde",
"tracing",
"turso_core",
@@ -3528,6 +3695,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "str_stack"
version = "0.1.0"
@@ -4000,7 +4173,7 @@ dependencies = [
[[package]]
name = "turso"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"rand 0.8.5",
"rand_chacha 0.3.1",
@@ -4012,7 +4185,7 @@ dependencies = [
[[package]]
name = "turso-java"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"jni",
"thiserror 2.0.12",
@@ -4021,7 +4194,7 @@ dependencies = [
[[package]]
name = "turso_cli"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"anyhow",
"cfg-if",
@@ -4038,7 +4211,7 @@ dependencies = [
"mimalloc",
"nu-ansi-term 0.50.1",
"rustyline",
"schemars",
"schemars 0.8.22",
"serde",
"serde_json",
"shlex",
@@ -4054,7 +4227,7 @@ dependencies = [
[[package]]
name = "turso_core"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"aegis",
"aes",
@@ -4113,7 +4286,7 @@ dependencies = [
[[package]]
name = "turso_dart"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"flutter_rust_bridge",
"turso_core",
@@ -4121,7 +4294,7 @@ dependencies = [
[[package]]
name = "turso_ext"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"chrono",
"getrandom 0.3.2",
@@ -4130,7 +4303,7 @@ dependencies = [
[[package]]
name = "turso_ext_tests"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"env_logger 0.11.7",
"lazy_static",
@@ -4141,7 +4314,7 @@ dependencies = [
[[package]]
name = "turso_macros"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"proc-macro2",
"quote",
@@ -4150,7 +4323,7 @@ dependencies = [
[[package]]
name = "turso_node"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"napi",
"napi-build",
@@ -4161,7 +4334,7 @@ dependencies = [
[[package]]
name = "turso_parser"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"bitflags 2.9.0",
"criterion",
@@ -4178,7 +4351,7 @@ dependencies = [
[[package]]
name = "turso_sqlite3"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"env_logger 0.11.7",
"libc",
@@ -4191,7 +4364,7 @@ dependencies = [
[[package]]
name = "turso_sqlite3_parser"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"bitflags 2.9.0",
"cc",
@@ -4209,7 +4382,7 @@ dependencies = [
[[package]]
name = "turso_stress"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"anarchist-readable-name-generator-lib 0.1.2",
"antithesis_sdk",
@@ -4225,7 +4398,7 @@ dependencies = [
[[package]]
name = "turso_sync_engine"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"base64",
"bytes",
@@ -4251,7 +4424,7 @@ dependencies = [
[[package]]
name = "turso_sync_js"
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
dependencies = [
"genawaiter",
"http",
@@ -4270,6 +4443,12 @@ version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
[[package]]
name = "ucd-trie"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971"
[[package]]
name = "uncased"
version = "0.9.10"

View File

@@ -33,29 +33,29 @@ members = [
exclude = ["perf/latency/limbo"]
[workspace.package]
version = "0.1.5-pre.1"
version = "0.1.5-pre.2"
authors = ["the Limbo authors"]
edition = "2021"
license = "MIT"
repository = "https://github.com/tursodatabase/turso"
[workspace.dependencies]
turso = { path = "bindings/rust", version = "0.1.5-pre.1" }
turso_node = { path = "bindings/javascript", version = "0.1.5-pre.1" }
limbo_completion = { path = "extensions/completion", version = "0.1.5-pre.1" }
turso_core = { path = "core", version = "0.1.5-pre.1" }
turso_sync_engine = { path = "sync/engine", version = "0.1.5-pre.1" }
limbo_crypto = { path = "extensions/crypto", version = "0.1.5-pre.1" }
limbo_csv = { path = "extensions/csv", version = "0.1.5-pre.1" }
turso_ext = { path = "extensions/core", version = "0.1.5-pre.1" }
turso_ext_tests = { path = "extensions/tests", version = "0.1.5-pre.1" }
limbo_ipaddr = { path = "extensions/ipaddr", version = "0.1.5-pre.1" }
turso_macros = { path = "macros", version = "0.1.5-pre.1" }
limbo_percentile = { path = "extensions/percentile", version = "0.1.5-pre.1" }
limbo_regexp = { path = "extensions/regexp", version = "0.1.5-pre.1" }
turso_sqlite3_parser = { path = "vendored/sqlite3-parser", version = "0.1.5-pre.1" }
limbo_uuid = { path = "extensions/uuid", version = "0.1.5-pre.1" }
turso_parser = { path = "parser" }
turso = { path = "bindings/rust", version = "0.1.5-pre.2" }
turso_node = { path = "bindings/javascript", version = "0.1.5-pre.2" }
limbo_completion = { path = "extensions/completion", version = "0.1.5-pre.2" }
turso_core = { path = "core", version = "0.1.5-pre.2" }
turso_sync_engine = { path = "sync/engine", version = "0.1.5-pre.2" }
limbo_crypto = { path = "extensions/crypto", version = "0.1.5-pre.2" }
limbo_csv = { path = "extensions/csv", version = "0.1.5-pre.2" }
turso_ext = { path = "extensions/core", version = "0.1.5-pre.2" }
turso_ext_tests = { path = "extensions/tests", version = "0.1.5-pre.2" }
limbo_ipaddr = { path = "extensions/ipaddr", version = "0.1.5-pre.2" }
turso_macros = { path = "macros", version = "0.1.5-pre.2" }
limbo_percentile = { path = "extensions/percentile", version = "0.1.5-pre.2" }
limbo_regexp = { path = "extensions/regexp", version = "0.1.5-pre.2" }
turso_sqlite3_parser = { path = "vendored/sqlite3-parser", version = "0.1.5-pre.2" }
limbo_uuid = { path = "extensions/uuid", version = "0.1.5-pre.2" }
turso_parser = { path = "parser", version = "0.1.5-pre.2" }
sql_generation = { path = "sql_generation" }
strum = { version = "0.26", features = ["derive"] }
strum_macros = "0.26"
@@ -67,6 +67,8 @@ rusqlite = { version = "0.37.0", features = ["bundled"] }
itertools = "0.14.0"
rand = "0.9.2"
tracing = "0.1.41"
schemars = "1.0.4"
garde = "0.22"
[profile.release]
debug = "line-tables-only"

View File

@@ -13,7 +13,6 @@ FROM chef AS planner
COPY ./Cargo.lock ./Cargo.lock
COPY ./Cargo.toml ./Cargo.toml
COPY ./bindings/dart ./bindings/dart/
COPY ./bindings/go ./bindings/go/
COPY ./bindings/java ./bindings/java/
COPY ./bindings/javascript ./bindings/javascript/
COPY ./bindings/python ./bindings/python/
@@ -25,13 +24,14 @@ COPY ./macros ./macros/
COPY ./packages ./packages/
COPY ./parser ./parser/
COPY ./simulator ./simulator/
COPY ./sql_generation ./sql_generation
COPY ./sqlite3 ./sqlite3/
COPY ./tests ./tests/
COPY ./stress ./stress/
COPY ./sync ./sync/
COPY ./vendored ./vendored/
COPY ./testing/sqlite_test_ext ./testing/sqlite_test_ext/
COPY ./testing/unreliable-libc ./testing/unreliable-libc/
COPY ./tests ./tests/
COPY ./vendored ./vendored/
RUN cargo chef prepare --bin turso_stress --recipe-path recipe.json
#
@@ -51,30 +51,26 @@ COPY stress/libvoidstar.so /opt/antithesis/libvoidstar.so
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook --bin turso_stress --release --recipe-path recipe.json
COPY --from=planner /app/Cargo.toml ./Cargo.toml
COPY --from=planner /app/bindings/dart ./bindings/dart/
COPY --from=planner /app/bindings/java ./bindings/java/
COPY --from=planner /app/bindings/javascript ./bindings/javascript/
COPY --from=planner /app/bindings/python ./bindings/python/
COPY --from=planner /app/bindings/rust ./bindings/rust/
COPY --from=planner /app/cli ./cli/
COPY --from=planner /app/core ./core/
COPY --from=planner /app/extensions ./extensions/
COPY --from=planner /app/macros ./macros/
COPY --from=planner /app/simulator ./simulator/
COPY --from=planner /app/sqlite3 ./sqlite3/
COPY --from=planner /app/tests ./tests/
COPY --from=planner /app/stress ./stress/
COPY --from=planner /app/bindings/rust ./bindings/rust/
COPY --from=planner /app/bindings/dart ./bindings/dart/
COPY --from=planner /app/bindings/go ./bindings/go/
COPY --from=planner /app/bindings/javascript ./bindings/javascript/
COPY --from=planner /app/bindings/java ./bindings/java/
COPY --from=planner /app/bindings/python ./bindings/python/
COPY --from=planner /app/packages ./packages/
COPY --from=planner /app/core ./core/
COPY --from=planner /app/extensions ./extensions/
COPY --from=planner /app/macros ./macros/
COPY --from=planner /app/parser ./parser/
COPY --from=planner /app/simulator ./simulator/
COPY --from=planner /app/sql_generation ./sql_generation
COPY --from=planner /app/sqlite3 ./sqlite3/
COPY --from=planner /app/stress ./stress/
COPY --from=planner /app/sync ./sync/
COPY --from=planner /app/parser ./parser/
COPY --from=planner /app/vendored ./vendored/
COPY --from=planner /app/testing/sqlite_test_ext ./testing/sqlite_test_ext/
COPY --from=planner /app/testing/unreliable-libc ./testing/unreliable-libc/
COPY --from=planner /app/tests ./tests/
COPY --from=planner /app/vendored ./vendored/
RUN if [ "$antithesis" = "true" ]; then \
cp /opt/antithesis/libvoidstar.so /usr/lib/libvoidstar.so && \

View File

@@ -55,7 +55,7 @@ uv-sync-test:
uv sync --all-extras --dev --package turso_test
.PHONE: uv-sync
test: limbo uv-sync-test test-compat test-vector test-sqlite3 test-shell test-memory test-write test-update test-constraint test-collate test-extensions test-mvcc test-matviews
test: limbo uv-sync-test test-compat test-alter-column test-vector test-sqlite3 test-shell test-memory test-write test-update test-constraint test-collate test-extensions test-mvcc test-matviews
.PHONY: test
test-extensions: limbo uv-sync-test
@@ -82,6 +82,10 @@ test-matviews:
RUST_LOG=$(RUST_LOG) SQLITE_EXEC=$(SQLITE_EXEC) ./testing/materialized_views.test
.PHONY: test-matviews
test-alter-column:
RUST_LOG=$(RUST_LOG) SQLITE_EXEC=$(SQLITE_EXEC) ./testing/alter_column.test
.PHONY: test-alter-column
reset-db:
./scripts/clone_test_db.sh
.PHONY: reset-db
@@ -201,3 +205,8 @@ endif
fi
.PHONY: merge-pr
sim-schema:
mkdir -p simulator/configs/custom
cargo run -p limbo_sim -- print-schema > simulator/configs/custom/profile-schema.json

View File

@@ -6,6 +6,7 @@ import java.io.InputStream;
import java.io.Reader;
import java.math.BigDecimal;
import java.net.URL;
import java.nio.ByteBuffer;
import java.sql.Array;
import java.sql.Blob;
import java.sql.Clob;
@@ -121,17 +122,38 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
@Override
public void setDate(int parameterIndex, Date x) throws SQLException {
// TODO
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
} else {
long time = x.getTime();
this.statement.bindBlob(
parameterIndex, ByteBuffer.allocate(Long.BYTES).putLong(time).array());
}
}
@Override
public void setTime(int parameterIndex, Time x) throws SQLException {
// TODO
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
} else {
long time = x.getTime();
this.statement.bindBlob(
parameterIndex, ByteBuffer.allocate(Long.BYTES).putLong(time).array());
}
}
@Override
public void setTimestamp(int parameterIndex, Timestamp x) throws SQLException {
// TODO
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
} else {
long time = x.getTime();
this.statement.bindBlob(
parameterIndex, ByteBuffer.allocate(Long.BYTES).putLong(time).array());
}
}
@Override
@@ -212,17 +234,18 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
@Override
public void setDate(int parameterIndex, Date x, Calendar cal) throws SQLException {
// TODO
setDate(parameterIndex, x);
}
@Override
public void setTime(int parameterIndex, Time x, Calendar cal) throws SQLException {
// TODO
setTime(parameterIndex, x);
}
@Override
public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws SQLException {
// TODO
// TODO: Apply calendar timezone conversion
setTimestamp(parameterIndex, x);
}
@Override

View File

@@ -5,6 +5,7 @@ import java.io.Reader;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.net.URL;
import java.nio.ByteBuffer;
import java.sql.Array;
import java.sql.Blob;
import java.sql.Clob;
@@ -146,21 +147,63 @@ public final class JDBC4ResultSet implements ResultSet, ResultSetMetaData {
}
@Override
@SkipNullableCheck
@Nullable
public Date getDate(int columnIndex) throws SQLException {
throw new UnsupportedOperationException("not implemented");
final Object result = resultSet.get(columnIndex);
if (result == null) {
return null;
}
return wrapTypeConversion(
() -> {
if (result instanceof byte[]) {
byte[] bytes = (byte[]) result;
if (bytes.length == Long.BYTES) {
long time = ByteBuffer.wrap(bytes).getLong();
return new Date(time);
}
}
throw new SQLException("Cannot convert value to Date: " + result.getClass());
});
}
@Override
@SkipNullableCheck
public Time getTime(int columnIndex) throws SQLException {
throw new UnsupportedOperationException("not implemented");
final Object result = resultSet.get(columnIndex);
if (result == null) {
return null;
}
return wrapTypeConversion(
() -> {
if (result instanceof byte[]) {
byte[] bytes = (byte[]) result;
if (bytes.length == Long.BYTES) {
long time = ByteBuffer.wrap(bytes).getLong();
return new Time(time);
}
}
throw new SQLException("Cannot convert value to Date: " + result.getClass());
});
}
@Override
@SkipNullableCheck
public Timestamp getTimestamp(int columnIndex) throws SQLException {
throw new UnsupportedOperationException("not implemented");
final Object result = resultSet.get(columnIndex);
if (result == null) {
return null;
}
return wrapTypeConversion(
() -> {
if (result instanceof byte[]) {
byte[] bytes = (byte[]) result;
if (bytes.length == Long.BYTES) {
long time = ByteBuffer.wrap(bytes).getLong();
return new Timestamp(time);
}
}
throw new SQLException("Cannot convert value to Timestamp: " + result.getClass());
});
}
@Override
@@ -238,9 +281,27 @@ public final class JDBC4ResultSet implements ResultSet, ResultSetMetaData {
}
@Override
@SkipNullableCheck
@Nullable
public Date getDate(String columnLabel) throws SQLException {
throw new UnsupportedOperationException("not implemented");
final Object result = resultSet.get(columnLabel);
if (result == null) {
return null;
}
return wrapTypeConversion(
() -> {
if (result instanceof byte[]) {
byte[] bytes = (byte[]) result;
if (bytes.length == Long.BYTES) {
long time = ByteBuffer.wrap(bytes).getLong();
return new Date(time);
}
}
// Try to parse as string if it's stored as TEXT
if (result instanceof String) {
return Date.valueOf((String) result);
}
throw new SQLException("Cannot convert value to Date: " + result.getClass());
});
}
@Override
@@ -252,7 +313,7 @@ public final class JDBC4ResultSet implements ResultSet, ResultSetMetaData {
@Override
@SkipNullableCheck
public Timestamp getTimestamp(String columnLabel) throws SQLException {
throw new UnsupportedOperationException("not implemented");
return getTimestamp(findColumn(columnLabel));
}
@Override
@@ -738,39 +799,45 @@ public final class JDBC4ResultSet implements ResultSet, ResultSetMetaData {
}
@Override
@SkipNullableCheck
@Nullable
public Date getDate(int columnIndex, Calendar cal) throws SQLException {
throw new UnsupportedOperationException("not implemented");
// TODO: Properly handle timezone conversion with Calendar
return getDate(columnIndex);
}
@Override
@SkipNullableCheck
@Nullable
public Date getDate(String columnLabel, Calendar cal) throws SQLException {
throw new UnsupportedOperationException("not implemented");
// TODO: Properly handle timezone conversion with Calendar
return getDate(columnLabel);
}
@Override
@SkipNullableCheck
public Time getTime(int columnIndex, Calendar cal) throws SQLException {
throw new UnsupportedOperationException("not implemented");
// TODO: Properly handle timezone conversion with Calendar
return getTime(columnIndex);
}
@Override
@SkipNullableCheck
public Time getTime(String columnLabel, Calendar cal) throws SQLException {
throw new UnsupportedOperationException("not implemented");
// TODO: Properly handle timezone conversion with Calendar
return getTime(columnLabel);
}
@Override
@SkipNullableCheck
public Timestamp getTimestamp(int columnIndex, Calendar cal) throws SQLException {
throw new UnsupportedOperationException("not implemented");
// TODO: Apply calendar timezone conversion
return getTimestamp(columnIndex);
}
@Override
@SkipNullableCheck
public Timestamp getTimestamp(String columnLabel, Calendar cal) throws SQLException {
throw new UnsupportedOperationException("not implemented");
// TODO: Apply calendar timezone conversion
return getTimestamp(findColumn(columnLabel));
}
@Override

View File

@@ -6,9 +6,12 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.Properties;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -225,6 +228,84 @@ class JDBC4PreparedStatementTest {
assertArrayEquals(new byte[] {7, 8, 9}, rs.getBytes(1));
}
@Test
void testSetDate() throws SQLException {
connection.prepareStatement("CREATE TABLE test (col BLOB)").execute();
PreparedStatement stmt =
connection.prepareStatement("INSERT INTO test (col) VALUES (?), (?), (?)");
Date date1 = new Date(1000000000000L);
Date date2 = new Date(1500000000000L);
Date date3 = new Date(2000000000000L);
stmt.setDate(1, date1);
stmt.setDate(2, date2);
stmt.setDate(3, date3);
stmt.execute();
PreparedStatement stmt2 = connection.prepareStatement("SELECT * FROM test;");
JDBC4ResultSet rs = (JDBC4ResultSet) stmt2.executeQuery();
assertTrue(rs.next());
assertEquals(date1, rs.getDate(1));
assertTrue(rs.next());
assertEquals(date2, rs.getDate(1));
assertTrue(rs.next());
assertEquals(date3, rs.getDate(1));
}
@Test
void testSetTime() throws SQLException {
connection.prepareStatement("CREATE TABLE test (col BLOB)").execute();
PreparedStatement stmt =
connection.prepareStatement("INSERT INTO test (col) VALUES (?), (?), (?)");
Time time1 = new Time(1000000000000L);
Time time2 = new Time(1500000000000L);
Time time3 = new Time(2000000000000L);
stmt.setTime(1, time1);
stmt.setTime(2, time2);
stmt.setTime(3, time3);
stmt.execute();
PreparedStatement stmt2 = connection.prepareStatement("SELECT * FROM test;");
JDBC4ResultSet rs = (JDBC4ResultSet) stmt2.executeQuery();
assertTrue(rs.next());
assertEquals(time1, rs.getTime(1));
assertTrue(rs.next());
assertEquals(time2, rs.getTime(1));
assertTrue(rs.next());
assertEquals(time3, rs.getTime(1));
}
@Test
void testSetTimestamp() throws SQLException {
connection.prepareStatement("CREATE TABLE test (col BLOB)").execute();
PreparedStatement stmt =
connection.prepareStatement("INSERT INTO test (col) VALUES (?), (?), (?)");
Timestamp timestamp1 = new Timestamp(1000000000000L);
Timestamp timestamp2 = new Timestamp(1500000000000L);
Timestamp timestamp3 = new Timestamp(2000000000000L);
stmt.setTimestamp(1, timestamp1);
stmt.setTimestamp(2, timestamp2);
stmt.setTimestamp(3, timestamp3);
stmt.execute();
PreparedStatement stmt2 = connection.prepareStatement("SELECT * FROM test;");
JDBC4ResultSet rs = (JDBC4ResultSet) stmt2.executeQuery();
assertTrue(rs.next());
assertEquals(timestamp1, rs.getTimestamp(1));
assertTrue(rs.next());
assertEquals(timestamp2, rs.getTimestamp(1));
assertTrue(rs.next());
assertEquals(timestamp3, rs.getTimestamp(1));
}
@Test
void testInsertMultipleTypes() throws SQLException {
connection

View File

@@ -1,12 +1,12 @@
{
"name": "@tursodatabase/database",
"version": "0.1.5-pre.1",
"version": "0.1.5-pre.2",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@tursodatabase/database",
"version": "0.1.5-pre.1",
"version": "0.1.5-pre.2",
"license": "MIT",
"devDependencies": {
"@napi-rs/cli": "^3.0.4",

View File

@@ -1,6 +1,6 @@
{
"name": "@tursodatabase/database",
"version": "0.1.5-pre.1",
"version": "0.1.5-pre.2",
"repository": {
"type": "git",
"url": "https://github.com/tursodatabase/turso"

View File

@@ -578,7 +578,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
if !(512..=65536).contains(&size) || size & (size - 1) != 0 {
return Err(turso_core::LimboError::NotADB);
}
let pos = (page_idx - 1) * size;
let pos = (page_idx as u64 - 1) * size as u64;
self.file.pread(pos, c)
}
@@ -590,7 +590,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
c: turso_core::Completion,
) -> turso_core::Result<turso_core::Completion> {
let size = buffer.len();
let pos = (page_idx - 1) * size;
let pos = (page_idx as u64 - 1) * size as u64;
self.file.pwrite(pos, buffer, c)
}
@@ -602,7 +602,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
_io_ctx: &turso_core::IOContext,
c: turso_core::Completion,
) -> turso_core::Result<turso_core::Completion> {
let pos = first_page_idx.saturating_sub(1) * page_size;
let pos = first_page_idx.saturating_sub(1) as u64 * page_size as u64;
let c = self.file.pwritev(pos, buffers, c)?;
Ok(c)
}
@@ -620,7 +620,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
len: usize,
c: turso_core::Completion,
) -> turso_core::Result<turso_core::Completion> {
let c = self.file.truncate(len, c)?;
let c = self.file.truncate(len as u64, c)?;
Ok(c)
}
}

View File

@@ -1,5 +1,7 @@
use thiserror::Error;
use crate::storage::page_cache::CacheError;
#[derive(Debug, Clone, Error, miette::Diagnostic)]
pub enum LimboError {
#[error("Corrupt database: {0}")]
@@ -8,8 +10,8 @@ pub enum LimboError {
NotADB,
#[error("Internal error: {0}")]
InternalError(String),
#[error("Page cache is full")]
CacheFull,
#[error(transparent)]
CacheError(#[from] CacheError),
#[error("Database is full: {0}")]
DatabaseFull(String),
#[error("Parse error: {0}")]

View File

@@ -583,6 +583,7 @@ impl Display for MathFunc {
#[derive(Debug)]
pub enum AlterTableFunc {
RenameTable,
AlterColumn,
RenameColumn,
}
@@ -591,6 +592,7 @@ impl Display for AlterTableFunc {
match self {
AlterTableFunc::RenameTable => write!(f, "limbo_rename_table"),
AlterTableFunc::RenameColumn => write!(f, "limbo_rename_column"),
AlterTableFunc::AlterColumn => write!(f, "limbo_alter_column"),
}
}
}

View File

@@ -1,8 +1,40 @@
use core::f64;
use crate::types::Value;
use crate::vdbe::Register;
use crate::LimboError;
// TODO: Support %!.3s %i, %x, %X, %o, %e, %E, %c. flags: - + 0 ! ,
fn get_exponential_formatted_str(number: &f64, uppercase: bool) -> crate::Result<String> {
let pre_formatted = format!("{number:.6e}");
let mut parts = pre_formatted.split("e");
let maybe_base = parts.next();
let maybe_exponent = parts.next();
let mut result = String::new();
match (maybe_base, maybe_exponent) {
(Some(base), Some(exponent)) => {
result.push_str(base);
result.push_str(if uppercase { "E" } else { "e" });
match exponent.parse::<i32>() {
Ok(exponent_number) => {
let exponent_fmt = format!("{exponent_number:+03}");
result.push_str(&exponent_fmt);
Ok(result)
}
Err(_) => Err(LimboError::InternalError(
"unable to parse exponential expression's exponent".into(),
)),
}
}
(_, _) => Err(LimboError::InternalError(
"unable to parse exponential expression".into(),
)),
}
}
// TODO: Support %!.3s. flags: - + 0 ! ,
#[inline(always)]
pub fn exec_printf(values: &[Register]) -> crate::Result<Value> {
if values.is_empty() {
@@ -40,6 +72,20 @@ pub fn exec_printf(values: &[Register]) -> crate::Result<Value> {
}
args_index += 1;
}
Some('u') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Integer(_) => {
let converted_value = value.as_uint();
result.push_str(&format!("{converted_value}"))
}
_ => result.push('0'),
}
args_index += 1;
}
Some('s') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
@@ -63,6 +109,119 @@ pub fn exec_printf(values: &[Register]) -> crate::Result<Value> {
}
args_index += 1;
}
Some('e') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => match get_exponential_formatted_str(f, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
},
Value::Integer(i) => {
let f = *i as f64;
match get_exponential_formatted_str(&f, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
}
}
Value::Text(s) => {
let number: f64 = s
.as_str()
.trim_start()
.trim_end_matches(|c: char| !c.is_numeric())
.parse()
.unwrap_or(0.0);
match get_exponential_formatted_str(&number, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
};
}
_ => result.push_str("0.000000e+00"),
}
args_index += 1;
}
Some('E') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => match get_exponential_formatted_str(f, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
},
Value::Integer(i) => {
let f = *i as f64;
match get_exponential_formatted_str(&f, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
}
}
Value::Text(s) => {
let number: f64 = s
.as_str()
.trim_start()
.trim_end_matches(|c: char| !c.is_numeric())
.parse()
.unwrap_or(0.0);
match get_exponential_formatted_str(&number, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
};
}
_ => result.push_str("0.000000e+00"),
}
args_index += 1;
}
Some('c') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
let value_str: String = format!("{value}");
if !value_str.is_empty() {
result.push_str(&value_str[0..1]);
}
args_index += 1;
}
Some('x') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => result.push_str(&format!("{:x}", *f as i64)),
Value::Integer(i) => result.push_str(&format!("{i:x}")),
_ => result.push('0'),
}
args_index += 1;
}
Some('X') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => result.push_str(&format!("{:X}", *f as i64)),
Value::Integer(i) => result.push_str(&format!("{i:X}")),
_ => result.push('0'),
}
args_index += 1;
}
Some('o') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => result.push_str(&format!("{:o}", *f as i64)),
Value::Integer(i) => result.push_str(&format!("{i:o}")),
_ => result.push('0'),
}
args_index += 1;
}
None => {
return Err(LimboError::InvalidArgument(
"incomplete format specifier".into(),
@@ -159,6 +318,29 @@ mod tests {
}
}
#[test]
fn test_printf_unsigned_integer_formatting() {
let test_cases = vec![
// Basic
(vec![text("Number: %u"), integer(42)], text("Number: 42")),
// Multiple numbers
(
vec![text("%u + %u = %u"), integer(2), integer(3), integer(5)],
text("2 + 3 = 5"),
),
// Negative number should be represented as its uint representation
(
vec![text("Negative: %u"), integer(-1)],
text("Negative: 18446744073709551615"),
),
// Non-numeric value defaults to 0
(vec![text("NaN: %u"), text("not a number")], text("NaN: 0")),
];
for (input, output) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *output.get_value())
}
}
#[test]
fn test_printf_float_formatting() {
let test_cases = vec![
@@ -194,6 +376,178 @@ mod tests {
}
}
#[test]
fn test_printf_character_formatting() {
let test_cases = vec![
// Simple character
(vec![text("character: %c"), text("a")], text("character: a")),
// Character with string
(
vec![text("character: %c"), text("this is a test")],
text("character: t"),
),
// Character with empty
(vec![text("character: %c"), text("")], text("character: ")),
// Character with integer
(
vec![text("character: %c"), integer(123)],
text("character: 1"),
),
// Character with float
(
vec![text("character: %c"), float(42.5)],
text("character: 4"),
),
];
for (input, expected) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
}
}
#[test]
fn test_printf_exponential_formatting() {
let test_cases = vec![
// Simple number
(
vec![text("Exp: %e"), float(23000000.0)],
text("Exp: 2.300000e+07"),
),
// Negative number
(
vec![text("Exp: %e"), float(-23000000.0)],
text("Exp: -2.300000e+07"),
),
// Non integer float
(
vec![text("Exp: %e"), float(250.375)],
text("Exp: 2.503750e+02"),
),
// Positive, but smaller than zero
(
vec![text("Exp: %e"), float(0.0003235)],
text("Exp: 3.235000e-04"),
),
// Zero
(vec![text("Exp: %e"), float(0.0)], text("Exp: 0.000000e+00")),
// Uppercase "e"
(
vec![text("Exp: %e"), float(0.0003235)],
text("Exp: 3.235000e-04"),
),
// String with integer number
(
vec![text("Exp: %e"), text("123")],
text("Exp: 1.230000e+02"),
),
// String with floating point number
(
vec![text("Exp: %e"), text("123.45")],
text("Exp: 1.234500e+02"),
),
// String with number with leftmost zeroes
(
vec![text("Exp: %e"), text("00123")],
text("Exp: 1.230000e+02"),
),
// String with text
(
vec![text("Exp: %e"), text("test")],
text("Exp: 0.000000e+00"),
),
// String starting with number, but with text on the end
(
vec![text("Exp: %e"), text("123ab")],
text("Exp: 1.230000e+02"),
),
// String starting with text, but with number on the end
(
vec![text("Exp: %e"), text("ab123")],
text("Exp: 0.000000e+00"),
),
// String with exponential representation
(
vec![text("Exp: %e"), text("1.230000e+02")],
text("Exp: 1.230000e+02"),
),
];
for (input, expected) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
}
}
#[test]
fn test_printf_hexadecimal_formatting() {
let test_cases = vec![
// Simple number
(vec![text("hex: %x"), integer(4)], text("hex: 4")),
// Bigger Number
(
vec![text("hex: %x"), integer(15565303546)],
text("hex: 39fc3aefa"),
),
// Uppercase letters
(
vec![text("hex: %X"), integer(15565303546)],
text("hex: 39FC3AEFA"),
),
// Negative
(
vec![text("hex: %x"), integer(-15565303546)],
text("hex: fffffffc603c5106"),
),
// Float
(vec![text("hex: %x"), float(42.5)], text("hex: 2a")),
// Negative Float
(
vec![text("hex: %x"), float(-42.5)],
text("hex: ffffffffffffffd6"),
),
// Text
(vec![text("hex: %x"), text("42")], text("hex: 0")),
// Empty Text
(vec![text("hex: %x"), text("")], text("hex: 0")),
];
for (input, expected) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
}
}
#[test]
fn test_printf_octal_formatting() {
let test_cases = vec![
// Simple number
(vec![text("octal: %o"), integer(4)], text("octal: 4")),
// Bigger Number
(
vec![text("octal: %o"), integer(15565303546)],
text("octal: 163760727372"),
),
// Negative
(
vec![text("octal: %o"), integer(-15565303546)],
text("octal: 1777777777614017050406"),
),
// Float
(vec![text("octal: %o"), float(42.5)], text("octal: 52")),
// Negative Float
(
vec![text("octal: %o"), float(-42.5)],
text("octal: 1777777777777777777726"),
),
// Text
(vec![text("octal: %o"), text("42")], text("octal: 0")),
// Empty Text
(vec![text("octal: %o"), text("")], text("octal: 0")),
];
for (input, expected) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
}
}
#[test]
fn test_printf_mixed_formatting() {
let test_cases = vec![

View File

@@ -68,9 +68,9 @@ impl File for GenericFile {
}
#[instrument(skip(self, c), level = Level::TRACE)]
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
let mut file = self.file.write();
file.seek(std::io::SeekFrom::Start(pos as u64))?;
file.seek(std::io::SeekFrom::Start(pos))?;
let nr = {
let r = c.as_read();
let buf = r.buf();
@@ -83,9 +83,9 @@ impl File for GenericFile {
}
#[instrument(skip(self, c, buffer), level = Level::TRACE)]
fn pwrite(&self, pos: usize, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
let mut file = self.file.write();
file.seek(std::io::SeekFrom::Start(pos as u64))?;
file.seek(std::io::SeekFrom::Start(pos))?;
let buf = buffer.as_slice();
file.write_all(buf)?;
c.complete(buffer.len() as i32);
@@ -101,9 +101,9 @@ impl File for GenericFile {
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
let file = self.file.write();
file.set_len(len as u64)?;
file.set_len(len)?;
c.complete(0);
Ok(c)
}

View File

@@ -182,7 +182,7 @@ struct WritevState {
/// File descriptor/id of the file we are writing to
file_id: Fd,
/// absolute file offset for next submit
file_pos: usize,
file_pos: u64,
/// current buffer index in `bufs`
current_buffer_idx: usize,
/// intra-buffer offset
@@ -198,7 +198,7 @@ struct WritevState {
}
impl WritevState {
fn new(file: &UringFile, pos: usize, bufs: Vec<Arc<crate::Buffer>>) -> Self {
fn new(file: &UringFile, pos: u64, bufs: Vec<Arc<crate::Buffer>>) -> Self {
let file_id = file
.id()
.map(Fd::Fixed)
@@ -223,23 +223,23 @@ impl WritevState {
/// Advance (idx, off, pos) after written bytes
#[inline(always)]
fn advance(&mut self, written: usize) {
fn advance(&mut self, written: u64) {
let mut remaining = written;
while remaining > 0 {
let current_buf_len = self.bufs[self.current_buffer_idx].len();
let left = current_buf_len - self.current_buffer_offset;
if remaining < left {
self.current_buffer_offset += remaining;
if remaining < left as u64 {
self.current_buffer_offset += remaining as usize;
self.file_pos += remaining;
remaining = 0;
} else {
remaining -= left;
self.file_pos += left;
remaining -= left as u64;
self.file_pos += left as u64;
self.current_buffer_idx += 1;
self.current_buffer_offset = 0;
}
}
self.total_written += written;
self.total_written += written as usize;
}
#[inline(always)]
@@ -400,7 +400,7 @@ impl WrappedIOUring {
iov_allocation[0].iov_len as u32,
id as u16,
)
.offset(st.file_pos as u64)
.offset(st.file_pos)
.build()
.user_data(key)
} else {
@@ -409,7 +409,7 @@ impl WrappedIOUring {
iov_allocation[0].iov_base as *const u8,
iov_allocation[0].iov_len as u32,
)
.offset(st.file_pos as u64)
.offset(st.file_pos)
.build()
.user_data(key)
}
@@ -425,7 +425,7 @@ impl WrappedIOUring {
let entry = with_fd!(st.file_id, |fd| {
io_uring::opcode::Writev::new(fd, ptr, iov_count as u32)
.offset(st.file_pos as u64)
.offset(st.file_pos)
.build()
.user_data(key)
});
@@ -443,8 +443,8 @@ impl WrappedIOUring {
return;
}
let written = result as usize;
state.advance(written);
let written = result;
state.advance(written as u64);
match state.remaining() {
0 => {
tracing::info!(
@@ -643,7 +643,7 @@ impl File for UringFile {
Ok(())
}
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
let r = c.as_read();
let mut io = self.io.borrow_mut();
let read_e = {
@@ -663,14 +663,14 @@ impl File for UringFile {
io.debug_check_fixed(idx, ptr, len);
}
io_uring::opcode::ReadFixed::new(fd, ptr, len as u32, idx as u16)
.offset(pos as u64)
.offset(pos)
.build()
.user_data(get_key(c.clone()))
} else {
trace!("pread(pos = {}, length = {})", pos, len);
// Use Read opcode if fixed buffer is not available
io_uring::opcode::Read::new(fd, buf.as_mut_ptr(), len as u32)
.offset(pos as u64)
.offset(pos)
.build()
.user_data(get_key(c.clone()))
}
@@ -680,7 +680,7 @@ impl File for UringFile {
Ok(c)
}
fn pwrite(&self, pos: usize, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
let mut io = self.io.borrow_mut();
let write = {
let ptr = buffer.as_ptr();
@@ -698,13 +698,13 @@ impl File for UringFile {
io.debug_check_fixed(idx, ptr, len);
}
io_uring::opcode::WriteFixed::new(fd, ptr, len as u32, idx as u16)
.offset(pos as u64)
.offset(pos)
.build()
.user_data(get_key(c.clone()))
} else {
trace!("pwrite(pos = {}, length = {})", pos, buffer.len());
io_uring::opcode::Write::new(fd, ptr, len as u32)
.offset(pos as u64)
.offset(pos)
.build()
.user_data(get_key(c.clone()))
}
@@ -728,7 +728,7 @@ impl File for UringFile {
fn pwritev(
&self,
pos: usize,
pos: u64,
bufs: Vec<Arc<crate::Buffer>>,
c: Completion,
) -> Result<Completion> {
@@ -748,10 +748,10 @@ impl File for UringFile {
Ok(self.file.metadata()?.len())
}
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
let mut io = self.io.borrow_mut();
let truncate = with_fd!(self, |fd| {
io_uring::opcode::Ftruncate::new(fd, len as u64)
io_uring::opcode::Ftruncate::new(fd, len)
.build()
.user_data(get_key(c.clone()))
});

View File

@@ -69,17 +69,12 @@ impl IO for MemoryIO {
files.remove(path);
Ok(())
}
fn run_once(&self) -> Result<()> {
// nop
Ok(())
}
}
pub struct MemoryFile {
path: String,
pages: UnsafeCell<BTreeMap<usize, MemPage>>,
size: Cell<usize>,
size: Cell<u64>,
}
unsafe impl Send for MemoryFile {}
unsafe impl Sync for MemoryFile {}
@@ -92,10 +87,10 @@ impl File for MemoryFile {
Ok(())
}
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
tracing::debug!("pread(path={}): pos={}", self.path, pos);
let r = c.as_read();
let buf_len = r.buf().len();
let buf_len = r.buf().len() as u64;
if buf_len == 0 {
c.complete(0);
return Ok(c);
@@ -110,8 +105,8 @@ impl File for MemoryFile {
let read_len = buf_len.min(file_size - pos);
{
let read_buf = r.buf();
let mut offset = pos;
let mut remaining = read_len;
let mut offset = pos as usize;
let mut remaining = read_len as usize;
let mut buf_offset = 0;
while remaining > 0 {
@@ -134,7 +129,7 @@ impl File for MemoryFile {
Ok(c)
}
fn pwrite(&self, pos: usize, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
tracing::debug!(
"pwrite(path={}): pos={}, size={}",
self.path,
@@ -147,7 +142,7 @@ impl File for MemoryFile {
return Ok(c);
}
let mut offset = pos;
let mut offset = pos as usize;
let mut remaining = buf_len;
let mut buf_offset = 0;
let data = &buffer.as_slice();
@@ -158,7 +153,7 @@ impl File for MemoryFile {
let bytes_to_write = remaining.min(PAGE_SIZE - page_offset);
{
let page = self.get_or_allocate_page(page_no);
let page = self.get_or_allocate_page(page_no as u64);
page[page_offset..page_offset + bytes_to_write]
.copy_from_slice(&data[buf_offset..buf_offset + bytes_to_write]);
}
@@ -169,7 +164,7 @@ impl File for MemoryFile {
}
self.size
.set(core::cmp::max(pos + buf_len, self.size.get()));
.set(core::cmp::max(pos + buf_len as u64, self.size.get()));
c.complete(buf_len as i32);
Ok(c)
@@ -182,13 +177,13 @@ impl File for MemoryFile {
Ok(c)
}
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
tracing::debug!("truncate(path={}): len={}", self.path, len);
if len < self.size.get() {
// Truncate pages
unsafe {
let pages = &mut *self.pages.get();
pages.retain(|&k, _| k * PAGE_SIZE < len);
pages.retain(|&k, _| k * PAGE_SIZE < len as usize);
}
}
self.size.set(len);
@@ -196,14 +191,14 @@ impl File for MemoryFile {
Ok(c)
}
fn pwritev(&self, pos: usize, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
fn pwritev(&self, pos: u64, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
tracing::debug!(
"pwritev(path={}): pos={}, buffers={:?}",
self.path,
pos,
buffers.iter().map(|x| x.len()).collect::<Vec<_>>()
);
let mut offset = pos;
let mut offset = pos as usize;
let mut total_written = 0;
for buffer in buffers {
@@ -222,7 +217,7 @@ impl File for MemoryFile {
let bytes_to_write = remaining.min(PAGE_SIZE - page_offset);
{
let page = self.get_or_allocate_page(page_no);
let page = self.get_or_allocate_page(page_no as u64);
page[page_offset..page_offset + bytes_to_write]
.copy_from_slice(&data[buf_offset..buf_offset + bytes_to_write]);
}
@@ -235,23 +230,23 @@ impl File for MemoryFile {
}
c.complete(total_written as i32);
self.size
.set(core::cmp::max(pos + total_written, self.size.get()));
.set(core::cmp::max(pos + total_written as u64, self.size.get()));
Ok(c)
}
fn size(&self) -> Result<u64> {
tracing::debug!("size(path={}): {}", self.path, self.size.get());
Ok(self.size.get() as u64)
Ok(self.size.get())
}
}
impl MemoryFile {
#[allow(clippy::mut_from_ref)]
fn get_or_allocate_page(&self, page_no: usize) -> &mut MemPage {
fn get_or_allocate_page(&self, page_no: u64) -> &mut MemPage {
unsafe {
let pages = &mut *self.pages.get();
pages
.entry(page_no)
.entry(page_no as usize)
.or_insert_with(|| Box::new([0; PAGE_SIZE]))
}
}

View File

@@ -12,10 +12,10 @@ use std::{fmt::Debug, pin::Pin};
pub trait File: Send + Sync {
fn lock_file(&self, exclusive: bool) -> Result<()>;
fn unlock_file(&self) -> Result<()>;
fn pread(&self, pos: usize, c: Completion) -> Result<Completion>;
fn pwrite(&self, pos: usize, buffer: Arc<Buffer>, c: Completion) -> Result<Completion>;
fn pread(&self, pos: u64, c: Completion) -> Result<Completion>;
fn pwrite(&self, pos: u64, buffer: Arc<Buffer>, c: Completion) -> Result<Completion>;
fn sync(&self, c: Completion) -> Result<Completion>;
fn pwritev(&self, pos: usize, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
fn pwritev(&self, pos: u64, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
use std::sync::atomic::{AtomicUsize, Ordering};
if buffers.is_empty() {
c.complete(0);
@@ -56,12 +56,12 @@ pub trait File: Send + Sync {
c.abort();
return Err(e);
}
pos += len;
pos += len as u64;
}
Ok(c)
}
fn size(&self) -> Result<u64>;
fn truncate(&self, len: usize, c: Completion) -> Result<Completion>;
fn truncate(&self, len: u64, c: Completion) -> Result<Completion>;
}
#[derive(Debug, Copy, Clone, PartialEq)]
@@ -87,7 +87,9 @@ pub trait IO: Clock + Send + Sync {
// remove_file is used in the sync-engine
fn remove_file(&self, path: &str) -> Result<()>;
fn run_once(&self) -> Result<()>;
fn run_once(&self) -> Result<()> {
Ok(())
}
fn wait_for_completion(&self, c: Completion) -> Result<()> {
while !c.finished() {
@@ -214,6 +216,10 @@ impl Completion {
self.inner.result.get().is_some_and(|val| val.is_some())
}
pub fn get_error(&self) -> Option<CompletionError> {
self.inner.result.get().and_then(|res| *res)
}
/// Checks if the Completion completed or errored
pub fn finished(&self) -> bool {
self.inner.result.get().is_some()

View File

@@ -15,8 +15,6 @@ use std::{io::ErrorKind, sync::Arc};
use tracing::debug;
use tracing::{instrument, trace, Level};
/// UnixIO lives longer than any of the files it creates, so it is
/// safe to store references to it's internals in the UnixFiles
pub struct UnixIO {}
unsafe impl Send for UnixIO {}
@@ -127,24 +125,6 @@ impl IO for UnixIO {
}
}
// enum CompletionCallback {
// Read(Arc<Mutex<std::fs::File>>, Completion, usize),
// Write(
// Arc<Mutex<std::fs::File>>,
// Completion,
// Arc<crate::Buffer>,
// usize,
// ),
// Writev(
// Arc<Mutex<std::fs::File>>,
// Completion,
// Vec<Arc<crate::Buffer>>,
// usize, // absolute file offset
// usize, // buf index
// usize, // intra-buf offset
// ),
// }
pub struct UnixFile {
file: Arc<Mutex<std::fs::File>>,
}
@@ -192,7 +172,7 @@ impl File for UnixFile {
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
let file = self.file.lock();
let result = unsafe {
let r = c.as_read();
@@ -217,7 +197,7 @@ impl File for UnixFile {
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn pwrite(&self, pos: usize, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
let file = self.file.lock();
let result = unsafe {
libc::pwrite(
@@ -241,7 +221,7 @@ impl File for UnixFile {
#[instrument(err, skip_all, level = Level::TRACE)]
fn pwritev(
&self,
pos: usize,
pos: u64,
buffers: Vec<Arc<crate::Buffer>>,
c: Completion,
) -> Result<Completion> {
@@ -251,7 +231,7 @@ impl File for UnixFile {
}
let file = self.file.lock();
match try_pwritev_raw(file.as_raw_fd(), pos as u64, &buffers, 0, 0) {
match try_pwritev_raw(file.as_raw_fd(), pos, &buffers, 0, 0) {
Ok(written) => {
trace!("pwritev wrote {written}");
c.complete(written as i32);
@@ -268,7 +248,6 @@ impl File for UnixFile {
let file = self.file.lock();
let result = unsafe {
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
{
libc::fsync(file.as_raw_fd())
@@ -278,14 +257,12 @@ impl File for UnixFile {
{
libc::fcntl(file.as_raw_fd(), libc::F_FULLFSYNC)
}
};
if result == -1 {
let e = std::io::Error::last_os_error();
Err(e.into())
} else {
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
trace!("fsync");
@@ -304,9 +281,9 @@ impl File for UnixFile {
}
#[instrument(err, skip_all, level = Level::INFO)]
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
let file = self.file.lock();
let result = file.set_len(len as u64);
let result = file.set_len(len);
match result {
Ok(()) => {
trace!("file truncated to len=({})", len);

View File

@@ -81,8 +81,6 @@ impl VfsMod {
}
}
// #Safety:
// the callback wrapper in the extension library is FnOnce, so we know
/// # Safety
/// the callback wrapper in the extension library is FnOnce, so we know
/// that the into_raw/from_raw contract will hold
@@ -121,7 +119,7 @@ impl File for VfsFileImpl {
Ok(())
}
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
if self.vfs.is_null() {
c.complete(-1);
return Err(LimboError::ExtensionError("VFS is null".to_string()));
@@ -145,7 +143,7 @@ impl File for VfsFileImpl {
Ok(c)
}
fn pwrite(&self, pos: usize, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
if self.vfs.is_null() {
c.complete(-1);
return Err(LimboError::ExtensionError("VFS is null".to_string()));
@@ -192,7 +190,7 @@ impl File for VfsFileImpl {
}
}
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
if self.vfs.is_null() {
c.complete(-1);
return Err(LimboError::ExtensionError("VFS is null".to_string()));

115
core/io/windows.rs Normal file
View File

@@ -0,0 +1,115 @@
use crate::{Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO};
use parking_lot::RwLock;
use std::io::{Read, Seek, Write};
use std::sync::Arc;
use tracing::{debug, instrument, trace, Level};
pub struct WindowsIO {}
impl WindowsIO {
pub fn new() -> Result<Self> {
debug!("Using IO backend 'syscall'");
Ok(Self {})
}
}
impl IO for WindowsIO {
#[instrument(err, skip_all, level = Level::TRACE)]
fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result<Arc<dyn File>> {
trace!("open_file(path = {})", path);
let mut file = std::fs::File::options();
file.read(true);
if !flags.contains(OpenFlags::ReadOnly) {
file.write(true);
file.create(flags.contains(OpenFlags::Create));
}
let file = file.open(path)?;
Ok(Arc::new(WindowsFile {
file: RwLock::new(file),
}))
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn remove_file(&self, path: &str) -> Result<()> {
trace!("remove_file(path = {})", path);
Ok(std::fs::remove_file(path)?)
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn run_once(&self) -> Result<()> {
Ok(())
}
}
impl Clock for WindowsIO {
fn now(&self) -> Instant {
let now = chrono::Local::now();
Instant {
secs: now.timestamp(),
micros: now.timestamp_subsec_micros(),
}
}
}
pub struct WindowsFile {
file: RwLock<std::fs::File>,
}
impl File for WindowsFile {
#[instrument(err, skip_all, level = Level::TRACE)]
fn lock_file(&self, exclusive: bool) -> Result<()> {
unimplemented!()
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn unlock_file(&self) -> Result<()> {
unimplemented!()
}
#[instrument(skip(self, c), level = Level::TRACE)]
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
let mut file = self.file.write();
file.seek(std::io::SeekFrom::Start(pos))?;
let nr = {
let r = c.as_read();
let buf = r.buf();
let buf = buf.as_mut_slice();
file.read_exact(buf)?;
buf.len() as i32
};
c.complete(nr);
Ok(c)
}
#[instrument(skip(self, c, buffer), level = Level::TRACE)]
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
let mut file = self.file.write();
file.seek(std::io::SeekFrom::Start(pos))?;
let buf = buffer.as_slice();
file.write_all(buf)?;
c.complete(buffer.len() as i32);
Ok(c)
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn sync(&self, c: Completion) -> Result<Completion> {
let file = self.file.write();
file.sync_all()?;
c.complete(0);
Ok(c)
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
let file = self.file.write();
file.set_len(len)?;
c.complete(0);
Ok(c)
}
fn size(&self) -> Result<u64> {
let file = self.file.read();
Ok(file.metadata().unwrap().len())
}
}

View File

@@ -1277,6 +1277,12 @@ impl Connection {
std::fs::set_permissions(&opts.path, perms.permissions())?;
}
let conn = db.connect()?;
if let Some(cipher) = opts.cipher {
let _ = conn.pragma_update("cipher", format!("'{cipher}'"));
}
if let Some(hexkey) = opts.hexkey {
let _ = conn.pragma_update("hexkey", format!("'{hexkey}'"));
}
Ok((io, conn))
}

View File

@@ -1047,8 +1047,8 @@ impl Column {
}
// TODO: This might replace some of util::columns_from_create_table_body
impl From<ColumnDefinition> for Column {
fn from(value: ColumnDefinition) -> Self {
impl From<&ColumnDefinition> for Column {
fn from(value: &ColumnDefinition) -> Self {
let name = value.col_name.as_str();
let mut default = None;
@@ -1057,13 +1057,13 @@ impl From<ColumnDefinition> for Column {
let mut unique = false;
let mut collation = None;
for ast::NamedColumnConstraint { constraint, .. } in value.constraints {
for ast::NamedColumnConstraint { constraint, .. } in &value.constraints {
match constraint {
ast::ColumnConstraint::PrimaryKey { .. } => primary_key = true,
ast::ColumnConstraint::NotNull { .. } => notnull = true,
ast::ColumnConstraint::Unique(..) => unique = true,
ast::ColumnConstraint::Default(expr) => {
default.replace(expr);
default.replace(expr.clone());
}
ast::ColumnConstraint::Collate { collation_name } => {
collation.replace(
@@ -1082,11 +1082,14 @@ impl From<ColumnDefinition> for Column {
let ty_str = value
.col_type
.as_ref()
.map(|t| t.name.to_string())
.unwrap_or_default();
let hidden = ty_str.contains("HIDDEN");
Column {
name: Some(name.to_string()),
name: Some(normalize_ident(name)),
ty,
default,
notnull,
@@ -1095,7 +1098,7 @@ impl From<ColumnDefinition> for Column {
is_rowid_alias: primary_key && matches!(ty, Type::Integer),
unique,
collation,
hidden: false,
hidden,
}
}
}

View File

@@ -91,7 +91,9 @@ impl DatabaseStorage for DatabaseFile {
if !(512..=65536).contains(&size) || size & (size - 1) != 0 {
return Err(LimboError::NotADB);
}
let pos = (page_idx - 1) * size;
let Some(pos) = (page_idx as u64 - 1).checked_mul(size as u64) else {
return Err(LimboError::IntegerOverflow);
};
if let Some(ctx) = io_ctx.encryption_context() {
let encryption_ctx = ctx.clone();
@@ -145,7 +147,9 @@ impl DatabaseStorage for DatabaseFile {
assert!(buffer_size >= 512);
assert!(buffer_size <= 65536);
assert_eq!(buffer_size & (buffer_size - 1), 0);
let pos = (page_idx - 1) * buffer_size;
let Some(pos) = (page_idx as u64 - 1).checked_mul(buffer_size as u64) else {
return Err(LimboError::IntegerOverflow);
};
let buffer = {
if let Some(ctx) = io_ctx.encryption_context() {
encrypt_buffer(page_idx, buffer, ctx)
@@ -169,7 +173,9 @@ impl DatabaseStorage for DatabaseFile {
assert!(page_size <= 65536);
assert_eq!(page_size & (page_size - 1), 0);
let pos = (first_page_idx - 1) * page_size;
let Some(pos) = (first_page_idx as u64 - 1).checked_mul(page_size as u64) else {
return Err(LimboError::IntegerOverflow);
};
let buffers = {
if let Some(ctx) = io_ctx.encryption_context() {
buffers
@@ -198,7 +204,7 @@ impl DatabaseStorage for DatabaseFile {
#[instrument(skip_all, level = Level::INFO)]
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
let c = self.file.truncate(len, c)?;
let c = self.file.truncate(len as u64, c)?;
Ok(c)
}
}

View File

@@ -1,14 +1,13 @@
#![allow(unused_variables, dead_code)]
use crate::{LimboError, Result};
use aegis::aegis256::Aegis256;
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use aes_gcm::aead::{AeadCore, OsRng};
use std::ops::Deref;
use turso_macros::match_ignore_ascii_case;
pub const ENCRYPTED_PAGE_SIZE: usize = 4096;
// AEGIS-256 supports both 16 and 32 byte tags, we use the 16 byte variant, it is faster
// and provides sufficient security for our use case.
const AEGIS_TAG_SIZE: usize = 16;
const AES256GCM_TAG_SIZE: usize = 16;
#[repr(transparent)]
#[derive(Clone)]
@@ -74,10 +73,25 @@ impl Drop for EncryptionKey {
}
}
pub trait AeadCipher {
fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>)>;
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8], ad: &[u8]) -> Result<Vec<u8>>;
fn encrypt_detached(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)>;
fn decrypt_detached(
&self,
ciphertext: &[u8],
nonce: &[u8],
tag: &[u8],
ad: &[u8],
) -> Result<Vec<u8>>;
}
// wrapper struct for AEGIS-256 cipher, because the crate we use is a bit low-level and we add
// some nice abstractions here
// note, the AEGIS has many variants and support for hardware acceleration. Here we just use the
// vanilla version, which is still order of maginitudes faster than AES-GCM in software. Hardware
// vanilla version, which is still order of magnitudes faster than AES-GCM in software. Hardware
// based compilation is left for future work.
#[derive(Clone)]
pub struct Aegis256Cipher {
@@ -85,39 +99,154 @@ pub struct Aegis256Cipher {
}
impl Aegis256Cipher {
// AEGIS-256 supports both 16 and 32 byte tags, we use the 16 byte variant, it is faster
// and provides sufficient security for our use case.
const TAG_SIZE: usize = 16;
fn new(key: &EncryptionKey) -> Self {
Self { key: key.clone() }
}
}
fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, [u8; 32])> {
impl AeadCipher for Aegis256Cipher {
fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
let nonce = generate_secure_nonce();
let (ciphertext, tag) =
Aegis256::<16>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad);
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad);
let mut result = ciphertext;
result.extend_from_slice(&tag);
Ok((result, nonce))
Ok((result, nonce.to_vec()))
}
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8; 32], ad: &[u8]) -> Result<Vec<u8>> {
if ciphertext.len() < Self::TAG_SIZE {
return Err(LimboError::InternalError(
"Ciphertext too short for AEGIS-256".into(),
));
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8], ad: &[u8]) -> Result<Vec<u8>> {
if ciphertext.len() < AEGIS_TAG_SIZE {
return Err(LimboError::InternalError("Ciphertext too short".into()));
}
let (ct, tag) = ciphertext.split_at(ciphertext.len() - Self::TAG_SIZE);
let tag_array: [u8; 16] = tag
.try_into()
.map_err(|_| LimboError::InternalError("Invalid tag size for AEGIS-256".into()))?;
let (ct, tag) = ciphertext.split_at(ciphertext.len() - AEGIS_TAG_SIZE);
let tag_array: [u8; AEGIS_TAG_SIZE] = tag.try_into().map_err(|_| {
LimboError::InternalError(format!("Invalid tag size for AEGIS-256 {AEGIS_TAG_SIZE}"))
})?;
let plaintext = Aegis256::<16>::new(self.key.as_bytes(), nonce)
let nonce_array: [u8; 32] = nonce
.try_into()
.map_err(|_| LimboError::InternalError("Invalid nonce size for AEGIS-256".into()))?;
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce_array)
.decrypt(ct, &tag_array, ad)
.map_err(|_| {
LimboError::InternalError("AEGIS-256 decryption failed: invalid tag".into())
})?;
Ok(plaintext)
.map_err(|_| LimboError::InternalError("AEGIS-256 decryption failed".into()))
}
fn encrypt_detached(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
let nonce = generate_secure_nonce();
let (ciphertext, tag) =
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad);
Ok((ciphertext, tag.to_vec(), nonce.to_vec()))
}
fn decrypt_detached(
&self,
ciphertext: &[u8],
nonce: &[u8],
tag: &[u8],
ad: &[u8],
) -> Result<Vec<u8>> {
let tag_array: [u8; AEGIS_TAG_SIZE] = tag.try_into().map_err(|_| {
LimboError::InternalError(format!("Invalid tag size for AEGIS-256 {AEGIS_TAG_SIZE}"))
})?;
let nonce_array: [u8; 32] = nonce
.try_into()
.map_err(|_| LimboError::InternalError("Invalid nonce size for AEGIS-256".into()))?;
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce_array)
.decrypt(ciphertext, &tag_array, ad)
.map_err(|_| LimboError::InternalError("AEGIS-256 decrypt_detached failed".into()))
}
}
#[derive(Clone)]
pub struct Aes256GcmCipher {
key: EncryptionKey,
}
impl Aes256GcmCipher {
fn new(key: &EncryptionKey) -> Self {
Self { key: key.clone() }
}
}
impl AeadCipher for Aes256GcmCipher {
fn encrypt(&self, plaintext: &[u8], _ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
use aes_gcm::aead::{AeadInPlace, KeyInit};
use aes_gcm::Aes256Gcm;
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
let nonce = Aes256Gcm::generate_nonce(&mut rand::thread_rng());
let mut buffer = plaintext.to_vec();
let tag = cipher
.encrypt_in_place_detached(&nonce, b"", &mut buffer)
.map_err(|_| LimboError::InternalError("AES-GCM encrypt failed".into()))?;
buffer.extend_from_slice(&tag[..AES256GCM_TAG_SIZE]);
Ok((buffer, nonce.to_vec()))
}
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8], ad: &[u8]) -> Result<Vec<u8>> {
use aes_gcm::aead::{AeadInPlace, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
if ciphertext.len() < AES256GCM_TAG_SIZE {
return Err(LimboError::InternalError("Ciphertext too short".into()));
}
let (ct, tag) = ciphertext.split_at(ciphertext.len() - AES256GCM_TAG_SIZE);
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
let nonce = Nonce::from_slice(nonce);
let mut buffer = ct.to_vec();
cipher
.decrypt_in_place_detached(nonce, ad, &mut buffer, tag.into())
.map_err(|_| LimboError::InternalError("AES-GCM decrypt failed".into()))?;
Ok(buffer)
}
fn encrypt_detached(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
use aes_gcm::aead::{AeadInPlace, KeyInit};
use aes_gcm::Aes256Gcm;
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
let nonce = Aes256Gcm::generate_nonce(&mut rand::thread_rng());
let mut buffer = plaintext.to_vec();
let tag = cipher
.encrypt_in_place_detached(&nonce, ad, &mut buffer)
.map_err(|_| LimboError::InternalError("AES-GCM encrypt_detached failed".into()))?;
Ok((buffer, nonce.to_vec(), tag.to_vec()))
}
fn decrypt_detached(
&self,
ciphertext: &[u8],
nonce: &[u8],
tag: &[u8],
ad: &[u8],
) -> Result<Vec<u8>> {
use aes_gcm::aead::{AeadInPlace, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
let nonce = Nonce::from_slice(nonce);
let mut buffer = ciphertext.to_vec();
cipher
.decrypt_in_place_detached(nonce, ad, &mut buffer, tag.into())
.map_err(|_| LimboError::InternalError("AES-GCM decrypt_detached failed".into()))?;
Ok(buffer)
}
}
@@ -180,8 +309,8 @@ impl CipherMode {
/// Returns the authentication tag size for this cipher mode.
pub fn tag_size(&self) -> usize {
match self {
CipherMode::Aes256Gcm => 16,
CipherMode::Aegis256 => 16,
CipherMode::Aes256Gcm => AES256GCM_TAG_SIZE,
CipherMode::Aegis256 => AEGIS_TAG_SIZE,
}
}
@@ -193,8 +322,17 @@ impl CipherMode {
#[derive(Clone)]
pub enum Cipher {
Aes256Gcm(Box<Aes256Gcm>),
Aegis256(Box<Aegis256Cipher>),
Aes256Gcm(Aes256GcmCipher),
Aegis256(Aegis256Cipher),
}
impl Cipher {
fn as_aead(&self) -> &dyn AeadCipher {
match self {
Cipher::Aes256Gcm(c) => c,
Cipher::Aegis256(c) => c,
}
}
}
impl std::fmt::Debug for Cipher {
@@ -210,10 +348,11 @@ impl std::fmt::Debug for Cipher {
pub struct EncryptionContext {
cipher_mode: CipherMode,
cipher: Cipher,
page_size: usize,
}
impl EncryptionContext {
pub fn new(cipher_mode: CipherMode, key: &EncryptionKey) -> Result<Self> {
pub fn new(cipher_mode: CipherMode, key: &EncryptionKey, page_size: usize) -> Result<Self> {
let required_size = cipher_mode.required_key_size();
if key.as_slice().len() != required_size {
return Err(crate::LimboError::InvalidArgument(format!(
@@ -225,15 +364,13 @@ impl EncryptionContext {
}
let cipher = match cipher_mode {
CipherMode::Aes256Gcm => {
let cipher_key: &Key<Aes256Gcm> = key.as_ref().into();
Cipher::Aes256Gcm(Box::new(Aes256Gcm::new(cipher_key)))
}
CipherMode::Aegis256 => Cipher::Aegis256(Box::new(Aegis256Cipher::new(key))),
CipherMode::Aes256Gcm => Cipher::Aes256Gcm(Aes256GcmCipher::new(key)),
CipherMode::Aegis256 => Cipher::Aegis256(Aegis256Cipher::new(key)),
};
Ok(Self {
cipher_mode,
cipher,
page_size,
})
}
@@ -255,36 +392,38 @@ impl EncryptionContext {
tracing::debug!("encrypting page {}", page_id);
assert_eq!(
page.len(),
ENCRYPTED_PAGE_SIZE,
"Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
self.page_size,
"Page data must be exactly {} bytes",
self.page_size
);
let metadata_size = self.cipher_mode.metadata_size();
let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - metadata_size..];
let reserved_bytes = &page[self.page_size - metadata_size..];
let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0);
assert!(
reserved_bytes_zeroed,
"last reserved bytes must be empty/zero, but found non-zero bytes"
);
let payload = &page[..ENCRYPTED_PAGE_SIZE - metadata_size];
let payload = &page[..self.page_size - metadata_size];
let (encrypted, nonce) = self.encrypt_raw(payload)?;
let nonce_size = self.cipher_mode.nonce_size();
assert_eq!(
encrypted.len(),
ENCRYPTED_PAGE_SIZE - nonce_size,
self.page_size - nonce_size,
"Encrypted page must be exactly {} bytes",
ENCRYPTED_PAGE_SIZE - nonce_size
self.page_size - nonce_size
);
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
let mut result = Vec::with_capacity(self.page_size);
result.extend_from_slice(&encrypted);
result.extend_from_slice(&nonce);
assert_eq!(
result.len(),
ENCRYPTED_PAGE_SIZE,
"Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
self.page_size,
"Encrypted page must be exactly {} bytes",
self.page_size
);
Ok(result)
}
@@ -298,8 +437,9 @@ impl EncryptionContext {
tracing::debug!("decrypting page {}", page_id);
assert_eq!(
encrypted_page.len(),
ENCRYPTED_PAGE_SIZE,
"Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
self.page_size,
"Encrypted page data must be exactly {} bytes",
self.page_size
);
let nonce_size = self.cipher_mode.nonce_size();
@@ -312,60 +452,40 @@ impl EncryptionContext {
let metadata_size = self.cipher_mode.metadata_size();
assert_eq!(
decrypted_data.len(),
ENCRYPTED_PAGE_SIZE - metadata_size,
self.page_size - metadata_size,
"Decrypted page data must be exactly {} bytes",
ENCRYPTED_PAGE_SIZE - metadata_size
self.page_size - metadata_size
);
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
let mut result = Vec::with_capacity(self.page_size);
result.extend_from_slice(&decrypted_data);
result.resize(ENCRYPTED_PAGE_SIZE, 0);
result.resize(self.page_size, 0);
assert_eq!(
result.len(),
ENCRYPTED_PAGE_SIZE,
"Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
self.page_size,
"Decrypted page data must be exactly {} bytes",
self.page_size
);
Ok(result)
}
/// encrypts raw data using the configured cipher, returns ciphertext and nonce
fn encrypt_raw(&self, plaintext: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
match &self.cipher {
Cipher::Aes256Gcm(cipher) => {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher
.encrypt(&nonce, plaintext)
.map_err(|e| LimboError::InternalError(format!("Encryption failed: {e:?}")))?;
Ok((ciphertext, nonce.to_vec()))
}
Cipher::Aegis256(cipher) => {
let ad = b"";
let (ciphertext, nonce) = cipher.encrypt(plaintext, ad)?;
Ok((ciphertext, nonce.to_vec()))
}
}
self.cipher.as_aead().encrypt(plaintext, b"")
}
fn decrypt_raw(&self, ciphertext: &[u8], nonce: &[u8]) -> Result<Vec<u8>> {
match &self.cipher {
Cipher::Aes256Gcm(cipher) => {
let nonce = Nonce::from_slice(nonce);
let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|e| {
crate::LimboError::InternalError(format!("Decryption failed: {e:?}"))
})?;
Ok(plaintext)
}
Cipher::Aegis256(cipher) => {
let nonce_array: [u8; 32] = nonce.try_into().map_err(|_| {
LimboError::InternalError(format!(
"Invalid nonce size for AEGIS-256: expected 32, got {}",
nonce.len()
))
})?;
let ad = b"";
cipher.decrypt(ciphertext, &nonce_array, ad)
}
}
self.cipher.as_aead().decrypt(ciphertext, nonce, b"")
}
fn encrypt_raw_detached(&self, plaintext: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
self.cipher.as_aead().encrypt_detached(plaintext, b"")
}
fn decrypt_raw_detached(&self, ciphertext: &[u8], nonce: &[u8], tag: &[u8]) -> Result<Vec<u8>> {
self.cipher
.as_aead()
.decrypt_detached(ciphertext, nonce, tag, b"")
}
#[cfg(not(feature = "encryption"))]
@@ -391,10 +511,12 @@ fn generate_secure_nonce() -> [u8; 32] {
nonce
}
#[cfg(feature = "encryption")]
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
const DEFAULT_ENCRYPTED_PAGE_SIZE: usize = 4096;
fn generate_random_hex_key() -> String {
let mut rng = rand::thread_rng();
@@ -404,15 +526,14 @@ mod tests {
}
#[test]
#[cfg(feature = "encryption")]
fn test_aes_encrypt_decrypt_round_trip() {
let mut rng = rand::thread_rng();
let cipher_mode = CipherMode::Aes256Gcm;
let metadata_size = cipher_mode.metadata_size();
let data_size = ENCRYPTED_PAGE_SIZE - metadata_size;
let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size;
let page_data = {
let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE];
let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE];
page.iter_mut()
.take(data_size)
.for_each(|byte| *byte = rng.gen());
@@ -420,21 +541,21 @@ mod tests {
};
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aes256Gcm, &key).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aes256Gcm, &key, DEFAULT_ENCRYPTED_PAGE_SIZE)
.unwrap();
let page_id = 42;
let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap();
assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_eq!(encrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
assert_ne!(&encrypted[..data_size], &page_data[..data_size]);
assert_ne!(&encrypted[..], &page_data[..]);
let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap();
assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_eq!(decrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
assert_eq!(decrypted, page_data);
}
#[test]
#[cfg(feature = "encryption")]
fn test_aegis256_cipher_wrapper() {
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
let cipher = Aegis256Cipher::new(&key);
@@ -451,10 +572,10 @@ mod tests {
}
#[test]
#[cfg(feature = "encryption")]
fn test_aegis256_raw_encryption() {
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key, DEFAULT_ENCRYPTED_PAGE_SIZE)
.unwrap();
let plaintext = b"Hello, AEGIS-256!";
let (ciphertext, nonce) = ctx.encrypt_raw(plaintext).unwrap();
@@ -467,15 +588,14 @@ mod tests {
}
#[test]
#[cfg(feature = "encryption")]
fn test_aegis256_encrypt_decrypt_round_trip() {
let mut rng = rand::thread_rng();
let cipher_mode = CipherMode::Aegis256;
let metadata_size = cipher_mode.metadata_size();
let data_size = ENCRYPTED_PAGE_SIZE - metadata_size;
let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size;
let page_data = {
let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE];
let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE];
page.iter_mut()
.take(data_size)
.for_each(|byte| *byte = rng.gen());
@@ -483,15 +603,16 @@ mod tests {
};
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key, DEFAULT_ENCRYPTED_PAGE_SIZE)
.unwrap();
let page_id = 42;
let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap();
assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_eq!(encrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
assert_ne!(&encrypted[..data_size], &page_data[..data_size]);
let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap();
assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_eq!(decrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
assert_eq!(decrypted, page_data);
}
}

View File

@@ -12,7 +12,7 @@ use super::pager::PageRef;
const DEFAULT_PAGE_CACHE_SIZE_IN_PAGES_MAKE_ME_SMALLER_ONCE_WAL_SPILL_IS_IMPLEMENTED: usize =
100000;
#[derive(Debug, Eq, Hash, PartialEq, Clone)]
#[derive(Debug, Eq, Hash, PartialEq, Clone, Copy)]
pub struct PageCacheKey {
pgno: usize,
}
@@ -47,14 +47,21 @@ struct HashMapNode {
value: NonNull<PageCacheEntry>,
}
#[derive(Debug, PartialEq)]
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum CacheError {
#[error("{0}")]
InternalError(String),
#[error("page {pgno} is locked")]
Locked { pgno: usize },
#[error("page {pgno} is dirty")]
Dirty { pgno: usize },
#[error("page {pgno} is pinned")]
Pinned { pgno: usize },
#[error("cache active refs")]
ActiveRefs,
#[error("Page cache is full")]
Full,
#[error("key already exists")]
KeyExists,
}
@@ -105,7 +112,7 @@ impl DumbLruPageCache {
trace!("insert(key={:?})", key);
// Check first if page already exists in cache
if !ignore_exists {
if let Some(existing_page_ref) = self.get(&key) {
if let Some(existing_page_ref) = self.get(&key)? {
assert!(
Arc::ptr_eq(&value, &existing_page_ref),
"Attempted to insert different page with same key: {key:?}"
@@ -115,7 +122,7 @@ impl DumbLruPageCache {
}
self.make_room_for(1)?;
let entry = Box::new(PageCacheEntry {
key: key.clone(),
key,
next: None,
prev: None,
page: value,
@@ -156,8 +163,21 @@ impl DumbLruPageCache {
ptr.copied()
}
pub fn get(&mut self, key: &PageCacheKey) -> Option<PageRef> {
self.peek(key, true)
pub fn get(&mut self, key: &PageCacheKey) -> Result<Option<PageRef>, CacheError> {
if let Some(page) = self.peek(key, true) {
// Because we can abort a read_page completion, this means a page can be in the cache but be unloaded and unlocked.
// However, if we do not evict that page from the page cache, we will return an unloaded page later which will trigger
// assertions later on. This is worsened by the fact that page cache is not per `Statement`, so you can abort a completion
// in one Statement, and trigger some error in the next one if we don't evict the page here.
if !page.is_loaded() && !page.is_locked() {
self.delete(*key)?;
Ok(None)
} else {
Ok(Some(page))
}
} else {
Ok(None)
}
}
/// Get page without promoting entry
@@ -309,7 +329,7 @@ impl DumbLruPageCache {
let entry = unsafe { current.as_ref() };
// Pick prev before modifying entry
current_opt = entry.prev;
match self.delete(entry.key.clone()) {
match self.delete(entry.key) {
Err(_) => {}
Ok(_) => need_to_evict -= 1,
}
@@ -396,7 +416,7 @@ impl DumbLruPageCache {
let mut current = head_ptr;
while let Some(node) = current {
unsafe {
this_keys.push(node.as_ref().key.clone());
this_keys.push(node.as_ref().key);
let node_ref = node.as_ref();
current = node_ref.next;
}
@@ -647,7 +667,7 @@ impl PageHashMap {
pub fn rehash(&self, new_capacity: usize) -> PageHashMap {
let mut new_hash_map = PageHashMap::new(new_capacity);
for node in self.iter() {
new_hash_map.insert(node.key.clone(), node.value);
new_hash_map.insert(node.key, node.value);
}
new_hash_map
}
@@ -698,7 +718,7 @@ mod tests {
fn insert_page(cache: &mut DumbLruPageCache, id: usize) -> PageCacheKey {
let key = create_key(id);
let page = page_with_content(id);
assert!(cache.insert(key.clone(), page).is_ok());
assert!(cache.insert(key, page).is_ok());
key
}
@@ -712,7 +732,7 @@ mod tests {
) -> (PageCacheKey, NonNull<PageCacheEntry>) {
let key = create_key(id);
let page = page_with_content(id);
assert!(cache.insert(key.clone(), page).is_ok());
assert!(cache.insert(key, page).is_ok());
let entry = cache.get_ptr(&key).expect("Entry should exist");
(key, entry)
}
@@ -727,7 +747,7 @@ mod tests {
assert!(cache.tail.borrow().is_some());
assert_eq!(*cache.head.borrow(), *cache.tail.borrow());
assert!(cache.delete(key1.clone()).is_ok());
assert!(cache.delete(key1).is_ok());
assert_eq!(
cache.len(),
@@ -759,7 +779,7 @@ mod tests {
"Initial head check"
);
assert!(cache.delete(key3.clone()).is_ok());
assert!(cache.delete(key3).is_ok());
assert_eq!(cache.len(), 2, "Length should be 2 after deleting head");
assert!(
@@ -803,7 +823,7 @@ mod tests {
"Initial tail check"
);
assert!(cache.delete(key1.clone()).is_ok()); // Delete tail
assert!(cache.delete(key1).is_ok()); // Delete tail
assert_eq!(cache.len(), 2, "Length should be 2 after deleting tail");
assert!(
@@ -854,7 +874,7 @@ mod tests {
let head_ptr_before = cache.head.borrow().unwrap();
let tail_ptr_before = cache.tail.borrow().unwrap();
assert!(cache.delete(key2.clone()).is_ok()); // Detach a middle element (key2)
assert!(cache.delete(key2).is_ok()); // Detach a middle element (key2)
assert_eq!(cache.len(), 3, "Length should be 3 after deleting middle");
assert!(
@@ -895,11 +915,11 @@ mod tests {
let mut cache = DumbLruPageCache::default();
let key1 = create_key(1);
let page1 = page_with_content(1);
assert!(cache.insert(key1.clone(), page1.clone()).is_ok());
assert!(cache.insert(key1, page1.clone()).is_ok());
assert!(page_has_content(&page1));
cache.verify_list_integrity();
let result = cache.delete(key1.clone());
let result = cache.delete(key1);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CacheError::ActiveRefs);
assert_eq!(cache.len(), 1);
@@ -918,10 +938,10 @@ mod tests {
let key1 = create_key(1);
let page1_v1 = page_with_content(1);
let page1_v2 = page_with_content(1);
assert!(cache.insert(key1.clone(), page1_v1.clone()).is_ok());
assert!(cache.insert(key1, page1_v1.clone()).is_ok());
assert_eq!(cache.len(), 1);
cache.verify_list_integrity();
let _ = cache.insert(key1.clone(), page1_v2.clone()); // Panic
let _ = cache.insert(key1, page1_v2.clone()); // Panic
}
#[test]
@@ -929,7 +949,7 @@ mod tests {
let mut cache = DumbLruPageCache::default();
let key_nonexist = create_key(99);
assert!(cache.delete(key_nonexist.clone()).is_ok()); // no-op
assert!(cache.delete(key_nonexist).is_ok()); // no-op
}
#[test]
@@ -937,8 +957,8 @@ mod tests {
let mut cache = DumbLruPageCache::new(1);
let key1 = insert_page(&mut cache, 1);
let key2 = insert_page(&mut cache, 2);
assert_eq!(cache.get(&key2).unwrap().get().id, 2);
assert!(cache.get(&key1).is_none());
assert_eq!(cache.get(&key2).unwrap().unwrap().get().id, 2);
assert!(cache.get(&key1).unwrap().is_none());
}
#[test]
@@ -1002,7 +1022,7 @@ mod tests {
fn test_detach_with_cleaning() {
let mut cache = DumbLruPageCache::default();
let (key, entry) = insert_and_get_entry(&mut cache, 1);
let page = cache.get(&key).expect("Page should exist");
let page = cache.get(&key).unwrap().expect("Page should exist");
assert!(page_has_content(&page));
drop(page);
assert!(cache.detach(entry, true).is_ok());
@@ -1034,8 +1054,8 @@ mod tests {
let (key1, _) = insert_and_get_entry(&mut cache, 1);
let (key2, entry2) = insert_and_get_entry(&mut cache, 2);
let (key3, _) = insert_and_get_entry(&mut cache, 3);
let head_key = unsafe { cache.head.borrow().unwrap().as_ref().key.clone() };
let tail_key = unsafe { cache.tail.borrow().unwrap().as_ref().key.clone() };
let head_key = unsafe { cache.head.borrow().unwrap().as_ref().key };
let tail_key = unsafe { cache.tail.borrow().unwrap().as_ref().key };
assert_eq!(head_key, key3, "Head should be key3");
assert_eq!(tail_key, key1, "Tail should be key1");
assert!(cache.detach(entry2, false).is_ok());
@@ -1044,12 +1064,12 @@ mod tests {
assert_eq!(head_entry.key, key3, "Head should still be key3");
assert_eq!(tail_entry.key, key1, "Tail should still be key1");
assert_eq!(
unsafe { head_entry.next.unwrap().as_ref().key.clone() },
unsafe { head_entry.next.unwrap().as_ref().key },
key1,
"Head's next should point to tail after middle element detached"
);
assert_eq!(
unsafe { tail_entry.prev.unwrap().as_ref().key.clone() },
unsafe { tail_entry.prev.unwrap().as_ref().key },
key3,
"Tail's prev should point to head after middle element detached"
);
@@ -1085,7 +1105,7 @@ mod tests {
continue; // skip duplicate page ids
}
tracing::debug!("inserting page {:?}", key);
match cache.insert(key.clone(), page.clone()) {
match cache.insert(key, page.clone()) {
Err(CacheError::Full | CacheError::ActiveRefs) => {} // Ignore
Err(err) => {
// Any other error should fail the test
@@ -1106,7 +1126,7 @@ mod tests {
PageCacheKey::new(id_page as usize)
} else {
let i = rng.next_u64() as usize % lru.len();
let key: PageCacheKey = lru.iter().nth(i).unwrap().0.clone();
let key: PageCacheKey = *lru.iter().nth(i).unwrap().0;
key
};
tracing::debug!("removing page {:?}", key);
@@ -1133,7 +1153,7 @@ mod tests {
let this_keys = cache.keys();
let mut lru_keys = Vec::new();
for (lru_key, _) in lru {
lru_keys.push(lru_key.clone());
lru_keys.push(*lru_key);
}
if this_keys != lru_keys {
cache.print();
@@ -1149,8 +1169,8 @@ mod tests {
let mut cache = DumbLruPageCache::default();
let key1 = insert_page(&mut cache, 1);
let key2 = insert_page(&mut cache, 2);
assert_eq!(cache.get(&key1).unwrap().get().id, 1);
assert_eq!(cache.get(&key2).unwrap().get().id, 2);
assert_eq!(cache.get(&key1).unwrap().unwrap().get().id, 1);
assert_eq!(cache.get(&key2).unwrap().unwrap().get().id, 2);
}
#[test]
@@ -1159,17 +1179,17 @@ mod tests {
let key1 = insert_page(&mut cache, 1);
let key2 = insert_page(&mut cache, 2);
let key3 = insert_page(&mut cache, 3);
assert!(cache.get(&key1).is_none());
assert_eq!(cache.get(&key2).unwrap().get().id, 2);
assert_eq!(cache.get(&key3).unwrap().get().id, 3);
assert!(cache.get(&key1).unwrap().is_none());
assert_eq!(cache.get(&key2).unwrap().unwrap().get().id, 2);
assert_eq!(cache.get(&key3).unwrap().unwrap().get().id, 3);
}
#[test]
fn test_page_cache_delete() {
let mut cache = DumbLruPageCache::default();
let key1 = insert_page(&mut cache, 1);
assert!(cache.delete(key1.clone()).is_ok());
assert!(cache.get(&key1).is_none());
assert!(cache.delete(key1).is_ok());
assert!(cache.get(&key1).unwrap().is_none());
}
#[test]
@@ -1178,8 +1198,8 @@ mod tests {
let key1 = insert_page(&mut cache, 1);
let key2 = insert_page(&mut cache, 2);
assert!(cache.clear().is_ok());
assert!(cache.get(&key1).is_none());
assert!(cache.get(&key2).is_none());
assert!(cache.get(&key1).unwrap().is_none());
assert!(cache.get(&key2).unwrap().is_none());
}
#[test]
@@ -1216,8 +1236,8 @@ mod tests {
assert_eq!(result, CacheResizeResult::Done);
assert_eq!(cache.len(), 2);
assert_eq!(cache.capacity, 5);
assert!(cache.get(&create_key(1)).is_some());
assert!(cache.get(&create_key(2)).is_some());
assert!(cache.get(&create_key(1)).unwrap().is_some());
assert!(cache.get(&create_key(2)).unwrap().is_some());
for i in 3..=5 {
let _ = insert_page(&mut cache, i);
}

View File

@@ -1123,7 +1123,7 @@ impl Pager {
tracing::trace!("read_page(page_idx = {})", page_idx);
let mut page_cache = self.page_cache.write();
let page_key = PageCacheKey::new(page_idx);
if let Some(page) = page_cache.get(&page_key) {
if let Some(page) = page_cache.get(&page_key)? {
tracing::trace!("read_page(page_idx = {}) = cached", page_idx);
return Ok((page.clone(), None));
}
@@ -1158,25 +1158,20 @@ impl Pager {
let page_key = PageCacheKey::new(page_idx);
match page_cache.insert(page_key, page.clone()) {
Ok(_) => {}
Err(CacheError::Full) => return Err(LimboError::CacheFull),
Err(CacheError::KeyExists) => {
unreachable!("Page should not exist in cache after get() miss")
}
Err(e) => {
return Err(LimboError::InternalError(format!(
"Failed to insert page into cache: {e:?}"
)))
}
Err(e) => return Err(e.into()),
}
Ok(())
}
// Get a page from the cache, if it exists.
pub fn cache_get(&self, page_idx: usize) -> Option<PageRef> {
pub fn cache_get(&self, page_idx: usize) -> Result<Option<PageRef>> {
tracing::trace!("read_page(page_idx = {})", page_idx);
let mut page_cache = self.page_cache.write();
let page_key = PageCacheKey::new(page_idx);
page_cache.get(&page_key)
Ok(page_cache.get(&page_key)?)
}
/// Get a page from cache only if it matches the target frame
@@ -1185,10 +1180,10 @@ impl Pager {
page_idx: usize,
target_frame: u64,
seq: u32,
) -> Option<PageRef> {
) -> Result<Option<PageRef>> {
let mut page_cache = self.page_cache.write();
let page_key = PageCacheKey::new(page_idx);
page_cache.get(&page_key).and_then(|page| {
let page = page_cache.get(&page_key)?.and_then(|page| {
if page.is_valid_for_checkpoint(target_frame, seq) {
tracing::trace!(
"cache_get_for_checkpoint: page {} frame {} is valid",
@@ -1207,7 +1202,8 @@ impl Pager {
);
None
}
})
});
Ok(page)
}
/// Changes the size of the page cache.
@@ -1261,7 +1257,7 @@ impl Pager {
let page = {
let mut cache = self.page_cache.write();
let page_key = PageCacheKey::new(*page_id);
let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
let page = cache.get(&page_key)?.expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
let page_type = page.get_contents().maybe_page_type();
trace!("cacheflush(page={}, page_type={:?}", page_id, page_type);
page
@@ -1344,7 +1340,7 @@ impl Pager {
let page = {
let mut cache = self.page_cache.write();
let page_key = PageCacheKey::new(page_id);
let page = cache.get(&page_key).expect(
let page = cache.get(&page_key)?.expect(
"dirty list contained a page that cache dropped (page={page_id})",
);
trace!(
@@ -1482,7 +1478,7 @@ impl Pager {
header.db_size as u64,
raw_page,
)?;
if let Some(page) = self.cache_get(header.page_number as usize) {
if let Some(page) = self.cache_get(header.page_number as usize)? {
let content = page.get_contents();
content.as_ptr().copy_from_slice(raw_page);
turso_assert!(
@@ -1505,7 +1501,7 @@ impl Pager {
for page_id in self.dirty_pages.borrow().iter() {
let page_key = PageCacheKey::new(*page_id);
let mut cache = self.page_cache.write();
let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
let page = cache.get(&page_key)?.expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
page.clear_dirty();
}
self.dirty_pages.borrow_mut().clear();
@@ -1902,15 +1898,7 @@ impl Pager {
self.add_dirty(&page);
let page_key = PageCacheKey::new(page.get().id);
let mut cache = self.page_cache.write();
match cache.insert(page_key, page.clone()) {
Ok(_) => (),
Err(CacheError::Full) => return Err(LimboError::CacheFull),
Err(_) => {
return Err(LimboError::InternalError(
"Unknown error inserting page to cache".into(),
))
}
}
cache.insert(page_key, page.clone())?;
}
}
@@ -2081,15 +2069,7 @@ impl Pager {
{
// Run in separate block to avoid deadlock on page cache write lock
let mut cache = self.page_cache.write();
match cache.insert(page_key, page.clone()) {
Err(CacheError::Full) => return Err(LimboError::CacheFull),
Err(_) => {
return Err(LimboError::InternalError(
"Unknown error inserting page to cache".into(),
))
}
Ok(_) => {}
};
cache.insert(page_key, page.clone())?;
}
header.database_size = new_db_size.into();
*state = AllocatePageState::Start;
@@ -2186,7 +2166,8 @@ impl Pager {
}
pub fn set_encryption_context(&self, cipher_mode: CipherMode, key: &EncryptionKey) {
let encryption_ctx = EncryptionContext::new(cipher_mode, key).unwrap();
let page_size = self.page_size.get().unwrap().get() as usize;
let encryption_ctx = EncryptionContext::new(cipher_mode, key, page_size).unwrap();
{
let mut io_ctx = self.io_ctx.borrow_mut();
io_ctx.set_encryption(encryption_ctx);
@@ -2428,13 +2409,16 @@ mod tests {
std::thread::spawn(move || {
let mut cache = cache.write();
let page_key = PageCacheKey::new(1);
cache.insert(page_key, Arc::new(Page::new(1))).unwrap();
let page = Page::new(1);
// Set loaded so that we avoid eviction, as we evict the page from cache if it is not locked and not loaded
page.set_loaded();
cache.insert(page_key, Arc::new(page)).unwrap();
})
};
let _ = thread.join();
let mut cache = cache.write();
let page_key = PageCacheKey::new(1);
let page = cache.get(&page_key);
let page = cache.get(&page_key).unwrap();
assert_eq!(page.unwrap().get().id, 1);
}
}

View File

@@ -1838,7 +1838,7 @@ pub fn read_entire_wal_dumb(file: &Arc<dyn File>) -> Result<Arc<UnsafeCell<WalFi
pub fn begin_read_wal_frame_raw(
buffer_pool: &Arc<BufferPool>,
io: &Arc<dyn File>,
offset: usize,
offset: u64,
complete: Box<ReadComplete>,
) -> Result<Completion> {
tracing::trace!("begin_read_wal_frame_raw(offset={})", offset);
@@ -1851,7 +1851,7 @@ pub fn begin_read_wal_frame_raw(
pub fn begin_read_wal_frame(
io: &Arc<dyn File>,
offset: usize,
offset: u64,
buffer_pool: Arc<BufferPool>,
complete: Box<ReadComplete>,
page_idx: usize,

View File

@@ -1082,7 +1082,7 @@ impl Wal for WalFile {
});
begin_read_wal_frame(
&self.get_shared().file,
offset + WAL_FRAME_HEADER_SIZE,
offset + WAL_FRAME_HEADER_SIZE as u64,
buffer_pool,
complete,
page_idx,
@@ -1095,6 +1095,11 @@ impl Wal for WalFile {
tracing::debug!("read_frame({})", frame_id);
let offset = self.frame_offset(frame_id);
let (frame_ptr, frame_len) = (frame.as_mut_ptr(), frame.len());
let encryption_ctx = {
let io_ctx = self.io_ctx.borrow();
io_ctx.encryption_context().cloned()
};
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
return;
@@ -1104,10 +1109,34 @@ impl Wal for WalFile {
bytes_read == buf_len as i32,
"read({bytes_read}) != expected({buf_len})"
);
let buf_ptr = buf.as_mut_ptr();
let buf_ptr = buf.as_ptr();
let frame_ref: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(frame_ptr, frame_len) };
// Copy the just-read WAL frame into the destination buffer
unsafe {
std::ptr::copy_nonoverlapping(buf_ptr, frame_ptr, frame_len);
}
// Now parse the header from the freshly-copied data
let (header, raw_page) = sqlite3_ondisk::parse_wal_frame_header(frame_ref);
if let Some(ctx) = encryption_ctx.clone() {
match ctx.decrypt_page(raw_page, header.page_number as usize) {
Ok(decrypted_data) => {
turso_assert!(
(frame_len - WAL_FRAME_HEADER_SIZE) == decrypted_data.len(),
"frame_len - header_size({}) != expected({})",
frame_len - WAL_FRAME_HEADER_SIZE,
decrypted_data.len()
);
frame_ref[WAL_FRAME_HEADER_SIZE..].copy_from_slice(&decrypted_data);
}
Err(_) => {
tracing::error!("Failed to decrypt page data for frame_id={frame_id}");
}
}
}
});
let c =
begin_read_wal_frame_raw(&self.buffer_pool, &self.get_shared().file, offset, complete)?;
@@ -1167,7 +1196,7 @@ impl Wal for WalFile {
});
let c = begin_read_wal_frame(
&self.get_shared().file,
offset + WAL_FRAME_HEADER_SIZE,
offset + WAL_FRAME_HEADER_SIZE as u64,
buffer_pool,
complete,
page_id as usize,
@@ -1431,14 +1460,14 @@ impl Wal for WalFile {
let mut next_frame_id = self.max_frame + 1;
// Build every frame in order, updating the rolling checksum
for (idx, page) in pages.iter().enumerate() {
let page_id = page.get().id as u64;
let page_id = page.get().id;
let plain = page.get_contents().as_ptr();
let data_to_write: std::borrow::Cow<[u8]> = {
let io_ctx = self.io_ctx.borrow();
let ectx = io_ctx.encryption_context();
if let Some(ctx) = ectx.as_ref() {
Cow::Owned(ctx.encrypt_page(plain, page_id as usize)?)
Cow::Owned(ctx.encrypt_page(plain, page_id)?)
} else {
Cow::Borrowed(plain)
}
@@ -1552,11 +1581,10 @@ impl WalFile {
self.get_shared().wal_header.lock().page_size
}
fn frame_offset(&self, frame_id: u64) -> usize {
fn frame_offset(&self, frame_id: u64) -> u64 {
assert!(frame_id > 0, "Frame ID must be 1-based");
let page_offset = (frame_id - 1) * (self.page_size() + WAL_FRAME_HEADER_SIZE as u32) as u64;
let offset = WAL_HEADER_SIZE as u64 + page_offset;
offset as usize
WAL_HEADER_SIZE as u64 + page_offset
}
#[allow(clippy::mut_from_ref)]
@@ -1748,7 +1776,7 @@ impl WalFile {
// Try cache first, if enabled
if let Some(cached_page) =
pager.cache_get_for_checkpoint(page_id as usize, target_frame, seq)
pager.cache_get_for_checkpoint(page_id as usize, target_frame, seq)?
{
let contents = cached_page.get_contents();
let buffer = contents.buffer.clone();
@@ -1805,7 +1833,7 @@ impl WalFile {
self.ongoing_checkpoint.pages_to_checkpoint.iter()
{
if *cached {
let page = pager.cache_get((*page_id) as usize);
let page = pager.cache_get((*page_id) as usize)?;
turso_assert!(
page.is_some(),
"page should still exist in the page cache"
@@ -2102,7 +2130,7 @@ impl WalFile {
// schedule read of the page payload
let c = begin_read_wal_frame(
&self.get_shared().file,
offset + WAL_FRAME_HEADER_SIZE,
offset + WAL_FRAME_HEADER_SIZE as u64,
self.buffer_pool.clone(),
complete,
page_id,
@@ -2288,7 +2316,7 @@ pub mod test {
let done = Rc::new(Cell::new(false));
let _done = done.clone();
let _ = file.file.truncate(
WAL_HEADER_SIZE,
WAL_HEADER_SIZE as u64,
Completion::new_trunc(move |_| {
let done = _done.clone();
done.set(true);

View File

@@ -125,27 +125,161 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re
});
}
/// Emits the bytecode for processing an aggregate step.
/// E.g. in `SELECT SUM(price) FROM t`, 'price' is evaluated for every row, and the result is added to the accumulator.
/// Enum representing the source of the aggregate function arguments
///
/// This is distinct from the final step, which is called after the main loop has finished processing
/// Aggregate arguments can come from different sources, depending on how the aggregation
/// is evaluated:
/// * In the common grouped case, the aggregate function arguments are first inserted
/// into a sorter in the main loop, and in the group by aggregation phase we read
/// the data from the sorter.
/// * In grouped cases where no sorting is required, arguments are retrieved directly
/// from registers allocated in the main loop.
/// * In ungrouped cases, arguments are computed directly from the `args` expressions.
pub enum AggArgumentSource<'a> {
/// The aggregate function arguments are retrieved from a pseudo cursor
/// which reads from the GROUP BY sorter.
PseudoCursor {
cursor_id: usize,
col_start: usize,
dest_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved from a contiguous block of registers
/// allocated in the main loop for that given aggregate function.
Register {
src_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved by evaluating expressions.
Expression { aggregate: &'a Aggregate },
}
impl<'a> AggArgumentSource<'a> {
/// Create a new [AggArgumentSource] that retrieves the values from a GROUP BY sorter.
pub fn new_from_cursor(
program: &mut ProgramBuilder,
cursor_id: usize,
col_start: usize,
aggregate: &'a Aggregate,
) -> Self {
let dest_reg_start = program.alloc_registers(aggregate.args.len());
Self::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
aggregate,
}
}
/// Create a new [AggArgumentSource] that retrieves the values directly from an already
/// populated register or registers.
pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self {
Self::Register {
src_reg_start,
aggregate,
}
}
/// Create a new [AggArgumentSource] that retrieves the values by evaluating `args` expressions.
pub fn new_from_expression(aggregate: &'a Aggregate) -> Self {
Self::Expression { aggregate }
}
pub fn aggregate(&self) -> &Aggregate {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate,
AggArgumentSource::Register { aggregate, .. } => aggregate,
AggArgumentSource::Expression { aggregate } => aggregate,
}
}
pub fn agg_func(&self) -> &AggFunc {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func,
AggArgumentSource::Register { aggregate, .. } => &aggregate.func,
AggArgumentSource::Expression { aggregate } => &aggregate.func,
}
}
pub fn args(&self) -> &[ast::Expr] {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args,
AggArgumentSource::Register { aggregate, .. } => &aggregate.args,
AggArgumentSource::Expression { aggregate } => &aggregate.args,
}
}
pub fn num_args(&self) -> usize {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(),
AggArgumentSource::Register { aggregate, .. } => aggregate.args.len(),
AggArgumentSource::Expression { aggregate } => aggregate.args.len(),
}
}
/// Read the value of an aggregate function argument
pub fn translate(
&self,
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
resolver: &Resolver,
arg_idx: usize,
) -> Result<usize> {
match self {
AggArgumentSource::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
..
} => {
program.emit_column_or_rowid(
*cursor_id,
*col_start + arg_idx,
dest_reg_start + arg_idx,
);
Ok(dest_reg_start + arg_idx)
}
AggArgumentSource::Register {
src_reg_start: start_reg,
..
} => Ok(*start_reg + arg_idx),
AggArgumentSource::Expression { aggregate } => {
let dest_reg = program.alloc_register();
translate_expr(
program,
Some(referenced_tables),
&aggregate.args[arg_idx],
dest_reg,
resolver,
)
}
}
}
}
/// Emits the bytecode for processing an aggregate step.
///
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
/// and the actual result value of the aggregation is materialized.
///
/// Ungrouped aggregation is a special case of grouped aggregation that involves a single group.
///
/// Examples:
/// * In `SELECT SUM(price) FROM t`, `price` is evaluated for each row and added to the accumulator.
/// * In `SELECT product_category, SUM(price) FROM t GROUP BY product_category`, `price` is evaluated for
/// each row in the group and added to that groups accumulator.
pub fn translate_aggregation_step(
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
agg: &Aggregate,
agg_arg_source: AggArgumentSource,
target_register: usize,
resolver: &Resolver,
) -> Result<usize> {
let dest = match agg.func {
let num_args = agg_arg_source.num_args();
let func = agg_arg_source.agg_func();
let dest = match func {
AggFunc::Avg => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("avg bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -155,20 +289,16 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Count | AggFunc::Count0 => {
let expr_reg = if agg.args.is_empty() {
program.alloc_register()
} else {
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
expr_reg
};
handle_distinct(program, agg, expr_reg);
if num_args != 1 {
crate::bail_parse_error!("count bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: if matches!(agg.func, AggFunc::Count0) {
func: if matches!(func, AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
@@ -177,18 +307,16 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::GroupConcat => {
if agg.args.len() != 1 && agg.args.len() != 2 {
if num_args != 1 && num_args != 2 {
crate::bail_parse_error!("group_concat bad number of arguments");
}
let expr_reg = program.alloc_register();
let delimiter_reg = program.alloc_register();
let expr = &agg.args[0];
let delimiter_expr: ast::Expr;
if agg.args.len() == 2 {
match &agg.args[1] {
if num_args == 2 {
match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
@@ -201,8 +329,8 @@ pub fn translate_aggregation_step(
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
}
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
@@ -221,13 +349,12 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Max => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -238,13 +365,12 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Min => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -256,23 +382,12 @@ pub fn translate_aggregation_step(
}
#[cfg(feature = "json")]
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
if agg.args.len() != 2 {
if num_args != 2 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let value_expr = &agg.args[1];
let value_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let _ = translate_expr(
program,
Some(referenced_tables),
value_expr,
value_reg,
resolver,
)?;
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let value_reg = agg_arg_source.translate(program, referenced_tables, resolver, 1)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -284,13 +399,11 @@ pub fn translate_aggregation_step(
}
#[cfg(feature = "json")]
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -300,15 +413,13 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::StringAgg => {
if agg.args.len() != 2 {
if num_args != 2 {
crate::bail_parse_error!("string_agg bad number of arguments");
}
let expr_reg = program.alloc_register();
let delimiter_reg = program.alloc_register();
let expr = &agg.args[0];
let delimiter_expr = match &agg.args[1] {
let delimiter_expr = match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
@@ -316,7 +427,7 @@ pub fn translate_aggregation_step(
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
translate_expr(
program,
Some(referenced_tables),
@@ -335,13 +446,11 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Sum => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("sum bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -351,13 +460,11 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Total => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("total bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -367,31 +474,24 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::External(ref func) => {
let expr_reg = program.alloc_register();
let argc = func.agg_args().map_err(|_| {
LimboError::ExtensionError(
"External aggregate function called with wrong number of arguments".to_string(),
)
})?;
if argc != agg.args.len() {
if argc != num_args {
crate::bail_parse_error!(
"External aggregate function called with wrong number of arguments"
);
}
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
for i in 0..argc {
if i != 0 {
let _ = program.alloc_register();
let _ = agg_arg_source.translate(program, referenced_tables, resolver, i)?;
}
let _ = translate_expr(
program,
Some(referenced_tables),
&agg.args[i],
expr_reg + i,
resolver,
)?;
// invariant: distinct aggregates are only supported for single-argument functions
if argc == 1 {
handle_distinct(program, agg, expr_reg + i);
handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i);
}
}
program.emit_insn(Insn::AggStep {

View File

@@ -1,5 +1,8 @@
use std::sync::Arc;
use turso_parser::{ast, parser::Parser};
use turso_parser::{
ast::{self, fmt::ToTokens as _},
parser::Parser,
};
use crate::{
function::{AlterTableFunc, Func},
@@ -166,7 +169,7 @@ pub fn translate_alter_table(
)?
}
ast::AlterTableBody::AddColumn(col_def) => {
let column = Column::from(col_def);
let column = Column::from(&col_def);
if let Some(default) = &column.default {
if !matches!(
@@ -233,97 +236,6 @@ pub fn translate_alter_table(
},
)?
}
ast::AlterTableBody::RenameColumn { old, new } => {
let rename_from = old.as_str();
let rename_to = new.as_str();
let Some((column_index, _)) = btree.get_column(rename_from) else {
return Err(LimboError::ParseError(format!(
"no such column: \"{rename_from}\""
)));
};
if btree.get_column(rename_to).is_some() {
return Err(LimboError::ParseError(format!(
"duplicate column name: \"{rename_from}\""
)));
};
let sqlite_schema = schema
.get_btree_table(SQLITE_TABLEID)
.expect("sqlite_schema should be on schema");
let cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_schema.clone()));
program.emit_insn(Insn::OpenWrite {
cursor_id,
root_page: RegisterOrLiteral::Literal(sqlite_schema.root_page),
db: 0,
});
program.cursor_loop(cursor_id, |program, rowid| {
let sqlite_schema_column_len = sqlite_schema.columns.len();
assert_eq!(sqlite_schema_column_len, 5);
let first_column = program.alloc_registers(sqlite_schema_column_len);
for i in 0..sqlite_schema_column_len {
program.emit_column_or_rowid(cursor_id, i, first_column + i);
}
program.emit_string8_new_reg(table_name.to_string());
program.mark_last_insn_constant();
program.emit_string8_new_reg(rename_from.to_string());
program.mark_last_insn_constant();
program.emit_string8_new_reg(rename_to.to_string());
program.mark_last_insn_constant();
let out = program.alloc_registers(sqlite_schema_column_len);
program.emit_insn(Insn::Function {
constant_mask: 0,
start_reg: first_column,
dest: out,
func: crate::function::FuncCtx {
func: Func::AlterTable(AlterTableFunc::RenameColumn),
arg_count: 8,
},
});
let record = program.alloc_register();
program.emit_insn(Insn::MakeRecord {
start_reg: out,
count: sqlite_schema_column_len,
dest_reg: record,
index_name: None,
});
program.emit_insn(Insn::Insert {
cursor: cursor_id,
key_reg: rowid,
record_reg: record,
flag: crate::vdbe::insn::InsertFlags(0),
table_name: table_name.to_string(),
});
});
program.emit_insn(Insn::SetCookie {
db: 0,
cookie: Cookie::SchemaVersion,
value: schema.schema_version as i32 + 1,
p5: 0,
});
program.emit_insn(Insn::RenameColumn {
table: table_name.to_owned(),
column_index,
name: rename_to.to_owned(),
});
program
}
ast::AlterTableBody::RenameTo(new_name) => {
let new_name = new_name.as_str();
@@ -409,6 +321,148 @@ pub fn translate_alter_table(
to: new_name.to_owned(),
});
program
}
body @ (ast::AlterTableBody::AlterColumn { .. }
| ast::AlterTableBody::RenameColumn { .. }) => {
let from;
let definition;
let col_name;
let rename;
match body {
ast::AlterTableBody::AlterColumn { old, new } => {
from = old;
definition = new;
col_name = definition.col_name.clone();
rename = false;
}
ast::AlterTableBody::RenameColumn { old, new } => {
from = old;
definition = ast::ColumnDefinition {
col_name: new.clone(),
col_type: None,
constraints: vec![],
};
col_name = new;
rename = true;
}
_ => unreachable!(),
}
let from = from.as_str();
let col_name = col_name.as_str();
let Some((column_index, _)) = btree.get_column(from) else {
return Err(LimboError::ParseError(format!(
"no such column: \"{from}\""
)));
};
if btree.get_column(col_name).is_some() {
return Err(LimboError::ParseError(format!(
"duplicate column name: \"{col_name}\""
)));
};
if definition
.constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::PrimaryKey { .. }))
{
return Err(LimboError::ParseError(
"PRIMARY KEY constraint cannot be altered".to_string(),
));
}
if definition
.constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::Unique { .. }))
{
return Err(LimboError::ParseError(
"UNIQUE constraint cannot be altered".to_string(),
));
}
let sqlite_schema = schema
.get_btree_table(SQLITE_TABLEID)
.expect("sqlite_schema should be on schema");
let cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_schema.clone()));
program.emit_insn(Insn::OpenWrite {
cursor_id,
root_page: RegisterOrLiteral::Literal(sqlite_schema.root_page),
db: 0,
});
program.cursor_loop(cursor_id, |program, rowid| {
let sqlite_schema_column_len = sqlite_schema.columns.len();
assert_eq!(sqlite_schema_column_len, 5);
let first_column = program.alloc_registers(sqlite_schema_column_len);
for i in 0..sqlite_schema_column_len {
program.emit_column_or_rowid(cursor_id, i, first_column + i);
}
program.emit_string8_new_reg(table_name.to_string());
program.mark_last_insn_constant();
program.emit_string8_new_reg(from.to_string());
program.mark_last_insn_constant();
program.emit_string8_new_reg(definition.format().unwrap());
program.mark_last_insn_constant();
let out = program.alloc_registers(sqlite_schema_column_len);
program.emit_insn(Insn::Function {
constant_mask: 0,
start_reg: first_column,
dest: out,
func: crate::function::FuncCtx {
func: Func::AlterTable(if rename {
AlterTableFunc::RenameColumn
} else {
AlterTableFunc::AlterColumn
}),
arg_count: 8,
},
});
let record = program.alloc_register();
program.emit_insn(Insn::MakeRecord {
start_reg: out,
count: sqlite_schema_column_len,
dest_reg: record,
index_name: None,
});
program.emit_insn(Insn::Insert {
cursor: cursor_id,
key_reg: rowid,
record_reg: record,
flag: crate::vdbe::insn::InsertFlags(0),
table_name: table_name.to_string(),
});
});
program.emit_insn(Insn::SetCookie {
db: 0,
cookie: Cookie::SchemaVersion,
value: schema.schema_version as i32 + 1,
p5: 0,
});
program.emit_insn(Insn::AlterColumn {
table: table_name.to_owned(),
column_index,
definition,
rename,
});
program
}
})

View File

@@ -1,9 +1,16 @@
use turso_parser::ast;
use super::{
emitter::TranslateCtx,
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
order_by::order_by_sorter_insert,
plan::{Distinctness, GroupBy, SelectPlan},
result_row::emit_select_result,
};
use crate::translate::aggregation::{translate_aggregation_step, AggArgumentSource};
use crate::translate::expr::{walk_expr, WalkControl};
use crate::translate::plan::ResultSetColumn;
use crate::{
function::AggFunc,
schema::PseudoCursorType,
translate::collate::CollationSeq,
util::exprs_are_equivalent,
@@ -15,15 +22,6 @@ use crate::{
Result,
};
use super::{
aggregation::handle_distinct,
emitter::{Resolver, TranslateCtx},
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
order_by::order_by_sorter_insert,
plan::{Aggregate, Distinctness, GroupBy, SelectPlan, TableReferences},
result_row::emit_select_result,
};
/// Labels needed for various jumps in GROUP BY handling.
#[derive(Debug)]
pub struct GroupByLabels {
@@ -394,102 +392,6 @@ pub enum GroupByRowSource {
},
}
/// Enum representing the source of the aggregate function arguments
/// emitted for a group by aggregation.
/// In the common case, the aggregate function arguments are first inserted
/// into a sorter in the main loop, and in the group by aggregation phase
/// we read the data from the sorter.
///
/// In the alternative case, no sorting is required for group by,
/// and the aggregate function arguments are retrieved directly from
/// registers allocated in the main loop.
pub enum GroupByAggArgumentSource<'a> {
/// The aggregate function arguments are retrieved from a pseudo cursor
/// which reads from the GROUP BY sorter.
PseudoCursor {
cursor_id: usize,
col_start: usize,
dest_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved from a contiguous block of registers
/// allocated in the main loop for that given aggregate function.
Register {
src_reg_start: usize,
aggregate: &'a Aggregate,
},
}
impl<'a> GroupByAggArgumentSource<'a> {
/// Create a new [GroupByAggArgumentSource] that retrieves the values from a GROUP BY sorter.
pub fn new_from_cursor(
program: &mut ProgramBuilder,
cursor_id: usize,
col_start: usize,
aggregate: &'a Aggregate,
) -> Self {
let dest_reg_start = program.alloc_registers(aggregate.args.len());
Self::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
aggregate,
}
}
/// Create a new [GroupByAggArgumentSource] that retrieves the values directly from an already
/// populated register or registers.
pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self {
Self::Register {
src_reg_start,
aggregate,
}
}
pub fn aggregate(&self) -> &Aggregate {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate,
GroupByAggArgumentSource::Register { aggregate, .. } => aggregate,
}
}
pub fn agg_func(&self) -> &AggFunc {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func,
GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.func,
}
}
pub fn args(&self) -> &[ast::Expr] {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args,
GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.args,
}
}
pub fn num_args(&self) -> usize {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(),
GroupByAggArgumentSource::Register { aggregate, .. } => aggregate.args.len(),
}
}
/// Read the value of an aggregate function argument either from sorter data or directly from a register.
pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result<usize> {
match self {
GroupByAggArgumentSource::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
..
} => {
program.emit_column_or_rowid(*cursor_id, *col_start, dest_reg_start + arg_idx);
Ok(dest_reg_start + arg_idx)
}
GroupByAggArgumentSource::Register {
src_reg_start: start_reg,
..
} => Ok(*start_reg + arg_idx),
}
}
}
/// Emits bytecode for processing a single GROUP BY group.
pub fn group_by_process_single_group(
program: &mut ProgramBuilder,
@@ -593,21 +495,19 @@ pub fn group_by_process_single_group(
.expect("aggregate registers must be initialized");
let agg_result_reg = start_reg + i;
let agg_arg_source = match &row_source {
GroupByRowSource::Sorter { pseudo_cursor, .. } => {
GroupByAggArgumentSource::new_from_cursor(
program,
*pseudo_cursor,
cursor_index + offset,
agg,
)
}
GroupByRowSource::Sorter { pseudo_cursor, .. } => AggArgumentSource::new_from_cursor(
program,
*pseudo_cursor,
cursor_index + offset,
agg,
),
GroupByRowSource::MainLoop { start_reg_src, .. } => {
// Aggregation arguments are always placed in the registers that follow any scalars.
let start_reg_aggs = start_reg_src + t_ctx.non_aggregate_expressions.len();
GroupByAggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
AggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
}
};
translate_aggregation_step_groupby(
translate_aggregation_step(
program,
&plan.table_references,
agg_arg_source,
@@ -897,220 +797,3 @@ pub fn group_by_emit_row_phase<'a>(
program.preassign_label_to_next_insn(labels.label_group_by_end);
Ok(())
}
/// Emits the bytecode for processing an aggregate step within a GROUP BY clause.
/// Eg. in `SELECT product_category, SUM(price) FROM t GROUP BY line_item`, 'price' is evaluated for every row
/// where the 'product_category' is the same, and the result is added to the accumulator for that category.
///
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
/// and the actual result value of the aggregation is materialized.
pub fn translate_aggregation_step_groupby(
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
agg_arg_source: GroupByAggArgumentSource,
target_register: usize,
resolver: &Resolver,
) -> Result<usize> {
let num_args = agg_arg_source.num_args();
let dest = match agg_arg_source.agg_func() {
AggFunc::Avg => {
if num_args != 1 {
crate::bail_parse_error!("avg bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Avg,
});
target_register
}
AggFunc::Count | AggFunc::Count0 => {
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: if matches!(agg_arg_source.agg_func(), AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
},
});
target_register
}
AggFunc::GroupConcat => {
let num_args = agg_arg_source.num_args();
if num_args != 1 && num_args != 2 {
crate::bail_parse_error!("group_concat bad number of arguments");
}
let delimiter_reg = program.alloc_register();
let delimiter_expr: ast::Expr;
if num_args == 2 {
match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
ast::Expr::Literal(ast::Literal::String(s)) => {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string()));
}
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
} else {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
&delimiter_expr,
delimiter_reg,
resolver,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: delimiter_reg,
func: AggFunc::GroupConcat,
});
target_register
}
AggFunc::Max => {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Max,
});
target_register
}
AggFunc::Min => {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Min,
});
target_register
}
#[cfg(feature = "json")]
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::JsonGroupArray,
});
target_register
}
#[cfg(feature = "json")]
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
if num_args != 2 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let value_reg = agg_arg_source.translate(program, 1)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: value_reg,
func: AggFunc::JsonGroupObject,
});
target_register
}
AggFunc::StringAgg => {
if num_args != 2 {
crate::bail_parse_error!("string_agg bad number of arguments");
}
let delimiter_reg = program.alloc_register();
let delimiter_expr = match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
}
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
&delimiter_expr,
delimiter_reg,
resolver,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: delimiter_reg,
func: AggFunc::StringAgg,
});
target_register
}
AggFunc::Sum => {
if num_args != 1 {
crate::bail_parse_error!("sum bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Sum,
});
target_register
}
AggFunc::Total => {
if num_args != 1 {
crate::bail_parse_error!("total bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Total,
});
target_register
}
AggFunc::External(_) => {
todo!("External aggregate functions are not yet supported in GROUP BY");
}
};
Ok(dest)
}

View File

@@ -19,7 +19,7 @@ use crate::{
};
use super::{
aggregation::translate_aggregation_step,
aggregation::{translate_aggregation_step, AggArgumentSource},
emitter::{OperationMode, TranslateCtx},
expr::{
translate_condition_expr, translate_expr, translate_expr_no_constant_opt,
@@ -868,7 +868,7 @@ fn emit_loop_source(
translate_aggregation_step(
program,
&plan.table_references,
agg,
AggArgumentSource::new_from_expression(agg),
reg,
&t_ctx.resolver,
)?;

View File

@@ -1048,6 +1048,24 @@ pub struct Aggregate {
}
impl Aggregate {
pub fn new(func: AggFunc, args: &[Box<Expr>], expr: &Expr, distinctness: Distinctness) -> Self {
let agg_args = if args.is_empty() {
// The AggStep instruction requires at least one argument. For functions that accept
// zero arguments (e.g. COUNT()), we insert a dummy literal so that AggStep remains valid.
// This does not cause ambiguity: the resolver has already verified that the function
// takes zero arguments, so the dummy value will be ignored.
vec![Expr::Literal(ast::Literal::Numeric("1".to_string()))]
} else {
args.iter().map(|arg| *arg.clone()).collect()
};
Aggregate {
func,
args: agg_args,
original_expr: expr.clone(),
distinctness,
}
}
pub fn is_distinct(&self) -> bool {
self.distinctness.is_distinct()
}

View File

@@ -73,12 +73,7 @@ pub fn resolve_aggregates(
"DISTINCT aggregate functions must have exactly one argument"
);
}
aggs.push(Aggregate {
func: f,
args: args.iter().map(|arg| *arg.clone()).collect(),
original_expr: expr.clone(),
distinctness,
});
aggs.push(Aggregate::new(f, args, expr, distinctness));
contains_aggregates = true;
}
_ => {
@@ -95,12 +90,7 @@ pub fn resolve_aggregates(
);
}
if let Ok(Func::Agg(f)) = Func::resolve_function(name.as_str(), 0) {
aggs.push(Aggregate {
func: f,
args: vec![],
original_expr: expr.clone(),
distinctness: Distinctness::NonDistinct,
});
aggs.push(Aggregate::new(f, &[], expr, Distinctness::NonDistinct));
contains_aggregates = true;
}
}

View File

@@ -371,27 +371,7 @@ fn prepare_one_select_plan(
}
match Func::resolve_function(name.as_str(), args_count) {
Ok(Func::Agg(f)) => {
let agg_args = match (args.is_empty(), &f) {
(true, crate::function::AggFunc::Count0) => {
// COUNT() case
vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))
.into()]
}
(true, _) => crate::bail_parse_error!(
"Aggregate function {} requires arguments",
name.as_str()
),
(false, _) => args.clone(),
};
let agg = Aggregate {
func: f,
args: agg_args.iter().map(|arg| *arg.clone()).collect(),
original_expr: *expr.clone(),
distinctness,
};
let agg = Aggregate::new(f, args, expr, distinctness);
aggregate_expressions.push(agg);
plan.result_columns.push(ResultSetColumn {
alias: maybe_alias.as_ref().map(|alias| match alias {
@@ -446,15 +426,12 @@ fn prepare_one_select_plan(
contains_aggregates,
});
} else {
let agg = Aggregate {
func: AggFunc::External(f.func.clone().into()),
args: args
.iter()
.map(|arg| *arg.clone())
.collect(),
original_expr: *expr.clone(),
let agg = Aggregate::new(
AggFunc::External(f.func.clone().into()),
args,
expr,
distinctness,
};
);
aggregate_expressions.push(agg);
plan.result_columns.push(ResultSetColumn {
alias: maybe_alias.as_ref().map(|alias| {
@@ -488,14 +465,8 @@ fn prepare_one_select_plan(
}
match Func::resolve_function(name.as_str(), 0) {
Ok(Func::Agg(f)) => {
let agg = Aggregate {
func: f,
args: vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))],
original_expr: *expr.clone(),
distinctness: Distinctness::NonDistinct,
};
let agg =
Aggregate::new(f, &[], expr, Distinctness::NonDistinct);
aggregate_expressions.push(agg);
plan.result_columns.push(ResultSetColumn {
alias: maybe_alias.as_ref().map(|alias| match alias {

View File

@@ -14,7 +14,7 @@ use crate::translate::plan::IterationDirection;
use crate::vdbe::sorter::Sorter;
use crate::vdbe::Register;
use crate::vtab::VirtualTableCursor;
use crate::{turso_assert, Completion, Result, IO};
use crate::{turso_assert, Completion, CompletionError, Result, IO};
use std::fmt::{Debug, Display};
const MAX_REAL_SIZE: u8 = 15;
@@ -350,6 +350,13 @@ impl Value {
}
}
pub fn as_uint(&self) -> u64 {
match self {
Value::Integer(i) => (*i).cast_unsigned(),
_ => 0,
}
}
pub fn from_text(text: &str) -> Self {
Value::Text(Text::new(text))
}
@@ -2502,6 +2509,13 @@ impl IOCompletions {
IOCompletions::Many(completions) => completions.iter().for_each(|c| c.abort()),
}
}
pub fn get_error(&self) -> Option<CompletionError> {
match self {
IOCompletions::Single(c) => c.get_error(),
IOCompletions::Many(completions) => completions.iter().find_map(|c| c.get_error()),
}
}
}
#[derive(Debug)]

View File

@@ -703,58 +703,7 @@ pub fn columns_from_create_table_body(
use turso_parser::ast;
Ok(columns
.iter()
.map(
|ast::ColumnDefinition {
col_name: name,
col_type,
constraints,
}| {
Column {
name: Some(normalize_ident(name.as_str())),
ty: match col_type {
Some(ref data_type) => type_from_name(data_type.name.as_str()).0,
None => Type::Null,
},
default: constraints.iter().find_map(|c| match &c.constraint {
ast::ColumnConstraint::Default(val) => Some(val.clone()),
_ => None,
}),
notnull: constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::NotNull { .. })),
ty_str: col_type
.clone()
.map(|t| t.name.to_string())
.unwrap_or_default(),
primary_key: constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::PrimaryKey { .. })),
is_rowid_alias: false,
unique: constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::Unique(..))),
collation: constraints.iter().find_map(|c| match &c.constraint {
// TODO: see if this should be the correct behavior
// currently there cannot be any user defined collation sequences.
// But in the future, when a user defines a collation sequence, creates a table with it,
// then closes the db and opens it again. This may panic here if the collation seq is not registered
// before reading the columns
ast::ColumnConstraint::Collate { collation_name } => Some(
CollationSeq::new(collation_name.as_str())
.expect("collation should have been set correctly in create table"),
),
_ => None,
}),
hidden: col_type
.as_ref()
.map(|data_type| data_type.name.as_str().contains("HIDDEN"))
.unwrap_or(false),
}
},
)
.collect::<Vec<_>>())
Ok(columns.iter().map(Into::into).collect())
}
/// This function checks if a given expression is a constant value that can be pushed down to the database engine.
@@ -803,6 +752,10 @@ pub struct OpenOptions<'a> {
pub cache: CacheMode,
/// immutable=1|0 specifies that the database is stored on read-only media
pub immutable: bool,
// The encryption cipher
pub cipher: Option<String>,
// The encryption key in hex format
pub hexkey: Option<String>,
}
pub const MEMORY_PATH: &str = ":memory:";
@@ -954,6 +907,8 @@ fn parse_query_params(query: &str, opts: &mut OpenOptions) -> Result<()> {
"cache" => opts.cache = decoded_value.as_str().into(),
"immutable" => opts.immutable = decoded_value == "1",
"vfs" => opts.vfs = Some(decoded_value),
"cipher" => opts.cipher = Some(decoded_value),
"hexkey" => opts.hexkey = Some(decoded_value),
_ => {}
}
}

View File

@@ -4461,10 +4461,16 @@ pub fn op_function(
}
}
ScalarFunc::SqliteVersion => {
let version_integer =
return_if_io!(pager.with_header(|header| header.version_number)).get() as i64;
let version = execute_sqlite_version(version_integer);
state.registers[*dest] = Register::Value(Value::build_text(version));
if !program.connection.is_db_initialized() {
state.registers[*dest] =
Register::Value(Value::build_text(info::build::PKG_VERSION));
} else {
let version_integer =
return_if_io!(pager.with_header(|header| header.version_number)).get()
as i64;
let version = execute_sqlite_version(version_integer);
state.registers[*dest] = Register::Value(Value::build_text(version));
}
}
ScalarFunc::SqliteSourceId => {
let src_id = format!(
@@ -4852,10 +4858,10 @@ pub fn op_function(
match stmt {
ast::Stmt::CreateIndex {
tbl_name,
unique,
if_not_exists,
idx_name,
tbl_name,
columns,
where_clause,
} => {
@@ -4867,10 +4873,10 @@ pub fn op_function(
Some(
ast::Stmt::CreateIndex {
tbl_name: ast::Name::new(&rename_to),
unique,
if_not_exists,
idx_name,
tbl_name: ast::Name::new(&rename_to),
columns,
where_clause,
}
@@ -4879,9 +4885,9 @@ pub fn op_function(
)
}
ast::Stmt::CreateTable {
tbl_name,
temporary,
if_not_exists,
tbl_name,
body,
} => {
let table_name = normalize_ident(tbl_name.name.as_str());
@@ -4892,13 +4898,13 @@ pub fn op_function(
Some(
ast::Stmt::CreateTable {
temporary,
if_not_exists,
tbl_name: ast::QualifiedName {
db_name: None,
name: ast::Name::new(&rename_to),
alias: None,
},
temporary,
if_not_exists,
body,
}
.format()
@@ -4911,7 +4917,7 @@ pub fn op_function(
(new_name, new_tbl_name, new_sql)
}
AlterTableFunc::RenameColumn => {
AlterTableFunc::AlterColumn | AlterTableFunc::RenameColumn => {
let table = {
match &state.registers[*start_reg + 5].get_value() {
Value::Text(rename_to) => normalize_ident(rename_to.as_str()),
@@ -4926,13 +4932,17 @@ pub fn op_function(
}
};
let rename_to = {
let column_def = {
match &state.registers[*start_reg + 7].get_value() {
Value::Text(rename_to) => normalize_ident(rename_to.as_str()),
Value::Text(column_def) => column_def.as_str(),
_ => panic!("rename_to parameter should be TEXT"),
}
};
let column_def = Parser::new(column_def.as_bytes())
.parse_column_definition(true)
.unwrap();
let new_sql = 'sql: {
if table != tbl_name {
break 'sql None;
@@ -4949,11 +4959,11 @@ pub fn op_function(
match stmt {
ast::Stmt::CreateIndex {
tbl_name,
mut columns,
unique,
if_not_exists,
idx_name,
tbl_name,
mut columns,
where_clause,
} => {
if table != normalize_ident(tbl_name.as_str()) {
@@ -4965,7 +4975,7 @@ pub fn op_function(
ast::Expr::Id(ast::Name::Ident(id))
if normalize_ident(id) == rename_from =>
{
*id = rename_to.clone();
*id = column_def.col_name.as_str().to_owned();
}
_ => {}
}
@@ -4973,11 +4983,11 @@ pub fn op_function(
Some(
ast::Stmt::CreateIndex {
tbl_name,
columns,
unique,
if_not_exists,
idx_name,
tbl_name,
columns,
where_clause,
}
.format()
@@ -4985,10 +4995,10 @@ pub fn op_function(
)
}
ast::Stmt::CreateTable {
temporary,
if_not_exists,
tbl_name,
body,
temporary,
if_not_exists,
} => {
if table != normalize_ident(tbl_name.name.as_str()) {
break 'sql None;
@@ -5008,18 +5018,24 @@ pub fn op_function(
.find(|column| column.col_name == ast::Name::new(&rename_from))
.expect("column being renamed should be present");
column.col_name = ast::Name::new(&rename_to);
match alter_func {
AlterTableFunc::AlterColumn => *column = column_def,
AlterTableFunc::RenameColumn => {
column.col_name = column_def.col_name
}
_ => unreachable!(),
}
Some(
ast::Stmt::CreateTable {
temporary,
if_not_exists,
tbl_name,
body: ast::CreateTableBody::ColumnsAndConstraints {
columns,
constraints,
options,
},
temporary,
if_not_exists,
}
.format()
.unwrap(),
@@ -7303,7 +7319,7 @@ pub fn op_add_column(
Ok(InsnFunctionStepResult::Step)
}
pub fn op_rename_column(
pub fn op_alter_column(
program: &Program,
state: &mut ProgramState,
insn: &Insn,
@@ -7311,16 +7327,19 @@ pub fn op_rename_column(
mv_store: Option<&Arc<MvStore>>,
) -> Result<InsnFunctionStepResult> {
load_insn!(
RenameColumn {
AlterColumn {
table: table_name,
column_index,
name
definition,
rename,
},
insn
);
let conn = program.connection.clone();
let new_column = crate::schema::Column::from(definition);
conn.with_schema_mut(|schema| {
let table = schema
.tables
@@ -7347,13 +7366,17 @@ pub fn op_rename_column(
if index_column.name
== *column.name.as_ref().expect("btree column should be named")
{
index_column.name = name.to_owned();
index_column.name = definition.col_name.as_str().to_owned();
}
}
}
}
column.name = Some(name.to_owned());
if *rename {
column.name = new_column.name;
} else {
*column = new_column;
}
});
state.pc += 1;

View File

@@ -1672,14 +1672,14 @@ pub fn insn_to_str(
0,
format!("add_column({table}, {column:?})"),
),
Insn::RenameColumn { table, column_index, name } => (
"RenameColumn",
Insn::AlterColumn { table, column_index, definition: column, rename } => (
"AlterColumn",
0,
0,
0,
Value::build_text(""),
0,
format!("rename_column({table}, {column_index}, {name})"),
format!("alter_column({table}, {column_index}, {column:?}, {rename:?})"),
),
Insn::MaxPgcnt { db, dest, new_max } => (
"MaxPgcnt",

View File

@@ -1053,10 +1053,11 @@ pub enum Insn {
table: String,
column: Column,
},
RenameColumn {
AlterColumn {
table: String,
column_index: usize,
name: String,
definition: turso_parser::ast::ColumnDefinition,
rename: bool,
},
/// Try to set the maximum page count for database P1 to the value in P3.
/// Do not let the maximum page count fall below the current page count and
@@ -1209,7 +1210,7 @@ impl Insn {
Insn::RenameTable { .. } => execute::op_rename_table,
Insn::DropColumn { .. } => execute::op_drop_column,
Insn::AddColumn { .. } => execute::op_add_column,
Insn::RenameColumn { .. } => execute::op_rename_column,
Insn::AlterColumn { .. } => execute::op_alter_column,
Insn::MaxPgcnt { .. } => execute::op_max_pgcnt,
Insn::JournalMode { .. } => execute::op_journal_mode,
}

View File

@@ -460,6 +460,11 @@ impl Program {
if !io.finished() {
return Ok(StepResult::IO);
}
if let Some(err) = io.get_error() {
let err = err.into();
handle_program_error(&pager, &self.connection, &err)?;
return Err(err);
}
state.io_completions = None;
}
// invalidate row

View File

@@ -370,7 +370,7 @@ struct SortedChunk {
/// The chunk file.
file: Arc<dyn File>,
/// Offset of the start of chunk in file
start_offset: usize,
start_offset: u64,
/// The size of this chunk file in bytes.
chunk_size: usize,
/// The read buffer.
@@ -391,7 +391,7 @@ impl SortedChunk {
fn new(file: Arc<dyn File>, start_offset: usize, buffer_size: usize) -> Self {
Self {
file,
start_offset,
start_offset: start_offset as u64,
chunk_size: 0,
buffer: Rc::new(RefCell::new(vec![0; buffer_size])),
buffer_len: Rc::new(Cell::new(0)),
@@ -522,7 +522,7 @@ impl SortedChunk {
let c = Completion::new_read(read_buffer_ref, read_complete);
let c = self
.file
.pread(self.start_offset + self.total_bytes_read.get(), c)?;
.pread(self.start_offset + self.total_bytes_read.get() as u64, c)?;
Ok(c)
}

View File

@@ -38,6 +38,7 @@ Welcome to Turso database manual!
- [WAL manipulation](#wal-manipulation)
- [`libsql_wal_frame_count`](#libsql_wal_frame_count)
- [Encryption](#encryption)
- [CDC](#cdc-early-preview)
- [Appendix A: Turso Internals](#appendix-a-turso-internals)
- [Frontend](#frontend)
- [Parser](#parser)
@@ -510,6 +511,114 @@ PRAGMA cipher = 'aegis256'; -- or 'aes256gcm'
PRAGMA hexkey = '2d7a30108d3eb3e45c90a732041fe54778bdcf707c76749fab7da335d1b39c1d';
```
## CDC (Early Preview)
Turso supports [Change Data Capture](https://en.wikipedia.org/wiki/Change_data_capture), a powerful pattern for tracking and recording changes to your database in real-time. Instead of periodically scanning tables to find what changed, CDC automatically logs every insert, update, and delete as it happens per connection.
### Enabling CDC
```sql
PRAGMA unstable_capture_data_changes_conn('<mode>[,custom_cdc_table]');
```
### Parameters
- `<mode>` can be:
- `off`: Turn off CDC for the connection
- `id`: Logs only the `rowid` (most compact)
- `before`: Captures row state before updates and deletes
- `after`: Captures row state after inserts and updates
- `full`: Captures both before and after states (recommended for complete audit trail)
- `custom_cdc` is optional, It lets you specify a custom table to capture changes.
If no table is provided, Turso uses a default `turso_cdc` table.
When **Change Data Capture (CDC)** is enabled for a connection, Turso automatically logs all modifications from that connection into a dedicated table (default: `turso_cdc`). This table records each change with details about the operation, the affected row or schema object, and its state **before** and **after** the modification.
> **Note:** Currently, the CDC table is a regular table stored explicitly on disk. If you use full CDC mode and update rows frequently, each update of size N bytes will be written three times to disk (once for the before state, once for the after state, and once for the actual value in the WAL). Frequent updates in full mode can therefore significantly increase disk I/O.
- **`change_id` (INTEGER)**
A monotonically increasing integer uniquely identifying each change record.(guaranteed by turso-db)
- Always strictly increasing.
- Serves as the primary key.
- **`change_time` (INTEGER)**
> turso-db guarantee nothing about properties of the change_time sequence
Local timestamp (Unix epoch, seconds) when the change was recorded.
- Not guaranteed to be strictly increasing (can drift or repeat).
- **`change_type` (INTEGER)**
Indicates the type of operation:
- `1` → INSERT
- `0` → UPDATE (also used for ALTER TABLE)
- `-1` → DELETE (also covers DROP TABLE, DROP INDEX)
- **`table_name` (TEXT)**
Name of the affected table.
- For schema changes (DDL), this is always `"sqlite_schema"`.
- **`id` (INTEGER)**
Rowid of the affected row in the source table.
- For DDL operations: rowid of the `sqlite_schema` entry.
- **Note:** `WITHOUT ROWID` tables are not supported in the tursodb and CDC
- **`before` (BLOB)**
Full state of the row/schema **before** an UPDATE or DELETE
- NULL for INSERT.
- For DDL changes, may contain the definition of the object before modification.
- **`after` (BLOB)**
Full state of the row/schema **after** an INSERT or UPDATE
- NULL for DELETE.
- For DDL changes, may contain the definition of the object after modification.
- **`updates` (BLOB)**
Granular details about the change.
- For UPDATE: shows specific column modifications.
> CDC records are visible even before a transaction commits.
> Operations that fail (e.g., constraint violations) are not recorded in CDC.
> Changes to the CDC table itself are also logged to CDC table. if CDC is enabled for that connection.
```zsh
Example:
turso> PRAGMA unstable_capture_data_changes_conn('full');
turso> .tables
turso_cdc
turso> CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT
);
turso> INSERT INTO users VALUES (1, 'John'), (2, 'Jane');
UPDATE users SET name='John Doe' WHERE id=1;
DELETE FROM users WHERE id=2;
SELECT * FROM turso_cdc;
┌───────────┬─────────────┬─────────────┬───────────────┬────┬──────────┬──────────────────────────────────────────────────────────────────────────────┬───────────────┐
│ change_id │ change_time │ change_type │ table_name │ id │ before │ after │ updates │
├───────────┼─────────────┼─────────────┼───────────────┼────┼──────────┼──────────────────────────────────────────────────────────────────────────────┼───────────────┤
117567131611 │ sqlite_schema │ 2 │ │ ytableusersusersCREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT) │ │
├───────────┼─────────────┼─────────────┼───────────────┼────┼──────────┼──────────────────────────────────────────────────────────────────────────────┼───────────────┤
217567131761 │ users │ 1 │ │ John │ │
├───────────┼─────────────┼─────────────┼───────────────┼────┼──────────┼──────────────────────────────────────────────────────────────────────────────┼───────────────┤
317567131761 │ users │ 2 │ │ Jane │ │
├───────────┼─────────────┼─────────────┼───────────────┼────┼──────────┼──────────────────────────────────────────────────────────────────────────────┼───────────────┤
417567131760 │ users │ 1 │ John │ John Doe │ John Doe │
├───────────┼─────────────┼─────────────┼───────────────┼────┼──────────┼──────────────────────────────────────────────────────────────────────────────┼───────────────┤
51756713176 │ -1 │ users │ 2 │ Jane │ │ │
└───────────┴─────────────┴─────────────┴───────────────┴────┴──────────┴──────────────────────────────────────────────────────────────────────────────┴───────────────┘
turso>
```
If you modify your table schema (adding/dropping columns), the `table_columns_json_array()` function returns the current schema, not the historical one. This can lead to incorrect results when decoding older CDC records. Manually track schema versions by storing the output of `table_columns_json_array()` before making schema changes.
## Appendix A: Turso Internals
Turso's architecture resembles SQLite's but differs primarily in its

View File

@@ -982,6 +982,8 @@ pub enum AlterTableBody {
RenameTo(Name),
/// `ADD COLUMN`
AddColumn(ColumnDefinition), // TODO distinction between ADD and ADD COLUMN
/// `ALTER COLUMN`
AlterColumn { old: Name, new: ColumnDefinition },
/// `RENAME COLUMN`
RenameColumn {
/// old name

View File

@@ -1409,6 +1409,13 @@ impl ToTokens for AlterTableBody {
s.append(TK_COLUMNKW, None)?;
def.to_tokens_with_context(s, context)
}
Self::AlterColumn { old, new } => {
s.append(TK_ALTER, None)?;
s.append(TK_COLUMNKW, None)?;
old.to_tokens_with_context(s, context)?;
s.append(TK_TO, None)?;
new.to_tokens_with_context(s, context)
}
Self::RenameColumn { old, new } => {
s.append(TK_RENAME, None)?;
s.append(TK_COLUMNKW, None)?;

View File

@@ -169,7 +169,7 @@ pub fn is_identifier_continue(b: u8) -> bool {
|| b > b'\x7F'
}
#[derive(Clone, PartialEq, Eq)] // do not derive Copy for Token, just use .clone() when needed
#[derive(Clone, PartialEq, Eq, Debug)] // do not derive Copy for Token, just use .clone() when needed
pub struct Token<'a> {
pub value: &'a [u8],
pub token_type: Option<TokenType>, // None means Token is whitespaces or comments

View File

@@ -419,7 +419,8 @@ impl<'a> Parser<'a> {
}
TK_COLUMNKW => {
let prev_tt = self.current_token.token_type.unwrap_or(TK_EOF);
let can_be_columnkw = matches!(prev_tt, TK_ADD | TK_RENAME | TK_DROP);
let can_be_columnkw =
matches!(prev_tt, TK_ADD | TK_RENAME | TK_DROP | TK_ALTER);
if !can_be_columnkw {
tok.token_type = Some(TK_ID);
@@ -1536,7 +1537,21 @@ impl<'a> Parser<'a> {
Name::Ident(s) => Literal::String(s),
})))
} else {
Ok(Box::new(Expr::Id(name)))
match name {
Name::Ident(s) => {
let s_bytes = s.as_bytes();
match_ignore_ascii_case!(match s_bytes {
b"true" => {
Ok(Box::new(Expr::Literal(Literal::Numeric("1".into()))))
}
b"false" => {
Ok(Box::new(Expr::Literal(Literal::Numeric("0".into()))))
}
_ => return Ok(Box::new(Expr::Id(Name::Ident(s)))),
})
}
_ => Ok(Box::new(Expr::Id(name))),
}
}
}
}
@@ -3400,7 +3415,7 @@ impl<'a> Parser<'a> {
Ok(result)
}
fn parse_column_definition(&mut self, in_alter: bool) -> Result<ColumnDefinition> {
pub fn parse_column_definition(&mut self, in_alter: bool) -> Result<ColumnDefinition> {
let col_name = self.parse_nm()?;
if !in_alter && col_name.as_str().eq_ignore_ascii_case("rowid") {
return Err(Error::Custom("cannot use reserved word: ROWID".to_owned()));
@@ -3419,7 +3434,7 @@ impl<'a> Parser<'a> {
eat_assert!(self, TK_ALTER);
eat_expect!(self, TK_TABLE);
let tbl_name = self.parse_fullname(false)?;
let tok = eat_expect!(self, TK_ADD, TK_DROP, TK_RENAME);
let tok = eat_expect!(self, TK_ADD, TK_DROP, TK_RENAME, TK_ALTER);
match tok.token_type.unwrap() {
TK_ADD => {
@@ -3470,6 +3485,19 @@ impl<'a> Parser<'a> {
}))
}
}
TK_ALTER => {
eat_expect!(self, TK_COLUMNKW);
let col_name = self.parse_nm()?;
eat_expect!(self, TK_TO);
let new = self.parse_column_definition(true)?;
Ok(Stmt::AlterTable(AlterTable {
name: tbl_name,
body: AlterTableBody::AlterColumn { old: col_name, new },
}))
}
_ => unreachable!(),
}
}
@@ -9837,6 +9865,24 @@ mod tests {
}),
}))],
),
(
b"ALTER TABLE foo ALTER COLUMN bar TO baz INTEGER".as_slice(),
vec![Cmd::Stmt(Stmt::AlterTable (AlterTable {
name: QualifiedName { db_name: None, name: Name::Ident("foo".to_owned()), alias: None },
body: AlterTableBody::AlterColumn {
old: Name::Ident("bar".to_owned()),
new: ColumnDefinition {
col_name: Name::Ident("baz".to_owned()),
col_type: Some(Type {
name: "INTEGER".to_owned(),
size: None,
}),
constraints: vec![],
},
},
}))],
),
// parse create index
(
b"CREATE INDEX idx_foo ON foo (bar)".as_slice(),

View File

@@ -3,5 +3,6 @@
cargo publish -p turso_macros
cargo publish -p turso_ext
cargo publish -p turso_sqlite3_parser
cargo publish -p turso_parser
cargo publish -p turso_core
cargo publish -p turso

View File

@@ -1,4 +1,4 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.87.0 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.88.0 AS chef
RUN apt update \
&& apt install -y git libssl-dev pkg-config\
&& apt clean \

1
simulator/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
configs/custom

View File

@@ -4,7 +4,7 @@
name = "limbo_sim"
version.workspace = true
authors.workspace = true
edition.workspace = true
edition = "2024"
license.workspace = true
repository.workspace = true
description = "The Limbo deterministic simulator"
@@ -38,3 +38,7 @@ hex = "0.4.3"
itertools = "0.14.0"
sql_generation = { workspace = true }
turso_parser = { workspace = true }
schemars = { workspace = true }
garde = { workspace = true, features = ["derive", "serde"] }
json5 = { version = "0.4.1" }
strum = { workspace = true }

View File

@@ -106,6 +106,18 @@ it should generate the necessary queries and assertions for the property.
You can use the `--differential` flag to run the simulator in differential testing mode. This mode will run the same interaction plan on both Limbo and SQLite, and compare the results. It will also check for any panics or errors in either database.
## Simulator Profiles
A Simulator Profile allows you to influence query generation and I/O fault injection. You can run predefined profiles or you can create your own custom profile in a separate JSON file. You can select the profile you want by passing the `--profile` flag to he CLI. It will accept a predefined Profile name or a file path.
For development purposes, you can run `make sim-schema` to generate a JsonSchema of the `Profile` struct. Then you can create profiles to test locally in a `configs/custom` folder that is gitignored and have editor integration by adding `$schema` tag to reference the generated JsonSchema:
```json
{
"$schema": "./profile-schema.json",
...
}
```
## Resources
- [(reading) TigerBeetle Deterministic Simulation Testing](https://docs.tigerbeetle.com/about/vopr/)

View File

@@ -25,9 +25,17 @@ impl GenerationContext for SimulatorEnv {
&self.tables.tables
}
fn opts(&self) -> sql_generation::generation::Opts {
sql_generation::generation::Opts {
indexes: self.opts.experimental_indexes,
}
fn opts(&self) -> &sql_generation::generation::Opts {
&self.profile.query.gen_opts
}
}
impl GenerationContext for &mut SimulatorEnv {
fn tables(&self) -> &Vec<sql_generation::model::table::Table> {
&self.tables.tables
}
fn opts(&self) -> &sql_generation::generation::Opts {
&self.profile.query.gen_opts
}
}

View File

@@ -9,25 +9,25 @@ use std::{
use serde::{Deserialize, Serialize};
use sql_generation::{
generation::{frequency, query::SelectFree, Arbitrary, ArbitraryFrom},
generation::{Arbitrary, ArbitraryFrom, GenerationContext, frequency, query::SelectFree},
model::{
query::{update::Update, Create, CreateIndex, Delete, Drop, Insert, Select},
query::{Create, CreateIndex, Delete, Drop, Insert, Select, update::Update},
table::SimValue,
},
};
use turso_core::{Connection, Result, StepResult};
use crate::{
SimulatorEnv,
generation::Shadow,
model::Query,
runner::{
env::{SimConnection, SimulationType, SimulatorTables},
io::SimulatorIO,
},
SimulatorEnv,
};
use super::property::{remaining, Property};
use super::property::{Property, remaining};
pub(crate) type ResultSet = Result<Vec<Vec<SimValue>>>;
@@ -254,16 +254,27 @@ impl Display for InteractionPlan {
#[derive(Debug, Clone, Copy)]
pub(crate) struct InteractionStats {
pub(crate) read_count: usize,
pub(crate) write_count: usize,
pub(crate) delete_count: usize,
pub(crate) update_count: usize,
pub(crate) create_count: usize,
pub(crate) create_index_count: usize,
pub(crate) drop_count: usize,
pub(crate) begin_count: usize,
pub(crate) commit_count: usize,
pub(crate) rollback_count: usize,
pub(crate) select_count: u32,
pub(crate) insert_count: u32,
pub(crate) delete_count: u32,
pub(crate) update_count: u32,
pub(crate) create_count: u32,
pub(crate) create_index_count: u32,
pub(crate) drop_count: u32,
pub(crate) begin_count: u32,
pub(crate) commit_count: u32,
pub(crate) rollback_count: u32,
}
impl InteractionStats {
pub fn total_writes(&self) -> u32 {
self.insert_count
+ self.delete_count
+ self.update_count
+ self.create_count
+ self.create_index_count
+ self.drop_count
}
}
impl Display for InteractionStats {
@@ -271,8 +282,8 @@ impl Display for InteractionStats {
write!(
f,
"Read: {}, Write: {}, Delete: {}, Update: {}, Create: {}, CreateIndex: {}, Drop: {}, Begin: {}, Commit: {}, Rollback: {}",
self.read_count,
self.write_count,
self.select_count,
self.insert_count,
self.delete_count,
self.update_count,
self.create_count,
@@ -351,8 +362,8 @@ impl InteractionPlan {
pub(crate) fn stats(&self) -> InteractionStats {
let mut stats = InteractionStats {
read_count: 0,
write_count: 0,
select_count: 0,
insert_count: 0,
delete_count: 0,
update_count: 0,
create_count: 0,
@@ -365,8 +376,8 @@ impl InteractionPlan {
fn query_stat(q: &Query, stats: &mut InteractionStats) {
match q {
Query::Select(_) => stats.read_count += 1,
Query::Insert(_) => stats.write_count += 1,
Query::Select(_) => stats.select_count += 1,
Query::Insert(_) => stats.insert_count += 1,
Query::Delete(_) => stats.delete_count += 1,
Query::Create(_) => stats.create_count += 1,
Query::Drop(_) => stats.drop_count += 1,
@@ -395,16 +406,14 @@ impl InteractionPlan {
stats
}
}
impl ArbitraryFrom<&mut SimulatorEnv> for InteractionPlan {
fn arbitrary_from<R: rand::Rng>(rng: &mut R, env: &mut SimulatorEnv) -> Self {
pub fn generate_plan<R: rand::Rng>(rng: &mut R, env: &mut SimulatorEnv) -> Self {
let mut plan = InteractionPlan::new();
let num_interactions = env.opts.max_interactions;
let num_interactions = env.opts.max_interactions as usize;
// First create at least one table
let create_query = Create::arbitrary(rng);
let create_query = Create::arbitrary(rng, env);
env.tables.push(create_query.table.clone());
plan.plan
@@ -416,7 +425,7 @@ impl ArbitraryFrom<&mut SimulatorEnv> for InteractionPlan {
plan.plan.len(),
num_interactions
);
let interactions = Interactions::arbitrary_from(rng, (env, plan.stats()));
let interactions = Interactions::arbitrary_from(rng, env, (env, plan.stats()));
interactions.shadow(&mut env.tables);
plan.plan.push(interactions);
}
@@ -756,42 +765,42 @@ fn reopen_database(env: &mut SimulatorEnv) {
}
fn random_create<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Interactions {
let mut create = Create::arbitrary(rng);
let mut create = Create::arbitrary(rng, env);
while env.tables.iter().any(|t| t.name == create.table.name) {
create = Create::arbitrary(rng);
create = Create::arbitrary(rng, env);
}
Interactions::Query(Query::Create(create))
}
fn random_read<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Interactions {
Interactions::Query(Query::Select(Select::arbitrary_from(rng, env)))
Interactions::Query(Query::Select(Select::arbitrary(rng, env)))
}
fn random_expr<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Interactions {
Interactions::Query(Query::Select(SelectFree::arbitrary_from(rng, env).0))
Interactions::Query(Query::Select(SelectFree::arbitrary(rng, env).0))
}
fn random_write<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Interactions {
Interactions::Query(Query::Insert(Insert::arbitrary_from(rng, env)))
Interactions::Query(Query::Insert(Insert::arbitrary(rng, env)))
}
fn random_delete<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Interactions {
Interactions::Query(Query::Delete(Delete::arbitrary_from(rng, env)))
Interactions::Query(Query::Delete(Delete::arbitrary(rng, env)))
}
fn random_update<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Interactions {
Interactions::Query(Query::Update(Update::arbitrary_from(rng, env)))
Interactions::Query(Query::Update(Update::arbitrary(rng, env)))
}
fn random_drop<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Interactions {
Interactions::Query(Query::Drop(Drop::arbitrary_from(rng, env)))
Interactions::Query(Query::Drop(Drop::arbitrary(rng, env)))
}
fn random_create_index<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Option<Interactions> {
if env.tables.is_empty() {
return None;
}
let mut create_index = CreateIndex::arbitrary_from(rng, env);
let mut create_index = CreateIndex::arbitrary(rng, env);
while env
.tables
.iter()
@@ -801,7 +810,7 @@ fn random_create_index<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Option<
.iter()
.any(|i| i == &create_index.index_name)
{
create_index = CreateIndex::arbitrary_from(rng, env);
create_index = CreateIndex::arbitrary(rng, env);
}
Some(Interactions::Query(Query::CreateIndex(create_index)))
@@ -818,29 +827,30 @@ fn random_fault<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Interactions {
}
impl ArbitraryFrom<(&SimulatorEnv, InteractionStats)> for Interactions {
fn arbitrary_from<R: rand::Rng>(
fn arbitrary_from<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
_context: &C,
(env, stats): (&SimulatorEnv, InteractionStats),
) -> Self {
let remaining_ = remaining(env, &stats);
let remaining_ = remaining(env.opts.max_interactions, &env.profile.query, &stats);
frequency(
vec![
(
f64::min(remaining_.read, remaining_.write) + remaining_.create,
u32::min(remaining_.select, remaining_.insert) + remaining_.create,
Box::new(|rng: &mut R| {
Interactions::Property(Property::arbitrary_from(rng, (env, &stats)))
Interactions::Property(Property::arbitrary_from(rng, env, (env, &stats)))
}),
),
(
remaining_.read,
remaining_.select,
Box::new(|rng: &mut R| random_read(rng, env)),
),
(
remaining_.read / 3.0,
remaining_.select / 3,
Box::new(|rng: &mut R| random_expr(rng, env)),
),
(
remaining_.write,
remaining_.insert,
Box::new(|rng: &mut R| random_write(rng, env)),
),
(
@@ -868,15 +878,15 @@ impl ArbitraryFrom<(&SimulatorEnv, InteractionStats)> for Interactions {
),
(
// remaining_.drop,
0.0,
0,
Box::new(|rng: &mut R| random_drop(rng, env)),
),
(
remaining_
.read
.min(remaining_.write)
.select
.min(remaining_.insert)
.min(remaining_.create)
.max(1.0),
.max(1),
Box::new(|rng: &mut R| random_fault(rng, env)),
),
],

View File

@@ -1,21 +1,23 @@
use serde::{Deserialize, Serialize};
use sql_generation::{
generation::{frequency, pick, pick_index, ArbitraryFrom},
generation::{Arbitrary, ArbitraryFrom, GenerationContext, frequency, pick, pick_index},
model::{
query::{
Create, Delete, Drop, Insert, Select,
predicate::Predicate,
select::{CompoundOperator, CompoundSelect, ResultColumn, SelectBody, SelectInner},
transaction::{Begin, Commit, Rollback},
update::Update,
Create, Delete, Drop, Insert, Select,
},
table::SimValue,
},
};
use turso_core::{types, LimboError};
use turso_core::{LimboError, types};
use turso_parser::ast::{self, Distinctness};
use crate::{generation::Shadow as _, model::Query, runner::env::SimulatorEnv};
use crate::{
generation::Shadow as _, model::Query, profiles::query::QueryProfile, runner::env::SimulatorEnv,
};
use super::plan::{Assertion, Interaction, InteractionStats, ResultSet};
@@ -301,7 +303,10 @@ impl Property {
for row in rows {
for (i, (col, val)) in update.set_values.iter().enumerate() {
if &row[i] != val {
return Ok(Err(format!("updated row {} has incorrect value for column {col}: expected {val}, got {}", i, row[i])));
return Ok(Err(format!(
"updated row {} has incorrect value for column {col}: expected {val}, got {}",
i, row[i]
)));
}
}
}
@@ -380,7 +385,10 @@ impl Property {
if found {
Ok(Ok(()))
} else {
Ok(Err(format!("row [{:?}] not found in table", row.iter().map(|v| v.to_string()).collect::<Vec<String>>())))
Ok(Err(format!(
"row [{:?}] not found in table",
row.iter().map(|v| v.to_string()).collect::<Vec<String>>()
)))
}
}
Err(err) => Err(LimboError::InternalError(err.to_string())),
@@ -854,15 +862,22 @@ impl Property {
match (select_result_set, select_tlp_result_set) {
(Ok(select_rows), Ok(select_tlp_rows)) => {
if select_rows.len() != select_tlp_rows.len() {
return Ok(Err(format!("row count mismatch: select returned {} rows, select_tlp returned {} rows", select_rows.len(), select_tlp_rows.len())));
return Ok(Err(format!(
"row count mismatch: select returned {} rows, select_tlp returned {} rows",
select_rows.len(),
select_tlp_rows.len()
)));
}
// Check if any row in select_rows is not in select_tlp_rows
for row in select_rows.iter() {
if !select_tlp_rows.iter().any(|r| r == row) {
tracing::debug!(
"select and select_tlp returned different rows, ({}) is in select but not in select_tlp",
row.iter().map(|v| v.to_string()).collect::<Vec<String>>().join(", ")
);
"select and select_tlp returned different rows, ({}) is in select but not in select_tlp",
row.iter()
.map(|v| v.to_string())
.collect::<Vec<String>>()
.join(", ")
);
return Ok(Err(format!(
"row mismatch: row [{}] exists in select results but not in select_tlp results",
print_row(row)
@@ -873,9 +888,12 @@ impl Property {
for row in select_tlp_rows.iter() {
if !select_rows.iter().any(|r| r == row) {
tracing::debug!(
"select and select_tlp returned different rows, ({}) is in select_tlp but not in select",
row.iter().map(|v| v.to_string()).collect::<Vec<String>>().join(", ")
);
"select and select_tlp returned different rows, ({}) is in select_tlp but not in select",
row.iter()
.map(|v| v.to_string())
.collect::<Vec<String>>()
.join(", ")
);
return Ok(Err(format!(
"row mismatch: row [{}] exists in select_tlp but not in select",
@@ -935,7 +953,9 @@ impl Property {
if union_count == count1 + count2 {
Ok(Ok(()))
} else {
Ok(Err(format!("UNION ALL should preserve cardinality but it didn't: {count1} + {count2} != {union_count}")))
Ok(Err(format!(
"UNION ALL should preserve cardinality but it didn't: {count1} + {count2} != {union_count}"
)))
}
}
(Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => {
@@ -952,7 +972,7 @@ impl Property {
}
fn assert_all_table_values(tables: &[String]) -> impl Iterator<Item = Interaction> + use<'_> {
let checks = tables.iter().flat_map(|table| {
tables.iter().flat_map(|table| {
let select = Interaction::Query(Query::Select(Select::simple(
table.clone(),
Predicate::true_(),
@@ -1006,50 +1026,64 @@ fn assert_all_table_values(tables: &[String]) -> impl Iterator<Item = Interactio
}),
});
[select, assertion].into_iter()
});
checks
})
}
#[derive(Debug)]
pub(crate) struct Remaining {
pub(crate) read: f64,
pub(crate) write: f64,
pub(crate) create: f64,
pub(crate) create_index: f64,
pub(crate) delete: f64,
pub(crate) update: f64,
pub(crate) drop: f64,
pub(crate) select: u32,
pub(crate) insert: u32,
pub(crate) create: u32,
pub(crate) create_index: u32,
pub(crate) delete: u32,
pub(crate) update: u32,
pub(crate) drop: u32,
}
pub(crate) fn remaining(env: &SimulatorEnv, stats: &InteractionStats) -> Remaining {
let remaining_read = ((env.opts.max_interactions as f64 * env.opts.read_percent / 100.0)
- (stats.read_count as f64))
.max(0.0);
let remaining_write = ((env.opts.max_interactions as f64 * env.opts.write_percent / 100.0)
- (stats.write_count as f64))
.max(0.0);
let remaining_create = ((env.opts.max_interactions as f64 * env.opts.create_percent / 100.0)
- (stats.create_count as f64))
.max(0.0);
pub(crate) fn remaining(
max_interactions: u32,
opts: &QueryProfile,
stats: &InteractionStats,
) -> Remaining {
let total_weight = opts.select_weight
+ opts.create_table_weight
+ opts.create_index_weight
+ opts.insert_weight
+ opts.update_weight
+ opts.delete_weight
+ opts.drop_table_weight;
let remaining_create_index =
((env.opts.max_interactions as f64 * env.opts.create_index_percent / 100.0)
- (stats.create_index_count as f64))
.max(0.0);
let total_select = (max_interactions * opts.select_weight) / total_weight;
let total_insert = (max_interactions * opts.insert_weight) / total_weight;
let total_create = (max_interactions * opts.create_table_weight) / total_weight;
let total_create_index = (max_interactions * opts.create_index_weight) / total_weight;
let total_delete = (max_interactions * opts.delete_weight) / total_weight;
let total_update = (max_interactions * opts.update_weight) / total_weight;
let total_drop = (max_interactions * opts.drop_table_weight) / total_weight;
let remaining_delete = ((env.opts.max_interactions as f64 * env.opts.delete_percent / 100.0)
- (stats.delete_count as f64))
.max(0.0);
let remaining_update = ((env.opts.max_interactions as f64 * env.opts.update_percent / 100.0)
- (stats.update_count as f64))
.max(0.0);
let remaining_drop = ((env.opts.max_interactions as f64 * env.opts.drop_percent / 100.0)
- (stats.drop_count as f64))
.max(0.0);
let remaining_select = total_select
.checked_sub(stats.select_count)
.unwrap_or_default();
let remaining_insert = total_insert
.checked_sub(stats.insert_count)
.unwrap_or_default();
let remaining_create = total_create
.checked_sub(stats.create_count)
.unwrap_or_default();
let remaining_create_index = total_create_index
.checked_sub(stats.create_index_count)
.unwrap_or_default();
let remaining_delete = total_delete
.checked_sub(stats.delete_count)
.unwrap_or_default();
let remaining_update = total_update
.checked_sub(stats.update_count)
.unwrap_or_default();
let remaining_drop = total_drop.checked_sub(stats.drop_count).unwrap_or_default();
Remaining {
read: remaining_read,
write: remaining_write,
select: remaining_select,
insert: remaining_insert,
create: remaining_create,
create_index: remaining_create_index,
delete: remaining_delete,
@@ -1067,7 +1101,7 @@ fn property_insert_values_select<R: rand::Rng>(
let table = pick(&env.tables, rng);
// Generate rows to insert
let rows = (0..rng.random_range(1..=5))
.map(|_| Vec::<SimValue>::arbitrary_from(rng, table))
.map(|_| Vec::<SimValue>::arbitrary_from(rng, env, table))
.collect::<Vec<_>>();
// Pick a random row to select
@@ -1101,7 +1135,7 @@ fn property_insert_values_select<R: rand::Rng>(
}));
}
for _ in 0..rng.random_range(0..3) {
let query = Query::arbitrary_from(rng, (env, remaining));
let query = Query::arbitrary_from(rng, env, remaining);
match &query {
Query::Delete(Delete {
table: t,
@@ -1144,7 +1178,7 @@ fn property_insert_values_select<R: rand::Rng>(
// Select the row
let select_query = Select::simple(
table.name.clone(),
Predicate::arbitrary_from(rng, (table, &row)),
Predicate::arbitrary_from(rng, env, (table, &row)),
);
Property::InsertValuesSelect {
@@ -1158,7 +1192,7 @@ fn property_insert_values_select<R: rand::Rng>(
fn property_read_your_updates_back<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Property {
// e.g. UPDATE t SET a=1, b=2 WHERE c=1;
let update = Update::arbitrary_from(rng, env);
let update = Update::arbitrary(rng, env);
// e.g. SELECT a, b FROM t WHERE c=1;
let select = Select::single(
update.table().to_string(),
@@ -1190,7 +1224,7 @@ fn property_select_limit<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv) -> Prope
let select = Select::single(
table.name.clone(),
vec![ResultColumn::Star],
Predicate::arbitrary_from(rng, table),
Predicate::arbitrary_from(rng, env, table),
Some(rng.random_range(1..=5)),
Distinctness::All,
);
@@ -1215,7 +1249,7 @@ fn property_double_create_failure<R: rand::Rng>(
// - [x] There will be no errors in the middle interactions.(best effort)
// - [ ] Table `t` will not be renamed or dropped.(todo: add this constraint once ALTER or DROP is implemented)
for _ in 0..rng.random_range(0..3) {
let query = Query::arbitrary_from(rng, (env, remaining));
let query = Query::arbitrary_from(rng, env, remaining);
if let Query::Create(Create { table: t }) = &query {
// There will be no errors in the middle interactions.
// - Creating the same table is an error
@@ -1240,7 +1274,7 @@ fn property_delete_select<R: rand::Rng>(
// Get a random table
let table = pick(&env.tables, rng);
// Generate a random predicate
let predicate = Predicate::arbitrary_from(rng, table);
let predicate = Predicate::arbitrary_from(rng, env, table);
// Create random queries respecting the constraints
let mut queries = Vec::new();
@@ -1248,7 +1282,7 @@ fn property_delete_select<R: rand::Rng>(
// - [x] A row that holds for the predicate will not be inserted.
// - [ ] The table `t` will not be renamed, dropped, or altered. (todo: add this constraint once ALTER or DROP is implemented)
for _ in 0..rng.random_range(0..3) {
let query = Query::arbitrary_from(rng, (env, remaining));
let query = Query::arbitrary_from(rng, env, remaining);
match &query {
Query::Insert(Insert::Values { table: t, values }) => {
// A row that holds for the predicate will not be inserted.
@@ -1303,7 +1337,7 @@ fn property_drop_select<R: rand::Rng>(
// - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort)
// - [-] The table `t` will not be created, no table will be renamed to `t`. (todo: update this constraint once ALTER is implemented)
for _ in 0..rng.random_range(0..3) {
let query = Query::arbitrary_from(rng, (env, remaining));
let query = Query::arbitrary_from(rng, env, remaining);
if let Query::Create(Create { table: t }) = &query {
// - The table `t` will not be created
if t.name == table.name {
@@ -1313,7 +1347,10 @@ fn property_drop_select<R: rand::Rng>(
queries.push(query);
}
let select = Select::simple(table.name.clone(), Predicate::arbitrary_from(rng, table));
let select = Select::simple(
table.name.clone(),
Predicate::arbitrary_from(rng, env, table),
);
Property::DropSelect {
table: table.name.clone(),
@@ -1326,7 +1363,7 @@ fn property_select_select_optimizer<R: rand::Rng>(rng: &mut R, env: &SimulatorEn
// Get a random table
let table = pick(&env.tables, rng);
// Generate a random predicate
let predicate = Predicate::arbitrary_from(rng, table);
let predicate = Predicate::arbitrary_from(rng, env, table);
// Transform into a Binary predicate to force values to be casted to a bool
let expr = ast::Expr::Binary(
Box::new(predicate.0),
@@ -1344,8 +1381,8 @@ fn property_where_true_false_null<R: rand::Rng>(rng: &mut R, env: &SimulatorEnv)
// Get a random table
let table = pick(&env.tables, rng);
// Generate a random predicate
let p1 = Predicate::arbitrary_from(rng, table);
let p2 = Predicate::arbitrary_from(rng, table);
let p1 = Predicate::arbitrary_from(rng, env, table);
let p2 = Predicate::arbitrary_from(rng, env, table);
// Create the select query
let select = Select::simple(table.name.clone(), p1);
@@ -1363,8 +1400,8 @@ fn property_union_all_preserves_cardinality<R: rand::Rng>(
// Get a random table
let table = pick(&env.tables, rng);
// Generate a random predicate
let p1 = Predicate::arbitrary_from(rng, table);
let p2 = Predicate::arbitrary_from(rng, table);
let p1 = Predicate::arbitrary_from(rng, env, table);
let p2 = Predicate::arbitrary_from(rng, env, table);
// Create the select query
let select = Select::single(
@@ -1387,7 +1424,7 @@ fn property_fsync_no_wait<R: rand::Rng>(
remaining: &Remaining,
) -> Property {
Property::FsyncNoWait {
query: Query::arbitrary_from(rng, (env, remaining)),
query: Query::arbitrary_from(rng, env, remaining),
tables: env.tables.iter().map(|t| t.name.clone()).collect(),
}
}
@@ -1398,108 +1435,111 @@ fn property_faulty_query<R: rand::Rng>(
remaining: &Remaining,
) -> Property {
Property::FaultyQuery {
query: Query::arbitrary_from(rng, (env, remaining)),
query: Query::arbitrary_from(rng, env, remaining),
tables: env.tables.iter().map(|t| t.name.clone()).collect(),
}
}
impl ArbitraryFrom<(&SimulatorEnv, &InteractionStats)> for Property {
fn arbitrary_from<R: rand::Rng>(
fn arbitrary_from<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
(env, stats): (&SimulatorEnv, &InteractionStats),
) -> Self {
let remaining_ = remaining(env, stats);
let opts = context.opts();
let remaining_ = remaining(env.opts.max_interactions, &env.profile.query, stats);
frequency(
vec![
(
if !env.opts.disable_insert_values_select {
f64::min(remaining_.read, remaining_.write)
u32::min(remaining_.select, remaining_.insert)
} else {
0.0
0
},
Box::new(|rng: &mut R| property_insert_values_select(rng, env, &remaining_)),
),
(
remaining_.read,
remaining_.select,
Box::new(|rng: &mut R| property_table_has_expected_content(rng, env)),
),
(
f64::min(remaining_.read, remaining_.write),
u32::min(remaining_.select, remaining_.insert),
Box::new(|rng: &mut R| property_read_your_updates_back(rng, env)),
),
(
if !env.opts.disable_double_create_failure {
remaining_.create / 2.0
remaining_.create / 2
} else {
0.0
0
},
Box::new(|rng: &mut R| property_double_create_failure(rng, env, &remaining_)),
),
(
if !env.opts.disable_select_limit {
remaining_.read
remaining_.select
} else {
0.0
0
},
Box::new(|rng: &mut R| property_select_limit(rng, env)),
),
(
if !env.opts.disable_delete_select {
f64::min(remaining_.read, remaining_.write).min(remaining_.delete)
u32::min(remaining_.select, remaining_.insert).min(remaining_.delete)
} else {
0.0
0
},
Box::new(|rng: &mut R| property_delete_select(rng, env, &remaining_)),
),
(
if !env.opts.disable_drop_select {
// remaining_.drop
0.0
0
} else {
0.0
0
},
Box::new(|rng: &mut R| property_drop_select(rng, env, &remaining_)),
),
(
if !env.opts.disable_select_optimizer {
remaining_.read / 2.0
remaining_.select / 2
} else {
0.0
0
},
Box::new(|rng: &mut R| property_select_select_optimizer(rng, env)),
),
(
if env.opts.experimental_indexes && !env.opts.disable_where_true_false_null {
remaining_.read / 2.0
if opts.indexes && !env.opts.disable_where_true_false_null {
remaining_.select / 2
} else {
0.0
0
},
Box::new(|rng: &mut R| property_where_true_false_null(rng, env)),
),
(
if env.opts.experimental_indexes
&& !env.opts.disable_union_all_preserves_cardinality
{
remaining_.read / 3.0
if opts.indexes && !env.opts.disable_union_all_preserves_cardinality {
remaining_.select / 3
} else {
0.0
0
},
Box::new(|rng: &mut R| property_union_all_preserves_cardinality(rng, env)),
),
(
if !env.opts.disable_fsync_no_wait {
50.0 // Freestyle number
if env.profile.io.enable && !env.opts.disable_fsync_no_wait {
50 // Freestyle number
} else {
0.0
0
},
Box::new(|rng: &mut R| property_fsync_no_wait(rng, env, &remaining_)),
),
(
if !env.opts.disable_faulty_query {
20.0
if env.profile.io.enable
&& env.profile.io.fault.enable
&& !env.opts.disable_faulty_query
{
20
} else {
0.0
0
},
Box::new(|rng: &mut R| property_faulty_query(rng, env, &remaining_)),
),

View File

@@ -1,35 +1,39 @@
use crate::{model::Query, SimulatorEnv};
use crate::model::Query;
use rand::Rng;
use sql_generation::{
generation::{frequency, Arbitrary, ArbitraryFrom},
model::query::{update::Update, Create, Delete, Insert, Select},
generation::{Arbitrary, ArbitraryFrom, GenerationContext, frequency},
model::query::{Create, Delete, Insert, Select, update::Update},
};
use super::property::Remaining;
impl ArbitraryFrom<(&SimulatorEnv, &Remaining)> for Query {
fn arbitrary_from<R: Rng>(rng: &mut R, (env, remaining): (&SimulatorEnv, &Remaining)) -> Self {
impl ArbitraryFrom<&Remaining> for Query {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
remaining: &Remaining,
) -> Self {
frequency(
vec![
(
remaining.create,
Box::new(|rng| Self::Create(Create::arbitrary(rng))),
Box::new(|rng| Self::Create(Create::arbitrary(rng, context))),
),
(
remaining.read,
Box::new(|rng| Self::Select(Select::arbitrary_from(rng, env))),
remaining.select,
Box::new(|rng| Self::Select(Select::arbitrary(rng, context))),
),
(
remaining.write,
Box::new(|rng| Self::Insert(Insert::arbitrary_from(rng, env))),
remaining.insert,
Box::new(|rng| Self::Insert(Insert::arbitrary(rng, context))),
),
(
remaining.update,
Box::new(|rng| Self::Update(Update::arbitrary_from(rng, env))),
Box::new(|rng| Self::Update(Update::arbitrary(rng, context))),
),
(
f64::min(remaining.write, remaining.delete),
Box::new(|rng| Self::Delete(Delete::arbitrary_from(rng, env))),
remaining.insert.min(remaining.delete),
Box::new(|rng| Self::Delete(Delete::arbitrary(rng, context))),
),
],
rng,

View File

@@ -8,25 +8,26 @@ use rand::prelude::*;
use runner::bugbase::{Bug, BugBase, LoadedBug};
use runner::cli::{SimulatorCLI, SimulatorCommand};
use runner::env::SimulatorEnv;
use runner::execution::{execute_plans, Execution, ExecutionHistory, ExecutionResult};
use runner::execution::{Execution, ExecutionHistory, ExecutionResult, execute_plans};
use runner::{differential, watch};
use sql_generation::generation::ArbitraryFrom;
use std::any::Any;
use std::backtrace::Backtrace;
use std::fs::OpenOptions;
use std::io::{IsTerminal, Write};
use std::path::Path;
use std::sync::{mpsc, Arc, Mutex};
use std::sync::{Arc, Mutex, mpsc};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::field::MakeExt;
use tracing_subscriber::fmt::format;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use crate::profiles::Profile;
use crate::runner::doublecheck;
use crate::runner::env::{Paths, SimulationPhase, SimulationType};
mod generation;
mod model;
mod profiles;
mod runner;
mod shrink;
@@ -35,80 +36,89 @@ fn main() -> anyhow::Result<()> {
let mut cli_opts = SimulatorCLI::parse();
cli_opts.validate()?;
match cli_opts.subcommand {
Some(SimulatorCommand::List) => {
let mut bugbase = BugBase::load()?;
bugbase.list_bugs()
}
Some(SimulatorCommand::Loop { n, short_circuit }) => {
banner();
for i in 0..n {
println!("iteration {i}");
let result = testing_main(&cli_opts);
if result.is_err() && short_circuit {
println!("short circuiting after {i} iterations");
return result;
} else if result.is_err() {
println!("iteration {i} failed");
} else {
println!("iteration {i} succeeded");
}
let profile = Profile::parse_from_type(cli_opts.profile.clone())?;
tracing::debug!(sim_profile = ?profile);
if let Some(ref command) = cli_opts.subcommand {
match command {
SimulatorCommand::List => {
let mut bugbase = BugBase::load()?;
bugbase.list_bugs()
}
SimulatorCommand::Loop { n, short_circuit } => {
banner();
for i in 0..*n {
println!("iteration {i}");
let result = testing_main(&cli_opts, &profile);
if result.is_err() && *short_circuit {
println!("short circuiting after {i} iterations");
return result;
} else if result.is_err() {
println!("iteration {i} failed");
} else {
println!("iteration {i} succeeded");
}
}
Ok(())
}
SimulatorCommand::Test { filter } => {
let mut bugbase = BugBase::load()?;
let bugs = bugbase.load_bugs()?;
let mut bugs = bugs
.into_iter()
.flat_map(|bug| {
let runs = bug
.runs
.into_iter()
.filter_map(|run| run.error.clone().map(|_| run))
.filter(|run| run.error.as_ref().unwrap().contains(filter))
.map(|run| run.cli_options)
.collect::<Vec<_>>();
runs.into_iter()
.map(|mut cli_opts| {
cli_opts.seed = Some(bug.seed);
cli_opts.load = None;
cli_opts
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
bugs.sort();
bugs.dedup_by(|a, b| a == b);
println!(
"found {} previously triggered configurations with {}",
bugs.len(),
filter
);
let results = bugs
.into_iter()
.map(|cli_opts| testing_main(&cli_opts, &profile))
.collect::<Vec<_>>();
let (successes, failures): (Vec<_>, Vec<_>) =
results.into_iter().partition(|result| result.is_ok());
println!("the results of the change are:");
println!("\t{} successful runs", successes.len());
println!("\t{} failed runs", failures.len());
Ok(())
}
SimulatorCommand::PrintSchema => {
let schema = schemars::schema_for!(crate::Profile);
println!("{}", serde_json::to_string_pretty(&schema).unwrap());
Ok(())
}
Ok(())
}
Some(SimulatorCommand::Test { filter }) => {
let mut bugbase = BugBase::load()?;
let bugs = bugbase.load_bugs()?;
let mut bugs = bugs
.into_iter()
.flat_map(|bug| {
let runs = bug
.runs
.into_iter()
.filter_map(|run| run.error.clone().map(|_| run))
.filter(|run| run.error.as_ref().unwrap().contains(&filter))
.map(|run| run.cli_options)
.collect::<Vec<_>>();
runs.into_iter()
.map(|mut cli_opts| {
cli_opts.seed = Some(bug.seed);
cli_opts.load = None;
cli_opts
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
bugs.sort();
bugs.dedup_by(|a, b| a == b);
println!(
"found {} previously triggered configurations with {}",
bugs.len(),
filter
);
let results = bugs
.into_iter()
.map(|cli_opts| testing_main(&cli_opts))
.collect::<Vec<_>>();
let (successes, failures): (Vec<_>, Vec<_>) =
results.into_iter().partition(|result| result.is_ok());
println!("the results of the change are:");
println!("\t{} successful runs", successes.len());
println!("\t{} failed runs", failures.len());
Ok(())
}
None => {
banner();
testing_main(&cli_opts)
}
} else {
banner();
testing_main(&cli_opts, &profile)
}
}
fn testing_main(cli_opts: &SimulatorCLI) -> anyhow::Result<()> {
fn testing_main(cli_opts: &SimulatorCLI, profile: &Profile) -> anyhow::Result<()> {
let mut bugbase = if cli_opts.disable_bugbase {
None
} else {
@@ -116,7 +126,7 @@ fn testing_main(cli_opts: &SimulatorCLI) -> anyhow::Result<()> {
Some(BugBase::load()?)
};
let (seed, mut env, plans) = setup_simulation(bugbase.as_mut(), cli_opts);
let (seed, mut env, plans) = setup_simulation(bugbase.as_mut(), cli_opts, profile);
if cli_opts.watch {
watch_mode(env).unwrap();
@@ -471,6 +481,7 @@ impl SandboxedResult {
fn setup_simulation(
bugbase: Option<&mut BugBase>,
cli_opts: &SimulatorCLI,
profile: &Profile,
) -> (u64, SimulatorEnv, Vec<InteractionPlan>) {
if let Some(seed) = &cli_opts.load {
let seed = seed.parse::<u64>().expect("seed should be a number");
@@ -484,7 +495,13 @@ fn setup_simulation(
if !paths.base.exists() {
std::fs::create_dir_all(&paths.base).unwrap();
}
let env = SimulatorEnv::new(bug.seed(), cli_opts, paths, SimulationType::Default);
let env = SimulatorEnv::new(
bug.seed(),
cli_opts,
paths,
SimulationType::Default,
profile,
);
let plan = match bug {
Bug::Loaded(LoadedBug { plan, .. }) => plan.clone(),
@@ -528,12 +545,12 @@ fn setup_simulation(
Paths::new(&dir)
};
let mut env = SimulatorEnv::new(seed, cli_opts, paths, SimulationType::Default);
let mut env = SimulatorEnv::new(seed, cli_opts, paths, SimulationType::Default, profile);
tracing::info!("Generating database interaction plan...");
let plans = (1..=env.opts.max_connections)
.map(|_| InteractionPlan::arbitrary_from(&mut env.rng.clone(), &mut env))
.map(|_| InteractionPlan::generate_plan(&mut env.rng.clone(), &mut env))
.collect::<Vec<_>>();
// todo: for now, we only use 1 connection, so it's safe to use the first plan.

View File

@@ -5,14 +5,14 @@ use itertools::Itertools;
use serde::{Deserialize, Serialize};
use sql_generation::model::{
query::{
Create, CreateIndex, Delete, Drop, EmptyContext, Insert, Select,
select::{CompoundOperator, FromClause, ResultColumn, SelectInner},
transaction::{Begin, Commit, Rollback},
update::Update,
Create, CreateIndex, Delete, Drop, EmptyContext, Insert, Select,
},
table::{JoinTable, JoinType, SimValue, Table, TableContext},
};
use turso_parser::ast::{fmt::ToTokens, Distinctness};
use turso_parser::ast::{Distinctness, fmt::ToTokens};
use crate::{generation::Shadow, runner::env::SimulatorTables};
@@ -282,10 +282,11 @@ impl Shadow for SelectInner {
Ok(join_table)
} else {
assert!(self
.columns
.iter()
.all(|col| matches!(col, ResultColumn::Expr(_))));
assert!(
self.columns
.iter()
.all(|col| matches!(col, ResultColumn::Expr(_)))
);
// If `WHERE` is false, just return an empty table
if !self.where_clause.test(&[], &Table::anonymous(vec![])) {

79
simulator/profiles/io.rs Normal file
View File

@@ -0,0 +1,79 @@
use garde::Validate;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::{max_dependent, min_dependent};
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields, default)]
pub struct IOProfile {
#[garde(skip)]
pub enable: bool,
#[garde(dive)]
pub latency: LatencyProfile,
#[garde(dive)]
pub fault: FaultProfile,
// TODO: expand here with header corruption options and faults on specific IO operations
}
impl Default for IOProfile {
fn default() -> Self {
Self {
enable: true,
latency: Default::default(),
fault: Default::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields, default)]
pub struct LatencyProfile {
#[garde(skip)]
pub enable: bool,
#[garde(range(min = 0, max = 100))]
/// Added IO latency probability
pub latency_probability: usize,
#[garde(custom(max_dependent(&self.max_tick)))]
/// Minimum tick time in microseconds for simulated time
pub min_tick: u64,
#[garde(custom(min_dependent(&self.min_tick)))]
/// Maximum tick time in microseconds for simulated time
pub max_tick: u64,
}
impl Default for LatencyProfile {
fn default() -> Self {
Self {
enable: true,
latency_probability: 1,
min_tick: 1,
max_tick: 30,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields, default)]
pub struct FaultProfile {
#[garde(skip)]
pub enable: bool,
// TODO: modify SimIo impls to have a FaultProfile inside so they can skip faults depending on the profile
#[garde(skip)]
pub read: bool,
#[garde(skip)]
pub write: bool,
#[garde(skip)]
pub sync: bool,
}
impl Default for FaultProfile {
fn default() -> Self {
Self {
enable: true,
read: true,
write: true,
sync: true,
}
}
}

192
simulator/profiles/mod.rs Normal file
View File

@@ -0,0 +1,192 @@
use std::{
fmt::Display,
fs,
num::NonZeroU32,
path::{Path, PathBuf},
str::FromStr,
};
use anyhow::Context;
use garde::Validate;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use sql_generation::generation::{InsertOpts, LargeTableOpts, Opts, QueryOpts, TableOpts};
use strum::EnumString;
use crate::profiles::{
io::{FaultProfile, IOProfile},
query::QueryProfile,
};
pub mod io;
pub mod query;
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(default)]
#[schemars(deny_unknown_fields)]
pub struct Profile {
#[garde(skip)]
/// Experimental MVCC feature
pub experimental_mvcc: bool,
#[garde(dive)]
pub io: IOProfile,
#[garde(dive)]
pub query: QueryProfile,
}
#[allow(clippy::derivable_impls)]
impl Default for Profile {
fn default() -> Self {
Self {
experimental_mvcc: false,
io: Default::default(),
query: Default::default(),
}
}
}
impl Profile {
pub fn write_heavy() -> Self {
let profile = Profile {
query: QueryProfile {
gen_opts: Opts {
// TODO: in the future tweak blob size for bigger inserts
// TODO: increase number of rows as well
table: TableOpts {
large_table: LargeTableOpts {
large_table_prob: 0.4,
..Default::default()
},
..Default::default()
},
query: QueryOpts {
insert: InsertOpts {
min_rows: NonZeroU32::new(5).unwrap(),
max_rows: NonZeroU32::new(11).unwrap(),
},
..Default::default()
},
..Default::default()
},
select_weight: 30,
insert_weight: 70,
delete_weight: 0,
update_weight: 0,
..Default::default()
},
..Default::default()
};
// Validate that we as the developer are not creating an incorrect default profile
profile.validate().unwrap();
profile
}
pub fn faultless() -> Self {
let profile = Profile {
io: IOProfile {
fault: FaultProfile {
enable: false,
..Default::default()
},
..Default::default()
},
query: QueryProfile {
create_table_weight: 0,
create_index_weight: 0,
..Default::default()
},
..Default::default()
};
// Validate that we as the developer are not creating an incorrect default profile
profile.validate().unwrap();
profile
}
pub fn parse_from_type(profile_type: ProfileType) -> anyhow::Result<Self> {
let profile = match profile_type {
ProfileType::Default => Self::default(),
ProfileType::WriteHeavy => Self::write_heavy(),
ProfileType::Faultless => Self::faultless(),
ProfileType::Custom(path) => {
Self::parse(path).with_context(|| "failed to parse JSON profile")?
}
};
Ok(profile)
}
// TODO: in the future handle extension and composability of profiles here
pub fn parse(path: impl AsRef<Path>) -> anyhow::Result<Self> {
let contents = fs::read_to_string(path)?;
// use json5 so we can support comments and trailing commas
let profile: Profile = json5::from_str(&contents)?;
profile.validate()?;
Ok(profile)
}
}
#[derive(
Debug,
Default,
Clone,
Serialize,
Deserialize,
EnumString,
PartialEq,
Eq,
PartialOrd,
Ord,
strum::Display,
strum::VariantNames,
)]
#[serde(rename_all = "snake_case")]
#[strum(ascii_case_insensitive, serialize_all = "snake_case")]
pub enum ProfileType {
#[default]
Default,
WriteHeavy,
Faultless,
#[strum(disabled)]
Custom(PathBuf),
}
impl ProfileType {
pub fn parse(s: &str) -> anyhow::Result<Self> {
if let Ok(prof) = ProfileType::from_str(s) {
Ok(prof)
} else if let path = PathBuf::from(s)
&& path.exists()
{
Ok(ProfileType::Custom(path))
} else {
Err(anyhow::anyhow!(
"failed identifying predifined profile or custom profile path"
))
}
}
}
/// Minimum value of field is dependent on another field in the struct
fn min_dependent<T: PartialOrd + Display>(min: &T) -> impl FnOnce(&T, &()) -> garde::Result + '_ {
move |value, _| {
if value < min {
return Err(garde::Error::new(format!(
"`{value}` is smaller than `{min}`"
)));
}
Ok(())
}
}
/// Maximum value of field is dependent on another field in the struct
fn max_dependent<T: PartialOrd + Display>(max: &T) -> impl FnOnce(&T, &()) -> garde::Result + '_ {
move |value, _| {
if value > max {
return Err(garde::Error::new(format!(
"`{value}` is bigger than `{max}`"
)));
}
Ok(())
}
}

View File

@@ -0,0 +1,50 @@
use garde::Validate;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use sql_generation::generation::Opts;
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields, default)]
pub struct QueryProfile {
#[garde(dive)]
pub gen_opts: Opts,
#[garde(skip)]
pub select_weight: u32,
#[garde(skip)]
pub create_table_weight: u32,
#[garde(skip)]
pub create_index_weight: u32,
#[garde(skip)]
pub insert_weight: u32,
#[garde(skip)]
pub update_weight: u32,
#[garde(skip)]
pub delete_weight: u32,
#[garde(skip)]
pub drop_table_weight: u32,
}
impl Default for QueryProfile {
fn default() -> Self {
Self {
gen_opts: Opts::default(),
select_weight: 60,
create_table_weight: 15,
create_index_weight: 5,
insert_weight: 30,
update_weight: 20,
delete_weight: 20,
drop_table_weight: 2,
}
}
}
#[derive(Debug, Clone, strum::VariantArray)]
pub enum QueryTypes {
CreateTable,
CreateIndex,
Insert,
Update,
Delete,
DropTable,
}

View File

@@ -6,7 +6,7 @@ use std::{
time::SystemTime,
};
use anyhow::{anyhow, Context};
use anyhow::{Context, anyhow};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};

View File

@@ -1,6 +1,13 @@
use clap::{command, Parser};
use clap::{
Arg, Command, Error, Parser,
builder::{PossibleValue, TypedValueParser, ValueParserFactory},
command,
error::{ContextKind, ContextValue, ErrorKind},
};
use serde::{Deserialize, Serialize};
use crate::profiles::ProfileType;
#[derive(Parser, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
#[command(name = "limbo-simulator")]
#[command(author, version, about, long_about = None)]
@@ -107,34 +114,25 @@ pub struct SimulatorCLI {
pub disable_faulty_query: bool,
#[clap(long, help = "disable Reopen-Database fault", default_value_t = false)]
pub disable_reopen_database: bool,
#[clap(
long = "latency-prob",
help = "added IO latency probability",
default_value_t = 1
)]
pub latency_probability: usize,
#[clap(
long,
help = "Minimum tick time in microseconds for simulated time",
default_value_t = 1
)]
pub min_tick: u64,
#[clap(
long,
help = "Maximum tick time in microseconds for simulated time",
default_value_t = 30
)]
pub max_tick: u64,
#[clap(long = "latency-prob", help = "added IO latency probability")]
pub latency_probability: Option<usize>,
#[clap(long, help = "Minimum tick time in microseconds for simulated time")]
pub min_tick: Option<u64>,
#[clap(long, help = "Maximum tick time in microseconds for simulated time")]
pub max_tick: Option<u64>,
#[clap(long, help = "Enable experimental MVCC feature")]
pub experimental_mvcc: bool,
pub experimental_mvcc: Option<bool>,
#[clap(long, help = "Disable experimental indexing feature")]
pub disable_experimental_indexes: bool,
pub disable_experimental_indexes: Option<bool>,
#[clap(
long,
help = "Keep all database and plan files",
default_value_t = false
)]
pub keep_files: bool,
#[clap(long, default_value_t = ProfileType::Default)]
/// Profile selector for Simulation run
pub profile: ProfileType,
}
#[derive(Parser, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
@@ -167,6 +165,8 @@ pub enum SimulatorCommand {
)]
filter: String,
},
/// Print profile Json Schema
PrintSchema,
}
impl SimulatorCLI {
@@ -192,10 +192,10 @@ impl SimulatorCLI {
anyhow::bail!("Cannot set seed and load plan at the same time");
}
if self.latency_probability > 100 {
if self.latency_probability.is_some_and(|prob| prob > 100) {
anyhow::bail!(
"latency probability must be a number between 0 and 100. Got `{}`",
self.latency_probability
self.latency_probability.unwrap()
);
}
@@ -206,3 +206,70 @@ impl SimulatorCLI {
Ok(())
}
}
#[derive(Clone)]
pub struct ProfileTypeParser;
impl TypedValueParser for ProfileTypeParser {
type Value = ProfileType;
fn parse_ref(
&self,
cmd: &Command,
arg: Option<&Arg>,
value: &std::ffi::OsStr,
) -> Result<Self::Value, Error> {
let s = value
.to_str()
.ok_or_else(|| Error::new(ErrorKind::InvalidUtf8).with_cmd(cmd))?;
ProfileType::parse(s).map_err(|_| {
let mut err = Error::new(ErrorKind::InvalidValue).with_cmd(cmd);
if let Some(arg) = arg {
err.insert(
ContextKind::InvalidArg,
ContextValue::String(arg.to_string()),
);
}
err.insert(
ContextKind::InvalidValue,
ContextValue::String(s.to_string()),
);
err.insert(
ContextKind::ValidValue,
ContextValue::Strings(
self.possible_values()
.unwrap()
.map(|s| s.get_name().to_string())
.collect(),
),
);
err
})
}
fn possible_values(&self) -> Option<Box<dyn Iterator<Item = PossibleValue> + '_>> {
use strum::VariantNames;
Some(Box::new(
Self::Value::VARIANTS
.iter()
.map(|variant| {
// Custom variant should be listed as a Custom path
if variant.eq_ignore_ascii_case("custom") {
"CUSTOM_PATH"
} else {
variant
}
})
.map(PossibleValue::new),
))
}
}
impl ValueParserFactory for ProfileType {
type Parser = ProfileTypeParser;
fn value_parser() -> Self::Parser {
ProfileTypeParser
}
}

View File

@@ -4,18 +4,18 @@ use sql_generation::{generation::pick_index, model::table::SimValue};
use turso_core::Value;
use crate::{
InteractionPlan,
generation::{
plan::{Interaction, InteractionPlanState, ResultSet},
Shadow as _,
plan::{Interaction, InteractionPlanState, ResultSet},
},
model::Query,
runner::execution::ExecutionContinuation,
InteractionPlan,
};
use super::{
env::{SimConnection, SimulatorEnv},
execution::{execute_interaction, Execution, ExecutionHistory, ExecutionResult},
execution::{Execution, ExecutionHistory, ExecutionResult, execute_interaction},
};
pub(crate) fn run_simulation(
@@ -249,7 +249,9 @@ fn execute_plan(
match (limbo_values, rusqlite_values) {
(Ok(limbo_values), Ok(rusqlite_values)) => {
if limbo_values != rusqlite_values {
tracing::error!("returned values from limbo and rusqlite results do not match");
tracing::error!(
"returned values from limbo and rusqlite results do not match"
);
let diff = limbo_values
.iter()
.zip(rusqlite_values.iter())
@@ -303,7 +305,9 @@ fn execute_plan(
tracing::warn!("rusqlite error {}", rusqlite_err);
}
(Ok(limbo_result), Err(rusqlite_err)) => {
tracing::error!("limbo and rusqlite results do not match, limbo returned values but rusqlite failed");
tracing::error!(
"limbo and rusqlite results do not match, limbo returned values but rusqlite failed"
);
tracing::error!("limbo values {:?}", limbo_result);
tracing::error!("rusqlite error {}", rusqlite_err);
return Err(turso_core::LimboError::InternalError(
@@ -311,7 +315,9 @@ fn execute_plan(
));
}
(Err(limbo_err), Ok(_)) => {
tracing::error!("limbo and rusqlite results do not match, limbo failed but rusqlite returned values");
tracing::error!(
"limbo and rusqlite results do not match, limbo failed but rusqlite returned values"
);
tracing::error!("limbo error {}", limbo_err);
return Err(turso_core::LimboError::InternalError(
"limbo and rusqlite results do not match".into(),

View File

@@ -6,13 +6,13 @@ use std::{
use sql_generation::generation::pick_index;
use crate::{
generation::plan::InteractionPlanState, runner::execution::ExecutionContinuation,
InteractionPlan,
InteractionPlan, generation::plan::InteractionPlanState,
runner::execution::ExecutionContinuation,
};
use super::{
env::{SimConnection, SimulatorEnv},
execution::{execute_interaction, Execution, ExecutionHistory, ExecutionResult},
execution::{Execution, ExecutionHistory, ExecutionResult, execute_interaction},
};
pub(crate) fn run_simulation(
@@ -207,7 +207,9 @@ fn execute_plan(
match (limbo_values, doublecheck_values) {
(Ok(limbo_values), Ok(doublecheck_values)) => {
if limbo_values != doublecheck_values {
tracing::error!("returned values from limbo and doublecheck results do not match");
tracing::error!(
"returned values from limbo and doublecheck results do not match"
);
tracing::debug!("limbo values {:?}", limbo_values);
tracing::debug!(
"doublecheck values {:?}",
@@ -231,7 +233,9 @@ fn execute_plan(
}
}
(Ok(limbo_result), Err(doublecheck_err)) => {
tracing::error!("limbo and doublecheck results do not match, limbo returned values but doublecheck failed");
tracing::error!(
"limbo and doublecheck results do not match, limbo returned values but doublecheck failed"
);
tracing::error!("limbo values {:?}", limbo_result);
tracing::error!("doublecheck error {}", doublecheck_err);
return Err(turso_core::LimboError::InternalError(
@@ -239,7 +243,9 @@ fn execute_plan(
));
}
(Err(limbo_err), Ok(_)) => {
tracing::error!("limbo and doublecheck results do not match, limbo failed but doublecheck returned values");
tracing::error!(
"limbo and doublecheck results do not match, limbo failed but doublecheck returned values"
);
tracing::error!("limbo error {}", limbo_err);
return Err(turso_core::LimboError::InternalError(
"limbo and doublecheck results do not match".into(),

View File

@@ -5,11 +5,13 @@ use std::panic::UnwindSafe;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use garde::Validate;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use sql_generation::model::table::Table;
use turso_core::Database;
use crate::profiles::Profile;
use crate::runner::io::SimulatorIO;
use super::cli::SimulatorCLI;
@@ -59,6 +61,7 @@ impl Deref for SimulatorTables {
pub(crate) struct SimulatorEnv {
pub(crate) opts: SimulatorOpts,
pub profile: Profile,
pub(crate) connections: Vec<SimConnection>,
pub(crate) io: Arc<SimulatorIO>,
pub(crate) db: Option<Arc<Database>>,
@@ -85,6 +88,7 @@ impl SimulatorEnv {
paths: self.paths.clone(),
type_: self.type_,
phase: self.phase,
profile: self.profile.clone(),
}
}
@@ -93,13 +97,15 @@ impl SimulatorEnv {
self.connections.iter_mut().for_each(|c| c.disconnect());
self.rng = ChaCha8Rng::seed_from_u64(self.opts.seed);
let latency_prof = &self.profile.io.latency;
let io = Arc::new(
SimulatorIO::new(
self.opts.seed,
self.opts.page_size,
self.opts.latency_probability,
self.opts.min_tick,
self.opts.max_tick,
latency_prof.latency_probability,
latency_prof.min_tick,
latency_prof.max_tick,
)
.unwrap(),
);
@@ -119,8 +125,8 @@ impl SimulatorEnv {
let db = match Database::open_file(
io.clone(),
db_path.to_str().unwrap(),
self.opts.experimental_mvcc,
self.opts.experimental_indexes,
self.profile.experimental_mvcc,
self.profile.query.gen_opts.indexes,
) {
Ok(db) => db,
Err(e) => {
@@ -161,6 +167,7 @@ impl SimulatorEnv {
cli_opts: &SimulatorCLI,
paths: Paths,
simulation_type: SimulationType,
profile: &Profile,
) -> Self {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
@@ -223,13 +230,6 @@ impl SimulatorEnv {
max_connections: 1, // TODO: for now let's use one connection as we didn't implement
// correct transactions processing
max_tables: rng.random_range(0..128),
create_percent,
create_index_percent,
read_percent,
write_percent,
delete_percent,
drop_percent,
update_percent,
disable_select_optimizer: cli_opts.disable_select_optimizer,
disable_insert_values_select: cli_opts.disable_insert_values_select,
disable_double_create_failure: cli_opts.disable_double_create_failure,
@@ -242,27 +242,12 @@ impl SimulatorEnv {
disable_fsync_no_wait: cli_opts.disable_fsync_no_wait,
disable_faulty_query: cli_opts.disable_faulty_query,
page_size: 4096, // TODO: randomize this too
max_interactions: rng.random_range(cli_opts.minimum_tests..=cli_opts.maximum_tests),
max_interactions: rng.random_range(cli_opts.minimum_tests..=cli_opts.maximum_tests)
as u32,
max_time_simulation: cli_opts.maximum_time,
disable_reopen_database: cli_opts.disable_reopen_database,
latency_probability: cli_opts.latency_probability,
experimental_mvcc: cli_opts.experimental_mvcc,
experimental_indexes: !cli_opts.disable_experimental_indexes,
min_tick: cli_opts.min_tick,
max_tick: cli_opts.max_tick,
};
let io = Arc::new(
SimulatorIO::new(
seed,
opts.page_size,
cli_opts.latency_probability,
cli_opts.min_tick,
cli_opts.max_tick,
)
.unwrap(),
);
// Remove existing database file if it exists
let db_path = paths.db(&simulation_type, &SimulationPhase::Test);
@@ -275,11 +260,44 @@ impl SimulatorEnv {
std::fs::remove_file(&wal_path).unwrap();
}
let mut profile = profile.clone();
// Conditionals here so that we can override some profile options from the CLI
if let Some(mvcc) = cli_opts.experimental_mvcc {
profile.experimental_mvcc = mvcc;
}
if let Some(indexes) = cli_opts.disable_experimental_indexes {
profile.query.gen_opts.indexes = indexes;
}
if let Some(latency_prob) = cli_opts.latency_probability {
profile.io.latency.latency_probability = latency_prob;
}
if let Some(max_tick) = cli_opts.max_tick {
profile.io.latency.max_tick = max_tick;
}
if let Some(min_tick) = cli_opts.min_tick {
profile.io.latency.min_tick = min_tick;
}
profile.validate().unwrap();
let latency_prof = &profile.io.latency;
let io = Arc::new(
SimulatorIO::new(
seed,
opts.page_size,
latency_prof.latency_probability,
latency_prof.min_tick,
latency_prof.max_tick,
)
.unwrap(),
);
let db = match Database::open_file(
io.clone(),
db_path.to_str().unwrap(),
opts.experimental_mvcc,
opts.experimental_indexes,
profile.experimental_mvcc,
profile.query.gen_opts.indexes,
) {
Ok(db) => db,
Err(e) => {
@@ -301,6 +319,7 @@ impl SimulatorEnv {
db: Some(db),
type_: simulation_type,
phase: SimulationPhase::Test,
profile: profile.clone(),
}
}
@@ -394,15 +413,6 @@ pub(crate) struct SimulatorOpts {
pub(crate) ticks: usize,
pub(crate) max_connections: usize,
pub(crate) max_tables: usize,
// this next options are the distribution of workload where read_percent + write_percent +
// delete_percent == 100%
pub(crate) create_percent: f64,
pub(crate) create_index_percent: f64,
pub(crate) read_percent: f64,
pub(crate) write_percent: f64,
pub(crate) delete_percent: f64,
pub(crate) update_percent: f64,
pub(crate) drop_percent: f64,
pub(crate) disable_select_optimizer: bool,
pub(crate) disable_insert_values_select: bool,
@@ -416,14 +426,9 @@ pub(crate) struct SimulatorOpts {
pub(crate) disable_faulty_query: bool,
pub(crate) disable_reopen_database: bool,
pub(crate) max_interactions: usize,
pub(crate) max_interactions: u32,
pub(crate) page_size: usize,
pub(crate) max_time_simulation: usize,
pub(crate) latency_probability: usize,
pub(crate) experimental_mvcc: bool,
pub(crate) experimental_indexes: bool,
pub min_tick: u64,
pub max_tick: u64,
}
#[derive(Debug, Clone)]

View File

@@ -5,8 +5,8 @@ use tracing::instrument;
use turso_core::{Connection, LimboError, Result, StepResult};
use crate::generation::{
plan::{Interaction, InteractionPlan, InteractionPlanState, ResultSet},
Shadow as _,
plan::{Interaction, InteractionPlan, InteractionPlanState, ResultSet},
};
use super::env::{SimConnection, SimulatorEnv};

View File

@@ -6,10 +6,10 @@ use std::{
use rand::Rng as _;
use rand_chacha::ChaCha8Rng;
use tracing::{instrument, Level};
use tracing::{Level, instrument};
use turso_core::{File, Result};
use crate::runner::{clock::SimulatorClock, FAULT_ERROR_MSG};
use crate::runner::{FAULT_ERROR_MSG, clock::SimulatorClock};
pub(crate) struct SimulatorFile {
pub path: String,
pub(crate) inner: Arc<dyn File>,
@@ -150,7 +150,7 @@ impl File for SimulatorFile {
self.inner.unlock_file()
}
fn pread(&self, pos: usize, c: turso_core::Completion) -> Result<turso_core::Completion> {
fn pread(&self, pos: u64, c: turso_core::Completion) -> Result<turso_core::Completion> {
self.nr_pread_calls.set(self.nr_pread_calls.get() + 1);
if self.fault.get() {
tracing::debug!("pread fault");
@@ -173,7 +173,7 @@ impl File for SimulatorFile {
fn pwrite(
&self,
pos: usize,
pos: u64,
buffer: Arc<turso_core::Buffer>,
c: turso_core::Completion,
) -> Result<turso_core::Completion> {
@@ -201,7 +201,9 @@ impl File for SimulatorFile {
self.nr_sync_calls.set(self.nr_sync_calls.get() + 1);
if self.fault.get() {
// TODO: Enable this when https://github.com/tursodatabase/turso/issues/2091 is fixed.
tracing::debug!("ignoring sync fault because it causes false positives with current simulator design");
tracing::debug!(
"ignoring sync fault because it causes false positives with current simulator design"
);
self.fault.set(false);
}
let c = if let Some(latency) = self.generate_latency_duration() {
@@ -225,7 +227,7 @@ impl File for SimulatorFile {
fn pwritev(
&self,
pos: usize,
pos: u64,
buffers: Vec<Arc<turso_core::Buffer>>,
c: turso_core::Completion,
) -> Result<turso_core::Completion> {
@@ -255,7 +257,7 @@ impl File for SimulatorFile {
self.inner.size()
}
fn truncate(&self, len: usize, c: turso_core::Completion) -> Result<turso_core::Completion> {
fn truncate(&self, len: u64, c: turso_core::Completion) -> Result<turso_core::Completion> {
if self.fault.get() {
return Err(turso_core::LimboError::InternalError(
FAULT_ERROR_MSG.into(),

View File

@@ -5,7 +5,7 @@ use std::{
use rand::{RngCore, SeedableRng};
use rand_chacha::ChaCha8Rng;
use turso_core::{Clock, Instant, OpenFlags, PlatformIO, Result, IO};
use turso_core::{Clock, IO, Instant, OpenFlags, PlatformIO, Result};
use crate::runner::{clock::SimulatorClock, file::SimulatorFile};

View File

@@ -10,7 +10,7 @@ use crate::{
use super::{
env::{SimConnection, SimulatorEnv},
execution::{execute_interaction, Execution, ExecutionHistory, ExecutionResult},
execution::{Execution, ExecutionHistory, ExecutionResult, execute_interaction},
};
pub(crate) fn run_simulation(

View File

@@ -1,4 +1,5 @@
use crate::{
SandboxedResult, SimulatorEnv,
generation::{
plan::{Interaction, InteractionPlan, Interactions},
property::Property,
@@ -6,7 +7,6 @@ use crate::{
model::Query,
run_simulation,
runner::execution::Execution,
SandboxedResult, SimulatorEnv,
};
use std::sync::{Arc, Mutex};

View File

@@ -19,6 +19,8 @@ anarchist-readable-name-generator-lib = "0.2.0"
itertools = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
schemars = { workspace = true }
garde = { workspace = true, features = ["derive", "serde"] }
[dev-dependencies]
rand_chacha = "0.9.0"

View File

@@ -5,7 +5,7 @@ use turso_parser::ast::{
use crate::{
generation::{
frequency, gen_random_text, one_of, pick, pick_index, Arbitrary, ArbitraryFrom,
ArbitrarySizedFrom, GenerationContext,
ArbitrarySized, ArbitrarySizedFrom, GenerationContext,
},
model::table::SimValue,
};
@@ -14,8 +14,21 @@ impl<T> Arbitrary for Box<T>
where
T: Arbitrary,
{
fn arbitrary<R: rand::Rng>(rng: &mut R) -> Self {
Box::from(T::arbitrary(rng))
fn arbitrary<R: rand::Rng, C: GenerationContext>(rng: &mut R, context: &C) -> Self {
Box::from(T::arbitrary(rng, context))
}
}
impl<T> ArbitrarySized for Box<T>
where
T: ArbitrarySized,
{
fn arbitrary_sized<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
size: usize,
) -> Self {
Box::from(T::arbitrary_sized(rng, context, size))
}
}
@@ -23,8 +36,13 @@ impl<A, T> ArbitrarySizedFrom<A> for Box<T>
where
T: ArbitrarySizedFrom<A>,
{
fn arbitrary_sized_from<R: rand::Rng>(rng: &mut R, t: A, size: usize) -> Self {
Box::from(T::arbitrary_sized_from(rng, t, size))
fn arbitrary_sized_from<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
t: A,
size: usize,
) -> Self {
Box::from(T::arbitrary_sized_from(rng, context, t, size))
}
}
@@ -32,8 +50,8 @@ impl<T> Arbitrary for Option<T>
where
T: Arbitrary,
{
fn arbitrary<R: rand::Rng>(rng: &mut R) -> Self {
rng.random_bool(0.5).then_some(T::arbitrary(rng))
fn arbitrary<R: rand::Rng, C: GenerationContext>(rng: &mut R, context: &C) -> Self {
rng.random_bool(0.5).then_some(T::arbitrary(rng, context))
}
}
@@ -41,9 +59,14 @@ impl<A, T> ArbitrarySizedFrom<A> for Option<T>
where
T: ArbitrarySizedFrom<A>,
{
fn arbitrary_sized_from<R: rand::Rng>(rng: &mut R, t: A, size: usize) -> Self {
fn arbitrary_sized_from<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
t: A,
size: usize,
) -> Self {
rng.random_bool(0.5)
.then_some(T::arbitrary_sized_from(rng, t, size))
.then_some(T::arbitrary_sized_from(rng, context, t, size))
}
}
@@ -51,20 +74,26 @@ impl<A: Copy, T> ArbitraryFrom<A> for Vec<T>
where
T: ArbitraryFrom<A>,
{
fn arbitrary_from<R: rand::Rng>(rng: &mut R, t: A) -> Self {
fn arbitrary_from<R: rand::Rng, C: GenerationContext>(rng: &mut R, context: &C, t: A) -> Self {
let size = rng.random_range(0..5);
(0..size).map(|_| T::arbitrary_from(rng, t)).collect()
(0..size)
.map(|_| T::arbitrary_from(rng, context, t))
.collect()
}
}
// Freestyling generation
impl<C: GenerationContext> ArbitrarySizedFrom<&C> for Expr {
fn arbitrary_sized_from<R: rand::Rng>(rng: &mut R, t: &C, size: usize) -> Self {
impl ArbitrarySized for Expr {
fn arbitrary_sized<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
size: usize,
) -> Self {
frequency(
vec![
(
1,
Box::new(|rng| Expr::Literal(ast::Literal::arbitrary_from(rng, t))),
Box::new(|rng| Expr::Literal(ast::Literal::arbitrary(rng, context))),
),
(
size,
@@ -79,9 +108,9 @@ impl<C: GenerationContext> ArbitrarySizedFrom<&C> for Expr {
// }),
Box::new(|rng: &mut R| {
Expr::Binary(
Box::arbitrary_sized_from(rng, t, size - 1),
Operator::arbitrary(rng),
Box::arbitrary_sized_from(rng, t, size - 1),
Box::arbitrary_sized(rng, context, size - 1),
Operator::arbitrary(rng, context),
Box::arbitrary_sized(rng, context, size - 1),
)
}),
// Box::new(|rng| Expr::Case {
@@ -133,8 +162,8 @@ impl<C: GenerationContext> ArbitrarySizedFrom<&C> for Expr {
// })
Box::new(|rng| {
Expr::Unary(
UnaryOperator::arbitrary_from(rng, t),
Box::arbitrary_sized_from(rng, t, size - 1),
UnaryOperator::arbitrary(rng, context),
Box::arbitrary_sized(rng, context, size - 1),
)
}),
// TODO: skip Exists for now
@@ -159,7 +188,7 @@ impl<C: GenerationContext> ArbitrarySizedFrom<&C> for Expr {
}
impl Arbitrary for Operator {
fn arbitrary<R: rand::Rng>(rng: &mut R) -> Self {
fn arbitrary<R: rand::Rng, C: GenerationContext>(rng: &mut R, _context: &C) -> Self {
let choices = [
Operator::Add,
Operator::And,
@@ -190,7 +219,7 @@ impl Arbitrary for Operator {
}
impl Arbitrary for Type {
fn arbitrary<R: rand::Rng>(rng: &mut R) -> Self {
fn arbitrary<R: rand::Rng, C: GenerationContext>(rng: &mut R, _context: &C) -> Self {
let name = pick(&["INT", "INTEGER", "REAL", "TEXT", "BLOB", "ANY"], rng).to_string();
Self {
name,
@@ -199,11 +228,11 @@ impl Arbitrary for Type {
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for QualifiedName {
fn arbitrary_from<R: rand::Rng>(rng: &mut R, t: &C) -> Self {
impl Arbitrary for QualifiedName {
fn arbitrary<R: rand::Rng, C: GenerationContext>(rng: &mut R, context: &C) -> Self {
// TODO: for now just generate table name
let table_idx = pick_index(t.tables().len(), rng);
let table = &t.tables()[table_idx];
let table_idx = pick_index(context.tables().len(), rng);
let table = &context.tables()[table_idx];
// TODO: for now forego alias
Self {
db_name: None,
@@ -213,8 +242,8 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for QualifiedName {
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for LikeOperator {
fn arbitrary_from<R: rand::Rng>(rng: &mut R, _t: &C) -> Self {
impl Arbitrary for LikeOperator {
fn arbitrary<R: rand::Rng, C: GenerationContext>(rng: &mut R, _t: &C) -> Self {
let choice = rng.random_range(0..4);
match choice {
0 => LikeOperator::Glob,
@@ -227,8 +256,8 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for LikeOperator {
}
// Current implementation does not take into account the columns affinity nor if table is Strict
impl<C: GenerationContext> ArbitraryFrom<&C> for ast::Literal {
fn arbitrary_from<R: rand::Rng>(rng: &mut R, _t: &C) -> Self {
impl Arbitrary for ast::Literal {
fn arbitrary<R: rand::Rng, C: GenerationContext>(rng: &mut R, _t: &C) -> Self {
loop {
let choice = rng.random_range(0..5);
let lit = match choice {
@@ -255,7 +284,11 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for ast::Literal {
// Creates a litreal value
impl ArbitraryFrom<&Vec<&SimValue>> for ast::Expr {
fn arbitrary_from<R: rand::Rng>(rng: &mut R, values: &Vec<&SimValue>) -> Self {
fn arbitrary_from<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
_context: &C,
values: &Vec<&SimValue>,
) -> Self {
if values.is_empty() {
return Self::Literal(ast::Literal::Null);
}
@@ -265,8 +298,8 @@ impl ArbitraryFrom<&Vec<&SimValue>> for ast::Expr {
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for UnaryOperator {
fn arbitrary_from<R: rand::Rng>(rng: &mut R, _t: &C) -> Self {
impl Arbitrary for UnaryOperator {
fn arbitrary<R: rand::Rng, C: GenerationContext>(rng: &mut R, _t: &C) -> Self {
let choice = rng.random_range(0..4);
match choice {
0 => Self::BitwiseNot,

View File

@@ -3,24 +3,13 @@ use std::{iter::Sum, ops::SubAssign};
use anarchist_readable_name_generator_lib::readable_name_custom;
use rand::{distr::uniform::SampleUniform, Rng};
use crate::model::table::Table;
pub mod expr;
pub mod opts;
pub mod predicate;
pub mod query;
pub mod table;
#[derive(Debug, Clone, Copy)]
pub struct Opts {
/// Indexes enabled
pub indexes: bool,
}
/// Trait used to provide context to generation functions
pub trait GenerationContext {
fn tables(&self) -> &Vec<Table>;
fn opts(&self) -> Opts;
}
pub use opts::*;
type ArbitraryFromFunc<'a, R, T> = Box<dyn Fn(&mut R) -> T + 'a>;
type Choice<'a, R, T> = (usize, Box<dyn Fn(&mut R) -> Option<T> + 'a>);
@@ -30,7 +19,7 @@ type Choice<'a, R, T> = (usize, Box<dyn Fn(&mut R) -> Option<T> + 'a>);
/// the possible values of the type, with a bias towards smaller values for
/// practicality.
pub trait Arbitrary {
fn arbitrary<R: Rng>(rng: &mut R) -> Self;
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, context: &C) -> Self;
}
/// ArbitrarySized trait for generating random values of a specific size
@@ -40,7 +29,8 @@ pub trait Arbitrary {
/// must fit in the given size. This is useful for generating values that are
/// constrained by a specific size, such as integers or strings.
pub trait ArbitrarySized {
fn arbitrary_sized<R: Rng>(rng: &mut R, size: usize) -> Self;
fn arbitrary_sized<R: Rng, C: GenerationContext>(rng: &mut R, context: &C, size: usize)
-> Self;
}
/// ArbitraryFrom trait for generating random values from a given value
@@ -49,7 +39,7 @@ pub trait ArbitrarySized {
/// such as generating an integer within an interval, or a value that fits in a table,
/// or a predicate satisfying a given table row.
pub trait ArbitraryFrom<T> {
fn arbitrary_from<R: Rng>(rng: &mut R, t: T) -> Self;
fn arbitrary_from<R: Rng, C: GenerationContext>(rng: &mut R, context: &C, t: T) -> Self;
}
/// ArbitrarySizedFrom trait for generating random values from a given value
@@ -61,12 +51,21 @@ pub trait ArbitraryFrom<T> {
/// This is useful for generating values that are constrained by a specific size,
/// such as integers or strings, while still being dependent on the given value.
pub trait ArbitrarySizedFrom<T> {
fn arbitrary_sized_from<R: Rng>(rng: &mut R, t: T, size: usize) -> Self;
fn arbitrary_sized_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
t: T,
size: usize,
) -> Self;
}
/// ArbitraryFromMaybe trait for fallibally generating random values from a given value
pub trait ArbitraryFromMaybe<T> {
fn arbitrary_from_maybe<R: Rng>(rng: &mut R, t: T) -> Option<Self>
fn arbitrary_from_maybe<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
t: T,
) -> Option<Self>
where
Self: Sized;
}
@@ -143,11 +142,15 @@ pub fn pick_index<R: Rng>(choices: usize, rng: &mut R) -> usize {
/// pick_n_unique is a helper function for uniformly picking N unique elements from a range.
/// The elements themselves are usize, typically representing indices.
pub fn pick_n_unique<R: Rng>(range: std::ops::Range<usize>, n: usize, rng: &mut R) -> Vec<usize> {
pub fn pick_n_unique<R: Rng>(
range: std::ops::Range<usize>,
n: usize,
rng: &mut R,
) -> impl Iterator<Item = usize> {
use rand::seq::SliceRandom;
let mut items: Vec<usize> = range.collect();
items.shuffle(rng);
items.into_iter().take(n).collect()
items.into_iter().take(n)
}
/// gen_random_text uses `anarchist_readable_name_generator_lib` to generate random
@@ -169,20 +172,41 @@ pub fn gen_random_text<T: Rng>(rng: &mut T) -> String {
}
}
pub fn pick_unique<T: ToOwned + PartialEq>(
items: &[T],
pub fn pick_unique<'a, T: PartialEq>(
items: &'a [T],
count: usize,
rng: &mut impl rand::Rng,
) -> Vec<T::Owned>
where
<T as ToOwned>::Owned: PartialEq,
{
let mut picked: Vec<T::Owned> = Vec::new();
) -> impl Iterator<Item = &'a T> {
let mut picked: Vec<&T> = Vec::new();
while picked.len() < count {
let item = pick(items, rng);
if !picked.contains(&item.to_owned()) {
picked.push(item.to_owned());
if !picked.contains(&item) {
picked.push(item);
}
}
picked.into_iter()
}
#[cfg(test)]
mod tests {
use crate::{
generation::{GenerationContext, Opts},
model::table::Table,
};
#[derive(Debug, Default, Clone)]
pub struct TestContext {
pub opts: Opts,
pub tables: Vec<Table>,
}
impl GenerationContext for TestContext {
fn tables(&self) -> &Vec<Table> {
&self.tables
}
fn opts(&self) -> &Opts {
&self.opts
}
}
picked
}

View File

@@ -0,0 +1,238 @@
use std::{
fmt::Display,
num::{NonZero, NonZeroU32},
ops::Range,
};
use garde::Validate;
use rand::distr::weighted::WeightedIndex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::model::table::Table;
/// Trait used to provide context to generation functions
pub trait GenerationContext {
fn tables(&self) -> &Vec<Table>;
fn opts(&self) -> &Opts;
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields, default)]
pub struct Opts {
#[garde(skip)]
/// Indexes enabled
pub indexes: bool,
#[garde(dive)]
pub table: TableOpts,
#[garde(dive)]
pub query: QueryOpts,
}
impl Default for Opts {
fn default() -> Self {
Self {
indexes: true,
table: Default::default(),
query: Default::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields, default)]
pub struct TableOpts {
#[garde(dive)]
pub large_table: LargeTableOpts,
/// Range of numbers of columns to generate
#[garde(custom(range_struct_min(1)))]
pub column_range: Range<u32>,
}
impl Default for TableOpts {
fn default() -> Self {
Self {
large_table: Default::default(),
// Up to 10 columns
column_range: 1..11,
}
}
}
/// Options for generating large tables
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields, default)]
pub struct LargeTableOpts {
#[garde(skip)]
pub enable: bool,
#[garde(range(min = 0.0, max = 1.0))]
pub large_table_prob: f64,
/// Range of numbers of columns to generate
#[garde(custom(range_struct_min(1)))]
pub column_range: Range<u32>,
}
impl Default for LargeTableOpts {
fn default() -> Self {
Self {
enable: true,
large_table_prob: 0.1,
// todo: make this higher (128+)
column_range: 64..125,
}
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields, default)]
pub struct QueryOpts {
#[garde(dive)]
pub select: SelectOpts,
#[garde(dive)]
pub from_clause: FromClauseOpts,
#[garde(dive)]
pub insert: InsertOpts,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields, default)]
pub struct SelectOpts {
#[garde(range(min = 0.0, max = 1.0))]
pub order_by_prob: f64,
#[garde(length(min = 1))]
pub compound_selects: Vec<CompoundSelectWeight>,
}
impl Default for SelectOpts {
fn default() -> Self {
Self {
order_by_prob: 0.3,
compound_selects: vec![
CompoundSelectWeight {
num_compound_selects: 0,
weight: 95,
},
CompoundSelectWeight {
num_compound_selects: 1,
weight: 4,
},
CompoundSelectWeight {
num_compound_selects: 2,
weight: 1,
},
],
}
}
}
impl SelectOpts {
pub fn compound_select_weighted_index(&self) -> WeightedIndex<u32> {
WeightedIndex::new(self.compound_selects.iter().map(|weight| weight.weight)).unwrap()
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct CompoundSelectWeight {
pub num_compound_selects: u32,
pub weight: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields)]
pub struct FromClauseOpts {
#[garde(length(min = 1))]
pub joins: Vec<JoinWeight>,
}
impl Default for FromClauseOpts {
fn default() -> Self {
Self {
joins: vec![
JoinWeight {
num_joins: 0,
weight: 90,
},
JoinWeight {
num_joins: 1,
weight: 7,
},
JoinWeight {
num_joins: 2,
weight: 3,
},
],
}
}
}
impl FromClauseOpts {
pub fn as_weighted_index(&self) -> WeightedIndex<u32> {
WeightedIndex::new(self.joins.iter().map(|weight| weight.weight)).unwrap()
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct JoinWeight {
pub num_joins: u32,
pub weight: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(deny_unknown_fields)]
pub struct InsertOpts {
#[garde(skip)]
pub min_rows: NonZeroU32,
#[garde(skip)]
pub max_rows: NonZeroU32,
}
impl Default for InsertOpts {
fn default() -> Self {
Self {
min_rows: NonZero::new(1).unwrap(),
max_rows: NonZero::new(10).unwrap(),
}
}
}
fn range_struct_min<T: PartialOrd + Display>(
min: T,
) -> impl FnOnce(&Range<T>, &()) -> garde::Result {
move |value, _| {
if value.start < min {
return Err(garde::Error::new(format!(
"range start `{}` is smaller than {min}",
value.start
)));
} else if value.end < min {
return Err(garde::Error::new(format!(
"range end `{}` is smaller than {min}",
value.end
)));
}
Ok(())
}
}
#[allow(dead_code)]
fn range_struct_max<T: PartialOrd + Display>(
max: T,
) -> impl FnOnce(&Range<T>, &()) -> garde::Result {
move |value, _| {
if value.start > max {
return Err(garde::Error::new(format!(
"range start `{}` is smaller than {max}",
value.start
)));
} else if value.end > max {
return Err(garde::Error::new(format!(
"range end `{}` is smaller than {max}",
value.end
)));
}
Ok(())
}
}

View File

@@ -7,7 +7,7 @@ use crate::{
backtrack, one_of, pick,
predicate::{CompoundPredicate, SimplePredicate},
table::{GTValue, LTValue, LikeValue},
ArbitraryFrom, ArbitraryFromMaybe as _,
ArbitraryFrom, ArbitraryFromMaybe as _, GenerationContext,
},
model::{
query::predicate::Predicate,
@@ -17,8 +17,9 @@ use crate::{
impl Predicate {
/// Generate an [ast::Expr::Binary] [Predicate] from a column and [SimValue]
pub fn from_column_binary<R: rand::Rng>(
pub fn from_column_binary<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
column_name: &str,
value: &SimValue,
) -> Predicate {
@@ -32,7 +33,7 @@ impl Predicate {
)
}),
Box::new(|rng| {
let gt_value = GTValue::arbitrary_from(rng, value).0;
let gt_value = GTValue::arbitrary_from(rng, context, value).0;
Expr::Binary(
Box::new(Expr::Id(ast::Name::Ident(column_name.to_string()))),
ast::Operator::Greater,
@@ -40,7 +41,7 @@ impl Predicate {
)
}),
Box::new(|rng| {
let lt_value = LTValue::arbitrary_from(rng, value).0;
let lt_value = LTValue::arbitrary_from(rng, context, value).0;
Expr::Binary(
Box::new(Expr::Id(ast::Name::Ident(column_name.to_string()))),
ast::Operator::Less,
@@ -54,7 +55,12 @@ impl Predicate {
}
/// Produces a true [ast::Expr::Binary] [Predicate] that is true for the provided row in the given table
pub fn true_binary<R: rand::Rng>(rng: &mut R, t: &Table, row: &[SimValue]) -> Predicate {
pub fn true_binary<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
t: &Table,
row: &[SimValue],
) -> Predicate {
// Pick a column
let column_index = rng.random_range(0..t.columns.len());
let mut column = t.columns[column_index].clone();
@@ -93,7 +99,7 @@ impl Predicate {
(
1,
Box::new(|rng| {
let v = SimValue::arbitrary_from(rng, &column.column_type);
let v = SimValue::arbitrary_from(rng, context, &column.column_type);
if &v == value {
None
} else {
@@ -111,7 +117,7 @@ impl Predicate {
(
1,
Box::new(|rng| {
let lt_value = LTValue::arbitrary_from(rng, value).0;
let lt_value = LTValue::arbitrary_from(rng, context, value).0;
Some(Expr::Binary(
Box::new(ast::Expr::Qualified(
ast::Name::new(&table_name),
@@ -125,7 +131,7 @@ impl Predicate {
(
1,
Box::new(|rng| {
let gt_value = GTValue::arbitrary_from(rng, value).0;
let gt_value = GTValue::arbitrary_from(rng, context, value).0;
Some(Expr::Binary(
Box::new(ast::Expr::Qualified(
ast::Name::new(&table_name),
@@ -140,7 +146,7 @@ impl Predicate {
1,
Box::new(|rng| {
// TODO: generation for Like and Glob expressions should be extracted to different module
LikeValue::arbitrary_from_maybe(rng, value).map(|like| {
LikeValue::arbitrary_from_maybe(rng, context, value).map(|like| {
Expr::Like {
lhs: Box::new(ast::Expr::Qualified(
ast::Name::new(&table_name),
@@ -162,7 +168,12 @@ impl Predicate {
}
/// Produces an [ast::Expr::Binary] [Predicate] that is false for the provided row in the given table
pub fn false_binary<R: rand::Rng>(rng: &mut R, t: &Table, row: &[SimValue]) -> Predicate {
pub fn false_binary<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
t: &Table,
row: &[SimValue],
) -> Predicate {
// Pick a column
let column_index = rng.random_range(0..t.columns.len());
let mut column = t.columns[column_index].clone();
@@ -197,7 +208,7 @@ impl Predicate {
}),
Box::new(|rng| {
let v = loop {
let v = SimValue::arbitrary_from(rng, &column.column_type);
let v = SimValue::arbitrary_from(rng, context, &column.column_type);
if &v != value {
break v;
}
@@ -212,7 +223,7 @@ impl Predicate {
)
}),
Box::new(|rng| {
let gt_value = GTValue::arbitrary_from(rng, value).0;
let gt_value = GTValue::arbitrary_from(rng, context, value).0;
Expr::Binary(
Box::new(ast::Expr::Qualified(
ast::Name::new(&table_name),
@@ -223,7 +234,7 @@ impl Predicate {
)
}),
Box::new(|rng| {
let lt_value = LTValue::arbitrary_from(rng, value).0;
let lt_value = LTValue::arbitrary_from(rng, context, value).0;
Expr::Binary(
Box::new(ast::Expr::Qualified(
ast::Name::new(&table_name),
@@ -242,8 +253,9 @@ impl Predicate {
impl SimplePredicate {
/// Generates a true [ast::Expr::Binary] [SimplePredicate] from a [TableContext] for a row in the table
pub fn true_binary<R: rand::Rng, T: TableContext>(
pub fn true_binary<R: rand::Rng, C: GenerationContext, T: TableContext>(
rng: &mut R,
context: &C,
table: &T,
row: &[SimValue],
) -> Self {
@@ -271,7 +283,7 @@ impl SimplePredicate {
)
}),
Box::new(|rng| {
let lt_value = LTValue::arbitrary_from(rng, column_value).0;
let lt_value = LTValue::arbitrary_from(rng, context, column_value).0;
Expr::Binary(
Box::new(Expr::Qualified(
ast::Name::new(table_name),
@@ -282,7 +294,7 @@ impl SimplePredicate {
)
}),
Box::new(|rng| {
let gt_value = GTValue::arbitrary_from(rng, column_value).0;
let gt_value = GTValue::arbitrary_from(rng, context, column_value).0;
Expr::Binary(
Box::new(Expr::Qualified(
ast::Name::new(table_name),
@@ -299,8 +311,9 @@ impl SimplePredicate {
}
/// Generates a false [ast::Expr::Binary] [SimplePredicate] from a [TableContext] for a row in the table
pub fn false_binary<R: rand::Rng, T: TableContext>(
pub fn false_binary<R: rand::Rng, C: GenerationContext, T: TableContext>(
rng: &mut R,
context: &C,
table: &T,
row: &[SimValue],
) -> Self {
@@ -328,7 +341,7 @@ impl SimplePredicate {
)
}),
Box::new(|rng| {
let gt_value = GTValue::arbitrary_from(rng, column_value).0;
let gt_value = GTValue::arbitrary_from(rng, context, column_value).0;
Expr::Binary(
Box::new(ast::Expr::Qualified(
ast::Name::new(table_name),
@@ -339,7 +352,7 @@ impl SimplePredicate {
)
}),
Box::new(|rng| {
let lt_value = LTValue::arbitrary_from(rng, column_value).0;
let lt_value = LTValue::arbitrary_from(rng, context, column_value).0;
Expr::Binary(
Box::new(ast::Expr::Qualified(
ast::Name::new(table_name),
@@ -360,8 +373,9 @@ impl CompoundPredicate {
/// Decide if you want to create an AND or an OR
///
/// Creates a Compound Predicate that is TRUE or FALSE for at least a single row
pub fn from_table_binary<R: rand::Rng, T: TableContext>(
pub fn from_table_binary<R: rand::Rng, C: GenerationContext, T: TableContext>(
rng: &mut R,
context: &C,
table: &T,
predicate_value: bool,
) -> Self {
@@ -381,7 +395,7 @@ impl CompoundPredicate {
// An AND for false requires at least one of its children to be false
if predicate_value {
(0..rng.random_range(1..=3))
.map(|_| SimplePredicate::arbitrary_from(rng, (table, row, true)).0)
.map(|_| SimplePredicate::arbitrary_from(rng, context, (table, row, true)).0)
.reduce(|accum, curr| {
Predicate(Expr::Binary(
Box::new(accum.0),
@@ -405,7 +419,7 @@ impl CompoundPredicate {
booleans
.iter()
.map(|b| SimplePredicate::arbitrary_from(rng, (table, row, *b)).0)
.map(|b| SimplePredicate::arbitrary_from(rng, context, (table, row, *b)).0)
.reduce(|accum, curr| {
Predicate(Expr::Binary(
Box::new(accum.0),
@@ -431,7 +445,7 @@ impl CompoundPredicate {
booleans
.iter()
.map(|b| SimplePredicate::arbitrary_from(rng, (table, row, *b)).0)
.map(|b| SimplePredicate::arbitrary_from(rng, context, (table, row, *b)).0)
.reduce(|accum, curr| {
Predicate(Expr::Binary(
Box::new(accum.0),
@@ -442,7 +456,7 @@ impl CompoundPredicate {
.unwrap_or(Predicate::true_())
} else {
(0..rng.random_range(1..=3))
.map(|_| SimplePredicate::arbitrary_from(rng, (table, row, false)).0)
.map(|_| SimplePredicate::arbitrary_from(rng, context, (table, row, false)).0)
.reduce(|accum, curr| {
Predicate(Expr::Binary(
Box::new(accum.0),
@@ -463,7 +477,9 @@ mod tests {
use rand_chacha::ChaCha8Rng;
use crate::{
generation::{pick, predicate::SimplePredicate, Arbitrary, ArbitraryFrom as _},
generation::{
pick, predicate::SimplePredicate, tests::TestContext, Arbitrary, ArbitraryFrom as _,
},
model::{
query::predicate::{expr_to_value, Predicate},
table::{SimValue, Table},
@@ -481,20 +497,22 @@ mod tests {
fn fuzz_true_binary_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let table = Table::arbitrary(&mut rng);
let table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
let row = pick(&values, &mut rng);
let predicate = Predicate::true_binary(&mut rng, &table, row);
let predicate = Predicate::true_binary(&mut rng, context, &table, row);
let value = expr_to_value(&predicate.0, row, &table);
assert!(
value.as_ref().is_some_and(|value| value.as_bool()),
@@ -507,20 +525,22 @@ mod tests {
fn fuzz_false_binary_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let table = Table::arbitrary(&mut rng);
let table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
let row = pick(&values, &mut rng);
let predicate = Predicate::false_binary(&mut rng, &table, row);
let predicate = Predicate::false_binary(&mut rng, context, &table, row);
let value = expr_to_value(&predicate.0, row, &table);
assert!(
!value.as_ref().is_some_and(|value| value.as_bool()),
@@ -533,21 +553,23 @@ mod tests {
fn fuzz_true_binary_simple_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let mut table = Table::arbitrary(&mut rng);
let mut table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
table.rows.extend(values.clone());
let row = pick(&table.rows, &mut rng);
let predicate = SimplePredicate::true_binary(&mut rng, &table, row);
let predicate = SimplePredicate::true_binary(&mut rng, context, &table, row);
let result = values
.iter()
.map(|row| predicate.0.test(row, &table))
@@ -561,21 +583,23 @@ mod tests {
fn fuzz_false_binary_simple_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let mut table = Table::arbitrary(&mut rng);
let mut table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
table.rows.extend(values.clone());
let row = pick(&table.rows, &mut rng);
let predicate = SimplePredicate::false_binary(&mut rng, &table, row);
let predicate = SimplePredicate::false_binary(&mut rng, context, &table, row);
let result = values
.iter()
.map(|row| predicate.0.test(row, &table))

View File

@@ -1,9 +1,12 @@
use rand::{seq::SliceRandom as _, Rng};
use turso_parser::ast::{self, Expr};
use crate::model::{
query::predicate::Predicate,
table::{SimValue, Table, TableContext},
use crate::{
generation::GenerationContext,
model::{
query::predicate::Predicate,
table::{SimValue, Table, TableContext},
},
};
use super::{one_of, ArbitraryFrom};
@@ -18,20 +21,24 @@ struct CompoundPredicate(Predicate);
struct SimplePredicate(Predicate);
impl<A: AsRef<[SimValue]>, T: TableContext> ArbitraryFrom<(&T, A, bool)> for SimplePredicate {
fn arbitrary_from<R: Rng>(rng: &mut R, (table, row, predicate_value): (&T, A, bool)) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
(table, row, predicate_value): (&T, A, bool),
) -> Self {
let row = row.as_ref();
// Pick an operator
let choice = rng.random_range(0..2);
// Pick an operator
match predicate_value {
true => match choice {
0 => SimplePredicate::true_binary(rng, table, row),
1 => SimplePredicate::true_unary(rng, table, row),
0 => SimplePredicate::true_binary(rng, context, table, row),
1 => SimplePredicate::true_unary(rng, context, table, row),
_ => unreachable!(),
},
false => match choice {
0 => SimplePredicate::false_binary(rng, table, row),
1 => SimplePredicate::false_unary(rng, table, row),
0 => SimplePredicate::false_binary(rng, context, table, row),
1 => SimplePredicate::false_unary(rng, context, table, row),
_ => unreachable!(),
},
}
@@ -39,43 +46,59 @@ impl<A: AsRef<[SimValue]>, T: TableContext> ArbitraryFrom<(&T, A, bool)> for Sim
}
impl<T: TableContext> ArbitraryFrom<(&T, bool)> for CompoundPredicate {
fn arbitrary_from<R: Rng>(rng: &mut R, (table, predicate_value): (&T, bool)) -> Self {
CompoundPredicate::from_table_binary(rng, table, predicate_value)
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
(table, predicate_value): (&T, bool),
) -> Self {
CompoundPredicate::from_table_binary(rng, context, table, predicate_value)
}
}
impl<T: TableContext> ArbitraryFrom<&T> for Predicate {
fn arbitrary_from<R: Rng>(rng: &mut R, table: &T) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(rng: &mut R, context: &C, table: &T) -> Self {
let predicate_value = rng.random_bool(0.5);
Predicate::arbitrary_from(rng, (table, predicate_value)).parens()
Predicate::arbitrary_from(rng, context, (table, predicate_value)).parens()
}
}
impl<T: TableContext> ArbitraryFrom<(&T, bool)> for Predicate {
fn arbitrary_from<R: Rng>(rng: &mut R, (table, predicate_value): (&T, bool)) -> Self {
CompoundPredicate::arbitrary_from(rng, (table, predicate_value)).0
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
(table, predicate_value): (&T, bool),
) -> Self {
CompoundPredicate::arbitrary_from(rng, context, (table, predicate_value)).0
}
}
impl ArbitraryFrom<(&str, &SimValue)> for Predicate {
fn arbitrary_from<R: Rng>(rng: &mut R, (column_name, value): (&str, &SimValue)) -> Self {
Predicate::from_column_binary(rng, column_name, value)
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
(column_name, value): (&str, &SimValue),
) -> Self {
Predicate::from_column_binary(rng, context, column_name, value)
}
}
impl ArbitraryFrom<(&Table, &Vec<SimValue>)> for Predicate {
fn arbitrary_from<R: Rng>(rng: &mut R, (t, row): (&Table, &Vec<SimValue>)) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
(t, row): (&Table, &Vec<SimValue>),
) -> Self {
// We want to produce a predicate that is true for the row
// We can do this by creating several predicates that
// are true, some that are false, combiend them in ways that correspond to the creation of a true predicate
// Produce some true and false predicates
let mut true_predicates = (1..=rng.random_range(1..=4))
.map(|_| Predicate::true_binary(rng, t, row))
.map(|_| Predicate::true_binary(rng, context, t, row))
.collect::<Vec<_>>();
let false_predicates = (0..=rng.random_range(0..=3))
.map(|_| Predicate::false_binary(rng, t, row))
.map(|_| Predicate::false_binary(rng, context, t, row))
.collect::<Vec<_>>();
// Start building a top level predicate from a true predicate
@@ -231,7 +254,9 @@ mod tests {
use rand_chacha::ChaCha8Rng;
use crate::{
generation::{pick, predicate::SimplePredicate, Arbitrary, ArbitraryFrom as _},
generation::{
pick, predicate::SimplePredicate, tests::TestContext, Arbitrary, ArbitraryFrom as _,
},
model::{
query::predicate::{expr_to_value, Predicate},
table::{SimValue, Table},
@@ -249,20 +274,23 @@ mod tests {
fn fuzz_arbitrary_table_true_simple_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let table = Table::arbitrary(&mut rng);
let table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
let row = pick(&values, &mut rng);
let predicate = SimplePredicate::arbitrary_from(&mut rng, (&table, row, true)).0;
let predicate =
SimplePredicate::arbitrary_from(&mut rng, context, (&table, row, true)).0;
let value = expr_to_value(&predicate.0, row, &table);
assert!(
value.as_ref().is_some_and(|value| value.as_bool()),
@@ -275,20 +303,23 @@ mod tests {
fn fuzz_arbitrary_table_false_simple_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let table = Table::arbitrary(&mut rng);
let table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
let row = pick(&values, &mut rng);
let predicate = SimplePredicate::arbitrary_from(&mut rng, (&table, row, false)).0;
let predicate =
SimplePredicate::arbitrary_from(&mut rng, context, (&table, row, false)).0;
let value = expr_to_value(&predicate.0, row, &table);
assert!(
!value.as_ref().is_some_and(|value| value.as_bool()),
@@ -301,20 +332,22 @@ mod tests {
fn fuzz_arbitrary_row_table_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let table = Table::arbitrary(&mut rng);
let table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
let row = pick(&values, &mut rng);
let predicate = Predicate::arbitrary_from(&mut rng, (&table, row));
let predicate = Predicate::arbitrary_from(&mut rng, context, (&table, row));
let value = expr_to_value(&predicate.0, row, &table);
assert!(
value.as_ref().is_some_and(|value| value.as_bool()),
@@ -327,20 +360,22 @@ mod tests {
fn fuzz_arbitrary_true_table_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let mut table = Table::arbitrary(&mut rng);
let mut table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
table.rows.extend(values.clone());
let predicate = Predicate::arbitrary_from(&mut rng, (&table, true));
let predicate = Predicate::arbitrary_from(&mut rng, context, (&table, true));
let result = values
.iter()
.map(|row| predicate.test(row, &table))
@@ -354,20 +389,22 @@ mod tests {
fn fuzz_arbitrary_false_table_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let mut table = Table::arbitrary(&mut rng);
let mut table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
table.rows.extend(values.clone());
let predicate = Predicate::arbitrary_from(&mut rng, (&table, false));
let predicate = Predicate::arbitrary_from(&mut rng, context, (&table, false));
let result = values
.iter()
.map(|row| predicate.test(row, &table))

View File

@@ -5,7 +5,9 @@
use turso_parser::ast::{self, Expr};
use crate::{
generation::{backtrack, pick, predicate::SimplePredicate, ArbitraryFromMaybe},
generation::{
backtrack, pick, predicate::SimplePredicate, ArbitraryFromMaybe, GenerationContext,
},
model::{
query::predicate::Predicate,
table::{SimValue, TableContext},
@@ -15,7 +17,11 @@ use crate::{
pub struct TrueValue(pub SimValue);
impl ArbitraryFromMaybe<&SimValue> for TrueValue {
fn arbitrary_from_maybe<R: rand::Rng>(_rng: &mut R, value: &SimValue) -> Option<Self>
fn arbitrary_from_maybe<R: rand::Rng, C: GenerationContext>(
_rng: &mut R,
_context: &C,
value: &SimValue,
) -> Option<Self>
where
Self: Sized,
{
@@ -25,7 +31,11 @@ impl ArbitraryFromMaybe<&SimValue> for TrueValue {
}
impl ArbitraryFromMaybe<&Vec<&SimValue>> for TrueValue {
fn arbitrary_from_maybe<R: rand::Rng>(rng: &mut R, values: &Vec<&SimValue>) -> Option<Self>
fn arbitrary_from_maybe<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
values: &Vec<&SimValue>,
) -> Option<Self>
where
Self: Sized,
{
@@ -34,14 +44,18 @@ impl ArbitraryFromMaybe<&Vec<&SimValue>> for TrueValue {
}
let value = pick(values, rng);
Self::arbitrary_from_maybe(rng, *value)
Self::arbitrary_from_maybe(rng, context, *value)
}
}
pub struct FalseValue(pub SimValue);
impl ArbitraryFromMaybe<&SimValue> for FalseValue {
fn arbitrary_from_maybe<R: rand::Rng>(_rng: &mut R, value: &SimValue) -> Option<Self>
fn arbitrary_from_maybe<R: rand::Rng, C: GenerationContext>(
_rng: &mut R,
_context: &C,
value: &SimValue,
) -> Option<Self>
where
Self: Sized,
{
@@ -51,7 +65,11 @@ impl ArbitraryFromMaybe<&SimValue> for FalseValue {
}
impl ArbitraryFromMaybe<&Vec<&SimValue>> for FalseValue {
fn arbitrary_from_maybe<R: rand::Rng>(rng: &mut R, values: &Vec<&SimValue>) -> Option<Self>
fn arbitrary_from_maybe<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
values: &Vec<&SimValue>,
) -> Option<Self>
where
Self: Sized,
{
@@ -60,7 +78,7 @@ impl ArbitraryFromMaybe<&Vec<&SimValue>> for FalseValue {
}
let value = pick(values, rng);
Self::arbitrary_from_maybe(rng, *value)
Self::arbitrary_from_maybe(rng, context, *value)
}
}
@@ -68,8 +86,9 @@ impl ArbitraryFromMaybe<&Vec<&SimValue>> for FalseValue {
pub struct BitNotValue(pub SimValue);
impl ArbitraryFromMaybe<(&SimValue, bool)> for BitNotValue {
fn arbitrary_from_maybe<R: rand::Rng>(
fn arbitrary_from_maybe<R: rand::Rng, C: GenerationContext>(
_rng: &mut R,
_context: &C,
(value, predicate): (&SimValue, bool),
) -> Option<Self>
where
@@ -82,8 +101,9 @@ impl ArbitraryFromMaybe<(&SimValue, bool)> for BitNotValue {
}
impl ArbitraryFromMaybe<(&Vec<&SimValue>, bool)> for BitNotValue {
fn arbitrary_from_maybe<R: rand::Rng>(
fn arbitrary_from_maybe<R: rand::Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
(values, predicate): (&Vec<&SimValue>, bool),
) -> Option<Self>
where
@@ -94,15 +114,16 @@ impl ArbitraryFromMaybe<(&Vec<&SimValue>, bool)> for BitNotValue {
}
let value = pick(values, rng);
Self::arbitrary_from_maybe(rng, (*value, predicate))
Self::arbitrary_from_maybe(rng, context, (*value, predicate))
}
}
// TODO: have some more complex generation with columns names here as well
impl SimplePredicate {
/// Generates a true [ast::Expr::Unary] [SimplePredicate] from a [TableContext] for some values in the table
pub fn true_unary<R: rand::Rng, T: TableContext>(
pub fn true_unary<R: rand::Rng, C: GenerationContext, T: TableContext>(
rng: &mut R,
context: &C,
table: &T,
row: &[SimValue],
) -> Self {
@@ -120,7 +141,7 @@ impl SimplePredicate {
(
num_retries,
Box::new(|rng| {
TrueValue::arbitrary_from_maybe(rng, column_value).map(|value| {
TrueValue::arbitrary_from_maybe(rng, context, column_value).map(|value| {
assert!(value.0.as_bool());
// Positive is a no-op in Sqlite
Expr::unary(ast::UnaryOperator::Positive, Expr::Literal(value.0.into()))
@@ -151,7 +172,7 @@ impl SimplePredicate {
(
num_retries,
Box::new(|rng| {
FalseValue::arbitrary_from_maybe(rng, column_value).map(|value| {
FalseValue::arbitrary_from_maybe(rng, context, column_value).map(|value| {
assert!(!value.0.as_bool());
Expr::unary(ast::UnaryOperator::Not, Expr::Literal(value.0.into()))
})
@@ -167,8 +188,9 @@ impl SimplePredicate {
}
/// Generates a false [ast::Expr::Unary] [SimplePredicate] from a [TableContext] for a row in the table
pub fn false_unary<R: rand::Rng, T: TableContext>(
pub fn false_unary<R: rand::Rng, C: GenerationContext, T: TableContext>(
rng: &mut R,
context: &C,
table: &T,
row: &[SimValue],
) -> Self {
@@ -217,7 +239,7 @@ impl SimplePredicate {
(
num_retries,
Box::new(|rng| {
TrueValue::arbitrary_from_maybe(rng, column_value).map(|value| {
TrueValue::arbitrary_from_maybe(rng, context, column_value).map(|value| {
assert!(value.0.as_bool());
Expr::unary(ast::UnaryOperator::Not, Expr::Literal(value.0.into()))
})
@@ -239,7 +261,9 @@ mod tests {
use rand_chacha::ChaCha8Rng;
use crate::{
generation::{pick, predicate::SimplePredicate, Arbitrary, ArbitraryFrom as _},
generation::{
pick, predicate::SimplePredicate, tests::TestContext, Arbitrary, ArbitraryFrom as _,
},
model::table::{SimValue, Table},
};
@@ -254,21 +278,23 @@ mod tests {
fn fuzz_true_unary_simple_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let mut table = Table::arbitrary(&mut rng);
let mut table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
table.rows.extend(values.clone());
let row = pick(&table.rows, &mut rng);
let predicate = SimplePredicate::true_unary(&mut rng, &table, row);
let predicate = SimplePredicate::true_unary(&mut rng, context, &table, row);
let result = values
.iter()
.map(|row| predicate.0.test(row, &table))
@@ -282,21 +308,23 @@ mod tests {
fn fuzz_false_unary_simple_predicate() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let context = &TestContext::default();
for _ in 0..10000 {
let mut table = Table::arbitrary(&mut rng);
let mut table = Table::arbitrary(&mut rng, context);
let num_rows = rng.random_range(1..10);
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(&mut rng, context, &c.column_type))
.collect()
})
.collect();
table.rows.extend(values.clone());
let row = pick(&table.rows, &mut rng);
let predicate = SimplePredicate::false_unary(&mut rng, &table, row);
let predicate = SimplePredicate::false_unary(&mut rng, context, &table, row);
let result = values
.iter()
.map(|row| predicate.0.test(row, &table))

View File

@@ -1,5 +1,5 @@
use crate::generation::{
gen_random_text, pick_n_unique, pick_unique, Arbitrary, ArbitraryFrom, ArbitrarySizedFrom,
gen_random_text, pick_n_unique, pick_unique, Arbitrary, ArbitraryFrom, ArbitrarySized,
GenerationContext,
};
use crate::model::query::predicate::Predicate;
@@ -17,23 +17,20 @@ use turso_parser::ast::{Expr, SortOrder};
use super::{backtrack, pick};
impl Arbitrary for Create {
fn arbitrary<R: Rng>(rng: &mut R) -> Self {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, context: &C) -> Self {
Create {
table: Table::arbitrary(rng),
table: Table::arbitrary(rng, context),
}
}
}
impl ArbitraryFrom<&Vec<Table>> for FromClause {
fn arbitrary_from<R: Rng>(rng: &mut R, tables: &Vec<Table>) -> Self {
let num_joins = match rng.random_range(0..=100) {
0..=90 => 0,
91..=97 => 1,
98..=100 => 2,
_ => unreachable!(),
};
impl Arbitrary for FromClause {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, context: &C) -> Self {
let opts = &context.opts().query.from_clause;
let weights = opts.as_weighted_index();
let num_joins = opts.joins[rng.sample(weights)].num_joins;
let mut tables = tables.clone();
let mut tables = context.tables().clone();
let mut table = pick(&tables, rng).clone();
tables.retain(|t| t.name != table.name);
@@ -74,7 +71,7 @@ impl ArbitraryFrom<&Vec<Table>> for FromClause {
);
}
let predicate = Predicate::arbitrary_from(rng, &table);
let predicate = Predicate::arbitrary_from(rng, context, &table);
Some(JoinedTable {
table: joined_table_name,
join_type: JoinType::Inner,
@@ -86,31 +83,32 @@ impl ArbitraryFrom<&Vec<Table>> for FromClause {
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for SelectInner {
fn arbitrary_from<R: Rng>(rng: &mut R, env: &C) -> Self {
let from = FromClause::arbitrary_from(rng, env.tables());
impl Arbitrary for SelectInner {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, env: &C) -> Self {
let from = FromClause::arbitrary(rng, env);
let tables = env.tables().clone();
let join_table = from.into_join_table(&tables);
let cuml_col_count = join_table.columns().count();
let order_by = 'order_by: {
if rng.random_bool(0.3) {
let order_by = rng
.random_bool(env.opts().query.select.order_by_prob)
.then(|| {
let order_by_table_candidates = from
.joins
.iter()
.map(|j| j.table.clone())
.chain(std::iter::once(from.table.clone()))
.map(|j| &j.table)
.chain(std::iter::once(&from.table))
.collect::<Vec<_>>();
let order_by_col_count =
(rng.random::<f64>() * rng.random::<f64>() * (cuml_col_count as f64)) as usize; // skew towards 0
if order_by_col_count == 0 {
break 'order_by None;
return None;
}
let mut col_names = std::collections::HashSet::new();
let mut order_by_cols = Vec::new();
while order_by_cols.len() < order_by_col_count {
let table = pick(&order_by_table_candidates, rng);
let table = tables.iter().find(|t| t.name == *table).unwrap();
let table = tables.iter().find(|t| t.name == table.as_str()).unwrap();
let col = pick(&table.columns, rng);
let col_name = format!("{}.{}", table.name, col.name);
if col_names.insert(col_name.clone()) {
@@ -127,38 +125,38 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for SelectInner {
Some(OrderBy {
columns: order_by_cols,
})
} else {
None
}
};
})
.flatten();
SelectInner {
distinctness: if env.opts().indexes {
Distinctness::arbitrary(rng)
Distinctness::arbitrary(rng, env)
} else {
Distinctness::All
},
columns: vec![ResultColumn::Star],
from: Some(from),
where_clause: Predicate::arbitrary_from(rng, &join_table),
where_clause: Predicate::arbitrary_from(rng, env, &join_table),
order_by,
}
}
}
impl<C: GenerationContext> ArbitrarySizedFrom<&C> for SelectInner {
fn arbitrary_sized_from<R: Rng>(rng: &mut R, env: &C, num_result_columns: usize) -> Self {
let mut select_inner = SelectInner::arbitrary_from(rng, env);
impl ArbitrarySized for SelectInner {
fn arbitrary_sized<R: Rng, C: GenerationContext>(
rng: &mut R,
env: &C,
num_result_columns: usize,
) -> Self {
let mut select_inner = SelectInner::arbitrary(rng, env);
let select_from = &select_inner.from.as_ref().unwrap();
let table_names = select_from
.joins
.iter()
.map(|j| j.table.clone())
.chain(std::iter::once(select_from.table.clone()))
.collect::<Vec<_>>();
.map(|j| &j.table)
.chain(std::iter::once(&select_from.table));
let flat_columns_names = table_names
.iter()
.flat_map(|t| {
env.tables()
.iter()
@@ -166,29 +164,30 @@ impl<C: GenerationContext> ArbitrarySizedFrom<&C> for SelectInner {
.unwrap()
.columns
.iter()
.map(|c| format!("{}.{}", t.clone(), c.name))
.map(move |c| format!("{}.{}", t, c.name))
})
.collect::<Vec<_>>();
let selected_columns = pick_unique(&flat_columns_names, num_result_columns, rng);
let mut columns = Vec::new();
for column_name in selected_columns {
columns.push(ResultColumn::Column(column_name.clone()));
}
let columns = selected_columns
.map(|col_name| ResultColumn::Column(col_name.clone()))
.collect();
select_inner.columns = columns;
select_inner
}
}
impl Arbitrary for Distinctness {
fn arbitrary<R: Rng>(rng: &mut R) -> Self {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, _context: &C) -> Self {
match rng.random_range(0..=5) {
0..4 => Distinctness::All,
_ => Distinctness::Distinct,
}
}
}
impl Arbitrary for CompoundOperator {
fn arbitrary<R: Rng>(rng: &mut R) -> Self {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, _context: &C) -> Self {
match rng.random_range(0..=1) {
0 => CompoundOperator::Union,
1 => CompoundOperator::UnionAll,
@@ -202,26 +201,23 @@ impl Arbitrary for CompoundOperator {
/// arbitrary expressions without referring to the tables.
pub struct SelectFree(pub Select);
impl<C: GenerationContext> ArbitraryFrom<&C> for SelectFree {
fn arbitrary_from<R: Rng>(rng: &mut R, env: &C) -> Self {
let expr = Predicate(Expr::arbitrary_sized_from(rng, env, 8));
impl Arbitrary for SelectFree {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, env: &C) -> Self {
let expr = Predicate(Expr::arbitrary_sized(rng, env, 8));
let select = Select::expr(expr);
Self(select)
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for Select {
fn arbitrary_from<R: Rng>(rng: &mut R, env: &C) -> Self {
impl Arbitrary for Select {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, env: &C) -> Self {
// Generate a number of selects based on the query size
// If experimental indexes are enabled, we can have selects with compounds
// Otherwise, we just have a single select with no compounds
let opts = &env.opts().query.select;
let num_compound_selects = if env.opts().indexes {
match rng.random_range(0..=100) {
0..=95 => 0,
96..=99 => 1,
100 => 2,
_ => unreachable!(),
}
opts.compound_selects[rng.sample(opts.compound_select_weighted_index())]
.num_compound_selects
} else {
0
};
@@ -231,10 +227,10 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for Select {
let num_result_columns = rng.random_range(1..=min_column_count_across_tables);
let mut first = SelectInner::arbitrary_sized_from(rng, env, num_result_columns);
let mut first = SelectInner::arbitrary_sized(rng, env, num_result_columns);
let mut rest: Vec<SelectInner> = (0..num_compound_selects)
.map(|_| SelectInner::arbitrary_sized_from(rng, env, num_result_columns))
.map(|_| SelectInner::arbitrary_sized(rng, env, num_result_columns))
.collect();
if !rest.is_empty() {
@@ -251,7 +247,7 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for Select {
compounds: rest
.into_iter()
.map(|s| CompoundSelect {
operator: CompoundOperator::arbitrary(rng),
operator: CompoundOperator::arbitrary(rng, env),
select: Box::new(s),
})
.collect(),
@@ -261,17 +257,18 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for Select {
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for Insert {
fn arbitrary_from<R: Rng>(rng: &mut R, env: &C) -> Self {
impl Arbitrary for Insert {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, env: &C) -> Self {
let opts = &env.opts().query.insert;
let gen_values = |rng: &mut R| {
let table = pick(env.tables(), rng);
let num_rows = rng.random_range(1..10);
let num_rows = rng.random_range(opts.min_rows.get()..opts.max_rows.get());
let values: Vec<Vec<SimValue>> = (0..num_rows)
.map(|_| {
table
.columns
.iter()
.map(|c| SimValue::arbitrary_from(rng, &c.column_type))
.map(|c| SimValue::arbitrary_from(rng, env, &c.column_type))
.collect()
})
.collect();
@@ -285,7 +282,7 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for Insert {
// Find a non-empty table
let select_table = env.tables().iter().find(|t| !t.rows.is_empty())?;
let row = pick(&select_table.rows, rng);
let predicate = Predicate::arbitrary_from(rng, (select_table, row));
let predicate = Predicate::arbitrary_from(rng, env, (select_table, row));
// Pick another table to insert into
let select = Select::simple(select_table.name.clone(), predicate);
let table = pick(env.tables(), rng);
@@ -301,18 +298,18 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for Insert {
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for Delete {
fn arbitrary_from<R: Rng>(rng: &mut R, env: &C) -> Self {
impl Arbitrary for Delete {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, env: &C) -> Self {
let table = pick(env.tables(), rng);
Self {
table: table.name.clone(),
predicate: Predicate::arbitrary_from(rng, table),
predicate: Predicate::arbitrary_from(rng, env, table),
}
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for Drop {
fn arbitrary_from<R: Rng>(rng: &mut R, env: &C) -> Self {
impl Arbitrary for Drop {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, env: &C) -> Self {
let table = pick(env.tables(), rng);
Self {
table: table.name.clone(),
@@ -320,8 +317,8 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for Drop {
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for CreateIndex {
fn arbitrary_from<R: Rng>(rng: &mut R, env: &C) -> Self {
impl Arbitrary for CreateIndex {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, env: &C) -> Self {
assert!(
!env.tables().is_empty(),
"Cannot create an index when no tables exist in the environment."
@@ -340,7 +337,6 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for CreateIndex {
let picked_column_indices = pick_n_unique(0..table.columns.len(), num_columns_to_pick, rng);
let columns = picked_column_indices
.into_iter()
.map(|i| {
let column = &table.columns[i];
(
@@ -368,24 +364,23 @@ impl<C: GenerationContext> ArbitraryFrom<&C> for CreateIndex {
}
}
impl<C: GenerationContext> ArbitraryFrom<&C> for Update {
fn arbitrary_from<R: Rng>(rng: &mut R, env: &C) -> Self {
impl Arbitrary for Update {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, env: &C) -> Self {
let table = pick(env.tables(), rng);
let num_cols = rng.random_range(1..=table.columns.len());
let columns = pick_unique(&table.columns, num_cols, rng);
let set_values: Vec<(String, SimValue)> = columns
.iter()
.map(|column| {
(
column.name.clone(),
SimValue::arbitrary_from(rng, &column.column_type),
SimValue::arbitrary_from(rng, env, &column.column_type),
)
})
.collect();
Update {
table: table.name.clone(),
set_values,
predicate: Predicate::arbitrary_from(rng, table),
predicate: Predicate::arbitrary_from(rng, env, table),
}
}
}

View File

@@ -3,54 +3,52 @@ use std::collections::HashSet;
use rand::Rng;
use turso_core::Value;
use crate::generation::{gen_random_text, pick, readable_name_custom, Arbitrary, ArbitraryFrom};
use crate::generation::{
gen_random_text, pick, readable_name_custom, Arbitrary, ArbitraryFrom, GenerationContext,
};
use crate::model::table::{Column, ColumnType, Name, SimValue, Table};
use super::ArbitraryFromMaybe;
impl Arbitrary for Name {
fn arbitrary<R: Rng>(rng: &mut R) -> Self {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, _c: &C) -> Self {
let name = readable_name_custom("_", rng);
Name(name.replace("-", "_"))
}
}
impl Arbitrary for Table {
fn arbitrary<R: Rng>(rng: &mut R) -> Self {
let name = Name::arbitrary(rng).0;
let columns = loop {
let large_table = rng.random_bool(0.1);
let column_size = if large_table {
rng.random_range(64..125) // todo: make this higher (128+)
} else {
rng.random_range(1..=10)
};
let columns = (1..=column_size)
.map(|_| Column::arbitrary(rng))
.collect::<Vec<_>>();
// TODO: see if there is a better way to detect duplicates here
let mut set = HashSet::with_capacity(columns.len());
set.extend(columns.iter());
// Has repeated column name inside so generate again
if set.len() != columns.len() {
continue;
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, context: &C) -> Self {
let opts = context.opts().table.clone();
let name = Name::arbitrary(rng, context).0;
let large_table =
opts.large_table.enable && rng.random_bool(opts.large_table.large_table_prob);
let column_size = if large_table {
rng.random_range(opts.large_table.column_range)
} else {
rng.random_range(opts.column_range)
} as usize;
let mut column_set = HashSet::with_capacity(column_size);
for col in std::iter::repeat_with(|| Column::arbitrary(rng, context)) {
column_set.insert(col);
if column_set.len() == column_size {
break;
}
break columns;
};
}
Table {
rows: Vec::new(),
name,
columns,
columns: Vec::from_iter(column_set),
indexes: vec![],
}
}
}
impl Arbitrary for Column {
fn arbitrary<R: Rng>(rng: &mut R) -> Self {
let name = Name::arbitrary(rng).0;
let column_type = ColumnType::arbitrary(rng);
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, context: &C) -> Self {
let name = Name::arbitrary(rng, context).0;
let column_type = ColumnType::arbitrary(rng, context);
Self {
name,
column_type,
@@ -61,16 +59,20 @@ impl Arbitrary for Column {
}
impl Arbitrary for ColumnType {
fn arbitrary<R: Rng>(rng: &mut R) -> Self {
fn arbitrary<R: Rng, C: GenerationContext>(rng: &mut R, _context: &C) -> Self {
pick(&[Self::Integer, Self::Float, Self::Text, Self::Blob], rng).to_owned()
}
}
impl ArbitraryFrom<&Table> for Vec<SimValue> {
fn arbitrary_from<R: Rng>(rng: &mut R, table: &Table) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
table: &Table,
) -> Self {
let mut row = Vec::new();
for column in table.columns.iter() {
let value = SimValue::arbitrary_from(rng, &column.column_type);
let value = SimValue::arbitrary_from(rng, context, &column.column_type);
row.push(value);
}
row
@@ -78,7 +80,11 @@ impl ArbitraryFrom<&Table> for Vec<SimValue> {
}
impl ArbitraryFrom<&Vec<&SimValue>> for SimValue {
fn arbitrary_from<R: Rng>(rng: &mut R, values: &Vec<&Self>) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
_context: &C,
values: &Vec<&Self>,
) -> Self {
if values.is_empty() {
return Self(Value::Null);
}
@@ -88,7 +94,11 @@ impl ArbitraryFrom<&Vec<&SimValue>> for SimValue {
}
impl ArbitraryFrom<&ColumnType> for SimValue {
fn arbitrary_from<R: Rng>(rng: &mut R, column_type: &ColumnType) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
_context: &C,
column_type: &ColumnType,
) -> Self {
let value = match column_type {
ColumnType::Integer => Value::Integer(rng.random_range(i64::MIN..i64::MAX)),
ColumnType::Float => Value::Float(rng.random_range(-1e10..1e10)),
@@ -102,19 +112,27 @@ impl ArbitraryFrom<&ColumnType> for SimValue {
pub struct LTValue(pub SimValue);
impl ArbitraryFrom<&Vec<&SimValue>> for LTValue {
fn arbitrary_from<R: Rng>(rng: &mut R, values: &Vec<&SimValue>) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
values: &Vec<&SimValue>,
) -> Self {
if values.is_empty() {
return Self(SimValue(Value::Null));
}
// Get value less than all values
let value = Value::exec_min(values.iter().map(|value| &value.0));
Self::arbitrary_from(rng, &SimValue(value))
Self::arbitrary_from(rng, context, &SimValue(value))
}
}
impl ArbitraryFrom<&SimValue> for LTValue {
fn arbitrary_from<R: Rng>(rng: &mut R, value: &SimValue) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
_context: &C,
value: &SimValue,
) -> Self {
let new_value = match &value.0 {
Value::Integer(i) => Value::Integer(rng.random_range(i64::MIN..*i - 1)),
Value::Float(f) => Value::Float(f - rng.random_range(0.0..1e10)),
@@ -164,19 +182,27 @@ impl ArbitraryFrom<&SimValue> for LTValue {
pub struct GTValue(pub SimValue);
impl ArbitraryFrom<&Vec<&SimValue>> for GTValue {
fn arbitrary_from<R: Rng>(rng: &mut R, values: &Vec<&SimValue>) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
context: &C,
values: &Vec<&SimValue>,
) -> Self {
if values.is_empty() {
return Self(SimValue(Value::Null));
}
// Get value greater than all values
let value = Value::exec_max(values.iter().map(|value| &value.0));
Self::arbitrary_from(rng, &SimValue(value))
Self::arbitrary_from(rng, context, &SimValue(value))
}
}
impl ArbitraryFrom<&SimValue> for GTValue {
fn arbitrary_from<R: Rng>(rng: &mut R, value: &SimValue) -> Self {
fn arbitrary_from<R: Rng, C: GenerationContext>(
rng: &mut R,
_context: &C,
value: &SimValue,
) -> Self {
let new_value = match &value.0 {
Value::Integer(i) => Value::Integer(rng.random_range(*i..i64::MAX)),
Value::Float(f) => Value::Float(rng.random_range(*f..1e10)),
@@ -226,7 +252,11 @@ impl ArbitraryFrom<&SimValue> for GTValue {
pub struct LikeValue(pub SimValue);
impl ArbitraryFromMaybe<&SimValue> for LikeValue {
fn arbitrary_from_maybe<R: Rng>(rng: &mut R, value: &SimValue) -> Option<Self> {
fn arbitrary_from_maybe<R: Rng, C: GenerationContext>(
rng: &mut R,
_context: &C,
value: &SimValue,
) -> Option<Self> {
match &value.0 {
value @ Value::Text(..) => {
let t = value.to_string();

View File

@@ -81,7 +81,7 @@ pub async fn db_bootstrap<C: ProtocolIO, Ctx>(
while !c.is_completed() {
coro.yield_(ProtocolCommand::IO).await?;
}
pos += content_len;
pos += content_len as u64;
}
if content.is_done()? {
break;
@@ -123,7 +123,7 @@ pub async fn wal_apply_from_file<Ctx>(
// todo(sivukhin): we need to error out in case of partial read
assert!(size as usize == WAL_FRAME_SIZE);
});
let c = frames_file.pread(offset as usize, c)?;
let c = frames_file.pread(offset, c)?;
while !c.is_completed() {
coro.yield_(ProtocolCommand::IO).await?;
}
@@ -243,7 +243,7 @@ pub async fn wal_pull_to_file_v1<C: ProtocolIO, Ctx>(
while !c.is_completed() {
coro.yield_(ProtocolCommand::IO).await?;
}
offset += WAL_FRAME_SIZE;
offset += WAL_FRAME_SIZE as u64;
}
let c = Completion::new_sync(move |_| {
@@ -323,7 +323,7 @@ pub async fn wal_pull_to_file_legacy<C: ProtocolIO, Ctx>(
coro.yield_(ProtocolCommand::IO).await?;
}
last_offset += WAL_FRAME_SIZE;
last_offset += WAL_FRAME_SIZE as u64;
buffer_len = 0;
start_frame += 1;
@@ -974,7 +974,7 @@ pub async fn bootstrap_db_file_v1<C: ProtocolIO, Ctx>(
};
assert!(rc as usize == 0);
});
let c = file.truncate(header.db_size as usize * PAGE_SIZE, c)?;
let c = file.truncate(header.db_size * PAGE_SIZE as u64, c)?;
while !c.is_completed() {
coro.yield_(ProtocolCommand::IO).await?;
}
@@ -984,7 +984,7 @@ pub async fn bootstrap_db_file_v1<C: ProtocolIO, Ctx>(
while let Some(page_data) =
wait_proto_message::<Ctx, PageData>(coro, &completion, &mut bytes).await?
{
let offset = page_data.page_id as usize * PAGE_SIZE;
let offset = page_data.page_id * PAGE_SIZE as u64;
let page = decode_page(&header, page_data)?;
if page.len() != PAGE_SIZE {
return Err(Error::DatabaseSyncEngineError(format!(
@@ -1074,7 +1074,7 @@ pub async fn reset_wal_file<Ctx>(
// let's truncate WAL file completely in order for this operation to safely execute on empty WAL in case of initial bootstrap phase
0
} else {
WAL_HEADER + WAL_FRAME_SIZE * (frames_count as usize)
WAL_HEADER as u64 + WAL_FRAME_SIZE as u64 * frames_count
};
tracing::debug!("reset db wal to the size of {} frames", frames_count);
let c = Completion::new_trunc(move |result| {

View File

@@ -59,7 +59,7 @@ impl IoOperations for Arc<dyn turso_core::IO> {
};
tracing::debug!("file truncated: rc={}", rc);
});
let c = file.truncate(len, c)?;
let c = file.truncate(len as u64, c)?;
while !c.is_completed() {
coro.yield_(ProtocolCommand::IO).await?;
}

View File

@@ -1,6 +1,6 @@
{
"name": "@tursodatabase/sync",
"version": "0.1.5-pre.1",
"version": "0.1.5-pre.2",
"repository": {
"type": "git",
"url": "https://github.com/tursodatabase/turso"

View File

@@ -144,3 +144,34 @@ do_execsql_test select-agg-json-array-object {
do_execsql_test select-distinct-agg-functions {
SELECT sum(distinct age), count(distinct age), avg(distinct age) FROM users;
} {5050|100|50.5}
do_execsql_test select-json-group-object {
select price,
json_group_object(cast (id as text), name)
from products
group by price
order by price;
} {1.0|{"9":"boots"}
18.0|{"3":"shirt"}
25.0|{"4":"sweater"}
33.0|{"10":"coat"}
70.0|{"6":"shorts"}
74.0|{"5":"sweatshirt"}
78.0|{"7":"jeans"}
79.0|{"1":"hat"}
81.0|{"11":"accessories"}
82.0|{"2":"cap","8":"sneakers"}}
do_execsql_test select-json-group-object-no-sorting-required {
select age,
json_group_object(cast (id as text), first_name)
from users
where first_name like 'Am%'
group by age
order by age
limit 5;
} {1|{"6737":"Amy"}
2|{"2297":"Amy","3580":"Amanda"}
3|{"3437":"Amanda"}
5|{"2378":"Amy","3227":"Amy","5605":"Amanda"}
7|{"2454":"Amber"}}

24
testing/alter_column.test Executable file
View File

@@ -0,0 +1,24 @@
#!/usr/bin/env tclsh
set testdir [file dirname $argv0]
source $testdir/tester.tcl
do_execsql_test_on_specific_db {:memory:} alter-column-rename-and-type {
CREATE TABLE t (a INTEGER);
CREATE INDEX i ON t (a);
ALTER TABLE t ALTER COLUMN a TO b BLOB;
SELECT sql FROM sqlite_schema;
} {
"CREATE TABLE t (b BLOB)"
"CREATE INDEX i ON t (b)"
}
do_execsql_test_in_memory_any_error fail-alter-column-primary-key {
CREATE TABLE t (a);
ALTER TABLE t ALTER COLUMN a TO a PRIMARY KEY;
}
do_execsql_test_in_memory_any_error fail-alter-column-unique {
CREATE TABLE t (a);
ALTER TABLE t ALTER COLUMN a TO a UNIQUE;
}

View File

@@ -10,12 +10,12 @@ do_execsql_test_on_specific_db {:memory:} alter-table-rename-table {
} { "t2" }
do_execsql_test_on_specific_db {:memory:} alter-table-rename-column {
CREATE TABLE t (a);
CREATE TABLE t (a INTEGER);
CREATE INDEX i ON t (a);
ALTER TABLE t RENAME a TO b;
SELECT sql FROM sqlite_schema;
} {
"CREATE TABLE t (b)"
"CREATE TABLE t (b INTEGER)"
"CREATE INDEX i ON t (b)"
}

View File

@@ -7,22 +7,22 @@ from cli_tests.test_turso_cli import TestTursoShell
sqlite_exec = "./scripts/limbo-sqlite3"
sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ")
test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL);
INSERT INTO numbers (value) VALUES (1.0);
INSERT INTO numbers (value) VALUES (2.0);
INSERT INTO numbers (value) VALUES (3.0);
INSERT INTO numbers (value) VALUES (4.0);
INSERT INTO numbers (value) VALUES (5.0);
INSERT INTO numbers (value) VALUES (6.0);
INSERT INTO numbers (value) VALUES (7.0);
CREATE TABLE test (value REAL, percent REAL);
INSERT INTO test values (10, 25);
INSERT INTO test values (20, 25);
INSERT INTO test values (30, 25);
INSERT INTO test values (40, 25);
INSERT INTO test values (50, 25);
INSERT INTO test values (60, 25);
INSERT INTO test values (70, 25);
test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL, category TEXT DEFAULT 'A');
INSERT INTO numbers (value, category) VALUES (1.0, 'A');
INSERT INTO numbers (value, category) VALUES (2.0, 'A');
INSERT INTO numbers (value, category) VALUES (3.0, 'A');
INSERT INTO numbers (value, category) VALUES (4.0, 'B');
INSERT INTO numbers (value, category) VALUES (5.0, 'B');
INSERT INTO numbers (value, category) VALUES (6.0, 'B');
INSERT INTO numbers (value, category) VALUES (7.0, 'B');
CREATE TABLE test (value REAL, percent REAL, category TEXT);
INSERT INTO test values (10, 25, 'A');
INSERT INTO test values (20, 25, 'A');
INSERT INTO test values (30, 25, 'B');
INSERT INTO test values (40, 25, 'C');
INSERT INTO test values (50, 25, 'C');
INSERT INTO test values (60, 25, 'C');
INSERT INTO test values (70, 25, 'D');
"""
@@ -174,6 +174,39 @@ def test_aggregates():
limbo.quit()
def test_grouped_aggregates():
limbo = TestTursoShell(init_commands=test_data)
extension_path = "./target/debug/liblimbo_percentile"
limbo.execute_dot(f".load {extension_path}")
limbo.run_test_fn(
"SELECT median(value) FROM numbers GROUP BY category;",
lambda res: "2.0\n5.5" == res,
"median aggregate function works",
)
limbo.run_test_fn(
"SELECT percentile(value, percent) FROM test GROUP BY category;",
lambda res: "12.5\n30.0\n45.0\n70.0" == res,
"grouped aggregate percentile function with 2 arguments works",
)
limbo.run_test_fn(
"SELECT percentile(value, 55) FROM test GROUP BY category;",
lambda res: "15.5\n30.0\n51.0\n70.0" == res,
"grouped aggregate percentile function with 1 argument works",
)
limbo.run_test_fn(
"SELECT percentile_cont(value, 0.25) FROM test GROUP BY category;",
lambda res: "12.5\n30.0\n45.0\n70.0" == res,
"grouped aggregate percentile_cont function works",
)
limbo.run_test_fn(
"SELECT percentile_disc(value, 0.55) FROM test GROUP BY category;",
lambda res: "10.0\n30.0\n50.0\n70.0" == res,
"grouped aggregate percentile_disc function works",
)
limbo.quit()
# Encoders and decoders
def validate_url_encode(a):
return a == "%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29"
@@ -770,6 +803,7 @@ def main():
test_regexp()
test_uuid()
test_aggregates()
test_grouped_aggregates()
test_crypto()
test_series()
test_ipaddr()

View File

@@ -74,3 +74,31 @@ do_execsql_test_on_specific_db {:memory:} collate_aggregation_explicit_nocase {
insert into fruits(name) values ('Apple') ,('banana') ,('CHERRY');
select max(name collate nocase) from fruits;
} {CHERRY}
do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_default_binary {
create table fruits(name collate binary, category text);
insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B');
select max(name) from fruits group by category;
} {banana
blueberry}
do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_default_nocase {
create table fruits(name collate nocase, category text);
insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B');
select max(name) from fruits group by category;
} {banana
CHERRY}
do_execsql_test_on_specific_db {:memory:} collate_grouped_aggregation_explicit_binary {
create table fruits(name collate nocase, category text);
insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B');
select max(name collate binary) from fruits group by category;
} {banana
blueberry}
do_execsql_test_on_specific_db {:memory:} collate_groupped_aggregation_explicit_nocase {
create table fruits(name collate binary, category text);
insert into fruits(name, category) values ('Apple', 'A'), ('banana', 'A'), ('CHERRY', 'B'), ('blueberry', 'B');
select max(name collate nocase) from fruits group by category;
} {banana
CHERRY}

View File

@@ -145,6 +145,18 @@ do_execsql_test group_by_count_star {
select u.first_name, count(*) from users u group by u.first_name limit 1;
} {Aaron|41}
do_execsql_test group_by_count_star_in_expression {
select u.first_name, count(*) % 3 from users u group by u.first_name order by u.first_name limit 3;
} {Aaron|2
Abigail|1
Adam|0}
do_execsql_test group_by_count_no_args_in_expression {
select u.first_name, count() % 3 from users u group by u.first_name order by u.first_name limit 3;
} {Aaron|2
Abigail|1
Adam|0}
do_execsql_test having {
select u.first_name, round(avg(u.age)) from users u group by u.first_name having avg(u.age) > 97 order by avg(u.age) desc limit 5;
} {Nina|100.0

View File

@@ -613,3 +613,23 @@ do_execsql_test_in_memory_error_content null-insert-in-nulltype-column-notnull-
CREATE TABLE test (id INTEGER,name NULL NOT NULL);
INSERT INTO test (id, name) VALUES (1, NULL);
} {NOT NULL constraint failed}
do_execsql_test_on_specific_db {:memory:} returning-true-literal {
CREATE TABLE test (id INTEGER, value TEXT);
INSERT INTO test (id, value) VALUES (1, true) RETURNING id, value;
} {1|1}
do_execsql_test_on_specific_db {:memory:} returning-false-literal {
CREATE TABLE test (id INTEGER, value TEXT);
INSERT INTO test (id, value) VALUES (1, false) RETURNING id, value;
} {1|0}
do_execsql_test_on_specific_db {:memory:} boolean-literal-edgecase {
CREATE TABLE true (id INTEGER, value TEXT);
INSERT INTO true (id, value) VALUES (1, true) RETURNING id, value;
} {1|1}
do_execsql_test_on_specific_db {:memory:} boolean-literal-edgecase-false {
CREATE TABLE false (id INTEGER, true TEXT);
INSERT INTO false (id, true) VALUES (1, false) RETURNING id, false;
} {1|0}

View File

@@ -18,3 +18,11 @@ do_execsql_test notnull {
do_execsql_test not-null {
select null not null, 'hi' not null;
} {0|1}
do_execsql_test sel-true {
select true;
} {1}
do_execsql_test sel-false {
select false;
} {0}

View File

@@ -124,5 +124,78 @@ fn test_per_page_encryption() -> anyhow::Result<()> {
})?;
}
{
// let's test connecting to the encrypted db using URI
let uri = format!(
"file:{}?cipher=aegis256&hexkey=b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327",
db_path.to_str().unwrap()
);
let (_io, conn) = turso_core::Connection::from_uri(&uri, true, false, false)?;
let mut row_count = 0;
run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |row: &Row| {
assert_eq!(row.get::<i64>(0).unwrap(), 1);
assert_eq!(row.get::<String>(1).unwrap(), "Hello, World!");
row_count += 1;
})?;
assert_eq!(row_count, 1);
}
Ok(())
}
#[test]
fn test_non_4k_page_size_encryption() -> anyhow::Result<()> {
let _ = env_logger::try_init();
let db_name = format!("test-8k-{}.db", rng().next_u32());
let tmp_db = TempDatabase::new(&db_name, false);
let db_path = tmp_db.path.clone();
{
let conn = tmp_db.connect_limbo();
// Set page size to 8k (8192 bytes) and test encryption. Default page size is 4k.
run_query(&tmp_db, &conn, "PRAGMA page_size = 8192;")?;
run_query(
&tmp_db,
&conn,
"PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';",
)?;
run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';")?;
run_query(
&tmp_db,
&conn,
"CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT);",
)?;
run_query(
&tmp_db,
&conn,
"INSERT INTO test (value) VALUES ('Hello, World!')",
)?;
let mut row_count = 0;
run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |row: &Row| {
assert_eq!(row.get::<i64>(0).unwrap(), 1);
assert_eq!(row.get::<String>(1).unwrap(), "Hello, World!");
row_count += 1;
})?;
assert_eq!(row_count, 1);
do_flush(&conn, &tmp_db)?;
}
{
// Reopen the existing db with 8k page size and test encryption
let existing_db = TempDatabase::new_with_existent(&db_path, false);
let conn = existing_db.connect_limbo();
run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';")?;
run_query(
&existing_db,
&conn,
"PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';",
)?;
run_query_on_row(&existing_db, &conn, "SELECT * FROM test", |row: &Row| {
assert_eq!(row.get::<i64>(0).unwrap(), 1);
assert_eq!(row.get::<String>(1).unwrap(), "Hello, World!");
})?;
}
Ok(())
}

View File

@@ -434,7 +434,7 @@ fn write_at(io: &impl IO, file: Arc<dyn File>, offset: usize, data: &[u8]) {
// reference the buffer to keep alive for async io
let _buf = _buf.clone();
});
let result = file.pwrite(offset, buffer, completion).unwrap();
let result = file.pwrite(offset as u64, buffer, completion).unwrap();
while !result.is_completed() {
io.run_once().unwrap();
}