diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index acf74e39a..be3b1a84b 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -34,10 +34,9 @@ jobs: go-version: "1.23" - name: build Go bindings library - run: cargo build --package limbo-go + run: cargo build --package turso-go - name: run Go tests env: LD_LIBRARY_PATH: ${{ github.workspace }}/target/debug:$LD_LIBRARY_PATH run: go test - diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 12073def9..b354dd8e3 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -39,7 +39,9 @@ jobs: - name: Build run: cargo build --verbose - name: Test Encryption - run: cargo test --features encryption --color=always --test integration_tests query_processing::encryption + run: | + cargo test --features encryption --color=always --test integration_tests query_processing::encryption + cargo test --features encryption --color=always --lib storage::encryption - name: Test env: RUST_LOG: ${{ runner.debug && 'turso_core::storage=trace' || '' }} diff --git a/Cargo.lock b/Cargo.lock index 526ecbf51..57cf52cd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,6 +27,16 @@ dependencies = [ "generic-array", ] +[[package]] +name = "aegis" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2a1c2f54793fee13c334f70557d3bd6a029a9d453ebffd82ba571d139064da8" +dependencies = [ + "cc", + "softaes", +] + [[package]] name = "aes" version = "0.8.4" @@ -109,6 +119,15 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "anarchist-readable-name-generator-lib" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09a645c34bad5551ed4b2496536985efdc4373b097c0e57abf2eb14774538278" +dependencies = [ + "rand 0.9.2", +] + [[package]] name = "android-tzdata" version = "0.1.1" @@ -2047,13 +2066,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "limbo-go" -version = "0.1.4" -dependencies = [ - "turso_core", -] - [[package]] name = "limbo_completion" version = "0.1.4" @@ -2115,7 +2127,6 @@ dependencies = [ name = "limbo_sim" version = "0.1.4" dependencies = [ - "anarchist-readable-name-generator-lib", "anyhow", "chrono", "clap", @@ -2125,17 +2136,18 @@ dependencies = [ "itertools 0.14.0", "log", "notify", - "rand 0.8.5", - "rand_chacha 0.3.1", + "rand 0.9.2", + "rand_chacha 0.9.0", "regex", "regex-syntax 0.8.5", "rusqlite", "serde", "serde_json", + "sql_generation", "tracing", "tracing-subscriber", "turso_core", - "turso_sqlite3_parser", + "turso_parser", ] [[package]] @@ -3440,12 +3452,34 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "softaes" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fef461faaeb36c340b6c887167a9054a034f6acfc50a014ead26a02b4356b3de" + [[package]] name = "sorted-vec" version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d372029cb5195f9ab4e4b9aef550787dce78b124fcaee8d82519925defcd6f0d" +[[package]] +name = "sql_generation" +version = "0.1.4" +dependencies = [ + "anarchist-readable-name-generator-lib 0.2.0", + "anyhow", + "hex", + "itertools 0.14.0", + "rand 0.9.2", + "rand_chacha 0.9.0", + "serde", + "tracing", + "turso_core", + "turso_parser", +] + [[package]] name = "sqlparser_bench" version = "0.1.0" @@ -3943,6 +3977,13 @@ dependencies = [ "turso_core", ] +[[package]] +name = "turso-go" +version = "0.1.4" +dependencies = [ + "turso_core", +] + [[package]] name = "turso-java" version = "0.1.4" @@ -3989,6 +4030,7 @@ dependencies = [ name = "turso_core" version = "0.1.4" dependencies = [ + "aegis", "aes", "aes-gcm", "antithesis_sdk", @@ -4142,7 +4184,7 @@ dependencies = [ name = "turso_stress" version = "0.1.4" dependencies = [ - "anarchist-readable-name-generator-lib", + "anarchist-readable-name-generator-lib 0.1.2", "antithesis_sdk", "clap", "hex", diff --git a/Cargo.toml b/Cargo.toml index 646cd977c..0dbb4b1fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ members = [ "parser", "sync/engine", "sync/javascript", + "sql_generation", ] exclude = ["perf/latency/limbo"] @@ -56,6 +57,7 @@ limbo_regexp = { path = "extensions/regexp", version = "0.1.4" } turso_sqlite3_parser = { path = "vendored/sqlite3-parser", version = "0.1.4" } limbo_uuid = { path = "extensions/uuid", version = "0.1.4" } turso_parser = { path = "parser" } +sql_generation = { path = "sql_generation" } strum = { version = "0.26", features = ["derive"] } strum_macros = "0.26" serde = "1.0" @@ -63,6 +65,9 @@ serde_json = "1.0" anyhow = "1.0.98" mimalloc = { version = "0.1.47", default-features = false } rusqlite = { version = "0.37.0", features = ["bundled"] } +itertools = "0.14.0" +rand = "0.9.2" +tracing = "0.1.41" [profile.release] debug = "line-tables-only" diff --git a/PERF.md b/PERF.md index 7ddf5eab0..ed09730cf 100644 --- a/PERF.md +++ b/PERF.md @@ -29,6 +29,7 @@ strace -f -c ../../Mobibench/shell/mobibench-turso -f 1024 -r 4 -a 0 -y 0 -t 1 - ./mobibench -p -n 1000 -d 0 -j 4 ``` + ## Clickbench We have a modified version of the Clickbench benchmark script that can be run with: @@ -41,7 +42,6 @@ This will build Turso in release mode, create a database, and run the benchmarks It will run the queries for both Turso and SQLite, and print the results. - ## Comparing VFS's/IO Back-ends (io_uring | syscall) ```shell @@ -54,26 +54,9 @@ The naive script will build and run limbo in release mode and execute the given ## TPC-H -1. Clone the Taratool TPC-H benchmarking tool: +Run the benchmark script: ```shell -git clone git@github.com:tarantool/tpch.git +./perf/tpc-h/benchmark.sh ``` -2. Patch the benchmark runner script: - -```patch -diff --git a/bench_queries.sh b/bench_queries.sh -index 6b894f9..c808e9a 100755 ---- a/bench_queries.sh -+++ b/bench_queries.sh -@@ -4,7 +4,7 @@ function check_q { - local query=queries/$*.sql - ( - echo $query -- time ( sqlite3 TPC-H.db < $query > /dev/null ) -+ time ( ../../limbo/target/release/limbo -m list TPC-H.db < $query > /dev/null ) - ) - } -``` - diff --git a/README.md b/README.md index a98999ee5..ceaca572c 100644 --- a/README.md +++ b/README.md @@ -174,8 +174,8 @@ print(res.fetchone()) 1. Clone the repository 2. Build the library and set your LD_LIBRARY_PATH to include turso's target directory ```console -cargo build --package limbo-go -export LD_LIBRARY_PATH=/path/to/limbo/target/debug:$LD_LIBRARY_PATH +cargo build --package turso-go +export LD_LIBRARY_PATH=/path/to/turso/target/debug:$LD_LIBRARY_PATH ``` 3. Use the driver @@ -191,7 +191,7 @@ import ( _ "github.com/tursodatabase/turso" ) -conn, _ = sql.Open("sqlite3", "sqlite.db") +conn, _ = sql.Open("turso", "sqlite.db") defer conn.Close() stmt, _ := conn.Prepare("select * from users") diff --git a/antithesis-tests/stress/singleton_driver_stress.sh b/antithesis-tests/stress/singleton_driver_stress.sh index fcab5ce2a..38c7392b4 100755 --- a/antithesis-tests/stress/singleton_driver_stress.sh +++ b/antithesis-tests/stress/singleton_driver_stress.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -/bin/turso_stress --silent --nr-iterations 10000 +/bin/turso_stress --silent --nr-threads 2 --nr-iterations 10000 diff --git a/bindings/go/Cargo.toml b/bindings/go/Cargo.toml index 228aead5e..8f8a55d76 100644 --- a/bindings/go/Cargo.toml +++ b/bindings/go/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "limbo-go" +name = "turso-go" version.workspace = true authors.workspace = true edition.workspace = true @@ -8,7 +8,7 @@ repository.workspace = true publish = false [lib] -name = "_limbo_go" +name = "_turso_go" crate-type = ["cdylib"] path = "rs_src/lib.rs" diff --git a/bindings/go/README.md b/bindings/go/README.md index 72672ebfe..af74b98a4 100644 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -1,4 +1,4 @@ -# Limbo driver for Go's `database/sql` library +# Turso driver for Go's `database/sql` library **NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state. @@ -17,7 +17,7 @@ To build with embedded library support, follow these steps: git clone https://github.com/tursodatabase/turso # Navigate to the Go bindings directory -cd limbo/bindings/go +cd turso/bindings/go # Build the library (defaults to release build) ./build_lib.sh @@ -52,34 +52,34 @@ Build the driver with the embedded library as described above, then simply impor #### Linux | MacOS -_All commands listed are relative to the bindings/go directory in the limbo repository_ +_All commands listed are relative to the bindings/go directory in the turso repository_ ``` -cargo build --package limbo-go +cargo build --package turso-go -# Your LD_LIBRARY_PATH environment variable must include limbo's `target/debug` directory +# Your LD_LIBRARY_PATH environment variable must include turso's `target/debug` directory -export LD_LIBRARY_PATH="/path/to/limbo/target/debug:$LD_LIBRARY_PATH" +export LD_LIBRARY_PATH="/path/to/turso/target/debug:$LD_LIBRARY_PATH" ``` #### Windows ``` -cargo build --package limbo-go +cargo build --package turso-go -# You must add limbo's `target/debug` directory to your PATH +# You must add turso's `target/debug` directory to your PATH # or you could built + copy the .dll to a location in your PATH # or just the CWD of your go module -cp path\to\limbo\target\debug\lib_limbo_go.dll . +cp path\to\turso\target\debug\lib_turso_go.dll . go test ``` -**Temporarily** you may have to clone the limbo repository and run: +**Temporarily** you may have to clone the turso repository and run: -`go mod edit -replace github.com/tursodatabase/turso=/path/to/limbo/bindings/go` +`go mod edit -replace github.com/tursodatabase/turso=/path/to/turso/bindings/go` ```go import ( @@ -89,19 +89,19 @@ import ( ) func main() { - conn, err := sql.Open("sqlite3", ":memory:") + conn, err := sql.Open("turso", ":memory:") if err != nil { fmt.Printf("Error: %v\n", err) os.Exit(1) } - sql := "CREATE table go_limbo (foo INTEGER, bar TEXT)" + sql := "CREATE table go_turso (foo INTEGER, bar TEXT)" _ = conn.Exec(sql) - sql = "INSERT INTO go_limbo (foo, bar) values (?, ?)" + sql = "INSERT INTO go_turso (foo, bar) values (?, ?)" stmt, _ := conn.Prepare(sql) defer stmt.Close() - _ = stmt.Exec(42, "limbo") - rows, _ := conn.Query("SELECT * from go_limbo") + _ = stmt.Exec(42, "turso") + rows, _ := conn.Query("SELECT * from go_turso") defer rows.Close() for rows.Next() { var a int diff --git a/bindings/go/build_lib.sh b/bindings/go/build_lib.sh index 26bf07ab0..1b77bfa26 100755 --- a/bindings/go/build_lib.sh +++ b/bindings/go/build_lib.sh @@ -6,12 +6,12 @@ set -e # Accept build type as parameter, default to release BUILD_TYPE=${1:-release} -echo "Building Limbo Go library for current platform (build type: $BUILD_TYPE)..." +echo "Building turso Go library for current platform (build type: $BUILD_TYPE)..." # Determine platform-specific details case "$(uname -s)" in Darwin*) - OUTPUT_NAME="lib_limbo_go.dylib" + OUTPUT_NAME="lib_turso_go.dylib" # Map x86_64 to amd64 for Go compatibility ARCH=$(uname -m) if [ "$ARCH" == "x86_64" ]; then @@ -20,7 +20,7 @@ case "$(uname -s)" in PLATFORM="darwin_${ARCH}" ;; Linux*) - OUTPUT_NAME="lib_limbo_go.so" + OUTPUT_NAME="lib_turso_go.so" # Map x86_64 to amd64 for Go compatibility ARCH=$(uname -m) if [ "$ARCH" == "x86_64" ]; then @@ -29,7 +29,7 @@ case "$(uname -s)" in PLATFORM="linux_${ARCH}" ;; MINGW*|MSYS*|CYGWIN*) - OUTPUT_NAME="lib_limbo_go.dll" + OUTPUT_NAME="lib_turso_go.dll" if [ "$(uname -m)" == "x86_64" ]; then PLATFORM="windows_amd64" else @@ -60,11 +60,11 @@ else fi # Build the library -echo "Running cargo build ${CARGO_ARGS} --package limbo-go" -cargo build ${CARGO_ARGS} --package limbo-go +echo "Running cargo build ${CARGO_ARGS} --package turso-go" +cargo build ${CARGO_ARGS} --package turso-go # Copy to the appropriate directory echo "Copying $OUTPUT_NAME to $OUTPUT_DIR/" cp "../../target/${TARGET_DIR}/$OUTPUT_NAME" "$OUTPUT_DIR/" -echo "Library built successfully for $PLATFORM ($BUILD_TYPE build)" \ No newline at end of file +echo "Library built successfully for $PLATFORM ($BUILD_TYPE build)" diff --git a/bindings/go/connection.go b/bindings/go/connection.go index 27d7fc06a..2d0a1dc7b 100644 --- a/bindings/go/connection.go +++ b/bindings/go/connection.go @@ -1,4 +1,4 @@ -package limbo +package turso import ( "context" @@ -16,16 +16,16 @@ func init() { if err != nil { panic(err) } - sql.Register(driverName, &limboDriver{}) + sql.Register(driverName, &tursoDriver{}) } -type limboDriver struct { +type tursoDriver struct { sync.Mutex } var ( libOnce sync.Once - limboLib uintptr + tursoLib uintptr loadErr error dbOpen func(string) uintptr dbClose func(uintptr) uintptr @@ -49,32 +49,32 @@ var ( // Register all the symbols on library load func ensureLibLoaded() error { libOnce.Do(func() { - limboLib, loadErr = loadLibrary() + tursoLib, loadErr = loadLibrary() if loadErr != nil { return } - purego.RegisterLibFunc(&dbOpen, limboLib, FfiDbOpen) - purego.RegisterLibFunc(&dbClose, limboLib, FfiDbClose) - purego.RegisterLibFunc(&connPrepare, limboLib, FfiDbPrepare) - purego.RegisterLibFunc(&connGetError, limboLib, FfiDbGetError) - purego.RegisterLibFunc(&freeBlobFunc, limboLib, FfiFreeBlob) - purego.RegisterLibFunc(&freeStringFunc, limboLib, FfiFreeCString) - purego.RegisterLibFunc(&rowsGetColumns, limboLib, FfiRowsGetColumns) - purego.RegisterLibFunc(&rowsGetColumnName, limboLib, FfiRowsGetColumnName) - purego.RegisterLibFunc(&rowsGetValue, limboLib, FfiRowsGetValue) - purego.RegisterLibFunc(&closeRows, limboLib, FfiRowsClose) - purego.RegisterLibFunc(&rowsNext, limboLib, FfiRowsNext) - purego.RegisterLibFunc(&rowsGetError, limboLib, FfiRowsGetError) - purego.RegisterLibFunc(&stmtQuery, limboLib, FfiStmtQuery) - purego.RegisterLibFunc(&stmtExec, limboLib, FfiStmtExec) - purego.RegisterLibFunc(&stmtParamCount, limboLib, FfiStmtParameterCount) - purego.RegisterLibFunc(&stmtGetError, limboLib, FfiStmtGetError) - purego.RegisterLibFunc(&stmtClose, limboLib, FfiStmtClose) + purego.RegisterLibFunc(&dbOpen, tursoLib, FfiDbOpen) + purego.RegisterLibFunc(&dbClose, tursoLib, FfiDbClose) + purego.RegisterLibFunc(&connPrepare, tursoLib, FfiDbPrepare) + purego.RegisterLibFunc(&connGetError, tursoLib, FfiDbGetError) + purego.RegisterLibFunc(&freeBlobFunc, tursoLib, FfiFreeBlob) + purego.RegisterLibFunc(&freeStringFunc, tursoLib, FfiFreeCString) + purego.RegisterLibFunc(&rowsGetColumns, tursoLib, FfiRowsGetColumns) + purego.RegisterLibFunc(&rowsGetColumnName, tursoLib, FfiRowsGetColumnName) + purego.RegisterLibFunc(&rowsGetValue, tursoLib, FfiRowsGetValue) + purego.RegisterLibFunc(&closeRows, tursoLib, FfiRowsClose) + purego.RegisterLibFunc(&rowsNext, tursoLib, FfiRowsNext) + purego.RegisterLibFunc(&rowsGetError, tursoLib, FfiRowsGetError) + purego.RegisterLibFunc(&stmtQuery, tursoLib, FfiStmtQuery) + purego.RegisterLibFunc(&stmtExec, tursoLib, FfiStmtExec) + purego.RegisterLibFunc(&stmtParamCount, tursoLib, FfiStmtParameterCount) + purego.RegisterLibFunc(&stmtGetError, tursoLib, FfiStmtGetError) + purego.RegisterLibFunc(&stmtClose, tursoLib, FfiStmtClose) }) return loadErr } -func (d *limboDriver) Open(name string) (driver.Conn, error) { +func (d *tursoDriver) Open(name string) (driver.Conn, error) { d.Lock() conn, err := openConn(name) d.Unlock() @@ -84,23 +84,23 @@ func (d *limboDriver) Open(name string) (driver.Conn, error) { return conn, nil } -type limboConn struct { +type tursoConn struct { sync.Mutex ctx uintptr } -func openConn(dsn string) (*limboConn, error) { +func openConn(dsn string) (*tursoConn, error) { ctx := dbOpen(dsn) if ctx == 0 { return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) } - return &limboConn{ + return &tursoConn{ sync.Mutex{}, ctx, }, loadErr } -func (c *limboConn) Close() error { +func (c *tursoConn) Close() error { if c.ctx == 0 { return nil } @@ -111,7 +111,7 @@ func (c *limboConn) Close() error { return nil } -func (c *limboConn) getError() error { +func (c *tursoConn) getError() error { if c.ctx == 0 { return errors.New("connection closed") } @@ -124,7 +124,7 @@ func (c *limboConn) getError() error { return errors.New(cpy) } -func (c *limboConn) Prepare(query string) (driver.Stmt, error) { +func (c *tursoConn) Prepare(query string) (driver.Stmt, error) { if c.ctx == 0 { return nil, errors.New("connection closed") } @@ -137,13 +137,13 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) { return newStmt(stmtPtr, query), nil } -// limboTx implements driver.Tx -type limboTx struct { - conn *limboConn +// tursoTx implements driver.Tx +type tursoTx struct { + conn *tursoConn } // Begin starts a new transaction with default isolation level -func (c *limboConn) Begin() (driver.Tx, error) { +func (c *tursoConn) Begin() (driver.Tx, error) { c.Lock() defer c.Unlock() @@ -165,12 +165,12 @@ func (c *limboConn) Begin() (driver.Tx, error) { return nil, err } - return &limboTx{conn: c}, nil + return &tursoTx{conn: c}, nil } // BeginTx starts a transaction with the specified options. // Currently only supports default isolation level and non-read-only transactions. -func (c *limboConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { +func (c *tursoConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { // Skip handling non-default isolation levels and read-only mode // for now, letting database/sql package handle these cases if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) || opts.ReadOnly { @@ -187,7 +187,7 @@ func (c *limboConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver. } // Commit commits the transaction -func (tx *limboTx) Commit() error { +func (tx *tursoTx) Commit() error { tx.conn.Lock() defer tx.conn.Unlock() @@ -208,8 +208,7 @@ func (tx *limboTx) Commit() error { } // Rollback aborts the transaction. -// Note: This operation is not currently fully supported by Limbo and will return an error. -func (tx *limboTx) Rollback() error { +func (tx *tursoTx) Rollback() error { tx.conn.Lock() defer tx.conn.Unlock() diff --git a/bindings/go/embedded.go b/bindings/go/embedded.go index 9f04f2d79..2a44795d6 100644 --- a/bindings/go/embedded.go +++ b/bindings/go/embedded.go @@ -1,4 +1,4 @@ -// Go bindings for the Limbo database. +// Go bindings for the turso database. // // This file implements library embedding and extraction at runtime, a pattern // also used in several other Go projects that need to distribute native binaries: @@ -21,7 +21,7 @@ // The embedded library is extracted to a user-specific temporary directory and // loaded dynamically. If extraction fails, the code falls back to the traditional // method of searching system paths. -package limbo +package turso import ( "embed" @@ -52,11 +52,11 @@ func extractEmbeddedLibrary() (string, error) { switch runtime.GOOS { case "darwin": - libName = "lib_limbo_go.dylib" + libName = "lib_turso_go.dylib" case "linux": - libName = "lib_limbo_go.so" + libName = "lib_turso_go.so" case "windows": - libName = "lib_limbo_go.dll" + libName = "lib_turso_go.dll" default: extractErr = fmt.Errorf("unsupported operating system: %s", runtime.GOOS) return @@ -80,7 +80,7 @@ func extractEmbeddedLibrary() (string, error) { platformDir = fmt.Sprintf("%s_%s", runtime.GOOS, archSuffix) // Create a unique temporary directory for the current user - tempDir := filepath.Join(os.TempDir(), fmt.Sprintf("limbo-go-%d", os.Getuid())) + tempDir := filepath.Join(os.TempDir(), fmt.Sprintf("turso-go-%d", os.Getuid())) if err := os.MkdirAll(tempDir, 0755); err != nil { extractErr = fmt.Errorf("failed to create temp directory: %w", err) return diff --git a/bindings/go/rows.go b/bindings/go/rows.go index 1d14e0d0c..c82bd2e65 100644 --- a/bindings/go/rows.go +++ b/bindings/go/rows.go @@ -1,4 +1,4 @@ -package limbo +package turso import ( "database/sql/driver" @@ -8,7 +8,7 @@ import ( "sync" ) -type limboRows struct { +type tursoRows struct { mu sync.Mutex ctx uintptr columns []string @@ -16,8 +16,8 @@ type limboRows struct { closed bool } -func newRows(ctx uintptr) *limboRows { - return &limboRows{ +func newRows(ctx uintptr) *tursoRows { + return &tursoRows{ mu: sync.Mutex{}, ctx: ctx, columns: nil, @@ -26,14 +26,14 @@ func newRows(ctx uintptr) *limboRows { } } -func (r *limboRows) isClosed() bool { +func (r *tursoRows) isClosed() bool { if r.ctx == 0 || r.closed { return true } return false } -func (r *limboRows) Columns() []string { +func (r *tursoRows) Columns() []string { if r.isClosed() { return nil } @@ -54,7 +54,7 @@ func (r *limboRows) Columns() []string { return r.columns } -func (r *limboRows) Close() error { +func (r *tursoRows) Close() error { r.err = errors.New(RowsClosedErr) if r.isClosed() { return r.err @@ -67,7 +67,7 @@ func (r *limboRows) Close() error { return nil } -func (r *limboRows) Err() error { +func (r *tursoRows) Err() error { if r.err == nil { r.mu.Lock() defer r.mu.Unlock() @@ -76,7 +76,7 @@ func (r *limboRows) Err() error { return r.err } -func (r *limboRows) Next(dest []driver.Value) error { +func (r *tursoRows) Next(dest []driver.Value) error { r.mu.Lock() defer r.mu.Unlock() if r.isClosed() { @@ -106,7 +106,7 @@ func (r *limboRows) Next(dest []driver.Value) error { } // mutex will already be locked. this is always called after FFI -func (r *limboRows) getError() error { +func (r *tursoRows) getError() error { if r.isClosed() { return r.err } diff --git a/bindings/go/rs_src/lib.rs b/bindings/go/rs_src/lib.rs index 475ad062f..26a2a4dfd 100644 --- a/bindings/go/rs_src/lib.rs +++ b/bindings/go/rs_src/lib.rs @@ -23,19 +23,19 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { let Ok((io, conn)) = Connection::from_uri(path, true, false, false) else { panic!("Failed to open connection with path: {path}"); }; - LimboConn::new(conn, io).to_ptr() + TursoConn::new(conn, io).to_ptr() } #[allow(dead_code)] -struct LimboConn { +struct TursoConn { conn: Arc, io: Arc, err: Option, } -impl LimboConn { +impl TursoConn { fn new(conn: Arc, io: Arc) -> Self { - LimboConn { + TursoConn { conn, io, err: None, @@ -47,11 +47,11 @@ impl LimboConn { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'static mut LimboConn { + fn from_ptr(ptr: *mut c_void) -> &'static mut TursoConn { if ptr.is_null() { panic!("Null pointer"); } - unsafe { &mut *(ptr as *mut LimboConn) } + unsafe { &mut *(ptr as *mut TursoConn) } } fn get_error(&mut self) -> *const c_char { @@ -73,7 +73,7 @@ pub extern "C" fn db_get_error(ctx: *mut c_void) -> *const c_char { if ctx.is_null() { return std::ptr::null(); } - let conn = LimboConn::from_ptr(ctx); + let conn = TursoConn::from_ptr(ctx); conn.get_error() } @@ -83,6 +83,6 @@ pub extern "C" fn db_get_error(ctx: *mut c_void) -> *const c_char { #[no_mangle] pub unsafe extern "C" fn db_close(db: *mut c_void) { if !db.is_null() { - let _ = unsafe { Box::from_raw(db as *mut LimboConn) }; + let _ = unsafe { Box::from_raw(db as *mut TursoConn) }; } } diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs index d3d64ed3c..8a05440a5 100644 --- a/bindings/go/rs_src/rows.rs +++ b/bindings/go/rs_src/rows.rs @@ -1,19 +1,19 @@ use crate::{ - types::{LimboValue, ResultCode}, - LimboConn, + types::{ResultCode, TursoValue}, + TursoConn, }; use std::ffi::{c_char, c_void}; use turso_core::{LimboError, Statement, StepResult, Value}; -pub struct LimboRows<'conn> { +pub struct TursoRows<'conn> { stmt: Box, - _conn: &'conn mut LimboConn, + _conn: &'conn mut TursoConn, err: Option, } -impl<'conn> LimboRows<'conn> { - pub fn new(stmt: Statement, conn: &'conn mut LimboConn) -> Self { - LimboRows { +impl<'conn> TursoRows<'conn> { + pub fn new(stmt: Statement, conn: &'conn mut TursoConn) -> Self { + TursoRows { stmt: Box::new(stmt), _conn: conn, err: None, @@ -25,11 +25,11 @@ impl<'conn> LimboRows<'conn> { Box::into_raw(Box::new(self)) as *mut c_void } - pub fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboRows<'conn> { + pub fn from_ptr(ptr: *mut c_void) -> &'conn mut TursoRows<'conn> { if ptr.is_null() { panic!("Null pointer"); } - unsafe { &mut *(ptr as *mut LimboRows) } + unsafe { &mut *(ptr as *mut TursoRows) } } fn get_error(&mut self) -> *const c_char { @@ -49,7 +49,7 @@ pub extern "C" fn rows_next(ctx: *mut c_void) -> ResultCode { if ctx.is_null() { return ResultCode::Error; } - let ctx = LimboRows::from_ptr(ctx); + let ctx = TursoRows::from_ptr(ctx); match ctx.stmt.step() { Ok(StepResult::Row) => ResultCode::Row, @@ -76,11 +76,11 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v if ctx.is_null() { return std::ptr::null(); } - let ctx = LimboRows::from_ptr(ctx); + let ctx = TursoRows::from_ptr(ctx); if let Some(row) = ctx.stmt.row() { if let Ok(value) = row.get::<&Value>(col_idx) { - return LimboValue::from_db_value(value).to_ptr(); + return TursoValue::from_db_value(value).to_ptr(); } } std::ptr::null() @@ -101,7 +101,7 @@ pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 { if rows_ptr.is_null() { return -1; } - let rows = LimboRows::from_ptr(rows_ptr); + let rows = TursoRows::from_ptr(rows_ptr); rows.stmt.num_columns() as i32 } @@ -113,7 +113,7 @@ pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *cons if rows_ptr.is_null() { return std::ptr::null_mut(); } - let rows = LimboRows::from_ptr(rows_ptr); + let rows = TursoRows::from_ptr(rows_ptr); if idx < 0 || idx as usize >= rows.stmt.num_columns() { return std::ptr::null_mut(); } @@ -127,18 +127,18 @@ pub extern "C" fn rows_get_error(ctx: *mut c_void) -> *const c_char { if ctx.is_null() { return std::ptr::null(); } - let ctx = LimboRows::from_ptr(ctx); + let ctx = TursoRows::from_ptr(ctx); ctx.get_error() } #[no_mangle] pub extern "C" fn rows_close(ctx: *mut c_void) { if !ctx.is_null() { - let rows = LimboRows::from_ptr(ctx); + let rows = TursoRows::from_ptr(ctx); rows.stmt.reset(); rows.err = None; } unsafe { - let _ = Box::from_raw(ctx.cast::()); + let _ = Box::from_raw(ctx.cast::()); } } diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 4dc115aec..65859161d 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -1,6 +1,6 @@ -use crate::rows::LimboRows; -use crate::types::{AllocPool, LimboValue, ResultCode}; -use crate::LimboConn; +use crate::rows::TursoRows; +use crate::types::{AllocPool, ResultCode, TursoValue}; +use crate::TursoConn; use std::ffi::{c_char, c_void}; use std::num::NonZero; use turso_core::{LimboError, Statement, StepResult}; @@ -12,10 +12,10 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v } let query_str = unsafe { std::ffi::CStr::from_ptr(query) }.to_str().unwrap(); - let db = LimboConn::from_ptr(ctx); + let db = TursoConn::from_ptr(ctx); let stmt = db.conn.prepare(query_str); match stmt { - Ok(stmt) => LimboStatement::new(Some(stmt), db).to_ptr(), + Ok(stmt) => TursoStatement::new(Some(stmt), db).to_ptr(), Err(err) => { db.err = Some(err); std::ptr::null_mut() @@ -26,14 +26,14 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v #[no_mangle] pub extern "C" fn stmt_execute( ctx: *mut c_void, - args_ptr: *mut LimboValue, + args_ptr: *mut TursoValue, arg_count: usize, changes: *mut i64, ) -> ResultCode { if ctx.is_null() { return ResultCode::Error; } - let stmt = LimboStatement::from_ptr(ctx); + let stmt = TursoStatement::from_ptr(ctx); let args = if !args_ptr.is_null() && arg_count > 0 { unsafe { std::slice::from_raw_parts(args_ptr, arg_count) } @@ -88,7 +88,7 @@ pub extern "C" fn stmt_parameter_count(ctx: *mut c_void) -> i32 { if ctx.is_null() { return -1; } - let stmt = LimboStatement::from_ptr(ctx); + let stmt = TursoStatement::from_ptr(ctx); let Some(statement) = stmt.statement.as_ref() else { stmt.err = Some(LimboError::InternalError("Statement is closed".to_string())); return -1; @@ -99,13 +99,13 @@ pub extern "C" fn stmt_parameter_count(ctx: *mut c_void) -> i32 { #[no_mangle] pub extern "C" fn stmt_query( ctx: *mut c_void, - args_ptr: *mut LimboValue, + args_ptr: *mut TursoValue, args_count: usize, ) -> *mut c_void { if ctx.is_null() { return std::ptr::null_mut(); } - let stmt = LimboStatement::from_ptr(ctx); + let stmt = TursoStatement::from_ptr(ctx); let args = if !args_ptr.is_null() && args_count > 0 { unsafe { std::slice::from_raw_parts(args_ptr, args_count) } } else { @@ -119,21 +119,21 @@ pub extern "C" fn stmt_query( let val = arg.to_value(&mut pool); statement.bind_at(NonZero::new(i + 1).unwrap(), val); } - // ownership of the statement is transferred to the LimboRows object. - LimboRows::new(statement, stmt.conn).to_ptr() + // ownership of the statement is transferred to the TursoRows object. + TursoRows::new(statement, stmt.conn).to_ptr() } -pub struct LimboStatement<'conn> { - /// If 'query' is ran on the statement, ownership is transferred to the LimboRows object +pub struct TursoStatement<'conn> { + /// If 'query' is ran on the statement, ownership is transferred to the TursoRows object pub statement: Option, - pub conn: &'conn mut LimboConn, + pub conn: &'conn mut TursoConn, pub err: Option, } #[no_mangle] pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode { if !ctx.is_null() { - let stmt = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; + let stmt = unsafe { Box::from_raw(ctx as *mut TursoStatement) }; drop(stmt); return ResultCode::Ok; } @@ -145,13 +145,13 @@ pub extern "C" fn stmt_get_error(ctx: *mut c_void) -> *const c_char { if ctx.is_null() { return std::ptr::null(); } - let stmt = LimboStatement::from_ptr(ctx); + let stmt = TursoStatement::from_ptr(ctx); stmt.get_error() } -impl<'conn> LimboStatement<'conn> { - pub fn new(statement: Option, conn: &'conn mut LimboConn) -> Self { - LimboStatement { +impl<'conn> TursoStatement<'conn> { + pub fn new(statement: Option, conn: &'conn mut TursoConn) -> Self { + TursoStatement { statement, conn, err: None, @@ -163,11 +163,11 @@ impl<'conn> LimboStatement<'conn> { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboStatement<'conn> { + fn from_ptr(ptr: *mut c_void) -> &'conn mut TursoStatement<'conn> { if ptr.is_null() { panic!("Null pointer"); } - unsafe { &mut *(ptr as *mut LimboStatement) } + unsafe { &mut *(ptr as *mut TursoStatement) } } fn get_error(&mut self) -> *const c_char { diff --git a/bindings/go/rs_src/types.rs b/bindings/go/rs_src/types.rs index 683cfde3f..9ec06b3bf 100644 --- a/bindings/go/rs_src/types.rs +++ b/bindings/go/rs_src/types.rs @@ -34,33 +34,33 @@ pub enum ValueType { } #[repr(C)] -pub struct LimboValue { +pub struct TursoValue { value_type: ValueType, value: ValueUnion, } -impl Debug for LimboValue { +impl Debug for TursoValue { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.value_type { ValueType::Integer => { let i = self.value.to_int(); - f.debug_struct("LimboValue").field("value", &i).finish() + f.debug_struct("TursoValue").field("value", &i).finish() } ValueType::Real => { let r = self.value.to_real(); - f.debug_struct("LimboValue").field("value", &r).finish() + f.debug_struct("TursoValue").field("value", &r).finish() } ValueType::Text => { let t = self.value.to_str(); - f.debug_struct("LimboValue").field("value", &t).finish() + f.debug_struct("TursoValue").field("value", &t).finish() } ValueType::Blob => { let blob = self.value.to_bytes(); - f.debug_struct("LimboValue") + f.debug_struct("TursoValue") .field("value", &blob.to_vec()) .finish() } ValueType::Null => f - .debug_struct("LimboValue") + .debug_struct("TursoValue") .field("value", &"NULL") .finish(), } @@ -164,9 +164,9 @@ impl ValueUnion { } } -impl LimboValue { +impl TursoValue { fn new(value_type: ValueType, value: ValueUnion) -> Self { - LimboValue { value_type, value } + TursoValue { value_type, value } } #[allow(clippy::wrong_self_convention)] @@ -177,18 +177,18 @@ impl LimboValue { pub fn from_db_value(value: &turso_core::Value) -> Self { match value { turso_core::Value::Integer(i) => { - LimboValue::new(ValueType::Integer, ValueUnion::from_int(*i)) + TursoValue::new(ValueType::Integer, ValueUnion::from_int(*i)) } turso_core::Value::Float(r) => { - LimboValue::new(ValueType::Real, ValueUnion::from_real(*r)) + TursoValue::new(ValueType::Real, ValueUnion::from_real(*r)) } turso_core::Value::Text(s) => { - LimboValue::new(ValueType::Text, ValueUnion::from_str(s.as_str())) + TursoValue::new(ValueType::Text, ValueUnion::from_str(s.as_str())) } turso_core::Value::Blob(b) => { - LimboValue::new(ValueType::Blob, ValueUnion::from_bytes(b.as_slice())) + TursoValue::new(ValueType::Blob, ValueUnion::from_bytes(b.as_slice())) } - turso_core::Value::Null => LimboValue::new(ValueType::Null, ValueUnion::from_null()), + turso_core::Value::Null => TursoValue::new(ValueType::Null, ValueUnion::from_null()), } } diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 9e045175e..c12ae9d71 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -1,4 +1,4 @@ -package limbo +package turso import ( "context" @@ -9,22 +9,22 @@ import ( "unsafe" ) -type limboStmt struct { +type tursoStmt struct { mu sync.Mutex ctx uintptr sql string err error } -func newStmt(ctx uintptr, sql string) *limboStmt { - return &limboStmt{ +func newStmt(ctx uintptr, sql string) *tursoStmt { + return &tursoStmt{ ctx: uintptr(ctx), sql: sql, err: nil, } } -func (ls *limboStmt) NumInput() int { +func (ls *tursoStmt) NumInput() int { ls.mu.Lock() defer ls.mu.Unlock() res := int(stmtParamCount(ls.ctx)) @@ -35,7 +35,7 @@ func (ls *limboStmt) NumInput() int { return res } -func (ls *limboStmt) Close() error { +func (ls *tursoStmt) Close() error { ls.mu.Lock() defer ls.mu.Unlock() if ls.ctx == 0 { @@ -49,7 +49,7 @@ func (ls *limboStmt) Close() error { return nil } -func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { +func (ls *tursoStmt) Exec(args []driver.Value) (driver.Result, error) { argArray, cleanup, err := buildArgs(args) defer cleanup() if err != nil { @@ -80,7 +80,7 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { } } -func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { +func (ls *tursoStmt) Query(args []driver.Value) (driver.Rows, error) { queryArgs, cleanup, err := buildArgs(args) defer cleanup() if err != nil { @@ -99,7 +99,7 @@ func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { return newRows(rowsPtr), nil } -func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (ls *tursoStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { stripped := namedValueToValue(args) argArray, cleanup, err := getArgsPtr(stripped) defer cleanup() @@ -129,7 +129,7 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive } } -func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +func (ls *tursoStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { queryArgs, allocs, err := buildNamedArgs(args) defer allocs() if err != nil { @@ -154,7 +154,7 @@ func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) } } -func (ls *limboStmt) Err() error { +func (ls *tursoStmt) Err() error { if ls.err == nil { ls.mu.Lock() defer ls.mu.Unlock() @@ -164,7 +164,7 @@ func (ls *limboStmt) Err() error { } // mutex should always be locked when calling - always called after FFI -func (ls *limboStmt) getError() error { +func (ls *tursoStmt) getError() error { err := stmtGetError(ls.ctx) if err == 0 { return nil diff --git a/bindings/go/limbo_test.go b/bindings/go/turso_test.go similarity index 96% rename from bindings/go/limbo_test.go rename to bindings/go/turso_test.go index 8fe36ae17..ff2bc90a4 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/turso_test.go @@ -1,4 +1,4 @@ -package limbo_test +package turso_test import ( "database/sql" @@ -17,7 +17,7 @@ var ( ) func TestMain(m *testing.M) { - conn, connErr = sql.Open("sqlite3", ":memory:") + conn, connErr = sql.Open("turso", ":memory:") if connErr != nil { panic(connErr) } @@ -146,7 +146,7 @@ func TestFunctions(t *testing.T) { } func TestDuplicateConnection(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -177,7 +177,7 @@ func TestDuplicateConnection(t *testing.T) { } func TestDuplicateConnection2(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -209,7 +209,7 @@ func TestDuplicateConnection2(t *testing.T) { } func TestConnectionError(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -228,7 +228,7 @@ func TestConnectionError(t *testing.T) { } func TestStatementError(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -250,7 +250,7 @@ func TestStatementError(t *testing.T) { } func TestDriverRowsErrorMessages(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("failed to open database: %v", err) } @@ -285,7 +285,7 @@ func TestDriverRowsErrorMessages(t *testing.T) { func TestTransaction(t *testing.T) { // Open database connection - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening database: %v", err) } @@ -359,7 +359,7 @@ func TestTransaction(t *testing.T) { } func TestVectorOperations(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -397,7 +397,7 @@ func TestVectorOperations(t *testing.T) { } func TestSQLFeatures(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -501,7 +501,7 @@ func TestSQLFeatures(t *testing.T) { } func TestDateTimeFunctions(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -536,7 +536,7 @@ func TestDateTimeFunctions(t *testing.T) { } func TestMathFunctions(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -572,7 +572,7 @@ func TestMathFunctions(t *testing.T) { } func TestJSONFunctions(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -610,7 +610,7 @@ func TestJSONFunctions(t *testing.T) { } func TestParameterOrdering(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -685,7 +685,7 @@ func TestParameterOrdering(t *testing.T) { } func TestIndex(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } diff --git a/bindings/go/limbo_unix.go b/bindings/go/turso_unix.go similarity index 99% rename from bindings/go/limbo_unix.go rename to bindings/go/turso_unix.go index 1dd51f42e..3e61f278f 100644 --- a/bindings/go/limbo_unix.go +++ b/bindings/go/turso_unix.go @@ -1,6 +1,6 @@ //go:build linux || darwin -package limbo +package turso import ( "fmt" diff --git a/bindings/go/limbo_windows.go b/bindings/go/turso_windows.go similarity index 98% rename from bindings/go/limbo_windows.go rename to bindings/go/turso_windows.go index 2fddfd9a4..3926bedc9 100644 --- a/bindings/go/limbo_windows.go +++ b/bindings/go/turso_windows.go @@ -1,6 +1,6 @@ //go:build windows -package limbo +package turso import ( "fmt" diff --git a/bindings/go/types.go b/bindings/go/types.go index f35899828..1c9f00d62 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -1,4 +1,4 @@ -package limbo +package turso import ( "database/sql/driver" @@ -66,8 +66,8 @@ func (rc ResultCode) String() string { } const ( - driverName = "sqlite3" - libName = "lib_limbo_go" + driverName = "turso" + libName = "lib_turso_go" RowsClosedErr = "sql: Rows closed" FfiDbOpen = "db_open" FfiDbClose = "db_close" @@ -98,7 +98,7 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value { return out } -func buildNamedArgs(named []driver.NamedValue) ([]limboValue, func(), error) { +func buildNamedArgs(named []driver.NamedValue) ([]tursoValue, func(), error) { args := namedValueToValue(named) return buildArgs(args) } @@ -131,7 +131,7 @@ func (vt valueType) String() string { } // struct to pass Go values over FFI -type limboValue struct { +type tursoValue struct { Type valueType _ [4]byte Value [8]byte @@ -143,12 +143,12 @@ type Blob struct { Len int64 } -// convert a limboValue to a native Go value +// convert a tursoValue to a native Go value func toGoValue(valPtr uintptr) interface{} { if valPtr == 0 { return nil } - val := (*limboValue)(unsafe.Pointer(valPtr)) + val := (*tursoValue)(unsafe.Pointer(valPtr)) switch val.Type { case intVal: return *(*int64)(unsafe.Pointer(&val.Value)) @@ -228,50 +228,50 @@ func freeCString(cstrPtr uintptr) { freeStringFunc(cstrPtr) } -// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI +// convert a Go slice of driver.Value to a slice of tursoValue that can be sent over FFI // for Blob types, we have to pin them so they are not garbage collected before they can be copied // into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call -func buildArgs(args []driver.Value) ([]limboValue, func(), error) { +func buildArgs(args []driver.Value) ([]tursoValue, func(), error) { pinner := new(runtime.Pinner) - argSlice := make([]limboValue, len(args)) + argSlice := make([]tursoValue, len(args)) for i, v := range args { - limboVal := limboValue{} + tursoVal := tursoValue{} switch val := v.(type) { case nil: - limboVal.Type = nullVal + tursoVal.Type = nullVal case int64: - limboVal.Type = intVal - limboVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) + tursoVal.Type = intVal + tursoVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) case float64: - limboVal.Type = realVal - limboVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) + tursoVal.Type = realVal + tursoVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) case bool: - limboVal.Type = intVal + tursoVal.Type = intVal boolAsInt := int64(0) if val { boolAsInt = 1 } - limboVal.Value = *(*[8]byte)(unsafe.Pointer(&boolAsInt)) + tursoVal.Value = *(*[8]byte)(unsafe.Pointer(&boolAsInt)) case string: - limboVal.Type = textVal + tursoVal.Type = textVal cstr := CString(val) pinner.Pin(cstr) - *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr)) + *(*uintptr)(unsafe.Pointer(&tursoVal.Value)) = uintptr(unsafe.Pointer(cstr)) case []byte: - limboVal.Type = blobVal + tursoVal.Type = blobVal blob := makeBlob(val) pinner.Pin(blob) - *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob)) + *(*uintptr)(unsafe.Pointer(&tursoVal.Value)) = uintptr(unsafe.Pointer(blob)) case time.Time: - limboVal.Type = textVal + tursoVal.Type = textVal timeStr := val.Format(time.RFC3339) cstr := CString(timeStr) pinner.Pin(cstr) - *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr)) + *(*uintptr)(unsafe.Pointer(&tursoVal.Value)) = uintptr(unsafe.Pointer(cstr)) default: return nil, pinner.Unpin, fmt.Errorf("unsupported type: %T", v) } - argSlice[i] = limboVal + argSlice[i] = tursoVal } return argSlice, pinner.Unpin, nil } diff --git a/core/Cargo.toml b/core/Cargo.toml index e9f11969a..37c150524 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -77,6 +77,7 @@ bytemuck = "1.23.1" aes-gcm = { version = "0.10.3"} aes = { version = "0.8.4"} turso_parser = { workspace = true } +aegis = "0.9.0" [build-dependencies] chrono = { version = "0.4.38", default-features = false } diff --git a/core/incremental/expr_compiler.rs b/core/incremental/expr_compiler.rs new file mode 100644 index 000000000..f94d72a2a --- /dev/null +++ b/core/incremental/expr_compiler.rs @@ -0,0 +1,457 @@ +// Expression compilation for incremental operators +// This module provides utilities to compile SQL expressions into VDBE subprograms +// that can be executed efficiently in the incremental computation context. + +use crate::schema::Schema; +use crate::storage::pager::Pager; +use crate::translate::emitter::Resolver; +use crate::translate::expr::translate_expr; +use crate::types::Text; +use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts}; +use crate::vdbe::insn::Insn; +use crate::vdbe::{Program, ProgramState, Register}; +use crate::SymbolTable; +use crate::{CaptureDataChangesMode, Connection, QueryMode, Result, Value}; +use std::rc::Rc; +use std::sync::Arc; +use turso_parser::ast::{Expr, Literal, Operator}; + +// Transform an expression to replace column references with Register expressions Why do we want to +// do this? +// +// Imagine you have a view like: +// +// create materialized view hex(count(*) + 2). translate_expr will usually try to find match names +// to either literals or columns. But "count(*)" is not a column in any sqlite table. +// +// We *could* theoretically have a table-representation of every DBSP-step, but it is a lot simpler +// to just pass registers as parameters to the VDBE expression, and teach translate_expr to +// recognize those. +// +// But because the expression compiler will not generate those register inputs, we have to +// transform the expression. +fn transform_expr_for_dbsp(expr: &Expr, input_column_names: &[String]) -> Expr { + match expr { + // Transform column references (represented as Id) to Register expressions + Expr::Id(name) => { + // Check if this is a column name from our input + if let Some(idx) = input_column_names + .iter() + .position(|col| col == name.as_str()) + { + // Replace with a Register expression + Expr::Register(idx) + } else { + // Not a column reference, keep as is + expr.clone() + } + } + // Recursively transform nested expressions + Expr::Binary(lhs, op, rhs) => Expr::Binary( + Box::new(transform_expr_for_dbsp(lhs, input_column_names)), + *op, + Box::new(transform_expr_for_dbsp(rhs, input_column_names)), + ), + Expr::Unary(op, operand) => Expr::Unary( + *op, + Box::new(transform_expr_for_dbsp(operand, input_column_names)), + ), + Expr::FunctionCall { + name, + distinctness, + args, + order_by, + filter_over, + } => Expr::FunctionCall { + name: name.clone(), + distinctness: *distinctness, + args: args + .iter() + .map(|arg| Box::new(transform_expr_for_dbsp(arg, input_column_names))) + .collect(), + order_by: order_by.clone(), + filter_over: filter_over.clone(), + }, + Expr::Parenthesized(exprs) => Expr::Parenthesized( + exprs + .iter() + .map(|e| Box::new(transform_expr_for_dbsp(e, input_column_names))) + .collect(), + ), + // For other expression types, keep as is + _ => expr.clone(), + } +} + +/// Enum to represent either a trivial or compiled expression +#[derive(Clone)] +pub enum ExpressionExecutor { + /// Trivial expression that can be evaluated inline + Trivial(TrivialExpression), + /// Compiled VDBE program for complex expressions + Compiled(Arc), +} + +/// Trivial expression that can be evaluated inline without VDBE +/// Only supports operations where operands have the same type (no coercion) +#[derive(Clone, Debug)] +pub enum TrivialExpression { + /// Direct column reference + Column(usize), + /// Immediate value + Immediate(Value), + /// Binary operation on trivial expressions (same-type operands only) + Binary { + left: Box, + op: Operator, + right: Box, + }, +} + +impl TrivialExpression { + /// Evaluate the trivial expression with the given input values + /// Panics if type mismatch occurs (this indicates a bug in validation) + pub fn evaluate(&self, values: &[Value]) -> Value { + match self { + TrivialExpression::Column(idx) => values.get(*idx).cloned().unwrap_or(Value::Null), + TrivialExpression::Immediate(val) => val.clone(), + TrivialExpression::Binary { left, op, right } => { + let left_val = left.evaluate(values); + let right_val = right.evaluate(values); + + // Only perform operations on same-type operands + match op { + Operator::Add => match (&left_val, &right_val) { + (Value::Integer(a), Value::Integer(b)) => Value::Integer(a + b), + (Value::Float(a), Value::Float(b)) => Value::Float(a + b), + (Value::Null, _) | (_, Value::Null) => Value::Null, + _ => panic!("Type mismatch in trivial expression: {left_val:?} + {right_val:?}. This is a bug in trivial expression validation."), + }, + Operator::Subtract => match (&left_val, &right_val) { + (Value::Integer(a), Value::Integer(b)) => Value::Integer(a - b), + (Value::Float(a), Value::Float(b)) => Value::Float(a - b), + (Value::Null, _) | (_, Value::Null) => Value::Null, + _ => panic!("Type mismatch in trivial expression: {left_val:?} - {right_val:?}. This is a bug in trivial expression validation."), + }, + Operator::Multiply => match (&left_val, &right_val) { + (Value::Integer(a), Value::Integer(b)) => Value::Integer(a * b), + (Value::Float(a), Value::Float(b)) => Value::Float(a * b), + (Value::Null, _) | (_, Value::Null) => Value::Null, + _ => panic!("Type mismatch in trivial expression: {left_val:?} * {right_val:?}. This is a bug in trivial expression validation."), + }, + Operator::Divide => match (&left_val, &right_val) { + (Value::Integer(a), Value::Integer(b)) => { + if *b != 0 { + Value::Integer(a / b) + } else { + Value::Null + } + } + (Value::Float(a), Value::Float(b)) => { + if *b != 0.0 { + Value::Float(a / b) + } else { + Value::Null + } + } + (Value::Null, _) | (_, Value::Null) => Value::Null, + _ => panic!("Type mismatch in trivial expression: {left_val:?} / {right_val:?}. This is a bug in trivial expression validation."), + }, + _ => panic!("Unsupported operator in trivial expression: {op:?}"), + } + } + } + } +} + +/// Compiled expression that can be executed on row values +#[derive(Clone)] +pub struct CompiledExpression { + /// The expression executor (trivial or compiled) + pub executor: ExpressionExecutor, + /// Number of input values expected (columns from the row) + pub input_count: usize, +} + +impl std::fmt::Debug for CompiledExpression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut s = f.debug_struct("CompiledExpression"); + s.field("input_count", &self.input_count); + match &self.executor { + ExpressionExecutor::Trivial(t) => s.field("executor", &format!("Trivial({t:?})")), + ExpressionExecutor::Compiled(p) => { + s.field("executor", &format!("Compiled({} insns)", p.insns.len())) + } + }; + s.finish() + } +} + +#[derive(PartialEq)] +enum TrivialType { + Integer, + Float, + Text, + Null, +} + +impl CompiledExpression { + /// Get the "type" of a trivial expression for type checking + /// Returns None if type can't be determined statically + fn get_trivial_type(expr: &TrivialExpression) -> Option { + match expr { + TrivialExpression::Column(_) => None, // Can't know column type statically + TrivialExpression::Immediate(val) => match val { + Value::Integer(_) => Some(TrivialType::Integer), + Value::Float(_) => Some(TrivialType::Float), + Value::Text(_) => Some(TrivialType::Text), + Value::Null => Some(TrivialType::Null), + _ => None, + }, + TrivialExpression::Binary { left, right, .. } => { + // For binary ops, both sides must have the same type + let left_type = Self::get_trivial_type(left)?; + let right_type = Self::get_trivial_type(right)?; + if left_type == right_type { + Some(left_type) + } else { + None // Type mismatch + } + } + } + } + + // Validates if an expression is trivial (columns, immediates, and simple arithmetic) + // Only considers expressions trivial if they don't require type coercion + fn try_get_trivial_expr( + expr: &Expr, + input_column_names: &[String], + ) -> Option { + match expr { + // Column reference or register + Expr::Id(name) => input_column_names + .iter() + .position(|col| col == name.as_str()) + .map(TrivialExpression::Column), + Expr::Register(idx) => Some(TrivialExpression::Column(*idx)), + + // Immediate values + Expr::Literal(lit) => { + let value = match lit { + Literal::Numeric(n) => { + if let Ok(i) = n.parse::() { + Value::Integer(i) + } else if let Ok(f) = n.parse::() { + Value::Float(f) + } else { + return None; + } + } + Literal::String(s) => { + let cleaned = s.trim_matches('\'').trim_matches('"'); + Value::Text(Text::new(cleaned)) + } + Literal::Null => Value::Null, + _ => return None, + }; + Some(TrivialExpression::Immediate(value)) + } + + // Binary operations with simple operators + Expr::Binary(left, op, right) => { + // Only support simple arithmetic operators + match op { + Operator::Add | Operator::Subtract | Operator::Multiply | Operator::Divide => { + // Both operands must be trivial + let left_trivial = Self::try_get_trivial_expr(left, input_column_names)?; + let right_trivial = Self::try_get_trivial_expr(right, input_column_names)?; + + // Check if we can determine types statically + // If both are immediates, they must have the same type + // If either is a column, we can't validate at compile time, + // but we'll assert at runtime if there's a mismatch + if let (Some(left_type), Some(right_type)) = ( + Self::get_trivial_type(&left_trivial), + Self::get_trivial_type(&right_trivial), + ) { + // Both types are known - they must match (or one is null) + if left_type != right_type + && left_type != TrivialType::Null + && right_type != TrivialType::Null + { + return None; // Type mismatch - not trivial + } + } + // If we can't determine types (columns involved), we optimistically + // assume they'll match at runtime (and assert if they don't) + + Some(TrivialExpression::Binary { + left: Box::new(left_trivial), + op: *op, + right: Box::new(right_trivial), + }) + } + _ => None, + } + } + + // Parenthesized expressions with single element + Expr::Parenthesized(exprs) if exprs.len() == 1 => { + Self::try_get_trivial_expr(&exprs[0], input_column_names) + } + + _ => None, + } + } + + /// Compile a SQL expression into either a trivial executor or VDBE program + /// + /// For trivial expressions (columns, immediates, simple same-type arithmetic), uses inline evaluation. + /// For complex expressions or those requiring type coercion, compiles to VDBE bytecode. + pub fn compile( + expr: &Expr, + input_column_names: &[String], + schema: &Schema, + syms: &SymbolTable, + connection: Arc, + ) -> Result { + let input_count = input_column_names.len(); + + // First, check if this is a trivial expression + if let Some(trivial) = Self::try_get_trivial_expr(expr, input_column_names) { + return Ok(CompiledExpression { + executor: ExpressionExecutor::Trivial(trivial), + input_count, + }); + } + + // Fall back to VDBE compilation for complex expressions + // Create a minimal program builder for expression compilation + let mut builder = ProgramBuilder::new( + QueryMode::Normal, + CaptureDataChangesMode::Off, + ProgramBuilderOpts { + num_cursors: 0, + approx_num_insns: 5, // Most expressions are simple + approx_num_labels: 0, // Expressions don't need labels + }, + ); + + // Allocate registers for input values + let input_count = input_column_names.len(); + + // Allocate input registers + for _ in 0..input_count { + builder.alloc_register(); + } + + // Allocate a temp register for computation + let temp_result_register = builder.alloc_register(); + + // Transform the expression to replace column references with Register expressions + let transformed_expr = transform_expr_for_dbsp(expr, input_column_names); + + // Create a resolver for translate_expr + let resolver = Resolver::new(schema, syms); + + // Translate the transformed expression to bytecode + translate_expr( + &mut builder, + None, // No table references needed for pure expressions + &transformed_expr, + temp_result_register, + &resolver, + )?; + + // Copy the result to register 0 for return + builder.emit_insn(Insn::Copy { + src_reg: temp_result_register, + dst_reg: 0, + extra_amount: 0, + }); + + // Add a Halt instruction to complete the subprogram + builder.emit_insn(Insn::Halt { + err_code: 0, + description: String::new(), + }); + + // Build the program from the compiled expression bytecode + let program = Arc::new(builder.build(connection, false, "")); + + Ok(CompiledExpression { + executor: ExpressionExecutor::Compiled(program), + input_count, + }) + } + + /// Execute the compiled expression with the given input values + pub fn execute(&self, values: &[Value], pager: Rc) -> Result { + match &self.executor { + ExpressionExecutor::Trivial(trivial) => { + // Fast path: evaluate trivial expression inline + Ok(trivial.evaluate(values)) + } + ExpressionExecutor::Compiled(program) => { + // Slow path: execute VDBE program + // Create a state with the input values loaded into registers + let mut state = ProgramState::new(program.max_registers, 0); + + // Load input values into registers + assert_eq!( + values.len(), + self.input_count, + "Mismatch in number of registers! Got {}, expected {}", + values.len(), + self.input_count + ); + for (idx, value) in values.iter().enumerate() { + state.set_register(idx, Register::Value(value.clone())); + } + + // Execute the program + let mut pc = 0usize; + while pc < program.insns.len() { + let (insn, insn_fn) = &program.insns[pc]; + state.pc = pc as u32; + + // Execute the instruction + match insn_fn(program, &mut state, insn, &pager, None)? { + crate::vdbe::execute::InsnFunctionStepResult::IO(_) => { + return Err(crate::LimboError::InternalError( + "Expression evaluation encountered unexpected I/O".to_string(), + )); + } + crate::vdbe::execute::InsnFunctionStepResult::Done => { + break; + } + crate::vdbe::execute::InsnFunctionStepResult::Row => { + return Err(crate::LimboError::InternalError( + "Expression evaluation produced unexpected row".to_string(), + )); + } + crate::vdbe::execute::InsnFunctionStepResult::Interrupt => { + return Err(crate::LimboError::InternalError( + "Expression evaluation was interrupted".to_string(), + )); + } + crate::vdbe::execute::InsnFunctionStepResult::Busy => { + return Err(crate::LimboError::InternalError( + "Expression evaluation encountered busy state".to_string(), + )); + } + crate::vdbe::execute::InsnFunctionStepResult::Step => { + pc = state.pc as usize; + } + } + } + + // The compiled expression puts the result in register 0 + match state.get_register(0) { + Register::Value(v) => Ok(v.clone()), + _ => Ok(Value::Null), + } + } + } + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index d80a09081..328f1a510 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -1,4 +1,5 @@ pub mod dbsp; +pub mod expr_compiler; pub mod hashable_row; pub mod operator; pub mod view; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 740044988..0391e3c0a 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -2,9 +2,10 @@ // Operator DAG for DBSP-style incremental computation // Based on Feldera DBSP design but adapted for Turso's architecture +use crate::incremental::expr_compiler::CompiledExpression; use crate::incremental::hashable_row::HashableRow; use crate::types::Text; -use crate::Value; +use crate::{Connection, Database, SymbolTable, Value}; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display}; use std::sync::Arc; @@ -342,14 +343,13 @@ impl FilterPredicate { } #[derive(Debug, Clone)] -pub enum ProjectColumn { - /// Direct column reference - Column(String), - /// Computed expression - Expression { - expr: Box, - alias: Option, - }, +pub struct ProjectColumn { + /// The original SQL expression (for debugging/fallback) + pub expr: turso_parser::ast::Expr, + /// Optional alias for the column + pub alias: Option, + /// Compiled expression (handles both trivial columns and complex expressions) + pub compiled: CompiledExpression, } #[derive(Debug, Clone)] @@ -584,34 +584,153 @@ impl IncrementalOperator for FilterOperator { } /// Project operator - selects/transforms columns -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ProjectOperator { columns: Vec, input_column_names: Vec, output_column_names: Vec, current_state: Delta, tracker: Option>>, + // Internal in-memory connection for expression evaluation + // Programs are very dependent on having a connection, so give it one. + // + // We could in theory pass the current connection, but there are a host of problems with that. + // For example: during a write transaction, where views are usually updated, we have autocommit + // on. When the program we are executing calls Halt, it will try to commit the current + // transaction, which is absolutely incorrect. + // + // There are other ways to solve this, but a read-only connection to an empty in-memory + // database gives us the closest environment we need to execute expressions. + internal_conn: Arc, +} + +impl std::fmt::Debug for ProjectOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProjectOperator") + .field("columns", &self.columns) + .field("input_column_names", &self.input_column_names) + .field("output_column_names", &self.output_column_names) + .field("current_state", &self.current_state) + .field("tracker", &self.tracker) + .finish_non_exhaustive() + } } impl ProjectOperator { - pub fn new(columns: Vec, input_column_names: Vec) -> Self { + /// Create a new ProjectOperator from a SELECT statement, extracting projection columns + pub fn from_select( + select: &turso_parser::ast::Select, + input_column_names: Vec, + schema: &crate::schema::Schema, + ) -> crate::Result { + use turso_parser::ast::*; + + // Set up internal connection for expression evaluation + let io = Arc::new(crate::MemoryIO::new()); + let db = Database::open_file( + io, ":memory:", false, // no MVCC needed for expression evaluation + false, // no indexes needed + )?; + let internal_conn = db.connect()?; + // Set to read-only mode and disable auto-commit since we're only evaluating expressions + internal_conn.query_only.set(true); + internal_conn.auto_commit.set(false); + + let temp_syms = SymbolTable::new(); + + // Extract columns from SELECT statement + let columns = if let OneSelect::Select { + columns: ref select_columns, + .. + } = &select.body.select + { + let mut columns = Vec::new(); + for result_col in select_columns { + match result_col { + ResultColumn::Expr(expr, alias) => { + let alias_str = if let Some(As::As(alias_name)) = alias { + Some(alias_name.as_str().to_string()) + } else { + None + }; + // Try to compile the expression (handles both columns and complex expressions) + let compiled = CompiledExpression::compile( + expr, + &input_column_names, + schema, + &temp_syms, + internal_conn.clone(), + )?; + columns.push(ProjectColumn { + expr: (**expr).clone(), + alias: alias_str, + compiled, + }); + } + ResultColumn::Star => { + // Select all columns - create trivial column references + for name in &input_column_names { + // Create an Id expression for the column + let expr = Expr::Id(Name::Ident(name.clone())); + let compiled = CompiledExpression::compile( + &expr, + &input_column_names, + schema, + &temp_syms, + internal_conn.clone(), + )?; + columns.push(ProjectColumn { + expr, + alias: None, + compiled, + }); + } + } + x => { + return Err(crate::LimboError::ParseError(format!( + "Unsupported {x:?} clause when compiling project operator", + ))); + } + } + } + + if columns.is_empty() { + return Err(crate::LimboError::ParseError( + "No columns found when compiling project operator".to_string(), + )); + } + columns + } else { + return Err(crate::LimboError::ParseError( + "Expression is not a valid SELECT expression".to_string(), + )); + }; + + // Generate output column names based on aliases or expressions let output_column_names = columns .iter() - .map(|c| match c { - ProjectColumn::Column(name) => name.clone(), - ProjectColumn::Expression { alias, .. } => { - alias.clone().unwrap_or_else(|| "expr".to_string()) - } + .map(|c| { + c.alias.clone().unwrap_or_else(|| match &c.expr { + Expr::Id(name) => name.as_str().to_string(), + Expr::Qualified(table, column) => { + format!("{}.{}", table.as_str(), column.as_str()) + } + Expr::DoublyQualified(db, table, column) => { + format!("{}.{}.{}", db.as_str(), table.as_str(), column.as_str()) + } + _ => c.expr.to_string(), + }) }) .collect(); - Self { + Ok(Self { columns, input_column_names, output_column_names, current_state: Delta::new(), tracker: None, - } + internal_conn, + }) } /// Get the columns for this projection @@ -623,24 +742,15 @@ impl ProjectOperator { let mut output = Vec::new(); for col in &self.columns { - match col { - ProjectColumn::Column(name) => { - if let Some(idx) = self.input_column_names.iter().position(|c| c == name) { - if let Some(v) = values.get(idx) { - output.push(v.clone()); - } else { - output.push(Value::Null); - } - } else { - output.push(Value::Null); - } - } - ProjectColumn::Expression { expr, .. } => { - // Evaluate the expression - let result = self.evaluate_expression(expr, values); - output.push(result); - } - } + // Use the internal connection's pager for expression evaluation + let internal_pager = self.internal_conn.pager.borrow().clone(); + + // Execute the compiled expression (handles both columns and complex expressions) + let result = col + .compiled + .execute(values, internal_pager) + .expect("Failed to execute compiled expression for the Project operator"); + output.push(result); } output @@ -648,7 +758,6 @@ impl ProjectOperator { fn evaluate_expression(&self, expr: &turso_parser::ast::Expr, values: &[Value]) -> Value { use turso_parser::ast::*; - match expr { Expr::Id(name) => { if let Some(idx) = self diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 3db09dea2..4f4d4c6e6 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -1,7 +1,7 @@ use super::dbsp::{RowKeyStream, RowKeyZSet}; use super::operator::{ AggregateFunction, AggregateOperator, ComputationTracker, Delta, FilterOperator, - FilterPredicate, IncrementalOperator, ProjectColumn, ProjectOperator, + FilterPredicate, IncrementalOperator, ProjectOperator, }; use crate::schema::{BTreeTable, Column, Schema}; use crate::types::{IOCompletions, IOResult, Value}; @@ -96,10 +96,7 @@ pub struct IncrementalView { impl IncrementalView { /// Validate that a CREATE MATERIALIZED VIEW statement can be handled by IncrementalView /// This should be called early, before updating sqlite_master - pub fn can_create_view(select: &ast::Select, schema: &Schema) -> Result<()> { - // Check for aggregations - let (group_by_columns, aggregate_functions, _) = Self::extract_aggregation_info(select); - + pub fn can_create_view(select: &ast::Select) -> Result<()> { // Check for JOINs let (join_tables, join_condition) = Self::extract_join_info(select); if join_tables.is_some() || join_condition.is_some() { @@ -108,29 +105,6 @@ impl IncrementalView { )); } - // Check that we have a base table - let base_table_name = Self::extract_base_table(select).ok_or_else(|| { - LimboError::ParseError("views without a base table not supported yet".to_string()) - })?; - - // Get the base table - let base_table = schema.get_btree_table(&base_table_name).ok_or_else(|| { - LimboError::ParseError(format!("Table '{base_table_name}' not found in schema")) - })?; - - // Get base table column names for validation - let base_table_column_names: Vec = base_table - .columns - .iter() - .enumerate() - .map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}"))) - .collect(); - - // For non-aggregated views, validate columns are a strict subset - if group_by_columns.is_empty() && aggregate_functions.is_empty() { - Self::validate_view_columns(select, &base_table_column_names)?; - } - Ok(()) } @@ -242,6 +216,7 @@ impl IncrementalView { view_columns, group_by_columns, aggregate_functions, + schema, ) } @@ -256,6 +231,7 @@ impl IncrementalView { columns: Vec, group_by_columns: Vec, aggregate_functions: Vec, + schema: &Schema, ) -> Result { let mut records = BTreeMap::new(); @@ -302,15 +278,11 @@ impl IncrementalView { // Only create project operator for non-aggregated views let project_operator = if !is_aggregated { - let columns = Self::extract_project_columns(&select_stmt, &base_table_column_names) - .unwrap_or_else(|| { - // If we can't extract columns, default to projecting all columns - base_table_column_names - .iter() - .map(|name| ProjectColumn::Column(name.to_string())) - .collect() - }); - let mut proj_op = ProjectOperator::new(columns, base_table_column_names.clone()); + let mut proj_op = ProjectOperator::from_select( + &select_stmt, + base_table_column_names.clone(), + schema, + )?; proj_op.set_tracker(tracker.clone()); Some(proj_op) } else { @@ -347,52 +319,6 @@ impl IncrementalView { vec![self.base_table.clone()] } - /// Validate that view columns are a strict subset of the base table columns - /// No duplicates, no complex expressions, only simple column references - fn validate_view_columns( - select: &ast::Select, - base_table_column_names: &[String], - ) -> Result<()> { - if let ast::OneSelect::Select { ref columns, .. } = select.body.select { - let mut seen_columns = std::collections::HashSet::new(); - - for result_col in columns { - match result_col { - ast::ResultColumn::Expr(expr, _) - if matches!(expr.as_ref(), ast::Expr::Id(_)) => - { - let ast::Expr::Id(name) = expr.as_ref() else { - unreachable!() - }; - let col_name = name.as_str(); - - // Check for duplicates - if !seen_columns.insert(col_name) { - return Err(LimboError::ParseError(format!( - "Duplicate column '{col_name}' in view. Views must have columns as a strict subset of the base table (no duplicates)" - ))); - } - - // Check that column exists in base table - if !base_table_column_names.iter().any(|n| n == col_name) { - return Err(LimboError::ParseError(format!( - "Column '{col_name}' not found in base table. Views must have columns as a strict subset of the base table" - ))); - } - } - ast::ResultColumn::Star => { - // SELECT * is allowed - it's the full set - } - _ => { - // Any other expression is not allowed - return Err(LimboError::ParseError("Complex expressions, functions, or computed columns are not supported in views. Views must have columns as a strict subset of the base table".to_string())); - } - } - } - } - Ok(()) - } - /// Extract the base table name from a SELECT statement (for non-join cases) fn extract_base_table(select: &ast::Select) -> Option { if let ast::OneSelect::Select { @@ -417,16 +343,14 @@ impl IncrementalView { // Get the columns used by the projection operator let mut columns = Vec::new(); for col in project_op.columns() { - match col { - ProjectColumn::Column(name) => { - columns.push(name.clone()); - } - ProjectColumn::Expression { .. } => { - // For expressions, we need all columns (for now) - columns.clear(); - columns.push("*".to_string()); - break; - } + // Check if it's a simple column reference + if let turso_parser::ast::Expr::Id(name) = &col.expr { + columns.push(name.as_str().to_string()); + } else { + // For expressions, we need all columns (for now) + columns.clear(); + columns.push("*".to_string()); + break; } } if columns.is_empty() || columns.contains(&"*".to_string()) { @@ -808,62 +732,6 @@ impl IncrementalView { None } - /// Extract projection columns from SELECT statement - fn extract_project_columns( - select: &ast::Select, - column_names: &[String], - ) -> Option> { - use turso_parser::ast::*; - - if let OneSelect::Select { - columns: ref select_columns, - .. - } = select.body.select - { - let mut columns = Vec::new(); - - for result_col in select_columns { - match result_col { - ResultColumn::Expr(expr, alias) => { - match expr.as_ref() { - Expr::Id(name) => { - // Simple column reference - columns.push(ProjectColumn::Column(name.as_str().to_string())); - } - _ => { - // Expression - store it for evaluation - let alias_str = if let Some(As::As(alias_name)) = alias { - Some(alias_name.as_str().to_string()) - } else { - None - }; - columns.push(ProjectColumn::Expression { - expr: expr.clone(), - alias: alias_str, - }); - } - } - } - ResultColumn::Star => { - // Select all columns - for name in column_names { - columns.push(ProjectColumn::Column(name.as_str().to_string())); - } - } - _ => { - // For now, skip TableStar and other cases - } - } - } - - if !columns.is_empty() { - return Some(columns); - } - } - - None - } - /// Get the current records as an iterator - for cursor-based access pub fn iter(&self) -> impl Iterator)> + '_ { self.stream.to_vec().into_iter().filter_map(move |row| { @@ -927,6 +795,12 @@ impl IncrementalView { // Apply operators in pipeline let mut current_delta = delta.clone(); current_delta = self.apply_filter_to_delta(current_delta); + + // Apply projection operator if present (for non-aggregated views) + if let Some(ref mut project_op) = self.project_operator { + current_delta = project_op.process_delta(current_delta); + } + current_delta = self.apply_aggregation_to_delta(current_delta); // Update records and stream with the processed delta @@ -1083,7 +957,7 @@ mod tests { #[test] fn test_projection_function_call() { let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT hex(a) as hex_a, b FROM t"; + let sql = "CREATE MATERIALIZED VIEW v AS SELECT abs(a - 300) as abs_diff, b FROM t"; let view = IncrementalView::from_sql(sql, &schema).unwrap(); @@ -1101,10 +975,8 @@ mod tests { let result = temp_project.get_current_state(); let (output, _weight) = result.changes.first().unwrap(); - assert_eq!( - output.values, - vec![Value::Text("FF".into()), Value::Integer(20),] - ); + // abs(255 - 300) = abs(-45) = 45 + assert_eq!(output.values, vec![Value::Integer(45), Value::Integer(20),]); } #[test] @@ -1214,12 +1086,12 @@ mod tests { assert_eq!( output.values, vec![ - Value::Integer(5), // a - Value::Integer(2), // b - Value::Integer(10), // a * 2 - Value::Integer(6), // b * 3 - Value::Integer(7), // a + b - Value::Text("F".into()), // hex(15) + Value::Integer(5), // a + Value::Integer(2), // b + Value::Integer(10), // a * 2 + Value::Integer(6), // b * 3 + Value::Integer(7), // a + b + Value::Text("3135".into()), // hex(15) - SQLite converts to string "15" then hex encodes ] ); } diff --git a/core/io/generic.rs b/core/io/generic.rs index 9a87c6c63..83caa1405 100644 --- a/core/io/generic.rs +++ b/core/io/generic.rs @@ -1,24 +1,20 @@ -use super::MemoryIO; -use crate::{Clock, Completion, CompletionType, File, Instant, LimboError, OpenFlags, Result, IO}; -use std::cell::RefCell; +use crate::{Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO}; +use parking_lot::RwLock; use std::io::{Read, Seek, Write}; use std::sync::Arc; -use tracing::{debug, trace}; - +use tracing::{debug, instrument, trace, Level}; pub struct GenericIO {} impl GenericIO { pub fn new() -> Result { - debug!("Using IO backend 'generic'"); + debug!("Using IO backend 'syscall'"); Ok(Self {}) } } -unsafe impl Send for GenericIO {} -unsafe impl Sync for GenericIO {} - impl IO for GenericIO { - fn open_file(&self, path: &str, flags: OpenFlags, _direct: bool) -> Result> { + #[instrument(err, skip_all, level = Level::TRACE)] + fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result> { trace!("open_file(path = {})", path); let mut file = std::fs::File::options(); file.read(true); @@ -30,17 +26,17 @@ impl IO for GenericIO { let file = file.open(path)?; Ok(Arc::new(GenericFile { - file: RefCell::new(file), - memory_io: Arc::new(MemoryIO::new()), + file: RwLock::new(file), })) } - + + #[instrument(err, skip_all, level = Level::TRACE)] fn remove_file(&self, path: &str) -> Result<()> { trace!("remove_file(path = {})", path); - std::fs::remove_file(path)?; - Ok(()) + Ok(std::fs::remove_file(path)?) } + #[instrument(err, skip_all, level = Level::TRACE)] fn run_once(&self) -> Result<()> { Ok(()) } @@ -57,68 +53,63 @@ impl Clock for GenericIO { } pub struct GenericFile { - file: RefCell, - memory_io: Arc, + file: RwLock, } -unsafe impl Send for GenericFile {} -unsafe impl Sync for GenericFile {} - impl File for GenericFile { - // Since we let the OS handle the locking, file locking is not supported on the generic IO implementation - // No-op implementation allows compilation but provides no actual file locking. - fn lock_file(&self, _exclusive: bool) -> Result<()> { - Ok(()) + #[instrument(err, skip_all, level = Level::TRACE)] + fn lock_file(&self, exclusive: bool) -> Result<()> { + unimplemented!() } + #[instrument(err, skip_all, level = Level::TRACE)] fn unlock_file(&self) -> Result<()> { - Ok(()) + unimplemented!() } + #[instrument(skip(self, c), level = Level::TRACE)] fn pread(&self, pos: usize, c: Completion) -> Result { - let mut file = self.file.borrow_mut(); + let mut file = self.file.write(); file.seek(std::io::SeekFrom::Start(pos as u64))?; - { + let nr = { let r = c.as_read(); - let mut buf = r.buf(); + let buf = r.buf(); let buf = buf.as_mut_slice(); file.read_exact(buf)?; - } - c.complete(0); + buf.len() as i32 + }; + c.complete(nr); Ok(c) } + #[instrument(skip(self, c, buffer), level = Level::TRACE)] fn pwrite(&self, pos: usize, buffer: Arc, c: Completion) -> Result { - let mut file = self.file.borrow_mut(); + let mut file = self.file.write(); file.seek(std::io::SeekFrom::Start(pos as u64))?; let buf = buffer.as_slice(); file.write_all(buf)?; - c.complete(buf.len() as i32); + c.complete(buffer.len() as i32); Ok(c) } + #[instrument(err, skip_all, level = Level::TRACE)] fn sync(&self, c: Completion) -> Result { - let mut file = self.file.borrow_mut(); + let file = self.file.write(); file.sync_all()?; c.complete(0); Ok(c) } + #[instrument(err, skip_all, level = Level::TRACE)] fn truncate(&self, len: usize, c: Completion) -> Result { - let mut file = self.file.borrow_mut(); + let file = self.file.write(); file.set_len(len as u64)?; c.complete(0); Ok(c) } fn size(&self) -> Result { - let file = self.file.borrow(); + let file = self.file.read(); Ok(file.metadata().unwrap().len()) } } - -impl Drop for GenericFile { - fn drop(&mut self) { - self.unlock_file().expect("Failed to unlock file"); - } -} diff --git a/core/io/mod.rs b/core/io/mod.rs index f376b5a3c..992eabac0 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -506,21 +506,7 @@ cfg_block! { pub use PlatformIO as SyscallIO; } - #[cfg(any(target_os = "android", target_os = "ios"))] { - mod unix; - #[cfg(feature = "fs")] - pub use unix::UnixIO; - pub use unix::UnixIO as SyscallIO; - pub use unix::UnixIO as PlatformIO; - } - - #[cfg(target_os = "windows")] { - mod windows; - pub use windows::WindowsIO as PlatformIO; - pub use PlatformIO as SyscallIO; - } - - #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "android", target_os = "ios")))] { + #[cfg(not(any(target_family = "unix", target_os = "android", target_os = "ios")))] { mod generic; pub use generic::GenericIO as PlatformIO; pub use PlatformIO as SyscallIO; diff --git a/core/io/unix.rs b/core/io/unix.rs index e10f7f3ec..b7567b683 100644 --- a/core/io/unix.rs +++ b/core/io/unix.rs @@ -259,12 +259,32 @@ impl File for UnixFile { #[instrument(err, skip_all, level = Level::TRACE)] fn sync(&self, c: Completion) -> Result { let file = self.file.lock(); - let result = unsafe { libc::fsync(file.as_raw_fd()) }; + + let result = unsafe { + + #[cfg(not(any(target_os = "macos", target_os = "ios")))] + { + libc::fsync(file.as_raw_fd()) + } + + #[cfg(any(target_os = "macos", target_os = "ios"))] + { + libc::fcntl(file.as_raw_fd(), libc::F_FULLFSYNC) + } + + }; + if result == -1 { let e = std::io::Error::last_os_error(); Err(e.into()) } else { + + #[cfg(not(any(target_os = "macos", target_os = "ios")))] trace!("fsync"); + + #[cfg(any(target_os = "macos", target_os = "ios"))] + trace!("fcntl(F_FULLSYNC)"); + c.complete(0); Ok(c) } diff --git a/core/io/windows.rs b/core/io/windows.rs deleted file mode 100644 index acb12b344..000000000 --- a/core/io/windows.rs +++ /dev/null @@ -1,115 +0,0 @@ -use crate::{Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO}; -use parking_lot::RwLock; -use std::io::{Read, Seek, Write}; -use std::sync::Arc; -use tracing::{debug, instrument, trace, Level}; -pub struct WindowsIO {} - -impl WindowsIO { - pub fn new() -> Result { - debug!("Using IO backend 'syscall'"); - Ok(Self {}) - } -} - -impl IO for WindowsIO { - #[instrument(err, skip_all, level = Level::TRACE)] - fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result> { - trace!("open_file(path = {})", path); - let mut file = std::fs::File::options(); - file.read(true); - - if !flags.contains(OpenFlags::ReadOnly) { - file.write(true); - file.create(flags.contains(OpenFlags::Create)); - } - - let file = file.open(path)?; - Ok(Arc::new(WindowsFile { - file: RwLock::new(file), - })) - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn remove_file(&self, path: &str) -> Result<()> { - trace!("remove_file(path = {})", path); - Ok(std::fs::remove_file(path)?) - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn run_once(&self) -> Result<()> { - Ok(()) - } -} - -impl Clock for WindowsIO { - fn now(&self) -> Instant { - let now = chrono::Local::now(); - Instant { - secs: now.timestamp(), - micros: now.timestamp_subsec_micros(), - } - } -} - -pub struct WindowsFile { - file: RwLock, -} - -impl File for WindowsFile { - #[instrument(err, skip_all, level = Level::TRACE)] - fn lock_file(&self, exclusive: bool) -> Result<()> { - unimplemented!() - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn unlock_file(&self) -> Result<()> { - unimplemented!() - } - - #[instrument(skip(self, c), level = Level::TRACE)] - fn pread(&self, pos: usize, c: Completion) -> Result { - let mut file = self.file.write(); - file.seek(std::io::SeekFrom::Start(pos as u64))?; - let nr = { - let r = c.as_read(); - let buf = r.buf(); - let buf = buf.as_mut_slice(); - file.read_exact(buf)?; - buf.len() as i32 - }; - c.complete(nr); - Ok(c) - } - - #[instrument(skip(self, c, buffer), level = Level::TRACE)] - fn pwrite(&self, pos: usize, buffer: Arc, c: Completion) -> Result { - let mut file = self.file.write(); - file.seek(std::io::SeekFrom::Start(pos as u64))?; - let buf = buffer.as_slice(); - file.write_all(buf)?; - c.complete(buffer.len() as i32); - Ok(c) - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn sync(&self, c: Completion) -> Result { - let file = self.file.write(); - file.sync_all()?; - c.complete(0); - Ok(c) - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn truncate(&self, len: usize, c: Completion) -> Result { - let file = self.file.write(); - file.set_len(len as u64)?; - c.complete(0); - Ok(c) - } - - fn size(&self) -> Result { - let file = self.file.read(); - Ok(file.metadata().unwrap().len()) - } -} diff --git a/core/json/jsonb.rs b/core/json/jsonb.rs index cb4028abd..c4c95aac0 100644 --- a/core/json/jsonb.rs +++ b/core/json/jsonb.rs @@ -909,12 +909,13 @@ impl Jsonb { } } - pub fn to_string(&self) -> Result { + #[expect(clippy::inherent_to_string)] + pub fn to_string(&self) -> String { let mut result = String::with_capacity(self.data.len() * 2); - self.write_to_string(&mut result, JsonIndentation::None)?; + self.write_to_string(&mut result, JsonIndentation::None); - Ok(result) + result } pub fn to_string_pretty(&self, indentation: Option<&str>) -> Result { @@ -924,16 +925,15 @@ impl Jsonb { } else { JsonIndentation::Indentation(Cow::Borrowed(" ")) }; - self.write_to_string(&mut result, ind)?; + self.write_to_string(&mut result, ind); Ok(result) } - fn write_to_string(&self, string: &mut String, indentation: JsonIndentation) -> Result<()> { + fn write_to_string(&self, string: &mut String, indentation: JsonIndentation) { let cursor = 0; let ind = indentation; let _ = self.serialize_value(string, cursor, 0, &ind); - Ok(()) } fn serialize_value( @@ -3093,7 +3093,7 @@ mod tests { jsonb.data.push(ElementType::NULL as u8); // Test serialization - let json_str = jsonb.to_string().unwrap(); + let json_str = jsonb.to_string(); assert_eq!(json_str, "null"); // Test round-trip @@ -3106,12 +3106,12 @@ mod tests { // True let mut jsonb_true = Jsonb::new(10, None); jsonb_true.data.push(ElementType::TRUE as u8); - assert_eq!(jsonb_true.to_string().unwrap(), "true"); + assert_eq!(jsonb_true.to_string(), "true"); // False let mut jsonb_false = Jsonb::new(10, None); jsonb_false.data.push(ElementType::FALSE as u8); - assert_eq!(jsonb_false.to_string().unwrap(), "false"); + assert_eq!(jsonb_false.to_string(), "false"); // Round-trip let true_parsed = Jsonb::from_str("true").unwrap(); @@ -3125,15 +3125,15 @@ mod tests { fn test_integer_serialization() { // Standard integer let parsed = Jsonb::from_str("42").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "42"); + assert_eq!(parsed.to_string(), "42"); // Negative integer let parsed = Jsonb::from_str("-123").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "-123"); + assert_eq!(parsed.to_string(), "-123"); // Zero let parsed = Jsonb::from_str("0").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "0"); + assert_eq!(parsed.to_string(), "0"); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3144,15 +3144,15 @@ mod tests { fn test_json5_integer_serialization() { // Hexadecimal notation let parsed = Jsonb::from_str("0x1A").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "26"); // Should convert to decimal + assert_eq!(parsed.to_string(), "26"); // Should convert to decimal // Positive sign (JSON5) let parsed = Jsonb::from_str("+42").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "42"); + assert_eq!(parsed.to_string(), "42"); // Negative hexadecimal let parsed = Jsonb::from_str("-0xFF").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "-255"); + assert_eq!(parsed.to_string(), "-255"); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3163,15 +3163,15 @@ mod tests { fn test_float_serialization() { // Standard float let parsed = Jsonb::from_str("3.14159").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "3.14159"); + assert_eq!(parsed.to_string(), "3.14159"); // Negative float let parsed = Jsonb::from_str("-2.718").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "-2.718"); + assert_eq!(parsed.to_string(), "-2.718"); // Scientific notation let parsed = Jsonb::from_str("6.022e23").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "6.022e23"); + assert_eq!(parsed.to_string(), "6.022e23"); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3182,23 +3182,23 @@ mod tests { fn test_json5_float_serialization() { // Leading decimal point let parsed = Jsonb::from_str(".123").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "0.123"); + assert_eq!(parsed.to_string(), "0.123"); // Trailing decimal point let parsed = Jsonb::from_str("42.").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "42.0"); + assert_eq!(parsed.to_string(), "42.0"); // Plus sign in exponent let parsed = Jsonb::from_str("1.5e+10").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "1.5e+10"); + assert_eq!(parsed.to_string(), "1.5e+10"); // Infinity let parsed = Jsonb::from_str("Infinity").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "9e999"); + assert_eq!(parsed.to_string(), "9e999"); // Negative Infinity let parsed = Jsonb::from_str("-Infinity").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "-9e999"); + assert_eq!(parsed.to_string(), "-9e999"); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3209,15 +3209,15 @@ mod tests { fn test_string_serialization() { // Simple string let parsed = Jsonb::from_str(r#""hello world""#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello world""#); + assert_eq!(parsed.to_string(), r#""hello world""#); // String with escaped characters let parsed = Jsonb::from_str(r#""hello\nworld""#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello\nworld""#); + assert_eq!(parsed.to_string(), r#""hello\nworld""#); // Unicode escape let parsed = Jsonb::from_str(r#""hello\u0020world""#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello\u0020world""#); + assert_eq!(parsed.to_string(), r#""hello\u0020world""#); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3228,11 +3228,11 @@ mod tests { fn test_json5_string_serialization() { // Single quotes let parsed = Jsonb::from_str("'hello world'").unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello world""#); + assert_eq!(parsed.to_string(), r#""hello world""#); // Hex escape let parsed = Jsonb::from_str(r#"'\x41\x42\x43'"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""\u0041\u0042\u0043""#); + assert_eq!(parsed.to_string(), r#""\u0041\u0042\u0043""#); // Multiline string with line continuation let parsed = Jsonb::from_str( @@ -3240,11 +3240,11 @@ mod tests { world""#, ) .unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello world""#); + assert_eq!(parsed.to_string(), r#""hello world""#); // Escaped single quote let parsed = Jsonb::from_str(r#"'Don\'t worry'"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""Don't worry""#); + assert_eq!(parsed.to_string(), r#""Don't worry""#); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3255,20 +3255,20 @@ world""#, fn test_array_serialization() { // Empty array let parsed = Jsonb::from_str("[]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[]"); + assert_eq!(parsed.to_string(), "[]"); // Simple array let parsed = Jsonb::from_str("[1,2,3]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); // Nested array let parsed = Jsonb::from_str("[[1,2],[3,4]]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[[1,2],[3,4]]"); + assert_eq!(parsed.to_string(), "[[1,2],[3,4]]"); // Mixed types array let parsed = Jsonb::from_str(r#"[1,"text",true,null,{"key":"value"}]"#).unwrap(); assert_eq!( - parsed.to_string().unwrap(), + parsed.to_string(), r#"[1,"text",true,null,{"key":"value"}]"# ); @@ -3281,44 +3281,41 @@ world""#, fn test_json5_array_serialization() { // Trailing comma let parsed = Jsonb::from_str("[1,2,3,]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); // Comments in array let parsed = Jsonb::from_str("[1,/* comment */2,3]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); // Line comment in array let parsed = Jsonb::from_str("[1,// line comment\n2,3]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); } #[test] fn test_object_serialization() { // Empty object let parsed = Jsonb::from_str("{}").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "{}"); + assert_eq!(parsed.to_string(), "{}"); // Simple object let parsed = Jsonb::from_str(r#"{"key":"value"}"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"key":"value"}"#); // Multiple properties let parsed = Jsonb::from_str(r#"{"a":1,"b":2,"c":3}"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"a":1,"b":2,"c":3}"#); + assert_eq!(parsed.to_string(), r#"{"a":1,"b":2,"c":3}"#); // Nested object let parsed = Jsonb::from_str(r#"{"outer":{"inner":"value"}}"#).unwrap(); - assert_eq!( - parsed.to_string().unwrap(), - r#"{"outer":{"inner":"value"}}"# - ); + assert_eq!(parsed.to_string(), r#"{"outer":{"inner":"value"}}"#); // Mixed values let parsed = Jsonb::from_str(r#"{"str":"text","num":42,"bool":true,"null":null,"arr":[1,2]}"#) .unwrap(); assert_eq!( - parsed.to_string().unwrap(), + parsed.to_string(), r#"{"str":"text","num":42,"bool":true,"null":null,"arr":[1,2]}"# ); @@ -3331,19 +3328,19 @@ world""#, fn test_json5_object_serialization() { // Unquoted keys let parsed = Jsonb::from_str("{key:\"value\"}").unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"key":"value"}"#); // Trailing comma let parsed = Jsonb::from_str(r#"{"a":1,"b":2,}"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"a":1,"b":2}"#); + assert_eq!(parsed.to_string(), r#"{"a":1,"b":2}"#); // Comments in object let parsed = Jsonb::from_str(r#"{"a":1,/*comment*/"b":2}"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"a":1,"b":2}"#); + assert_eq!(parsed.to_string(), r#"{"a":1,"b":2}"#); // Single quotes for keys and values let parsed = Jsonb::from_str("{'a':'value'}").unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"a":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"a":"value"}"#); } #[test] @@ -3366,8 +3363,8 @@ world""#, let parsed = Jsonb::from_str(complex_json).unwrap(); // Round-trip test - let reparsed = Jsonb::from_str(&parsed.to_string().unwrap()).unwrap(); - assert_eq!(parsed.to_string().unwrap(), reparsed.to_string().unwrap()); + let reparsed = Jsonb::from_str(&parsed.to_string()).unwrap(); + assert_eq!(parsed.to_string(), reparsed.to_string()); } #[test] @@ -3486,11 +3483,11 @@ world""#, fn test_unicode_escapes() { // Basic unicode escape let parsed = Jsonb::from_str(r#""\u00A9""#).unwrap(); // Copyright symbol - assert_eq!(parsed.to_string().unwrap(), r#""\u00A9""#); + assert_eq!(parsed.to_string(), r#""\u00A9""#); // Non-BMP character (surrogate pair) let parsed = Jsonb::from_str(r#""\uD83D\uDE00""#).unwrap(); // Smiley emoji - assert_eq!(parsed.to_string().unwrap(), r#""\uD83D\uDE00""#); + assert_eq!(parsed.to_string(), r#""\uD83D\uDE00""#); } #[test] @@ -3503,7 +3500,7 @@ world""#, }"#, ) .unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"key":"value"}"#); // Block comments let parsed = Jsonb::from_str( @@ -3514,7 +3511,7 @@ world""#, }"#, ) .unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"key":"value"}"#); // Comments inside array let parsed = Jsonb::from_str( @@ -3522,7 +3519,7 @@ world""#, 2, /* Another comment */ 3]"#, ) .unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); } #[test] @@ -3540,7 +3537,7 @@ world""#, let parsed = Jsonb::from_str(json_with_whitespace).unwrap(); assert_eq!( - parsed.to_string().unwrap(), + parsed.to_string(), r#"{"key1":"value1","key2":[1,2,3],"key3":{"nested":true}}"# ); } @@ -3554,7 +3551,7 @@ world""#, // Create a new Jsonb from the binary data let from_binary = Jsonb::new(0, Some(&binary_data)); - assert_eq!(from_binary.to_string().unwrap(), original); + assert_eq!(from_binary.to_string(), original); } #[test] @@ -3570,8 +3567,8 @@ world""#, large_array.push(']'); let parsed = Jsonb::from_str(&large_array).unwrap(); - assert!(parsed.to_string().unwrap().starts_with("[0,1,2,")); - assert!(parsed.to_string().unwrap().ends_with("998,999]")); + assert!(parsed.to_string().starts_with("[0,1,2,")); + assert!(parsed.to_string().ends_with("998,999]")); } #[test] @@ -3600,7 +3597,7 @@ world""#, }"#; let parsed = Jsonb::from_str(json).unwrap(); - let result = parsed.to_string().unwrap(); + let result = parsed.to_string(); assert!(result.contains(r#""escaped_quotes":"He said \"Hello\"""#)); assert!(result.contains(r#""backslashes":"C:\\Windows\\System32""#)); @@ -3767,7 +3764,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the value was updated - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"Jane","age":30}"#); } @@ -3791,7 +3788,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the value was inserted - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"John","age":30}"#); } @@ -3814,7 +3811,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the property was deleted - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"John"}"#); } @@ -3839,7 +3836,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the value was replaced - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"items":[10,50,30]}"#); } @@ -3863,7 +3860,7 @@ mod path_operations_tests { // Get the search result let search_result = operation.result(); - let result_str = search_result.to_string().unwrap(); + let result_str = search_result.to_string(); // Verify the search found the correct value assert_eq!(result_str, r#"{"name":"John","age":30}"#); @@ -3912,7 +3909,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the deep value was updated - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!( updated_json, r#"{"level1":{"level2":{"level3":{"value":100}}}}"# @@ -3953,7 +3950,7 @@ mod path_operations_tests { let result = jsonb.operate_on_path(&path, &mut operation); assert!(result.is_ok()); - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"John","age":30}"#); // 3. InsertNew mode - should fail when path already exists @@ -3991,7 +3988,7 @@ mod path_operations_tests { let result = jsonb.operate_on_path(&path, &mut operation); assert!(result.is_ok()); - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"John","age":31,"surname":"Doe"}"#); } } diff --git a/core/json/mod.rs b/core/json/mod.rs index 5e8aeadef..caa1b28a0 100644 --- a/core/json/mod.rs +++ b/core/json/mod.rs @@ -44,7 +44,7 @@ pub fn get_json(json_value: &Value, indent: Option<&str>) -> crate::Result json_val.to_string_pretty(Some(indent))?, - None => json_val.to_string()?, + None => json_val.to_string(), }; Ok(Value::Text(Text::json(json))) @@ -53,7 +53,7 @@ pub fn get_json(json_value: &Value, indent: Option<&str>) -> crate::Result crate::Result { - let mut json_string = json.to_string()?; + let mut json_string = json.to_string(); if matches!(flag, OutputVariant::Binary) { return Ok(Value::Blob(json.data())); } diff --git a/core/lib.rs b/core/lib.rs index 5ba3a4faa..553076375 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -41,6 +41,7 @@ pub mod numeric; mod numeric; use crate::incremental::view::ViewTransactionState; +use crate::storage::encryption::CipherMode; use crate::translate::optimizer::optimize_plan; use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME; #[cfg(all(feature = "fs", feature = "conn_raw_api"))] @@ -71,11 +72,11 @@ use std::{ num::NonZero, ops::Deref, rc::Rc, - sync::{Arc, LazyLock, Mutex, Weak}, + sync::{atomic::AtomicUsize, Arc, LazyLock, Mutex, Weak}, }; #[cfg(feature = "fs")] use storage::database::DatabaseFile; -pub use storage::encryption::{EncryptionKey, EncryptionContext}; +pub use storage::encryption::{EncryptionContext, EncryptionKey}; use storage::page_cache::DumbLruPageCache; use storage::pager::{AtomicDbState, DbState}; use storage::sqlite3_ondisk::PageSize; @@ -137,6 +138,7 @@ pub struct Database { open_flags: OpenFlags, builtin_syms: RefCell, experimental_views: bool, + n_connections: AtomicUsize, } unsafe impl Send for Database {} @@ -185,6 +187,12 @@ impl fmt::Debug for Database { }; debug_struct.field("page_cache", &cache_info); + debug_struct.field( + "n_connections", + &self + .n_connections + .load(std::sync::atomic::Ordering::Relaxed), + ); debug_struct.finish() } } @@ -372,6 +380,7 @@ impl Database { init_lock: Arc::new(Mutex::new(())), experimental_views: enable_views, buffer_pool: BufferPool::begin_init(&io, arena_size), + n_connections: AtomicUsize::new(0), }); db.register_global_builtin_extensions() .expect("unable to register global extensions"); @@ -455,7 +464,10 @@ impl Database { metrics: RefCell::new(ConnectionMetrics::new()), is_nested_stmt: Cell::new(false), encryption_key: RefCell::new(None), + encryption_cipher_mode: Cell::new(None), }); + self.n_connections + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); let builtin_syms = self.builtin_syms.borrow(); // add built-in extensions symbols to the connection to prevent having to load each time conn.syms.borrow_mut().extend(&builtin_syms); @@ -886,6 +898,18 @@ pub struct Connection { /// Generally this is only true for ParseSchema. is_nested_stmt: Cell, encryption_key: RefCell>, + encryption_cipher_mode: Cell>, +} + +impl Drop for Connection { + fn drop(&mut self) { + if !self.closed.get() { + // if connection wasn't properly closed, decrement the connection counter + self._db + .n_connections + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } + } } impl Connection { @@ -1500,16 +1524,23 @@ impl Connection { pager.end_tx( true, // rollback = true for close self, - self.wal_auto_checkpoint_disabled.get(), ) })?; self.transaction_state.set(TransactionState::None); } } - self.pager - .borrow() - .checkpoint_shutdown(self.wal_auto_checkpoint_disabled.get()) + if self + ._db + .n_connections + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed) + .eq(&1) + { + self.pager + .borrow() + .checkpoint_shutdown(self.wal_auto_checkpoint_disabled.get())?; + }; + Ok(()) } pub fn wal_auto_checkpoint_disable(&self) { @@ -1958,8 +1989,31 @@ impl Connection { pub fn set_encryption_key(&self, key: EncryptionKey) { tracing::trace!("setting encryption key for connection"); *self.encryption_key.borrow_mut() = Some(key.clone()); + self.set_encryption_context(); + } + + pub fn set_encryption_cipher(&self, cipher_mode: CipherMode) { + tracing::trace!("setting encryption cipher for connection"); + self.encryption_cipher_mode.replace(Some(cipher_mode)); + self.set_encryption_context(); + } + + pub fn get_encryption_cipher_mode(&self) -> Option { + self.encryption_cipher_mode.get() + } + + // if both key and cipher are set, set encryption context on pager + fn set_encryption_context(&self) { + let key_ref = self.encryption_key.borrow(); + let Some(key) = key_ref.as_ref() else { + return; + }; + let Some(cipher_mode) = self.encryption_cipher_mode.get() else { + return; + }; + tracing::trace!("setting encryption ctx for connection"); let pager = self.pager.borrow(); - pager.set_encryption_context(&key); + pager.set_encryption_context(cipher_mode, key); } } @@ -2106,7 +2160,7 @@ impl Statement { } let state = self.program.connection.transaction_state.get(); if let TransactionState::Write { .. } = state { - let end_tx_res = self.pager.end_tx(true, &self.program.connection, true)?; + let end_tx_res = self.pager.end_tx(true, &self.program.connection)?; self.program .connection .transaction_state @@ -2166,10 +2220,18 @@ impl Statement { self.program.parameters.count() } + pub fn parameter_index(&self, name: &str) -> Option> { + self.program.parameters.index(name) + } + pub fn bind_at(&mut self, index: NonZero, value: Value) { self.state.bind_at(index, value); } + pub fn clear_bindings(&mut self) { + self.state.clear_bindings(); + } + pub fn reset(&mut self) { self.state.reset(); } diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index bb1a0608e..13e28c437 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -546,7 +546,6 @@ impl StateTransition for CommitStateMachine { .end_tx( false, // rollback = false since we're committing &self.connection, - self.connection.wal_auto_checkpoint_disabled.get(), ) .map_err(|e| LimboError::InternalError(e.to_string())) .unwrap(); diff --git a/core/pragma.rs b/core/pragma.rs index f1a77fa66..e006963c0 100644 --- a/core/pragma.rs +++ b/core/pragma.rs @@ -109,7 +109,11 @@ pub fn pragma_for(pragma: &PragmaName) -> Pragma { FreelistCount => Pragma::new(PragmaFlags::Result0, &["freelist_count"]), EncryptionKey => Pragma::new( PragmaFlags::Result0 | PragmaFlags::SchemaReq | PragmaFlags::NoColumns1, - &["key"], + &["hexkey"], + ), + EncryptionCipher => Pragma::new( + PragmaFlags::Result0 | PragmaFlags::SchemaReq | PragmaFlags::NoColumns1, + &["cipher"], ), } } diff --git a/core/storage/btree.rs b/core/storage/btree.rs index ca5624220..49f4c06ef 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -6,9 +6,10 @@ use crate::{ storage::{ pager::{BtreePageAllocMode, Pager}, sqlite3_ondisk::{ - read_u32, read_varint, BTreeCell, DatabaseHeader, PageContent, PageSize, PageType, - TableInteriorCell, TableLeafCell, CELL_PTR_SIZE_BYTES, INTERIOR_PAGE_HEADER_SIZE_BYTES, - LEAF_PAGE_HEADER_SIZE_BYTES, LEFT_CHILD_PTR_SIZE_BYTES, + payload_overflows, read_u32, read_varint, BTreeCell, DatabaseHeader, PageContent, + PageSize, PageType, TableInteriorCell, TableLeafCell, CELL_PTR_SIZE_BYTES, + INTERIOR_PAGE_HEADER_SIZE_BYTES, LEAF_PAGE_HEADER_SIZE_BYTES, + LEFT_CHILD_PTR_SIZE_BYTES, }, state_machines::{ AdvanceState, CountState, EmptyTableState, MoveToRightState, MoveToState, RewindState, @@ -37,7 +38,6 @@ use super::{ write_varint_to_vec, IndexInteriorCell, IndexLeafCell, OverflowCell, MINIMUM_CELL_SIZE, }, }; -#[cfg(debug_assertions)] use std::collections::HashSet; use std::{ cell::{Cell, Ref, RefCell}, @@ -2274,7 +2274,7 @@ impl BTreeCursor { ref mut fill_cell_payload_state, } => { return_if_io!(fill_cell_payload( - page.get().get().contents.as_ref().unwrap(), + page.get(), bkey.maybe_rowid(), new_payload, *cell_idx, @@ -4321,6 +4321,9 @@ impl BTreeCursor { return Ok(IOResult::Done(None)); } } + if self.get_null_flag() { + return Ok(IOResult::Done(None)); + } if self.has_record.get() { let page = self.stack.top(); let page = page.get(); @@ -5176,10 +5179,9 @@ impl BTreeCursor { fill_cell_payload_state, } => { let page = page_ref.get(); - let page_contents = page.get().contents.as_ref().unwrap(); { return_if_io!(fill_cell_payload( - page_contents, + page, *rowid, new_payload, cell_idx, @@ -5502,6 +5504,11 @@ pub enum IntegrityCheckError { got: usize, expected: usize, }, + #[error("Page {page_id} referenced multiple times")] + PageReferencedMultipleTimes { + page_id: usize, + is_overflow_page: bool, + }, } #[derive(Clone)] @@ -5514,6 +5521,7 @@ pub struct IntegrityCheckState { pub current_page: usize, page_stack: Vec, first_leaf_level: Option, + pages_referenced: HashSet, page: Option, } @@ -5526,6 +5534,7 @@ impl IntegrityCheckState { level: 0, max_intkey: i64::MAX, }], + pages_referenced: HashSet::new(), first_leaf_level: None, page: None, } @@ -5555,165 +5564,214 @@ pub fn integrity_check( errors: &mut Vec, pager: &Rc, ) -> Result> { - let Some(IntegrityCheckPageEntry { - page_idx, - level, - max_intkey, - }) = state.page_stack.last().cloned() - else { - return Ok(IOResult::Done(())); - }; - let page = match state.page.take() { - Some(page) => page, - None => { - let (page, c) = btree_read_page(pager, page_idx)?; - state.page = Some(page.get()); - if let Some(c) = c { - io_yield_one!(c); + loop { + let Some(IntegrityCheckPageEntry { + page_idx, + level, + max_intkey, + }) = state.page_stack.last().cloned() + else { + return Ok(IOResult::Done(())); + }; + let page = match state.page.take() { + Some(page) => page, + None => { + let (page, c) = btree_read_page(pager, page_idx)?; + state.page = Some(page.get()); + if let Some(c) = c { + io_yield_one!(c); + } + state.page.take().expect("page should be present") } - page.get() - } - }; - turso_assert!(page.is_loaded(), "page should be loaded"); - state.page_stack.pop(); + }; + turso_assert!(page.is_loaded(), "page should be loaded"); + state.page_stack.pop(); - let contents = page.get_contents(); - let usable_space = pager.usable_space(); - let mut coverage_checker = CoverageChecker::new(page.get().id); + let contents = page.get_contents(); + let is_overflow_page = contents.maybe_page_type().is_none(); + if !state.pages_referenced.insert(page.get().id) { + errors.push(IntegrityCheckError::PageReferencedMultipleTimes { + page_id: page.get().id, + is_overflow_page, + }); + continue; + } + let usable_space = pager.usable_space(); + let mut coverage_checker = CoverageChecker::new(page.get().id); - // Now we check every cell for few things: - // 1. Check cell is in correct range. Not exceeds page and not starts before we have marked - // (cell content area). - // 2. We add the cell to coverage checker in order to check if cells do not overlap. - // 3. We check order of rowids in case of table pages. We iterate backwards in order to check - // if current cell's rowid is less than the next cell. We also check rowid is less than the - // parent's divider cell. In case of this page being root page max rowid will be i64::MAX. - // 4. We append pages to the stack to check later. - // 5. In case of leaf page, check if the current level(depth) is equal to other leaf pages we - // have seen. - let mut next_rowid = max_intkey; - for cell_idx in (0..contents.cell_count()).rev() { - let (cell_start, cell_length) = contents.cell_get_raw_region(cell_idx, usable_space); - if cell_start < contents.cell_content_area() as usize || cell_start > usable_space - 4 { - errors.push(IntegrityCheckError::CellOutOfRange { - cell_idx, - page_id: page.get().id, - cell_start, - cell_end: cell_start + cell_length, - content_area: contents.cell_content_area() as usize, - usable_space, - }); - } - if cell_start + cell_length > usable_space { - errors.push(IntegrityCheckError::CellOverflowsPage { - cell_idx, - page_id: page.get().id, - cell_start, - cell_end: cell_start + cell_length, - content_area: contents.cell_content_area() as usize, - usable_space, - }); - } - coverage_checker.add_cell(cell_start, cell_start + cell_length); - let cell = contents.cell_get(cell_idx, usable_space)?; - match cell { - BTreeCell::TableInteriorCell(table_interior_cell) => { + if is_overflow_page { + let next_overflow_page = contents.read_u32_no_offset(0); + if next_overflow_page != 0 { state.page_stack.push(IntegrityCheckPageEntry { - page_idx: table_interior_cell.left_child_page as usize, - level: level + 1, - max_intkey: table_interior_cell.rowid, - }); - let rowid = table_interior_cell.rowid; - if rowid > max_intkey || rowid > next_rowid { - errors.push(IntegrityCheckError::CellRowidOutOfRange { - page_id: page.get().id, - cell_idx, - rowid, - max_intkey, - next_rowid, - }); - } - next_rowid = rowid; - } - BTreeCell::TableLeafCell(table_leaf_cell) => { - // check depth of leaf pages are equal - if let Some(expected_leaf_level) = state.first_leaf_level { - if expected_leaf_level != level { - errors.push(IntegrityCheckError::LeafDepthMismatch { - page_id: page.get().id, - this_page_depth: level, - other_page_depth: expected_leaf_level, - }); - } - } else { - state.first_leaf_level = Some(level); - } - let rowid = table_leaf_cell.rowid; - if rowid > max_intkey || rowid > next_rowid { - errors.push(IntegrityCheckError::CellRowidOutOfRange { - page_id: page.get().id, - cell_idx, - rowid, - max_intkey, - next_rowid, - }); - } - next_rowid = rowid; - } - BTreeCell::IndexInteriorCell(index_interior_cell) => { - state.page_stack.push(IntegrityCheckPageEntry { - page_idx: index_interior_cell.left_child_page as usize, - level: level + 1, - max_intkey, // we don't care about intkey in non-table pages + page_idx: next_overflow_page as usize, + level, + max_intkey, }); } - BTreeCell::IndexLeafCell(_) => { - // check depth of leaf pages are equal - if let Some(expected_leaf_level) = state.first_leaf_level { - if expected_leaf_level != level { - errors.push(IntegrityCheckError::LeafDepthMismatch { - page_id: page.get().id, - this_page_depth: level, - other_page_depth: expected_leaf_level, - }); - } - } else { - state.first_leaf_level = Some(level); - } - } + continue; } - } - // Now we add free blocks to the coverage checker - let first_freeblock = contents.first_freeblock() as usize; - if first_freeblock > 0 { - let mut pc = first_freeblock; - while pc > 0 { - let next = contents.read_u16_no_offset(pc as usize) as usize; - let size = contents.read_u16_no_offset(pc as usize + 2) as usize; - // check it doesn't go out of range - if pc > usable_space - 4 { - errors.push(IntegrityCheckError::FreeBlockOutOfRange { + // Now we check every cell for few things: + // 1. Check cell is in correct range. Not exceeds page and not starts before we have marked + // (cell content area). + // 2. We add the cell to coverage checker in order to check if cells do not overlap. + // 3. We check order of rowids in case of table pages. We iterate backwards in order to check + // if current cell's rowid is less than the next cell. We also check rowid is less than the + // parent's divider cell. In case of this page being root page max rowid will be i64::MAX. + // 4. We append pages to the stack to check later. + // 5. In case of leaf page, check if the current level(depth) is equal to other leaf pages we + // have seen. + let mut next_rowid = max_intkey; + for cell_idx in (0..contents.cell_count()).rev() { + let (cell_start, cell_length) = contents.cell_get_raw_region(cell_idx, usable_space); + if cell_start < contents.cell_content_area() as usize || cell_start > usable_space - 4 { + errors.push(IntegrityCheckError::CellOutOfRange { + cell_idx, page_id: page.get().id, - start: pc, - end: pc + size, + cell_start, + cell_end: cell_start + cell_length, + content_area: contents.cell_content_area() as usize, + usable_space, }); - break; } - coverage_checker.add_free_block(pc, pc + size); - pc = next; + if cell_start + cell_length > usable_space { + errors.push(IntegrityCheckError::CellOverflowsPage { + cell_idx, + page_id: page.get().id, + cell_start, + cell_end: cell_start + cell_length, + content_area: contents.cell_content_area() as usize, + usable_space, + }); + } + coverage_checker.add_cell(cell_start, cell_start + cell_length); + let cell = contents.cell_get(cell_idx, usable_space)?; + match cell { + BTreeCell::TableInteriorCell(table_interior_cell) => { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: table_interior_cell.left_child_page as usize, + level: level + 1, + max_intkey: table_interior_cell.rowid, + }); + let rowid = table_interior_cell.rowid; + if rowid > max_intkey || rowid > next_rowid { + errors.push(IntegrityCheckError::CellRowidOutOfRange { + page_id: page.get().id, + cell_idx, + rowid, + max_intkey, + next_rowid, + }); + } + next_rowid = rowid; + } + BTreeCell::TableLeafCell(table_leaf_cell) => { + // check depth of leaf pages are equal + if let Some(expected_leaf_level) = state.first_leaf_level { + if expected_leaf_level != level { + errors.push(IntegrityCheckError::LeafDepthMismatch { + page_id: page.get().id, + this_page_depth: level, + other_page_depth: expected_leaf_level, + }); + } + } else { + state.first_leaf_level = Some(level); + } + let rowid = table_leaf_cell.rowid; + if rowid > max_intkey || rowid > next_rowid { + errors.push(IntegrityCheckError::CellRowidOutOfRange { + page_id: page.get().id, + cell_idx, + rowid, + max_intkey, + next_rowid, + }); + } + next_rowid = rowid; + if let Some(first_overflow_page) = table_leaf_cell.first_overflow_page { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: first_overflow_page as usize, + level, + max_intkey, + }); + } + } + BTreeCell::IndexInteriorCell(index_interior_cell) => { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: index_interior_cell.left_child_page as usize, + level: level + 1, + max_intkey, // we don't care about intkey in non-table pages + }); + if let Some(first_overflow_page) = index_interior_cell.first_overflow_page { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: first_overflow_page as usize, + level, + max_intkey, + }); + } + } + BTreeCell::IndexLeafCell(index_leaf_cell) => { + // check depth of leaf pages are equal + if let Some(expected_leaf_level) = state.first_leaf_level { + if expected_leaf_level != level { + errors.push(IntegrityCheckError::LeafDepthMismatch { + page_id: page.get().id, + this_page_depth: level, + other_page_depth: expected_leaf_level, + }); + } + } else { + state.first_leaf_level = Some(level); + } + if let Some(first_overflow_page) = index_leaf_cell.first_overflow_page { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: first_overflow_page as usize, + level, + max_intkey, + }); + } + } + } } + + if let Some(rightmost) = contents.rightmost_pointer() { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: rightmost as usize, + level: level + 1, + max_intkey, + }); + } + + // Now we add free blocks to the coverage checker + let first_freeblock = contents.first_freeblock() as usize; + if first_freeblock > 0 { + let mut pc = first_freeblock; + while pc > 0 { + let next = contents.read_u16_no_offset(pc as usize) as usize; + let size = contents.read_u16_no_offset(pc as usize + 2) as usize; + // check it doesn't go out of range + if pc > usable_space - 4 { + errors.push(IntegrityCheckError::FreeBlockOutOfRange { + page_id: page.get().id, + start: pc, + end: pc + size, + }); + break; + } + coverage_checker.add_free_block(pc, pc + size); + pc = next; + } + } + + // Let's check the overlap of freeblocks and cells now that we have collected them all. + coverage_checker.analyze( + usable_space, + contents.cell_content_area() as usize, + errors, + contents.num_frag_free_bytes() as usize, + ); } - - // Let's check the overlap of freeblocks and cells now that we have collected them all. - coverage_checker.analyze( - usable_space, - contents.cell_content_area() as usize, - errors, - contents.num_frag_free_bytes() as usize, - ); - - Ok(IOResult::Done(())) } pub fn btree_read_page( @@ -6081,59 +6139,104 @@ impl BTreePageInner { } } -/// Try to find a free block available and allocate it if found -fn find_free_cell(page_ref: &PageContent, usable_space: usize, amount: usize) -> Result { +/// Try to find a freeblock inside the cell content area that is large enough to fit the given amount of bytes. +/// Used to check if a cell can be inserted into a freeblock to reduce fragmentation. +/// Returns the absolute byte offset of the freeblock if found. +fn find_free_slot( + page_ref: &PageContent, + usable_space: usize, + amount: usize, +) -> Result> { + const CELL_SIZE_MIN: usize = 4; // NOTE: freelist is in ascending order of keys and pc // unuse_space is reserved bytes at the end of page, therefore we must substract from maxpc - let mut prev_pc = page_ref.offset + offset::BTREE_FIRST_FREEBLOCK; - let mut pc = page_ref.first_freeblock() as usize; - let maxpc = usable_space - amount; + let mut prev_block = None; + let mut cur_block = match page_ref.first_freeblock() { + 0 => None, + first_block => Some(first_block as usize), + }; - while pc <= maxpc { - if pc + 4 > usable_space { + let max_start_offset = usable_space - amount; + + while let Some(cur) = cur_block { + if cur + CELL_SIZE_MIN > usable_space { return_corrupt!("Free block header extends beyond page"); } - let next = page_ref.read_u16_no_offset(pc); - let size = page_ref.read_u16_no_offset(pc + 2); + let (next, size) = { + let cur_u16: u16 = cur + .try_into() + .unwrap_or_else(|_| panic!("cur={cur} is too large to fit in a u16")); + let (next, size) = page_ref.read_freeblock(cur_u16); + (next as usize, size as usize) + }; - if amount <= size as usize { - let new_size = size as usize - amount; - if new_size < 4 { - // The code is checking if using a free slot that would leave behind a very small fragment (x < 4 bytes) - // would cause the total fragmentation to exceed the limit of 60 bytes - // check sqlite docs https://www.sqlite.org/fileformat.html#:~:text=A%20freeblock%20requires,not%20exceed%2060 - if page_ref.num_frag_free_bytes() > 57 { - return Ok(0); - } - // Delete the slot from freelist and update the page's fragment count. - page_ref.write_u16_no_offset(prev_pc, next); - let frag = page_ref.num_frag_free_bytes() + new_size as u8; - page_ref.write_fragmented_bytes_count(frag); - return Ok(pc); - } else if new_size + pc > maxpc { - return_corrupt!("Free block extends beyond page end"); - } else { - // Requested amount fits inside the current free slot so we reduce its size - // to account for newly allocated space. - page_ref.write_u16_no_offset(pc + 2, new_size as u16); - return Ok(pc + new_size); + // Doesn't fit in this freeblock, try the next one. + if amount > size { + if next == 0 { + // No next -> can't fit. + return Ok(None); } - } - prev_pc = pc; - pc = next as usize; - if pc <= prev_pc { - if pc != 0 { + prev_block = cur_block; + if next <= cur { return_corrupt!("Free list not in ascending order"); } - return Ok(0); + cur_block = Some(next); + continue; + } + + let new_size = size - amount; + // If the freeblock's new size is < CELL_SIZE_MIN, the freeblock is deleted and the remaining bytes + // become fragmented free bytes. + if new_size < CELL_SIZE_MIN { + if page_ref.num_frag_free_bytes() > 57 { + // SQLite has a fragmentation limit of 60 bytes. + // check sqlite docs https://www.sqlite.org/fileformat.html#:~:text=A%20freeblock%20requires,not%20exceed%2060 + return Ok(None); + } + // Delete the slot from freelist and update the page's fragment count. + match prev_block { + Some(prev) => { + let prev_u16: u16 = prev + .try_into() + .unwrap_or_else(|_| panic!("prev={prev} is too large to fit in a u16")); + let next_u16: u16 = next + .try_into() + .unwrap_or_else(|_| panic!("next={next} is too large to fit in a u16")); + page_ref.write_freeblock_next_ptr(prev_u16, next_u16); + } + None => { + let next_u16: u16 = next + .try_into() + .unwrap_or_else(|_| panic!("next={next} is too large to fit in a u16")); + page_ref.write_first_freeblock(next_u16); + } + } + let new_size_u8: u8 = new_size + .try_into() + .unwrap_or_else(|_| panic!("new_size={new_size} is too large to fit in a u8")); + let frag = page_ref.num_frag_free_bytes() + new_size_u8; + page_ref.write_fragmented_bytes_count(frag); + return Ok(cur_block); + } else if new_size + cur > max_start_offset { + return_corrupt!("Free block extends beyond page end"); + } else { + // Requested amount fits inside the current free slot so we reduce its size + // to account for newly allocated space. + let cur_u16: u16 = cur + .try_into() + .unwrap_or_else(|_| panic!("cur={cur} is too large to fit in a u16")); + let new_size_u16: u16 = new_size + .try_into() + .unwrap_or_else(|_| panic!("new_size={new_size} is too large to fit in a u16")); + page_ref.write_freeblock_size(cur_u16, new_size_u16); + // Return the offset immediately after the shrunk freeblock. + return Ok(Some(cur + new_size)); } } - if pc > maxpc + amount - 4 { - return_corrupt!("Free block chain extends beyond page end"); - } - Ok(0) + + Ok(None) } pub fn btree_init_page(page: &BTreePage, page_type: PageType, offset: usize, usable_space: usize) { @@ -6384,121 +6487,176 @@ fn page_insert_array( /// This function also updates the freeblock list in the page. /// Freeblocks are used to keep track of free space in the page, /// and are organized as a linked list. +/// +/// This function may merge the freed cell range into either the next freeblock, +/// previous freeblock, or both. fn free_cell_range( page: &mut PageContent, mut offset: usize, len: usize, usable_space: usize, ) -> Result<()> { - if len < 4 { - return_corrupt!("Minimum cell size is 4"); + const CELL_SIZE_MIN: usize = 4; + if len < CELL_SIZE_MIN { + return_corrupt!("free_cell_range: minimum cell size is {CELL_SIZE_MIN}"); } - - if offset > usable_space.saturating_sub(4) { - return_corrupt!("Start offset beyond usable space"); + if offset > usable_space.saturating_sub(CELL_SIZE_MIN) { + return_corrupt!("free_cell_range: start offset beyond usable space: offset={offset} usable_space={usable_space}"); } let mut size = len; let mut end = offset + len; - let mut pointer_to_pc = page.offset + 1; - // if the freeblock list is empty, we set this block as the first freeblock in the page header. - let pc = if page.first_freeblock() == 0 { - 0 - } else { - // if the freeblock list is not empty, and the offset is greater than the first freeblock, - // then we need to do some more calculation to figure out where to insert the freeblock - // in the freeblock linked list. - let first_block = page.first_freeblock() as usize; - - let mut pc = first_block; - - while pc < offset { - if pc <= pointer_to_pc { - if pc == 0 { - break; - } - return_corrupt!("free cell range free block not in ascending order"); - } - - let next = page.read_u16_no_offset(pc) as usize; - pointer_to_pc = pc; - pc = next; + let cur_content_area = page.cell_content_area() as usize; + let first_block = page.first_freeblock() as usize; + if first_block == 0 { + if offset < cur_content_area { + return_corrupt!("free_cell_range: free block before content area: offset={offset} cell_content_area={cur_content_area}"); } - - if pc > usable_space - 4 { - return_corrupt!("Free block beyond usable space"); + if offset == cur_content_area { + // if the freeblock list is empty and the freed range is exactly at the beginning of the content area, + // we are not creating a freeblock; instead we are just extending the unallocated region. + page.write_cell_content_area(end); + } else { + // otherwise we set it as the first freeblock in the page header. + let offset_u16: u16 = offset + .try_into() + .unwrap_or_else(|_| panic!("offset={offset} is too large to fit in a u16")); + page.write_first_freeblock(offset_u16); + let size_u16: u16 = size + .try_into() + .unwrap_or_else(|_| panic!("size={size} is too large to fit in a u16")); + page.write_freeblock(offset_u16, size_u16, None); } - let mut removed_fragmentation = 0; - if pc > 0 && offset + len + 3 >= pc { - removed_fragmentation = (pc - end) as u8; + return Ok(()); + } - if end > pc { - return_corrupt!("Invalid block overlap"); - } - end = pc + page.read_u16_no_offset(pc + 2) as usize; + // if the freeblock list is not empty, we need to find the correct position to insert the new freeblock + // resulting from the freeing of this cell range; we may be also able to merge the freed range into existing freeblocks. + let mut prev_block = None; + let mut next_block = Some(first_block); + + while let Some(next) = next_block { + if prev_block.is_some_and(|prev| next <= prev) { + return_corrupt!("free_cell_range: freeblocks not in ascending order: next_block={next} prev_block={prev_block:?}"); + } + if next >= offset { + break; + } + prev_block = Some(next); + next_block = match page.read_u16_no_offset(next) { + // Freed range extends beyond the last freeblock, so we are creating a new freeblock. + 0 => None, + next => Some(next as usize), + }; + } + + if let Some(next) = next_block { + if next + CELL_SIZE_MIN > usable_space { + return_corrupt!("free_cell_range: free block beyond usable space: next_block={next} usable_space={usable_space}"); + } + } + let mut removed_fragmentation = 0; + const SINGLE_FRAGMENT_SIZE_MAX: usize = CELL_SIZE_MIN - 1; + + if end > usable_space { + return_corrupt!("free_cell_range: freed range extends beyond usable space: offset={offset} len={len} end={end} usable_space={usable_space}"); + } + + // If the freed range extends into the next freeblock, we will merge the freed range into it. + // If there is a 1-3 byte gap between the freed range and the next freeblock, we are effectively + // clearing that amount of fragmented bytes, since a 1-3 byte range cannot be a valid cell. + if let Some(next) = next_block { + if end + SINGLE_FRAGMENT_SIZE_MAX >= next { + removed_fragmentation = (next - end) as u8; + let next_size = page.read_u16_no_offset(next + 2) as usize; + end = next + next_size; if end > usable_space { - return_corrupt!("Coalesced block extends beyond page"); + return_corrupt!("free_cell_range: coalesced block extends beyond page: offset={offset} len={len} end={end} usable_space={usable_space}"); } size = end - offset; - pc = page.read_u16_no_offset(pc) as usize; + // Since we merged the two freeblocks, we need to update the next_block to the next freeblock in the list. + next_block = match page.read_u16_no_offset(next) { + 0 => None, + next => Some(next as usize), + }; } + } - if pointer_to_pc > page.offset + 1 { - let prev_end = pointer_to_pc + page.read_u16_no_offset(pointer_to_pc + 2) as usize; - if prev_end + 3 >= offset { - if prev_end > offset { - return_corrupt!("Invalid previous block overlap"); + // If the freed range extends into the previous freeblock, we will merge them similarly as above. + if let Some(prev) = prev_block { + let prev_size = page.read_u16_no_offset(prev + 2) as usize; + let prev_end = prev + prev_size; + if prev_end > offset { + return_corrupt!( + "free_cell_range: previous block overlap: prev_end={prev_end} offset={offset}" + ); + } + // If the previous freeblock extends into the freed range, we will merge the freed range into the + // previous freeblock and clear any 1-3 byte fragmentation in between, similarly as above + if prev_end + SINGLE_FRAGMENT_SIZE_MAX >= offset { + removed_fragmentation += (offset - prev_end) as u8; + size = end - prev; + offset = prev; + } + } + + let cur_frag_free_bytes = page.num_frag_free_bytes(); + if removed_fragmentation > cur_frag_free_bytes { + return_corrupt!("free_cell_range: invalid fragmentation count: removed_fragmentation={removed_fragmentation} num_frag_free_bytes={cur_frag_free_bytes}"); + } + let frag = cur_frag_free_bytes - removed_fragmentation; + page.write_fragmented_bytes_count(frag); + + if offset < cur_content_area { + return_corrupt!("free_cell_range: free block before content area: offset={offset} cell_content_area={cur_content_area}"); + } + + // As above, if the freed range is exactly at the beginning of the content area, we are not creating a freeblock; + // instead we are just extending the unallocated region. + if offset == cur_content_area { + if prev_block.is_some_and(|prev| prev != first_block) { + return_corrupt!("free_cell_range: invalid content area merge - freed range should have been merged with previous freeblock: prev={prev} first_block={first_block}"); + } + // If we get here, we are freeing data from the left end of the content area, + // so we are extending the unallocated region instead of creating a freeblock. + // We update the first freeblock to be the next one, and shrink the content area to start from the end + // of the freed range. + match next_block { + Some(next) => { + if next <= end { + return_corrupt!("free_cell_range: invalid content area merge - first freeblock should either be 0 or greater than the content area start: next_block={next} end={end}"); } - removed_fragmentation += (offset - prev_end) as u8; - size = end - pointer_to_pc; - offset = pointer_to_pc; + let next_u16: u16 = next + .try_into() + .unwrap_or_else(|_| panic!("next={next} is too large to fit in a u16")); + page.write_first_freeblock(next_u16); + } + None => { + page.write_first_freeblock(0); } } - if removed_fragmentation > page.num_frag_free_bytes() { - return_corrupt!(format!( - "Invalid fragmentation count. Had {} and removed {}", - page.num_frag_free_bytes(), - removed_fragmentation - )); - } - let frag = page.num_frag_free_bytes() - removed_fragmentation; - page.write_fragmented_bytes_count(frag); - pc - }; - - if (offset as u32) <= page.cell_content_area() { - if (offset as u32) < page.cell_content_area() { - return_corrupt!("Free block before content area"); - } - if pointer_to_pc != page.offset + offset::BTREE_FIRST_FREEBLOCK { - return_corrupt!("Invalid content area merge"); - } - turso_assert!( - pc < PageSize::MAX as usize, - "pc={pc} PageSize::MAX={}", - PageSize::MAX - ); - page.write_first_freeblock(pc as u16); page.write_cell_content_area(end); } else { - turso_assert!( - pointer_to_pc < PageSize::MAX as usize, - "pointer_to_pc={pointer_to_pc} PageSize::MAX={}", - PageSize::MAX - ); - turso_assert!( - offset < PageSize::MAX as usize, - "offset={offset} PageSize::MAX={}", - PageSize::MAX - ); - turso_assert!( - size < PageSize::MAX as usize, - "size={size} PageSize::MAX={}", - PageSize::MAX - ); - page.write_u16_no_offset(pointer_to_pc, offset as u16); - page.write_u16_no_offset(offset, pc as u16); - page.write_u16_no_offset(offset + 2, size as u16); + // If we are creating a new freeblock: + // a) if it's the first one, we update the header to indicate so, + // b) if it's not the first one, we update the previous freeblock to point to the new one, + // and the new one to point to the next one. + let offset_u16: u16 = offset + .try_into() + .unwrap_or_else(|_| panic!("offset={offset} is too large to fit in a u16")); + if let Some(prev) = prev_block { + page.write_u16_no_offset(prev, offset_u16); + } else { + page.write_first_freeblock(offset_u16); + } + let size_u16: u16 = size + .try_into() + .unwrap_or_else(|_| panic!("size={size} is too large to fit in a u16")); + let next_block_u16 = next_block.map(|b| { + b.try_into() + .unwrap_or_else(|_| panic!("next_block={b} is too large to fit in a u16")) + }); + page.write_freeblock(offset_u16, size_u16, next_block_u16); } Ok(()) @@ -6893,7 +7051,7 @@ fn compute_free_space(page: &PageContent, usable_space: usize) -> usize { // Next should always be 0 (NULL) at this point since we have reached the end of the freeblocks linked list assert_eq!( next, 0, - "corrupted page: freeblocks list not in ascending order" + "corrupted page: freeblocks list not in ascending order: cur_freeblock_ptr={cur_freeblock_ptr} size={size} next={next}" ); assert!( @@ -6929,8 +7087,7 @@ fn allocate_cell_space( && unallocated_region_start + CELL_PTR_SIZE_BYTES <= cell_content_area_start { // find slot - let pc = find_free_cell(page_ref, usable_space, amount)?; - if pc != 0 { + if let Some(pc) = find_free_slot(page_ref, usable_space, amount)? { // we can fit the cell in a freeblock. return Ok(pc as u16); } @@ -6963,39 +7120,66 @@ fn allocate_cell_space( #[derive(Debug, Clone)] pub enum FillCellPayloadState { + /// Determine whether we can fit the record on the current page. + /// If yes, return immediately after copying the data. + /// Otherwise move to [CopyData] state. Start, - AllocateOverflowPages { - /// Arc because we clone [WriteState] for some reason and we use unsafe pointer dereferences in [FillCellPayloadState::AllocateOverflowPages] - /// so the underlying bytes must not be cloned in upper layers. - record_buf: Arc<[u8]>, - space_left: usize, - to_copy_buffer_ptr: *const u8, - to_copy_buffer_len: usize, - pointer: *mut u8, - pointer_to_next: *mut u8, + /// Copy the next chunk of data from the record buffer to the cell payload. + /// If we can't fit all of the remaining data on the current page, + /// move the internal state to [CopyDataState::AllocateOverflowPage] + CopyData { + /// Internal state of the copy data operation. + /// We can either be copying data or allocating an overflow page. + state: CopyDataState, + /// Track how much space we have left on the current page we are copying data into. + /// This is reset whenever a new overflow page is allocated. + space_left_on_cur_page: usize, + /// Offset into the record buffer to copy from. + src_data_offset: usize, + /// Offset into the destination buffer we are copying data into. + /// This is either: + /// - an offset in the btree page where the cell is, or + /// - an offset in an overflow page + dst_data_offset: usize, + /// If this is Some, we will copy data into this overflow page. + /// If this is None, we will copy data into the cell payload on the btree page. + /// Also: to safely form a chain of overflow pages, the current page must be pinned to the page cache + /// so that e.g. a spilling operation does not evict it to disk. + current_overflow_page: Option, }, } +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum CopyDataState { + /// Copy the next chunk of data from the record buffer to the cell payload. + Copy, + /// Allocate a new overflow page if we couldn't fit all data to the current page. + AllocateOverflowPage, +} + /// Fill in the cell payload with the record. /// If the record is too large to fit in the cell, it will spill onto overflow pages. /// This function needs a separate [FillCellPayloadState] because allocating overflow pages /// may require I/O. #[allow(clippy::too_many_arguments)] fn fill_cell_payload( - page_contents: &PageContent, + page: PageRef, int_key: Option, cell_payload: &mut Vec, cell_idx: usize, record: &ImmutableRecord, usable_space: usize, pager: Rc, - state: &mut FillCellPayloadState, + fill_cell_payload_state: &mut FillCellPayloadState, ) -> Result> { - loop { - match state { + let overflow_page_pointer_size = 4; + let overflow_page_data_size = usable_space - overflow_page_pointer_size; + let result = loop { + let record_buf = record.get_payload(); + match fill_cell_payload_state { FillCellPayloadState::Start => { - // TODO: make record raw from start, having to serialize is not good - let record_buf: Arc<[u8]> = Arc::from(record.get_payload()); + page.pin(); // We need to pin this page because we will be accessing its contents after fill_cell_payload is done. + let page_contents = page.get().contents.as_ref().unwrap(); let page_type = page_contents.page_type(); // fill in header @@ -7014,124 +7198,124 @@ fn fill_cell_payload( write_varint_to_vec(record_buf.len() as u64, cell_payload); } - let payload_overflow_threshold_max = - payload_overflow_threshold_max(page_type, usable_space); - tracing::debug!( - "fill_cell_payload(record_size={}, payload_overflow_threshold_max={})", - record_buf.len(), - payload_overflow_threshold_max - ); - if record_buf.len() <= payload_overflow_threshold_max { + let max_local = payload_overflow_threshold_max(page_type, usable_space); + let min_local = payload_overflow_threshold_min(page_type, usable_space); + + let (overflows, local_size_if_overflow) = + payload_overflows(record_buf.len(), max_local, min_local, usable_space); + if !overflows { // enough allowed space to fit inside a btree page cell_payload.extend_from_slice(record_buf.as_ref()); - return Ok(IOResult::Done(())); + break Ok(IOResult::Done(())); } - let payload_overflow_threshold_min = - payload_overflow_threshold_min(page_type, usable_space); - // see e.g. https://github.com/sqlite/sqlite/blob/9591d3fe93936533c8c3b0dc4d025ac999539e11/src/dbstat.c#L371 - let mut space_left = payload_overflow_threshold_min - + (record_buf.len() - payload_overflow_threshold_min) % (usable_space - 4); + // so far we've written any of: left child page, rowid, payload size (depending on page type) + let cell_non_payload_elems_size = cell_payload.len(); + let new_total_local_size = cell_non_payload_elems_size + local_size_if_overflow; + cell_payload.resize(new_total_local_size, 0); - if space_left > payload_overflow_threshold_max { - space_left = payload_overflow_threshold_min; - } - - // cell_size must be equal to first value of space_left as this will be the bytes copied to non-overflow page. - let cell_size = space_left + cell_payload.len() + 4; // 4 is the number of bytes of pointer to first overflow page - let to_copy_buffer = record_buf.as_ref(); - - let prev_size = cell_payload.len(); - cell_payload.resize(prev_size + space_left + 4, 0); - assert_eq!( - cell_size, - cell_payload.len(), - "cell_size={} != cell_payload.len()={}", - cell_size, - cell_payload.len() - ); - - // SAFETY: this pointer is valid because it points to a buffer in an Arc>> that lives at least as long as this function, - // and the Vec will not be mutated in FillCellPayloadState::AllocateOverflowPages, which we will move to next. - let pointer = unsafe { cell_payload.as_mut_ptr().add(prev_size) }; - let pointer_to_next = - unsafe { cell_payload.as_mut_ptr().add(prev_size + space_left) }; - - let to_copy_buffer_ptr = to_copy_buffer.as_ptr(); - let to_copy_buffer_len = to_copy_buffer.len(); - - *state = FillCellPayloadState::AllocateOverflowPages { - record_buf, - space_left, - to_copy_buffer_ptr, - to_copy_buffer_len, - pointer, - pointer_to_next, + *fill_cell_payload_state = FillCellPayloadState::CopyData { + state: CopyDataState::Copy, + space_left_on_cur_page: local_size_if_overflow - overflow_page_pointer_size, // local_size_if_overflow includes the overflow page pointer, but we don't want to write payload data there. + src_data_offset: 0, + dst_data_offset: cell_non_payload_elems_size, + current_overflow_page: None, }; continue; } - FillCellPayloadState::AllocateOverflowPages { - record_buf: _record_buf, - space_left, - to_copy_buffer_ptr, - to_copy_buffer_len, - pointer, - pointer_to_next, + FillCellPayloadState::CopyData { + state, + src_data_offset, + space_left_on_cur_page, + dst_data_offset, + current_overflow_page, } => { - let to_copy; - { - let to_copy_buffer_ptr = *to_copy_buffer_ptr; - let to_copy_buffer_len = *to_copy_buffer_len; - let pointer = *pointer; - let space_left = *space_left; + match state { + CopyDataState::Copy => { + turso_assert!(*src_data_offset < record_buf.len(), "trying to read past end of record buffer: record_offset={} < record_buf.len()={}", src_data_offset, record_buf.len()); + let record_offset_slice = &record_buf[*src_data_offset..]; + let amount_to_copy = + (*space_left_on_cur_page).min(record_offset_slice.len()); + let record_offset_slice_to_copy = &record_offset_slice[..amount_to_copy]; + if let Some(cur_page) = current_overflow_page { + // Copy data into the current overflow page. + turso_assert!( + cur_page.is_loaded(), + "current overflow page is not loaded" + ); + turso_assert!(*dst_data_offset == overflow_page_pointer_size, "data must be copied to offset {overflow_page_pointer_size} on overflow pages, instead tried to copy to offset {dst_data_offset}"); + let contents = cur_page.get_contents(); + let buf = &mut contents.as_ptr() + [*dst_data_offset..*dst_data_offset + amount_to_copy]; + buf.copy_from_slice(record_offset_slice_to_copy); + } else { + // Copy data into the cell payload on the btree page. + let buf = &mut cell_payload + [*dst_data_offset..*dst_data_offset + amount_to_copy]; + buf.copy_from_slice(record_offset_slice_to_copy); + } - // SAFETY: we know to_copy_buffer_ptr is valid because it refers to record_buf which lives at least as long as this function, - // and the underlying bytes are not mutated in FillCellPayloadState::AllocateOverflowPages. - let to_copy_buffer = unsafe { - std::slice::from_raw_parts(to_copy_buffer_ptr, to_copy_buffer_len) - }; - to_copy = space_left.min(to_copy_buffer_len); - // SAFETY: we know 'pointer' is valid because it refers to cell_payload which lives at least as long as this function, - // and the underlying bytes are not mutated in FillCellPayloadState::AllocateOverflowPages. - unsafe { std::ptr::copy(to_copy_buffer_ptr, pointer, to_copy) }; + if record_offset_slice.len() - amount_to_copy == 0 { + let cur_page = current_overflow_page.as_ref().expect("we must have overflowed if the remaining payload fits on the current page"); + cur_page.unpin(); // We can safely unpin the current overflow page now. + // Everything copied. + break Ok(IOResult::Done(())); + } + *state = CopyDataState::AllocateOverflowPage; + *src_data_offset += amount_to_copy; + } + CopyDataState::AllocateOverflowPage => { + let new_overflow_page = match pager.allocate_overflow_page() { + Ok(IOResult::Done(new_overflow_page)) => new_overflow_page, + Ok(IOResult::IO(io_result)) => return Ok(IOResult::IO(io_result)), + Err(e) => { + if let Some(cur_page) = current_overflow_page { + cur_page.unpin(); + } + break Err(e); + } + }; + new_overflow_page.pin(); // Pin the current overflow page so the cache won't evict it because we need this page to be in memory for the next iteration of FillCellPayloadState::CopyData. + if let Some(prev_page) = current_overflow_page { + prev_page.unpin(); // We can safely unpin the previous overflow page now. + } - let left = to_copy_buffer.len() - to_copy; - if left == 0 { - break; + turso_assert!( + new_overflow_page.is_loaded(), + "new overflow page is not loaded" + ); + let new_overflow_page_id = new_overflow_page.get().id as u32; + + if let Some(prev_page) = current_overflow_page { + // Update the previous overflow page's "next overflow page" pointer to point to the new overflow page. + turso_assert!( + prev_page.is_loaded(), + "previous overflow page is not loaded" + ); + let contents = prev_page.get_contents(); + let buf = &mut contents.as_ptr()[..overflow_page_pointer_size]; + buf.copy_from_slice(&new_overflow_page_id.to_be_bytes()); + } else { + // Update the cell payload's "next overflow page" pointer to point to the new overflow page. + let first_overflow_page_ptr_offset = + cell_payload.len() - overflow_page_pointer_size; + let buf = &mut cell_payload[first_overflow_page_ptr_offset + ..first_overflow_page_ptr_offset + overflow_page_pointer_size]; + buf.copy_from_slice(&new_overflow_page_id.to_be_bytes()); + } + + *dst_data_offset = overflow_page_pointer_size; + *space_left_on_cur_page = overflow_page_data_size; + *current_overflow_page = Some(new_overflow_page.clone()); + *state = CopyDataState::Copy; } } - - // we still have bytes to add, we will need to allocate new overflow page - // FIXME: handle page cache is full - let overflow_page = return_if_io!(pager.allocate_overflow_page()); - turso_assert!(overflow_page.is_loaded(), "overflow page is not loaded"); - { - let id = overflow_page.get().id as u32; - let contents = overflow_page.get_contents(); - - // TODO: take into account offset here? - let buf = contents.as_ptr(); - let as_bytes = id.to_be_bytes(); - // update pointer to new overflow page - // SAFETY: we know 'pointer_to_next' is valid because it refers to an offset in cell_payload which is less than space_left + 4, - // and the underlying bytes are not mutated in FillCellPayloadState::AllocateOverflowPages. - unsafe { std::ptr::copy(as_bytes.as_ptr(), *pointer_to_next, 4) }; - - *pointer = unsafe { buf.as_mut_ptr().add(4) }; - *pointer_to_next = buf.as_mut_ptr(); - *space_left = usable_space - 4; - } - - *to_copy_buffer_len -= to_copy; - // SAFETY: we know 'to_copy_buffer_ptr' is valid because it refers to record_buf which lives at least as long as this function, - // and that the offset is less than its length, and the underlying bytes are not mutated in FillCellPayloadState::AllocateOverflowPages. - *to_copy_buffer_ptr = unsafe { to_copy_buffer_ptr.add(to_copy) }; } } - } - Ok(IOResult::Done(())) + }; + page.unpin(); + result } - /// Returns the maximum payload size (X) that can be stored directly on a b-tree page without spilling to overflow pages. /// /// For table leaf pages: X = usable_size - 35 @@ -7285,7 +7469,7 @@ mod tests { fn add_record( id: usize, pos: usize, - page: &mut PageContent, + page: PageRef, record: ImmutableRecord, conn: &Arc, ) -> Vec { @@ -7294,7 +7478,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page, + page.clone(), Some(id as i64), &mut payload, pos, @@ -7307,7 +7491,7 @@ mod tests { &conn.pager.borrow().clone(), ) .unwrap(); - insert_into_cell(page, &payload, pos, 4096).unwrap(); + insert_into_cell(page.get_contents(), &payload, pos, 4096).unwrap(); payload } @@ -7317,17 +7501,17 @@ mod tests { let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get(); - let page = page.get_contents(); let header_size = 8; let regs = &[Register::Value(Value::Integer(1))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(1, 0, page, record, &conn); - assert_eq!(page.cell_count(), 1); - let free = compute_free_space(page, 4096); + let payload = add_record(1, 0, page.clone(), record, &conn); + let page_contents = page.get_contents(); + assert_eq!(page_contents.cell_count(), 1); + let free = compute_free_space(page_contents, 4096); assert_eq!(free, 4096 - payload.len() - 2 - header_size); let cell_idx = 0; - ensure_cell(page, cell_idx, &payload); + ensure_cell(page_contents, cell_idx, &payload); } struct Cell { @@ -7342,7 +7526,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -7351,22 +7535,22 @@ mod tests { for i in 0..3 { let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(i, i, page, record, &conn); - assert_eq!(page.cell_count(), i + 1); - let free = compute_free_space(page, usable_space); + let payload = add_record(i, i, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), i + 1); + let free = compute_free_space(page_contents, usable_space); total_size += payload.len() + 2; assert_eq!(free, 4096 - total_size - header_size); cells.push(Cell { pos: i, payload }); } for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } cells.remove(1); - drop_cell(page, 1, usable_space).unwrap(); + drop_cell(page_contents, 1, usable_space).unwrap(); for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } } @@ -7860,10 +8044,7 @@ mod tests { pager.deref(), ) .unwrap(); - pager - .io - .block(|| pager.end_tx(false, &conn, false)) - .unwrap(); + pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); pager.begin_read_tx().unwrap(); // FIXME: add sorted vector instead, should be okay for small amounts of keys for now :P, too lazy to fix right now let _c = cursor.move_to_root().unwrap(); @@ -8008,10 +8189,7 @@ mod tests { if let Some(c) = c { pager.io.wait_for_completion(c).unwrap(); } - pager - .io - .block(|| pager.end_tx(false, &conn, false)) - .unwrap(); + pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); } // Check that all keys can be found by seeking @@ -8217,10 +8395,7 @@ mod tests { if let Some(c) = c { pager.io.wait_for_completion(c).unwrap(); } - pager - .io - .block(|| pager.end_tx(false, &conn, false)) - .unwrap(); + pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); } // Final validation @@ -8321,7 +8496,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8331,9 +8506,9 @@ mod tests { for i in 0..total_cells { let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(i, i, page, record, &conn); - assert_eq!(page.cell_count(), i + 1); - let free = compute_free_space(page, usable_space); + let payload = add_record(i, i, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), i + 1); + let free = compute_free_space(page_contents, usable_space); total_size += payload.len() + 2; assert_eq!(free, 4096 - total_size - header_size); cells.push(Cell { pos: i, payload }); @@ -8343,7 +8518,7 @@ mod tests { let mut new_cells = Vec::new(); for cell in cells { if cell.pos % 2 == 1 { - drop_cell(page, cell.pos - removed, usable_space).unwrap(); + drop_cell(page_contents, cell.pos - removed, usable_space).unwrap(); removed += 1; } else { new_cells.push(cell); @@ -8351,11 +8526,11 @@ mod tests { } let cells = new_cells; for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } } @@ -8773,7 +8948,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8782,28 +8957,28 @@ mod tests { for i in 0..3 { let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(i, i, page, record, &conn); - assert_eq!(page.cell_count(), i + 1); - let free = compute_free_space(page, usable_space); + let payload = add_record(i, i, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), i + 1); + let free = compute_free_space(page_contents, usable_space); total_size += payload.len() + 2; assert_eq!(free, 4096 - total_size - header_size); cells.push(Cell { pos: i, payload }); } for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } cells.remove(1); - drop_cell(page, 1, usable_space).unwrap(); + drop_cell(page_contents, 1, usable_space).unwrap(); for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } } @@ -8814,7 +8989,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8824,9 +8999,9 @@ mod tests { for i in 0..total_cells { let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(i, i, page, record, &conn); - assert_eq!(page.cell_count(), i + 1); - let free = compute_free_space(page, usable_space); + let payload = add_record(i, i, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), i + 1); + let free = compute_free_space(page_contents, usable_space); total_size += payload.len() + 2; assert_eq!(free, 4096 - total_size - header_size); cells.push(Cell { pos: i, payload }); @@ -8836,7 +9011,7 @@ mod tests { let mut new_cells = Vec::new(); for cell in cells { if cell.pos % 2 == 1 { - drop_cell(page, cell.pos - removed, usable_space).unwrap(); + drop_cell(page_contents, cell.pos - removed, usable_space).unwrap(); removed += 1; } else { new_cells.push(cell); @@ -8844,13 +9019,13 @@ mod tests { } let cells = new_cells; for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } } @@ -8861,7 +9036,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8876,8 +9051,8 @@ mod tests { match rng.next_u64() % 4 { 0 => { // allow appends with extra place to insert - let cell_idx = rng.next_u64() as usize % (page.cell_count() + 1); - let free = compute_free_space(page, usable_space); + let cell_idx = rng.next_u64() as usize % (page_contents.cell_count() + 1); + let free = compute_free_space(page_contents, usable_space); let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); let mut payload: Vec = Vec::new(); @@ -8885,7 +9060,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page, + page.clone(), Some(i as i64), &mut payload, cell_idx, @@ -8902,34 +9077,34 @@ mod tests { // do not try to insert overflow pages because they require balancing continue; } - insert_into_cell(page, &payload, cell_idx, 4096).unwrap(); - assert!(page.overflow_cells.is_empty()); + insert_into_cell(page_contents, &payload, cell_idx, 4096).unwrap(); + assert!(page_contents.overflow_cells.is_empty()); total_size += payload.len() + 2; cells.insert(cell_idx, Cell { pos: i, payload }); } 1 => { - if page.cell_count() == 0 { + if page_contents.cell_count() == 0 { continue; } - let cell_idx = rng.next_u64() as usize % page.cell_count(); - let (_, len) = page.cell_get_raw_region(cell_idx, usable_space); - drop_cell(page, cell_idx, usable_space).unwrap(); + let cell_idx = rng.next_u64() as usize % page_contents.cell_count(); + let (_, len) = page_contents.cell_get_raw_region(cell_idx, usable_space); + drop_cell(page_contents, cell_idx, usable_space).unwrap(); total_size -= len + 2; cells.remove(cell_idx); } 2 => { - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); } 3 => { // check cells for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } - assert_eq!(page.cell_count(), cells.len()); + assert_eq!(page_contents.cell_count(), cells.len()); } _ => unreachable!(), } - let free = compute_free_space(page, usable_space); + let free = compute_free_space(page_contents, usable_space); assert_eq!(free, 4096 - total_size - header_size); } } @@ -8943,7 +9118,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8958,8 +9133,8 @@ mod tests { match rng.next_u64() % 3 { 0 => { // allow appends with extra place to insert - let cell_idx = rng.next_u64() as usize % (page.cell_count() + 1); - let free = compute_free_space(page, usable_space); + let cell_idx = rng.next_u64() as usize % (page_contents.cell_count() + 1); + let free = compute_free_space(page_contents, usable_space); let regs = &[Register::Value(Value::Integer(i))]; let record = ImmutableRecord::from_registers(regs, regs.len()); let mut payload: Vec = Vec::new(); @@ -8967,7 +9142,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page, + page.clone(), Some(i), &mut payload, cell_idx, @@ -8984,8 +9159,8 @@ mod tests { // do not try to insert overflow pages because they require balancing continue; } - insert_into_cell(page, &payload, cell_idx, 4096).unwrap(); - assert!(page.overflow_cells.is_empty()); + insert_into_cell(page_contents, &payload, cell_idx, 4096).unwrap(); + assert!(page_contents.overflow_cells.is_empty()); total_size += payload.len() + 2; cells.push(Cell { pos: i as usize, @@ -8993,21 +9168,21 @@ mod tests { }); } 1 => { - if page.cell_count() == 0 { + if page_contents.cell_count() == 0 { continue; } - let cell_idx = rng.next_u64() as usize % page.cell_count(); - let (_, len) = page.cell_get_raw_region(cell_idx, usable_space); - drop_cell(page, cell_idx, usable_space).unwrap(); + let cell_idx = rng.next_u64() as usize % page_contents.cell_count(); + let (_, len) = page_contents.cell_get_raw_region(cell_idx, usable_space); + drop_cell(page_contents, cell_idx, usable_space).unwrap(); total_size -= len + 2; cells.remove(cell_idx); } 2 => { - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); } _ => unreachable!(), } - let free = compute_free_space(page, usable_space); + let free = compute_free_space(page_contents, usable_space); assert_eq!(free, 4096 - total_size - header_size); } } @@ -9119,14 +9294,14 @@ mod tests { let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let usable_space = 4096; let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); - let free = compute_free_space(page, usable_space); + let payload = add_record(0, 0, page.clone(), record, &conn); + let free = compute_free_space(page_contents, usable_space); assert_eq!(free, 4096 - payload.len() - 2 - header_size); } @@ -9137,18 +9312,18 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); + let payload = add_record(0, 0, page.clone(), record, &conn); - assert_eq!(page.cell_count(), 1); - defragment_page(page, usable_space, 4).unwrap(); - assert_eq!(page.cell_count(), 1); - let (start, len) = page.cell_get_raw_region(0, usable_space); - let buf = page.as_ptr(); + assert_eq!(page_contents.cell_count(), 1); + defragment_page(page_contents, usable_space, 4).unwrap(); + assert_eq!(page_contents.cell_count(), 1); + let (start, len) = page_contents.cell_get_raw_region(0, usable_space); + let buf = page_contents.as_ptr(); assert_eq!(&payload, &buf[start..start + len]); } @@ -9159,7 +9334,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[ @@ -9167,19 +9342,19 @@ mod tests { Register::Value(Value::Text(Text::new("aaaaaaaa"))), ]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 0, page, record, &conn); + let _ = add_record(0, 0, page.clone(), record, &conn); - assert_eq!(page.cell_count(), 1); - drop_cell(page, 0, usable_space).unwrap(); - assert_eq!(page.cell_count(), 0); + assert_eq!(page_contents.cell_count(), 1); + drop_cell(page_contents, 0, usable_space).unwrap(); + assert_eq!(page_contents.cell_count(), 0); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); - assert_eq!(page.cell_count(), 1); + let payload = add_record(0, 0, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), 1); - let (start, len) = page.cell_get_raw_region(0, usable_space); - let buf = page.as_ptr(); + let (start, len) = page_contents.cell_get_raw_region(0, usable_space); + let buf = page_contents.as_ptr(); assert_eq!(&payload, &buf[start..start + len]); } @@ -9190,7 +9365,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[ @@ -9198,20 +9373,20 @@ mod tests { Register::Value(Value::Text(Text::new("aaaaaaaa"))), ]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 0, page, record, &conn); + let _ = add_record(0, 0, page.clone(), record, &conn); for _ in 0..100 { - assert_eq!(page.cell_count(), 1); - drop_cell(page, 0, usable_space).unwrap(); - assert_eq!(page.cell_count(), 0); + assert_eq!(page_contents.cell_count(), 1); + drop_cell(page_contents, 0, usable_space).unwrap(); + assert_eq!(page_contents.cell_count(), 0); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); - assert_eq!(page.cell_count(), 1); + let payload = add_record(0, 0, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), 1); - let (start, len) = page.cell_get_raw_region(0, usable_space); - let buf = page.as_ptr(); + let (start, len) = page_contents.cell_get_raw_region(0, usable_space); + let buf = page_contents.as_ptr(); assert_eq!(&payload, &buf[start..start + len]); } } @@ -9223,23 +9398,23 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); + let payload = add_record(0, 0, page.clone(), record, &conn); let regs = &[Register::Value(Value::Integer(1))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(1, 1, page, record, &conn); + let _ = add_record(1, 1, page.clone(), record, &conn); let regs = &[Register::Value(Value::Integer(2))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(2, 2, page, record, &conn); + let _ = add_record(2, 2, page.clone(), record, &conn); - drop_cell(page, 1, usable_space).unwrap(); - drop_cell(page, 1, usable_space).unwrap(); + drop_cell(page_contents, 1, usable_space).unwrap(); + drop_cell(page_contents, 1, usable_space).unwrap(); - ensure_cell(page, 0, &payload); + ensure_cell(page_contents, 0, &payload); } #[test] @@ -9249,29 +9424,29 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 0, page, record, &conn); + let _ = add_record(0, 0, page.clone(), record, &conn); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 0, page, record, &conn); - drop_cell(page, 0, usable_space).unwrap(); + let _ = add_record(0, 0, page.clone(), record, &conn); + drop_cell(page_contents, 0, usable_space).unwrap(); - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 1, page, record, &conn); + let _ = add_record(0, 1, page.clone(), record, &conn); - drop_cell(page, 0, usable_space).unwrap(); + drop_cell(page_contents, 0, usable_space).unwrap(); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 1, page, record, &conn); + let _ = add_record(0, 1, page.clone(), record, &conn); } #[test] @@ -9295,21 +9470,21 @@ mod tests { let page = page.get(); defragment(page.get_contents()); defragment(page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); drop(0, page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); drop(0, page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); defragment(page.get_contents()); defragment(page.get_contents()); drop(0, page.get_contents()); defragment(page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); drop(0, page.get_contents()); - insert(0, page.get_contents()); - insert(1, page.get_contents()); - insert(1, page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); + insert(1, page.clone()); + insert(1, page.clone()); + insert(0, page.clone()); drop(3, page.get_contents()); drop(2, page.get_contents()); compute_free_space(page.get_contents(), usable_space); @@ -9340,7 +9515,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page.get().get_contents(), + page.get(), Some(0), &mut payload, 0, @@ -9354,11 +9529,11 @@ mod tests { ) .unwrap(); let page = page.get(); - insert(0, page.get_contents()); + insert(0, page.clone()); defragment(page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); defragment(page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); drop(2, page.get_contents()); drop(0, page.get_contents()); let free = compute_free_space(page.get_contents(), usable_space); @@ -9426,7 +9601,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page.get().get_contents(), + page.get(), Some(0), &mut payload, 0, @@ -9778,7 +9953,7 @@ mod tests { while compute_free_space(page.get_contents(), pager.usable_space()) >= size as usize + 10 { - insert_cell(i, size, page.get_contents(), pager.clone()); + insert_cell(i, size, page.clone(), pager.clone()); i += 1; size = (rng.next_u64() % 1024) as u16; } @@ -9829,15 +10004,16 @@ mod tests { } } - fn insert_cell(cell_idx: u64, size: u16, contents: &mut PageContent, pager: Rc) { + fn insert_cell(cell_idx: u64, size: u16, page: PageRef, pager: Rc) { let mut payload = Vec::new(); let regs = &[Register::Value(Value::Blob(vec![0; size as usize]))]; let record = ImmutableRecord::from_registers(regs, regs.len()); let mut fill_cell_payload_state = FillCellPayloadState::Start; + let contents = page.get_contents(); run_until_done( || { fill_cell_payload( - contents, + page.clone(), Some(cell_idx as i64), &mut payload, cell_idx as usize, diff --git a/core/storage/database.rs b/core/storage/database.rs index d608558fc..b13962d30 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -184,11 +184,7 @@ impl DatabaseFile { } } -fn encrypt_buffer( - page_idx: usize, - buffer: Arc, - ctx: &EncryptionContext, -) -> Arc { +fn encrypt_buffer(page_idx: usize, buffer: Arc, ctx: &EncryptionContext) -> Arc { let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap(); Arc::new(Buffer::new(encrypted_data.to_vec())) } diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index 97836d3d1..51cf84ee5 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -1,15 +1,13 @@ #![allow(unused_variables, dead_code)] use crate::{LimboError, Result}; +use aegis::aegis256::Aegis256; use aes_gcm::{ aead::{Aead, AeadCore, KeyInit, OsRng}, Aes256Gcm, Key, Nonce, }; use std::ops::Deref; -pub const ENCRYPTION_METADATA_SIZE: usize = 28; pub const ENCRYPTED_PAGE_SIZE: usize = 4096; -pub const ENCRYPTION_NONCE_SIZE: usize = 12; -pub const ENCRYPTION_TAG_SIZE: usize = 16; #[repr(transparent)] #[derive(Clone)] @@ -20,12 +18,17 @@ impl EncryptionKey { Self(key) } - pub fn from_string(s: &str) -> Self { - let mut key = [0u8; 32]; - let bytes = s.as_bytes(); - let len = bytes.len().min(32); - key[..len].copy_from_slice(&bytes[..len]); - Self(key) + pub fn from_hex_string(s: &str) -> Result { + let hex_str = s.trim(); + let bytes = hex::decode(hex_str) + .map_err(|e| LimboError::InvalidArgument(format!("Invalid hex string: {e}")))?; + let key: [u8; 32] = bytes.try_into().map_err(|v: Vec| { + LimboError::InvalidArgument(format!( + "Hex string must decode to exactly 32 bytes, got {}", + v.len() + )) + })?; + Ok(Self(key)) } pub fn as_bytes(&self) -> &[u8; 32] { @@ -70,9 +73,88 @@ impl Drop for EncryptionKey { } } +// wrapper struct for AEGIS-256 cipher, because the crate we use is a bit low-level and we add +// some nice abstractions here +// note, the AEGIS has many variants and support for hardware acceleration. Here we just use the +// vanilla version, which is still order of maginitudes faster than AES-GCM in software. Hardware +// based compilation is left for future work. +#[derive(Clone)] +pub struct Aegis256Cipher { + key: EncryptionKey, +} + +impl Aegis256Cipher { + // AEGIS-256 supports both 16 and 32 byte tags, we use the 16 byte variant, it is faster + // and provides sufficient security for our use case. + const TAG_SIZE: usize = 16; + fn new(key: &EncryptionKey) -> Self { + Self { key: key.clone() } + } + + fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec, [u8; 32])> { + let nonce = generate_secure_nonce(); + let (ciphertext, tag) = + Aegis256::<16>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad); + let mut result = ciphertext; + result.extend_from_slice(&tag); + Ok((result, nonce)) + } + + fn decrypt(&self, ciphertext: &[u8], nonce: &[u8; 32], ad: &[u8]) -> Result> { + if ciphertext.len() < Self::TAG_SIZE { + return Err(LimboError::InternalError( + "Ciphertext too short for AEGIS-256".into(), + )); + } + let (ct, tag) = ciphertext.split_at(ciphertext.len() - Self::TAG_SIZE); + let tag_array: [u8; 16] = tag + .try_into() + .map_err(|_| LimboError::InternalError("Invalid tag size for AEGIS-256".into()))?; + + let plaintext = Aegis256::<16>::new(self.key.as_bytes(), nonce) + .decrypt(ct, &tag_array, ad) + .map_err(|_| { + LimboError::InternalError("AEGIS-256 decryption failed: invalid tag".into()) + })?; + Ok(plaintext) + } +} + +impl std::fmt::Debug for Aegis256Cipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Aegis256Cipher") + .field("key", &"") + .finish() + } +} + #[derive(Debug, Clone, Copy, PartialEq)] pub enum CipherMode { Aes256Gcm, + Aegis256, +} + +impl TryFrom<&str> for CipherMode { + type Error = LimboError; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "aes256gcm" | "aes-256-gcm" | "aes_256_gcm" => Ok(CipherMode::Aes256Gcm), + "aegis256" | "aegis-256" | "aegis_256" => Ok(CipherMode::Aegis256), + _ => Err(LimboError::InvalidArgument(format!( + "Unknown cipher name: {s}" + ))), + } + } +} + +impl std::fmt::Display for CipherMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CipherMode::Aes256Gcm => write!(f, "aes256gcm"), + CipherMode::Aegis256 => write!(f, "aegis256"), + } + } } impl CipherMode { @@ -81,33 +163,43 @@ impl CipherMode { pub fn required_key_size(&self) -> usize { match self { CipherMode::Aes256Gcm => 32, + CipherMode::Aegis256 => 32, } } - /// Returns the nonce size for this cipher mode. Though most AEAD ciphers use 12-byte nonces. + /// Returns the nonce size for this cipher mode. pub fn nonce_size(&self) -> usize { match self { - CipherMode::Aes256Gcm => ENCRYPTION_NONCE_SIZE, + CipherMode::Aes256Gcm => 12, + CipherMode::Aegis256 => 32, } } - /// Returns the authentication tag size for this cipher mode. All common AEAD ciphers use 16-byte tags. + /// Returns the authentication tag size for this cipher mode. pub fn tag_size(&self) -> usize { match self { - CipherMode::Aes256Gcm => ENCRYPTION_TAG_SIZE, + CipherMode::Aes256Gcm => 16, + CipherMode::Aegis256 => 16, } } + + /// Returns the total metadata size (nonce + tag) for this cipher mode. + pub fn metadata_size(&self) -> usize { + self.nonce_size() + self.tag_size() + } } #[derive(Clone)] pub enum Cipher { Aes256Gcm(Box), + Aegis256(Box), } impl std::fmt::Debug for Cipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Cipher::Aes256Gcm(_) => write!(f, "Cipher::Aes256Gcm"), + Cipher::Aegis256(_) => write!(f, "Cipher::Aegis256"), } } } @@ -119,8 +211,7 @@ pub struct EncryptionContext { } impl EncryptionContext { - pub fn new(key: &EncryptionKey) -> Result { - let cipher_mode = CipherMode::Aes256Gcm; + pub fn new(cipher_mode: CipherMode, key: &EncryptionKey) -> Result { let required_size = cipher_mode.required_key_size(); if key.as_slice().len() != required_size { return Err(crate::LimboError::InvalidArgument(format!( @@ -136,6 +227,7 @@ impl EncryptionContext { let cipher_key: &Key = key.as_ref().into(); Cipher::Aes256Gcm(Box::new(Aes256Gcm::new(cipher_key))) } + CipherMode::Aegis256 => Cipher::Aegis256(Box::new(Aegis256Cipher::new(key))), }; Ok(Self { cipher_mode, @@ -147,6 +239,11 @@ impl EncryptionContext { self.cipher_mode } + /// Returns the number of reserved bytes required at the end of each page for encryption metadata. + pub fn required_reserved_bytes(&self) -> u8 { + self.cipher_mode.metadata_size() as u8 + } + #[cfg(feature = "encryption")] pub fn encrypt_page(&self, page: &[u8], page_id: usize) -> Result> { if page_id == 1 { @@ -159,21 +256,26 @@ impl EncryptionContext { ENCRYPTED_PAGE_SIZE, "Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" ); - let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE..]; + + let metadata_size = self.cipher_mode.metadata_size(); + let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - metadata_size..]; let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0); assert!( reserved_bytes_zeroed, "last reserved bytes must be empty/zero, but found non-zero bytes" ); - let payload = &page[..ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE]; + + let payload = &page[..ENCRYPTED_PAGE_SIZE - metadata_size]; let (encrypted, nonce) = self.encrypt_raw(payload)?; + let nonce_size = self.cipher_mode.nonce_size(); assert_eq!( encrypted.len(), - ENCRYPTED_PAGE_SIZE - nonce.len(), + ENCRYPTED_PAGE_SIZE - nonce_size, "Encrypted page must be exactly {} bytes", - ENCRYPTED_PAGE_SIZE - nonce.len() + ENCRYPTED_PAGE_SIZE - nonce_size ); + let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); result.extend_from_slice(&encrypted); result.extend_from_slice(&nonce); @@ -198,18 +300,21 @@ impl EncryptionContext { "Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" ); - let nonce_start = encrypted_page.len() - ENCRYPTION_NONCE_SIZE; + let nonce_size = self.cipher_mode.nonce_size(); + let nonce_start = encrypted_page.len() - nonce_size; let payload = &encrypted_page[..nonce_start]; let nonce = &encrypted_page[nonce_start..]; let decrypted_data = self.decrypt_raw(payload, nonce)?; + let metadata_size = self.cipher_mode.metadata_size(); assert_eq!( decrypted_data.len(), - ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE, + ENCRYPTED_PAGE_SIZE - metadata_size, "Decrypted page data must be exactly {} bytes", - ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE + ENCRYPTED_PAGE_SIZE - metadata_size ); + let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); result.extend_from_slice(&decrypted_data); result.resize(ENCRYPTED_PAGE_SIZE, 0); @@ -231,6 +336,11 @@ impl EncryptionContext { .map_err(|e| LimboError::InternalError(format!("Encryption failed: {e:?}")))?; Ok((ciphertext, nonce.to_vec())) } + Cipher::Aegis256(cipher) => { + let ad = b""; + let (ciphertext, nonce) = cipher.encrypt(plaintext, ad)?; + Ok((ciphertext, nonce.to_vec())) + } } } @@ -243,6 +353,16 @@ impl EncryptionContext { })?; Ok(plaintext) } + Cipher::Aegis256(cipher) => { + let nonce_array: [u8; 32] = nonce.try_into().map_err(|_| { + LimboError::InternalError(format!( + "Invalid nonce size for AEGIS-256: expected 32, got {}", + nonce.len() + )) + })?; + let ad = b""; + cipher.decrypt(ciphertext, &nonce_array, ad) + } } } @@ -261,16 +381,33 @@ impl EncryptionContext { } } +fn generate_secure_nonce() -> [u8; 32] { + // use OsRng directly to fill bytes, similar to how AeadCore does it + use aes_gcm::aead::rand_core::RngCore; + let mut nonce = [0u8; 32]; + OsRng.fill_bytes(&mut nonce); + nonce +} + #[cfg(test)] mod tests { use super::*; use rand::Rng; + fn generate_random_hex_key() -> String { + let mut rng = rand::thread_rng(); + let mut bytes = [0u8; 32]; + rng.fill(&mut bytes); + hex::encode(bytes) + } + #[test] #[cfg(feature = "encryption")] - fn test_encrypt_decrypt_round_trip() { + fn test_aes_encrypt_decrypt_round_trip() { let mut rng = rand::thread_rng(); - let data_size = ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE; + let cipher_mode = CipherMode::Aes256Gcm; + let metadata_size = cipher_mode.metadata_size(); + let data_size = ENCRYPTED_PAGE_SIZE - metadata_size; let page_data = { let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE]; @@ -280,8 +417,8 @@ mod tests { page }; - let key = EncryptionKey::from_string("alice and bob use encryption on database"); - let ctx = EncryptionContext::new(&key).unwrap(); + let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap(); + let ctx = EncryptionContext::new(CipherMode::Aes256Gcm, &key).unwrap(); let page_id = 42; let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap(); @@ -293,4 +430,66 @@ mod tests { assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE); assert_eq!(decrypted, page_data); } + + #[test] + #[cfg(feature = "encryption")] + fn test_aegis256_cipher_wrapper() { + let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap(); + let cipher = Aegis256Cipher::new(&key); + + let plaintext = b"Hello, AEGIS-256!"; + let ad = b"additional data"; + + let (ciphertext, nonce) = cipher.encrypt(plaintext, ad).unwrap(); + assert_eq!(nonce.len(), 32); + assert_ne!(ciphertext[..plaintext.len()], plaintext[..]); + + let decrypted = cipher.decrypt(&ciphertext, &nonce, ad).unwrap(); + assert_eq!(decrypted, plaintext); + } + + #[test] + #[cfg(feature = "encryption")] + fn test_aegis256_raw_encryption() { + let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap(); + let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap(); + + let plaintext = b"Hello, AEGIS-256!"; + let (ciphertext, nonce) = ctx.encrypt_raw(plaintext).unwrap(); + + assert_eq!(nonce.len(), 32); // AEGIS-256 uses 32-byte nonces + assert_ne!(ciphertext[..plaintext.len()], plaintext[..]); + + let decrypted = ctx.decrypt_raw(&ciphertext, &nonce).unwrap(); + assert_eq!(decrypted, plaintext); + } + + #[test] + #[cfg(feature = "encryption")] + fn test_aegis256_encrypt_decrypt_round_trip() { + let mut rng = rand::thread_rng(); + let cipher_mode = CipherMode::Aegis256; + let metadata_size = cipher_mode.metadata_size(); + let data_size = ENCRYPTED_PAGE_SIZE - metadata_size; + + let page_data = { + let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE]; + page.iter_mut() + .take(data_size) + .for_each(|byte| *byte = rng.gen()); + page + }; + + let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap(); + let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap(); + + let page_id = 42; + let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap(); + assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE); + assert_ne!(&encrypted[..data_size], &page_data[..data_size]); + + let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap(); + assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE); + assert_eq!(decrypted, page_data); + } } diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 4512b93c7..fcee1ec4c 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1,4 +1,5 @@ use crate::result::LimboResult; +use crate::storage::wal::IOV_MAX; use crate::storage::{ btree::BTreePageInner, buffer_pool::BufferPool, @@ -28,9 +29,7 @@ use super::btree::{btree_init_page, BTreePage}; use super::page_cache::{CacheError, CacheResizeResult, DumbLruPageCache, PageCacheKey}; use super::sqlite3_ondisk::begin_write_btree_page; use super::wal::CheckpointMode; -use crate::storage::encryption::{ - EncryptionKey, EncryptionContext, ENCRYPTION_METADATA_SIZE, -}; +use crate::storage::encryption::{CipherMode, EncryptionContext, EncryptionKey}; /// SQLite's default maximum page count const DEFAULT_MAX_PAGE_COUNT: u32 = 0xfffffffe; @@ -125,6 +124,14 @@ pub struct PageInner { pub flags: AtomicUsize, pub contents: Option, pub id: usize, + /// If >0, the page is pinned and not eligible for eviction from the page cache. + /// The reason this is a counter is that multiple nested code paths may signal that + /// a page must not be evicted from the page cache, so even if an inner code path + /// requests unpinning via [Page::unpin], the pin count will still be >0 if the outer + /// code path has not yet requested to unpin the page as well. + /// + /// Note that [DumbLruPageCache::clear] evicts the pages even if pinned, so as long as + /// we clear the page cache on errors, pins will not 'leak'. pub pin_count: AtomicUsize, /// The WAL frame number this page was loaded from (0 if loaded from main DB file) /// This tracks which version of the page we have in memory @@ -239,12 +246,13 @@ impl Page { } } - /// Pin the page to prevent it from being evicted from the page cache. + /// Increment the pin count by 1. A pin count >0 means the page is pinned and not eligible for eviction from the page cache. pub fn pin(&self) { self.get().pin_count.fetch_add(1, Ordering::Relaxed); } - /// Unpin the page to allow it to be evicted from the page cache. + /// Decrement the pin count by 1. If the count reaches 0, the page is no longer + /// pinned and is eligible for eviction from the page cache. pub fn unpin(&self) { let was_pinned = self.try_unpin(); @@ -255,8 +263,8 @@ impl Page { ); } - /// Try to unpin the page if it's pinned, otherwise do nothing. - /// Returns true if the page was originally pinned. + /// Try to decrement the pin count by 1, but do nothing if it was already 0. + /// Returns true if the pin count was decremented. pub fn try_unpin(&self) -> bool { self.get() .pin_count @@ -270,6 +278,7 @@ impl Page { .is_ok() } + /// Returns true if the page is pinned and thus not eligible for eviction from the page cache. pub fn is_pinned(&self) -> bool { self.get().pin_count.load(Ordering::Acquire) > 0 } @@ -350,6 +359,7 @@ pub enum BtreePageAllocMode { /// This will keep track of the state of current cache commit in order to not repeat work struct CommitInfo { state: Cell, + time: Cell, } /// Track the state of the auto-vacuum mode. @@ -565,6 +575,7 @@ impl Pager { } else { RefCell::new(AllocatePage1State::Done) }; + let now = io.now(); Ok(Self { db_file, wal, @@ -575,6 +586,7 @@ impl Pager { ))), commit_info: CommitInfo { state: CommitState::Start.into(), + time: now.into(), }, syncing: Rc::new(Cell::new(false)), checkpoint_state: RefCell::new(CheckpointState::Checkpoint), @@ -1028,7 +1040,6 @@ impl Pager { &self, rollback: bool, connection: &Connection, - wal_auto_checkpoint_disabled: bool, ) -> Result> { if connection.is_nested_stmt.get() { // Parent statement will handle the transaction rollback. @@ -1052,7 +1063,8 @@ impl Pager { self.rollback(schema_did_change, connection, is_write)?; return Ok(IOResult::Done(PagerCommitResult::Rollback)); } - let commit_status = return_if_io!(self.commit_dirty_pages(wal_auto_checkpoint_disabled)); + let commit_status = + return_if_io!(self.commit_dirty_pages(connection.wal_auto_checkpoint_disabled.get())); wal.borrow().end_write_tx(); wal.borrow().end_read_tx(); @@ -1252,36 +1264,51 @@ impl Pager { .iter() .copied() .collect::>(); - let mut completions: Vec = Vec::with_capacity(dirty_pages.len()); - for page_id in dirty_pages { + let len = dirty_pages.len().min(IOV_MAX); + let mut completions: Vec = Vec::new(); + let mut pages = Vec::with_capacity(len); + let page_sz = self.page_size.get().unwrap_or_default(); + let commit_frame = None; // cacheflush only so we are not setting a commit frame here + for (idx, page_id) in dirty_pages.iter().enumerate() { let page = { let mut cache = self.page_cache.write(); - let page_key = PageCacheKey::new(page_id); + let page_key = PageCacheKey::new(*page_id); let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it."); let page_type = page.get().contents.as_ref().unwrap().maybe_page_type(); - trace!( - "commit_dirty_pages(page={}, page_type={:?}", - page_id, - page_type - ); + trace!("cacheflush(page={}, page_type={:?}", page_id, page_type); page }; + pages.push(page); + if pages.len() == IOV_MAX { + let c = wal + .borrow_mut() + .append_frames_vectored( + std::mem::replace( + &mut pages, + Vec::with_capacity(std::cmp::min(IOV_MAX, dirty_pages.len() - idx)), + ), + page_sz, + commit_frame, + ) + .inspect_err(|_| { + for c in completions.iter() { + c.abort(); + } + })?; + completions.push(c); + } + } + if !pages.is_empty() { let c = wal .borrow_mut() - .append_frame( - page.clone(), - self.page_size.get().expect("page size not set"), - 0, - ) + .append_frames_vectored(pages, page_sz, commit_frame) .inspect_err(|_| { for c in completions.iter() { c.abort(); } })?; - // TODO: invalidade previous completions if this one fails completions.push(c); } - // Pages are cleared dirty on callback completion Ok(completions) } @@ -1299,57 +1326,70 @@ impl Pager { "commit_dirty_pages() called on database without WAL".to_string(), )); }; + let mut checkpoint_result = CheckpointResult::default(); let res = loop { let state = self.commit_info.state.get(); trace!(?state); match state { CommitState::Start => { - let db_size = { + let now = self.io.now(); + self.commit_info.time.set(now); + let db_size_after = { self.io .block(|| self.with_header(|header| header.database_size))? .get() }; - let dirty_len = self.dirty_pages.borrow().iter().len(); - let mut completions: Vec = Vec::with_capacity(dirty_len); - for (curr_page_idx, page_id) in - self.dirty_pages.borrow().iter().copied().enumerate() - { - let is_last_frame = curr_page_idx == dirty_len - 1; - let db_size = if is_last_frame { db_size } else { 0 }; + let dirty_ids: Vec = self.dirty_pages.borrow().iter().copied().collect(); + if dirty_ids.is_empty() { + return Ok(IOResult::Done(PagerCommitResult::WalWritten)); + } + let page_sz = self.page_size.get().expect("page size not set"); + let mut completions: Vec = Vec::new(); + let mut pages: Vec = Vec::with_capacity(dirty_ids.len().min(IOV_MAX)); + let total = dirty_ids.len(); + + for (i, page_id) in dirty_ids.into_iter().enumerate() { let page = { let mut cache = self.page_cache.write(); let page_key = PageCacheKey::new(page_id); - let page = cache.get(&page_key).unwrap_or_else(|| { - panic!( - "we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it. page={page_id}" - ) - }); - let page_type = page.get().contents.as_ref().unwrap().maybe_page_type(); + let page = cache.get(&page_key).expect( + "dirty list contained a page that cache dropped (page={page_id})", + ); trace!( - "commit_dirty_pages(page={}, page_type={:?}", + "commit_dirty_pages(page={}, page_type={:?})", page_id, - page_type + page.get().contents.as_ref().unwrap().maybe_page_type() ); page }; + pages.push(page); - // TODO: invalidade previous completions on error here - let c = wal - .borrow_mut() - .append_frame( - page.clone(), - self.page_size.get().expect("page size not set"), - db_size, - ) - .inspect_err(|_| { - for c in completions.iter() { - c.abort(); + let end_of_chunk = pages.len() == IOV_MAX || i == total - 1; + if end_of_chunk { + let commit_flag = if i == total - 1 { + // Only the commit frame (final) frame carries the db_size + Some(db_size_after) + } else { + None + }; + let r = wal.borrow_mut().append_frames_vectored( + std::mem::take(&mut pages), + page_sz, + commit_flag, + ); + match r { + Ok(c) => completions.push(c), + Err(e) => { + for c in &completions { + c.abort(); + } + return Err(e); } - })?; - completions.push(c); + } + } } self.dirty_pages.borrow_mut().clear(); // Nothing to append @@ -1357,13 +1397,17 @@ impl Pager { return Ok(IOResult::Done(PagerCommitResult::WalWritten)); } else { self.commit_info.state.set(CommitState::SyncWal); + } + if !completions.iter().all(|c| c.is_completed()) { io_yield_many!(completions); } } CommitState::SyncWal => { self.commit_info.state.set(CommitState::AfterSyncWal); let c = wal.borrow_mut().sync()?; - io_yield_one!(c); + if !c.is_completed() { + io_yield_one!(c); + } } CommitState::AfterSyncWal => { turso_assert!(!wal.borrow().is_syncing(), "wal should have synced"); @@ -1380,7 +1424,9 @@ impl Pager { CommitState::SyncDbFile => { let c = sqlite3_ondisk::begin_sync(self.db_file.clone(), self.syncing.clone())?; self.commit_info.state.set(CommitState::AfterSyncDbFile); - io_yield_one!(c); + if !c.is_completed() { + io_yield_one!(c); + } } CommitState::AfterSyncDbFile => { turso_assert!(!self.syncing.get(), "should have finished syncing"); @@ -1389,7 +1435,15 @@ impl Pager { } } }; - // We should only signal that we finished appenind frames after wal sync to avoid inconsistencies when sync fails + + let now = self.io.now(); + tracing::debug!( + "total time flushing cache: {} ms", + now.to_system_time() + .duration_since(self.commit_info.time.get().to_system_time()) + .unwrap() + .as_millis() + ); wal.borrow_mut().finish_append_frames_commit()?; Ok(IOResult::Done(res)) } @@ -1418,22 +1472,25 @@ impl Pager { "wal_insert_frame() called on database without WAL".to_string(), )); }; - let mut wal = wal.borrow_mut(); let (header, raw_page) = parse_wal_frame_header(frame); - wal.write_frame_raw( - self.buffer_pool.clone(), - frame_no, - header.page_number as u64, - header.db_size as u64, - raw_page, - )?; - if let Some(page) = self.cache_get(header.page_number as usize) { - let content = page.get_contents(); - content.as_ptr().copy_from_slice(raw_page); - turso_assert!( - page.get().id == header.page_number as usize, - "page has unexpected id" - ); + + { + let mut wal = wal.borrow_mut(); + wal.write_frame_raw( + self.buffer_pool.clone(), + frame_no, + header.page_number as u64, + header.db_size as u64, + raw_page, + )?; + if let Some(page) = self.cache_get(header.page_number as usize) { + let content = page.get_contents(); + content.as_ptr().copy_from_slice(raw_page); + turso_assert!( + page.get().id == header.page_number as usize, + "page has unexpected id" + ); + } } if header.page_number == 1 { let db_size = self @@ -1508,8 +1565,18 @@ impl Pager { .expect("Failed to clear page cache"); } + /// Checkpoint in Truncate mode and delete the WAL file. This method is _only_ to be called + /// for shutting down the last remaining connection to a database. + /// + /// sqlite3.h + /// Usually, when a database in [WAL mode] is closed or detached from a + /// database handle, SQLite checks if if there are other connections to the + /// same database, and if there are no other database connection (if the + /// connection being closed is the last open connection to the database), + /// then SQLite performs a [checkpoint] before closing the connection and + /// deletes the WAL file. pub fn checkpoint_shutdown(&self, wal_auto_checkpoint_disabled: bool) -> Result<()> { - let mut _attempts = 0; + let mut attempts = 0; { let Some(wal) = self.wal.as_ref() else { return Err(LimboError::InternalError( @@ -1518,16 +1585,25 @@ impl Pager { }; let mut wal = wal.borrow_mut(); // fsync the wal syncronously before beginning checkpoint - // TODO: for now forget about timeouts as they fail regularly in SIM - // need to think of a better way to do this let c = wal.sync()?; self.io.wait_for_completion(c)?; } if !wal_auto_checkpoint_disabled { - self.wal_checkpoint(CheckpointMode::Passive { + while let Err(LimboError::Busy) = self.wal_checkpoint(CheckpointMode::Truncate { upper_bound_inclusive: None, - })?; + }) { + if attempts == 3 { + // don't return error on `close` if we are unable to checkpoint, we can silently fail + tracing::warn!( + "Failed to checkpoint WAL on shutdown after 3 attempts, giving up" + ); + return Ok(()); + } + attempts += 1; + } } + // TODO: delete the WAL file here after truncate checkpoint, but *only* if we are sure that + // no other connections have opened since. Ok(()) } @@ -1726,8 +1802,8 @@ impl Pager { default_header.database_size = 1.into(); // if a key is set, then we will reserve space for encryption metadata - if self.encryption_ctx.borrow().is_some() { - default_header.reserved_space = ENCRYPTION_METADATA_SIZE as u8; + if let Some(ref ctx) = *self.encryption_ctx.borrow() { + default_header.reserved_space = ctx.required_reserved_bytes() } if let Some(size) = self.page_size.get() { @@ -2086,6 +2162,7 @@ impl Pager { self.checkpoint_state.replace(CheckpointState::Checkpoint); self.syncing.replace(false); self.commit_info.state.set(CommitState::Start); + self.commit_info.time.set(self.io.now()); self.allocate_page_state.replace(AllocatePageState::Start); self.free_page_state.replace(FreePageState::Start); #[cfg(not(feature = "omit_autovacuum"))] @@ -2111,8 +2188,8 @@ impl Pager { Ok(IOResult::Done(f(header))) } - pub fn set_encryption_context(&self, key: &EncryptionKey) { - let encryption_ctx = EncryptionContext::new(key).unwrap(); + pub fn set_encryption_context(&self, cipher_mode: CipherMode, key: &EncryptionKey) { + let encryption_ctx = EncryptionContext::new(cipher_mode, key).unwrap(); self.encryption_ctx.replace(Some(encryption_ctx.clone())); let Some(wal) = self.wal.as_ref() else { return }; wal.borrow_mut().set_encryption_context(encryption_ctx) diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 8ec4c861f..53ebf04a4 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -568,6 +568,41 @@ impl PageContent { self.write_u16(BTREE_FIRST_FREEBLOCK, value); } + /// Write a freeblock to the page content at the given absolute offset. + /// Parameters: + /// - offset: the absolute offset of the freeblock + /// - size: the size of the freeblock + /// - next_block: the absolute offset of the next freeblock, or None if this is the last freeblock + pub fn write_freeblock(&self, offset: u16, size: u16, next_block: Option) { + self.write_freeblock_next_ptr(offset, next_block.unwrap_or(0)); + self.write_freeblock_size(offset, size); + } + + /// Write the new size of a freeblock. + /// Parameters: + /// - offset: the absolute offset of the freeblock + /// - size: the new size of the freeblock + pub fn write_freeblock_size(&self, offset: u16, size: u16) { + self.write_u16_no_offset(offset as usize + 2, size); + } + + /// Write the absolute offset of the next freeblock. + /// Parameters: + /// - offset: the absolute offset of the current freeblock + /// - next_block: the absolute offset of the next freeblock + pub fn write_freeblock_next_ptr(&self, offset: u16, next_block: u16) { + self.write_u16_no_offset(offset as usize, next_block); + } + + /// Read a freeblock from the page content at the given absolute offset. + /// Returns (absolute offset of next freeblock, size of the current freeblock) + pub fn read_freeblock(&self, offset: u16) -> (u16, u16) { + ( + self.read_u16_no_offset(offset as usize), + self.read_u16_no_offset(offset as usize + 2), + ) + } + /// Write the number of cells on this page. pub fn write_cell_count(&self, value: u16) { self.write_u16(BTREE_CELL_COUNT, value); diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 3b07d80f8..15fd04afe 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -1,7 +1,7 @@ -#![allow(clippy::arc_with_non_send_sync)] #![allow(clippy::not_unsafe_ptr_arg_deref)] use std::array; +use std::borrow::Cow; use std::cell::{RefCell, UnsafeCell}; use std::collections::{BTreeMap, HashMap, HashSet}; use strum::EnumString; @@ -276,6 +276,13 @@ pub trait Wal: Debug { db_size: u32, ) -> Result; + fn append_frames_vectored( + &mut self, + pages: Vec, + page_sz: PageSize, + db_size_on_commit: Option, + ) -> Result; + /// Complete append of frames by updating shared wal state. Before this /// all changes were stored locally. fn finish_append_frames_commit(&mut self) -> Result<()>; @@ -318,7 +325,8 @@ pub const CKPT_BATCH_PAGES: usize = 512; const MIN_AVG_RUN_FOR_FLUSH: f32 = 32.0; const MIN_BATCH_LEN_FOR_FLUSH: usize = 512; const MAX_INFLIGHT_WRITES: usize = 64; -const MAX_INFLIGHT_READS: usize = 512; +pub const MAX_INFLIGHT_READS: usize = 512; +pub const IOV_MAX: usize = 1024; type PageId = usize; struct InflightRead { @@ -815,14 +823,14 @@ impl Wal for WalFile { // WAL and fetch pages directly from the DB file. We do this // by taking read‑lock 0, and capturing the latest state. if shared_max == nbackfills { - let lock_idx = 0; - if !self.get_shared().read_locks[lock_idx].read() { + let lock_0_idx = 0; + if !self.get_shared().read_locks[lock_0_idx].read() { return Ok((LimboResult::Busy, db_changed)); } // we need to keep self.max_frame set to the appropriate // max frame in the wal at the time this transaction starts. self.max_frame = shared_max; - self.max_frame_read_lock_index.set(lock_idx); + self.max_frame_read_lock_index.set(lock_0_idx); self.min_frame = nbackfills + 1; self.last_checksum = last_checksum; return Ok((LimboResult::Ok, db_changed)); @@ -965,7 +973,7 @@ impl Wal for WalFile { } // Snapshot is stale, give up and let caller retry from scratch - tracing::debug!("unable to upgrade transaction from read to write: snapshot is stale, give up and let caller retry from scratch"); + tracing::debug!("unable to upgrade transaction from read to write: snapshot is stale, give up and let caller retry from scratch, self.max_frame={}, shared_max={}", self.max_frame, shared_max); shared.write_lock.unlock(); Ok(LimboResult::Busy) } @@ -1000,8 +1008,18 @@ impl Wal for WalFile { "frame_watermark must be >= than current WAL backfill amount: frame_watermark={:?}, nBackfill={}", frame_watermark, self.get_shared().nbackfills.load(Ordering::Acquire) ); - // if we are holding read_lock 0, skip and read right from db file. - if self.max_frame_read_lock_index.get() == 0 { + // if we are holding read_lock 0 and didn't write anything to the WAL, skip and read right from db file. + // + // note, that max_frame_read_lock_index is set to 0 only when shared_max_frame == nbackfill in which case + // min_frame is set to nbackfill + 1 and max_frame is set to shared_max_frame + // + // by default, SQLite tries to restart log file in this case - but for now let's keep it simple in the turso-db + if self.max_frame_read_lock_index.get() == 0 && self.max_frame < self.min_frame { + tracing::debug!( + "find_frame(page_id={}, frame_watermark={:?}): max_frame is 0 - read from DB file", + page_id, + frame_watermark, + ); return Ok(None); } let shared = self.get_shared(); @@ -1009,8 +1027,21 @@ impl Wal for WalFile { let range = frame_watermark .map(|x| 0..=x) .unwrap_or(self.min_frame..=self.max_frame); + tracing::debug!( + "find_frame(page_id={}, frame_watermark={:?}): min_frame={}, max_frame={}", + page_id, + frame_watermark, + self.min_frame, + self.max_frame + ); if let Some(list) = frames.get(&page_id) { if let Some(f) = list.iter().rfind(|&&f| range.contains(&f)) { + tracing::debug!( + "find_frame(page_id={}, frame_watermark={:?}): found frame={}", + page_id, + frame_watermark, + *f + ); return Ok(Some(*f)); } } @@ -1369,6 +1400,112 @@ impl Wal for WalFile { Ok(pages) } + /// Use pwritev to append many frames to the log at once + fn append_frames_vectored( + &mut self, + pages: Vec, + page_sz: PageSize, + db_size_on_commit: Option, + ) -> Result { + turso_assert!( + pages.len() <= IOV_MAX, + "we limit number of iovecs to IOV_MAX" + ); + self.ensure_header_if_needed(page_sz)?; + + let (header, shared_page_size, seq) = { + let shared = self.get_shared(); + let hdr_guard = shared.wal_header.lock(); + let header: WalHeader = *hdr_guard; + let shared_page_size = header.page_size; + let seq = header.checkpoint_seq; + (header, shared_page_size, seq) + }; + turso_assert!( + shared_page_size == page_sz.get(), + "page size mismatch, tried to change page size after WAL header was already initialized: shared.page_size={shared_page_size}, page_size={}", + page_sz.get() + ); + + // Prepare write buffers and bookkeeping + let mut iovecs: Vec> = Vec::with_capacity(pages.len()); + let mut page_frame_and_checksum: Vec<(PageRef, u64, (u32, u32))> = + Vec::with_capacity(pages.len()); + + // Rolling checksum input to each frame build + let mut rolling_checksum: (u32, u32) = self.last_checksum; + + let mut next_frame_id = self.max_frame + 1; + // Build every frame in order, updating the rolling checksum + for (idx, page) in pages.iter().enumerate() { + let page_id = page.get().id as u64; + let plain = page.get_contents().as_ptr(); + + let data_to_write: std::borrow::Cow<[u8]> = { + let ectx = self.encryption_ctx.borrow(); + if let Some(ctx) = ectx.as_ref() { + Cow::Owned(ctx.encrypt_page(plain, page_id as usize)?) + } else { + Cow::Borrowed(plain) + } + }; + + let frame_db_size = if idx + 1 == pages.len() { + // if it's the final frame we are appending, and the caller included a db_size for the + // commit frame, then we ensure to set it in the header. + db_size_on_commit.unwrap_or(0) + } else { + 0 + }; + let (new_checksum, frame_bytes) = prepare_wal_frame( + &self.buffer_pool, + &header, + rolling_checksum, + shared_page_size, + page_id as u32, + frame_db_size, + &data_to_write, + ); + iovecs.push(frame_bytes); + + // (page, assigned_frame_id, cumulative_checksum_at_this_frame) + page_frame_and_checksum.push((page.clone(), next_frame_id, new_checksum)); + + // Advance for the next frame + rolling_checksum = new_checksum; + next_frame_id += 1; + } + + let first_frame_id = self.max_frame + 1; + let start_off = self.frame_offset(first_frame_id); + + // pre-advance in-memory WAL state + for (page, fid, csum) in &page_frame_and_checksum { + self.complete_append_frame(page.get().id as u64, *fid, *csum); + } + + // single completion for the whole batch + let total_len: i32 = iovecs.iter().map(|b| b.len() as i32).sum(); + let page_frame_for_cb = page_frame_and_checksum.clone(); + let c = Completion::new_write(move |res: Result| { + let Ok(bytes_written) = res else { + return; + }; + turso_assert!( + bytes_written == total_len, + "pwritev wrote {bytes_written} bytes, expected {total_len}" + ); + + for (page, fid, _csum) in &page_frame_for_cb { + page.clear_dirty(); + page.set_wal_tag(*fid, seq); + } + }); + + let c = self.get_shared().file.pwritev(start_off, iovecs, c)?; + Ok(c) + } + #[cfg(debug_assertions)] fn as_any(&self) -> &dyn std::any::Any { self diff --git a/core/translate/analyze.rs b/core/translate/analyze.rs index 4b72b1457..0d8f8de4e 100644 --- a/core/translate/analyze.rs +++ b/core/translate/analyze.rs @@ -1,19 +1,27 @@ +use std::sync::Arc; + use turso_parser::ast; use crate::{ bail_parse_error, - schema::Schema, + schema::{BTreeTable, Schema}, + storage::pager::CreateBTreeFlags, + translate::{ + emitter::Resolver, + schema::{emit_schema_entry, SchemaEntryType, SQLITE_TABLEID}, + }, util::normalize_ident, vdbe::{ builder::{CursorType, ProgramBuilder}, - insn::{Insn, RegisterOrLiteral::*}, + insn::{Insn, RegisterOrLiteral}, }, - Result, + Result, SymbolTable, }; pub fn translate_analyze( target_opt: Option, schema: &Schema, + syms: &SymbolTable, mut program: ProgramBuilder, ) -> Result { let Some(target) = target_opt else { @@ -34,7 +42,15 @@ pub fn translate_analyze( dest_end: None, }); + // After preparing/creating sqlite_stat1, we need to OpenWrite it, and how we acquire + // the necessary BTreeTable for cursor creation and root page for the instruction changes + // depending on which path we take. + let sqlite_stat1_btreetable: Arc; + let sqlite_stat1_source: RegisterOrLiteral<_>; + if let Some(sqlite_stat1) = schema.get_btree_table("sqlite_stat1") { + sqlite_stat1_btreetable = sqlite_stat1.clone(); + sqlite_stat1_source = RegisterOrLiteral::Literal(sqlite_stat1.root_page); // sqlite_stat1 already exists, so we need to remove the row // corresponding to the stats for the table which we're about to // ANALYZE. SQLite implements this as a full table scan over @@ -43,7 +59,7 @@ pub fn translate_analyze( let cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_stat1.clone())); program.emit_insn(Insn::OpenWrite { cursor_id, - root_page: Literal(sqlite_stat1.root_page), + root_page: RegisterOrLiteral::Literal(sqlite_stat1.root_page), db: 0, }); let after_loop = program.allocate_label(); @@ -89,7 +105,61 @@ pub fn translate_analyze( }); program.preassign_label_to_next_insn(after_loop); } else { - bail_parse_error!("ANALYZE without an existing sqlite_stat1 is not supported"); + // FIXME: Emit ReadCookie 0 3 2 + // FIXME: Emit If 3 +2 0 + // FIXME: Emit SetCookie 0 2 4 + // FIXME: Emit SetCookie 0 5 1 + + // See the large comment in schema.rs:translate_create_table about + // deviating from SQLite codegen, as the same deviation is being done + // here. + + // TODO: this code half-copies translate_create_table, because there's + // no way to get the table_root_reg back out, and it's needed for later + // codegen to open the table we just created. It's worth a future + // refactoring to remove the duplication one the rest of ANALYZE is + // implemented. + let table_root_reg = program.alloc_register(); + program.emit_insn(Insn::CreateBtree { + db: 0, + root: table_root_reg, + flags: CreateBTreeFlags::new_table(), + }); + let sql = "CREATE TABLE sqlite_stat1(tbl,idx,stat)"; + // The root_page==0 is false, but we don't rely on it, and there's no + // way to initialize it with a correct value. + sqlite_stat1_btreetable = Arc::new(BTreeTable::from_sql(sql, 0)?); + sqlite_stat1_source = RegisterOrLiteral::Register(table_root_reg); + + let table = schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let sqlite_schema_cursor_id = + program.alloc_cursor_id(CursorType::BTreeTable(table.clone())); + program.emit_insn(Insn::OpenWrite { + cursor_id: sqlite_schema_cursor_id, + root_page: 1usize.into(), + db: 0, + }); + + let resolver = Resolver::new(schema, syms); + // Add the table entry to sqlite_schema + emit_schema_entry( + &mut program, + &resolver, + sqlite_schema_cursor_id, + None, + SchemaEntryType::Table, + "sqlite_stat1", + "sqlite_stat1", + table_root_reg, + Some(sql.to_string()), + )?; + //FIXME: Emit SetCookie? + let parse_schema_where_clause = + "tbl_name = 'sqlite_stat1' AND type != 'trigger'".to_string(); + program.emit_insn(Insn::ParseSchema { + db: sqlite_schema_cursor_id, + where_clause: Some(parse_schema_where_clause), + }); }; if target_schema.columns().iter().any(|c| c.primary_key) { @@ -100,13 +170,11 @@ pub fn translate_analyze( } // Count the number of rows in the target table, and insert it into sqlite_stat1. - let sqlite_stat1 = schema - .get_btree_table("sqlite_stat1") - .expect("sqlite_stat1 either pre-existed or was just created"); + let sqlite_stat1 = sqlite_stat1_btreetable; let stat_cursor = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_stat1.clone())); program.emit_insn(Insn::OpenWrite { cursor_id: stat_cursor, - root_page: Literal(sqlite_stat1.root_page), + root_page: sqlite_stat1_source, db: 0, }); let target_cursor = program.alloc_cursor_id(CursorType::BTreeTable(target_btree.clone())); diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 53111904a..398b31c3c 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -2206,6 +2206,15 @@ pub fn translate_expr( }); Ok(target_register) } + ast::Expr::Register(src_reg) => { + // For DBSP expression compilation: copy from source register to target + program.emit_insn(Insn::Copy { + src_reg: *src_reg, + dst_reg: target_register, + extra_amount: 0, + }); + Ok(target_register) + } }?; if let Some(span) = constant_span { @@ -2828,7 +2837,8 @@ where | ast::Expr::DoublyQualified(..) | ast::Expr::Name(_) | ast::Expr::Qualified(..) - | ast::Expr::Variable(_) => { + | ast::Expr::Variable(_) + | ast::Expr::Register(_) => { // No nested expressions } } @@ -3004,7 +3014,8 @@ where | ast::Expr::DoublyQualified(..) | ast::Expr::Name(_) | ast::Expr::Qualified(..) - | ast::Expr::Variable(_) => { + | ast::Expr::Variable(_) + | ast::Expr::Register(_) => { // No nested expressions } } diff --git a/core/translate/index.rs b/core/translate/index.rs index ef574f6af..7d332f0d2 100644 --- a/core/translate/index.rs +++ b/core/translate/index.rs @@ -44,6 +44,10 @@ pub fn translate_create_index( // Check if the index is being created on a valid btree table and // the name is globally unique in the schema. if !schema.is_unique_idx_name(&idx_name) { + // If IF NOT EXISTS is specified, silently return without error + if unique_if_not_exists.1 { + return Ok(program); + } crate::bail_parse_error!("Error: index with name '{idx_name}' already exists."); } let Some(tbl) = schema.tables.get(&tbl_name) else { diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 86cd60f52..7d29a8173 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -148,7 +148,7 @@ pub fn translate_inner( ast::Stmt::AlterTable(alter) => { translate_alter_table(alter, syms, schema, program, connection, input)? } - ast::Stmt::Analyze { name } => translate_analyze(name, schema, program)?, + ast::Stmt::Analyze { name } => translate_analyze(name, schema, syms, program)?, ast::Stmt::Attach { expr, db_name, key } => { attach::translate_attach(&expr, &db_name, &key, schema, syms, program)? } diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index 03da50566..7fb16b81b 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -674,6 +674,7 @@ impl Optimizable for ast::Expr { Expr::Subquery(..) => false, Expr::Unary(_, expr) => expr.is_nonnull(tables), Expr::Variable(..) => false, + Expr::Register(..) => false, // Register values can be null } } /// Returns true if the expression is a constant i.e. does not depend on variables or columns etc. @@ -751,6 +752,7 @@ impl Optimizable for ast::Expr { Expr::Subquery(_) => false, Expr::Unary(_, expr) => expr.is_constant(resolver), Expr::Variable(_) => false, + Expr::Register(_) => false, // Register values are not constants } } /// Returns true if the expression is a constant expression that, when evaluated as a condition, is always true or false diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 3ef8b84a5..43abf02e7 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -10,7 +10,7 @@ use turso_parser::ast::{PragmaName, QualifiedName}; use super::integrity_check::translate_integrity_check; use crate::pragma::pragma_for; use crate::schema::Schema; -use crate::storage::encryption::EncryptionKey; +use crate::storage::encryption::{CipherMode, EncryptionKey}; use crate::storage::pager::AutoVacuumMode; use crate::storage::pager::Pager; use crate::storage::sqlite3_ondisk::CacheSize; @@ -314,10 +314,16 @@ fn update_pragma( ), PragmaName::EncryptionKey => { let value = parse_string(&value)?; - let key = EncryptionKey::from_string(&value); + let key = EncryptionKey::from_hex_string(&value)?; connection.set_encryption_key(key); Ok((program, TransactionMode::None)) } + PragmaName::EncryptionCipher => { + let value = parse_string(&value)?; + let cipher = CipherMode::try_from(value.as_str())?; + connection.set_encryption_cipher(cipher); + Ok((program, TransactionMode::None)) + } } } @@ -589,6 +595,15 @@ fn query_pragma( program.add_pragma_result_column(pragma.to_string()); Ok((program, TransactionMode::None)) } + PragmaName::EncryptionCipher => { + if let Some(cipher) = connection.get_encryption_cipher_mode() { + let register = program.alloc_register(); + program.emit_string8(cipher.to_string(), register); + program.emit_result_row(register, 1); + program.add_pragma_result_column(pragma.to_string()); + } + Ok((program, TransactionMode::None)) + } } } diff --git a/core/translate/view.rs b/core/translate/view.rs index f2dcf40a8..b339e8961 100644 --- a/core/translate/view.rs +++ b/core/translate/view.rs @@ -96,7 +96,7 @@ pub fn translate_create_materialized_view( // This validation happens before updating sqlite_master to prevent // storing invalid view definitions use crate::incremental::view::IncrementalView; - IncrementalView::can_create_view(select_stmt, schema)?; + IncrementalView::can_create_view(select_stmt)?; // Reconstruct the SQL string let sql = create_materialized_view_to_str(view_name, select_stmt); diff --git a/core/util.rs b/core/util.rs index 200964a87..9b7d9e2c8 100644 --- a/core/util.rs +++ b/core/util.rs @@ -217,14 +217,13 @@ pub fn parse_schema_rows( if should_create_new { // Create a new IncrementalView - if let Ok(incremental_view) = - IncrementalView::from_sql(sql, schema) - { - let referenced_tables = - incremental_view.get_referenced_table_names(); - schema.add_materialized_view(incremental_view); - views_to_process.push((view_name, referenced_tables)); - } + // If this fails, we should propagate the error so the transaction rolls back + let incremental_view = + IncrementalView::from_sql(sql, schema)?; + let referenced_tables = + incremental_view.get_referenced_table_names(); + schema.add_materialized_view(incremental_view); + views_to_process.push((view_name, referenced_tables)); } } Stmt::CreateView { diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index c910d827f..849b37839 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -2170,7 +2170,7 @@ pub fn op_auto_commit( if *auto_commit != conn.auto_commit.get() { if *rollback { // TODO(pere): add rollback I/O logic once we implement rollback journal - return_if_io!(pager.end_tx(true, &conn, false)); + return_if_io!(pager.end_tx(true, &conn)); conn.transaction_state.replace(TransactionState::None); conn.auto_commit.replace(true); } else { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index b99c17cd5..1cdd2fb69 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -315,6 +315,14 @@ impl ProgramState { } } + pub fn set_register(&mut self, idx: usize, value: Register) { + self.registers[idx] = value; + } + + pub fn get_register(&self, idx: usize) -> &Register { + &self.registers[idx] + } + pub fn column_count(&self) -> usize { self.registers.len() } @@ -335,6 +343,10 @@ impl ProgramState { self.parameters.insert(index, value); } + pub fn clear_bindings(&mut self) { + self.parameters.clear(); + } + pub fn get_parameter(&self, index: NonZero) -> Value { self.parameters.get(&index).cloned().unwrap_or(Value::Null) } @@ -430,9 +442,7 @@ impl Program { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.transaction_state.get(); if let TransactionState::Write { .. } = state { - pager - .io - .block(|| pager.end_tx(true, &self.connection, false))?; + pager.io.block(|| pager.end_tx(true, &self.connection))?; } return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -604,11 +614,7 @@ impl Program { connection: &Connection, rollback: bool, ) -> Result> { - let cacheflush_status = pager.end_tx( - rollback, - connection, - connection.wal_auto_checkpoint_disabled.get(), - )?; + let cacheflush_status = pager.end_tx(rollback, connection)?; match cacheflush_status { IOResult::Done(_) => { if self.change_cnt_on { @@ -853,6 +859,10 @@ pub fn handle_program_error( connection: &Connection, err: &LimboError, ) -> Result<()> { + if connection.is_nested_stmt.get() { + // Errors from nested statements are handled by the parent statement. + return Ok(()); + } match err { // Transaction errors, e.g. trying to start a nested transaction, do not cause a rollback. LimboError::TxError(_) => {} @@ -861,7 +871,7 @@ pub fn handle_program_error( _ => { pager .io - .block(|| pager.end_tx(true, connection, false)) + .block(|| pager.end_tx(true, connection)) .inspect_err(|e| { tracing::error!("end_tx failed: {e}"); })?; diff --git a/flake.nix b/flake.nix index e262c09c4..8f59224a9 100644 --- a/flake.nix +++ b/flake.nix @@ -71,6 +71,7 @@ python3 nodejs toolchain + uv ] ++ lib.optionals pkgs.stdenv.isDarwin [ apple-sdk ]; diff --git a/parser/src/ast.rs b/parser/src/ast.rs index ce931027c..144a8be93 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -3,6 +3,8 @@ pub mod fmt; use strum_macros::{EnumIter, EnumString}; +use crate::ast::fmt::ToTokens; + /// `?` or `$` Prepared statement arg placeholder(s) #[derive(Default)] pub struct ParameterInfo { @@ -345,6 +347,9 @@ pub enum Expr { }, /// binary expression Binary(Box, Operator, Box), + /// Register reference for DBSP expression compilation + /// This is not part of SQL syntax but used internally for incremental computation + Register(usize), /// `CASE` expression Case { /// operand @@ -471,6 +476,76 @@ pub enum Expr { Variable(String), } +impl Expr { + pub fn into_boxed(self) -> Box { + Box::new(self) + } + + pub fn unary(operator: UnaryOperator, expr: Expr) -> Expr { + Expr::Unary(operator, Box::new(expr)) + } + + pub fn binary(lhs: Expr, operator: Operator, rhs: Expr) -> Expr { + Expr::Binary(Box::new(lhs), operator, Box::new(rhs)) + } + + pub fn not_null(expr: Expr) -> Expr { + Expr::NotNull(Box::new(expr)) + } + + pub fn between(lhs: Expr, not: bool, start: Expr, end: Expr) -> Expr { + Expr::Between { + lhs: Box::new(lhs), + not, + start: Box::new(start), + end: Box::new(end), + } + } + + pub fn in_select(lhs: Expr, not: bool, select: Select) -> Expr { + Expr::InSelect { + lhs: Box::new(lhs), + not, + rhs: select, + } + } + + pub fn like( + lhs: Expr, + not: bool, + operator: LikeOperator, + rhs: Expr, + escape: Option, + ) -> Expr { + Expr::Like { + lhs: Box::new(lhs), + not, + op: operator, + rhs: Box::new(rhs), + escape: escape.map(Box::new), + } + } + + pub fn is_null(expr: Expr) -> Expr { + Expr::IsNull(Box::new(expr)) + } + + pub fn collate(expr: Expr, name: Name) -> Expr { + Expr::Collate(Box::new(expr), name) + } + + pub fn cast(expr: Expr, type_name: Option) -> Expr { + Expr::Cast { + expr: Box::new(expr), + type_name, + } + } + + pub fn raise(resolve_type: ResolveType, expr: Option) -> Expr { + Expr::Raise(resolve_type, expr.map(Box::new)) + } +} + /// SQL literal #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -856,6 +931,41 @@ pub struct QualifiedName { pub alias: Option, // FIXME restrict alias usage (fullname vs xfullname) } +impl QualifiedName { + /// Constructor + pub fn single(name: Name) -> Self { + Self { + db_name: None, + name, + alias: None, + } + } + /// Constructor + pub fn fullname(db_name: Name, name: Name) -> Self { + Self { + db_name: Some(db_name), + name, + alias: None, + } + } + /// Constructor + pub fn xfullname(db_name: Name, name: Name, alias: Name) -> Self { + Self { + db_name: Some(db_name), + name, + alias: Some(alias), + } + } + /// Constructor + pub fn alias(name: Name, alias: Name) -> Self { + Self { + db_name: None, + name, + alias: Some(alias), + } + } +} + /// `ALTER TABLE` body // https://sqlite.org/lang_altertable.html #[derive(Clone, Debug, PartialEq, Eq)] @@ -1027,7 +1137,7 @@ bitflags::bitflags! { } /// Sort orders -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum SortOrder { /// `ASC` @@ -1036,6 +1146,12 @@ pub enum SortOrder { Desc, } +impl core::fmt::Display for SortOrder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_fmt(f) + } +} + /// `NULLS FIRST` or `NULLS LAST` #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -1204,6 +1320,10 @@ pub enum PragmaName { AutoVacuum, /// `cache_size` pragma CacheSize, + /// encryption cipher algorithm name for encrypted databases + #[strum(serialize = "cipher")] + #[cfg_attr(feature = "serde", serde(rename = "cipher"))] + EncryptionCipher, /// List databases DatabaseList, /// Encoding - only support utf8 @@ -1214,10 +1334,9 @@ pub enum PragmaName { IntegrityCheck, /// `journal_mode` pragma JournalMode, - /// encryption key for encrypted databases. This is just called `key` because most - /// extensions use this name instead of `encryption_key`. - #[strum(serialize = "key")] - #[cfg_attr(feature = "serde", serde(rename = "key"))] + /// encryption key for encrypted databases, specified as hexadecimal string. + #[strum(serialize = "hexkey")] + #[cfg_attr(feature = "serde", serde(rename = "hexkey"))] EncryptionKey, /// Noop as per SQLite docs LegacyFileFormat, diff --git a/parser/src/ast/fmt.rs b/parser/src/ast/fmt.rs index 224d942cf..9499ac07a 100644 --- a/parser/src/ast/fmt.rs +++ b/parser/src/ast/fmt.rs @@ -716,6 +716,11 @@ impl ToTokens for Expr { op.to_tokens(s, context)?; rhs.to_tokens(s, context) } + Self::Register(reg) => { + // This is for internal use only, not part of SQL syntax + // Use a special notation that won't conflict with SQL + s.append(TK_VARIABLE, Some(&format!("$r{reg}"))) + } Self::Case { base, when_then_pairs, diff --git a/perf/tpc-h/run.sh b/perf/tpc-h/run.sh index 7bea14c23..1d572f388 100755 --- a/perf/tpc-h/run.sh +++ b/perf/tpc-h/run.sh @@ -50,6 +50,8 @@ echo "Starting TPC-H query timing comparison..." echo "The script might ask you to enter the password for sudo, in order to clear system caches." clear_caches +exit_code=0 + for query_file in $(ls "$QUERIES_DIR"/*.sql | sort -V); do if [ -f "$query_file" ]; then query_name=$(basename "$query_file") @@ -85,6 +87,7 @@ for query_file in $(ls "$QUERIES_DIR"/*.sql | sort -V); do if [ -n "$output_diff" ]; then echo "Output difference:" echo "$output_diff" + exit_code=1 else echo "No output difference" fi @@ -96,3 +99,8 @@ done echo "-----------------------------------------------------------" echo "TPC-H query timing comparison completed." + +if [ $exit_code -ne 0 ]; then + echo "Error: Output differences found" + exit $exit_code +fi diff --git a/scripts/merge-pr.py b/scripts/merge-pr.py index 4ff2183a7..3aa48c683 100755 --- a/scripts/merge-pr.py +++ b/scripts/merge-pr.py @@ -9,6 +9,7 @@ import json import os import re +import shlex import subprocess import sys import tempfile @@ -112,8 +113,10 @@ def merge_remote(pr_number: int, commit_message: str, commit_title: str): try: print(f"\nMerging PR #{pr_number} with custom commit message...") + # Use gh pr merge with the commit message file - cmd = f'gh pr merge {pr_number} --merge --subject "{commit_title}" --body-file "{temp_file_path}"' + safe_title = shlex.quote(commit_title) + cmd = f'gh pr merge {pr_number} --merge --subject {safe_title} --body-file "{temp_file_path}"' output, error, returncode = run_command(cmd, capture_output=False) if returncode == 0: diff --git a/simulator-docker-runner/Dockerfile.simulator b/simulator-docker-runner/Dockerfile.simulator index f70a1e7ad..74f579ed6 100644 --- a/simulator-docker-runner/Dockerfile.simulator +++ b/simulator-docker-runner/Dockerfile.simulator @@ -26,6 +26,7 @@ COPY stress ./stress/ COPY tests ./tests/ COPY packages ./packages/ COPY testing/sqlite_test_ext ./testing/sqlite_test_ext +COPY sql_generation ./sql_generation/ RUN cargo chef prepare --bin limbo_sim --recipe-path recipe.json # @@ -43,6 +44,7 @@ COPY --from=planner /app/macros ./macros/ COPY --from=planner /app/parser ./parser/ COPY --from=planner /app/simulator ./simulator/ COPY --from=planner /app/packages ./packages/ +COPY --from=planner /app/sql_generation ./sql_generation/ RUN cargo build --bin limbo_sim --release diff --git a/simulator/Cargo.toml b/simulator/Cargo.toml index 20696fa91..f01896716 100644 --- a/simulator/Cargo.toml +++ b/simulator/Cargo.toml @@ -16,15 +16,14 @@ path = "main.rs" [dependencies] turso_core = { path = "../core", features = ["simulator"]} -rand = "0.8.5" -rand_chacha = "0.3.1" +rand = { workspace = true } +rand_chacha = "0.9.0" log = "0.4.20" env_logger = "0.10.1" regex = "1.11.1" regex-syntax = { version = "0.8.5", default-features = false, features = [ "unicode", ] } -anarchist-readable-name-generator-lib = "=0.1.2" clap = { version = "4.5", features = ["derive"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } @@ -32,9 +31,10 @@ notify = "8.0.0" rusqlite.workspace = true dirs = "6.0.0" chrono = { version = "0.4.40", features = ["serde"] } -tracing = "0.1.41" +tracing = { workspace = true } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } anyhow.workspace = true -turso_sqlite3_parser = { workspace = true, features = ["serde"]} hex = "0.4.3" itertools = "0.14.0" +sql_generation = { workspace = true } +turso_parser = { workspace = true } diff --git a/simulator/generation/mod.rs b/simulator/generation/mod.rs index 6d944b065..79bdf506f 100644 --- a/simulator/generation/mod.rs +++ b/simulator/generation/mod.rs @@ -1,65 +1,10 @@ -use std::{iter::Sum, ops::SubAssign}; +use sql_generation::generation::GenerationContext; -use anarchist_readable_name_generator_lib::readable_name_custom; -use rand::{distributions::uniform::SampleUniform, Rng}; +use crate::runner::env::{SimulatorEnv, SimulatorTables}; -use crate::runner::env::SimulatorTables; - -mod expr; pub mod plan; -mod predicate; pub mod property; pub mod query; -pub mod table; - -type ArbitraryFromFunc<'a, R, T> = Box T + 'a>; -type Choice<'a, R, T> = (usize, Box Option + 'a>); - -/// Arbitrary trait for generating random values -/// An implementation of arbitrary is assumed to be a uniform sampling of -/// the possible values of the type, with a bias towards smaller values for -/// practicality. -pub trait Arbitrary { - fn arbitrary(rng: &mut R) -> Self; -} - -/// ArbitrarySized trait for generating random values of a specific size -/// An implementation of arbitrary_sized is assumed to be a uniform sampling of -/// the possible values of the type, with a bias towards smaller values for -/// practicality, but with the additional constraint that the generated value -/// must fit in the given size. This is useful for generating values that are -/// constrained by a specific size, such as integers or strings. -pub trait ArbitrarySized { - fn arbitrary_sized(rng: &mut R, size: usize) -> Self; -} - -/// ArbitraryFrom trait for generating random values from a given value -/// ArbitraryFrom allows for constructing relations, where the generated -/// value is dependent on the given value. These relations could be constraints -/// such as generating an integer within an interval, or a value that fits in a table, -/// or a predicate satisfying a given table row. -pub trait ArbitraryFrom { - fn arbitrary_from(rng: &mut R, t: T) -> Self; -} - -/// ArbitrarySizedFrom trait for generating random values from a given value -/// ArbitrarySizedFrom allows for constructing relations, where the generated -/// value is dependent on the given value and a size constraint. These relations -/// could be constraints such as generating an integer within an interval, -/// or a value that fits in a table, or a predicate satisfying a given table row, -/// but with the additional constraint that the generated value must fit in the given size. -/// This is useful for generating values that are constrained by a specific size, -/// such as integers or strings, while still being dependent on the given value. -pub trait ArbitrarySizedFrom { - fn arbitrary_sized_from(rng: &mut R, t: T, size: usize) -> Self; -} - -/// ArbitraryFromMaybe trait for fallibally generating random values from a given value -pub trait ArbitraryFromMaybe { - fn arbitrary_from_maybe(rng: &mut R, t: T) -> Option - where - Self: Sized; -} /// Shadow trait for types that can be "shadowed" in the simulator environment. /// Shadowing is a process of applying a transformation to the simulator environment @@ -75,108 +20,14 @@ pub(crate) trait Shadow { fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result; } -/// Frequency is a helper function for composing different generators with different frequency -/// of occurrences. -/// The type signature for the `N` parameter is a bit complex, but it -/// roughly corresponds to a type that can be summed, compared, subtracted and sampled, which are -/// the operations we require for the implementation. -// todo: switch to a simpler type signature that can accommodate all integer and float types, which -// should be enough for our purposes. -pub(crate) fn frequency< - T, - R: Rng, - N: Sum + PartialOrd + Copy + Default + SampleUniform + SubAssign, ->( - choices: Vec<(N, ArbitraryFromFunc)>, - rng: &mut R, -) -> T { - let total = choices.iter().map(|(weight, _)| *weight).sum::(); - let mut choice = rng.gen_range(N::default()..total); - - for (weight, f) in choices { - if choice < weight { - return f(rng); - } - choice -= weight; +impl GenerationContext for SimulatorEnv { + fn tables(&self) -> &Vec { + &self.tables.tables } - unreachable!() -} - -/// one_of is a helper function for composing different generators with equal probability of occurrence. -pub(crate) fn one_of(choices: Vec>, rng: &mut R) -> T { - let index = rng.gen_range(0..choices.len()); - choices[index](rng) -} - -/// backtrack is a helper function for composing different "failable" generators. -/// The function takes a list of functions that return an Option, along with number of retries -/// to make before giving up. -pub(crate) fn backtrack(mut choices: Vec>, rng: &mut R) -> Option { - loop { - // If there are no more choices left, we give up - let choices_ = choices - .iter() - .enumerate() - .filter(|(_, (retries, _))| *retries > 0) - .collect::>(); - if choices_.is_empty() { - tracing::trace!("backtrack: no more choices left"); - return None; - } - // Run a one_of on the remaining choices - let (choice_index, choice) = pick(&choices_, rng); - let choice_index = *choice_index; - // If the choice returns None, we decrement the number of retries and try again - let result = choice.1(rng); - if result.is_some() { - return result; - } else { - choices[choice_index].0 -= 1; + fn opts(&self) -> sql_generation::generation::Opts { + sql_generation::generation::Opts { + indexes: self.opts.experimental_indexes, } } } - -/// pick is a helper function for uniformly picking a random element from a slice -pub(crate) fn pick<'a, T, R: Rng>(choices: &'a [T], rng: &mut R) -> &'a T { - let index = rng.gen_range(0..choices.len()); - &choices[index] -} - -/// pick_index is typically used for picking an index from a slice to later refer to the element -/// at that index. -pub(crate) fn pick_index(choices: usize, rng: &mut R) -> usize { - rng.gen_range(0..choices) -} - -/// pick_n_unique is a helper function for uniformly picking N unique elements from a range. -/// The elements themselves are usize, typically representing indices. -pub(crate) fn pick_n_unique( - range: std::ops::Range, - n: usize, - rng: &mut R, -) -> Vec { - use rand::seq::SliceRandom; - let mut items: Vec = range.collect(); - items.shuffle(rng); - items.into_iter().take(n).collect() -} - -/// gen_random_text uses `anarchist_readable_name_generator_lib` to generate random -/// readable names for tables, columns, text values etc. -pub(crate) fn gen_random_text(rng: &mut T) -> String { - let big_text = rng.gen_ratio(1, 1000); - if big_text { - // let max_size: u64 = 2 * 1024 * 1024 * 1024; - let max_size: u64 = 2 * 1024; - let size = rng.gen_range(1024..max_size); - let mut name = String::with_capacity(size as usize); - for i in 0..size { - name.push(((i % 26) as u8 + b'A') as char); - } - name - } else { - let name = readable_name_custom("_", rng); - name.replace("-", "_") - } -} diff --git a/simulator/generation/plan.rs b/simulator/generation/plan.rs index c27becad5..47763657a 100644 --- a/simulator/generation/plan.rs +++ b/simulator/generation/plan.rs @@ -8,14 +8,18 @@ use std::{ use serde::{Deserialize, Serialize}; +use sql_generation::{ + generation::{frequency, query::SelectFree, Arbitrary, ArbitraryFrom}, + model::{ + query::{update::Update, Create, CreateIndex, Delete, Drop, Insert, Select}, + table::SimValue, + }, +}; use turso_core::{Connection, Result, StepResult}; use crate::{ - generation::{query::SelectFree, Shadow}, - model::{ - query::{update::Update, Create, CreateIndex, Delete, Drop, Insert, Query, Select}, - table::SimValue, - }, + generation::Shadow, + model::Query, runner::{ env::{SimConnection, SimulationType, SimulatorTables}, io::SimulatorIO, @@ -23,8 +27,6 @@ use crate::{ SimulatorEnv, }; -use crate::generation::{frequency, Arbitrary, ArbitraryFrom}; - use super::property::{remaining, Property}; pub(crate) type ResultSet = Result>>; @@ -661,7 +663,7 @@ impl Interaction { .iter() .any(|file| file.sync_completion.borrow().is_some()) }; - let inject_fault = env.rng.gen_bool(current_prob); + let inject_fault = env.rng.random_bool(current_prob); // TODO: avoid for now injecting faults when syncing if inject_fault && !syncing { env.io.inject_fault(true); @@ -811,7 +813,7 @@ fn random_fault(rng: &mut R, env: &SimulatorEnv) -> Interactions { } else { vec![Fault::Disconnect, Fault::ReopenDatabase] }; - let fault = faults[rng.gen_range(0..faults.len())].clone(); + let fault = faults[rng.random_range(0..faults.len())].clone(); Interactions::Fault(fault) } diff --git a/simulator/generation/property.rs b/simulator/generation/property.rs index 4725aa384..288c4e75d 100644 --- a/simulator/generation/property.rs +++ b/simulator/generation/property.rs @@ -1,30 +1,23 @@ use serde::{Deserialize, Serialize}; -use turso_core::{types, LimboError}; -use turso_sqlite3_parser::ast::{self}; - -use crate::{ - generation::Shadow as _, +use sql_generation::{ + generation::{frequency, pick, pick_index, ArbitraryFrom}, model::{ query::{ predicate::Predicate, - select::{ - CompoundOperator, CompoundSelect, Distinctness, ResultColumn, SelectBody, - SelectInner, - }, + select::{CompoundOperator, CompoundSelect, ResultColumn, SelectBody, SelectInner}, transaction::{Begin, Commit, Rollback}, update::Update, - Create, Delete, Drop, Insert, Query, Select, + Create, Delete, Drop, Insert, Select, }, table::SimValue, }, - runner::env::SimulatorEnv, }; +use turso_core::{types, LimboError}; +use turso_parser::ast::{self, Distinctness}; -use super::{ - frequency, pick, pick_index, - plan::{Assertion, Interaction, InteractionStats, ResultSet}, - ArbitraryFrom, -}; +use crate::{generation::Shadow as _, model::Query, runner::env::SimulatorEnv}; + +use super::plan::{Assertion, Interaction, InteractionStats, ResultSet}; /// Properties are representations of executable specifications /// about the database behavior. @@ -1073,7 +1066,7 @@ fn property_insert_values_select( // Get a random table let table = pick(&env.tables, rng); // Generate rows to insert - let rows = (0..rng.gen_range(1..=5)) + let rows = (0..rng.random_range(1..=5)) .map(|_| Vec::::arbitrary_from(rng, table)) .collect::>(); @@ -1088,10 +1081,10 @@ fn property_insert_values_select( }; // Choose if we want queries to be executed in an interactive transaction - let interactive = if rng.gen_bool(0.5) { + let interactive = if rng.random_bool(0.5) { Some(InteractiveQueryInfo { - start_with_immediate: rng.gen_bool(0.5), - end_with_commit: rng.gen_bool(0.5), + start_with_immediate: rng.random_bool(0.5), + end_with_commit: rng.random_bool(0.5), }) } else { None @@ -1107,7 +1100,7 @@ fn property_insert_values_select( immediate: interactive.start_with_immediate, })); } - for _ in 0..rng.gen_range(0..3) { + for _ in 0..rng.random_range(0..3) { let query = Query::arbitrary_from(rng, (env, remaining)); match &query { Query::Delete(Delete { @@ -1198,7 +1191,7 @@ fn property_select_limit(rng: &mut R, env: &SimulatorEnv) -> Prope table.name.clone(), vec![ResultColumn::Star], Predicate::arbitrary_from(rng, table), - Some(rng.gen_range(1..=5)), + Some(rng.random_range(1..=5)), Distinctness::All, ); Property::SelectLimit { select } @@ -1221,7 +1214,7 @@ fn property_double_create_failure( // The interactions in the middle has the following constraints; // - [x] There will be no errors in the middle interactions.(best effort) // - [ ] Table `t` will not be renamed or dropped.(todo: add this constraint once ALTER or DROP is implemented) - for _ in 0..rng.gen_range(0..3) { + for _ in 0..rng.random_range(0..3) { let query = Query::arbitrary_from(rng, (env, remaining)); if let Query::Create(Create { table: t }) = &query { // There will be no errors in the middle interactions. @@ -1254,7 +1247,7 @@ fn property_delete_select( // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) // - [x] A row that holds for the predicate will not be inserted. // - [ ] The table `t` will not be renamed, dropped, or altered. (todo: add this constraint once ALTER or DROP is implemented) - for _ in 0..rng.gen_range(0..3) { + for _ in 0..rng.random_range(0..3) { let query = Query::arbitrary_from(rng, (env, remaining)); match &query { Query::Insert(Insert::Values { table: t, values }) => { @@ -1309,7 +1302,7 @@ fn property_drop_select( let mut queries = Vec::new(); // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) // - [-] The table `t` will not be created, no table will be renamed to `t`. (todo: update this constraint once ALTER is implemented) - for _ in 0..rng.gen_range(0..3) { + for _ in 0..rng.random_range(0..3) { let query = Query::arbitrary_from(rng, (env, remaining)); if let Query::Create(Create { table: t }) = &query { // - The table `t` will not be created diff --git a/simulator/generation/query.rs b/simulator/generation/query.rs index 1abd41b1b..bb1344c2a 100644 --- a/simulator/generation/query.rs +++ b/simulator/generation/query.rs @@ -1,330 +1,11 @@ -use crate::generation::{Arbitrary, ArbitraryFrom, ArbitrarySizedFrom, Shadow}; -use crate::model::query::predicate::Predicate; -use crate::model::query::select::{ - CompoundOperator, CompoundSelect, Distinctness, FromClause, OrderBy, ResultColumn, SelectBody, - SelectInner, -}; -use crate::model::query::update::Update; -use crate::model::query::{Create, Delete, Drop, Insert, Query, Select}; -use crate::model::table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}; -use crate::SimulatorEnv; -use itertools::Itertools; +use crate::{model::Query, SimulatorEnv}; use rand::Rng; -use turso_sqlite3_parser::ast::{Expr, SortOrder}; +use sql_generation::{ + generation::{frequency, Arbitrary, ArbitraryFrom}, + model::query::{update::Update, Create, Delete, Insert, Select}, +}; use super::property::Remaining; -use super::{backtrack, frequency, pick}; - -impl Arbitrary for Create { - fn arbitrary(rng: &mut R) -> Self { - Create { - table: Table::arbitrary(rng), - } - } -} - -impl ArbitraryFrom<&Vec> for FromClause { - fn arbitrary_from(rng: &mut R, tables: &Vec
) -> Self { - let num_joins = match rng.gen_range(0..=100) { - 0..=90 => 0, - 91..=97 => 1, - 98..=100 => 2, - _ => unreachable!(), - }; - - let mut tables = tables.clone(); - let mut table = pick(&tables, rng).clone(); - - tables.retain(|t| t.name != table.name); - - let name = table.name.clone(); - - let mut table_context = JoinTable { - tables: Vec::new(), - rows: Vec::new(), - }; - - let joins: Vec<_> = (0..num_joins) - .filter_map(|_| { - if tables.is_empty() { - return None; - } - let join_table = pick(&tables, rng).clone(); - let joined_table_name = join_table.name.clone(); - - tables.retain(|t| t.name != join_table.name); - table_context.rows = table_context - .rows - .iter() - .cartesian_product(join_table.rows.iter()) - .map(|(t_row, j_row)| { - let mut row = t_row.clone(); - row.extend(j_row.clone()); - row - }) - .collect(); - // TODO: inneficient. use a Deque to push_front? - table_context.tables.insert(0, join_table); - for row in &mut table.rows { - assert_eq!( - row.len(), - table.columns.len(), - "Row length does not match column length after join" - ); - } - - let predicate = Predicate::arbitrary_from(rng, &table); - Some(JoinedTable { - table: joined_table_name, - join_type: JoinType::Inner, - on: predicate, - }) - }) - .collect(); - FromClause { table: name, joins } - } -} - -impl ArbitraryFrom<&SimulatorEnv> for SelectInner { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let from = FromClause::arbitrary_from(rng, &env.tables); - let mut tables = env.tables.clone(); - // todo: this is a temporary hack because env is not separated from the tables - let join_table = from - .shadow(&mut tables) - .expect("Failed to shadow FromClause"); - let cuml_col_count = join_table.columns().count(); - - let order_by = 'order_by: { - if rng.gen_bool(0.3) { - let order_by_table_candidates = from - .joins - .iter() - .map(|j| j.table.clone()) - .chain(std::iter::once(from.table.clone())) - .collect::>(); - let order_by_col_count = - (rng.gen::() * rng.gen::() * (cuml_col_count as f64)) as usize; // skew towards 0 - if order_by_col_count == 0 { - break 'order_by None; - } - let mut col_names = std::collections::HashSet::new(); - let mut order_by_cols = Vec::new(); - while order_by_cols.len() < order_by_col_count { - let table = pick(&order_by_table_candidates, rng); - let table = tables.iter().find(|t| t.name == *table).unwrap(); - let col = pick(&table.columns, rng); - let col_name = format!("{}.{}", table.name, col.name); - if col_names.insert(col_name.clone()) { - order_by_cols.push(( - col_name, - if rng.gen_bool(0.5) { - SortOrder::Asc - } else { - SortOrder::Desc - }, - )); - } - } - Some(OrderBy { - columns: order_by_cols, - }) - } else { - None - } - }; - - SelectInner { - distinctness: if env.opts.experimental_indexes { - Distinctness::arbitrary(rng) - } else { - Distinctness::All - }, - columns: vec![ResultColumn::Star], - from: Some(from), - where_clause: Predicate::arbitrary_from(rng, &join_table), - order_by, - } - } -} - -impl ArbitrarySizedFrom<&SimulatorEnv> for SelectInner { - fn arbitrary_sized_from( - rng: &mut R, - env: &SimulatorEnv, - num_result_columns: usize, - ) -> Self { - let mut select_inner = SelectInner::arbitrary_from(rng, env); - let select_from = &select_inner.from.as_ref().unwrap(); - let table_names = select_from - .joins - .iter() - .map(|j| j.table.clone()) - .chain(std::iter::once(select_from.table.clone())) - .collect::>(); - - let flat_columns_names = table_names - .iter() - .flat_map(|t| { - env.tables - .iter() - .find(|table| table.name == *t) - .unwrap() - .columns - .iter() - .map(|c| format!("{}.{}", t.clone(), c.name)) - }) - .collect::>(); - let selected_columns = pick_unique(&flat_columns_names, num_result_columns, rng); - let mut columns = Vec::new(); - for column_name in selected_columns { - columns.push(ResultColumn::Column(column_name.clone())); - } - select_inner.columns = columns; - select_inner - } -} - -impl Arbitrary for Distinctness { - fn arbitrary(rng: &mut R) -> Self { - match rng.gen_range(0..=5) { - 0..4 => Distinctness::All, - _ => Distinctness::Distinct, - } - } -} -impl Arbitrary for CompoundOperator { - fn arbitrary(rng: &mut R) -> Self { - match rng.gen_range(0..=1) { - 0 => CompoundOperator::Union, - 1 => CompoundOperator::UnionAll, - _ => unreachable!(), - } - } -} - -/// SelectFree is a wrapper around Select that allows for arbitrary generation -/// of selects without requiring a specific environment, which is useful for generating -/// arbitrary expressions without referring to the tables. -pub(crate) struct SelectFree(pub(crate) Select); - -impl ArbitraryFrom<&SimulatorEnv> for SelectFree { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let expr = Predicate(Expr::arbitrary_sized_from(rng, env, 8)); - let select = Select::expr(expr); - Self(select) - } -} - -impl ArbitraryFrom<&SimulatorEnv> for Select { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - // Generate a number of selects based on the query size - // If experimental indexes are enabled, we can have selects with compounds - // Otherwise, we just have a single select with no compounds - let num_compound_selects = if env.opts.experimental_indexes { - match rng.gen_range(0..=100) { - 0..=95 => 0, - 96..=99 => 1, - 100 => 2, - _ => unreachable!(), - } - } else { - 0 - }; - - let min_column_count_across_tables = - env.tables.iter().map(|t| t.columns.len()).min().unwrap(); - - let num_result_columns = rng.gen_range(1..=min_column_count_across_tables); - - let mut first = SelectInner::arbitrary_sized_from(rng, env, num_result_columns); - - let mut rest: Vec = (0..num_compound_selects) - .map(|_| SelectInner::arbitrary_sized_from(rng, env, num_result_columns)) - .collect(); - - if !rest.is_empty() { - // ORDER BY is not supported in compound selects yet - first.order_by = None; - for s in &mut rest { - s.order_by = None; - } - } - - Self { - body: SelectBody { - select: Box::new(first), - compounds: rest - .into_iter() - .map(|s| CompoundSelect { - operator: CompoundOperator::arbitrary(rng), - select: Box::new(s), - }) - .collect(), - }, - limit: None, - } - } -} - -impl ArbitraryFrom<&SimulatorEnv> for Insert { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let gen_values = |rng: &mut R| { - let table = pick(&env.tables, rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(rng, &c.column_type)) - .collect() - }) - .collect(); - Some(Insert::Values { - table: table.name.clone(), - values, - }) - }; - - let _gen_select = |rng: &mut R| { - // Find a non-empty table - let select_table = env.tables.iter().find(|t| !t.rows.is_empty())?; - let row = pick(&select_table.rows, rng); - let predicate = Predicate::arbitrary_from(rng, (select_table, row)); - // Pick another table to insert into - let select = Select::simple(select_table.name.clone(), predicate); - let table = pick(&env.tables, rng); - Some(Insert::Select { - table: table.name.clone(), - select: Box::new(select), - }) - }; - - // TODO: Add back gen_select when https://github.com/tursodatabase/turso/issues/2129 is fixed. - // Backtrack here cannot return None - backtrack(vec![(1, Box::new(gen_values))], rng).unwrap() - } -} - -impl ArbitraryFrom<&SimulatorEnv> for Delete { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let table = pick(&env.tables, rng); - Self { - table: table.name.clone(), - predicate: Predicate::arbitrary_from(rng, table), - } - } -} - -impl ArbitraryFrom<&SimulatorEnv> for Drop { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let table = pick(&env.tables, rng); - Self { - table: table.name.clone(), - } - } -} impl ArbitraryFrom<(&SimulatorEnv, &Remaining)> for Query { fn arbitrary_from(rng: &mut R, (env, remaining): (&SimulatorEnv, &Remaining)) -> Self { @@ -355,43 +36,3 @@ impl ArbitraryFrom<(&SimulatorEnv, &Remaining)> for Query { ) } } - -fn pick_unique( - items: &[T], - count: usize, - rng: &mut impl rand::Rng, -) -> Vec -where - ::Owned: PartialEq, -{ - let mut picked: Vec = Vec::new(); - while picked.len() < count { - let item = pick(items, rng); - if !picked.contains(&item.to_owned()) { - picked.push(item.to_owned()); - } - } - picked -} - -impl ArbitraryFrom<&SimulatorEnv> for Update { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let table = pick(&env.tables, rng); - let num_cols = rng.gen_range(1..=table.columns.len()); - let columns = pick_unique(&table.columns, num_cols, rng); - let set_values: Vec<(String, SimValue)> = columns - .iter() - .map(|column| { - ( - column.name.clone(), - SimValue::arbitrary_from(rng, &column.column_type), - ) - }) - .collect(); - Update { - table: table.name.clone(), - set_values, - predicate: Predicate::arbitrary_from(rng, table), - } - } -} diff --git a/simulator/main.rs b/simulator/main.rs index d2a31f099..0aaaf8be7 100644 --- a/simulator/main.rs +++ b/simulator/main.rs @@ -2,7 +2,6 @@ use anyhow::anyhow; use clap::Parser; use generation::plan::{Interaction, InteractionPlan, InteractionPlanState}; -use generation::ArbitraryFrom; use notify::event::{DataChange, ModifyKind}; use notify::{EventKind, RecursiveMode, Watcher}; use rand::prelude::*; @@ -11,6 +10,7 @@ use runner::cli::{SimulatorCLI, SimulatorCommand}; use runner::env::SimulatorEnv; use runner::execution::{execute_plans, Execution, ExecutionHistory, ExecutionResult}; use runner::{differential, watch}; +use sql_generation::generation::ArbitraryFrom; use std::any::Any; use std::backtrace::Backtrace; use std::fs::OpenOptions; @@ -507,7 +507,7 @@ fn setup_simulation( (seed, env, plans) } else { let seed = cli_opts.seed.unwrap_or_else(|| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); rng.next_u64() }); tracing::info!("seed={}", seed); diff --git a/simulator/model/mod.rs b/simulator/model/mod.rs index e68355ee4..ce249baf5 100644 --- a/simulator/model/mod.rs +++ b/simulator/model/mod.rs @@ -1,4 +1,417 @@ -pub mod query; -pub mod table; +use std::{collections::HashSet, fmt::Display}; -pub(crate) const FAULT_ERROR_MSG: &str = "Injected fault"; +use anyhow::Context; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use sql_generation::model::{ + query::{ + select::{CompoundOperator, FromClause, ResultColumn, SelectInner}, + transaction::{Begin, Commit, Rollback}, + update::Update, + Create, CreateIndex, Delete, Drop, EmptyContext, Insert, Select, + }, + table::{JoinTable, JoinType, SimValue, Table, TableContext}, +}; +use turso_parser::ast::{fmt::ToTokens, Distinctness}; + +use crate::{generation::Shadow, runner::env::SimulatorTables}; + +// This type represents the potential queries on the database. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Query { + Create(Create), + Select(Select), + Insert(Insert), + Delete(Delete), + Update(Update), + Drop(Drop), + CreateIndex(CreateIndex), + Begin(Begin), + Commit(Commit), + Rollback(Rollback), +} + +impl Query { + pub fn dependencies(&self) -> HashSet { + match self { + Query::Select(select) => select.dependencies(), + Query::Create(_) => HashSet::new(), + Query::Insert(Insert::Select { table, .. }) + | Query::Insert(Insert::Values { table, .. }) + | Query::Delete(Delete { table, .. }) + | Query::Update(Update { table, .. }) + | Query::Drop(Drop { table, .. }) => HashSet::from_iter([table.clone()]), + Query::CreateIndex(CreateIndex { table_name, .. }) => { + HashSet::from_iter([table_name.clone()]) + } + Query::Begin(_) | Query::Commit(_) | Query::Rollback(_) => HashSet::new(), + } + } + pub fn uses(&self) -> Vec { + match self { + Query::Create(Create { table }) => vec![table.name.clone()], + Query::Select(select) => select.dependencies().into_iter().collect(), + Query::Insert(Insert::Select { table, .. }) + | Query::Insert(Insert::Values { table, .. }) + | Query::Delete(Delete { table, .. }) + | Query::Update(Update { table, .. }) + | Query::Drop(Drop { table, .. }) => vec![table.clone()], + Query::CreateIndex(CreateIndex { table_name, .. }) => vec![table_name.clone()], + Query::Begin(..) | Query::Commit(..) | Query::Rollback(..) => vec![], + } + } +} + +impl Display for Query { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Create(create) => write!(f, "{create}"), + Self::Select(select) => write!(f, "{select}"), + Self::Insert(insert) => write!(f, "{insert}"), + Self::Delete(delete) => write!(f, "{delete}"), + Self::Update(update) => write!(f, "{update}"), + Self::Drop(drop) => write!(f, "{drop}"), + Self::CreateIndex(create_index) => write!(f, "{create_index}"), + Self::Begin(begin) => write!(f, "{begin}"), + Self::Commit(commit) => write!(f, "{commit}"), + Self::Rollback(rollback) => write!(f, "{rollback}"), + } + } +} + +impl Shadow for Query { + type Result = anyhow::Result>>; + + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + match self { + Query::Create(create) => create.shadow(env), + Query::Insert(insert) => insert.shadow(env), + Query::Delete(delete) => delete.shadow(env), + Query::Select(select) => select.shadow(env), + Query::Update(update) => update.shadow(env), + Query::Drop(drop) => drop.shadow(env), + Query::CreateIndex(create_index) => Ok(create_index.shadow(env)), + Query::Begin(begin) => Ok(begin.shadow(env)), + Query::Commit(commit) => Ok(commit.shadow(env)), + Query::Rollback(rollback) => Ok(rollback.shadow(env)), + } + } +} + +impl Shadow for Create { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + if !tables.iter().any(|t| t.name == self.table.name) { + tables.push(self.table.clone()); + Ok(vec![]) + } else { + Err(anyhow::anyhow!( + "Table {} already exists. CREATE TABLE statement ignored.", + self.table.name + )) + } + } +} + +impl Shadow for CreateIndex { + type Result = Vec>; + fn shadow(&self, env: &mut SimulatorTables) -> Vec> { + env.tables + .iter_mut() + .find(|t| t.name == self.table_name) + .unwrap() + .indexes + .push(self.index_name.clone()); + vec![] + } +} + +impl Shadow for Delete { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + let table = tables.tables.iter_mut().find(|t| t.name == self.table); + + if let Some(table) = table { + // If the table exists, we can delete from it + let t2 = table.clone(); + table.rows.retain_mut(|r| !self.predicate.test(r, &t2)); + } else { + // If the table does not exist, we return an error + return Err(anyhow::anyhow!( + "Table {} does not exist. DELETE statement ignored.", + self.table + )); + } + + Ok(vec![]) + } +} + +impl Shadow for Drop { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + if !tables.iter().any(|t| t.name == self.table) { + // If the table does not exist, we return an error + return Err(anyhow::anyhow!( + "Table {} does not exist. DROP statement ignored.", + self.table + )); + } + + tables.tables.retain(|t| t.name != self.table); + + Ok(vec![]) + } +} + +impl Shadow for Insert { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + match self { + Insert::Values { table, values } => { + if let Some(t) = tables.tables.iter_mut().find(|t| &t.name == table) { + t.rows.extend(values.clone()); + } else { + return Err(anyhow::anyhow!( + "Table {} does not exist. INSERT statement ignored.", + table + )); + } + } + Insert::Select { table, select } => { + let rows = select.shadow(tables)?; + if let Some(t) = tables.tables.iter_mut().find(|t| &t.name == table) { + t.rows.extend(rows); + } else { + return Err(anyhow::anyhow!( + "Table {} does not exist. INSERT statement ignored.", + table + )); + } + } + } + + Ok(vec![]) + } +} + +impl Shadow for FromClause { + type Result = anyhow::Result; + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + let tables = &mut env.tables; + + let first_table = tables + .iter() + .find(|t| t.name == self.table) + .context("Table not found")?; + + let mut join_table = JoinTable { + tables: vec![first_table.clone()], + rows: Vec::new(), + }; + + for join in &self.joins { + let joined_table = tables + .iter() + .find(|t| t.name == join.table) + .context("Joined table not found")?; + + join_table.tables.push(joined_table.clone()); + + match join.join_type { + JoinType::Inner => { + // Implement inner join logic + let join_rows = joined_table + .rows + .iter() + .filter(|row| join.on.test(row, joined_table)) + .cloned() + .collect::>(); + // take a cartesian product of the rows + let all_row_pairs = join_table + .rows + .clone() + .into_iter() + .cartesian_product(join_rows.iter()); + + for (row1, row2) in all_row_pairs { + let row = row1.iter().chain(row2.iter()).cloned().collect::>(); + + let is_in = join.on.test(&row, &join_table); + + if is_in { + join_table.rows.push(row); + } + } + } + _ => todo!(), + } + } + Ok(join_table) + } +} + +impl Shadow for SelectInner { + type Result = anyhow::Result; + + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + if let Some(from) = &self.from { + let mut join_table = from.shadow(env)?; + let col_count = join_table.columns().count(); + for row in &mut join_table.rows { + assert_eq!( + row.len(), + col_count, + "Row length does not match column length after join" + ); + } + let join_clone = join_table.clone(); + + join_table + .rows + .retain(|row| self.where_clause.test(row, &join_clone)); + + if self.distinctness == Distinctness::Distinct { + join_table.rows.sort_unstable(); + join_table.rows.dedup(); + } + + Ok(join_table) + } else { + assert!(self + .columns + .iter() + .all(|col| matches!(col, ResultColumn::Expr(_)))); + + // If `WHERE` is false, just return an empty table + if !self.where_clause.test(&[], &Table::anonymous(vec![])) { + return Ok(JoinTable { + tables: Vec::new(), + rows: Vec::new(), + }); + } + + // Compute the results of the column expressions and make a row + let mut row = Vec::new(); + for col in &self.columns { + match col { + ResultColumn::Expr(expr) => { + let value = expr.eval(&[], &Table::anonymous(vec![])); + if let Some(value) = value { + row.push(value); + } else { + return Err(anyhow::anyhow!( + "Failed to evaluate expression in free select ({})", + expr.0.format_with_context(&EmptyContext {}).unwrap() + )); + } + } + _ => unreachable!("Only expressions are allowed in free selects"), + } + } + + Ok(JoinTable { + tables: Vec::new(), + rows: vec![row], + }) + } + } +} + +impl Shadow for Select { + type Result = anyhow::Result>>; + + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + let first_result = self.body.select.shadow(env)?; + + let mut rows = first_result.rows; + + for compound in self.body.compounds.iter() { + let compound_results = compound.select.shadow(env)?; + + match compound.operator { + CompoundOperator::Union => { + // Union means we need to combine the results, removing duplicates + let mut new_rows = compound_results.rows; + new_rows.extend(rows.clone()); + new_rows.sort_unstable(); + new_rows.dedup(); + rows = new_rows; + } + CompoundOperator::UnionAll => { + // Union all means we just concatenate the results + rows.extend(compound_results.rows.into_iter()); + } + } + } + + Ok(rows) + } +} + +impl Shadow for Begin { + type Result = Vec>; + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + tables.snapshot = Some(tables.tables.clone()); + vec![] + } +} + +impl Shadow for Commit { + type Result = Vec>; + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + tables.snapshot = None; + vec![] + } +} + +impl Shadow for Rollback { + type Result = Vec>; + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + if let Some(tables_) = tables.snapshot.take() { + tables.tables = tables_; + } + vec![] + } +} + +impl Shadow for Update { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + let table = tables.tables.iter_mut().find(|t| t.name == self.table); + + let table = if let Some(table) = table { + table + } else { + return Err(anyhow::anyhow!( + "Table {} does not exist. UPDATE statement ignored.", + self.table + )); + }; + + let t2 = table.clone(); + for row in table + .rows + .iter_mut() + .filter(|r| self.predicate.test(r, &t2)) + { + for (column, set_value) in &self.set_values { + if let Some((idx, _)) = table + .columns + .iter() + .enumerate() + .find(|(_, c)| &c.name == column) + { + row[idx] = set_value.clone(); + } + } + } + + Ok(vec![]) + } +} diff --git a/simulator/model/query/create.rs b/simulator/model/query/create.rs deleted file mode 100644 index ab0cd9789..000000000 --- a/simulator/model/query/create.rs +++ /dev/null @@ -1,45 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{ - generation::Shadow, - model::table::{SimValue, Table}, - runner::env::SimulatorTables, -}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Create { - pub(crate) table: Table, -} - -impl Shadow for Create { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - if !tables.iter().any(|t| t.name == self.table.name) { - tables.push(self.table.clone()); - Ok(vec![]) - } else { - Err(anyhow::anyhow!( - "Table {} already exists. CREATE TABLE statement ignored.", - self.table.name - )) - } - } -} - -impl Display for Create { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "CREATE TABLE {} (", self.table.name)?; - - for (i, column) in self.table.columns.iter().enumerate() { - if i != 0 { - write!(f, ",")?; - } - write!(f, "{} {}", column.name, column.column_type)?; - } - - write!(f, ")") - } -} diff --git a/simulator/model/query/create_index.rs b/simulator/model/query/create_index.rs deleted file mode 100644 index 276396d4e..000000000 --- a/simulator/model/query/create_index.rs +++ /dev/null @@ -1,106 +0,0 @@ -use crate::{ - generation::{gen_random_text, pick, pick_n_unique, ArbitraryFrom, Shadow}, - model::table::SimValue, - runner::env::{SimulatorEnv, SimulatorTables}, -}; -use rand::Rng; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum SortOrder { - Asc, - Desc, -} - -impl std::fmt::Display for SortOrder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - SortOrder::Asc => write!(f, "ASC"), - SortOrder::Desc => write!(f, "DESC"), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(crate) struct CreateIndex { - pub(crate) index_name: String, - pub(crate) table_name: String, - pub(crate) columns: Vec<(String, SortOrder)>, -} - -impl Shadow for CreateIndex { - type Result = Vec>; - fn shadow(&self, env: &mut SimulatorTables) -> Vec> { - env.tables - .iter_mut() - .find(|t| t.name == self.table_name) - .unwrap() - .indexes - .push(self.index_name.clone()); - vec![] - } -} - -impl std::fmt::Display for CreateIndex { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "CREATE INDEX {} ON {} ({})", - self.index_name, - self.table_name, - self.columns - .iter() - .map(|(name, order)| format!("{name} {order}")) - .collect::>() - .join(", ") - ) - } -} - -impl ArbitraryFrom<&SimulatorEnv> for CreateIndex { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - assert!( - !env.tables.is_empty(), - "Cannot create an index when no tables exist in the environment." - ); - - let table = pick(&env.tables, rng); - - if table.columns.is_empty() { - panic!( - "Cannot create an index on table '{}' as it has no columns.", - table.name - ); - } - - let num_columns_to_pick = rng.gen_range(1..=table.columns.len()); - let picked_column_indices = pick_n_unique(0..table.columns.len(), num_columns_to_pick, rng); - - let columns = picked_column_indices - .into_iter() - .map(|i| { - let column = &table.columns[i]; - ( - column.name.clone(), - if rng.gen_bool(0.5) { - SortOrder::Asc - } else { - SortOrder::Desc - }, - ) - }) - .collect::>(); - - let index_name = format!( - "idx_{}_{}", - table.name, - gen_random_text(rng).chars().take(8).collect::() - ); - - CreateIndex { - index_name, - table_name: table.name.clone(), - columns, - } - } -} diff --git a/simulator/model/query/delete.rs b/simulator/model/query/delete.rs deleted file mode 100644 index 265cdfe96..000000000 --- a/simulator/model/query/delete.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - -use super::predicate::Predicate; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Delete { - pub(crate) table: String, - pub(crate) predicate: Predicate, -} - -impl Shadow for Delete { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - let table = tables.tables.iter_mut().find(|t| t.name == self.table); - - if let Some(table) = table { - // If the table exists, we can delete from it - let t2 = table.clone(); - table.rows.retain_mut(|r| !self.predicate.test(r, &t2)); - } else { - // If the table does not exist, we return an error - return Err(anyhow::anyhow!( - "Table {} does not exist. DELETE statement ignored.", - self.table - )); - } - - Ok(vec![]) - } -} - -impl Display for Delete { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DELETE FROM {} WHERE {}", self.table, self.predicate) - } -} diff --git a/simulator/model/query/drop.rs b/simulator/model/query/drop.rs deleted file mode 100644 index 2b4379ff9..000000000 --- a/simulator/model/query/drop.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Drop { - pub(crate) table: String, -} - -impl Shadow for Drop { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - if !tables.iter().any(|t| t.name == self.table) { - // If the table does not exist, we return an error - return Err(anyhow::anyhow!( - "Table {} does not exist. DROP statement ignored.", - self.table - )); - } - - tables.tables.retain(|t| t.name != self.table); - - Ok(vec![]) - } -} - -impl Display for Drop { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DROP TABLE {}", self.table) - } -} diff --git a/simulator/model/query/mod.rs b/simulator/model/query/mod.rs deleted file mode 100644 index 38e44073e..000000000 --- a/simulator/model/query/mod.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::{collections::HashSet, fmt::Display}; - -pub(crate) use create::Create; -pub(crate) use create_index::CreateIndex; -pub(crate) use delete::Delete; -pub(crate) use drop::Drop; -pub(crate) use insert::Insert; -pub(crate) use select::Select; -use serde::{Deserialize, Serialize}; -use turso_sqlite3_parser::to_sql_string::ToSqlContext; -use update::Update; - -use crate::{ - generation::Shadow, - model::{ - query::transaction::{Begin, Commit, Rollback}, - table::SimValue, - }, - runner::env::SimulatorTables, -}; - -pub mod create; -pub mod create_index; -pub mod delete; -pub mod drop; -pub mod insert; -pub mod predicate; -pub mod select; -pub mod transaction; -pub mod update; - -// This type represents the potential queries on the database. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum Query { - Create(Create), - Select(Select), - Insert(Insert), - Delete(Delete), - Update(Update), - Drop(Drop), - CreateIndex(CreateIndex), - Begin(Begin), - Commit(Commit), - Rollback(Rollback), -} - -impl Query { - pub(crate) fn dependencies(&self) -> HashSet { - match self { - Query::Select(select) => select.dependencies(), - Query::Create(_) => HashSet::new(), - Query::Insert(Insert::Select { table, .. }) - | Query::Insert(Insert::Values { table, .. }) - | Query::Delete(Delete { table, .. }) - | Query::Update(Update { table, .. }) - | Query::Drop(Drop { table, .. }) => HashSet::from_iter([table.clone()]), - Query::CreateIndex(CreateIndex { table_name, .. }) => { - HashSet::from_iter([table_name.clone()]) - } - Query::Begin(_) | Query::Commit(_) | Query::Rollback(_) => HashSet::new(), - } - } - pub(crate) fn uses(&self) -> Vec { - match self { - Query::Create(Create { table }) => vec![table.name.clone()], - Query::Select(select) => select.dependencies().into_iter().collect(), - Query::Insert(Insert::Select { table, .. }) - | Query::Insert(Insert::Values { table, .. }) - | Query::Delete(Delete { table, .. }) - | Query::Update(Update { table, .. }) - | Query::Drop(Drop { table, .. }) => vec![table.clone()], - Query::CreateIndex(CreateIndex { table_name, .. }) => vec![table_name.clone()], - Query::Begin(..) | Query::Commit(..) | Query::Rollback(..) => vec![], - } - } -} - -impl Shadow for Query { - type Result = anyhow::Result>>; - - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - match self { - Query::Create(create) => create.shadow(env), - Query::Insert(insert) => insert.shadow(env), - Query::Delete(delete) => delete.shadow(env), - Query::Select(select) => select.shadow(env), - Query::Update(update) => update.shadow(env), - Query::Drop(drop) => drop.shadow(env), - Query::CreateIndex(create_index) => Ok(create_index.shadow(env)), - Query::Begin(begin) => Ok(begin.shadow(env)), - Query::Commit(commit) => Ok(commit.shadow(env)), - Query::Rollback(rollback) => Ok(rollback.shadow(env)), - } - } -} - -impl Display for Query { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Create(create) => write!(f, "{create}"), - Self::Select(select) => write!(f, "{select}"), - Self::Insert(insert) => write!(f, "{insert}"), - Self::Delete(delete) => write!(f, "{delete}"), - Self::Update(update) => write!(f, "{update}"), - Self::Drop(drop) => write!(f, "{drop}"), - Self::CreateIndex(create_index) => write!(f, "{create_index}"), - Self::Begin(begin) => write!(f, "{begin}"), - Self::Commit(commit) => write!(f, "{commit}"), - Self::Rollback(rollback) => write!(f, "{rollback}"), - } - } -} - -/// Used to print sql strings that already have all the context it needs -pub(crate) struct EmptyContext; - -impl ToSqlContext for EmptyContext { - fn get_column_name( - &self, - _table_id: turso_sqlite3_parser::ast::TableInternalId, - _col_idx: usize, - ) -> String { - unreachable!() - } - - fn get_table_name(&self, _id: turso_sqlite3_parser::ast::TableInternalId) -> &str { - unreachable!() - } -} diff --git a/simulator/model/query/select.rs b/simulator/model/query/select.rs deleted file mode 100644 index 2dcc762a8..000000000 --- a/simulator/model/query/select.rs +++ /dev/null @@ -1,497 +0,0 @@ -use std::{collections::HashSet, fmt::Display}; - -use anyhow::Context; -pub use ast::Distinctness; -use itertools::Itertools; -use serde::{Deserialize, Serialize}; -use turso_sqlite3_parser::ast::{self, fmt::ToTokens, SortOrder}; - -use crate::{ - generation::Shadow, - model::{ - query::EmptyContext, - table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}, - }, - runner::env::SimulatorTables, -}; - -use super::predicate::Predicate; - -/// `SELECT` or `RETURNING` result column -// https://sqlite.org/syntax/result-column.html -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum ResultColumn { - /// expression - Expr(Predicate), - /// `*` - Star, - /// column name - Column(String), -} - -impl Display for ResultColumn { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ResultColumn::Expr(expr) => write!(f, "({expr})"), - ResultColumn::Star => write!(f, "*"), - ResultColumn::Column(name) => write!(f, "{name}"), - } - } -} -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Select { - pub(crate) body: SelectBody, - pub(crate) limit: Option, -} - -impl Select { - pub fn simple(table: String, where_clause: Predicate) -> Self { - Self::single( - table, - vec![ResultColumn::Star], - where_clause, - None, - Distinctness::All, - ) - } - - pub fn expr(expr: Predicate) -> Self { - Select { - body: SelectBody { - select: Box::new(SelectInner { - distinctness: Distinctness::All, - columns: vec![ResultColumn::Expr(expr)], - from: None, - where_clause: Predicate::true_(), - order_by: None, - }), - compounds: Vec::new(), - }, - limit: None, - } - } - - pub fn single( - table: String, - result_columns: Vec, - where_clause: Predicate, - limit: Option, - distinct: Distinctness, - ) -> Self { - Select { - body: SelectBody { - select: Box::new(SelectInner { - distinctness: distinct, - columns: result_columns, - from: Some(FromClause { - table, - joins: Vec::new(), - }), - where_clause, - order_by: None, - }), - compounds: Vec::new(), - }, - limit, - } - } - - pub fn compound(left: Select, right: Select, operator: CompoundOperator) -> Self { - let mut body = left.body; - body.compounds.push(CompoundSelect { - operator, - select: Box::new(right.body.select.as_ref().clone()), - }); - Select { - body, - limit: left.limit.or(right.limit), - } - } - - pub(crate) fn dependencies(&self) -> HashSet { - if self.body.select.from.is_none() { - return HashSet::new(); - } - let from = self.body.select.from.as_ref().unwrap(); - let mut tables = HashSet::new(); - tables.insert(from.table.clone()); - - tables.extend(from.dependencies()); - - for compound in &self.body.compounds { - tables.extend( - compound - .select - .from - .as_ref() - .map(|f| f.dependencies()) - .unwrap_or(vec![]) - .into_iter(), - ); - } - - tables - } -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct SelectBody { - /// first select - pub select: Box, - /// compounds - pub compounds: Vec, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct OrderBy { - pub columns: Vec<(String, SortOrder)>, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct SelectInner { - /// `DISTINCT` - pub distinctness: Distinctness, - /// columns - pub columns: Vec, - /// `FROM` clause - pub from: Option, - /// `WHERE` clause - pub where_clause: Predicate, - /// `ORDER BY` clause - pub order_by: Option, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum CompoundOperator { - /// `UNION` - Union, - /// `UNION ALL` - UnionAll, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct CompoundSelect { - /// operator - pub operator: CompoundOperator, - /// select - pub select: Box, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct FromClause { - /// table - pub table: String, - /// `JOIN`ed tables - pub joins: Vec, -} - -impl FromClause { - fn to_sql_ast(&self) -> ast::FromClause { - ast::FromClause { - select: Some(Box::new(ast::SelectTable::Table( - ast::QualifiedName::single(ast::Name::from_str(&self.table)), - None, - None, - ))), - joins: if self.joins.is_empty() { - None - } else { - Some( - self.joins - .iter() - .map(|join| ast::JoinedSelectTable { - operator: match join.join_type { - JoinType::Inner => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::INNER)) - } - JoinType::Left => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::LEFT)) - } - JoinType::Right => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::RIGHT)) - } - JoinType::Full => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::OUTER)) - } - JoinType::Cross => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::CROSS)) - } - }, - table: ast::SelectTable::Table( - ast::QualifiedName::single(ast::Name::from_str(&join.table)), - None, - None, - ), - constraint: Some(ast::JoinConstraint::On(join.on.0.clone())), - }) - .collect(), - ) - }, - op: None, // FIXME: this is a temporary fix, we should remove this field - } - } - - pub(crate) fn dependencies(&self) -> Vec { - let mut deps = vec![self.table.clone()]; - for join in &self.joins { - deps.push(join.table.clone()); - } - deps - } -} - -impl Shadow for FromClause { - type Result = anyhow::Result; - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - let tables = &mut env.tables; - - let first_table = tables - .iter() - .find(|t| t.name == self.table) - .context("Table not found")?; - - let mut join_table = JoinTable { - tables: vec![first_table.clone()], - rows: Vec::new(), - }; - - for join in &self.joins { - let joined_table = tables - .iter() - .find(|t| t.name == join.table) - .context("Joined table not found")?; - - join_table.tables.push(joined_table.clone()); - - match join.join_type { - JoinType::Inner => { - // Implement inner join logic - let join_rows = joined_table - .rows - .iter() - .filter(|row| join.on.test(row, joined_table)) - .cloned() - .collect::>(); - // take a cartesian product of the rows - let all_row_pairs = join_table - .rows - .clone() - .into_iter() - .cartesian_product(join_rows.iter()); - - for (row1, row2) in all_row_pairs { - let row = row1.iter().chain(row2.iter()).cloned().collect::>(); - - let is_in = join.on.test(&row, &join_table); - - if is_in { - join_table.rows.push(row); - } - } - } - _ => todo!(), - } - } - Ok(join_table) - } -} - -impl Shadow for SelectInner { - type Result = anyhow::Result; - - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - if let Some(from) = &self.from { - let mut join_table = from.shadow(env)?; - let col_count = join_table.columns().count(); - for row in &mut join_table.rows { - assert_eq!( - row.len(), - col_count, - "Row length does not match column length after join" - ); - } - let join_clone = join_table.clone(); - - join_table - .rows - .retain(|row| self.where_clause.test(row, &join_clone)); - - if self.distinctness == Distinctness::Distinct { - join_table.rows.sort_unstable(); - join_table.rows.dedup(); - } - - Ok(join_table) - } else { - assert!(self - .columns - .iter() - .all(|col| matches!(col, ResultColumn::Expr(_)))); - - // If `WHERE` is false, just return an empty table - if !self.where_clause.test(&[], &Table::anonymous(vec![])) { - return Ok(JoinTable { - tables: Vec::new(), - rows: Vec::new(), - }); - } - - // Compute the results of the column expressions and make a row - let mut row = Vec::new(); - for col in &self.columns { - match col { - ResultColumn::Expr(expr) => { - let value = expr.eval(&[], &Table::anonymous(vec![])); - if let Some(value) = value { - row.push(value); - } else { - return Err(anyhow::anyhow!( - "Failed to evaluate expression in free select ({})", - expr.0.format_with_context(&EmptyContext {}).unwrap() - )); - } - } - _ => unreachable!("Only expressions are allowed in free selects"), - } - } - - Ok(JoinTable { - tables: Vec::new(), - rows: vec![row], - }) - } - } -} - -impl Shadow for Select { - type Result = anyhow::Result>>; - - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - let first_result = self.body.select.shadow(env)?; - - let mut rows = first_result.rows; - - for compound in self.body.compounds.iter() { - let compound_results = compound.select.shadow(env)?; - - match compound.operator { - CompoundOperator::Union => { - // Union means we need to combine the results, removing duplicates - let mut new_rows = compound_results.rows; - new_rows.extend(rows.clone()); - new_rows.sort_unstable(); - new_rows.dedup(); - rows = new_rows; - } - CompoundOperator::UnionAll => { - // Union all means we just concatenate the results - rows.extend(compound_results.rows.into_iter()); - } - } - } - - Ok(rows) - } -} - -impl Select { - pub fn to_sql_ast(&self) -> ast::Select { - ast::Select { - with: None, - body: ast::SelectBody { - select: Box::new(ast::OneSelect::Select(Box::new(ast::SelectInner { - distinctness: if self.body.select.distinctness == Distinctness::Distinct { - Some(ast::Distinctness::Distinct) - } else { - None - }, - columns: self - .body - .select - .columns - .iter() - .map(|col| match col { - ResultColumn::Expr(expr) => { - ast::ResultColumn::Expr(expr.0.clone(), None) - } - ResultColumn::Star => ast::ResultColumn::Star, - ResultColumn::Column(name) => ast::ResultColumn::Expr( - ast::Expr::Id(ast::Name::Ident(name.clone())), - None, - ), - }) - .collect(), - from: self.body.select.from.as_ref().map(|f| f.to_sql_ast()), - where_clause: Some(self.body.select.where_clause.0.clone()), - group_by: None, - window_clause: None, - }))), - compounds: Some( - self.body - .compounds - .iter() - .map(|compound| ast::CompoundSelect { - operator: match compound.operator { - CompoundOperator::Union => ast::CompoundOperator::Union, - CompoundOperator::UnionAll => ast::CompoundOperator::UnionAll, - }, - select: Box::new(ast::OneSelect::Select(Box::new(ast::SelectInner { - distinctness: Some(compound.select.distinctness), - columns: compound - .select - .columns - .iter() - .map(|col| match col { - ResultColumn::Expr(expr) => { - ast::ResultColumn::Expr(expr.0.clone(), None) - } - ResultColumn::Star => ast::ResultColumn::Star, - ResultColumn::Column(name) => ast::ResultColumn::Expr( - ast::Expr::Id(ast::Name::Ident(name.clone())), - None, - ), - }) - .collect(), - from: compound.select.from.as_ref().map(|f| f.to_sql_ast()), - where_clause: Some(compound.select.where_clause.0.clone()), - group_by: None, - window_clause: None, - }))), - }) - .collect(), - ), - }, - order_by: self.body.select.order_by.as_ref().map(|o| { - o.columns - .iter() - .map(|(name, order)| ast::SortedColumn { - expr: ast::Expr::Id(ast::Name::Ident(name.clone())), - order: match order { - SortOrder::Asc => Some(ast::SortOrder::Asc), - SortOrder::Desc => Some(ast::SortOrder::Desc), - }, - nulls: None, - }) - .collect() - }), - limit: self.limit.map(|l| { - Box::new(ast::Limit { - expr: ast::Expr::Literal(ast::Literal::Numeric(l.to_string())), - offset: None, - }) - }), - } - } -} -impl Display for Select { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.to_sql_ast().to_fmt_with_context(f, &EmptyContext {}) - } -} - -#[cfg(test)] -mod select_tests { - - #[test] - fn test_select_display() {} -} diff --git a/simulator/model/query/transaction.rs b/simulator/model/query/transaction.rs deleted file mode 100644 index a73fb076e..000000000 --- a/simulator/model/query/transaction.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Begin { - pub(crate) immediate: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Commit; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Rollback; - -impl Shadow for Begin { - type Result = Vec>; - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - tables.snapshot = Some(tables.tables.clone()); - vec![] - } -} - -impl Shadow for Commit { - type Result = Vec>; - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - tables.snapshot = None; - vec![] - } -} - -impl Shadow for Rollback { - type Result = Vec>; - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - if let Some(tables_) = tables.snapshot.take() { - tables.tables = tables_; - } - vec![] - } -} - -impl Display for Begin { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "BEGIN {}", if self.immediate { "IMMEDIATE" } else { "" }) - } -} - -impl Display for Commit { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "COMMIT") - } -} - -impl Display for Rollback { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "ROLLBACK") - } -} diff --git a/simulator/model/query/update.rs b/simulator/model/query/update.rs deleted file mode 100644 index a4cc13fa8..000000000 --- a/simulator/model/query/update.rs +++ /dev/null @@ -1,71 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - -use super::predicate::Predicate; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Update { - pub(crate) table: String, - pub(crate) set_values: Vec<(String, SimValue)>, // Pair of value for set expressions => SET name=value - pub(crate) predicate: Predicate, -} - -impl Update { - pub fn table(&self) -> &str { - &self.table - } -} - -impl Shadow for Update { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - let table = tables.tables.iter_mut().find(|t| t.name == self.table); - - let table = if let Some(table) = table { - table - } else { - return Err(anyhow::anyhow!( - "Table {} does not exist. UPDATE statement ignored.", - self.table - )); - }; - - let t2 = table.clone(); - for row in table - .rows - .iter_mut() - .filter(|r| self.predicate.test(r, &t2)) - { - for (column, set_value) in &self.set_values { - if let Some((idx, _)) = table - .columns - .iter() - .enumerate() - .find(|(_, c)| &c.name == column) - { - row[idx] = set_value.clone(); - } - } - } - - Ok(vec![]) - } -} - -impl Display for Update { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "UPDATE {} SET ", self.table)?; - for (i, (name, value)) in self.set_values.iter().enumerate() { - if i != 0 { - write!(f, ", ")?; - } - write!(f, "{name} = {value}")?; - } - write!(f, " WHERE {}", self.predicate)?; - Ok(()) - } -} diff --git a/simulator/runner/clock.rs b/simulator/runner/clock.rs index ef687c5c1..871a01346 100644 --- a/simulator/runner/clock.rs +++ b/simulator/runner/clock.rs @@ -27,7 +27,7 @@ impl SimulatorClock { let nanos = self .rng .borrow_mut() - .gen_range(self.min_tick..self.max_tick); + .random_range(self.min_tick..self.max_tick); let nanos = std::time::Duration::from_micros(nanos); *time += nanos; *time diff --git a/simulator/runner/differential.rs b/simulator/runner/differential.rs index 7d37babe7..5723418c1 100644 --- a/simulator/runner/differential.rs +++ b/simulator/runner/differential.rs @@ -1,14 +1,14 @@ use std::sync::{Arc, Mutex}; +use sql_generation::{generation::pick_index, model::table::SimValue}; use turso_core::Value; use crate::{ generation::{ - pick_index, plan::{Interaction, InteractionPlanState, ResultSet}, Shadow as _, }, - model::{query::Query, table::SimValue}, + model::Query, runner::execution::ExecutionContinuation, InteractionPlan, }; diff --git a/simulator/runner/doublecheck.rs b/simulator/runner/doublecheck.rs index 5ba98ca50..7c9d33b4e 100644 --- a/simulator/runner/doublecheck.rs +++ b/simulator/runner/doublecheck.rs @@ -3,9 +3,10 @@ use std::{ sync::{Arc, Mutex}, }; +use sql_generation::generation::pick_index; + use crate::{ - generation::{pick_index, plan::InteractionPlanState}, - runner::execution::ExecutionContinuation, + generation::plan::InteractionPlanState, runner::execution::ExecutionContinuation, InteractionPlan, }; diff --git a/simulator/runner/env.rs b/simulator/runner/env.rs index f5787bc57..a29adc591 100644 --- a/simulator/runner/env.rs +++ b/simulator/runner/env.rs @@ -7,10 +7,9 @@ use std::sync::Arc; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; +use sql_generation::model::table::Table; use turso_core::Database; -use crate::model::table::Table; - use crate::runner::io::SimulatorIO; use super::cli::SimulatorCLI; @@ -173,29 +172,29 @@ impl SimulatorEnv { let mut delete_percent = 0.0; let mut update_percent = 0.0; - let read_percent = rng.gen_range(0.0..=total); + let read_percent = rng.random_range(0.0..=total); let write_percent = total - read_percent; if !cli_opts.disable_create { // Create percent should be 5-15% of the write percent - create_percent = rng.gen_range(0.05..=0.15) * write_percent; + create_percent = rng.random_range(0.05..=0.15) * write_percent; } if !cli_opts.disable_create_index { // Create indexpercent should be 2-5% of the write percent - create_index_percent = rng.gen_range(0.02..=0.05) * write_percent; + create_index_percent = rng.random_range(0.02..=0.05) * write_percent; } if !cli_opts.disable_drop { // Drop percent should be 2-5% of the write percent - drop_percent = rng.gen_range(0.02..=0.05) * write_percent; + drop_percent = rng.random_range(0.02..=0.05) * write_percent; } if !cli_opts.disable_delete { // Delete percent should be 10-20% of the write percent - delete_percent = rng.gen_range(0.1..=0.2) * write_percent; + delete_percent = rng.random_range(0.1..=0.2) * write_percent; } if !cli_opts.disable_update { // Update percent should be 10-20% of the write percent // TODO: freestyling the percentage - update_percent = rng.gen_range(0.1..=0.2) * write_percent; + update_percent = rng.random_range(0.1..=0.2) * write_percent; } let write_percent = write_percent @@ -220,10 +219,10 @@ impl SimulatorEnv { let opts = SimulatorOpts { seed, - ticks: rng.gen_range(cli_opts.minimum_tests..=cli_opts.maximum_tests), + ticks: rng.random_range(cli_opts.minimum_tests..=cli_opts.maximum_tests), max_connections: 1, // TODO: for now let's use one connection as we didn't implement // correct transactions processing - max_tables: rng.gen_range(0..128), + max_tables: rng.random_range(0..128), create_percent, create_index_percent, read_percent, @@ -243,7 +242,7 @@ impl SimulatorEnv { disable_fsync_no_wait: cli_opts.disable_fsync_no_wait, disable_faulty_query: cli_opts.disable_faulty_query, page_size: 4096, // TODO: randomize this too - max_interactions: rng.gen_range(cli_opts.minimum_tests..=cli_opts.maximum_tests), + max_interactions: rng.random_range(cli_opts.minimum_tests..=cli_opts.maximum_tests), max_time_simulation: cli_opts.maximum_time, disable_reopen_database: cli_opts.disable_reopen_database, latency_probability: cli_opts.latency_probability, diff --git a/simulator/runner/execution.rs b/simulator/runner/execution.rs index 9cbac3826..fa3dcbff9 100644 --- a/simulator/runner/execution.rs +++ b/simulator/runner/execution.rs @@ -1,10 +1,10 @@ use std::sync::{Arc, Mutex}; +use sql_generation::generation::pick_index; use tracing::instrument; use turso_core::{Connection, LimboError, Result, StepResult}; use crate::generation::{ - pick_index, plan::{Interaction, InteractionPlan, InteractionPlanState, ResultSet}, Shadow as _, }; diff --git a/simulator/runner/file.rs b/simulator/runner/file.rs index c8c5ff4fa..bbda05b1d 100644 --- a/simulator/runner/file.rs +++ b/simulator/runner/file.rs @@ -9,7 +9,7 @@ use rand_chacha::ChaCha8Rng; use tracing::{instrument, Level}; use turso_core::{File, Result}; -use crate::{model::FAULT_ERROR_MSG, runner::clock::SimulatorClock}; +use crate::runner::{clock::SimulatorClock, FAULT_ERROR_MSG}; pub(crate) struct SimulatorFile { pub path: String, pub(crate) inner: Arc, @@ -100,10 +100,10 @@ impl SimulatorFile { fn generate_latency_duration(&self) -> Option { let mut rng = self.rng.borrow_mut(); // Chance to introduce some latency - rng.gen_bool(self.latency_probability as f64 / 100.0) + rng.random_bool(self.latency_probability as f64 / 100.0) .then(|| { let now = self.clock.now(); - let sum = now + std::time::Duration::from_millis(rng.gen_range(5..20)); + let sum = now + std::time::Duration::from_millis(rng.random_range(5..20)); sum.into() }) } diff --git a/simulator/runner/mod.rs b/simulator/runner/mod.rs index b56335da5..3eef78331 100644 --- a/simulator/runner/mod.rs +++ b/simulator/runner/mod.rs @@ -9,3 +9,5 @@ pub mod execution; pub mod file; pub mod io; pub mod watch; + +pub const FAULT_ERROR_MSG: &str = "Injected Fault"; diff --git a/simulator/runner/watch.rs b/simulator/runner/watch.rs index 90e8edc68..feab80af1 100644 --- a/simulator/runner/watch.rs +++ b/simulator/runner/watch.rs @@ -1,10 +1,9 @@ use std::sync::{Arc, Mutex}; +use sql_generation::generation::pick_index; + use crate::{ - generation::{ - pick_index, - plan::{Interaction, InteractionPlanState}, - }, + generation::plan::{Interaction, InteractionPlanState}, runner::execution::ExecutionContinuation, }; diff --git a/simulator/shrink/plan.rs b/simulator/shrink/plan.rs index f08ccbb5a..bccd07afd 100644 --- a/simulator/shrink/plan.rs +++ b/simulator/shrink/plan.rs @@ -1,9 +1,9 @@ -use crate::model::query::Query; use crate::{ generation::{ plan::{Interaction, InteractionPlan, Interactions}, property::Property, }, + model::Query, run_simulation, runner::execution::Execution, SandboxedResult, SimulatorEnv, diff --git a/sql_generation/Cargo.toml b/sql_generation/Cargo.toml new file mode 100644 index 000000000..d84d08380 --- /dev/null +++ b/sql_generation/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "sql_generation" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +path = "lib.rs" + +[dependencies] +hex = "0.4.3" +serde = { workspace = true, features = ["derive"] } +turso_core = { workspace = true, features = ["simulator"] } +turso_parser = { workspace = true, features = ["serde"] } +rand = { workspace = true } +anarchist-readable-name-generator-lib = "0.2.0" +itertools = { workspace = true } +anyhow = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +rand_chacha = "0.9.0" diff --git a/simulator/generation/expr.rs b/sql_generation/generation/expr.rs similarity index 85% rename from simulator/generation/expr.rs rename to sql_generation/generation/expr.rs index 682c38d5c..c07d81414 100644 --- a/simulator/generation/expr.rs +++ b/sql_generation/generation/expr.rs @@ -1,14 +1,13 @@ -use turso_sqlite3_parser::ast::{ +use turso_parser::ast::{ self, Expr, LikeOperator, Name, Operator, QualifiedName, Type, UnaryOperator, }; use crate::{ generation::{ frequency, gen_random_text, one_of, pick, pick_index, Arbitrary, ArbitraryFrom, - ArbitrarySizedFrom, + ArbitrarySizedFrom, GenerationContext, }, model::table::SimValue, - SimulatorEnv, }; impl Arbitrary for Box @@ -34,7 +33,7 @@ where T: Arbitrary, { fn arbitrary(rng: &mut R) -> Self { - rng.gen_bool(0.5).then_some(T::arbitrary(rng)) + rng.random_bool(0.5).then_some(T::arbitrary(rng)) } } @@ -43,7 +42,7 @@ where T: ArbitrarySizedFrom, { fn arbitrary_sized_from(rng: &mut R, t: A, size: usize) -> Self { - rng.gen_bool(0.5) + rng.random_bool(0.5) .then_some(T::arbitrary_sized_from(rng, t, size)) } } @@ -53,14 +52,14 @@ where T: ArbitraryFrom, { fn arbitrary_from(rng: &mut R, t: A) -> Self { - let size = rng.gen_range(0..5); + let size = rng.random_range(0..5); (0..size).map(|_| T::arbitrary_from(rng, t)).collect() } } // Freestyling generation -impl ArbitrarySizedFrom<&SimulatorEnv> for Expr { - fn arbitrary_sized_from(rng: &mut R, t: &SimulatorEnv, size: usize) -> Self { +impl ArbitrarySizedFrom<&C> for Expr { + fn arbitrary_sized_from(rng: &mut R, t: &C, size: usize) -> Self { frequency( vec![ ( @@ -200,36 +199,23 @@ impl Arbitrary for Type { } } -struct CollateName(String); - -impl Arbitrary for CollateName { - fn arbitrary(rng: &mut R) -> Self { - let choice = rng.gen_range(0..3); - CollateName( - match choice { - 0 => "BINARY", - 1 => "RTRIM", - 2 => "NOCASE", - _ => unreachable!(), - } - .to_string(), - ) - } -} - -impl ArbitraryFrom<&SimulatorEnv> for QualifiedName { - fn arbitrary_from(rng: &mut R, t: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for QualifiedName { + fn arbitrary_from(rng: &mut R, t: &C) -> Self { // TODO: for now just generate table name - let table_idx = pick_index(t.tables.len(), rng); - let table = &t.tables[table_idx]; + let table_idx = pick_index(t.tables().len(), rng); + let table = &t.tables()[table_idx]; // TODO: for now forego alias - Self::single(Name::from_str(&table.name)) + Self { + db_name: None, + name: Name::new(&table.name), + alias: None, + } } } -impl ArbitraryFrom<&SimulatorEnv> for LikeOperator { - fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { - let choice = rng.gen_range(0..4); +impl ArbitraryFrom<&C> for LikeOperator { + fn arbitrary_from(rng: &mut R, _t: &C) -> Self { + let choice = rng.random_range(0..4); match choice { 0 => LikeOperator::Glob, 1 => LikeOperator::Like, @@ -241,17 +227,17 @@ impl ArbitraryFrom<&SimulatorEnv> for LikeOperator { } // Current implementation does not take into account the columns affinity nor if table is Strict -impl ArbitraryFrom<&SimulatorEnv> for ast::Literal { - fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for ast::Literal { + fn arbitrary_from(rng: &mut R, _t: &C) -> Self { loop { - let choice = rng.gen_range(0..5); + let choice = rng.random_range(0..5); let lit = match choice { 0 => ast::Literal::Numeric({ - let integer = rng.gen_bool(0.5); + let integer = rng.random_bool(0.5); if integer { - rng.gen_range(i64::MIN..i64::MAX).to_string() + rng.random_range(i64::MIN..i64::MAX).to_string() } else { - rng.gen_range(-1e10..1e10).to_string() + rng.random_range(-1e10..1e10).to_string() } }), 1 => ast::Literal::String(format!("'{}'", gen_random_text(rng))), @@ -279,9 +265,9 @@ impl ArbitraryFrom<&Vec<&SimValue>> for ast::Expr { } } -impl ArbitraryFrom<&SimulatorEnv> for UnaryOperator { - fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { - let choice = rng.gen_range(0..4); +impl ArbitraryFrom<&C> for UnaryOperator { + fn arbitrary_from(rng: &mut R, _t: &C) -> Self { + let choice = rng.random_range(0..4); match choice { 0 => Self::BitwiseNot, 1 => Self::Negative, diff --git a/sql_generation/generation/mod.rs b/sql_generation/generation/mod.rs new file mode 100644 index 000000000..25bd7ec09 --- /dev/null +++ b/sql_generation/generation/mod.rs @@ -0,0 +1,188 @@ +use std::{iter::Sum, ops::SubAssign}; + +use anarchist_readable_name_generator_lib::readable_name_custom; +use rand::{distr::uniform::SampleUniform, Rng}; + +use crate::model::table::Table; + +pub mod expr; +pub mod predicate; +pub mod query; +pub mod table; + +#[derive(Debug, Clone, Copy)] +pub struct Opts { + /// Indexes enabled + pub indexes: bool, +} + +/// Trait used to provide context to generation functions +pub trait GenerationContext { + fn tables(&self) -> &Vec
; + fn opts(&self) -> Opts; +} + +type ArbitraryFromFunc<'a, R, T> = Box T + 'a>; +type Choice<'a, R, T> = (usize, Box Option + 'a>); + +/// Arbitrary trait for generating random values +/// An implementation of arbitrary is assumed to be a uniform sampling of +/// the possible values of the type, with a bias towards smaller values for +/// practicality. +pub trait Arbitrary { + fn arbitrary(rng: &mut R) -> Self; +} + +/// ArbitrarySized trait for generating random values of a specific size +/// An implementation of arbitrary_sized is assumed to be a uniform sampling of +/// the possible values of the type, with a bias towards smaller values for +/// practicality, but with the additional constraint that the generated value +/// must fit in the given size. This is useful for generating values that are +/// constrained by a specific size, such as integers or strings. +pub trait ArbitrarySized { + fn arbitrary_sized(rng: &mut R, size: usize) -> Self; +} + +/// ArbitraryFrom trait for generating random values from a given value +/// ArbitraryFrom allows for constructing relations, where the generated +/// value is dependent on the given value. These relations could be constraints +/// such as generating an integer within an interval, or a value that fits in a table, +/// or a predicate satisfying a given table row. +pub trait ArbitraryFrom { + fn arbitrary_from(rng: &mut R, t: T) -> Self; +} + +/// ArbitrarySizedFrom trait for generating random values from a given value +/// ArbitrarySizedFrom allows for constructing relations, where the generated +/// value is dependent on the given value and a size constraint. These relations +/// could be constraints such as generating an integer within an interval, +/// or a value that fits in a table, or a predicate satisfying a given table row, +/// but with the additional constraint that the generated value must fit in the given size. +/// This is useful for generating values that are constrained by a specific size, +/// such as integers or strings, while still being dependent on the given value. +pub trait ArbitrarySizedFrom { + fn arbitrary_sized_from(rng: &mut R, t: T, size: usize) -> Self; +} + +/// ArbitraryFromMaybe trait for fallibally generating random values from a given value +pub trait ArbitraryFromMaybe { + fn arbitrary_from_maybe(rng: &mut R, t: T) -> Option + where + Self: Sized; +} + +/// Frequency is a helper function for composing different generators with different frequency +/// of occurrences. +/// The type signature for the `N` parameter is a bit complex, but it +/// roughly corresponds to a type that can be summed, compared, subtracted and sampled, which are +/// the operations we require for the implementation. +// todo: switch to a simpler type signature that can accommodate all integer and float types, which +// should be enough for our purposes. +pub fn frequency( + choices: Vec<(N, ArbitraryFromFunc)>, + rng: &mut R, +) -> T { + let total = choices.iter().map(|(weight, _)| *weight).sum::(); + let mut choice = rng.random_range(N::default()..total); + + for (weight, f) in choices { + if choice < weight { + return f(rng); + } + choice -= weight; + } + + unreachable!() +} + +/// one_of is a helper function for composing different generators with equal probability of occurrence. +pub fn one_of(choices: Vec>, rng: &mut R) -> T { + let index = rng.random_range(0..choices.len()); + choices[index](rng) +} + +/// backtrack is a helper function for composing different "failable" generators. +/// The function takes a list of functions that return an Option, along with number of retries +/// to make before giving up. +pub fn backtrack(mut choices: Vec>, rng: &mut R) -> Option { + loop { + // If there are no more choices left, we give up + let choices_ = choices + .iter() + .enumerate() + .filter(|(_, (retries, _))| *retries > 0) + .collect::>(); + if choices_.is_empty() { + tracing::trace!("backtrack: no more choices left"); + return None; + } + // Run a one_of on the remaining choices + let (choice_index, choice) = pick(&choices_, rng); + let choice_index = *choice_index; + // If the choice returns None, we decrement the number of retries and try again + let result = choice.1(rng); + if result.is_some() { + return result; + } else { + choices[choice_index].0 -= 1; + } + } +} + +/// pick is a helper function for uniformly picking a random element from a slice +pub fn pick<'a, T, R: Rng>(choices: &'a [T], rng: &mut R) -> &'a T { + let index = rng.random_range(0..choices.len()); + &choices[index] +} + +/// pick_index is typically used for picking an index from a slice to later refer to the element +/// at that index. +pub fn pick_index(choices: usize, rng: &mut R) -> usize { + rng.random_range(0..choices) +} + +/// pick_n_unique is a helper function for uniformly picking N unique elements from a range. +/// The elements themselves are usize, typically representing indices. +pub fn pick_n_unique(range: std::ops::Range, n: usize, rng: &mut R) -> Vec { + use rand::seq::SliceRandom; + let mut items: Vec = range.collect(); + items.shuffle(rng); + items.into_iter().take(n).collect() +} + +/// gen_random_text uses `anarchist_readable_name_generator_lib` to generate random +/// readable names for tables, columns, text values etc. +pub fn gen_random_text(rng: &mut T) -> String { + let big_text = rng.random_ratio(1, 1000); + if big_text { + // let max_size: u64 = 2 * 1024 * 1024 * 1024; + let max_size: u64 = 2 * 1024; + let size = rng.random_range(1024..max_size); + let mut name = String::with_capacity(size as usize); + for i in 0..size { + name.push(((i % 26) as u8 + b'A') as char); + } + name + } else { + let name = readable_name_custom("_", rng); + name.replace("-", "_") + } +} + +pub fn pick_unique( + items: &[T], + count: usize, + rng: &mut impl rand::Rng, +) -> Vec +where + ::Owned: PartialEq, +{ + let mut picked: Vec = Vec::new(); + while picked.len() < count { + let item = pick(items, rng); + if !picked.contains(&item.to_owned()) { + picked.push(item.to_owned()); + } + } + picked +} diff --git a/simulator/generation/predicate/binary.rs b/sql_generation/generation/predicate/binary.rs similarity index 87% rename from simulator/generation/predicate/binary.rs rename to sql_generation/generation/predicate/binary.rs index f8ba27236..29c1727a9 100644 --- a/simulator/generation/predicate/binary.rs +++ b/sql_generation/generation/predicate/binary.rs @@ -1,6 +1,6 @@ //! Contains code for generation for [ast::Expr::Binary] Predicate -use turso_sqlite3_parser::ast::{self, Expr}; +use turso_parser::ast::{self, Expr}; use crate::{ generation::{ @@ -56,7 +56,7 @@ impl Predicate { /// Produces a true [ast::Expr::Binary] [Predicate] that is true for the provided row in the given table pub fn true_binary(rng: &mut R, t: &Table, row: &[SimValue]) -> Predicate { // Pick a column - let column_index = rng.gen_range(0..t.columns.len()); + let column_index = rng.random_range(0..t.columns.len()); let mut column = t.columns[column_index].clone(); let value = &row[column_index]; @@ -82,8 +82,8 @@ impl Predicate { Box::new(|_| { Some(Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), + ast::Name::new(&table_name), + ast::Name::new(&column.name), )), ast::Operator::Equals, Box::new(Expr::Literal(value.into())), @@ -99,8 +99,8 @@ impl Predicate { } else { Some(Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), + ast::Name::new(&table_name), + ast::Name::new(&column.name), )), ast::Operator::NotEquals, Box::new(Expr::Literal(v.into())), @@ -114,8 +114,8 @@ impl Predicate { let lt_value = LTValue::arbitrary_from(rng, value).0; Some(Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), + ast::Name::new(&table_name), + ast::Name::new(&column.name), )), ast::Operator::Greater, Box::new(Expr::Literal(lt_value.into())), @@ -128,8 +128,8 @@ impl Predicate { let gt_value = GTValue::arbitrary_from(rng, value).0; Some(Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), + ast::Name::new(&table_name), + ast::Name::new(&column.name), )), ast::Operator::Less, Box::new(Expr::Literal(gt_value.into())), @@ -143,8 +143,8 @@ impl Predicate { LikeValue::arbitrary_from_maybe(rng, value).map(|like| { Expr::Like { lhs: Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), + ast::Name::new(&table_name), + ast::Name::new(&column.name), )), not: false, // TODO: also generate this value eventually op: ast::LikeOperator::Like, @@ -164,7 +164,7 @@ impl Predicate { /// Produces an [ast::Expr::Binary] [Predicate] that is false for the provided row in the given table pub fn false_binary(rng: &mut R, t: &Table, row: &[SimValue]) -> Predicate { // Pick a column - let column_index = rng.gen_range(0..t.columns.len()); + let column_index = rng.random_range(0..t.columns.len()); let mut column = t.columns[column_index].clone(); let mut table_name = t.name.clone(); let value = &row[column_index]; @@ -188,8 +188,8 @@ impl Predicate { Box::new(|_| { Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), + ast::Name::new(&table_name), + ast::Name::new(&column.name), )), ast::Operator::NotEquals, Box::new(Expr::Literal(value.into())), @@ -204,8 +204,8 @@ impl Predicate { }; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), + ast::Name::new(&table_name), + ast::Name::new(&column.name), )), ast::Operator::Equals, Box::new(Expr::Literal(v.into())), @@ -215,8 +215,8 @@ impl Predicate { let gt_value = GTValue::arbitrary_from(rng, value).0; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), + ast::Name::new(&table_name), + ast::Name::new(&column.name), )), ast::Operator::Greater, Box::new(Expr::Literal(gt_value.into())), @@ -226,8 +226,8 @@ impl Predicate { let lt_value = LTValue::arbitrary_from(rng, value).0; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), + ast::Name::new(&table_name), + ast::Name::new(&column.name), )), ast::Operator::Less, Box::new(Expr::Literal(lt_value.into())), @@ -249,7 +249,7 @@ impl SimplePredicate { ) -> Self { // Pick a random column let columns = table.columns().collect::>(); - let column_index = rng.gen_range(0..columns.len()); + let column_index = rng.random_range(0..columns.len()); let column = columns[column_index]; let column_value = &row[column_index]; let table_name = column.table_name; @@ -263,8 +263,8 @@ impl SimplePredicate { Box::new(|_rng| { Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), + ast::Name::new(table_name), + ast::Name::new(&column.column.name), )), ast::Operator::Equals, Box::new(Expr::Literal(column_value.into())), @@ -274,8 +274,8 @@ impl SimplePredicate { let lt_value = LTValue::arbitrary_from(rng, column_value).0; Expr::Binary( Box::new(Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), + ast::Name::new(table_name), + ast::Name::new(&column.column.name), )), ast::Operator::Greater, Box::new(Expr::Literal(lt_value.into())), @@ -285,8 +285,8 @@ impl SimplePredicate { let gt_value = GTValue::arbitrary_from(rng, column_value).0; Expr::Binary( Box::new(Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), + ast::Name::new(table_name), + ast::Name::new(&column.column.name), )), ast::Operator::Less, Box::new(Expr::Literal(gt_value.into())), @@ -306,7 +306,7 @@ impl SimplePredicate { ) -> Self { let columns = table.columns().collect::>(); // Pick a random column - let column_index = rng.gen_range(0..columns.len()); + let column_index = rng.random_range(0..columns.len()); let column = columns[column_index]; let column_value = &row[column_index]; let table_name = column.table_name; @@ -320,8 +320,8 @@ impl SimplePredicate { Box::new(|_rng| { Expr::Binary( Box::new(Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), + ast::Name::new(table_name), + ast::Name::new(&column.column.name), )), ast::Operator::NotEquals, Box::new(Expr::Literal(column_value.into())), @@ -331,8 +331,8 @@ impl SimplePredicate { let gt_value = GTValue::arbitrary_from(rng, column_value).0; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), + ast::Name::new(table_name), + ast::Name::new(&column.column.name), )), ast::Operator::Greater, Box::new(Expr::Literal(gt_value.into())), @@ -342,8 +342,8 @@ impl SimplePredicate { let lt_value = LTValue::arbitrary_from(rng, column_value).0; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), + ast::Name::new(table_name), + ast::Name::new(&column.column.name), )), ast::Operator::Less, Box::new(Expr::Literal(lt_value.into())), @@ -376,11 +376,11 @@ impl CompoundPredicate { } let row = pick(rows, rng); - let predicate = if rng.gen_bool(0.7) { + let predicate = if rng.random_bool(0.7) { // An AND for true requires each of its children to be true // An AND for false requires at least one of its children to be false if predicate_value { - (0..rng.gen_range(1..=3)) + (0..rng.random_range(1..=3)) .map(|_| SimplePredicate::arbitrary_from(rng, (table, row, true)).0) .reduce(|accum, curr| { Predicate(Expr::Binary( @@ -392,15 +392,15 @@ impl CompoundPredicate { .unwrap_or(Predicate::true_()) } else { // Create a vector of random booleans - let mut booleans = (0..rng.gen_range(1..=3)) - .map(|_| rng.gen_bool(0.5)) + let mut booleans = (0..rng.random_range(1..=3)) + .map(|_| rng.random_bool(0.5)) .collect::>(); let len = booleans.len(); // Make sure at least one of them is false if booleans.iter().all(|b| *b) { - booleans[rng.gen_range(0..len)] = false; + booleans[rng.random_range(0..len)] = false; } booleans @@ -420,13 +420,13 @@ impl CompoundPredicate { // An OR for false requires each of its children to be false if predicate_value { // Create a vector of random booleans - let mut booleans = (0..rng.gen_range(1..=3)) - .map(|_| rng.gen_bool(0.5)) + let mut booleans = (0..rng.random_range(1..=3)) + .map(|_| rng.random_bool(0.5)) .collect::>(); let len = booleans.len(); // Make sure at least one of them is true if booleans.iter().all(|b| !*b) { - booleans[rng.gen_range(0..len)] = true; + booleans[rng.random_range(0..len)] = true; } booleans @@ -441,7 +441,7 @@ impl CompoundPredicate { }) .unwrap_or(Predicate::true_()) } else { - (0..rng.gen_range(1..=3)) + (0..rng.random_range(1..=3)) .map(|_| SimplePredicate::arbitrary_from(rng, (table, row, false)).0) .reduce(|accum, curr| { Predicate(Expr::Binary( @@ -483,7 +483,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table @@ -509,7 +509,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table @@ -535,7 +535,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table @@ -563,7 +563,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table diff --git a/simulator/generation/predicate/mod.rs b/sql_generation/generation/predicate/mod.rs similarity index 95% rename from simulator/generation/predicate/mod.rs rename to sql_generation/generation/predicate/mod.rs index 5c5887818..b919ad0bd 100644 --- a/simulator/generation/predicate/mod.rs +++ b/sql_generation/generation/predicate/mod.rs @@ -1,5 +1,5 @@ use rand::{seq::SliceRandom as _, Rng}; -use turso_sqlite3_parser::ast::{self, Expr}; +use turso_parser::ast::{self, Expr}; use crate::model::{ query::predicate::Predicate, @@ -8,8 +8,8 @@ use crate::model::{ use super::{one_of, ArbitraryFrom}; -mod binary; -mod unary; +pub mod binary; +pub mod unary; #[derive(Debug)] struct CompoundPredicate(Predicate); @@ -21,7 +21,7 @@ impl, T: TableContext> ArbitraryFrom<(&T, A, bool)> for Sim fn arbitrary_from(rng: &mut R, (table, row, predicate_value): (&T, A, bool)) -> Self { let row = row.as_ref(); // Pick an operator - let choice = rng.gen_range(0..2); + let choice = rng.random_range(0..2); // Pick an operator match predicate_value { true => match choice { @@ -46,7 +46,7 @@ impl ArbitraryFrom<(&T, bool)> for CompoundPredicate { impl ArbitraryFrom<&T> for Predicate { fn arbitrary_from(rng: &mut R, table: &T) -> Self { - let predicate_value = rng.gen_bool(0.5); + let predicate_value = rng.random_bool(0.5); Predicate::arbitrary_from(rng, (table, predicate_value)).parens() } } @@ -70,11 +70,11 @@ impl ArbitraryFrom<(&Table, &Vec)> for Predicate { // are true, some that are false, combiend them in ways that correspond to the creation of a true predicate // Produce some true and false predicates - let mut true_predicates = (1..=rng.gen_range(1..=4)) + let mut true_predicates = (1..=rng.random_range(1..=4)) .map(|_| Predicate::true_binary(rng, t, row)) .collect::>(); - let false_predicates = (0..=rng.gen_range(0..=3)) + let false_predicates = (0..=rng.random_range(0..=3)) .map(|_| Predicate::false_binary(rng, t, row)) .collect::>(); @@ -92,7 +92,7 @@ impl ArbitraryFrom<(&Table, &Vec)> for Predicate { while !predicates.is_empty() { // Create a new predicate from at least 1 and at most 3 predicates let context = - predicates[0..rng.gen_range(0..=usize::min(3, predicates.len()))].to_vec(); + predicates[0..rng.random_range(0..=usize::min(3, predicates.len()))].to_vec(); // Shift `predicates` to remove the predicates in the context predicates = predicates[context.len()..].to_vec(); @@ -251,7 +251,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table @@ -277,7 +277,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table @@ -303,7 +303,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table @@ -329,7 +329,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table @@ -356,7 +356,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table diff --git a/simulator/generation/predicate/unary.rs b/sql_generation/generation/predicate/unary.rs similarity index 97% rename from simulator/generation/predicate/unary.rs rename to sql_generation/generation/predicate/unary.rs index f7f374b6e..62c6d7d65 100644 --- a/simulator/generation/predicate/unary.rs +++ b/sql_generation/generation/predicate/unary.rs @@ -2,7 +2,7 @@ //! TODO: for now just generating [ast::Literal], but want to also generate Columns and any //! arbitrary [ast::Expr] -use turso_sqlite3_parser::ast::{self, Expr}; +use turso_parser::ast::{self, Expr}; use crate::{ generation::{backtrack, pick, predicate::SimplePredicate, ArbitraryFromMaybe}, @@ -64,6 +64,7 @@ impl ArbitraryFromMaybe<&Vec<&SimValue>> for FalseValue { } } +#[allow(dead_code)] pub struct BitNotValue(pub SimValue); impl ArbitraryFromMaybe<(&SimValue, bool)> for BitNotValue { @@ -107,7 +108,7 @@ impl SimplePredicate { ) -> Self { let columns = table.columns().collect::>(); // Pick a random column - let column_index = rng.gen_range(0..columns.len()); + let column_index = rng.random_range(0..columns.len()); let column_value = &row[column_index]; let num_retries = row.len(); // Avoid creation of NULLs @@ -173,7 +174,7 @@ impl SimplePredicate { ) -> Self { let columns = table.columns().collect::>(); // Pick a random column - let column_index = rng.gen_range(0..columns.len()); + let column_index = rng.random_range(0..columns.len()); let column_value = &row[column_index]; let num_retries = row.len(); // Avoid creation of NULLs @@ -255,7 +256,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table @@ -283,7 +284,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(seed); for _ in 0..10000 { let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); + let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { table diff --git a/sql_generation/generation/query.rs b/sql_generation/generation/query.rs new file mode 100644 index 000000000..d7840a001 --- /dev/null +++ b/sql_generation/generation/query.rs @@ -0,0 +1,391 @@ +use crate::generation::{ + gen_random_text, pick_n_unique, pick_unique, Arbitrary, ArbitraryFrom, ArbitrarySizedFrom, + GenerationContext, +}; +use crate::model::query::predicate::Predicate; +use crate::model::query::select::{ + CompoundOperator, CompoundSelect, Distinctness, FromClause, OrderBy, ResultColumn, SelectBody, + SelectInner, +}; +use crate::model::query::update::Update; +use crate::model::query::{Create, CreateIndex, Delete, Drop, Insert, Select}; +use crate::model::table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}; +use itertools::Itertools; +use rand::Rng; +use turso_parser::ast::{Expr, SortOrder}; + +use super::{backtrack, pick}; + +impl Arbitrary for Create { + fn arbitrary(rng: &mut R) -> Self { + Create { + table: Table::arbitrary(rng), + } + } +} + +impl ArbitraryFrom<&Vec
> for FromClause { + fn arbitrary_from(rng: &mut R, tables: &Vec
) -> Self { + let num_joins = match rng.random_range(0..=100) { + 0..=90 => 0, + 91..=97 => 1, + 98..=100 => 2, + _ => unreachable!(), + }; + + let mut tables = tables.clone(); + let mut table = pick(&tables, rng).clone(); + + tables.retain(|t| t.name != table.name); + + let name = table.name.clone(); + + let mut table_context = JoinTable { + tables: Vec::new(), + rows: Vec::new(), + }; + + let joins: Vec<_> = (0..num_joins) + .filter_map(|_| { + if tables.is_empty() { + return None; + } + let join_table = pick(&tables, rng).clone(); + let joined_table_name = join_table.name.clone(); + + tables.retain(|t| t.name != join_table.name); + table_context.rows = table_context + .rows + .iter() + .cartesian_product(join_table.rows.iter()) + .map(|(t_row, j_row)| { + let mut row = t_row.clone(); + row.extend(j_row.clone()); + row + }) + .collect(); + // TODO: inneficient. use a Deque to push_front? + table_context.tables.insert(0, join_table); + for row in &mut table.rows { + assert_eq!( + row.len(), + table.columns.len(), + "Row length does not match column length after join" + ); + } + + let predicate = Predicate::arbitrary_from(rng, &table); + Some(JoinedTable { + table: joined_table_name, + join_type: JoinType::Inner, + on: predicate, + }) + }) + .collect(); + FromClause { table: name, joins } + } +} + +impl ArbitraryFrom<&C> for SelectInner { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let from = FromClause::arbitrary_from(rng, env.tables()); + let tables = env.tables().clone(); + let join_table = from.into_join_table(&tables); + let cuml_col_count = join_table.columns().count(); + + let order_by = 'order_by: { + if rng.random_bool(0.3) { + let order_by_table_candidates = from + .joins + .iter() + .map(|j| j.table.clone()) + .chain(std::iter::once(from.table.clone())) + .collect::>(); + let order_by_col_count = + (rng.random::() * rng.random::() * (cuml_col_count as f64)) as usize; // skew towards 0 + if order_by_col_count == 0 { + break 'order_by None; + } + let mut col_names = std::collections::HashSet::new(); + let mut order_by_cols = Vec::new(); + while order_by_cols.len() < order_by_col_count { + let table = pick(&order_by_table_candidates, rng); + let table = tables.iter().find(|t| t.name == *table).unwrap(); + let col = pick(&table.columns, rng); + let col_name = format!("{}.{}", table.name, col.name); + if col_names.insert(col_name.clone()) { + order_by_cols.push(( + col_name, + if rng.random_bool(0.5) { + SortOrder::Asc + } else { + SortOrder::Desc + }, + )); + } + } + Some(OrderBy { + columns: order_by_cols, + }) + } else { + None + } + }; + + SelectInner { + distinctness: if env.opts().indexes { + Distinctness::arbitrary(rng) + } else { + Distinctness::All + }, + columns: vec![ResultColumn::Star], + from: Some(from), + where_clause: Predicate::arbitrary_from(rng, &join_table), + order_by, + } + } +} + +impl ArbitrarySizedFrom<&C> for SelectInner { + fn arbitrary_sized_from(rng: &mut R, env: &C, num_result_columns: usize) -> Self { + let mut select_inner = SelectInner::arbitrary_from(rng, env); + let select_from = &select_inner.from.as_ref().unwrap(); + let table_names = select_from + .joins + .iter() + .map(|j| j.table.clone()) + .chain(std::iter::once(select_from.table.clone())) + .collect::>(); + + let flat_columns_names = table_names + .iter() + .flat_map(|t| { + env.tables() + .iter() + .find(|table| table.name == *t) + .unwrap() + .columns + .iter() + .map(|c| format!("{}.{}", t.clone(), c.name)) + }) + .collect::>(); + let selected_columns = pick_unique(&flat_columns_names, num_result_columns, rng); + let mut columns = Vec::new(); + for column_name in selected_columns { + columns.push(ResultColumn::Column(column_name.clone())); + } + select_inner.columns = columns; + select_inner + } +} + +impl Arbitrary for Distinctness { + fn arbitrary(rng: &mut R) -> Self { + match rng.random_range(0..=5) { + 0..4 => Distinctness::All, + _ => Distinctness::Distinct, + } + } +} +impl Arbitrary for CompoundOperator { + fn arbitrary(rng: &mut R) -> Self { + match rng.random_range(0..=1) { + 0 => CompoundOperator::Union, + 1 => CompoundOperator::UnionAll, + _ => unreachable!(), + } + } +} + +/// SelectFree is a wrapper around Select that allows for arbitrary generation +/// of selects without requiring a specific environment, which is useful for generating +/// arbitrary expressions without referring to the tables. +pub struct SelectFree(pub Select); + +impl ArbitraryFrom<&C> for SelectFree { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let expr = Predicate(Expr::arbitrary_sized_from(rng, env, 8)); + let select = Select::expr(expr); + Self(select) + } +} + +impl ArbitraryFrom<&C> for Select { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + // Generate a number of selects based on the query size + // If experimental indexes are enabled, we can have selects with compounds + // Otherwise, we just have a single select with no compounds + let num_compound_selects = if env.opts().indexes { + match rng.random_range(0..=100) { + 0..=95 => 0, + 96..=99 => 1, + 100 => 2, + _ => unreachable!(), + } + } else { + 0 + }; + + let min_column_count_across_tables = + env.tables().iter().map(|t| t.columns.len()).min().unwrap(); + + let num_result_columns = rng.random_range(1..=min_column_count_across_tables); + + let mut first = SelectInner::arbitrary_sized_from(rng, env, num_result_columns); + + let mut rest: Vec = (0..num_compound_selects) + .map(|_| SelectInner::arbitrary_sized_from(rng, env, num_result_columns)) + .collect(); + + if !rest.is_empty() { + // ORDER BY is not supported in compound selects yet + first.order_by = None; + for s in &mut rest { + s.order_by = None; + } + } + + Self { + body: SelectBody { + select: Box::new(first), + compounds: rest + .into_iter() + .map(|s| CompoundSelect { + operator: CompoundOperator::arbitrary(rng), + select: Box::new(s), + }) + .collect(), + }, + limit: None, + } + } +} + +impl ArbitraryFrom<&C> for Insert { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let gen_values = |rng: &mut R| { + let table = pick(env.tables(), rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(rng, &c.column_type)) + .collect() + }) + .collect(); + Some(Insert::Values { + table: table.name.clone(), + values, + }) + }; + + let _gen_select = |rng: &mut R| { + // Find a non-empty table + let select_table = env.tables().iter().find(|t| !t.rows.is_empty())?; + let row = pick(&select_table.rows, rng); + let predicate = Predicate::arbitrary_from(rng, (select_table, row)); + // Pick another table to insert into + let select = Select::simple(select_table.name.clone(), predicate); + let table = pick(env.tables(), rng); + Some(Insert::Select { + table: table.name.clone(), + select: Box::new(select), + }) + }; + + // TODO: Add back gen_select when https://github.com/tursodatabase/turso/issues/2129 is fixed. + // Backtrack here cannot return None + backtrack(vec![(1, Box::new(gen_values))], rng).unwrap() + } +} + +impl ArbitraryFrom<&C> for Delete { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let table = pick(env.tables(), rng); + Self { + table: table.name.clone(), + predicate: Predicate::arbitrary_from(rng, table), + } + } +} + +impl ArbitraryFrom<&C> for Drop { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let table = pick(env.tables(), rng); + Self { + table: table.name.clone(), + } + } +} + +impl ArbitraryFrom<&C> for CreateIndex { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + assert!( + !env.tables().is_empty(), + "Cannot create an index when no tables exist in the environment." + ); + + let table = pick(env.tables(), rng); + + if table.columns.is_empty() { + panic!( + "Cannot create an index on table '{}' as it has no columns.", + table.name + ); + } + + let num_columns_to_pick = rng.random_range(1..=table.columns.len()); + let picked_column_indices = pick_n_unique(0..table.columns.len(), num_columns_to_pick, rng); + + let columns = picked_column_indices + .into_iter() + .map(|i| { + let column = &table.columns[i]; + ( + column.name.clone(), + if rng.random_bool(0.5) { + SortOrder::Asc + } else { + SortOrder::Desc + }, + ) + }) + .collect::>(); + + let index_name = format!( + "idx_{}_{}", + table.name, + gen_random_text(rng).chars().take(8).collect::() + ); + + CreateIndex { + index_name, + table_name: table.name.clone(), + columns, + } + } +} + +impl ArbitraryFrom<&C> for Update { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let table = pick(env.tables(), rng); + let num_cols = rng.random_range(1..=table.columns.len()); + let columns = pick_unique(&table.columns, num_cols, rng); + let set_values: Vec<(String, SimValue)> = columns + .iter() + .map(|column| { + ( + column.name.clone(), + SimValue::arbitrary_from(rng, &column.column_type), + ) + }) + .collect(); + Update { + table: table.name.clone(), + set_values, + predicate: Predicate::arbitrary_from(rng, table), + } + } +} diff --git a/simulator/generation/table.rs b/sql_generation/generation/table.rs similarity index 80% rename from simulator/generation/table.rs rename to sql_generation/generation/table.rs index 66f48b5ad..d21397cbe 100644 --- a/simulator/generation/table.rs +++ b/sql_generation/generation/table.rs @@ -19,11 +19,11 @@ impl Arbitrary for Table { fn arbitrary(rng: &mut R) -> Self { let name = Name::arbitrary(rng).0; let columns = loop { - let large_table = rng.gen_bool(0.1); + let large_table = rng.random_bool(0.1); let column_size = if large_table { - rng.gen_range(64..125) // todo: make this higher (128+) + rng.random_range(64..125) // todo: make this higher (128+) } else { - rng.gen_range(1..=10) + rng.random_range(1..=10) }; let columns = (1..=column_size) .map(|_| Column::arbitrary(rng)) @@ -90,8 +90,8 @@ impl ArbitraryFrom<&Vec<&SimValue>> for SimValue { impl ArbitraryFrom<&ColumnType> for SimValue { fn arbitrary_from(rng: &mut R, column_type: &ColumnType) -> Self { let value = match column_type { - ColumnType::Integer => Value::Integer(rng.gen_range(i64::MIN..i64::MAX)), - ColumnType::Float => Value::Float(rng.gen_range(-1e10..1e10)), + ColumnType::Integer => Value::Integer(rng.random_range(i64::MIN..i64::MAX)), + ColumnType::Float => Value::Float(rng.random_range(-1e10..1e10)), ColumnType::Text => Value::build_text(gen_random_text(rng)), ColumnType::Blob => Value::Blob(gen_random_text(rng).as_bytes().to_vec()), }; @@ -99,7 +99,7 @@ impl ArbitraryFrom<&ColumnType> for SimValue { } } -pub(crate) struct LTValue(pub(crate) SimValue); +pub struct LTValue(pub SimValue); impl ArbitraryFrom<&Vec<&SimValue>> for LTValue { fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { @@ -116,21 +116,21 @@ impl ArbitraryFrom<&Vec<&SimValue>> for LTValue { impl ArbitraryFrom<&SimValue> for LTValue { fn arbitrary_from(rng: &mut R, value: &SimValue) -> Self { let new_value = match &value.0 { - Value::Integer(i) => Value::Integer(rng.gen_range(i64::MIN..*i - 1)), - Value::Float(f) => Value::Float(f - rng.gen_range(0.0..1e10)), + Value::Integer(i) => Value::Integer(rng.random_range(i64::MIN..*i - 1)), + Value::Float(f) => Value::Float(f - rng.random_range(0.0..1e10)), value @ Value::Text(..) => { // Either shorten the string, or make at least one character smaller and mutate the rest let mut t = value.to_string(); - if rng.gen_bool(0.01) { + if rng.random_bool(0.01) { t.pop(); Value::build_text(t) } else { let mut t = t.chars().map(|c| c as u32).collect::>(); - let index = rng.gen_range(0..t.len()); + let index = rng.random_range(0..t.len()); t[index] -= 1; // Mutate the rest of the string for val in t.iter_mut().skip(index + 1) { - *val = rng.gen_range('a' as u32..='z' as u32); + *val = rng.random_range('a' as u32..='z' as u32); } let t = t .into_iter() @@ -142,15 +142,15 @@ impl ArbitraryFrom<&SimValue> for LTValue { Value::Blob(b) => { // Either shorten the blob, or make at least one byte smaller and mutate the rest let mut b = b.clone(); - if rng.gen_bool(0.01) { + if rng.random_bool(0.01) { b.pop(); Value::Blob(b) } else { - let index = rng.gen_range(0..b.len()); + let index = rng.random_range(0..b.len()); b[index] -= 1; // Mutate the rest of the blob for val in b.iter_mut().skip(index + 1) { - *val = rng.gen_range(0..=255); + *val = rng.random_range(0..=255); } Value::Blob(b) } @@ -161,7 +161,7 @@ impl ArbitraryFrom<&SimValue> for LTValue { } } -pub(crate) struct GTValue(pub(crate) SimValue); +pub struct GTValue(pub SimValue); impl ArbitraryFrom<&Vec<&SimValue>> for GTValue { fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { @@ -178,21 +178,21 @@ impl ArbitraryFrom<&Vec<&SimValue>> for GTValue { impl ArbitraryFrom<&SimValue> for GTValue { fn arbitrary_from(rng: &mut R, value: &SimValue) -> Self { let new_value = match &value.0 { - Value::Integer(i) => Value::Integer(rng.gen_range(*i..i64::MAX)), - Value::Float(f) => Value::Float(rng.gen_range(*f..1e10)), + Value::Integer(i) => Value::Integer(rng.random_range(*i..i64::MAX)), + Value::Float(f) => Value::Float(rng.random_range(*f..1e10)), value @ Value::Text(..) => { // Either lengthen the string, or make at least one character smaller and mutate the rest let mut t = value.to_string(); - if rng.gen_bool(0.01) { - t.push(rng.gen_range(0..=255) as u8 as char); + if rng.random_bool(0.01) { + t.push(rng.random_range(0..=255) as u8 as char); Value::build_text(t) } else { let mut t = t.chars().map(|c| c as u32).collect::>(); - let index = rng.gen_range(0..t.len()); + let index = rng.random_range(0..t.len()); t[index] += 1; // Mutate the rest of the string for val in t.iter_mut().skip(index + 1) { - *val = rng.gen_range('a' as u32..='z' as u32); + *val = rng.random_range('a' as u32..='z' as u32); } let t = t .into_iter() @@ -204,15 +204,15 @@ impl ArbitraryFrom<&SimValue> for GTValue { Value::Blob(b) => { // Either lengthen the blob, or make at least one byte smaller and mutate the rest let mut b = b.clone(); - if rng.gen_bool(0.01) { - b.push(rng.gen_range(0..=255)); + if rng.random_bool(0.01) { + b.push(rng.random_range(0..=255)); Value::Blob(b) } else { - let index = rng.gen_range(0..b.len()); + let index = rng.random_range(0..b.len()); b[index] += 1; // Mutate the rest of the blob for val in b.iter_mut().skip(index + 1) { - *val = rng.gen_range(0..=255); + *val = rng.random_range(0..=255); } Value::Blob(b) } @@ -223,7 +223,7 @@ impl ArbitraryFrom<&SimValue> for GTValue { } } -pub(crate) struct LikeValue(pub(crate) SimValue); +pub struct LikeValue(pub SimValue); impl ArbitraryFromMaybe<&SimValue> for LikeValue { fn arbitrary_from_maybe(rng: &mut R, value: &SimValue) -> Option { @@ -235,18 +235,18 @@ impl ArbitraryFromMaybe<&SimValue> for LikeValue { // insert one `%` for the whole substring let mut i = 0; while i < t.len() { - if rng.gen_bool(0.1) { + if rng.random_bool(0.1) { t[i] = '_'; - } else if rng.gen_bool(0.05) { + } else if rng.random_bool(0.05) { t[i] = '%'; // skip a list of characters - for _ in 0..rng.gen_range(0..=3.min(t.len() - i - 1)) { + for _ in 0..rng.random_range(0..=3.min(t.len() - i - 1)) { t.remove(i + 1); } } i += 1; } - let index = rng.gen_range(0..t.len()); + let index = rng.random_range(0..t.len()); t.insert(index, '%'); Some(Self(SimValue(Value::build_text( t.into_iter().collect::(), diff --git a/sql_generation/lib.rs b/sql_generation/lib.rs new file mode 100644 index 000000000..f52cdebdf --- /dev/null +++ b/sql_generation/lib.rs @@ -0,0 +1,2 @@ +pub mod generation; +pub mod model; diff --git a/sql_generation/model/mod.rs b/sql_generation/model/mod.rs new file mode 100644 index 000000000..a29f56382 --- /dev/null +++ b/sql_generation/model/mod.rs @@ -0,0 +1,2 @@ +pub mod query; +pub mod table; diff --git a/sql_generation/model/query/create.rs b/sql_generation/model/query/create.rs new file mode 100644 index 000000000..607d5fe8d --- /dev/null +++ b/sql_generation/model/query/create.rs @@ -0,0 +1,25 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use crate::model::table::Table; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Create { + pub table: Table, +} + +impl Display for Create { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CREATE TABLE {} (", self.table.name)?; + + for (i, column) in self.table.columns.iter().enumerate() { + if i != 0 { + write!(f, ",")?; + } + write!(f, "{} {}", column.name, column.column_type)?; + } + + write!(f, ")") + } +} diff --git a/sql_generation/model/query/create_index.rs b/sql_generation/model/query/create_index.rs new file mode 100644 index 000000000..db9d15a04 --- /dev/null +++ b/sql_generation/model/query/create_index.rs @@ -0,0 +1,25 @@ +use serde::{Deserialize, Serialize}; +use turso_parser::ast::SortOrder; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct CreateIndex { + pub index_name: String, + pub table_name: String, + pub columns: Vec<(String, SortOrder)>, +} + +impl std::fmt::Display for CreateIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "CREATE INDEX {} ON {} ({})", + self.index_name, + self.table_name, + self.columns + .iter() + .map(|(name, order)| format!("{name} {order}")) + .collect::>() + .join(", ") + ) + } +} diff --git a/sql_generation/model/query/delete.rs b/sql_generation/model/query/delete.rs new file mode 100644 index 000000000..89ebd61b8 --- /dev/null +++ b/sql_generation/model/query/delete.rs @@ -0,0 +1,17 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use super::predicate::Predicate; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Delete { + pub table: String, + pub predicate: Predicate, +} + +impl Display for Delete { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DELETE FROM {} WHERE {}", self.table, self.predicate) + } +} diff --git a/sql_generation/model/query/drop.rs b/sql_generation/model/query/drop.rs new file mode 100644 index 000000000..0d0ef31bb --- /dev/null +++ b/sql_generation/model/query/drop.rs @@ -0,0 +1,14 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Drop { + pub table: String, +} + +impl Display for Drop { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DROP TABLE {}", self.table) + } +} diff --git a/simulator/model/query/insert.rs b/sql_generation/model/query/insert.rs similarity index 52% rename from simulator/model/query/insert.rs rename to sql_generation/model/query/insert.rs index 3dc8659df..d69921388 100644 --- a/simulator/model/query/insert.rs +++ b/sql_generation/model/query/insert.rs @@ -2,12 +2,12 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; +use crate::model::table::SimValue; use super::select::Select; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) enum Insert { +pub enum Insert { Values { table: String, values: Vec>, @@ -18,40 +18,8 @@ pub(crate) enum Insert { }, } -impl Shadow for Insert { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - match self { - Insert::Values { table, values } => { - if let Some(t) = tables.tables.iter_mut().find(|t| &t.name == table) { - t.rows.extend(values.clone()); - } else { - return Err(anyhow::anyhow!( - "Table {} does not exist. INSERT statement ignored.", - table - )); - } - } - Insert::Select { table, select } => { - let rows = select.shadow(tables)?; - if let Some(t) = tables.tables.iter_mut().find(|t| &t.name == table) { - t.rows.extend(rows); - } else { - return Err(anyhow::anyhow!( - "Table {} does not exist. INSERT statement ignored.", - table - )); - } - } - } - - Ok(vec![]) - } -} - impl Insert { - pub(crate) fn table(&self) -> &str { + pub fn table(&self) -> &str { match self { Insert::Values { table, .. } | Insert::Select { table, .. } => table, } diff --git a/sql_generation/model/query/mod.rs b/sql_generation/model/query/mod.rs new file mode 100644 index 000000000..5bf0cecde --- /dev/null +++ b/sql_generation/model/query/mod.rs @@ -0,0 +1,34 @@ +pub use create::Create; +pub use create_index::CreateIndex; +pub use delete::Delete; +pub use drop::Drop; +pub use insert::Insert; +pub use select::Select; +use turso_parser::ast::fmt::ToSqlContext; + +pub mod create; +pub mod create_index; +pub mod delete; +pub mod drop; +pub mod insert; +pub mod predicate; +pub mod select; +pub mod transaction; +pub mod update; + +/// Used to print sql strings that already have all the context it needs +pub struct EmptyContext; + +impl ToSqlContext for EmptyContext { + fn get_column_name( + &self, + _table_id: turso_parser::ast::TableInternalId, + _col_idx: usize, + ) -> String { + unreachable!() + } + + fn get_table_name(&self, _id: turso_parser::ast::TableInternalId) -> &str { + unreachable!() + } +} diff --git a/simulator/model/query/predicate.rs b/sql_generation/model/query/predicate.rs similarity index 85% rename from simulator/model/query/predicate.rs rename to sql_generation/model/query/predicate.rs index c55c45417..29f0b966c 100644 --- a/simulator/model/query/predicate.rs +++ b/sql_generation/model/query/predicate.rs @@ -1,7 +1,7 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; -use turso_sqlite3_parser::ast::{self, fmt::ToTokens}; +use turso_parser::ast::{self, fmt::ToTokens}; use crate::model::table::{SimValue, Table, TableContext}; @@ -9,27 +9,28 @@ use crate::model::table::{SimValue, Table, TableContext}; pub struct Predicate(pub ast::Expr); impl Predicate { - pub(crate) fn true_() -> Self { + pub fn true_() -> Self { Self(ast::Expr::Literal(ast::Literal::Keyword( "TRUE".to_string(), ))) } - pub(crate) fn false_() -> Self { + pub fn false_() -> Self { Self(ast::Expr::Literal(ast::Literal::Keyword( "FALSE".to_string(), ))) } - pub(crate) fn null() -> Self { + pub fn null() -> Self { Self(ast::Expr::Literal(ast::Literal::Null)) } - pub(crate) fn not(predicate: Predicate) -> Self { + #[allow(clippy::should_implement_trait)] + pub fn not(predicate: Predicate) -> Self { let expr = ast::Expr::Unary(ast::UnaryOperator::Not, Box::new(predicate.0)); Self(expr).parens() } - pub(crate) fn and(predicates: Vec) -> Self { + pub fn and(predicates: Vec) -> Self { if predicates.is_empty() { Self::true_() } else if predicates.len() == 1 { @@ -44,7 +45,7 @@ impl Predicate { } } - pub(crate) fn or(predicates: Vec) -> Self { + pub fn or(predicates: Vec) -> Self { if predicates.is_empty() { Self::false_() } else if predicates.len() == 1 { @@ -59,26 +60,26 @@ impl Predicate { } } - pub(crate) fn eq(lhs: Predicate, rhs: Predicate) -> Self { + pub fn eq(lhs: Predicate, rhs: Predicate) -> Self { let expr = ast::Expr::Binary(Box::new(lhs.0), ast::Operator::Equals, Box::new(rhs.0)); Self(expr).parens() } - pub(crate) fn is(lhs: Predicate, rhs: Predicate) -> Self { + pub fn is(lhs: Predicate, rhs: Predicate) -> Self { let expr = ast::Expr::Binary(Box::new(lhs.0), ast::Operator::Is, Box::new(rhs.0)); Self(expr).parens() } - pub(crate) fn parens(self) -> Self { - let expr = ast::Expr::Parenthesized(vec![self.0]); + pub fn parens(self) -> Self { + let expr = ast::Expr::Parenthesized(vec![Box::new(self.0)]); Self(expr) } - pub(crate) fn eval(&self, row: &[SimValue], table: &Table) -> Option { + pub fn eval(&self, row: &[SimValue], table: &Table) -> Option { expr_to_value(&self.0, row, table) } - pub(crate) fn test(&self, row: &[SimValue], table: &T) -> bool { + pub fn test(&self, row: &[SimValue], table: &T) -> bool { let value = expr_to_value(&self.0, row, table); value.is_some_and(|value| value.as_bool()) } diff --git a/sql_generation/model/query/select.rs b/sql_generation/model/query/select.rs new file mode 100644 index 000000000..6c34888ff --- /dev/null +++ b/sql_generation/model/query/select.rs @@ -0,0 +1,378 @@ +use std::{collections::HashSet, fmt::Display}; + +pub use ast::Distinctness; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use turso_parser::ast::{self, fmt::ToTokens, SortOrder}; + +use crate::model::{ + query::EmptyContext, + table::{JoinTable, JoinType, JoinedTable, Table}, +}; + +use super::predicate::Predicate; + +/// `SELECT` or `RETURNING` result column +// https://sqlite.org/syntax/result-column.html +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[allow(clippy::large_enum_variant)] +pub enum ResultColumn { + /// expression + Expr(Predicate), + /// `*` + Star, + /// column name + Column(String), +} + +impl Display for ResultColumn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ResultColumn::Expr(expr) => write!(f, "({expr})"), + ResultColumn::Star => write!(f, "*"), + ResultColumn::Column(name) => write!(f, "{name}"), + } + } +} +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Select { + pub body: SelectBody, + pub limit: Option, +} + +impl Select { + pub fn simple(table: String, where_clause: Predicate) -> Self { + Self::single( + table, + vec![ResultColumn::Star], + where_clause, + None, + Distinctness::All, + ) + } + + pub fn expr(expr: Predicate) -> Self { + Select { + body: SelectBody { + select: Box::new(SelectInner { + distinctness: Distinctness::All, + columns: vec![ResultColumn::Expr(expr)], + from: None, + where_clause: Predicate::true_(), + order_by: None, + }), + compounds: Vec::new(), + }, + limit: None, + } + } + + pub fn single( + table: String, + result_columns: Vec, + where_clause: Predicate, + limit: Option, + distinct: Distinctness, + ) -> Self { + Select { + body: SelectBody { + select: Box::new(SelectInner { + distinctness: distinct, + columns: result_columns, + from: Some(FromClause { + table, + joins: Vec::new(), + }), + where_clause, + order_by: None, + }), + compounds: Vec::new(), + }, + limit, + } + } + + pub fn compound(left: Select, right: Select, operator: CompoundOperator) -> Self { + let mut body = left.body; + body.compounds.push(CompoundSelect { + operator, + select: Box::new(right.body.select.as_ref().clone()), + }); + Select { + body, + limit: left.limit.or(right.limit), + } + } + + pub fn dependencies(&self) -> HashSet { + if self.body.select.from.is_none() { + return HashSet::new(); + } + let from = self.body.select.from.as_ref().unwrap(); + let mut tables = HashSet::new(); + tables.insert(from.table.clone()); + + tables.extend(from.dependencies()); + + for compound in &self.body.compounds { + tables.extend( + compound + .select + .from + .as_ref() + .map(|f| f.dependencies()) + .unwrap_or(vec![]) + .into_iter(), + ); + } + + tables + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct SelectBody { + /// first select + pub select: Box, + /// compounds + pub compounds: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct OrderBy { + pub columns: Vec<(String, SortOrder)>, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct SelectInner { + /// `DISTINCT` + pub distinctness: Distinctness, + /// columns + pub columns: Vec, + /// `FROM` clause + pub from: Option, + /// `WHERE` clause + pub where_clause: Predicate, + /// `ORDER BY` clause + pub order_by: Option, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum CompoundOperator { + /// `UNION` + Union, + /// `UNION ALL` + UnionAll, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct CompoundSelect { + /// operator + pub operator: CompoundOperator, + /// select + pub select: Box, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct FromClause { + /// table + pub table: String, + /// `JOIN`ed tables + pub joins: Vec, +} + +impl FromClause { + fn to_sql_ast(&self) -> ast::FromClause { + ast::FromClause { + select: Box::new(ast::SelectTable::Table( + ast::QualifiedName::single(ast::Name::new(&self.table)), + None, + None, + )), + joins: self + .joins + .iter() + .map(|join| ast::JoinedSelectTable { + operator: match join.join_type { + JoinType::Inner => ast::JoinOperator::TypedJoin(Some(ast::JoinType::INNER)), + JoinType::Left => ast::JoinOperator::TypedJoin(Some(ast::JoinType::LEFT)), + JoinType::Right => ast::JoinOperator::TypedJoin(Some(ast::JoinType::RIGHT)), + JoinType::Full => ast::JoinOperator::TypedJoin(Some(ast::JoinType::OUTER)), + JoinType::Cross => ast::JoinOperator::TypedJoin(Some(ast::JoinType::CROSS)), + }, + table: Box::new(ast::SelectTable::Table( + ast::QualifiedName::single(ast::Name::new(&join.table)), + None, + None, + )), + constraint: Some(ast::JoinConstraint::On(Box::new(join.on.0.clone()))), + }) + .collect(), + } + } + + pub fn dependencies(&self) -> Vec { + let mut deps = vec![self.table.clone()]; + for join in &self.joins { + deps.push(join.table.clone()); + } + deps + } + + pub fn into_join_table(&self, tables: &[Table]) -> JoinTable { + let first_table = tables + .iter() + .find(|t| t.name == self.table) + .expect("Table not found"); + + let mut join_table = JoinTable { + tables: vec![first_table.clone()], + rows: Vec::new(), + }; + + for join in &self.joins { + let joined_table = tables + .iter() + .find(|t| t.name == join.table) + .expect("Joined table not found"); + + join_table.tables.push(joined_table.clone()); + + match join.join_type { + JoinType::Inner => { + // Implement inner join logic + let join_rows = joined_table + .rows + .iter() + .filter(|row| join.on.test(row, joined_table)) + .cloned() + .collect::>(); + // take a cartesian product of the rows + let all_row_pairs = join_table + .rows + .clone() + .into_iter() + .cartesian_product(join_rows.iter()); + + for (row1, row2) in all_row_pairs { + let row = row1.iter().chain(row2.iter()).cloned().collect::>(); + + let is_in = join.on.test(&row, &join_table); + + if is_in { + join_table.rows.push(row); + } + } + } + _ => todo!(), + } + } + join_table + } +} + +impl Select { + pub fn to_sql_ast(&self) -> ast::Select { + ast::Select { + with: None, + body: ast::SelectBody { + select: ast::OneSelect::Select { + distinctness: if self.body.select.distinctness == Distinctness::Distinct { + Some(ast::Distinctness::Distinct) + } else { + None + }, + columns: self + .body + .select + .columns + .iter() + .map(|col| match col { + ResultColumn::Expr(expr) => { + ast::ResultColumn::Expr(expr.0.clone().into_boxed(), None) + } + ResultColumn::Star => ast::ResultColumn::Star, + ResultColumn::Column(name) => ast::ResultColumn::Expr( + ast::Expr::Id(ast::Name::Ident(name.clone())).into_boxed(), + None, + ), + }) + .collect(), + from: self.body.select.from.as_ref().map(|f| f.to_sql_ast()), + where_clause: Some(self.body.select.where_clause.0.clone().into_boxed()), + group_by: None, + window_clause: Vec::new(), + }, + compounds: self + .body + .compounds + .iter() + .map(|compound| ast::CompoundSelect { + operator: match compound.operator { + CompoundOperator::Union => ast::CompoundOperator::Union, + CompoundOperator::UnionAll => ast::CompoundOperator::UnionAll, + }, + select: ast::OneSelect::Select { + distinctness: Some(compound.select.distinctness), + columns: compound + .select + .columns + .iter() + .map(|col| match col { + ResultColumn::Expr(expr) => { + ast::ResultColumn::Expr(expr.0.clone().into_boxed(), None) + } + ResultColumn::Star => ast::ResultColumn::Star, + ResultColumn::Column(name) => ast::ResultColumn::Expr( + ast::Expr::Id(ast::Name::Ident(name.clone())).into_boxed(), + None, + ), + }) + .collect(), + from: compound.select.from.as_ref().map(|f| f.to_sql_ast()), + where_clause: Some(compound.select.where_clause.0.clone().into_boxed()), + group_by: None, + window_clause: Vec::new(), + }, + }) + .collect(), + }, + order_by: self + .body + .select + .order_by + .as_ref() + .map(|o| { + o.columns + .iter() + .map(|(name, order)| ast::SortedColumn { + expr: ast::Expr::Id(ast::Name::Ident(name.clone())).into_boxed(), + order: match order { + SortOrder::Asc => Some(ast::SortOrder::Asc), + SortOrder::Desc => Some(ast::SortOrder::Desc), + }, + nulls: None, + }) + .collect() + }) + .unwrap_or_default(), + limit: self.limit.map(|l| ast::Limit { + expr: ast::Expr::Literal(ast::Literal::Numeric(l.to_string())).into_boxed(), + offset: None, + }), + } + } +} + +impl Display for Select { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_sql_ast().to_fmt_with_context(f, &EmptyContext {}) + } +} + +#[cfg(test)] +mod select_tests { + + #[test] + fn test_select_display() {} +} diff --git a/sql_generation/model/query/transaction.rs b/sql_generation/model/query/transaction.rs new file mode 100644 index 000000000..1114200a0 --- /dev/null +++ b/sql_generation/model/query/transaction.rs @@ -0,0 +1,32 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Begin { + pub immediate: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Commit; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Rollback; + +impl Display for Begin { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BEGIN {}", if self.immediate { "IMMEDIATE" } else { "" }) + } +} + +impl Display for Commit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "COMMIT") + } +} + +impl Display for Rollback { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ROLLBACK") + } +} diff --git a/sql_generation/model/query/update.rs b/sql_generation/model/query/update.rs new file mode 100644 index 000000000..412731bbe --- /dev/null +++ b/sql_generation/model/query/update.rs @@ -0,0 +1,34 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use crate::model::table::SimValue; + +use super::predicate::Predicate; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Update { + pub table: String, + pub set_values: Vec<(String, SimValue)>, // Pair of value for set expressions => SET name=value + pub predicate: Predicate, +} + +impl Update { + pub fn table(&self) -> &str { + &self.table + } +} + +impl Display for Update { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "UPDATE {} SET ", self.table)?; + for (i, (name, value)) in self.set_values.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{name} = {value}")?; + } + write!(f, " WHERE {}", self.predicate)?; + Ok(()) + } +} diff --git a/simulator/model/table.rs b/sql_generation/model/table.rs similarity index 93% rename from simulator/model/table.rs rename to sql_generation/model/table.rs index b69a197d0..87057b42b 100644 --- a/simulator/model/table.rs +++ b/sql_generation/model/table.rs @@ -2,11 +2,11 @@ use std::{fmt::Display, hash::Hash, ops::Deref}; use serde::{Deserialize, Serialize}; use turso_core::{numeric::Numeric, types}; -use turso_sqlite3_parser::ast; +use turso_parser::ast; use crate::model::query::predicate::Predicate; -pub(crate) struct Name(pub(crate) String); +pub struct Name(pub String); impl Deref for Name { type Target = str; @@ -41,11 +41,11 @@ impl TableContext for Table { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Table { - pub(crate) name: String, - pub(crate) columns: Vec, - pub(crate) rows: Vec>, - pub(crate) indexes: Vec, +pub struct Table { + pub name: String, + pub columns: Vec, + pub rows: Vec>, + pub indexes: Vec, } impl Table { @@ -60,11 +60,11 @@ impl Table { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Column { - pub(crate) name: String, - pub(crate) column_type: ColumnType, - pub(crate) primary: bool, - pub(crate) unique: bool, +pub struct Column { + pub name: String, + pub column_type: ColumnType, + pub primary: bool, + pub unique: bool, } // Uniquely defined by name in this case @@ -83,7 +83,7 @@ impl PartialEq for Column { impl Eq for Column {} #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum ColumnType { +pub enum ColumnType { Integer, Float, Text, @@ -136,23 +136,8 @@ pub struct JoinTable { pub rows: Vec>, } -fn float_to_string(float: &f64, serializer: S) -> Result -where - S: serde::Serializer, -{ - serializer.serialize_str(&format!("{float}")) -} - -fn string_to_float<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - s.parse().map_err(serde::de::Error::custom) -} - #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Serialize, Deserialize)] -pub(crate) struct SimValue(pub turso_core::Value); +pub struct SimValue(pub turso_core::Value); fn to_sqlite_blob(bytes: &[u8]) -> String { format!( diff --git a/sqlite3/include/sqlite3.h b/sqlite3/include/sqlite3.h index 21695d4f1..0d098ce81 100644 --- a/sqlite3/include/sqlite3.h +++ b/sqlite3/include/sqlite3.h @@ -76,6 +76,8 @@ int sqlite3_close(sqlite3 *db); int sqlite3_close_v2(sqlite3 *db); +const char *sqlite3_db_filename(sqlite3 *db, const char *db_name); + int sqlite3_trace_v2(sqlite3 *_db, unsigned int _mask, void (*_callback)(unsigned int, void*, void*, void*), @@ -105,6 +107,8 @@ int sqlite3_stmt_readonly(sqlite3_stmt *_stmt); int sqlite3_stmt_busy(sqlite3_stmt *_stmt); +sqlite3_stmt *sqlite3_next_stmt(sqlite3 *db, sqlite3_stmt *stmt); + int sqlite3_serialize(sqlite3 *_db, const char *_schema, void **_out, int *_out_bytes, unsigned int _flags); int sqlite3_deserialize(sqlite3 *_db, const char *_schema, const void *_in_, int _in_bytes, unsigned int _flags); @@ -153,6 +157,8 @@ int sqlite3_bind_parameter_count(sqlite3_stmt *_stmt); const char *sqlite3_bind_parameter_name(sqlite3_stmt *_stmt, int _idx); +int sqlite3_bind_parameter_index(sqlite3_stmt *_stmt, const char *_name); + int sqlite3_bind_null(sqlite3_stmt *_stmt, int _idx); int sqlite3_bind_int64(sqlite3_stmt *_stmt, int _idx, int64_t _val); diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 32a2976e3..b110a3c3e 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -56,6 +56,8 @@ struct sqlite3Inner { pub(crate) malloc_failed: bool, pub(crate) e_open_state: u8, pub(crate) p_err: *mut ffi::c_void, + pub(crate) filename: CString, + pub(crate) stmt_list: *mut sqlite3_stmt, } impl sqlite3 { @@ -63,6 +65,7 @@ impl sqlite3 { io: Arc, db: Arc, conn: Arc, + filename: CString, ) -> Self { let inner = sqlite3Inner { _io: io, @@ -73,6 +76,8 @@ impl sqlite3 { malloc_failed: false, e_open_state: SQLITE_STATE_OPEN, p_err: std::ptr::null_mut(), + filename, + stmt_list: std::ptr::null_mut(), }; #[allow(clippy::arc_with_non_send_sync)] let inner = Arc::new(Mutex::new(inner)); @@ -88,6 +93,7 @@ pub struct sqlite3_stmt { Option, *mut ffi::c_void, )>, + pub(crate) next: *mut sqlite3_stmt, } impl sqlite3_stmt { @@ -96,6 +102,7 @@ impl sqlite3_stmt { db, stmt, destructors: Vec::new(), + next: std::ptr::null_mut(), } } } @@ -132,26 +139,30 @@ pub unsafe extern "C" fn sqlite3_open( if db_out.is_null() { return SQLITE_MISUSE; } - let filename = CStr::from_ptr(filename); - let filename = match filename.to_str() { + let filename_cstr = CStr::from_ptr(filename); + let filename_str = match filename_cstr.to_str() { Ok(s) => s, Err(_) => return SQLITE_MISUSE, }; - let io: Arc = match filename { + let io: Arc = match filename_str { ":memory:" => Arc::new(turso_core::MemoryIO::new()), _ => match turso_core::PlatformIO::new() { Ok(io) => Arc::new(io), Err(_) => return SQLITE_CANTOPEN, }, }; - match turso_core::Database::open_file(io.clone(), filename, false, false) { + match turso_core::Database::open_file(io.clone(), filename_str, false, false) { Ok(db) => { let conn = db.connect().unwrap(); - *db_out = Box::leak(Box::new(sqlite3::new(io, db, conn))); + let filename = match filename_str { + ":memory:" => CString::new("".to_string()).unwrap(), + _ => CString::from(filename_cstr), + }; + *db_out = Box::leak(Box::new(sqlite3::new(io, db, conn, filename))); SQLITE_OK } Err(e) => { - trace!("error opening database {}: {:?}", filename, e); + trace!("error opening database {}: {:?}", filename_str, e); SQLITE_CANTOPEN } } @@ -184,6 +195,25 @@ pub unsafe extern "C" fn sqlite3_close_v2(db: *mut sqlite3) -> ffi::c_int { sqlite3_close(db) } +#[no_mangle] +pub unsafe extern "C" fn sqlite3_db_filename( + db: *mut sqlite3, + db_name: *const ffi::c_char, +) -> *const ffi::c_char { + if db.is_null() { + return std::ptr::null(); + } + if !db_name.is_null() { + let name = CStr::from_ptr(db_name); + if name.to_bytes() != b"main" { + return std::ptr::null(); + } + } + let db = &*db; + let inner = db.inner.lock().unwrap(); + inner.filename.as_ptr() +} + #[no_mangle] pub unsafe extern "C" fn sqlite3_trace_v2( _db: *mut sqlite3, @@ -253,7 +283,12 @@ pub unsafe extern "C" fn sqlite3_prepare_v2( return SQLITE_ERROR; } }; - *out_stmt = Box::leak(Box::new(sqlite3_stmt::new(raw_db, stmt))); + let new_stmt = Box::leak(Box::new(sqlite3_stmt::new(raw_db, stmt))); + + new_stmt.next = db.stmt_list; + db.stmt_list = new_stmt; + + *out_stmt = new_stmt; SQLITE_OK } @@ -264,6 +299,25 @@ pub unsafe extern "C" fn sqlite3_finalize(stmt: *mut sqlite3_stmt) -> ffi::c_int } let stmt_ref = &mut *stmt; + if !stmt_ref.db.is_null() { + let db = &mut *stmt_ref.db; + let mut db_inner = db.inner.lock().unwrap(); + + if db_inner.stmt_list == stmt { + db_inner.stmt_list = stmt_ref.next; + } else { + let mut current = db_inner.stmt_list; + while !current.is_null() { + let current_ref = &mut *current; + if current_ref.next == stmt { + current_ref.next = stmt_ref.next; + break; + } + current = current_ref.next; + } + } + } + for (_idx, destructor_opt, ptr) in stmt_ref.destructors.drain(..) { if let Some(destructor_fn) = destructor_opt { destructor_fn(ptr); @@ -355,6 +409,25 @@ pub unsafe extern "C" fn sqlite3_stmt_busy(_stmt: *mut sqlite3_stmt) -> ffi::c_i stub!(); } +/// Iterate over all prepared statements in the database. +#[no_mangle] +pub unsafe extern "C" fn sqlite3_next_stmt( + db: *mut sqlite3, + stmt: *mut sqlite3_stmt, +) -> *mut sqlite3_stmt { + if db.is_null() { + return std::ptr::null_mut(); + } + if stmt.is_null() { + let db = &*db; + let db = db.inner.lock().unwrap(); + db.stmt_list + } else { + let stmt = &mut *stmt; + stmt.next + } +} + #[no_mangle] pub unsafe extern "C" fn sqlite3_serialize( _db: *mut sqlite3, @@ -378,8 +451,17 @@ pub unsafe extern "C" fn sqlite3_deserialize( } #[no_mangle] -pub unsafe extern "C" fn sqlite3_get_autocommit(_db: *mut sqlite3) -> ffi::c_int { - stub!(); +pub unsafe extern "C" fn sqlite3_get_autocommit(db: *mut sqlite3) -> ffi::c_int { + if db.is_null() { + return 1; + } + let db: &mut sqlite3 = &mut *db; + let inner = db.inner.lock().unwrap(); + if inner.conn.get_auto_commit() { + 1 + } else { + 0 + } } #[no_mangle] @@ -426,13 +508,24 @@ pub unsafe extern "C" fn sqlite3_limit( } #[no_mangle] -pub unsafe extern "C" fn sqlite3_malloc64(_n: ffi::c_int) -> *mut ffi::c_void { - stub!(); +pub unsafe extern "C" fn sqlite3_malloc(n: ffi::c_int) -> *mut ffi::c_void { + sqlite3_malloc64(n) } #[no_mangle] -pub unsafe extern "C" fn sqlite3_free(_ptr: *mut ffi::c_void) { - stub!(); +pub unsafe extern "C" fn sqlite3_malloc64(n: ffi::c_int) -> *mut ffi::c_void { + if n <= 0 { + return std::ptr::null_mut(); + } + libc::malloc(n as usize) +} + +#[no_mangle] +pub unsafe extern "C" fn sqlite3_free(ptr: *mut ffi::c_void) { + if ptr.is_null() { + return; + } + libc::free(ptr); } /// Returns the error code for the most recent failed API call to connection. @@ -529,6 +622,28 @@ pub unsafe extern "C" fn sqlite3_bind_parameter_name( } } +#[no_mangle] +pub unsafe extern "C" fn sqlite3_bind_parameter_index( + stmt: *mut sqlite3_stmt, + name: *const ffi::c_char, +) -> ffi::c_int { + if stmt.is_null() || name.is_null() { + return 0; + } + + let stmt = &*stmt; + let name_str = match CStr::from_ptr(name).to_str() { + Ok(s) => s, + Err(_) => return 0, + }; + + if let Some(index) = stmt.stmt.parameter_index(name_str) { + index.get() as ffi::c_int + } else { + 0 + } +} + #[no_mangle] pub unsafe extern "C" fn sqlite3_bind_null(stmt: *mut sqlite3_stmt, idx: ffi::c_int) -> ffi::c_int { if stmt.is_null() { @@ -702,6 +817,18 @@ pub unsafe extern "C" fn sqlite3_bind_blob( SQLITE_OK } +#[no_mangle] +pub unsafe extern "C" fn sqlite3_clear_bindings(stmt: *mut sqlite3_stmt) -> ffi::c_int { + if stmt.is_null() { + return SQLITE_MISUSE; + } + + let stmt_ref = &mut *stmt; + stmt_ref.stmt.clear_bindings(); + + SQLITE_OK +} + #[no_mangle] pub unsafe extern "C" fn sqlite3_column_type( stmt: *mut sqlite3_stmt, diff --git a/sqlite3/tests/compat/mod.rs b/sqlite3/tests/compat/mod.rs index 2192a89f9..52ed3d8fa 100644 --- a/sqlite3/tests/compat/mod.rs +++ b/sqlite3/tests/compat/mod.rs @@ -19,6 +19,7 @@ extern "C" { fn sqlite3_libversion_number() -> i32; fn sqlite3_close(db: *mut sqlite3) -> i32; fn sqlite3_open(filename: *const libc::c_char, db: *mut *mut sqlite3) -> i32; + fn sqlite3_db_filename(db: *mut sqlite3, db_name: *const libc::c_char) -> *const libc::c_char; fn sqlite3_prepare_v2( db: *mut sqlite3, sql: *const libc::c_char, @@ -27,6 +28,7 @@ extern "C" { tail: *mut *const libc::c_char, ) -> i32; fn sqlite3_step(stmt: *mut sqlite3_stmt) -> i32; + fn sqlite3_reset(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_finalize(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_wal_checkpoint(db: *mut sqlite3, db_name: *const libc::c_char) -> i32; fn sqlite3_wal_checkpoint_v2( @@ -46,9 +48,12 @@ extern "C" { ) -> i32; fn libsql_wal_disable_checkpoint(db: *mut sqlite3) -> i32; fn sqlite3_column_int(stmt: *mut sqlite3_stmt, idx: i32) -> i64; + fn sqlite3_next_stmt(db: *mut sqlite3, stmt: *mut sqlite3_stmt) -> *mut sqlite3_stmt; fn sqlite3_bind_int(stmt: *mut sqlite3_stmt, idx: i32, val: i64) -> i32; fn sqlite3_bind_parameter_count(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_bind_parameter_name(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; + fn sqlite3_bind_parameter_index(stmt: *mut sqlite3_stmt, name: *const libc::c_char) -> i32; + fn sqlite3_clear_bindings(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_column_name(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; fn sqlite3_last_insert_rowid(db: *mut sqlite3) -> i32; fn sqlite3_column_count(stmt: *mut sqlite3_stmt) -> i32; @@ -71,6 +76,7 @@ extern "C" { fn sqlite3_column_blob(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_void; fn sqlite3_column_type(stmt: *mut sqlite3_stmt, idx: i32) -> i32; fn sqlite3_column_decltype(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; + fn sqlite3_get_autocommit(db: *mut sqlite3) -> i32; } const SQLITE_OK: i32 = 0; @@ -986,6 +992,85 @@ mod tests { } } + #[test] + fn test_get_autocommit() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Should be in autocommit mode by default + assert_eq!(sqlite3_get_autocommit(db), 1); + + // Begin a transaction + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"BEGIN".as_ptr(), -1, &mut stmt, ptr::null_mut()), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + // Should NOT be in autocommit mode during transaction + assert_eq!(sqlite3_get_autocommit(db), 0); + + // Create a table within the transaction + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"CREATE TABLE test (id INTEGER PRIMARY KEY)".as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + // Still not in autocommit mode + assert_eq!(sqlite3_get_autocommit(db), 0); + + // Commit the transaction + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"COMMIT".as_ptr(), -1, &mut stmt, ptr::null_mut()), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + // Should be back in autocommit mode after commit + assert_eq!(sqlite3_get_autocommit(db), 1); + + // Test with ROLLBACK + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"BEGIN".as_ptr(), -1, &mut stmt, ptr::null_mut()), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + assert_eq!(sqlite3_get_autocommit(db), 0); + + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"ROLLBACK".as_ptr(), -1, &mut stmt, ptr::null_mut()), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + // Should be back in autocommit mode after rollback + assert_eq!(sqlite3_get_autocommit(db), 1); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + #[test] fn test_wal_checkpoint() { let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); @@ -1095,4 +1180,221 @@ mod tests { } } } + + #[test] + fn test_sqlite3_clear_bindings() { + unsafe { + let mut db: *mut sqlite3 = ptr::null_mut(); + let mut stmt: *mut sqlite3_stmt = ptr::null_mut(); + + assert_eq!(sqlite3_open(c":memory:".as_ptr(), &mut db), SQLITE_OK); + + assert_eq!( + sqlite3_prepare_v2( + db, + c"CREATE TABLE person (id INTEGER, name TEXT, age INTEGER)".as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + assert_eq!( + sqlite3_prepare_v2( + db, + c"INSERT INTO person (id, name, age) VALUES (1, 'John', 25), (2, 'Jane', 30)" + .as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT * FROM person WHERE id = ? AND age > ?".as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + + // Bind parameters - should find John (id=1, age=25 > 20) + assert_eq!(sqlite3_bind_int(stmt, 1, 1), SQLITE_OK); + assert_eq!(sqlite3_bind_int(stmt, 2, 20), SQLITE_OK); + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + assert_eq!(sqlite3_column_int(stmt, 0), 1); + assert_eq!(sqlite3_column_int(stmt, 2), 25); + + // Reset and clear bindings, query should return no rows + assert_eq!(sqlite3_reset(stmt), SQLITE_OK); + assert_eq!(sqlite3_clear_bindings(stmt), SQLITE_OK); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_sqlite3_bind_parameter_index() { + const SQLITE_OK: i32 = 0; + + unsafe { + let mut db: *mut sqlite3 = ptr::null_mut(); + let mut stmt: *mut sqlite3_stmt = ptr::null_mut(); + + assert_eq!(sqlite3_open(c":memory:".as_ptr(), &mut db), SQLITE_OK); + + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT * FROM sqlite_master WHERE name = :table_name AND type = :object_type" + .as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + + let index1 = sqlite3_bind_parameter_index(stmt, c":table_name".as_ptr()); + assert_eq!(index1, 1); + + let index2 = sqlite3_bind_parameter_index(stmt, c":object_type".as_ptr()); + assert_eq!(index2, 2); + + let index3 = sqlite3_bind_parameter_index(stmt, c":nonexistent".as_ptr()); + assert_eq!(index3, 0); + + let index4 = sqlite3_bind_parameter_index(stmt, ptr::null()); + assert_eq!(index4, 0); + + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + } + } + + #[test] + fn test_sqlite3_db_filename() { + const SQLITE_OK: i32 = 0; + + unsafe { + // Test with in-memory database + let mut db: *mut sqlite3 = ptr::null_mut(); + assert_eq!(sqlite3_open(c":memory:".as_ptr(), &mut db), SQLITE_OK); + let filename = sqlite3_db_filename(db, c"main".as_ptr()); + assert!(!filename.is_null()); + let filename_str = std::ffi::CStr::from_ptr(filename).to_str().unwrap(); + assert_eq!(filename_str, ""); + assert_eq!(sqlite3_close(db), SQLITE_OK); + + // Open a file-backed database + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Test with "main" database name + let filename = sqlite3_db_filename(db, c"main".as_ptr()); + assert!(!filename.is_null()); + let filename_str = std::ffi::CStr::from_ptr(filename).to_str().unwrap(); + assert_eq!(filename_str, temp_file.path().to_str().unwrap()); + + // Test with NULL database name (defaults to main) + let filename_default = sqlite3_db_filename(db, ptr::null()); + assert!(!filename_default.is_null()); + assert_eq!(filename, filename_default); + + // Test with non-existent database name + let filename = sqlite3_db_filename(db, c"temp".as_ptr()); + assert!(filename.is_null()); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_sqlite3_next_stmt() { + const SQLITE_OK: i32 = 0; + + unsafe { + let mut db: *mut sqlite3 = ptr::null_mut(); + assert_eq!(sqlite3_open(c":memory:".as_ptr(), &mut db), SQLITE_OK); + + // Initially, there should be no prepared statements + let iter = sqlite3_next_stmt(db, ptr::null_mut()); + assert!(iter.is_null()); + + // Prepare first statement + let mut stmt1: *mut sqlite3_stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"SELECT 1;".as_ptr(), -1, &mut stmt1, ptr::null_mut()), + SQLITE_OK + ); + assert!(!stmt1.is_null()); + + // Now there should be one statement + let iter = sqlite3_next_stmt(db, ptr::null_mut()); + assert_eq!(iter, stmt1); + + // And no more after that + let iter = sqlite3_next_stmt(db, stmt1); + assert!(iter.is_null()); + + // Prepare second statement + let mut stmt2: *mut sqlite3_stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"SELECT 2;".as_ptr(), -1, &mut stmt2, ptr::null_mut()), + SQLITE_OK + ); + assert!(!stmt2.is_null()); + + // Prepare third statement + let mut stmt3: *mut sqlite3_stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"SELECT 3;".as_ptr(), -1, &mut stmt3, ptr::null_mut()), + SQLITE_OK + ); + assert!(!stmt3.is_null()); + + // Count all statements + let mut count = 0; + let mut iter = sqlite3_next_stmt(db, ptr::null_mut()); + while !iter.is_null() { + count += 1; + iter = sqlite3_next_stmt(db, iter); + } + assert_eq!(count, 3); + + // Finalize the middle statement + assert_eq!(sqlite3_finalize(stmt2), SQLITE_OK); + + // Count should now be 2 + count = 0; + iter = sqlite3_next_stmt(db, ptr::null_mut()); + while !iter.is_null() { + count += 1; + iter = sqlite3_next_stmt(db, iter); + } + assert_eq!(count, 2); + + // Finalize remaining statements + assert_eq!(sqlite3_finalize(stmt1), SQLITE_OK); + assert_eq!(sqlite3_finalize(stmt3), SQLITE_OK); + + // Should be no statements left + let iter = sqlite3_next_stmt(db, ptr::null_mut()); + assert!(iter.is_null()); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } } diff --git a/sqlite3/tests/sqlite3_tests.c b/sqlite3/tests/sqlite3_tests.c index 2cd490a93..9fc1ffc49 100644 --- a/sqlite3/tests/sqlite3_tests.c +++ b/sqlite3/tests/sqlite3_tests.c @@ -18,6 +18,7 @@ void test_sqlite3_bind_text2(); void test_sqlite3_bind_blob(); void test_sqlite3_column_type(); void test_sqlite3_column_decltype(); +void test_sqlite3_next_stmt(); int allocated = 0; @@ -35,6 +36,7 @@ int main(void) test_sqlite3_bind_blob(); test_sqlite3_column_type(); test_sqlite3_column_decltype(); + test_sqlite3_next_stmt(); return 0; } diff --git a/testing/analyze.test b/testing/analyze.test index a1761bf45..7d7f95066 100755 --- a/testing/analyze.test +++ b/testing/analyze.test @@ -5,28 +5,26 @@ source $testdir/tester.tcl # Things that do work: do_execsql_test_on_specific_db {:memory:} empty-table { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer); ANALYZE temp; SELECT * FROM sqlite_stat1; } {} do_execsql_test_on_specific_db {:memory:} one-row-table { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer); INSERT INTO temp VALUES (1); ANALYZE temp; SELECT * FROM sqlite_stat1; } {temp||1} -do_execsql_test_on_specific_db {:memory:} analyze-deletes { - CREATE TABLE sqlite_stat1(tbl,idx,stat); - INSERT INTO sqlite_stat1 VALUES ('temp', NULL, 10); +do_execsql_test_on_specific_db {:memory:} analyze-overwrites { CREATE TABLE temp (a integer); INSERT INTO temp VALUES (1); ANALYZE temp; + INSERT INTO temp VALUES (2); + ANALYZE temp; SELECT * FROM sqlite_stat1; -} {temp||1} +} {temp||2} # Things that don't work: @@ -38,25 +36,17 @@ do_execsql_test_in_memory_error analyze-one-database-fails { ANALYZE main; } {.*ANALYZE.*not supported.*} -do_execsql_test_in_memory_error analyze-without-stat-table-fails { - CREATE TABLE temp (a integer); - ANALYZE temp; -} {.*ANALYZE.*not supported.*} - do_execsql_test_in_memory_error analyze-table-with-pk-fails { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer primary key); ANALYZE temp; } {.*ANALYZE.*not supported.*} do_execsql_test_in_memory_error analyze-table-without-rowid-fails { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer primary key) WITHOUT ROWID; ANALYZE temp; } {.*ANALYZE.*not supported.*} do_execsql_test_in_memory_error analyze-index-fails { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer, b integer); CREATE INDEX temp_b ON temp (b); ANALYZE temp_b; diff --git a/testing/join.test b/testing/join.test index db25128ec..853ccc875 100755 --- a/testing/join.test +++ b/testing/join.test @@ -302,3 +302,12 @@ do_execsql_test left-join-backwards-iteration { } {12|Alan| 11|Travis|accessories 10|Daniel|coat} + +# regression test for issue 2794: not nulling out rowid properly when left join does not match +do_execsql_test_on_specific_db {:memory:} min-null-regression-test { + create table t (x integer primary key, y); + create table u (x integer primary key, y); + insert into t values (1,1),(2,2); + insert into u values (1,1),(3,3); + select count(u.x) from t left join u using(y); +} {1} \ No newline at end of file diff --git a/testing/materialized_views.test b/testing/materialized_views.test index 36cf63e52..2b6a56be3 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -367,4 +367,21 @@ do_execsql_test_on_specific_db {:memory:} matview-mixed-operations-sequence { 200|100|1 100|25|1 200|100|1 -300|150|1} \ No newline at end of file +300|150|1} + +do_execsql_test_on_specific_db {:memory:} matview-projections { + CREATE TABLE t(a,b); + + CREATE MATERIALIZED VIEW v AS + SELECT b, a, b + a as c , (b * a) + 10 as d , min(a,b) as e + FROM t + where b > 2; + + INSERT INTO t VALUES (1, 1); + INSERT INTO t VALUES (2, 2); + INSERT INTO t VALUES (3, 4); + INSERT INTO t VALUES (4, 3); + + SELECT * from v; +} {4|3|7|22|3 +3|4|7|22|3} diff --git a/tests/integration/functions/test_wal_api.rs b/tests/integration/functions/test_wal_api.rs index 38537000a..069b0b084 100644 --- a/tests/integration/functions/test_wal_api.rs +++ b/tests/integration/functions/test_wal_api.rs @@ -926,3 +926,53 @@ fn test_db_share_same_file() { ]] ); } + +#[test] +fn test_wal_api_simulate_spilled_frames() { + let (mut rng, _) = rng_from_time(); + let db1 = TempDatabase::new_empty(false); + let conn1 = db1.connect_limbo(); + let db2 = TempDatabase::new_empty(false); + let conn2 = db2.connect_limbo(); + conn1 + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY, y)") + .unwrap(); + conn2 + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY, y)") + .unwrap(); + let watermark = conn1.wal_state().unwrap().max_frame; + for _ in 0..128 { + let key = rng.next_u32(); + let length = rng.next_u32() % 4096 + 1; + conn1 + .execute(format!( + "INSERT INTO t VALUES ({key}, randomblob({length}))" + )) + .unwrap(); + } + let mut frame = [0u8; 24 + 4096]; + conn2 + .checkpoint(CheckpointMode::Truncate { + upper_bound_inclusive: None, + }) + .unwrap(); + conn2.wal_insert_begin().unwrap(); + let frames_count = conn1.wal_state().unwrap().max_frame; + for frame_id in watermark + 1..=frames_count { + let mut info = conn1.wal_get_frame(frame_id, &mut frame).unwrap(); + info.db_size = 0; + info.put_to_frame_header(&mut frame); + conn2 + .wal_insert_frame(frame_id - watermark, &frame) + .unwrap(); + } + for _ in 0..128 { + let key = rng.next_u32(); + let length = rng.next_u32() % 4096 + 1; + conn2 + .execute(format!( + "INSERT INTO t VALUES ({key}, randomblob({length}))" + )) + .unwrap(); + } +} diff --git a/tests/integration/query_processing/encryption.rs b/tests/integration/query_processing/encryption.rs index c286bde7c..7fdaec29d 100644 --- a/tests/integration/query_processing/encryption.rs +++ b/tests/integration/query_processing/encryption.rs @@ -16,8 +16,9 @@ fn test_per_page_encryption() -> anyhow::Result<()> { run_query( &tmp_db, &conn, - "PRAGMA key = 'super secret key for encryption';", + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';", )?; + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';")?; run_query( &tmp_db, &conn, @@ -49,16 +50,73 @@ fn test_per_page_encryption() -> anyhow::Result<()> { should_panic.is_err(), "should panic when accessing encrypted DB without key" ); + + // it should also panic if we specify either only key or cipher + let conn = tmp_db.connect_limbo(); + let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';").unwrap(); + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_: &Row| {}).unwrap(); + })); + assert!( + should_panic.is_err(), + "should panic when accessing encrypted DB without key" + ); + + let conn = tmp_db.connect_limbo(); + let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + run_query( + &tmp_db, + &conn, + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';", + ).unwrap(); + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_: &Row| {}).unwrap(); + })); + assert!( + should_panic.is_err(), + "should panic when accessing encrypted DB without cipher name" + ); + + // it should panic if we specify wrong cipher or key + let conn = tmp_db.connect_limbo(); + let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + run_query( + &tmp_db, + &conn, + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';", + ).unwrap(); + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aes256gcm';").unwrap(); + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_: &Row| {}).unwrap(); + })); + assert!( + should_panic.is_err(), + "should panic when accessing encrypted DB with incorrect cipher" + ); + + let conn = tmp_db.connect_limbo(); + let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';").unwrap(); + run_query( + &tmp_db, + &conn, + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76377';", + ).unwrap(); + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_: &Row| {}).unwrap(); + })); + assert!( + should_panic.is_err(), + "should panic when accessing encrypted DB with incorrect key" + ); } { // let's test the existing db with the key let existing_db = TempDatabase::new_with_existent(&db_path, false); let conn = existing_db.connect_limbo(); + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';")?; run_query( &existing_db, &conn, - "PRAGMA key = 'super secret key for encryption';", + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';", )?; run_query_on_row(&existing_db, &conn, "SELECT * FROM test", |row: &Row| { assert_eq!(row.get::(0).unwrap(), 1); diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index e2cbfa06e..60cd6495a 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -284,6 +284,7 @@ fn test_wal_checkpoint() -> anyhow::Result<()> { let conn = tmp_db.connect_limbo(); for i in 0..iterations { + log::info!("iteration #{i}"); let insert_query = format!("INSERT INTO test VALUES ({i})"); do_flush(&conn, &tmp_db)?; conn.checkpoint(CheckpointMode::Passive { @@ -823,7 +824,7 @@ pub fn run_query_core( on_row(row) } } - _ => unreachable!(), + r => panic!("unexpected step result: {r:?}"), } } }; diff --git a/vendored/sqlite3-parser/src/parser/ast/fmt.rs b/vendored/sqlite3-parser/src/parser/ast/fmt.rs index 64f722421..96013bf35 100644 --- a/vendored/sqlite3-parser/src/parser/ast/fmt.rs +++ b/vendored/sqlite3-parser/src/parser/ast/fmt.rs @@ -890,6 +890,11 @@ impl ToTokens for Expr { Some(_) => s.append(TK_VARIABLE, Some(&("?".to_owned() + var))), None => s.append(TK_VARIABLE, Some("?")), }, + Self::Register(reg) => { + // This is for internal use only, not part of SQL syntax + // Use a special notation that won't conflict with SQL + s.append(TK_VARIABLE, Some(&format!("$r{reg}"))) + } } } } diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index de096827f..b67c51507 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -365,6 +365,9 @@ pub enum Expr { }, /// binary expression Binary(Box, Operator, Box), + /// Register reference for DBSP expression compilation + /// This is not part of SQL syntax but used internally for incremental computation + Register(usize), /// `CASE` expression Case { /// operand