diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 000000000..151ee791c --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,43 @@ +name: Go Tests + +on: + push: + branches: + - main + tags: + - v* + pull_request: + branches: + - main + +env: + working-directory: bindings/go + +jobs: + test: + runs-on: ubuntu-latest + + defaults: + run: + working-directory: ${{ env.working-directory }} + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Rust(stable) + uses: dtolnay/rust-toolchain@stable + + - name: Set up go + uses: actions/setup-go@v4 + with: + go-version: "1.23" + + - name: build Go bindings library + run: cargo build --package limbo-go + + - name: run Go tests + env: + LD_LIBRARY_PATH: ${{ github.workspace }}/target/debug:$LD_LIBRARY_PATH + run: go test + diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index d6e02a69f..f7dc35258 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -15,19 +15,28 @@ env: PIP_DISABLE_PIP_VERSION_CHECK: "true" jobs: + configure-strategy: + runs-on: ubuntu-latest + outputs: + python-versions: ${{ steps.gen-matrix.outputs.python-versions }} + steps: + - id: gen-matrix + run: | + if [ ${{ github.event_name }} == "pull_request" ]; then + echo "python-versions=[\"3.13\"]" >> $GITHUB_OUTPUT + else + echo "python-versions=[\"3.9\",\"3.10\",\"3.11\",\"3.12\",\"3.13\"]" >> $GITHUB_OUTPUT + fi + test: + needs: configure-strategy strategy: matrix: os: - ubuntu-latest - macos-latest - windows-latest - python-version: - - "3.9" - - "3.10" - - "3.11" - - "3.12" - - "3.13" + python-version: ${{ fromJson(needs.configure-strategy.outputs.python-versions) }} runs-on: ${{ matrix.os }} defaults: @@ -123,7 +132,7 @@ jobs: - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: linux-wheels + name: wheels-linux path: bindings/python/dist macos-x86_64: @@ -152,7 +161,7 @@ jobs: - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: macos-x86-wheels + name: wheels-macos-x86 path: bindings/python/dist macos-arm64: @@ -181,7 +190,7 @@ jobs: - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: macos-arm64-wheels + name: wheels-macos-arm64 path: bindings/python/dist sdist: @@ -200,7 +209,7 @@ jobs: - name: Upload sdist uses: actions/upload-artifact@v4 with: - name: sdist-wheels + name: wheels-sdist path: bindings/python/dist release: @@ -209,18 +218,11 @@ jobs: if: "startsWith(github.ref, 'refs/tags/')" needs: [linux, macos-arm64, macos-x86_64, sdist] steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: - name: linux-wheels - - uses: actions/download-artifact@v3 - with: - name: macos-x86-wheels - - uses: actions/download-artifact@v3 - with: - name: macos-arm64-wheels - - uses: actions/download-artifact@v3 - with: - name: sdist-wheels + path: bindings/python/dist + pattern: wheels-* + merge-multiple: true - name: Publish to PyPI uses: PyO3/maturin-action@v1 env: diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index ad9d1bbe7..44fb4cff3 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -14,7 +14,7 @@ jobs: - name: Close stale pull requests uses: actions/stale@v6 with: - repo-token: ${{ secrets.STALE_GH_TOKEN }} + repo-token: ${{ secrets.GH_TOKEN }} operations-per-run: 1000 ascending: true stale-pr-message: 'This pull request has been marked as stale due to inactivity. It will be closed in 7 days if no further activity occurs.' diff --git a/.gitignore b/.gitignore index 1f1406ceb..c7c56a7ee 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,7 @@ dist/ .DS_Store # Javascript -**/node_modules/ \ No newline at end of file +**/node_modules/ + +# testing +testing/limbo_output.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 9186ec8a5..fee38e12c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,79 @@ # Changelog +## 0.0.14 - 2025-02-04 + +### Added + +**Core:** + +* Improve changes() and total_changes() functions and add tests (Ben Li) +* Add support for `json_object` function (Jorge Hermo) +* Implemented json_valid function (Harin) +* Implement Not (Vrishabh) +* Initial support for wal_checkpoint pragma (Sonny) +* Implement Or and And bytecodes (Diego Reis) +* Implement strftime function (Pedro Muniz) +* implement sqlite_source_id function (Glauber Costa) +* json_patch() function implementation (Ihor Andrianov) +* json_remove() function implementation (Ihor Andrianov) +* Implement isnull / not null for filter expressions (Glauber Costa) +* Add support for offset in select queries (Ben Li) +* Support returning column names from prepared statement (Preston Thorpe) +* Implement Concat opcode (Harin) +* Table info (Glauber Costa) +* Pragma list (Glauber Costa) +* Implement Noop bytecode (Pedro Muniz) +* implement is and is not where constraints (Glauber Costa) +* Pagecount (Glauber Costa) +* Support column aliases in GROUP BY, ORDER BY and HAVING (Jussi Saurio) +* Implement json_pretty (Pedro Muniz) + +**Extensions:** + +* Initial pass on vector extension (Pekka Enberg) +* Enable static linking for 'built-in' extensions (Preston Thorpe) + +**Go Bindings:** + +* Initial support for Go database/sql driver (Preston Thorpe) +* Avoid potentially expensive operations on prepare' (Glauber Costa) + +**Java Bindings:** + +* Implement JDBC `ResultSet` (Kim Seon Woo) +* Implement LimboConnection `close()` (Kim Seon Woo) +* Implement close() for `LimboStatement` and `LimboResultSet` (Kim Seon Woo) +* Implement methods in `JDBC4ResultSet` (Kim Seon Woo) +* Load native library from Jar (Kim Seon Woo) +* Change logger dependency (Kim Seon Woo) +* Log driver loading error (Pekka Enberg) + +**Simulator:** + +* Implement `--load` and `--watch` flags (Alperen Keleş) + +**Build system and CI:** + +* Add Nyrkiö change point detection to 'cargo bench' workflow (Henrik Ingo) + +### Fixed + +* Fix `select X'1';` causes limbo to go in infinite loop (Krishna Vishal) +* Fix rowid search codegen (Nikita Sivukhin) +* Fix logical codegen (Nikita Sivukhin) +* Fix parser panic when duplicate column names are given to `CREATE TABLE` (Krishna Vishal) +* Fix panic when double quoted strings are used for column names. (Krishna Vishal) +* Fix `SELECT -9223372036854775808` result differs from SQLite (Krishna Vishal) +* Fix `SELECT ABS(-9223372036854775808)` causes limbo to panic. (Krishna Vishal) +* Fix memory leaks, make extension types more efficient (Preston Thorpe) +* Fix table with single column PRIMARY KEY to not create extra btree (Krishna Vishal) +* Fix null cmp codegen (Nikita Sivukhin) +* Fix null expr codegen (Nikita Sivukhin) +* Fix rowid generation (Nikita Sivukhin) +* Fix shr instruction (Nikita Sivukhin) +* Fix strftime function compatibility problems (Pedro Muniz) +* Dont fsync the WAL on read queries (Jussi Saurio) + ## 0.0.13 - 2025-01-19 ### Added diff --git a/COMPAT.md b/COMPAT.md index d7c8b0a2c..1c2b9b227 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -19,9 +19,12 @@ This document describes the compatibility of Limbo with SQLite. - [JSON functions](#json-functions) - [SQLite C API](#sqlite-c-api) - [SQLite VDBE opcodes](#sqlite-vdbe-opcodes) + - [SQLite journaling modes](#sqlite-journaling-modes) - [Extensions](#extensions) - [UUID](#uuid) - [regexp](#regexp) + - [Vector](#vector) + - [Time](#time) ## Features @@ -38,229 +41,229 @@ The current status of Limbo is: ### Statements -| Statement | Status | Comment | -| ------------------------- | ------- | ------- | -| ALTER TABLE | No | | -| ANALYZE | No | | -| ATTACH DATABASE | No | | -| BEGIN TRANSACTION | No | | -| COMMIT TRANSACTION | No | | -| CREATE INDEX | No | | -| CREATE TABLE | Partial | | -| CREATE TRIGGER | No | | -| CREATE VIEW | No | | -| CREATE VIRTUAL TABLE | No | | -| DELETE | No | | -| DETACH DATABASE | No | | -| DROP INDEX | No | | -| DROP TABLE | No | | -| DROP TRIGGER | No | | -| DROP VIEW | No | | -| END TRANSACTION | No | | -| EXPLAIN | Yes | | -| INDEXED BY | No | | -| INSERT | Partial | | -| ON CONFLICT clause | No | | -| REINDEX | No | | -| RELEASE SAVEPOINT | No | | -| REPLACE | No | | -| RETURNING clause | No | | -| ROLLBACK TRANSACTION | No | | -| SAVEPOINT | No | | -| SELECT | Yes | | -| SELECT ... WHERE | Yes | | -| SELECT ... WHERE ... LIKE | Yes | | -| SELECT ... LIMIT | Yes | | -| SELECT ... ORDER BY | Yes | | -| SELECT ... GROUP BY | Yes | | -| SELECT ... HAVING | Yes | | -| SELECT ... JOIN | Yes | | +| Statement | Status | Comment | +|---------------------------|---------|-----------------------------------------------------------------------------------| +| ALTER TABLE | No | | +| ANALYZE | No | | +| ATTACH DATABASE | No | | +| BEGIN TRANSACTION | No | | +| COMMIT TRANSACTION | No | | +| CREATE INDEX | No | | +| CREATE TABLE | Partial | | +| CREATE TRIGGER | No | | +| CREATE VIEW | No | | +| CREATE VIRTUAL TABLE | No | | +| DELETE | No | | +| DETACH DATABASE | No | | +| DROP INDEX | No | | +| DROP TABLE | No | | +| DROP TRIGGER | No | | +| DROP VIEW | No | | +| END TRANSACTION | No | | +| EXPLAIN | Yes | | +| INDEXED BY | No | | +| INSERT | Partial | | +| ON CONFLICT clause | No | | +| REINDEX | No | | +| RELEASE SAVEPOINT | No | | +| REPLACE | No | | +| RETURNING clause | No | | +| ROLLBACK TRANSACTION | No | | +| SAVEPOINT | No | | +| SELECT | Yes | | +| SELECT ... WHERE | Yes | | +| SELECT ... WHERE ... LIKE | Yes | | +| SELECT ... LIMIT | Yes | | +| SELECT ... ORDER BY | Yes | | +| SELECT ... GROUP BY | Yes | | +| SELECT ... HAVING | Yes | | +| SELECT ... JOIN | Yes | | | SELECT ... CROSS JOIN | Yes | SQLite CROSS JOIN means "do not reorder joins". We don't support that yet anyway. | -| SELECT ... INNER JOIN | Yes | | -| SELECT ... OUTER JOIN | Partial | no RIGHT JOIN | -| SELECT ... JOIN USING | Yes | | -| SELECT ... NATURAL JOIN | Yes | | -| UPDATE | No | | -| UPSERT | No | | -| VACUUM | No | | -| WITH clause | No | | +| SELECT ... INNER JOIN | Yes | | +| SELECT ... OUTER JOIN | Partial | no RIGHT JOIN | +| SELECT ... JOIN USING | Yes | | +| SELECT ... NATURAL JOIN | Yes | | +| UPDATE | No | | +| UPSERT | No | | +| VACUUM | No | | +| WITH clause | No | | #### [PRAGMA](https://www.sqlite.org/pragma.html) -| Statement | Status | Comment | -|----------------------------------|------------|-------------------------------------------------| -| PRAGMA analysis_limit | No | | -| PRAGMA application_id | No | | -| PRAGMA auto_vacuum | No | | -| PRAGMA automatic_index | No | | -| PRAGMA busy_timeout | No | | -| PRAGMA busy_timeout | No | | -| PRAGMA cache_size | Yes | | -| PRAGMA cache_spill | No | | -| PRAGMA case_sensitive_like | Not Needed | deprecated in SQLite | -| PRAGMA cell_size_check | No | | -| PRAGMA checkpoint_fullsync | No | | -| PRAGMA collation_list | No | | -| PRAGMA compile_options | No | | -| PRAGMA count_changes | Not Needed | deprecated in SQLite | -| PRAGMA data_store_directory | Not Needed | deprecated in SQLite | -| PRAGMA data_version | No | | -| PRAGMA database_list | No | | -| PRAGMA default_cache_size | Not Needed | deprecated in SQLite | -| PRAGMA defer_foreign_keys | No | | -| PRAGMA empty_result_callbacks | Not Needed | deprecated in SQLite | -| PRAGMA encoding | No | | -| PRAGMA foreign_key_check | No | | -| PRAGMA foreign_key_list | No | | -| PRAGMA foreign_keys | No | | -| PRAGMA freelist_count | No | | -| PRAGMA full_column_names | Not Needed | deprecated in SQLite | -| PRAGMA fullsync | No | | -| PRAGMA function_list | No | | -| PRAGMA hard_heap_limit | No | | -| PRAGMA ignore_check_constraints | No | | -| PRAGMA incremental_vacuum | No | | -| PRAGMA index_info | No | | -| PRAGMA index_list | No | | -| PRAGMA index_xinfo | No | | -| PRAGMA integrity_check | No | | -| PRAGMA journal_mode | No | | -| PRAGMA journal_size_limit | No | | -| PRAGMA legacy_alter_table | No | | -| PRAGMA legacy_file_format | No | | -| PRAGMA locking_mode | No | | -| PRAGMA max_page_count | No | | -| PRAGMA mmap_size | No | | -| PRAGMA module_list | No | | -| PRAGMA optimize | No | | -| PRAGMA page_count | No | | -| PRAGMA page_size | No | | -| PRAGMA parser_trace | No | | -| PRAGMA pragma_list | No | | -| PRAGMA query_only | No | | -| PRAGMA quick_check | No | | -| PRAGMA read_uncommitted | No | | -| PRAGMA recursive_triggers | No | | -| PRAGMA reverse_unordered_selects | No | | -| PRAGMA schema_version | No | | -| PRAGMA secure_delete | No | | -| PRAGMA short_column_names | Not Needed | deprecated in SQLite | -| PRAGMA shrink_memory | No | | -| PRAGMA soft_heap_limit | No | | -| PRAGMA stats | No | Used for testing in SQLite | -| PRAGMA synchronous | No | | -| PRAGMA table_info | No | | -| PRAGMA table_list | No | | -| PRAGMA table_xinfo | No | | -| PRAGMA temp_store | No | | -| PRAGMA temp_store_directory | Not Needed | deprecated in SQLite | -| PRAGMA threads | No | | -| PRAGMA trusted_schema | No | | -| PRAGMA user_version | No | | -| PRAGMA vdbe_addoptrace | No | | -| PRAGMA vdbe_debug | No | | -| PRAGMA vdbe_listing | No | | -| PRAGMA vdbe_trace | No | | -| PRAGMA wal_autocheckpoint | No | | -| PRAGMA wal_checkpoint | Partial | Not supported calling with param (pragma-value) | -| PRAGMA writable_schema | No | | +| Statement | Status | Comment | +|----------------------------------|------------|----------------------------------------------| +| PRAGMA analysis_limit | No | | +| PRAGMA application_id | No | | +| PRAGMA auto_vacuum | No | | +| PRAGMA automatic_index | No | | +| PRAGMA busy_timeout | No | | +| PRAGMA busy_timeout | No | | +| PRAGMA cache_size | Yes | | +| PRAGMA cache_spill | No | | +| PRAGMA case_sensitive_like | Not Needed | deprecated in SQLite | +| PRAGMA cell_size_check | No | | +| PRAGMA checkpoint_fullsync | No | | +| PRAGMA collation_list | No | | +| PRAGMA compile_options | No | | +| PRAGMA count_changes | Not Needed | deprecated in SQLite | +| PRAGMA data_store_directory | Not Needed | deprecated in SQLite | +| PRAGMA data_version | No | | +| PRAGMA database_list | No | | +| PRAGMA default_cache_size | Not Needed | deprecated in SQLite | +| PRAGMA defer_foreign_keys | No | | +| PRAGMA empty_result_callbacks | Not Needed | deprecated in SQLite | +| PRAGMA encoding | No | | +| PRAGMA foreign_key_check | No | | +| PRAGMA foreign_key_list | No | | +| PRAGMA foreign_keys | No | | +| PRAGMA freelist_count | No | | +| PRAGMA full_column_names | Not Needed | deprecated in SQLite | +| PRAGMA fullsync | No | | +| PRAGMA function_list | No | | +| PRAGMA hard_heap_limit | No | | +| PRAGMA ignore_check_constraints | No | | +| PRAGMA incremental_vacuum | No | | +| PRAGMA index_info | No | | +| PRAGMA index_list | No | | +| PRAGMA index_xinfo | No | | +| PRAGMA integrity_check | No | | +| PRAGMA journal_mode | Yes | | +| PRAGMA journal_size_limit | No | | +| PRAGMA legacy_alter_table | No | | +| PRAGMA legacy_file_format | No | | +| PRAGMA locking_mode | No | | +| PRAGMA max_page_count | No | | +| PRAGMA mmap_size | No | | +| PRAGMA module_list | No | | +| PRAGMA optimize | No | | +| PRAGMA page_count | Yes | | +| PRAGMA page_size | No | | +| PRAGMA parser_trace | No | | +| PRAGMA pragma_list | Yes | | +| PRAGMA query_only | No | | +| PRAGMA quick_check | No | | +| PRAGMA read_uncommitted | No | | +| PRAGMA recursive_triggers | No | | +| PRAGMA reverse_unordered_selects | No | | +| PRAGMA schema_version | No | | +| PRAGMA secure_delete | No | | +| PRAGMA short_column_names | Not Needed | deprecated in SQLite | +| PRAGMA shrink_memory | No | | +| PRAGMA soft_heap_limit | No | | +| PRAGMA stats | No | Used for testing in SQLite | +| PRAGMA synchronous | No | | +| PRAGMA table_info | Yes | | +| PRAGMA table_list | No | | +| PRAGMA table_xinfo | No | | +| PRAGMA temp_store | No | | +| PRAGMA temp_store_directory | Not Needed | deprecated in SQLite | +| PRAGMA threads | No | | +| PRAGMA trusted_schema | No | | +| PRAGMA user_version | No | | +| PRAGMA vdbe_addoptrace | No | | +| PRAGMA vdbe_debug | No | | +| PRAGMA vdbe_listing | No | | +| PRAGMA vdbe_trace | No | | +| PRAGMA wal_autocheckpoint | No | | +| PRAGMA wal_checkpoint | Partial | Not Needed calling with param (pragma-value) | +| PRAGMA writable_schema | No | | ### Expressions Feature support of [sqlite expr syntax](https://www.sqlite.org/lang_expr.html). -| Syntax | Status | Comment | -|------------------------------|---------|---------| -| literals | Yes | | -| schema.table.column | Partial | Schemas aren't supported | -| unary operator | Yes | | -| binary operator | Partial | Only `%`, `!<`, and `!>` are unsupported | -| agg() FILTER (WHERE ...) | No | Is incorrectly ignored | -| ... OVER (...) | No | Is incorrectly ignored | -| (expr) | Yes | | -| CAST (expr AS type) | Yes | | -| COLLATE | No | | -| (NOT) LIKE | No | | -| (NOT) GLOB | No | | -| (NOT) REGEXP | No | | -| (NOT) MATCH | No | | -| IS (NOT) | No | | -| IS (NOT) DISTINCT FROM | No | | -| (NOT) BETWEEN ... AND ... | No | | -| (NOT) IN (subquery) | No | | -| (NOT) EXISTS (subquery) | No | | -| CASE WHEN THEN ELSE END | Yes | | -| RAISE | No | | +| Syntax | Status | Comment | +|---------------------------|---------|------------------------------------------| +| literals | Yes | | +| schema.table.column | Partial | Schemas aren't supported | +| unary operator | Yes | | +| binary operator | Partial | Only `%`, `!<`, and `!>` are unsupported | +| agg() FILTER (WHERE ...) | No | Is incorrectly ignored | +| ... OVER (...) | No | Is incorrectly ignored | +| (expr) | Yes | | +| CAST (expr AS type) | Yes | | +| COLLATE | No | | +| (NOT) LIKE | Yes | | +| (NOT) GLOB | Yes | | +| (NOT) REGEXP | No | | +| (NOT) MATCH | No | | +| IS (NOT) | Yes | | +| IS (NOT) DISTINCT FROM | Yes | | +| (NOT) BETWEEN ... AND ... | No | | +| (NOT) IN (subquery) | No | | +| (NOT) EXISTS (subquery) | No | | +| CASE WHEN THEN ELSE END | Yes | | +| RAISE | No | | ### SQL functions #### Scalar functions -| Function | Status | Comment | -|------------------------------|--------|---------| -| abs(X) | Yes | | -| changes() | Partial| Still need to support update statements and triggers | -| char(X1,X2,...,XN) | Yes | | -| coalesce(X,Y,...) | Yes | | -| concat(X,...) | Yes | | -| concat_ws(SEP,X,...) | Yes | | -| format(FORMAT,...) | No | | -| glob(X,Y) | Yes | | -| hex(X) | Yes | | -| ifnull(X,Y) | Yes | | -| iif(X,Y,Z) | Yes | | -| instr(X,Y) | Yes | | -| last_insert_rowid() | Yes | | -| length(X) | Yes | | -| like(X,Y) | Yes | | -| like(X,Y,Z) | Yes | | -| likelihood(X,Y) | No | | -| likely(X) | No | | -| load_extension(X) | Yes | sqlite3 extensions not yet supported | -| load_extension(X,Y) | No | | -| lower(X) | Yes | | -| ltrim(X) | Yes | | -| ltrim(X,Y) | Yes | | -| max(X,Y,...) | Yes | | -| min(X,Y,...) | Yes | | -| nullif(X,Y) | Yes | | -| octet_length(X) | Yes | | -| printf(FORMAT,...) | No | | -| quote(X) | Yes | | -| random() | Yes | | -| randomblob(N) | Yes | | -| replace(X,Y,Z) | Yes | | -| round(X) | Yes | | -| round(X,Y) | Yes | | -| rtrim(X) | Yes | | -| rtrim(X,Y) | Yes | | -| sign(X) | Yes | | -| soundex(X) | Yes | | -| sqlite_compileoption_get(N) | No | | -| sqlite_compileoption_used(X) | No | | -| sqlite_offset(X) | No | | -| sqlite_source_id() | No | | -| sqlite_version() | Yes | | -| substr(X,Y,Z) | Yes | | -| substr(X,Y) | Yes | | -| substring(X,Y,Z) | Yes | | -| substring(X,Y) | Yes | | -| total_changes() | Partial| Still need to support update statements and triggers | -| trim(X) | Yes | | -| trim(X,Y) | Yes | | -| typeof(X) | Yes | | -| unhex(X) | Yes | | -| unhex(X,Y) | Yes | | -| unicode(X) | Yes | | -| unlikely(X) | No | | -| upper(X) | Yes | | -| zeroblob(N) | Yes | | +| Function | Status | Comment | +|------------------------------|---------|------------------------------------------------------| +| abs(X) | Yes | | +| changes() | Partial | Still need to support update statements and triggers | +| char(X1,X2,...,XN) | Yes | | +| coalesce(X,Y,...) | Yes | | +| concat(X,...) | Yes | | +| concat_ws(SEP,X,...) | Yes | | +| format(FORMAT,...) | No | | +| glob(X,Y) | Yes | | +| hex(X) | Yes | | +| ifnull(X,Y) | Yes | | +| iif(X,Y,Z) | Yes | | +| instr(X,Y) | Yes | | +| last_insert_rowid() | Yes | | +| length(X) | Yes | | +| like(X,Y) | Yes | | +| like(X,Y,Z) | Yes | | +| likelihood(X,Y) | No | | +| likely(X) | No | | +| load_extension(X) | Yes | sqlite3 extensions not yet supported | +| load_extension(X,Y) | No | | +| lower(X) | Yes | | +| ltrim(X) | Yes | | +| ltrim(X,Y) | Yes | | +| max(X,Y,...) | Yes | | +| min(X,Y,...) | Yes | | +| nullif(X,Y) | Yes | | +| octet_length(X) | Yes | | +| printf(FORMAT,...) | Yes | Still need support additional modifiers | +| quote(X) | Yes | | +| random() | Yes | | +| randomblob(N) | Yes | | +| replace(X,Y,Z) | Yes | | +| round(X) | Yes | | +| round(X,Y) | Yes | | +| rtrim(X) | Yes | | +| rtrim(X,Y) | Yes | | +| sign(X) | Yes | | +| soundex(X) | Yes | | +| sqlite_compileoption_get(N) | No | | +| sqlite_compileoption_used(X) | No | | +| sqlite_offset(X) | No | | +| sqlite_source_id() | Yes | | +| sqlite_version() | Yes | | +| substr(X,Y,Z) | Yes | | +| substr(X,Y) | Yes | | +| substring(X,Y,Z) | Yes | | +| substring(X,Y) | Yes | | +| total_changes() | Partial | Still need to support update statements and triggers | +| trim(X) | Yes | | +| trim(X,Y) | Yes | | +| typeof(X) | Yes | | +| unhex(X) | Yes | | +| unhex(X,Y) | Yes | | +| unicode(X) | Yes | | +| unlikely(X) | No | | +| upper(X) | Yes | | +| zeroblob(N) | Yes | | #### Mathematical functions | Function | Status | Comment | -| ---------- | ------ | ------- | +|------------|--------|---------| | acos(X) | Yes | | | acosh(X) | Yes | | | asin(X) | Yes | | @@ -348,7 +351,7 @@ Modifiers: #### JSON functions | Function | Status | Comment | -|------------------------------------|---------|----------------------------------------------------------------------------------------------------------------------------------------------| +| ---------------------------------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------- | | json(json) | Partial | | | jsonb(json) | | | | json_array(value1,value2,...) | Yes | | @@ -364,14 +367,14 @@ Modifiers: | jsonb_insert(json,path,value,...) | | | | json_object(label1,value1,...) | Yes | When keys are duplicated, only the last one processed is returned. This differs from sqlite, where the keys in the output can be duplicated | | jsonb_object(label1,value1,...) | | | -| json_patch(json1,json2) | | | +| json_patch(json1,json2) | Yes | | | jsonb_patch(json1,json2) | | | -| json_pretty(json) | | | -| json_remove(json,path,...) | | | +| json_pretty(json) | Partial | Shares same json(val) limitations. Also, when passing blobs for indentation, conversion is not exactly the same as in SQLite | +| json_remove(json,path,...) | Partial | Uses same json path parser as json_extract so shares same limitations. | | jsonb_remove(json,path,...) | | | | json_replace(json,path,value,...) | | | | jsonb_replace(json,path,value,...) | | | -| json_set(json,path,value,...) | | | +| json_set(json,path,value,...) | Yes | | | jsonb_set(json,path,value,...) | | | | json_type(json) | Yes | | | json_type(json,path) | Yes | | @@ -400,178 +403,193 @@ Modifiers: ## SQLite VDBE opcodes -| Opcode | Status | -|----------------|--------| -| Add | Yes | -| AddImm | No | -| Affinity | No | -| AggFinal | Yes | -| AggStep | Yes | -| AggStep | Yes | -| And | Yes | -| AutoCommit | No | -| BitAnd | Yes | -| BitNot | Yes | -| BitOr | Yes | -| Blob | Yes | -| Checkpoint | No | -| Clear | No | -| Close | No | -| CollSeq | No | -| Column | Yes | -| Compare | Yes | -| Concat | Yes | -| Copy | Yes | -| Count | No | -| CreateIndex | No | -| CreateTable | No | -| DecrJumpZero | Yes | -| Delete | No | -| Destroy | No | -| Divide | Yes | -| DropIndex | No | -| DropTable | No | -| DropTrigger | No | -| EndCoroutine | Yes | -| Eq | Yes | -| Expire | No | -| Explain | No | -| FkCounter | No | -| FkIfZero | No | -| Found | No | -| Function | Yes | -| Ge | Yes | -| Gosub | Yes | -| Goto | Yes | -| Gt | Yes | -| Halt | Yes | -| HaltIfNull | No | -| IdxDelete | No | -| IdxGE | Yes | -| IdxInsert | No | -| IdxLT | No | -| IdxRowid | No | -| If | Yes | -| IfNeg | No | -| IfNot | Yes | -| IfPos | Yes | -| IfZero | No | -| IncrVacuum | No | -| Init | Yes | -| InitCoroutine | Yes | -| Insert | No | -| InsertAsync | Yes | -| InsertAwait | Yes | -| InsertInt | No | -| Int64 | No | -| Integer | Yes | -| IntegrityCk | No | -| IsNull | Yes | -| IsUnique | No | -| JournalMode | No | -| Jump | Yes | -| Last | No | -| Le | Yes | -| LoadAnalysis | No | -| Lt | Yes | -| MakeRecord | Yes | -| MaxPgcnt | No | -| MemMax | No | -| Move | No | -| Multiply | Yes | -| MustBeInt | Yes | -| Ne | Yes | -| NewRowid | Yes | -| Next | No | -| NextAsync | Yes | -| NextAwait | Yes | -| Noop | No | -| Not | Yes | -| NotExists | Yes | -| NotFound | No | -| NotNull | Yes | -| Null | Yes | -| NullRow | Yes | -| Once | No | -| OpenAutoindex | No | -| OpenEphemeral | No | -| OpenPseudo | Yes | -| OpenRead | Yes | -| OpenReadAsync | Yes | -| OpenWrite | No | -| OpenWriteAsync | Yes | -| OpenWriteAwait | Yes | -| Or | Yes | -| Pagecount | No | -| Param | No | -| ParseSchema | No | -| Permutation | No | -| Prev | No | -| PrevAsync | Yes | -| PrevAwait | Yes | -| Program | No | -| ReadCookie | No | -| Real | Yes | -| RealAffinity | Yes | -| Remainder | Yes | -| ResetCount | No | -| ResultRow | Yes | -| Return | Yes | -| Rewind | Yes | -| RewindAsync | Yes | -| RewindAwait | Yes | -| RowData | No | -| RowId | Yes | -| RowKey | No | -| RowSetAdd | No | -| RowSetRead | No | -| RowSetTest | No | -| Rowid | Yes | -| SCopy | No | -| Savepoint | No | -| Seek | No | -| SeekGe | Yes | -| SeekGt | Yes | -| SeekLe | No | -| SeekLt | No | -| SeekRowid | Yes | -| Sequence | No | -| SetCookie | No | -| ShiftLeft | Yes | -| ShiftRight | Yes | -| SoftNull | Yes | -| Sort | No | -| SorterCompare | No | -| SorterData | Yes | -| SorterInsert | Yes | -| SorterNext | Yes | -| SorterOpen | Yes | -| SorterSort | Yes | -| String | No | -| String8 | Yes | -| Subtract | Yes | -| TableLock | No | -| ToBlob | No | -| ToInt | No | -| ToNumeric | No | -| ToReal | No | -| ToText | No | -| Trace | No | -| Transaction | Yes | -| VBegin | No | -| VColumn | No | -| VCreate | No | -| VDestroy | No | -| VFilter | No | -| VNext | No | -| VOpen | No | -| VRename | No | -| VUpdate | No | -| Vacuum | No | -| Variable | No | -| VerifyCookie | No | -| Yield | Yes | -| ZeroOrNull | Yes | +| Opcode | Status | Comment | +|----------------|--------|---------| +| Add | Yes | | +| AddImm | No | | +| Affinity | No | | +| AggFinal | Yes | | +| AggStep | Yes | | +| AggStep | Yes | | +| And | Yes | | +| AutoCommit | No | | +| BitAnd | Yes | | +| BitNot | Yes | | +| BitOr | Yes | | +| Blob | Yes | | +| Checkpoint | No | | +| Clear | No | | +| Close | No | | +| CollSeq | No | | +| Column | Yes | | +| Compare | Yes | | +| Concat | Yes | | +| Copy | Yes | | +| Count | No | | +| CreateBTree | Partial| no temp databases | +| CreateTable | No | | +| CreateTable | No | | +| DecrJumpZero | Yes | | +| Delete | No | | +| Destroy | No | | +| Divide | Yes | | +| DropIndex | No | | +| DropTable | No | | +| DropTrigger | No | | +| EndCoroutine | Yes | | +| Eq | Yes | | +| Expire | No | | +| Explain | No | | +| FkCounter | No | | +| FkIfZero | No | | +| Found | No | | +| Function | Yes | | +| Ge | Yes | | +| Gosub | Yes | | +| Goto | Yes | | +| Gt | Yes | | +| Halt | Yes | | +| HaltIfNull | No | | +| IdxDelete | No | | +| IdxGE | Yes | | +| IdxInsert | No | | +| IdxLT | No | | +| IdxRowid | No | | +| If | Yes | | +| IfNeg | No | | +| IfNot | Yes | | +| IfPos | Yes | | +| IfZero | No | | +| IncrVacuum | No | | +| Init | Yes | | +| InitCoroutine | Yes | | +| Insert | No | | +| InsertAsync | Yes | | +| InsertAwait | Yes | | +| InsertInt | No | | +| Int64 | No | | +| Integer | Yes | | +| IntegrityCk | No | | +| IsNull | Yes | | +| IsUnique | No | | +| JournalMode | No | | +| Jump | Yes | | +| Last | No | | +| Le | Yes | | +| LoadAnalysis | No | | +| Lt | Yes | | +| MakeRecord | Yes | | +| MaxPgcnt | No | | +| MemMax | No | | +| Move | No | | +| Multiply | Yes | | +| MustBeInt | Yes | | +| Ne | Yes | | +| NewRowid | Yes | | +| Next | No | | +| NextAsync | Yes | | +| NextAwait | Yes | | +| Noop | Yes | | +| Not | Yes | | +| NotExists | Yes | | +| NotFound | No | | +| NotNull | Yes | | +| Null | Yes | | +| NullRow | Yes | | +| Once | No | | +| OpenAutoindex | No | | +| OpenEphemeral | No | | +| OpenPseudo | Yes | | +| OpenRead | Yes | | +| OpenReadAsync | Yes | | +| OpenWrite | No | | +| OpenWriteAsync | Yes | | +| OpenWriteAwait | Yes | | +| Or | Yes | | +| Pagecount | Partial| no temp databases | +| Param | No | | +| ParseSchema | No | | +| Permutation | No | | +| Prev | No | | +| PrevAsync | Yes | | +| PrevAwait | Yes | | +| Program | No | | +| ReadCookie | No | | +| Real | Yes | | +| RealAffinity | Yes | | +| Remainder | Yes | | +| ResetCount | No | | +| ResultRow | Yes | | +| Return | Yes | | +| Rewind | Yes | | +| RewindAsync | Yes | | +| RewindAwait | Yes | | +| RowData | No | | +| RowId | Yes | | +| RowKey | No | | +| RowSetAdd | No | | +| RowSetRead | No | | +| RowSetTest | No | | +| Rowid | Yes | | +| SCopy | No | | +| Savepoint | No | | +| Seek | No | | +| SeekGe | Yes | | +| SeekGt | Yes | | +| SeekLe | No | | +| SeekLt | No | | +| SeekRowid | Yes | | +| Sequence | No | | +| SetCookie | No | | +| ShiftLeft | Yes | | +| ShiftRight | Yes | | +| SoftNull | Yes | | +| Sort | No | | +| SorterCompare | No | | +| SorterData | Yes | | +| SorterInsert | Yes | | +| SorterNext | Yes | | +| SorterOpen | Yes | | +| SorterSort | Yes | | +| String | No | | +| String8 | Yes | | +| Subtract | Yes | | +| TableLock | No | | +| ToBlob | No | | +| ToInt | No | | +| ToNumeric | No | | +| ToReal | No | | +| ToText | No | | +| Trace | No | | +| Transaction | Yes | | +| VBegin | No | | +| VColumn | No | | +| VCreate | No | | +| VDestroy | No | | +| VFilter | No | | +| VNext | No | | +| VOpen | No | | +| VRename | No | | +| VUpdate | No | | +| Vacuum | No | | +| Variable | No | | +| VerifyCookie | No | | +| Yield | Yes | | +| ZeroOrNull | Yes | | + +## [SQLite journaling modes](https://www.sqlite.org/pragma.html#pragma_journal_mode) + +We currently don't have plan to support the rollback journal mode as it locks the database file during writes. +Therefore, all rollback-type modes (delete, truncate, persist, memory) are marked are `Not Needed` below. + +| Journal mode | Status | Comment | +|--------------|------------|--------------------------------| +| wal | Yes | | +| wal2 | No | experimental feature in sqlite | +| delete | Not Needed | | +| truncate | Not Needed | | +| persist | Not Needed | | +| memory | Not Needed | | ## Extensions @@ -581,7 +599,7 @@ Limbo has in-tree extensions. UUID's in Limbo are `blobs` by default. -| Function | Status | Comment | +| Function | Status | Comment | |-----------------------|--------|---------------------------------------------------------------| | uuid4() | Yes | UUID version 4 | | uuid4_str() | Yes | UUID v4 string alias `gen_random_uuid()` for PG compatibility | @@ -594,10 +612,75 @@ UUID's in Limbo are `blobs` by default. The `regexp` extension is compatible with [sqlean-regexp](https://github.com/nalgeon/sqlean/blob/main/docs/regexp.md). -| Function | Status | Comment | +| Function | Status | Comment | |------------------------------------------------|--------|---------| | regexp(pattern, source) | Yes | | | regexp_like(source, pattern) | Yes | | | regexp_substr(source, pattern) | Yes | | | regexp_capture(source, pattern[, n]) | No | | | regexp_replace(source, pattern, replacement) | No | | + +### Vector + +The `vector` extension is compatible with libSQL native vector search. + +| Function | Status | Comment | +|------------------------------------------------|--------|---------| +| vector(x) | Yes | | +| vector32(x) | Yes | | +| vector64(x) | Yes | | +| vector_extract(x) | Yes | | +| vector_distance_cos(x, y) | Yes | | + +### Time + +The `time` extension is compatible with [sqlean-time](https://github.com/nalgeon/sqlean/blob/main/docs/time.md). + + +| Function | Status | Comment | +| ------------------------------------------------------------------- | ------ | ---------------------------- | +| time_now() | Yes | | +| time_date(year, month, day[, hour, min, sec[, nsec[, offset_sec]]]) | Yes | offset_sec is not normalized | +| time_get_year(t) | Yes | | +| time_get_month(t) | Yes | | +| time_get_day(t) | Yes | | +| time_get_hour(t) | Yes | | +| time_get_minute(t) | Yes | | +| time_get_second(t) | Yes | | +| time_get_nano(t) | Yes | | +| time_get_weekday(t) | Yes | | +| time_get_yearday(t) | Yes | | +| time_get_isoyear(t) | Yes | | +| time_get_isoweek(t) | Yes | | +| time_get(t, field) | Yes | | +| time_unix(sec[, nsec]) | Yes | | +| time_milli(msec) | Yes | | +| time_micro(usec) | Yes | | +| time_nano(nsec) | Yes | | +| time_to_unix(t) | Yes | | +| time_to_milli(t) | Yes | | +| time_to_micro(t) | Yes | | +| time_to_nano(t) | Yes | | +| time_after(t, u) | Yes | | +| time_before(t, u) | Yes | | +| time_compare(t, u) | Yes | | +| time_equal(t, u) | Yes | | +| time_add(t, d) | Yes | | +| time_add_date(t, years[, months[, days]]) | Yes | | +| time_sub(t, u) | Yes | | +| time_since(t) | Yes | | +| time_until(t) | Yes | | +| time_trunc(t, field) | Yes | | +| time_trunc(t, d) | Yes | | +| time_round(t, d) | Yes | | +| time_fmt_iso(t[, offset_sec]) | Yes | | +| time_fmt_datetime(t[, offset_sec]) | Yes | | +| time_fmt_date(t[, offset_sec]) | Yes | | +| time_fmt_time(t[, offset_sec]) | Yes | | +| time_parse(s) | Yes | | +| dur_ns() | Yes | | +| dur_us() | Yes | | +| dur_ms() | Yes | | +| dur_s() | Yes | | +| dur_m() | Yes | | +| dur_h() | Yes | | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2f2b99922..cad33a8db 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,7 @@ We'd love to have you contribute to Limbo! This document is a quick helper to get you going. -## Getting started +## Getting Started Limbo is a rewrite of SQLite in Rust. If you are new to SQLite, the following articles and books are a good starting point: @@ -19,7 +19,47 @@ If you are new to Rust, the following books are recommended reading: Examples of contributing -* [How to contribute a SQL function implementation](docs/internals/functions.md) +* [How to contribute a SQL function implementation](docs/contributing/contributing_functions.md) + +To build and run `limbo` cli: + +```shell +cargo run --package limbo --bin limbo database.db +``` + +Run tests: + +```console +cargo test +``` + +Test coverage report: + +``` +cargo tarpaulin -o html +``` + +> [!NOTE] +> Generation of coverage report requires [tarpaulin](https://github.com/xd009642/tarpaulin) binary to be installed. +> You can install it with `cargo install cargo-tarpaulin` + +[//]: # (TODO remove the below tip when the bug is solved) + +> [!TIP] +> If coverage fails with "Test failed during run" error and all of the tests passed it might be the result of tarpaulin [bug](https://github.com/xd009642/tarpaulin/issues/1642). You can temporarily set [dynamic libraries linking manually](https://doc.rust-lang.org/cargo/reference/environment-variables.html#dynamic-library-paths) as a workaround, e.g. for linux `LD_LIBRARY_PATH="$(rustc --print=target-libdir)" cargo tarpaulin -o html`. + +Run benchmarks: + +```console +cargo bench +``` + +Run benchmarks and generate flamegraphs: + +```console +echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid +cargo bench --bench benchmark -- --profile-time=5 +``` ## Finding things to work on diff --git a/Cargo.lock b/Cargo.lock index 1fda860f2..a412add52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,10 +24,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -60,7 +60,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18a1e15a87b13ae79e04e07b3714fc41d5f6993dff11662fdbe0b207c6ad0fe0" dependencies = [ - "rand", + "rand 0.8.5", ] [[package]] @@ -224,6 +224,16 @@ dependencies = [ "serde", ] +[[package]] +name = "built" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b" +dependencies = [ + "chrono", + "git2", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -260,6 +270,8 @@ version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -451,7 +463,7 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "core_tester" -version = "0.0.13" +version = "0.0.14" dependencies = [ "anyhow", "assert_cmd", @@ -460,6 +472,8 @@ dependencies = [ "env_logger 0.10.2", "limbo_core", "log", + "rand 0.9.0", + "rand_chacha 0.9.0", "rexpect", "rusqlite", "rustyline", @@ -540,6 +554,16 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-skiplist" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df29de440c58ca2cc6e587ec3d22347551a32435fbde9d2bff64e78a9ffa151b" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -639,6 +663,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "doc-comment" version = "0.3.3" @@ -672,6 +707,16 @@ dependencies = [ "log", ] +[[package]] +name = "env_logger" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" +dependencies = [ + "log", + "regex", +] + [[package]] name = "env_logger" version = "0.10.2" @@ -806,6 +851,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "fragile" version = "2.0.0" @@ -935,16 +989,41 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", +] + [[package]] name = "gimli" version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "git2" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b903b73e45dc0c6c596f2d37eccece7c1c8bb6e4407b001096387c63d0d93724" +dependencies = [ + "bitflags 2.8.0", + "libc", + "libgit2-sys", + "log", + "url", +] + [[package]] name = "glob" version = "0.3.2" @@ -1042,6 +1121,145 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + +[[package]] +name = "idna" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + [[package]] name = "indexmap" version = "2.7.0" @@ -1147,7 +1365,7 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "java-limbo" -version = "0.0.13" +version = "0.0.14" dependencies = [ "jni", "limbo_core", @@ -1176,6 +1394,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.77" @@ -1197,7 +1424,7 @@ dependencies = [ "itoa", "nom", "ordered-float", - "rand", + "rand 0.8.5", "ryu", "serde_json", ] @@ -1243,6 +1470,18 @@ version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +[[package]] +name = "libgit2-sys" +version = "0.17.0+1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10472326a8a6477c3c20a64547b0059e4b0d086869eee31e6d7da728a8eb7224" +dependencies = [ + "cc", + "libc", + "libz-sys", + "pkg-config", +] + [[package]] name = "libloading" version = "0.8.6" @@ -1285,6 +1524,18 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libz-sys" +version = "1.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2d16453e800a8cf6dd2fc3eb4bc99b786a9b90c663b8559a5b1a041bf89e472" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "limbo" version = "0.0.13" @@ -1303,14 +1554,14 @@ dependencies = [ [[package]] name = "limbo-go" -version = "0.0.13" +version = "0.0.14" dependencies = [ "limbo_core", ] [[package]] name = "limbo-wasm" -version = "0.0.13" +version = "0.0.14" dependencies = [ "console_error_panic_hook", "js-sys", @@ -1322,14 +1573,16 @@ dependencies = [ [[package]] name = "limbo_core" -version = "0.0.13" +version = "0.0.14" dependencies = [ + "built", "bumpalo", "cfg_block", "chrono", "criterion", + "crossbeam-skiplist", "fallible-iterator 0.3.0", - "getrandom", + "getrandom 0.2.15", "hex", "indexmap", "io-uring", @@ -1339,15 +1592,21 @@ dependencies = [ "libloading", "limbo_ext", "limbo_macros", + "limbo_percentile", + "limbo_regexp", + "limbo_time", + "limbo_uuid", + "limbo_vector", "log", "miette", "mimalloc", "mockall", + "parking_lot", "pest", "pest_derive", "polling", "pprof", - "rand", + "rand 0.8.5", "regex", "regex-syntax", "rstest", @@ -1356,14 +1615,15 @@ dependencies = [ "serde", "sieve-cache", "sqlite3-parser", + "strum", "tempfile", "thiserror 1.0.69", - "uuid", + "tracing", ] [[package]] name = "limbo_ext" -version = "0.0.13" +version = "0.0.14" dependencies = [ "limbo_macros", "log", @@ -1371,7 +1631,7 @@ dependencies = [ [[package]] name = "limbo_libsql" -version = "0.0.13" +version = "0.0.14" dependencies = [ "limbo_core", "thiserror 2.0.11", @@ -1380,7 +1640,7 @@ dependencies = [ [[package]] name = "limbo_macros" -version = "0.0.13" +version = "0.0.14" dependencies = [ "proc-macro2", "quote", @@ -1389,23 +1649,25 @@ dependencies = [ [[package]] name = "limbo_percentile" -version = "0.0.13" +version = "0.0.14" dependencies = [ "limbo_ext", + "mimalloc", ] [[package]] name = "limbo_regexp" -version = "0.0.13" +version = "0.0.14" dependencies = [ "limbo_ext", "log", + "mimalloc", "regex", ] [[package]] name = "limbo_sim" -version = "0.0.13" +version = "0.0.14" dependencies = [ "anarchist-readable-name-generator-lib", "clap", @@ -1413,10 +1675,8 @@ dependencies = [ "limbo_core", "log", "notify", - "rand", - "rand_chacha", - "regex", - "regex-syntax", + "rand 0.8.5", + "rand_chacha 0.3.1", "serde", "serde_json", "tempfile", @@ -1424,7 +1684,7 @@ dependencies = [ [[package]] name = "limbo_sqlite3" -version = "0.0.13" +version = "0.0.14" dependencies = [ "env_logger 0.11.6", "libc", @@ -1432,21 +1692,49 @@ dependencies = [ "log", ] +[[package]] +name = "limbo_time" +version = "0.0.14" +dependencies = [ + "chrono", + "limbo_ext", + "mimalloc", + "strum", + "strum_macros", + "thiserror 2.0.11", +] + [[package]] name = "limbo_uuid" -version = "0.0.13" +version = "0.0.14" dependencies = [ "limbo_ext", - "log", + "mimalloc", "uuid", ] +[[package]] +name = "limbo_vector" +version = "0.0.14" +dependencies = [ + "limbo_ext", + "quickcheck", + "quickcheck_macros", + "rand 0.8.5", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +[[package]] +name = "litemap" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" + [[package]] name = "lock_api" version = "0.4.12" @@ -1550,7 +1838,7 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -1742,6 +2030,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "pest" version = "2.7.15" @@ -1813,7 +2107,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", - "rand", + "rand 0.8.5", ] [[package]] @@ -1922,7 +2216,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -1963,7 +2257,7 @@ dependencies = [ [[package]] name = "py-limbo" -version = "0.0.13" +version = "0.0.14" dependencies = [ "anyhow", "limbo_core", @@ -2045,6 +2339,28 @@ dependencies = [ "memchr", ] +[[package]] +name = "quickcheck" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6" +dependencies = [ + "env_logger 0.8.4", + "log", + "rand 0.8.5", +] + +[[package]] +name = "quickcheck_macros" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b22a693222d716a9587786f37ac3f6b4faedb5b80c23914e7303ff5a1d8016e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "quote" version = "1.0.38" @@ -2071,8 +2387,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.0", + "zerocopy 0.8.14", ] [[package]] @@ -2082,7 +2409,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.0", ] [[package]] @@ -2091,7 +2428,17 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", +] + +[[package]] +name = "rand_core" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b08f3c9802962f7e1b25113931d94f43ed9725bebc59db9d0c3e9a23b67e15ff" +dependencies = [ + "getrandom 0.3.1", + "zerocopy 0.8.14", ] [[package]] @@ -2129,7 +2476,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", "thiserror 1.0.69", ] @@ -2429,6 +2776,8 @@ dependencies = [ "phf", "phf_codegen", "phf_shared", + "strum", + "strum_macros", "uncased", ] @@ -2456,6 +2805,28 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.96", +] + [[package]] name = "supports-color" version = "3.0.2" @@ -2522,6 +2893,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "target-lexicon" version = "0.12.16" @@ -2536,7 +2918,7 @@ checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.2.15", "once_cell", "rustix", "windows-sys 0.59.0", @@ -2617,6 +2999,16 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -2663,14 +3055,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "tracing-core" version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +dependencies = [ + "once_cell", +] [[package]] name = "typenum" @@ -2723,6 +3130,29 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +[[package]] +name = "url" +version = "2.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -2735,7 +3165,7 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -2775,6 +3205,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -3110,6 +3549,51 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags 2.8.0", +] + +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -3117,7 +3601,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a367f292d93d4eab890745e75a778da40909cab4d6ff8173693812f79c4a2468" +dependencies = [ + "zerocopy-derive 0.8.14", ] [[package]] @@ -3130,3 +3623,57 @@ dependencies = [ "quote", "syn 2.0.96", ] + +[[package]] +name = "zerocopy-derive" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3931cb58c62c13adec22e38686b559c86a30565e16ad6e8510a337cedc611e1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + +[[package]] +name = "zerofrom" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", + "synstructure", +] + +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] diff --git a/Cargo.toml b/Cargo.toml index 0ffbdf6ac..85b8c46d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,11 +18,13 @@ members = [ "sqlite3", "tests", "extensions/percentile", + "extensions/vector", + "extensions/time", ] exclude = ["perf/latency/limbo"] [workspace.package] -version = "0.0.13" +version = "0.0.14" authors = ["the Limbo authors"] edition = "2021" license = "MIT" diff --git a/Makefile b/Makefile index 109a3f147..e1fdf6d29 100644 --- a/Makefile +++ b/Makefile @@ -62,11 +62,11 @@ limbo-wasm: cargo build --package limbo-wasm --target wasm32-wasi .PHONY: limbo-wasm -test: limbo test-compat test-sqlite3 test-shell test-extensions +test: limbo test-compat test-vector test-sqlite3 test-shell test-extensions .PHONY: test test-extensions: limbo - cargo build --package limbo_uuid + cargo build --package limbo_regexp ./testing/extensions.py .PHONY: test-extensions @@ -78,6 +78,14 @@ test-compat: SQLITE_EXEC=$(SQLITE_EXEC) ./testing/all.test .PHONY: test-compat +test-vector: + SQLITE_EXEC=$(SQLITE_EXEC) ./testing/vector.test +.PHONY: test-vector + +test-time: + SQLITE_EXEC=$(SQLITE_EXEC) ./testing/time.test +.PHONY: test-time + test-sqlite3: limbo-c LIBS="$(SQLITE_LIB)" HEADERS="$(SQLITE_LIB_HEADERS)" make -C sqlite3/tests test .PHONY: test-sqlite3 diff --git a/README.md b/README.md index 85fb2f1ea..654ba3193 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@

- Limbo -

Limbo

+ Limbo +

Project Limbo

- Limbo is a work-in-progress, in-process OLTP database management system, compatible with SQLite. + Limbo is a project to build the modern evolution of SQLite.

@@ -17,37 +17,38 @@

- Chat on Discord + Chat with developers on Discord

--- ## Features -Limbo is an in-process OLTP database engine library that has: +Limbo is a _work-in-progress_, in-process OLTP database engine library written in Rust that has: * **Asynchronous I/O** support on Linux with `io_uring` * **SQLite compatibility** [[doc](COMPAT.md)] for SQL dialect, file formats, and the C API -* **Language bindings** for JavaScript/WebAssembly, Rust, Python, and Java +* **Language bindings** for JavaScript/WebAssembly, Rust, Go, Python, and [Java](bindings/java) * **OS support** for Linux, macOS, and Windows ## Getting Started -### CLI +### 💻 Command Line -Install `limbo` with: +You can install the latest `limbo` release with: ```shell curl --proto '=https' --tlsv1.2 -LsSf \ https://github.com/tursodatabase/limbo/releases/latest/download/limbo-installer.sh | sh ``` -Then use the SQL shell to create and query a database: +Then launch the shell to execute SQL statements: ```console -$ limbo database.db -Limbo v0.0.6 +Limbo Enter ".help" for usage hints. +Connected to a transient in-memory database. +Use ".open FILENAME" to reopen on a persistent database limbo> CREATE TABLE users (id INT PRIMARY KEY, username TEXT); limbo> INSERT INTO users VALUES (1, 'alice'); limbo> INSERT INTO users VALUES (2, 'bob'); @@ -56,7 +57,13 @@ limbo> SELECT * FROM users; 2|bob ``` -### JavaScript (wip) +You can also build and run the latest development version with: + +```shell +cargo run +``` + +### ✨ JavaScript (wip) Installation: @@ -75,7 +82,7 @@ const users = stmt.all(); console.log(users); ``` -### Python (wip) +### 🐍 Python (wip) ```console pip install pylimbo @@ -92,63 +99,60 @@ res = cur.execute("SELECT * FROM users") print(res.fetchone()) ``` -## Developing +### 🐹 Go (wip) -Build and run `limbo` cli: - -```shell -cargo run --package limbo --bin limbo database.db +1. Clone the repository +2. Build the library and set your LD_LIBRARY_PATH to include limbo's target directory +```console +cargo build --package limbo-go +export LD_LIBRARY_PATH=/path/to/limbo/target/debug:$LD_LIBRARY_PATH ``` - -Run tests: +3. Use the driver ```console -cargo test +go get github.com/tursodatabase/limbo +go install github.com/tursodatabase/limbo ``` -Test coverage report: +Example usage: +```go +import ( + "database/sql" + _"github.com/tursodatabase/limbo" +) -``` -cargo tarpaulin -o html +conn, _ = sql.Open("sqlite3", "sqlite.db") +defer conn.Close() + +stmt, _ := conn.Prepare("select * from users") +defer stmt.Close() + +rows, _ = stmt.Query() +for rows.Next() { + var id int + var username string + _ := rows.Scan(&id, &username) + fmt.Printf("User: ID: %d, Username: %s\n", id, username) +} ``` -> [!NOTE] -> Generation of coverage report requires [tarpaulin](https://github.com/xd009642/tarpaulin) binary to be installed. -> You can install it with `cargo install cargo-tarpaulin` +## Contributing -[//]: # (TODO remove the below tip when the bug is solved) - -> [!TIP] -> If coverage fails with "Test failed during run" error and all of the tests passed it might be the result of tarpaulin [bug](https://github.com/xd009642/tarpaulin/issues/1642). You can temporarily set [dynamic libraries linking manually](https://doc.rust-lang.org/cargo/reference/environment-variables.html#dynamic-library-paths) as a workaround, e.g. for linux `LD_LIBRARY_PATH="$(rustc --print=target-libdir)" cargo tarpaulin -o html`. - -Run benchmarks: - -```console -cargo bench -``` - -Run benchmarks and generate flamegraphs: - -```console -echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid -cargo bench --bench benchmark -- --profile-time=5 -``` +We'd love to have you contribute to Limbo! Please check out the [contribution guide] to get started. ## FAQ -### How is Limbo different from libSQL? +### How is Limbo different from Turso's libSQL? -Limbo is a research project to build a SQLite compatible in-process database in Rust with native async support. The libSQL project, on the other hand, is an open source, open contribution fork of SQLite, with focus on production features such as replication, backups, encryption, and so on. There is no hard dependency between the two projects. Of course, if Limbo becomes widely successful, we might consider merging with libSQL, but that is something that will be decided in the future. +Limbo is a project to build the modern evolution of SQLite in Rust, with a strong open contribution focus and features like native async support, vector search, and more. The libSQL project is also an attempt to evolve SQLite in a similar direction, but through a fork rather than a rewrite. + +Rewriting SQLite in Rust started as an unassuming experiment, and due to its incredible success, replaces libSQL as our intended direction. At this point, libSQL is production ready, Limbo is not - although it is evolving rapidly. As the project start to near production readiness, we plan to rename it to just "Turso". More details [here](https://turso.tech/blog/we-will-rewrite-sqlite-and-we-are-going-all-in). ## Publications * Pekka Enberg, Sasu Tarkoma, Jon Crowcroft Ashwin Rao (2024). Serverless Runtime / Database Co-Design With Asynchronous I/O. In _EdgeSys ‘24_. [[PDF]](https://penberg.org/papers/penberg-edgesys24.pdf) * Pekka Enberg, Sasu Tarkoma, and Ashwin Rao (2023). Towards Database and Serverless Runtime Co-Design. In _CoNEXT-SW ’23_. [[PDF](https://penberg.org/papers/penberg-conext-sw-23.pdf)] [[Slides](https://penberg.org/papers/penberg-conext-sw-23-slides.pdf)] -## Contributing - -We'd love to have you contribute to Limbo! Check out the [contribution guide] to get started. - ## License This project is licensed under the [MIT license]. diff --git a/bindings/go/README.md b/bindings/go/README.md new file mode 100644 index 000000000..ab50140bb --- /dev/null +++ b/bindings/go/README.md @@ -0,0 +1,71 @@ +# Limbo driver for Go's `database/sql` library + + +**NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state. + +This driver uses the awesome [purego](https://github.com/ebitengine/purego) library to call C (in this case Rust with C ABI) functions from Go without the use of `CGO`. + + +## To use: (_UNSTABLE_ testing or development purposes only) + +### Linux | MacOS + +_All commands listed are relative to the bindings/go directory in the limbo repository_ + +``` +cargo build --package limbo-go + +# Your LD_LIBRARY_PATH environment variable must include limbo's `target/debug` directory + +export LD_LIBRARY_PATH="/path/to/limbo/target/debug:$LD_LIBRARY_PATH" + +``` + +## Windows + +``` +cargo build --package limbo-go + +# You must add limbo's `target/debug` directory to your PATH +# or you could built + copy the .dll to a location in your PATH +# or just the CWD of your go module + +cp path\to\limbo\target\debug\lib_limbo_go.dll . + +go test + +``` +**Temporarily** you may have to clone the limbo repository and run: + +`go mod edit -replace github.com/tursodatabase/limbo=/path/to/limbo/bindings/go` + +```go +import ( + "fmt" + "database/sql" + _"github.com/tursodatabase/limbo" +) + +func main() { + conn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + sql := "CREATE table go_limbo (foo INTEGER, bar TEXT)" + _ = conn.Exec(sql) + + sql = "INSERT INTO go_limbo (foo, bar) values (?, ?)" + stmt, _ := conn.Prepare(sql) + defer stmt.Close() + _ = stmt.Exec(42, "limbo") + rows, _ := conn.Query("SELECT * from go_limbo") + defer rows.Close() + for rows.Next() { + var a int + var b string + _ = rows.Scan(&a, &b) + fmt.Printf("%d, %s", a, b) + } +} +``` diff --git a/bindings/go/connection.go b/bindings/go/connection.go new file mode 100644 index 000000000..2d7a27e8b --- /dev/null +++ b/bindings/go/connection.go @@ -0,0 +1,142 @@ +package limbo + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "sync" + + "github.com/ebitengine/purego" +) + +func init() { + err := ensureLibLoaded() + if err != nil { + panic(err) + } + sql.Register(driverName, &limboDriver{}) +} + +type limboDriver struct { + sync.Mutex +} + +var ( + libOnce sync.Once + limboLib uintptr + loadErr error + dbOpen func(string) uintptr + dbClose func(uintptr) uintptr + connPrepare func(uintptr, string) uintptr + connGetError func(uintptr) uintptr + freeBlobFunc func(uintptr) + freeStringFunc func(uintptr) + rowsGetColumns func(uintptr) int32 + rowsGetColumnName func(uintptr, int32) uintptr + rowsGetValue func(uintptr, int32) uintptr + rowsGetError func(uintptr) uintptr + closeRows func(uintptr) uintptr + rowsNext func(uintptr) uintptr + stmtQuery func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr + stmtExec func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 + stmtParamCount func(uintptr) int32 + stmtGetError func(uintptr) uintptr + stmtClose func(uintptr) int32 +) + +// Register all the symbols on library load +func ensureLibLoaded() error { + libOnce.Do(func() { + limboLib, loadErr = loadLibrary() + if loadErr != nil { + return + } + purego.RegisterLibFunc(&dbOpen, limboLib, FfiDbOpen) + purego.RegisterLibFunc(&dbClose, limboLib, FfiDbClose) + purego.RegisterLibFunc(&connPrepare, limboLib, FfiDbPrepare) + purego.RegisterLibFunc(&connGetError, limboLib, FfiDbGetError) + purego.RegisterLibFunc(&freeBlobFunc, limboLib, FfiFreeBlob) + purego.RegisterLibFunc(&freeStringFunc, limboLib, FfiFreeCString) + purego.RegisterLibFunc(&rowsGetColumns, limboLib, FfiRowsGetColumns) + purego.RegisterLibFunc(&rowsGetColumnName, limboLib, FfiRowsGetColumnName) + purego.RegisterLibFunc(&rowsGetValue, limboLib, FfiRowsGetValue) + purego.RegisterLibFunc(&closeRows, limboLib, FfiRowsClose) + purego.RegisterLibFunc(&rowsNext, limboLib, FfiRowsNext) + purego.RegisterLibFunc(&rowsGetError, limboLib, FfiDbGetError) + purego.RegisterLibFunc(&stmtQuery, limboLib, FfiStmtQuery) + purego.RegisterLibFunc(&stmtExec, limboLib, FfiStmtExec) + purego.RegisterLibFunc(&stmtParamCount, limboLib, FfiStmtParameterCount) + purego.RegisterLibFunc(&stmtGetError, limboLib, FfiDbGetError) + purego.RegisterLibFunc(&stmtClose, limboLib, FfiStmtClose) + }) + return loadErr +} + +func (d *limboDriver) Open(name string) (driver.Conn, error) { + d.Lock() + conn, err := openConn(name) + d.Unlock() + if err != nil { + return nil, err + } + return conn, nil +} + +type limboConn struct { + sync.Mutex + ctx uintptr +} + +func openConn(dsn string) (*limboConn, error) { + ctx := dbOpen(dsn) + if ctx == 0 { + return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) + } + return &limboConn{ + sync.Mutex{}, + ctx, + }, loadErr +} + +func (c *limboConn) Close() error { + if c.ctx == 0 { + return nil + } + c.Lock() + dbClose(c.ctx) + c.Unlock() + c.ctx = 0 + return nil +} + +func (c *limboConn) getError() error { + if c.ctx == 0 { + return errors.New("connection closed") + } + err := connGetError(c.ctx) + if err == 0 { + return nil + } + defer freeStringFunc(err) + cpy := fmt.Sprintf("%s", GoString(err)) + return errors.New(cpy) +} + +func (c *limboConn) Prepare(query string) (driver.Stmt, error) { + if c.ctx == 0 { + return nil, errors.New("connection closed") + } + c.Lock() + defer c.Unlock() + stmtPtr := connPrepare(c.ctx, query) + if stmtPtr == 0 { + return nil, c.getError() + } + return newStmt(stmtPtr, query), nil +} + +// begin is needed to implement driver.Conn.. for now not implemented +func (c *limboConn) Begin() (driver.Tx, error) { + return nil, errors.New("transactions not implemented") +} diff --git a/bindings/go/go.mod b/bindings/go/go.mod index 589b9a0e3..a9145591b 100644 --- a/bindings/go/go.mod +++ b/bindings/go/go.mod @@ -1,8 +1,8 @@ -module limbo +module github.com/tursodatabase/limbo go 1.23.4 require ( github.com/ebitengine/purego v0.8.2 - golang.org/x/sys/windows v0.29.0 + golang.org/x/sys v0.29.0 ) diff --git a/bindings/go/limbo.go b/bindings/go/limbo.go deleted file mode 100644 index 4011fb1ac..000000000 --- a/bindings/go/limbo.go +++ /dev/null @@ -1,141 +0,0 @@ -package limbo - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "log/slog" - "os" - "runtime" - "sync" - "unsafe" - - "github.com/ebitengine/purego" - "golang.org/x/sys/windows" -) - -const limbo = "../../target/debug/lib_limbo_go" -const driverName = "limbo" - -var limboLib uintptr - -func getSystemLibrary() error { - switch runtime.GOOS { - case "darwin": - slib, err := purego.Dlopen(fmt.Sprintf("%s.dylib", limbo), purego.RTLD_LAZY) - if err != nil { - return err - } - limboLib = slib - case "linux": - slib, err := purego.Dlopen(fmt.Sprintf("%s.so", limbo), purego.RTLD_LAZY) - if err != nil { - return err - } - limboLib = slib - case "windows": - slib, err := windows.LoadLibrary(fmt.Sprintf("%s.dll", limbo)) - if err != nil { - return err - } - limboLib = slib - default: - panic(fmt.Errorf("GOOS=%s is not supported", runtime.GOOS)) - } - return nil -} - -func init() { - err := getSystemLibrary() - if err != nil { - slog.Error("Error opening limbo library: ", err) - os.Exit(1) - } - sql.Register(driverName, &limboDriver{}) -} - -type limboDriver struct{} - -func (d limboDriver) Open(name string) (driver.Conn, error) { - return openConn(name) -} - -func toCString(s string) uintptr { - b := append([]byte(s), 0) - return uintptr(unsafe.Pointer(&b[0])) -} - -// helper to register an FFI function in the lib_limbo_go library -func getFfiFunc(ptr interface{}, name string) { - purego.RegisterLibFunc(&ptr, limboLib, name) -} - -type limboConn struct { - ctx uintptr - sync.Mutex - prepare func(uintptr, uintptr) uintptr -} - -func newConn(ctx uintptr) *limboConn { - var prepare func(uintptr, uintptr) uintptr - getFfiFunc(&prepare, FfiDbPrepare) - return &limboConn{ - ctx, - sync.Mutex{}, - prepare, - } -} - -func openConn(dsn string) (*limboConn, error) { - var dbOpen func(uintptr) uintptr - getFfiFunc(&dbOpen, FfiDbOpen) - - cStr := toCString(dsn) - defer freeCString(cStr) - - ctx := dbOpen(cStr) - if ctx == 0 { - return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) - } - return &limboConn{ctx: ctx}, nil -} - -func (c *limboConn) Close() error { - if c.ctx == 0 { - return nil - } - var dbClose func(uintptr) uintptr - getFfiFunc(&dbClose, FfiDbClose) - - dbClose(c.ctx) - c.ctx = 0 - return nil -} - -func (c *limboConn) Prepare(query string) (driver.Stmt, error) { - if c.ctx == 0 { - return nil, errors.New("connection closed") - } - if c.prepare == nil { - var dbPrepare func(uintptr, uintptr) uintptr - getFfiFunc(&dbPrepare, FfiDbPrepare) - c.prepare = dbPrepare - } - qPtr := toCString(query) - stmtPtr := c.prepare(c.ctx, qPtr) - freeCString(qPtr) - - if stmtPtr == 0 { - return nil, fmt.Errorf("prepare failed: %q", query) - } - return &limboStmt{ - ctx: stmtPtr, - sql: query, - }, nil -} - -// begin is needed to implement driver.Conn.. for now not implemented -func (c *limboConn) Begin() (driver.Tx, error) { - return nil, errors.New("transactions not implemented") -} diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go new file mode 100644 index 000000000..31b1fc3b4 --- /dev/null +++ b/bindings/go/limbo_test.go @@ -0,0 +1,323 @@ +package limbo_test + +import ( + "database/sql" + "fmt" + "log" + "testing" + + _ "github.com/tursodatabase/limbo" +) + +var conn *sql.DB +var connErr error + +func TestMain(m *testing.M) { + conn, connErr = sql.Open("sqlite3", ":memory:") + if connErr != nil { + panic(connErr) + } + defer conn.Close() + err := createTable(conn) + if err != nil { + log.Fatalf("Error creating table: %v", err) + } + m.Run() +} + +func TestInsertData(t *testing.T) { + err := insertData(conn) + if err != nil { + t.Fatalf("Error inserting data: %v", err) + } +} + +func TestQuery(t *testing.T) { + query := "SELECT * FROM test;" + stmt, err := conn.Prepare(query) + if err != nil { + t.Fatalf("Error preparing query: %v", err) + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + + expectedCols := []string{"foo", "bar", "baz"} + cols, err := rows.Columns() + if err != nil { + t.Fatalf("Error getting columns: %v", err) + } + if len(cols) != len(expectedCols) { + t.Fatalf("Expected %d columns, got %d", len(expectedCols), len(cols)) + } + for i, col := range cols { + if col != expectedCols[i] { + t.Errorf("Expected column %d to be %s, got %s", i, expectedCols[i], col) + } + } + var i = 1 + for rows.Next() { + var a int + var b string + var c []byte + err = rows.Scan(&a, &b, &c) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if a != i || b != rowsMap[i] || !slicesAreEq(c, []byte(rowsMap[i])) { + t.Fatalf("Expected %d, %s, %s, got %d, %s, %s", i, rowsMap[i], rowsMap[i], a, b, string(c)) + } + fmt.Println("RESULTS: ", a, b, string(c)) + i++ + } + + if err = rows.Err(); err != nil { + t.Fatalf("Row iteration error: %v", err) + } + +} + +func TestFunctions(t *testing.T) { + insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, zeroblob(?));" + stmt, err := conn.Prepare(insert) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + _, err = stmt.Exec(60, "TestFunction", 400) + if err != nil { + t.Fatalf("Error executing statment with arguments: %v", err) + } + stmt.Close() + stmt, err = conn.Prepare("SELECT baz FROM test where foo = ?") + if err != nil { + t.Fatalf("Error preparing select stmt: %v", err) + } + defer stmt.Close() + rows, err := stmt.Query(60) + if err != nil { + t.Fatalf("Error executing select stmt: %v", err) + } + defer rows.Close() + for rows.Next() { + var b []byte + err = rows.Scan(&b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if len(b) != 400 { + t.Fatalf("Expected 100 bytes, got %d", len(b)) + } + } + sql := "SELECT uuid4_str();" + stmt, err = conn.Prepare(sql) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + defer stmt.Close() + rows, err = stmt.Query() + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + var i int + for rows.Next() { + var b string + err = rows.Scan(&b) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + if len(b) != 36 { + t.Fatalf("Expected 36 bytes, got %d", len(b)) + } + i++ + fmt.Printf("uuid: %s\n", b) + } + if i != 1 { + t.Fatalf("Expected 1 row, got %d", i) + } + fmt.Println("zeroblob + uuid functions passed") +} + +func TestDuplicateConnection(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + err = createTable(newConn) + if err != nil { + t.Fatalf("Error creating table: %v", err) + } + err = insertData(newConn) + if err != nil { + t.Fatalf("Error inserting data: %v", err) + } + query := "SELECT * FROM test;" + rows, err := newConn.Query(query) + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + for rows.Next() { + var a int + var b string + var c []byte + err = rows.Scan(&a, &b, &c) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + fmt.Println("RESULTS: ", a, b, string(c)) + } +} + +func TestDuplicateConnection2(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + sql := "CREATE TABLE test (foo INTEGER, bar INTEGER, baz BLOB);" + newConn.Exec(sql) + sql = "INSERT INTO test (foo, bar, baz) VALUES (?, ?, uuid4());" + stmt, err := newConn.Prepare(sql) + stmt.Exec(242345, 2342434) + defer stmt.Close() + query := "SELECT * FROM test;" + rows, err := newConn.Query(query) + if err != nil { + t.Fatalf("Error executing query: %v", err) + } + defer rows.Close() + for rows.Next() { + var a int + var b int + var c []byte + err = rows.Scan(&a, &b, &c) + if err != nil { + t.Fatalf("Error scanning row: %v", err) + } + fmt.Println("RESULTS: ", a, b, string(c)) + if len(c) != 16 { + t.Fatalf("Expected 16 bytes, got %d", len(c)) + } + } +} + +func TestConnectionError(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + sql := "CREATE TABLE test (foo INTEGER, bar INTEGER, baz BLOB);" + newConn.Exec(sql) + sql = "INSERT INTO test (foo, bar, baz) VALUES (?, ?, notafunction(?));" + _, err = newConn.Prepare(sql) + if err == nil { + t.Fatalf("Expected error, got nil") + } + expectedErr := "Parse error: unknown function notafunction" + if err.Error() != expectedErr { + t.Fatalf("Error test failed, expected: %s, found: %v", expectedErr, err) + } + fmt.Println("Connection error test passed") +} + +func TestStatementError(t *testing.T) { + newConn, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening new connection: %v", err) + } + sql := "CREATE TABLE test (foo INTEGER, bar INTEGER, baz BLOB);" + newConn.Exec(sql) + sql = "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);" + stmt, err := newConn.Prepare(sql) + if err != nil { + t.Fatalf("Error preparing statement: %v", err) + } + _, err = stmt.Exec(1, 2) + if err == nil { + t.Fatalf("Expected error, got nil") + } + if err.Error() != "sql: expected 3 arguments, got 2" { + t.Fatalf("Unexpected : %v\n", err) + } + fmt.Println("Statement error test passed") +} + +func TestDriverRowsErrorMessages(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE test (id INTEGER, name TEXT)") + if err != nil { + t.Fatalf("failed to create table: %v", err) + } + + _, err = db.Exec("INSERT INTO test (id, name) VALUES (?, ?)", 1, "Alice") + if err != nil { + t.Fatalf("failed to insert row: %v", err) + } + + rows, err := db.Query("SELECT id, name FROM test") + if err != nil { + t.Fatalf("failed to query table: %v", err) + } + + if !rows.Next() { + t.Fatalf("expected at least one row") + } + var id int + var name string + err = rows.Scan(&name, &id) + if err == nil { + t.Fatalf("expected error scanning wrong type: %v", err) + } + t.Log("Rows error behavior test passed") +} + +func slicesAreEq(a, b []byte) bool { + if len(a) != len(b) { + fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b)) + return false + } + for i := range a { + if a[i] != b[i] { + fmt.Printf("SLICES NOT EQUAL: %v != %v\n", a, b) + return false + } + } + return true +} + +var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"} + +func createTable(conn *sql.DB) error { + insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);" + stmt, err := conn.Prepare(insert) + if err != nil { + return err + } + defer stmt.Close() + _, err = stmt.Exec() + return err +} + +func insertData(conn *sql.DB) error { + for i := 1; i <= 5; i++ { + insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);" + stmt, err := conn.Prepare(insert) + if err != nil { + return err + } + defer stmt.Close() + if _, err = stmt.Exec(i, rowsMap[i], []byte(rowsMap[i])); err != nil { + return err + } + } + return nil +} diff --git a/bindings/go/limbo_unix.go b/bindings/go/limbo_unix.go new file mode 100644 index 000000000..8ab911416 --- /dev/null +++ b/bindings/go/limbo_unix.go @@ -0,0 +1,45 @@ +//go:build linux || darwin + +package limbo + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + + "github.com/ebitengine/purego" +) + +func loadLibrary() (uintptr, error) { + var libraryName string + switch runtime.GOOS { + case "darwin": + libraryName = fmt.Sprintf("%s.dylib", libName) + case "linux": + libraryName = fmt.Sprintf("%s.so", libName) + default: + return 0, fmt.Errorf("GOOS=%s is not supported", runtime.GOOS) + } + + libPath := os.Getenv("LD_LIBRARY_PATH") + paths := strings.Split(libPath, ":") + cwd, err := os.Getwd() + if err != nil { + return 0, err + } + paths = append(paths, cwd) + + for _, path := range paths { + libPath := filepath.Join(path, libraryName) + if _, err := os.Stat(libPath); err == nil { + slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL) + if dlerr != nil { + return 0, fmt.Errorf("failed to load library at %s: %w", libPath, dlerr) + } + return slib, nil + } + } + return 0, fmt.Errorf("%s library not found in LD_LIBRARY_PATH or CWD", libName) +} diff --git a/bindings/go/limbo_windows.go b/bindings/go/limbo_windows.go new file mode 100644 index 000000000..d56381176 --- /dev/null +++ b/bindings/go/limbo_windows.go @@ -0,0 +1,36 @@ +//go:build windows + +package limbo + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/windows" +) + +func loadLibrary() (uintptr, error) { + libName := fmt.Sprintf("%s.dll", libName) + pathEnv := os.Getenv("PATH") + paths := strings.Split(pathEnv, ";") + + cwd, err := os.Getwd() + if err != nil { + return 0, err + } + paths = append(paths, cwd) + for _, path := range paths { + dllPath := filepath.Join(path, libName) + if _, err := os.Stat(dllPath); err == nil { + slib, loadErr := windows.LoadLibrary(dllPath) + if loadErr != nil { + return 0, fmt.Errorf("failed to load library at %s: %w", dllPath, loadErr) + } + return uintptr(slib), nil + } + } + + return 0, fmt.Errorf("library %s not found in PATH or CWD", libName) +} diff --git a/bindings/go/rows.go b/bindings/go/rows.go new file mode 100644 index 000000000..1d14e0d0c --- /dev/null +++ b/bindings/go/rows.go @@ -0,0 +1,121 @@ +package limbo + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + "sync" +) + +type limboRows struct { + mu sync.Mutex + ctx uintptr + columns []string + err error + closed bool +} + +func newRows(ctx uintptr) *limboRows { + return &limboRows{ + mu: sync.Mutex{}, + ctx: ctx, + columns: nil, + err: nil, + closed: false, + } +} + +func (r *limboRows) isClosed() bool { + if r.ctx == 0 || r.closed { + return true + } + return false +} + +func (r *limboRows) Columns() []string { + if r.isClosed() { + return nil + } + if r.columns == nil { + r.mu.Lock() + count := rowsGetColumns(r.ctx) + if count > 0 { + columns := make([]string, 0, count) + for i := 0; i < int(count); i++ { + cstr := rowsGetColumnName(r.ctx, int32(i)) + columns = append(columns, fmt.Sprintf("%s", GoString(cstr))) + freeCString(cstr) + } + r.mu.Unlock() + r.columns = columns + } + } + return r.columns +} + +func (r *limboRows) Close() error { + r.err = errors.New(RowsClosedErr) + if r.isClosed() { + return r.err + } + r.mu.Lock() + r.closed = true + closeRows(r.ctx) + r.ctx = 0 + r.mu.Unlock() + return nil +} + +func (r *limboRows) Err() error { + if r.err == nil { + r.mu.Lock() + defer r.mu.Unlock() + r.getError() + } + return r.err +} + +func (r *limboRows) Next(dest []driver.Value) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.isClosed() { + return r.err + } + for { + status := rowsNext(r.ctx) + switch ResultCode(status) { + case Row: + for i := range dest { + valPtr := rowsGetValue(r.ctx, int32(i)) + val := toGoValue(valPtr) + if val == nil { + r.getError() + } + dest[i] = val + } + return nil + case Io: + continue + case Done: + return io.EOF + default: + return r.getError() + } + } +} + +// mutex will already be locked. this is always called after FFI +func (r *limboRows) getError() error { + if r.isClosed() { + return r.err + } + err := rowsGetError(r.ctx) + if err == 0 { + return nil + } + defer freeCString(err) + cpy := fmt.Sprintf("%s", GoString(err)) + r.err = errors.New(cpy) + return r.err +} diff --git a/bindings/go/rs_src/lib.rs b/bindings/go/rs_src/lib.rs index 199ed10c0..6ef640580 100644 --- a/bindings/go/rs_src/lib.rs +++ b/bindings/go/rs_src/lib.rs @@ -26,7 +26,6 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { let db = Database::open_file(io.clone(), &db_options.path.to_string()); match db { Ok(db) => { - println!("Opened database: {}", path); let conn = db.connect(); return LimboConn::new(conn, io).to_ptr(); } @@ -43,23 +42,51 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { struct LimboConn { conn: Rc, io: Arc, + err: Option, } -impl LimboConn { +impl<'conn> LimboConn { fn new(conn: Rc, io: Arc) -> Self { - LimboConn { conn, io } + LimboConn { + conn, + io, + err: None, + } } + #[allow(clippy::wrong_self_convention)] fn to_ptr(self) -> *mut c_void { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'static mut LimboConn { + fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboConn { if ptr.is_null() { panic!("Null pointer"); } unsafe { &mut *(ptr as *mut LimboConn) } } + + fn get_error(&mut self) -> *const c_char { + if let Some(err) = &self.err { + let err = format!("{}", err); + let c_str = std::ffi::CString::new(err).unwrap(); + self.err = None; + c_str.into_raw() as *const c_char + } else { + std::ptr::null() + } + } +} +/// Get the error value from the connection, if any, as a null +/// terminated string. The caller is responsible for freeing the +/// memory with `free_string`. +#[no_mangle] +pub extern "C" fn db_get_error(ctx: *mut c_void) -> *const c_char { + if ctx.is_null() { + return std::ptr::null(); + } + let conn = LimboConn::from_ptr(ctx); + conn.get_error() } /// Close the database connection diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs index 456d57bdc..62b9560eb 100644 --- a/bindings/go/rs_src/rows.rs +++ b/bindings/go/rs_src/rows.rs @@ -1,22 +1,24 @@ use crate::{ - statement::LimboStatement, types::{LimboValue, ResultCode}, + LimboConn, }; -use limbo_core::{Statement, StepResult, Value}; +use limbo_core::{LimboError, Row, Statement, StepResult}; use std::ffi::{c_char, c_void}; -pub struct LimboRows<'a> { - rows: Statement, - cursor: Option>>, - stmt: Box>, +pub struct LimboRows<'conn, 'a> { + stmt: Box, + conn: &'conn mut LimboConn, + cursor: Option>, + err: Option, } -impl<'a> LimboRows<'a> { - pub fn new(rows: Statement, stmt: Box>) -> Self { +impl<'conn, 'a> LimboRows<'conn, 'a> { + pub fn new(stmt: Statement, conn: &'conn mut LimboConn) -> Self { LimboRows { - rows, - stmt, + stmt: Box::new(stmt), cursor: None, + conn, + err: None, } } @@ -25,12 +27,23 @@ impl<'a> LimboRows<'a> { Box::into_raw(Box::new(self)) as *mut c_void } - pub fn from_ptr(ptr: *mut c_void) -> &'static mut LimboRows<'a> { + pub fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboRows<'conn, 'a> { if ptr.is_null() { panic!("Null pointer"); } unsafe { &mut *(ptr as *mut LimboRows) } } + + fn get_error(&mut self) -> *const c_char { + if let Some(err) = &self.err { + let err = format!("{}", err); + let c_str = std::ffi::CString::new(err).unwrap(); + self.err = None; + c_str.into_raw() as *const c_char + } else { + std::ptr::null() + } + } } #[no_mangle] @@ -40,19 +53,22 @@ pub extern "C" fn rows_next(ctx: *mut c_void) -> ResultCode { } let ctx = LimboRows::from_ptr(ctx); - match ctx.rows.step() { + match ctx.stmt.step() { Ok(StepResult::Row(row)) => { - ctx.cursor = Some(row.values); + ctx.cursor = Some(row); ResultCode::Row } Ok(StepResult::Done) => ResultCode::Done, Ok(StepResult::IO) => { - let _ = ctx.stmt.conn.io.run_once(); + let _ = ctx.conn.io.run_once(); ResultCode::Io } Ok(StepResult::Busy) => ResultCode::Busy, Ok(StepResult::Interrupt) => ResultCode::Interrupt, - Err(_) => ResultCode::Error, + Err(err) => { + ctx.err = Some(err); + ResultCode::Error + } } } @@ -64,9 +80,8 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v let ctx = LimboRows::from_ptr(ctx); if let Some(ref cursor) = ctx.cursor { - if let Some(value) = cursor.get(col_idx) { - let val = LimboValue::from_value(value); - return val.to_ptr(); + if let Some(value) = cursor.values.get(col_idx) { + return LimboValue::from_value(value).to_ptr(); } } std::ptr::null() @@ -79,60 +94,53 @@ pub extern "C" fn free_string(s: *mut c_char) { } } +/// Function to get the number of expected ResultColumns in the prepared statement. +/// to avoid the needless complexity of returning an array of strings, this instead +/// works like rows_next/rows_get_value #[no_mangle] -pub extern "C" fn rows_get_columns( - rows_ptr: *mut c_void, - out_length: *mut usize, -) -> *mut *const c_char { - if rows_ptr.is_null() || out_length.is_null() { +pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 { + if rows_ptr.is_null() { + return -1; + } + let rows = LimboRows::from_ptr(rows_ptr); + rows.stmt.columns().len() as i32 +} + +/// Returns a pointer to a string with the name of the column at the given index. +/// The caller is responsible for freeing the memory, it should be copied on the Go side +/// immediately and 'free_string' called +#[no_mangle] +pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *const c_char { + if rows_ptr.is_null() { return std::ptr::null_mut(); } let rows = LimboRows::from_ptr(rows_ptr); - let c_strings: Vec = rows - .rows - .columns() - .iter() - .map(|name| std::ffi::CString::new(name.as_str()).unwrap()) - .collect(); - - let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect(); - unsafe { - *out_length = c_ptrs.len(); + if idx < 0 || idx as usize >= rows.stmt.columns().len() { + return std::ptr::null_mut(); } - let ptr = c_ptrs.as_ptr(); - std::mem::forget(c_strings); - std::mem::forget(c_ptrs); - ptr as *mut *const c_char + let name = &rows.stmt.columns()[idx as usize]; + let cstr = std::ffi::CString::new(name.as_bytes()).expect("Failed to create CString"); + cstr.into_raw() as *const c_char } #[no_mangle] -pub extern "C" fn rows_close(rows_ptr: *mut c_void) { - if !rows_ptr.is_null() { - let _ = unsafe { Box::from_raw(rows_ptr as *mut LimboRows) }; +pub extern "C" fn rows_get_error(ctx: *mut c_void) -> *const c_char { + if ctx.is_null() { + return std::ptr::null(); } + let ctx = LimboRows::from_ptr(ctx); + ctx.get_error() } #[no_mangle] -pub extern "C" fn free_columns(columns: *mut *const c_char) { - if columns.is_null() { - return; +pub extern "C" fn rows_close(ctx: *mut c_void) { + if !ctx.is_null() { + let rows = LimboRows::from_ptr(ctx); + rows.stmt.reset(); + rows.cursor = None; + rows.err = None; } unsafe { - let mut idx = 0; - while !(*columns.add(idx)).is_null() { - let _ = std::ffi::CString::from_raw(*columns.add(idx) as *mut c_char); - idx += 1; - } - let _ = Box::from_raw(columns); - } -} - -#[no_mangle] -pub extern "C" fn free_rows(rows: *mut c_void) { - if rows.is_null() { - return; - } - unsafe { - let _ = Box::from_raw(rows as *mut Statement); + let _ = Box::from_raw(ctx.cast::()); } } diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 82fb55648..897c246fa 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -1,7 +1,7 @@ use crate::rows::LimboRows; use crate::types::{AllocPool, LimboValue, ResultCode}; use crate::LimboConn; -use limbo_core::{Statement, StepResult}; +use limbo_core::{LimboError, Statement, StepResult}; use std::ffi::{c_char, c_void}; use std::num::NonZero; @@ -13,11 +13,13 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v let query_str = unsafe { std::ffi::CStr::from_ptr(query) }.to_str().unwrap(); let db = LimboConn::from_ptr(ctx); - - let stmt = db.conn.prepare(query_str.to_string()); + let stmt = db.conn.prepare(query_str); match stmt { - Ok(stmt) => LimboStatement::new(stmt, db).to_ptr(), - Err(_) => std::ptr::null_mut(), + Ok(stmt) => LimboStatement::new(Some(stmt), db).to_ptr(), + Err(err) => { + db.err = Some(err); + std::ptr::null_mut() + } } } @@ -38,21 +40,25 @@ pub extern "C" fn stmt_execute( } else { &[] }; + let mut pool = AllocPool::new(); + let Some(statement) = stmt.statement.as_mut() else { + return ResultCode::Error; + }; for (i, arg) in args.iter().enumerate() { - let val = arg.to_value(&mut stmt.pool); - stmt.statement.bind_at(NonZero::new(i + 1).unwrap(), val); + let val = arg.to_value(&mut pool); + statement.bind_at(NonZero::new(i + 1).unwrap(), val); } loop { - match stmt.statement.step() { + match statement.step() { Ok(StepResult::Row(_)) => { // unexpected row during execution, error out. return ResultCode::Error; } Ok(StepResult::Done) => { - stmt.conn.conn.total_changes(); + let total_changes = stmt.conn.conn.total_changes(); if !changes.is_null() { unsafe { - *changes = stmt.conn.conn.total_changes(); + *changes = total_changes; } } return ResultCode::Done; @@ -66,7 +72,8 @@ pub extern "C" fn stmt_execute( Ok(StepResult::Interrupt) => { return ResultCode::Interrupt; } - Err(_) => { + Err(err) => { + stmt.conn.err = Some(err); return ResultCode::Error; } } @@ -79,7 +86,11 @@ pub extern "C" fn stmt_parameter_count(ctx: *mut c_void) -> i32 { return -1; } let stmt = LimboStatement::from_ptr(ctx); - stmt.statement.parameters_count() as i32 + let Some(statement) = stmt.statement.as_ref() else { + stmt.err = Some(LimboError::InternalError("Statement is closed".to_string())); + return -1; + }; + statement.parameters_count() as i32 } #[no_mangle] @@ -97,31 +108,50 @@ pub extern "C" fn stmt_query( } else { &[] }; + let mut pool = AllocPool::new(); + let Some(mut statement) = stmt.statement.take() else { + return std::ptr::null_mut(); + }; for (i, arg) in args.iter().enumerate() { - let val = arg.to_value(&mut stmt.pool); - stmt.statement.bind_at(NonZero::new(i + 1).unwrap(), val); - } - match stmt.statement.query() { - Ok(rows) => { - let stmt = unsafe { Box::from_raw(stmt) }; - LimboRows::new(rows, stmt).to_ptr() - } - Err(_) => std::ptr::null_mut(), + let val = arg.to_value(&mut pool); + statement.bind_at(NonZero::new(i + 1).unwrap(), val); } + // ownership of the statement is transfered to the LimboRows object. + LimboRows::new(statement, stmt.conn).to_ptr() } pub struct LimboStatement<'conn> { - pub statement: Statement, + /// If 'query' is ran on the statement, ownership is transfered to the LimboRows object + pub statement: Option, pub conn: &'conn mut LimboConn, - pub pool: AllocPool, + pub err: Option, +} + +#[no_mangle] +pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode { + if !ctx.is_null() { + let stmt = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; + drop(stmt); + return ResultCode::Ok; + } + ResultCode::Invalid +} + +#[no_mangle] +pub extern "C" fn stmt_get_error(ctx: *mut c_void) -> *const c_char { + if ctx.is_null() { + return std::ptr::null(); + } + let stmt = LimboStatement::from_ptr(ctx); + stmt.get_error() } impl<'conn> LimboStatement<'conn> { - pub fn new(statement: Statement, conn: &'conn mut LimboConn) -> Self { + pub fn new(statement: Option, conn: &'conn mut LimboConn) -> Self { LimboStatement { statement, conn, - pool: AllocPool::new(), + err: None, } } @@ -130,10 +160,21 @@ impl<'conn> LimboStatement<'conn> { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'static mut LimboStatement<'conn> { + fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboStatement<'conn> { if ptr.is_null() { panic!("Null pointer"); } unsafe { &mut *(ptr as *mut LimboStatement) } } + + fn get_error(&mut self) -> *const c_char { + if let Some(err) = &self.err { + let err = format!("{}", err); + let c_str = std::ffi::CString::new(err).unwrap(); + self.err = None; + c_str.into_raw() as *const c_char + } else { + std::ptr::null() + } + } } diff --git a/bindings/go/rs_src/types.rs b/bindings/go/rs_src/types.rs index 851212c65..11c9b251f 100644 --- a/bindings/go/rs_src/types.rs +++ b/bindings/go/rs_src/types.rs @@ -14,6 +14,9 @@ pub enum ResultCode { ReadOnly = 8, NoData = 9, Done = 10, + SyntaxErr = 11, + ConstraintViolation = 12, + NoSuchEntity = 13, } #[repr(C)] @@ -27,34 +30,29 @@ pub enum ValueType { #[repr(C)] pub struct LimboValue { - pub value_type: ValueType, - pub value: ValueUnion, + value_type: ValueType, + value: ValueUnion, } #[repr(C)] -pub union ValueUnion { - pub int_val: i64, - pub real_val: f64, - pub text_ptr: *const c_char, - pub blob_ptr: *const c_void, +union ValueUnion { + int_val: i64, + real_val: f64, + text_ptr: *const c_char, + blob_ptr: *const c_void, } #[repr(C)] -pub struct Blob { - pub data: *const u8, - pub len: usize, -} - -impl Blob { - pub fn to_ptr(&self) -> *const c_void { - self as *const Blob as *const c_void - } +struct Blob { + data: *const u8, + len: i64, } pub struct AllocPool { strings: Vec, blobs: Vec>, } + impl AllocPool { pub fn new() -> Self { AllocPool { @@ -82,21 +80,23 @@ pub extern "C" fn free_blob(blob_ptr: *mut c_void) { let _ = Box::from_raw(blob_ptr as *mut Blob); } } + #[allow(dead_code)] impl ValueUnion { fn from_str(s: &str) -> Self { + let cstr = std::ffi::CString::new(s).expect("Failed to create CString"); ValueUnion { - text_ptr: s.as_ptr() as *const c_char, + text_ptr: cstr.into_raw(), } } fn from_bytes(b: &[u8]) -> Self { + let blob = Box::new(Blob { + data: b.as_ptr(), + len: b.len() as i64, + }); ValueUnion { - blob_ptr: Blob { - data: b.as_ptr(), - len: b.len(), - } - .to_ptr(), + blob_ptr: Box::into_raw(blob) as *const c_void, } } @@ -121,18 +121,25 @@ impl ValueUnion { } pub fn to_str(&self) -> &str { - unsafe { std::ffi::CStr::from_ptr(self.text_ptr).to_str().unwrap() } + unsafe { + if self.text_ptr.is_null() { + return ""; + } + std::ffi::CStr::from_ptr(self.text_ptr) + .to_str() + .unwrap_or("") + } } pub fn to_bytes(&self) -> &[u8] { let blob = unsafe { self.blob_ptr as *const Blob }; let blob = unsafe { &*blob }; - unsafe { std::slice::from_raw_parts(blob.data, blob.len) } + unsafe { std::slice::from_raw_parts(blob.data, blob.len as usize) } } } impl LimboValue { - pub fn new(value_type: ValueType, value: ValueUnion) -> Self { + fn new(value_type: ValueType, value: ValueUnion) -> Self { LimboValue { value_type, value } } @@ -157,16 +164,30 @@ impl LimboValue { } } + // The values we get from Go need to be temporarily owned by the statement until they are bound + // then they can be cleaned up immediately afterwards pub fn to_value<'pool>(&self, pool: &'pool mut AllocPool) -> limbo_core::Value<'pool> { match self.value_type { - ValueType::Integer => limbo_core::Value::Integer(unsafe { self.value.int_val }), - ValueType::Real => limbo_core::Value::Float(unsafe { self.value.real_val }), + ValueType::Integer => { + if unsafe { self.value.int_val == 0 } { + return limbo_core::Value::Null; + } + limbo_core::Value::Integer(unsafe { self.value.int_val }) + } + ValueType::Real => { + if unsafe { self.value.real_val == 0.0 } { + return limbo_core::Value::Null; + } + limbo_core::Value::Float(unsafe { self.value.real_val }) + } ValueType::Text => { + if unsafe { self.value.text_ptr.is_null() } { + return limbo_core::Value::Null; + } let cstr = unsafe { std::ffi::CStr::from_ptr(self.value.text_ptr) }; match cstr.to_str() { Ok(utf8_str) => { let owned = utf8_str.to_owned(); - // statement needs to own these strings, will free when closed let borrowed = pool.add_string(owned); limbo_core::Value::Text(borrowed) } @@ -174,15 +195,12 @@ impl LimboValue { } } ValueType::Blob => { - let blob_ptr = unsafe { self.value.blob_ptr as *const Blob }; - if blob_ptr.is_null() { - limbo_core::Value::Null - } else { - let blob = unsafe { &*blob_ptr }; - let data = unsafe { std::slice::from_raw_parts(blob.data, blob.len) }; - let borrowed = pool.add_blob(data.to_vec()); - limbo_core::Value::Blob(borrowed) + if unsafe { self.value.blob_ptr.is_null() } { + return limbo_core::Value::Null; } + let bytes = self.value.to_bytes(); + let borrowed = pool.add_blob(bytes.to_vec()); + limbo_core::Value::Blob(borrowed) } ValueType::Null => limbo_core::Value::Null, } diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 30bceefac..9e045175e 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -5,40 +5,53 @@ import ( "database/sql/driver" "errors" "fmt" - "io" + "sync" "unsafe" ) -// only construct limboStmt with initStmt function to ensure proper initialization type limboStmt struct { - ctx uintptr - sql string - query stmtQueryFn - execute stmtExecuteFn - getParamCount func(uintptr) int32 + mu sync.Mutex + ctx uintptr + sql string + err error } -// Initialize/register the FFI function pointers for the statement methods -func initStmt(ctx uintptr, sql string) *limboStmt { - var query stmtQueryFn - var execute stmtExecuteFn - var getParamCount func(uintptr) int32 - methods := []ExtFunc{{query, FfiStmtQuery}, {execute, FfiStmtExec}, {getParamCount, FfiStmtParameterCount}} - for i := range methods { - methods[i].initFunc() - } +func newStmt(ctx uintptr, sql string) *limboStmt { return &limboStmt{ ctx: uintptr(ctx), sql: sql, + err: nil, } } -func (st *limboStmt) NumInput() int { - return int(st.getParamCount(st.ctx)) +func (ls *limboStmt) NumInput() int { + ls.mu.Lock() + defer ls.mu.Unlock() + res := int(stmtParamCount(ls.ctx)) + if res < 0 { + // set the error from rust + _ = ls.getError() + } + return res } -func (st *limboStmt) Exec(args []driver.Value) (driver.Result, error) { - argArray, err := buildArgs(args) +func (ls *limboStmt) Close() error { + ls.mu.Lock() + defer ls.mu.Unlock() + if ls.ctx == 0 { + return nil + } + res := stmtClose(ls.ctx) + ls.ctx = 0 + if ResultCode(res) != Ok { + return fmt.Errorf("error closing statement: %s", ResultCode(res).String()) + } + return nil +} + +func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { + argArray, cleanup, err := buildArgs(args) + defer cleanup() if err != nil { return nil, err } @@ -48,9 +61,11 @@ func (st *limboStmt) Exec(args []driver.Value) (driver.Result, error) { argPtr = uintptr(unsafe.Pointer(&argArray[0])) } var changes uint64 - rc := st.execute(st.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) + ls.mu.Lock() + defer ls.mu.Unlock() + rc := stmtExec(ls.ctx, argPtr, argCount, uintptr(unsafe.Pointer(&changes))) switch ResultCode(rc) { - case Ok: + case Ok, Done: return driver.RowsAffected(changes), nil case Error: return nil, errors.New("error executing statement") @@ -61,134 +76,101 @@ func (st *limboStmt) Exec(args []driver.Value) (driver.Result, error) { case Invalid: return nil, errors.New("invalid statement") default: - return nil, fmt.Errorf("unexpected status: %d", rc) + return nil, ls.getError() } } -func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) { - queryArgs, err := buildArgs(args) +func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { + queryArgs, cleanup, err := buildArgs(args) + defer cleanup() if err != nil { return nil, err } - rowsPtr := st.query(st.ctx, uintptr(unsafe.Pointer(&queryArgs[0])), uint64(len(queryArgs))) - if rowsPtr == 0 { - return nil, fmt.Errorf("query failed for: %q", st.sql) + argPtr := uintptr(0) + if len(args) > 0 { + argPtr = uintptr(unsafe.Pointer(&queryArgs[0])) } - return initRows(rowsPtr), nil + ls.mu.Lock() + defer ls.mu.Unlock() + rowsPtr := stmtQuery(ls.ctx, argPtr, uint64(len(queryArgs))) + if rowsPtr == 0 { + return nil, ls.getError() + } + return newRows(rowsPtr), nil } -func (ts *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { stripped := namedValueToValue(args) - argArray, err := getArgsPtr(stripped) + argArray, cleanup, err := getArgsPtr(stripped) + defer cleanup() if err != nil { return nil, err } - var changes uintptr - res := ts.execute(ts.ctx, argArray, uint64(len(args)), changes) - switch ResultCode(res) { - case Ok: - return driver.RowsAffected(changes), nil - case Error: - return nil, errors.New("error executing statement") - case Busy: - return nil, errors.New("busy") - case Interrupt: - return nil, errors.New("interrupted") + ls.mu.Lock() + select { + case <-ctx.Done(): + ls.mu.Unlock() + return nil, ctx.Err() default: - return nil, fmt.Errorf("unexpected status: %d", res) + var changes uint64 + defer ls.mu.Unlock() + res := stmtExec(ls.ctx, argArray, uint64(len(args)), uintptr(unsafe.Pointer(&changes))) + switch ResultCode(res) { + case Ok, Done: + changes := uint64(changes) + return driver.RowsAffected(changes), nil + case Busy: + return nil, errors.New("Database is Busy") + case Interrupt: + return nil, errors.New("Interrupted") + default: + return nil, ls.getError() + } } } -func (st *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - queryArgs, err := buildNamedArgs(args) +func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + queryArgs, allocs, err := buildNamedArgs(args) + defer allocs() if err != nil { return nil, err } - rowsPtr := st.query(st.ctx, uintptr(unsafe.Pointer(&queryArgs[0])), uint64(len(queryArgs))) - if rowsPtr == 0 { - return nil, fmt.Errorf("query failed for: %q", st.sql) + argsPtr := uintptr(0) + if len(queryArgs) > 0 { + argsPtr = uintptr(unsafe.Pointer(&queryArgs[0])) } - return initRows(rowsPtr), nil -} - -// only construct limboRows with initRows function to ensure proper initialization -type limboRows struct { - ctx uintptr - columns []string - closed bool - getCols func(uintptr, *uint) uintptr - next func(uintptr) uintptr - getValue func(uintptr, int32) uintptr - closeRows func(uintptr) uintptr - freeCols func(uintptr) uintptr -} - -// Initialize/register the FFI function pointers for the rows methods -// DO NOT construct 'limboRows' without this function -func initRows(ctx uintptr) *limboRows { - var getCols func(uintptr, *uint) uintptr - var getValue func(uintptr, int32) uintptr - var closeRows func(uintptr) uintptr - var freeCols func(uintptr) uintptr - var next func(uintptr) uintptr - methods := []ExtFunc{ - {getCols, FfiRowsGetColumns}, - {getValue, FfiRowsGetValue}, - {closeRows, FfiRowsClose}, - {freeCols, FfiFreeColumns}, - {next, FfiRowsNext}} - for i := range methods { - methods[i].initFunc() - } - - return &limboRows{ - ctx: ctx, - getCols: getCols, - getValue: getValue, - closeRows: closeRows, - freeCols: freeCols, - next: next, - } -} - -func (r *limboRows) Columns() []string { - if r.columns == nil { - var columnCount uint - colArrayPtr := r.getCols(r.ctx, &columnCount) - if colArrayPtr != 0 && columnCount > 0 { - r.columns = cArrayToGoStrings(colArrayPtr, columnCount) - if r.freeCols == nil { - getFfiFunc(&r.freeCols, FfiFreeColumns) - } - defer r.freeCols(colArrayPtr) - } - } - return r.columns -} - -func (r *limboRows) Close() error { - if r.closed { - return nil - } - r.closed = true - r.closeRows(r.ctx) - r.ctx = 0 - return nil -} - -func (r *limboRows) Next(dest []driver.Value) error { - status := r.next(r.ctx) - switch ResultCode(status) { - case Row: - for i := range dest { - valPtr := r.getValue(r.ctx, int32(i)) - val := toGoValue(valPtr) - dest[i] = val - } - return nil - case Done: - return io.EOF + ls.mu.Lock() + select { + case <-ctx.Done(): + ls.mu.Unlock() + return nil, ctx.Err() default: - return fmt.Errorf("unexpected status: %d", status) + defer ls.mu.Unlock() + rowsPtr := stmtQuery(ls.ctx, argsPtr, uint64(len(queryArgs))) + if rowsPtr == 0 { + return nil, ls.getError() + } + return newRows(rowsPtr), nil } } + +func (ls *limboStmt) Err() error { + if ls.err == nil { + ls.mu.Lock() + defer ls.mu.Unlock() + ls.getError() + } + return ls.err +} + +// mutex should always be locked when calling - always called after FFI +func (ls *limboStmt) getError() error { + err := stmtGetError(ls.ctx) + if err == 0 { + return nil + } + defer freeCString(err) + cpy := fmt.Sprintf("%s", GoString(err)) + ls.err = errors.New(cpy) + return ls.err +} diff --git a/bindings/go/types.go b/bindings/go/types.go index c27832f43..608c41376 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -3,39 +3,87 @@ package limbo import ( "database/sql/driver" "fmt" + "runtime" "unsafe" ) -type ResultCode int +type ResultCode int32 const ( - Error ResultCode = -1 - Ok ResultCode = 0 - Row ResultCode = 1 - Busy ResultCode = 2 - Io ResultCode = 3 - Interrupt ResultCode = 4 - Invalid ResultCode = 5 - Null ResultCode = 6 - NoMem ResultCode = 7 - ReadOnly ResultCode = 8 - NoData ResultCode = 9 - Done ResultCode = 10 + Error ResultCode = -1 + Ok ResultCode = 0 + Row ResultCode = 1 + Busy ResultCode = 2 + Io ResultCode = 3 + Interrupt ResultCode = 4 + Invalid ResultCode = 5 + Null ResultCode = 6 + NoMem ResultCode = 7 + ReadOnly ResultCode = 8 + NoData ResultCode = 9 + Done ResultCode = 10 + SyntaxErr ResultCode = 11 + ConstraintViolation ResultCode = 12 + NoSuchEntity ResultCode = 13 ) +func (rc ResultCode) String() string { + switch rc { + case Error: + return "Error" + case Ok: + return "Ok" + case Row: + return "Row" + case Busy: + return "Busy" + case Io: + return "Io" + case Interrupt: + return "Query was interrupted" + case Invalid: + return "Invalid" + case Null: + return "Null" + case NoMem: + return "Out of memory" + case ReadOnly: + return "Read Only" + case NoData: + return "No Data" + case Done: + return "Done" + case SyntaxErr: + return "Syntax Error" + case ConstraintViolation: + return "Constraint Violation" + case NoSuchEntity: + return "No such entity" + default: + return "Unknown response code" + } +} + const ( - FfiDbOpen string = "db_open" - FfiDbClose string = "db_close" - FfiDbPrepare string = "db_prepare" - FfiStmtExec string = "stmt_execute" - FfiStmtQuery string = "stmt_query" - FfiStmtParameterCount string = "stmt_parameter_count" - FfiRowsClose string = "rows_close" - FfiRowsGetColumns string = "rows_get_columns" - FfiRowsNext string = "rows_next" - FfiRowsGetValue string = "rows_get_value" - FfiFreeColumns string = "free_columns" - FfiFreeCString string = "free_string" + driverName = "sqlite3" + libName = "lib_limbo_go" + RowsClosedErr = "sql: Rows closed" + FfiDbOpen = "db_open" + FfiDbClose = "db_close" + FfiDbPrepare = "db_prepare" + FfiDbGetError = "db_get_error" + FfiStmtExec = "stmt_execute" + FfiStmtQuery = "stmt_query" + FfiStmtParameterCount = "stmt_parameter_count" + FfiStmtClose = "stmt_close" + FfiRowsClose = "rows_close" + FfiRowsGetColumns = "rows_get_columns" + FfiRowsGetColumnName = "rows_get_column_name" + FfiRowsNext = "rows_next" + FfiRowsGetValue = "rows_get_value" + FfiFreeColumns = "free_columns" + FfiFreeCString = "free_string" + FfiFreeBlob = "free_blob" ) // convert a namedValue slice into normal values until named parameters are supported @@ -47,47 +95,56 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value { return out } -func buildNamedArgs(named []driver.NamedValue) ([]limboValue, error) { - args := make([]driver.Value, len(named)) - for i, nv := range named { - args[i] = nv.Value - } +func buildNamedArgs(named []driver.NamedValue) ([]limboValue, func(), error) { + args := namedValueToValue(named) return buildArgs(args) } -type ExtFunc struct { - funcPtr interface{} - funcName string -} - -func (ef *ExtFunc) initFunc() { - getFfiFunc(&ef.funcPtr, ef.funcName) -} - -type valueType int +type valueType int32 const ( - intVal valueType = iota - textVal - blobVal - realVal - nullVal + intVal valueType = 0 + textVal valueType = 1 + blobVal valueType = 2 + realVal valueType = 3 + nullVal valueType = 4 ) +func (vt valueType) String() string { + switch vt { + case intVal: + return "int" + case textVal: + return "text" + case blobVal: + return "blob" + case realVal: + return "real" + case nullVal: + return "null" + default: + return "unknown" + } +} + // struct to pass Go values over FFI type limboValue struct { Type valueType + _ [4]byte Value [8]byte } // struct to pass byte slices over FFI type Blob struct { Data uintptr - Len uint + Len int64 } // convert a limboValue to a native Go value func toGoValue(valPtr uintptr) interface{} { + if valPtr == 0 { + return nil + } val := (*limboValue)(unsafe.Pointer(valPtr)) switch val.Type { case intVal: @@ -96,9 +153,11 @@ func toGoValue(valPtr uintptr) interface{} { return *(*float64)(unsafe.Pointer(&val.Value)) case textVal: textPtr := *(*uintptr)(unsafe.Pointer(&val.Value)) + defer freeCString(textPtr) return GoString(textPtr) case blobVal: blobPtr := *(*uintptr)(unsafe.Pointer(&val.Value)) + defer freeBlob(blobPtr) return toGoBlob(blobPtr) case nullVal: return nil @@ -107,15 +166,15 @@ func toGoValue(valPtr uintptr) interface{} { } } -func getArgsPtr(args []driver.Value) (uintptr, error) { +func getArgsPtr(args []driver.Value) (uintptr, func(), error) { if len(args) == 0 { - return 0, nil + return 0, nil, nil } - argSlice, err := buildArgs(args) + argSlice, allocs, err := buildArgs(args) if err != nil { - return 0, err + return 0, allocs, err } - return uintptr(unsafe.Pointer(&argSlice[0])), nil + return uintptr(unsafe.Pointer(&argSlice[0])), allocs, nil } // convert a byte slice to a Blob type that can be sent over FFI @@ -123,11 +182,10 @@ func makeBlob(b []byte) *Blob { if len(b) == 0 { return nil } - blob := &Blob{ + return &Blob{ Data: uintptr(unsafe.Pointer(&b[0])), - Len: uint(len(b)), + Len: int64(len(b)), } - return blob } // converts a blob received via FFI to a native Go byte slice @@ -136,85 +194,64 @@ func toGoBlob(blobPtr uintptr) []byte { return nil } blob := (*Blob)(unsafe.Pointer(blobPtr)) - return unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len) -} - -var freeString func(*byte) - -// free a C style string allocated via FFI -func freeCString(cstr uintptr) { - if cstr == 0 { - return - } - if freeString == nil { - getFfiFunc(&freeString, FfiFreeCString) - } - freeString((*byte)(unsafe.Pointer(cstr))) -} - -func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { - if arrayPtr == 0 || length == 0 { + if blob.Data == 0 || blob.Len == 0 { return nil } + data := unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len) + copied := make([]byte, len(data)) + copy(copied, data) + return copied +} - ptrSlice := unsafe.Slice( - (**byte)(unsafe.Pointer(arrayPtr)), - length, - ) - - out := make([]string, 0, length) - for _, cstr := range ptrSlice { - out = append(out, GoString(uintptr(unsafe.Pointer(cstr)))) +func freeBlob(blobPtr uintptr) { + if blobPtr == 0 { + return } - return out + freeBlobFunc(blobPtr) +} + +func freeCString(cstrPtr uintptr) { + if cstrPtr == 0 { + return + } + freeStringFunc(cstrPtr) } // convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI -func buildArgs(args []driver.Value) ([]limboValue, error) { +// for Blob types, we have to pin them so they are not garbage collected before they can be copied +// into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call +func buildArgs(args []driver.Value) ([]limboValue, func(), error) { + pinner := new(runtime.Pinner) argSlice := make([]limboValue, len(args)) - for i, v := range args { + limboVal := limboValue{} switch val := v.(type) { case nil: - argSlice[i].Type = nullVal - + limboVal.Type = nullVal case int64: - argSlice[i].Type = intVal - storeInt64(&argSlice[i].Value, val) - + limboVal.Type = intVal + limboVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) case float64: - argSlice[i].Type = realVal - storeFloat64(&argSlice[i].Value, val) + limboVal.Type = realVal + limboVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) case string: - argSlice[i].Type = textVal + limboVal.Type = textVal cstr := CString(val) - storePointer(&argSlice[i].Value, cstr) + pinner.Pin(cstr) + *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr)) case []byte: - argSlice[i].Type = blobVal + limboVal.Type = blobVal blob := makeBlob(val) - *(*uintptr)(unsafe.Pointer(&argSlice[i].Value)) = uintptr(unsafe.Pointer(blob)) + pinner.Pin(blob) + *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob)) default: - return nil, fmt.Errorf("unsupported type: %T", v) + return nil, pinner.Unpin, fmt.Errorf("unsupported type: %T", v) } + argSlice[i] = limboVal } - return argSlice, nil + return argSlice, pinner.Unpin, nil } -func storeInt64(data *[8]byte, val int64) { - *(*int64)(unsafe.Pointer(data)) = val -} - -func storeFloat64(data *[8]byte, val float64) { - *(*float64)(unsafe.Pointer(data)) = val -} - -func storePointer(data *[8]byte, ptr *byte) { - *(*uintptr)(unsafe.Pointer(data)) = uintptr(unsafe.Pointer(ptr)) -} - -type stmtExecuteFn func(stmtPtr uintptr, argsPtr uintptr, argCount uint64, changes uintptr) int32 -type stmtQueryFn func(stmtPtr uintptr, argsPtr uintptr, argCount uint64) uintptr - /* Credit below (Apache2 License) to: https://github.com/ebitengine/purego/blob/main/internal/strings/strings.go */ diff --git a/bindings/java/.gitignore b/bindings/java/.gitignore index 581bf51fb..969ce0fcb 100644 --- a/bindings/java/.gitignore +++ b/bindings/java/.gitignore @@ -37,3 +37,6 @@ bin/ ### Mac OS ### .DS_Store + +### limbo builds ### +libs diff --git a/bindings/java/Makefile b/bindings/java/Makefile index 4bcbee2c1..fbcaea66b 100644 --- a/bindings/java/Makefile +++ b/bindings/java/Makefile @@ -1,4 +1,37 @@ -.PHONY: java_lint test build_test +RELEASE_DIR := libs +TEMP_DIR := temp + +CARGO_BUILD := cargo build --release + +MACOS_X86_DIR := $(RELEASE_DIR)/macos_x86 +MACOS_ARM64_DIR := $(RELEASE_DIR)/macos_arm64 +WINDOWS_DIR := $(RELEASE_DIR)/windows + +.PHONY: libs macos_x86 macos_arm64 windows lint lint_apply test build_test + +libs: macos_x86 macos_arm64 windows + +macos_x86: + @echo "Building release version for macOS x86_64..." + @mkdir -p $(TEMP_DIR) $(MACOS_X86_DIR) + @CARGO_TARGET_DIR=$(TEMP_DIR) $(CARGO_BUILD) --target x86_64-apple-darwin + @cp $(TEMP_DIR)/x86_64-apple-darwin/release/lib_limbo_java.dylib $(MACOS_X86_DIR) + @rm -rf $(TEMP_DIR) + +macos_arm64: + @echo "Building release version for macOS ARM64..." + @mkdir -p $(TEMP_DIR) $(MACOS_ARM64_DIR) + @CARGO_TARGET_DIR=$(TEMP_DIR) $(CARGO_BUILD) --target aarch64-apple-darwin + @cp $(TEMP_DIR)/aarch64-apple-darwin/release/lib_limbo_java.dylib $(MACOS_ARM64_DIR) + @rm -rf $(TEMP_DIR) + +# windows generates file with name `_limbo_java.dll` unlike others, so we manually add prefix +windows: + @echo "Building release version for Windows..." + @mkdir -p $(TEMP_DIR) $(WINDOWS_DIR) + @CARGO_TARGET_DIR=$(TEMP_DIR) $(CARGO_BUILD) --target x86_64-pc-windows-gnu + @cp $(TEMP_DIR)/x86_64-pc-windows-gnu/release/_limbo_java.dll $(WINDOWS_DIR)/lib_limbo_java.dll + @rm -rf $(TEMP_DIR) lint: ./gradlew spotlessCheck @@ -7,7 +40,10 @@ lint_apply: ./gradlew spotlessApply test: lint build_test - ./gradlew test --info + ./gradlew test build_test: CARGO_TARGET_DIR=src/test/resources/limbo cargo build + +publish_local: + ./gradlew clean publishToMavenLocal diff --git a/bindings/java/README.md b/bindings/java/README.md new file mode 100644 index 000000000..051833325 --- /dev/null +++ b/bindings/java/README.md @@ -0,0 +1,66 @@ +# Limbo JDBC Driver + +The Limbo JDBC driver is a library for accessing and creating Limbo database files using Java. + +## Project Status + +The project is actively developed. Feel free to open issues and contribute. + +To view related works, visit this [issue](https://github.com/tursodatabase/limbo/issues/615). + +## How to use + +Currently, we have not published to the maven central. Instead, you can locally build the jar and deploy it to maven local to use it. + +### Build jar and publish to maven local +```shell +$ cd bindings/java + +# Please select the appropriate target platform, currently supports `macos_x86`, `macos_arm64`, `windows` +$ make macos_x86 + +# deploy to maven local +$ make publish_local +``` + +Now you can use the dependency as follows: +```kotlin +dependencies { + implementation("org.github.tursodatabase:limbo:0.0.1-SNAPSHOT") +} +``` + +## Development + +### How to Run Tests + +To run tests, use the following command: + +```shell +$ make test +``` + +### Code Formatting + +To unify Java's formatting style, we use Spotless. To apply the formatting style, run: + +```shell +$ make lint_apply +``` + +To apply the formatting style for Rust, run the following command: + +```shell +$ cargo fmt +``` + +## Concepts + +Note that this project is actively developed, so the concepts might change in the future. + +- `LimboDB` represents a Limbo database. +- `LimboConnection` represents a connection to `LimboDB`. Multiple `LimboConnections` can be created on the same + `LimboDB`. +- `LimboStatement` represents a Limbo database statement. Multiple `LimboStatements` can be created on the same + `LimboConnection`. +- `LimboResultSet` represents the result of `LimboStatement` execution. It is one-to-one mapped to `LimboStatement`. diff --git a/bindings/java/build.gradle.kts b/bindings/java/build.gradle.kts index a9137c888..6a20b4330 100644 --- a/bindings/java/build.gradle.kts +++ b/bindings/java/build.gradle.kts @@ -6,6 +6,8 @@ import org.gradle.api.tasks.testing.logging.TestLogEvent plugins { java application + `java-library` + `maven-publish` id("net.ltgt.errorprone") version "3.1.0" // If you're stuck on JRE 8, use id 'com.diffplug.spotless' version '6.13.0' or older. @@ -20,13 +22,23 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } +publishing { + publications { + create("mavenJava") { + from(components["java"]) + groupId = "org.github.tursodatabase" + artifactId = "limbo" + version = "0.0.1-SNAPSHOT" + } + } +} + repositories { mavenCentral() } dependencies { - implementation("ch.qos.logback:logback-classic:1.2.13") - implementation("ch.qos.logback:logback-core:1.2.13") + implementation("org.slf4j:slf4j-api:1.7.32") errorprone("com.uber.nullaway:nullaway:0.10.26") // maximum version which supports java 8 errorprone("com.google.errorprone:error_prone_core:2.10.0") // maximum version which supports java 8 @@ -34,12 +46,13 @@ dependencies { testImplementation(platform("org.junit:junit-bom:5.10.0")) testImplementation("org.junit.jupiter:junit-jupiter") testImplementation("org.assertj:assertj-core:3.27.0") + + testImplementation("ch.qos.logback:logback-classic:1.2.13") + testImplementation("ch.qos.logback:logback-core:1.2.13") } application { - mainClass.set("org.github.tursodatabase.Main") - - val limboSystemLibraryPath = System.getenv("LIMBO_SYSTEM_PATH") + val limboSystemLibraryPath = System.getenv("LIMBO_LIBRARY_PATH") if (limboSystemLibraryPath != null) { applicationDefaultJvmArgs = listOf( "-Djava.library.path=${System.getProperty("java.library.path")}:$limboSystemLibraryPath" @@ -47,6 +60,12 @@ application { } } +tasks.jar { + from("libs") { + into("libs") + } +} + tasks.test { useJUnitPlatform() // In order to find rust built file under resources, we need to set it as system path diff --git a/bindings/java/example/.gitignore b/bindings/java/example/.gitignore new file mode 100644 index 000000000..7e659ca6f --- /dev/null +++ b/bindings/java/example/.gitignore @@ -0,0 +1,10 @@ +.gradle +build/ +!gradle/wrapper/gradle-wrapper.jar +!**/src/main/**/build/ +!**/src/test/**/build/ +sample.db +sample.db-wal + +### IntelliJ IDEA ### +.idea diff --git a/bindings/java/example/build.gradle.kts b/bindings/java/example/build.gradle.kts new file mode 100644 index 000000000..44e605d11 --- /dev/null +++ b/bindings/java/example/build.gradle.kts @@ -0,0 +1,21 @@ +plugins { + id("java") +} + +group = "org.github.seonwkim" +version = "1.0-SNAPSHOT" + +repositories { + mavenLocal() + mavenCentral() +} + +dependencies { + implementation("org.github.tursodatabase:limbo:0.0.1-SNAPSHOT") + testImplementation(platform("org.junit:junit-bom:5.10.0")) + testImplementation("org.junit.jupiter:junit-jupiter") +} + +tasks.test { + useJUnitPlatform() +} diff --git a/bindings/java/example/gradle/wrapper/gradle-wrapper.jar b/bindings/java/example/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 000000000..249e5832f Binary files /dev/null and b/bindings/java/example/gradle/wrapper/gradle-wrapper.jar differ diff --git a/bindings/java/example/gradle/wrapper/gradle-wrapper.properties b/bindings/java/example/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 000000000..ec7af5ad3 --- /dev/null +++ b/bindings/java/example/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Sun Feb 02 20:06:51 KST 2025 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/bindings/java/example/gradlew b/bindings/java/example/gradlew new file mode 100755 index 000000000..1b6c78733 --- /dev/null +++ b/bindings/java/example/gradlew @@ -0,0 +1,234 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit + +APP_NAME="Gradle" +APP_BASE_NAME=${0##*/} + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/bindings/java/example/gradlew.bat b/bindings/java/example/gradlew.bat new file mode 100644 index 000000000..107acd32c --- /dev/null +++ b/bindings/java/example/gradlew.bat @@ -0,0 +1,89 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/bindings/java/example/settings.gradle.kts b/bindings/java/example/settings.gradle.kts new file mode 100644 index 000000000..7516803c6 --- /dev/null +++ b/bindings/java/example/settings.gradle.kts @@ -0,0 +1,2 @@ +rootProject.name = "example" + diff --git a/bindings/java/example/src/main/java/org/github/seonwkim/Main.java b/bindings/java/example/src/main/java/org/github/seonwkim/Main.java new file mode 100644 index 000000000..ca1d8bc9d --- /dev/null +++ b/bindings/java/example/src/main/java/org/github/seonwkim/Main.java @@ -0,0 +1,27 @@ +package org.github.seonwkim; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; + +public class Main { + public static void main(String[] args) { + try (Connection conn = DriverManager.getConnection("jdbc:sqlite:sample.db")) { + Statement stmt = + conn.createStatement( + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + ResultSet.CLOSE_CURSORS_AT_COMMIT); + stmt.execute("CREATE TABLE users (id INT PRIMARY KEY, username TEXT);"); + stmt.execute("INSERT INTO users VALUES (1, 'seonwoo');"); + stmt.execute("INSERT INTO users VALUES (2, 'seonwoo');"); + stmt.execute("INSERT INTO users VALUES (3, 'seonwoo');"); + stmt.execute("SELECT * FROM users"); + System.out.println( + "result: " + stmt.getResultSet().getInt(1) + ", " + stmt.getResultSet().getString(2)); + } catch (Exception e) { + System.out.println("Error: " + e); + } + } +} diff --git a/bindings/java/rs_src/limbo_connection.rs b/bindings/java/rs_src/limbo_connection.rs index dd54e9087..048310ce5 100644 --- a/bindings/java/rs_src/limbo_connection.rs +++ b/bindings/java/rs_src/limbo_connection.rs @@ -12,7 +12,6 @@ use std::rc::Rc; use std::sync::Arc; #[derive(Clone)] -#[allow(dead_code)] pub struct LimboConnection { // Because java's LimboConnection is 1:1 mapped to limbo connection, we can use Rc pub(crate) conn: Rc, @@ -29,7 +28,6 @@ impl LimboConnection { Box::into_raw(Box::new(self)) as jlong } - #[allow(dead_code)] pub fn drop(ptr: jlong) { let _boxed = unsafe { Box::from_raw(ptr as *mut LimboConnection) }; } @@ -43,6 +41,15 @@ pub fn to_limbo_connection(ptr: jlong) -> Result<&'static mut LimboConnection> { } } +#[no_mangle] +pub extern "system" fn Java_org_github_tursodatabase_core_LimboConnection__1close<'local>( + _env: JNIEnv<'local>, + _obj: JObject<'local>, + connection_ptr: jlong, +) { + LimboConnection::drop(connection_ptr); +} + #[no_mangle] pub extern "system" fn Java_org_github_tursodatabase_core_LimboConnection_prepareUtf8<'local>( mut env: JNIEnv<'local>, diff --git a/bindings/java/rs_src/limbo_statement.rs b/bindings/java/rs_src/limbo_statement.rs index 7de4b2c19..aed8e7d99 100644 --- a/bindings/java/rs_src/limbo_statement.rs +++ b/bindings/java/rs_src/limbo_statement.rs @@ -29,7 +29,6 @@ impl LimboStatement { Box::into_raw(Box::new(self)) as jlong } - #[allow(dead_code)] pub fn drop(ptr: jlong) { let _boxed = unsafe { Box::from_raw(ptr as *mut LimboStatement) }; } @@ -88,6 +87,15 @@ pub extern "system" fn Java_org_github_tursodatabase_core_LimboStatement_step<'l } } +#[no_mangle] +pub extern "system" fn Java_org_github_tursodatabase_core_LimboStatement__1close<'local>( + _env: JNIEnv<'local>, + _obj: JObject<'local>, + stmt_ptr: jlong, +) { + LimboStatement::drop(stmt_ptr); +} + fn row_to_obj_array<'local>( env: &mut JNIEnv<'local>, row: &limbo_core::Row, diff --git a/bindings/java/src/main/java/org/github/tursodatabase/JDBC.java b/bindings/java/src/main/java/org/github/tursodatabase/JDBC.java index 63c6e57d7..6928c32cc 100644 --- a/bindings/java/src/main/java/org/github/tursodatabase/JDBC.java +++ b/bindings/java/src/main/java/org/github/tursodatabase/JDBC.java @@ -3,20 +3,23 @@ package org.github.tursodatabase; import java.sql.*; import java.util.Locale; import java.util.Properties; -import java.util.logging.Logger; import org.github.tursodatabase.annotations.Nullable; import org.github.tursodatabase.annotations.SkipNullableCheck; import org.github.tursodatabase.core.LimboConnection; import org.github.tursodatabase.jdbc4.JDBC4Connection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class JDBC implements Driver { + private static final Logger logger = LoggerFactory.getLogger(JDBC.class); + private static final String VALID_URL_PREFIX = "jdbc:sqlite:"; static { try { DriverManager.registerDriver(new JDBC()); } catch (Exception e) { - // TODO: log + logger.error("Failed to register driver", e); } } @@ -72,7 +75,7 @@ public class JDBC implements Driver { @Override @SkipNullableCheck - public Logger getParentLogger() throws SQLFeatureNotSupportedException { + public java.util.logging.Logger getParentLogger() throws SQLFeatureNotSupportedException { // TODO return null; } diff --git a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboConnection.java b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboConnection.java index cd200f74f..34e5692e3 100644 --- a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboConnection.java +++ b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboConnection.java @@ -16,6 +16,7 @@ public abstract class LimboConnection implements Connection { private final long connectionPtr; private final AbstractDB database; + private boolean closed; public LimboConnection(String url, String filePath) throws SQLException { this(url, filePath, new Properties()); @@ -28,24 +29,8 @@ public abstract class LimboConnection implements Connection { * @param filePath path to file */ public LimboConnection(String url, String filePath, Properties properties) throws SQLException { - AbstractDB db = null; - - try { - db = open(url, filePath, properties); - } catch (Throwable t) { - try { - if (db != null) { - db.close(); - } - } catch (Throwable t2) { - t.addSuppressed(t2); - } - - throw t; - } - - this.database = db; - this.connectionPtr = db.connect(); + this.database = open(url, filePath, properties); + this.connectionPtr = this.database.connect(); } private static AbstractDB open(String url, String filePath, Properties properties) @@ -59,13 +44,18 @@ public abstract class LimboConnection implements Connection { @Override public void close() throws SQLException { - if (isClosed()) return; - database.close(); + if (isClosed()) { + return; + } + this._close(this.connectionPtr); + this.closed = true; } + private native void _close(long connectionPtr); + @Override public boolean isClosed() throws SQLException { - return database.isClosed(); + return closed; } public AbstractDB getDatabase() { @@ -114,12 +104,15 @@ public abstract class LimboConnection implements Connection { */ protected void checkCursor(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { - if (resultSetType != ResultSet.TYPE_FORWARD_ONLY) + if (resultSetType != ResultSet.TYPE_FORWARD_ONLY) { throw new SQLException("SQLite only supports TYPE_FORWARD_ONLY cursors"); - if (resultSetConcurrency != ResultSet.CONCUR_READ_ONLY) + } + if (resultSetConcurrency != ResultSet.CONCUR_READ_ONLY) { throw new SQLException("SQLite only supports CONCUR_READ_ONLY cursors"); - if (resultSetHoldability != ResultSet.CLOSE_CURSORS_AT_COMMIT) + } + if (resultSetHoldability != ResultSet.CLOSE_CURSORS_AT_COMMIT) { throw new SQLException("SQLite only supports closing cursors at commit"); + } } public void setBusyTimeout(int busyTimeout) { diff --git a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboDB.java b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboDB.java index c32eaadf9..ac3f03b5c 100644 --- a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboDB.java +++ b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboDB.java @@ -2,6 +2,10 @@ package org.github.tursodatabase.core; import static org.github.tursodatabase.utils.ByteArrayUtils.stringToUtf8ByteArray; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.sql.SQLException; import java.util.concurrent.locks.ReentrantLock; import org.github.tursodatabase.LimboErrorCode; @@ -30,17 +34,136 @@ public final class LimboDB extends AbstractDB { } } - /** Loads the SQLite interface backend. */ + /** + * Enum representing different architectures and their corresponding library paths and file + * extensions. + */ + enum Architecture { + MACOS_ARM64("libs/macos_arm64/lib_limbo_java.dylib", ".dylib"), + MACOS_X86("libs/macos_x86/lib_limbo_java.dylib", ".dylib"), + WINDOWS("libs/windows/lib_limbo_java.dll", ".dll"), + UNSUPPORTED("", ""); + + private final String libPath; + private final String fileExtension; + + Architecture(String libPath, String fileExtension) { + this.libPath = libPath; + this.fileExtension = fileExtension; + } + + public String getLibPath() { + return libPath; + } + + public String getFileExtension() { + return fileExtension; + } + + public static Architecture detect() { + String osName = System.getProperty("os.name").toLowerCase(); + String osArch = System.getProperty("os.arch").toLowerCase(); + + if (osName.contains("mac")) { + if (osArch.contains("aarch64") || osArch.contains("arm64")) { + return MACOS_ARM64; + } else if (osArch.contains("x86_64") || osArch.contains("amd64")) { + return MACOS_X86; + } + } else if (osName.contains("win")) { + return WINDOWS; + } + + return UNSUPPORTED; + } + } + + /** + * This method attempts to load the native library required for Limbo operations. It first tries + * to load the library from the system's library path using {@link #loadFromSystemPath()}. If that + * fails, it attempts to load the library from the JAR file using {@link #loadFromJar()}. If + * either method succeeds, the `isLoaded` flag is set to true. If both methods fail, an {@link + * InternalError} is thrown indicating that the necessary native library could not be loaded. + * + * @throws InternalError if the native library cannot be loaded from either the system path or the + * JAR file. + */ public static void load() { if (isLoaded) { return; } + if (loadFromSystemPath() || loadFromJar()) { + isLoaded = true; + return; + } + + throw new InternalError("Unable to load necessary native library"); + } + + /** + * Load the native library from the system path. + * + *

This method attempts to load the native library named "_limbo_java" from the system's + * library path. If the library is successfully loaded, the `isLoaded` flag is set to true. + * + * @return true if the library was successfully loaded, false otherwise. + */ + private static boolean loadFromSystemPath() { try { System.loadLibrary("_limbo_java"); - } finally { - isLoaded = true; + return true; + } catch (Throwable t) { + logger.info("Unable to load from default path: {}", String.valueOf(t)); } + + return false; + } + + /** + * Load the native library from the JAR file. + * + *

By default, native libraries are packaged within the JAR file. This method extracts the + * appropriate native library for the current operating system and architecture from the JAR and + * loads it. + * + * @return true if the library was successfully loaded, false otherwise. + */ + private static boolean loadFromJar() { + Architecture arch = Architecture.detect(); + if (arch == Architecture.UNSUPPORTED) { + logger.info("Unsupported OS or architecture"); + return false; + } + + try { + InputStream is = LimboDB.class.getClassLoader().getResourceAsStream(arch.getLibPath()); + assert is != null; + File file = convertInputStreamToFile(is, arch); + System.load(file.getPath()); + return true; + } catch (Throwable t) { + logger.info("Unable to load from jar: {}", String.valueOf(t)); + } + + return false; + } + + private static File convertInputStreamToFile(InputStream is, Architecture arch) + throws IOException { + File tempFile = File.createTempFile("lib", arch.getFileExtension()); + tempFile.deleteOnExit(); + + try (FileOutputStream os = new FileOutputStream(tempFile)) { + int read; + byte[] bytes = new byte[1024]; + + while ((read = is.read(bytes)) != -1) { + os.write(bytes, 0, read); + } + } + + return tempFile; } /** @@ -56,8 +179,6 @@ public final class LimboDB extends AbstractDB { super(url, filePath); } - // WRAPPER FUNCTIONS //////////////////////////////////////////// - // TODO: add support for JNI @Override protected native long openUtf8(byte[] file, int openFlags) throws SQLException; @@ -66,9 +187,6 @@ public final class LimboDB extends AbstractDB { @Override protected native void close0() throws SQLException; - // TODO: add support for JNI - native int execUtf8(byte[] sqlUtf8) throws SQLException; - // TODO: add support for JNI @Override public native void interrupt(); diff --git a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboResultSet.java b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboResultSet.java index c6cb8d00e..1884d4786 100644 --- a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboResultSet.java +++ b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboResultSet.java @@ -1,5 +1,6 @@ package org.github.tursodatabase.core; +import java.sql.ResultSet; import java.sql.SQLException; import org.github.tursodatabase.annotations.Nullable; import org.slf4j.Logger; @@ -39,6 +40,20 @@ public class LimboResultSet { this.statement = statement; } + /** + * Consumes all the rows in this {@link ResultSet} until the {@link #next()} method returns + * `false`. + * + * @throws SQLException if the result set is not open or if an error occurs while iterating. + */ + public void consumeAll() throws SQLException { + if (!open) { + throw new SQLException("The result set is not open"); + } + + while (next()) {} + } + /** * Moves the cursor forward one row from its current position. A {@link LimboResultSet} cursor is * initially positioned before the first fow; the first call to the method next makes @@ -50,7 +65,11 @@ public class LimboResultSet { * cursor can only move forward. */ public boolean next() throws SQLException { - if (!open || isEmptyResultSet || pastLastRow) { + if (!open) { + throw new SQLException("The resultSet is not open"); + } + + if (isEmptyResultSet || pastLastRow) { return false; // completed ResultSet } @@ -70,9 +89,6 @@ public class LimboResultSet { } pastLastRow = lastStepResult.isDone(); - if (pastLastRow) { - open = false; - } return !pastLastRow; } @@ -97,6 +113,29 @@ public class LimboResultSet { } } + public void close() throws SQLException { + this.open = false; + } + + // Note that columnIndex starts from 1 + @Nullable + public Object get(int columnIndex) throws SQLException { + if (!this.isOpen()) { + throw new SQLException("ResultSet is not open"); + } + + if (this.lastStepResult == null || this.lastStepResult.getResult() == null) { + throw new SQLException("ResultSet is null"); + } + + final Object[] resultSet = this.lastStepResult.getResult(); + if (columnIndex > resultSet.length || columnIndex < 0) { + throw new SQLException("columnIndex out of bound"); + } + + return resultSet[columnIndex - 1]; + } + @Override public String toString() { return "LimboResultSet{" diff --git a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboStatement.java b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboStatement.java index c749e27cc..fa660b67c 100644 --- a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboStatement.java +++ b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboStatement.java @@ -21,6 +21,8 @@ public class LimboStatement { private final long statementPointer; private final LimboResultSet resultSet; + private boolean closed; + // TODO: what if the statement we ran was DDL, update queries and etc. Should we still create a // resultSet? public LimboStatement(String sql, long statementPointer) { @@ -53,6 +55,11 @@ public class LimboStatement { return result; } + /** + * Because Limbo supports async I/O, it is possible to return a {@link LimboStepResult} with + * {@link LimboStepResult#STEP_RESULT_ID_ROW}. However, this is handled by the native side, so you + * can expect that this method will not return a {@link LimboStepResult#STEP_RESULT_ID_ROW}. + */ @Nullable private native LimboStepResult step(long stmtPointer) throws SQLException; @@ -67,6 +74,30 @@ public class LimboStatement { LimboExceptionUtils.throwLimboException(errorCode, errorMessageBytes); } + /** + * Closes the current statement and releases any resources associated with it. This method calls + * the native `_close` method to perform the actual closing operation. + */ + public void close() throws SQLException { + if (closed) { + return; + } + this.resultSet.close(); + _close(statementPointer); + closed = true; + } + + private native void _close(long statementPointer); + + /** + * Checks if the statement is closed. + * + * @return true if the statement is closed, false otherwise. + */ + public boolean isClosed() { + return closed; + } + @Override public String toString() { return "LimboStatement{" diff --git a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboStepResult.java b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboStepResult.java index 93a1878aa..b82750b9a 100644 --- a/bindings/java/src/main/java/org/github/tursodatabase/core/LimboStepResult.java +++ b/bindings/java/src/main/java/org/github/tursodatabase/core/LimboStepResult.java @@ -47,6 +47,11 @@ public class LimboStepResult { || stepResultId == STEP_RESULT_ID_ERROR; } + @Nullable + public Object[] getResult() { + return result; + } + @Override public String toString() { return "LimboStepResult{" diff --git a/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4Connection.java b/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4Connection.java index b17c6af36..3c32ffaf2 100644 --- a/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4Connection.java +++ b/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4Connection.java @@ -83,13 +83,12 @@ public class JDBC4Connection extends LimboConnection { @Override public void close() throws SQLException { - // TODO + super.close(); } @Override public boolean isClosed() throws SQLException { - // TODO - return false; + return super.isClosed(); } @Override diff --git a/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4ResultSet.java b/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4ResultSet.java index 867b2688e..ad3720b8c 100644 --- a/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4ResultSet.java +++ b/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4ResultSet.java @@ -3,10 +3,26 @@ package org.github.tursodatabase.jdbc4; import java.io.InputStream; import java.io.Reader; import java.math.BigDecimal; +import java.math.RoundingMode; import java.net.URL; -import java.sql.*; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; import java.util.Calendar; import java.util.Map; +import org.github.tursodatabase.annotations.Nullable; import org.github.tursodatabase.annotations.SkipNullableCheck; import org.github.tursodatabase.core.LimboResultSet; @@ -25,7 +41,7 @@ public class JDBC4ResultSet implements ResultSet { @Override public void close() throws SQLException { - // TODO + resultSet.close(); } @Override @@ -35,64 +51,99 @@ public class JDBC4ResultSet implements ResultSet { } @Override + @Nullable public String getString(int columnIndex) throws SQLException { - // TODO - return ""; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return null; + } + return wrapTypeConversion(() -> (String) result); } @Override public boolean getBoolean(int columnIndex) throws SQLException { - // TODO - return false; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return false; + } + return wrapTypeConversion(() -> (Long) result != 0); } @Override public byte getByte(int columnIndex) throws SQLException { - // TODO - return 0; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return 0; + } + return wrapTypeConversion(() -> ((Long) result).byteValue()); } @Override public short getShort(int columnIndex) throws SQLException { - // TODO - return 0; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return 0; + } + return wrapTypeConversion(() -> ((Long) result).shortValue()); } @Override public int getInt(int columnIndex) throws SQLException { - // TODO - return 0; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return 0; + } + return wrapTypeConversion(() -> ((Long) result).intValue()); } @Override public long getLong(int columnIndex) throws SQLException { - // TODO - return 0; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return 0; + } + return wrapTypeConversion(() -> (long) result); } @Override public float getFloat(int columnIndex) throws SQLException { - // TODO - return 0; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return 0; + } + return wrapTypeConversion(() -> ((Double) result).floatValue()); } @Override public double getDouble(int columnIndex) throws SQLException { - // TODO - return 0; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return 0; + } + return wrapTypeConversion(() -> (double) result); } + // TODO: customize rounding mode? @Override - @SkipNullableCheck + @Nullable public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException { - // TODO - return null; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return null; + } + final double doubleResult = wrapTypeConversion(() -> (double) result); + final BigDecimal bigDecimalResult = BigDecimal.valueOf(doubleResult); + return bigDecimalResult.setScale(scale, RoundingMode.HALF_UP); } @Override + @Nullable public byte[] getBytes(int columnIndex) throws SQLException { - // TODO - return new byte[0]; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return null; + } + return wrapTypeConversion(() -> (byte[]) result); } @Override @@ -300,10 +351,14 @@ public class JDBC4ResultSet implements ResultSet { } @Override - @SkipNullableCheck + @Nullable public BigDecimal getBigDecimal(int columnIndex) throws SQLException { - // TODO - return null; + final Object result = resultSet.get(columnIndex); + if (result == null) { + return null; + } + final double doubleResult = wrapTypeConversion(() -> (double) result); + return BigDecimal.valueOf(doubleResult); } @Override @@ -866,8 +921,7 @@ public class JDBC4ResultSet implements ResultSet { @Override public boolean isClosed() throws SQLException { - // TODO - return false; + return !resultSet.isOpen(); } @Override @@ -1127,7 +1181,16 @@ public class JDBC4ResultSet implements ResultSet { return false; } - private SQLException throwNotSupportedException() { - return new SQLFeatureNotSupportedException("Not implemented by the driver"); + @FunctionalInterface + public interface ResultSetSupplier { + T get() throws Exception; + } + + private T wrapTypeConversion(ResultSetSupplier supplier) throws SQLException { + try { + return supplier.get(); + } catch (Exception e) { + throw new SQLException("Type conversion failed: " + e); + } } } diff --git a/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4Statement.java b/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4Statement.java index eee4c95a3..3965e3cae 100644 --- a/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4Statement.java +++ b/bindings/java/src/main/java/org/github/tursodatabase/jdbc4/JDBC4Statement.java @@ -19,6 +19,8 @@ public class JDBC4Statement implements Statement { private final LimboConnection connection; @Nullable private LimboStatement statement = null; + // Because JDBC4Statement has different life cycle in compared to LimboStatement, let's use this + // field to manage JDBC4Statement lifecycle private boolean closed; private boolean closeOnCompletion; @@ -51,9 +53,21 @@ public class JDBC4Statement implements Statement { this.resultSetHoldability = resultSetHoldability; } + // TODO: should executeQuery run execute right after preparing the statement? @Override public ResultSet executeQuery(String sql) throws SQLException { - execute(sql); + ensureOpen(); + statement = + this.withConnectionTimeout( + () -> { + try { + // TODO: if sql is a readOnly query, do we still need the locks? + connectionLock.lock(); + return connection.prepare(sql); + } finally { + connectionLock.unlock(); + } + }); requireNonNull(statement, "statement should not be null after running execute method"); return new JDBC4ResultSet(statement.getResultSet()); @@ -65,9 +79,7 @@ public class JDBC4Statement implements Statement { requireNonNull(statement, "statement should not be null after running execute method"); final LimboResultSet resultSet = statement.getResultSet(); - while (resultSet.isOpen()) { - resultSet.next(); - } + resultSet.consumeAll(); // TODO: return update count; return 0; @@ -75,8 +87,14 @@ public class JDBC4Statement implements Statement { @Override public void close() throws SQLException { - clearGeneratedKeys(); - internalClose(); + if (closed) { + return; + } + + if (this.statement != null) { + this.statement.close(); + } + closed = true; } @@ -150,8 +168,7 @@ public class JDBC4Statement implements Statement { */ @Override public boolean execute(String sql) throws SQLException { - internalClose(); - + ensureOpen(); return this.withConnectionTimeout( () -> { try { @@ -298,8 +315,7 @@ public class JDBC4Statement implements Statement { @Override public boolean isClosed() throws SQLException { - // TODO - return false; + return this.closed; } @Override @@ -346,14 +362,6 @@ public class JDBC4Statement implements Statement { return false; } - protected void internalClose() throws SQLException { - // TODO - } - - protected void clearGeneratedKeys() throws SQLException { - // TODO - } - protected void updateGeneratedKeys() throws SQLException { // TODO } @@ -378,4 +386,10 @@ public class JDBC4Statement implements Statement { protected interface SQLCallable { T call() throws SQLException; } + + private void ensureOpen() throws SQLException { + if (closed) { + throw new SQLException("Statement is closed"); + } + } } diff --git a/bindings/java/src/test/java/org/github/tursodatabase/core/LimboStatementTest.java b/bindings/java/src/test/java/org/github/tursodatabase/core/LimboStatementTest.java new file mode 100644 index 000000000..fe274b07e --- /dev/null +++ b/bindings/java/src/test/java/org/github/tursodatabase/core/LimboStatementTest.java @@ -0,0 +1,31 @@ +package org.github.tursodatabase.core; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Properties; +import org.github.tursodatabase.TestUtils; +import org.github.tursodatabase.jdbc4.JDBC4Connection; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class LimboStatementTest { + + private JDBC4Connection connection; + + @BeforeEach + void setUp() throws Exception { + String filePath = TestUtils.createTempFile(); + String url = "jdbc:sqlite:" + filePath; + connection = new JDBC4Connection(url, filePath, new Properties()); + } + + @Test + void closing_statement_closes_related_resources() throws Exception { + LimboStatement stmt = connection.prepare("SELECT 1;"); + stmt.execute(); + + stmt.close(); + assertTrue(stmt.isClosed()); + assertFalse(stmt.getResultSet().isOpen()); + } +} diff --git a/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4ConnectionTest.java b/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4ConnectionTest.java index 60f6ee56e..1bc4fb526 100644 --- a/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4ConnectionTest.java +++ b/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4ConnectionTest.java @@ -65,4 +65,23 @@ class JDBC4ConnectionTest { void prepare_simple_create_table() throws Exception { connection.prepare("CREATE TABLE users (id INT PRIMARY KEY, username TEXT)"); } + + @Test + void calling_close_multiple_times_throws_no_exception() throws Exception { + assertFalse(connection.isClosed()); + connection.close(); + assertTrue(connection.isClosed()); + connection.close(); + } + + @Test + void calling_methods_on_closed_connection_should_throw_exception() throws Exception { + connection.close(); + assertTrue(connection.isClosed()); + assertThrows( + SQLException.class, + () -> + connection.createStatement( + ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, -1)); + } } diff --git a/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4ResultSetTest.java b/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4ResultSetTest.java index f764a9361..ec79ce1eb 100644 --- a/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4ResultSetTest.java +++ b/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4ResultSetTest.java @@ -1,14 +1,25 @@ package org.github.tursodatabase.jdbc4; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.sql.ResultSet; +import java.sql.SQLException; import java.sql.Statement; import java.util.Properties; +import java.util.stream.Stream; import org.github.tursodatabase.TestUtils; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; class JDBC4ResultSetTest { @@ -57,4 +68,449 @@ class JDBC4ResultSetTest { // as well assertFalse(resultSet.next()); } + + @Test + void close_resultSet_test() throws Exception { + stmt.executeQuery("SELECT 1;"); + ResultSet resultSet = stmt.getResultSet(); + + assertFalse(resultSet.isClosed()); + resultSet.close(); + assertTrue(resultSet.isClosed()); + } + + @Test + void calling_methods_on_closed_resultSet_should_throw_exception() throws Exception { + stmt.executeQuery("SELECT 1;"); + ResultSet resultSet = stmt.getResultSet(); + resultSet.close(); + assertTrue(resultSet.isClosed()); + + assertThrows(SQLException.class, resultSet::next); + } + + @Test + void test_getString() throws Exception { + stmt.executeUpdate("CREATE TABLE test_string (string_col TEXT);"); + stmt.executeUpdate("INSERT INTO test_string (string_col) VALUES ('test');"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_string"); + assertTrue(resultSet.next()); + assertEquals("test", resultSet.getString(1)); + } + + @Test + void test_getString_returns_null_on_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (string_col TEXT);"); + stmt.executeUpdate("INSERT INTO test_null (string_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertNull(resultSet.getString(1)); + } + + @Test + void test_getBoolean_true() throws Exception { + stmt.executeUpdate("CREATE TABLE test_boolean (boolean_col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_boolean (boolean_col) VALUES (1);"); + stmt.executeUpdate("INSERT INTO test_boolean (boolean_col) VALUES (2);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_boolean"); + + assertTrue(resultSet.next()); + assertTrue(resultSet.getBoolean(1)); + + resultSet.next(); + assertTrue(resultSet.getBoolean(1)); + } + + @Test + void test_getBoolean_false() throws Exception { + stmt.executeUpdate("CREATE TABLE test_boolean (boolean_col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_boolean (boolean_col) VALUES (0);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_boolean"); + assertTrue(resultSet.next()); + assertFalse(resultSet.getBoolean(1)); + } + + @Test + void test_getBoolean_returns_false_on_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (boolean_col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_null (boolean_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertFalse(resultSet.getBoolean(1)); + } + + @Test + void test_getByte() throws Exception { + stmt.executeUpdate("CREATE TABLE test_byte (byte_col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_byte (byte_col) VALUES (1);"); + stmt.executeUpdate("INSERT INTO test_byte (byte_col) VALUES (128);"); // Exceeds byte size + stmt.executeUpdate("INSERT INTO test_byte (byte_col) VALUES (-129);"); // Exceeds byte size + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_byte"); + + // Test value that fits within byte size + assertTrue(resultSet.next()); + assertEquals(1, resultSet.getByte(1)); + + // Test value that exceeds byte size (positive overflow) + assertTrue(resultSet.next()); + assertEquals(-128, resultSet.getByte(1)); // 128 overflows to -128 + + // Test value that exceeds byte size (negative overflow) + assertTrue(resultSet.next()); + assertEquals(127, resultSet.getByte(1)); // -129 overflows to 127 + } + + @Test + void test_getByte_returns_zero_on_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (byte_col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_null (byte_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertEquals(0, resultSet.getByte(1)); + } + + @Test + void test_getShort() throws Exception { + stmt.executeUpdate("CREATE TABLE test_short (short_col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_short (short_col) VALUES (123);"); + stmt.executeUpdate("INSERT INTO test_short (short_col) VALUES (32767);"); // Max short value + stmt.executeUpdate("INSERT INTO test_short (short_col) VALUES (-32768);"); // Min short value + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_short"); + + // Test typical short value + assertTrue(resultSet.next()); + assertEquals(123, resultSet.getShort(1)); + + // Test maximum short value + assertTrue(resultSet.next()); + assertEquals(32767, resultSet.getShort(1)); + + // Test minimum short value + assertTrue(resultSet.next()); + assertEquals(-32768, resultSet.getShort(1)); + } + + @Test + void test_getShort_returns_zero_on_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (short_col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_null (short_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertEquals(0, resultSet.getShort(1)); + } + + @Test + void test_getInt() throws Exception { + stmt.executeUpdate("CREATE TABLE test_int (int_col INT);"); + stmt.executeUpdate("INSERT INTO test_int (int_col) VALUES (12345);"); + stmt.executeUpdate("INSERT INTO test_int (int_col) VALUES (2147483647);"); // Max int value + stmt.executeUpdate("INSERT INTO test_int (int_col) VALUES (-2147483648);"); // Min int value + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_int"); + + // Test typical int value + assertTrue(resultSet.next()); + assertEquals(12345, resultSet.getInt(1)); + + // Test maximum int value + assertTrue(resultSet.next()); + assertEquals(2147483647, resultSet.getInt(1)); + + // Test minimum int value + assertTrue(resultSet.next()); + assertEquals(-2147483648, resultSet.getInt(1)); + } + + @Test + void test_getInt_returns_zero_on_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (int_col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_null (int_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertEquals(0, resultSet.getInt(1)); + } + + @Test + @Disabled("limbo has a bug which sees -9223372036854775808 as double") + void test_getLong() throws Exception { + stmt.executeUpdate("CREATE TABLE test_long (long_col BIGINT);"); + stmt.executeUpdate("INSERT INTO test_long (long_col) VALUES (1234567890);"); + stmt.executeUpdate( + "INSERT INTO test_long (long_col) VALUES (9223372036854775807);"); // Max long value + stmt.executeUpdate( + "INSERT INTO test_long (long_col) VALUES (-9223372036854775808);"); // Min long value + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_long"); + + // Test typical long value + assertEquals(1234567890L, resultSet.getLong(1)); + + // Test maximum long value + assertTrue(resultSet.next()); + assertEquals(9223372036854775807L, resultSet.getLong(1)); + + // Test minimum long value + assertTrue(resultSet.next()); + assertEquals(-9223372036854775808L, resultSet.getLong(1)); + } + + @Test + void test_getLong_returns_zero_no_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (long_col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_null (long_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertEquals(0L, resultSet.getLong(1)); + } + + @Test + void test_getFloat() throws Exception { + stmt.executeUpdate("CREATE TABLE test_float (float_col REAL);"); + stmt.executeUpdate("INSERT INTO test_float (float_col) VALUES (1.23);"); + stmt.executeUpdate( + "INSERT INTO test_float (float_col) VALUES (3.4028235E38);"); // Max float value + stmt.executeUpdate( + "INSERT INTO test_float (float_col) VALUES (1.4E-45);"); // Min positive float value + stmt.executeUpdate( + "INSERT INTO test_float (float_col) VALUES (-3.4028235E38);"); // Min negative float value + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_float"); + + // Test typical float value + assertTrue(resultSet.next()); + assertEquals(1.23f, resultSet.getFloat(1), 0.0001); + + // Test maximum float value + assertTrue(resultSet.next()); + assertEquals(3.4028235E38f, resultSet.getFloat(1), 0.0001); + + // Test minimum positive float value + assertTrue(resultSet.next()); + assertEquals(1.4E-45f, resultSet.getFloat(1), 0.0001); + + // Test minimum negative float value + assertTrue(resultSet.next()); + assertEquals(-3.4028235E38f, resultSet.getFloat(1), 0.0001); + } + + @Test + void test_getFloat_returns_zero_on_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (float_col REAL);"); + stmt.executeUpdate("INSERT INTO test_null (float_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertEquals(0.0f, resultSet.getFloat(1), 0.0001); + } + + @Test + void test_getDouble() throws Exception { + stmt.executeUpdate("CREATE TABLE test_double (double_col REAL);"); + stmt.executeUpdate("INSERT INTO test_double (double_col) VALUES (1.234567);"); + stmt.executeUpdate( + "INSERT INTO test_double (double_col) VALUES (1.7976931348623157E308);"); // Max double + // value + stmt.executeUpdate( + "INSERT INTO test_double (double_col) VALUES (4.9E-324);"); // Min positive double value + stmt.executeUpdate( + "INSERT INTO test_double (double_col) VALUES (-1.7976931348623157E308);"); // Min negative + // double value + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_double"); + + // Test typical double value + assertTrue(resultSet.next()); + assertEquals(1.234567, resultSet.getDouble(1), 0.0001); + + // Test maximum double value + assertTrue(resultSet.next()); + assertEquals(1.7976931348623157E308, resultSet.getDouble(1), 0.0001); + + // Test minimum positive double value + assertTrue(resultSet.next()); + assertEquals(4.9E-324, resultSet.getDouble(1), 0.0001); + + // Test minimum negative double value + assertTrue(resultSet.next()); + assertEquals(-1.7976931348623157E308, resultSet.getDouble(1), 0.0001); + } + + @Test + void test_getDouble_returns_zero_on_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (double_col REAL);"); + stmt.executeUpdate("INSERT INTO test_null (double_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertEquals(0.0, resultSet.getDouble(1), 0.0001); + } + + @Test + void test_getBigDecimal() throws Exception { + stmt.executeUpdate("CREATE TABLE test_bigdecimal (bigdecimal_col REAL);"); + stmt.executeUpdate("INSERT INTO test_bigdecimal (bigdecimal_col) VALUES (12345.67);"); + stmt.executeUpdate( + "INSERT INTO test_bigdecimal (bigdecimal_col) VALUES (1.7976931348623157E308);"); // Max + // double + // value + stmt.executeUpdate( + "INSERT INTO test_bigdecimal (bigdecimal_col) VALUES (4.9E-324);"); // Min positive double + // value + stmt.executeUpdate( + "INSERT INTO test_bigdecimal (bigdecimal_col) VALUES (-12345.67);"); // Negative value + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_bigdecimal"); + + // Test typical BigDecimal value + assertTrue(resultSet.next()); + assertEquals( + new BigDecimal("12345.67").setScale(2, RoundingMode.HALF_UP), + resultSet.getBigDecimal(1, 2)); + + // Test maximum double value + assertTrue(resultSet.next()); + assertEquals( + new BigDecimal("1.7976931348623157E308").setScale(10, RoundingMode.HALF_UP), + resultSet.getBigDecimal(1, 10)); + + // Test minimum positive double value + assertTrue(resultSet.next()); + assertEquals( + new BigDecimal("4.9E-324").setScale(10, RoundingMode.HALF_UP), + resultSet.getBigDecimal(1, 10)); + + // Test negative BigDecimal value + assertTrue(resultSet.next()); + assertEquals( + new BigDecimal("-12345.67").setScale(2, RoundingMode.HALF_UP), + resultSet.getBigDecimal(1, 2)); + } + + @Test + void test_getBigDecimal_returns_null_on_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (bigdecimal_col REAL);"); + stmt.executeUpdate("INSERT INTO test_null (bigdecimal_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertNull(resultSet.getBigDecimal(1, 2)); + } + + @ParameterizedTest + @MethodSource("byteArrayProvider") + void test_getBytes(byte[] data) throws Exception { + stmt.executeUpdate("CREATE TABLE test_bytes (bytes_col BLOB);"); + executeDMLAndAssert(data); + } + + private static Stream byteArrayProvider() { + return Stream.of( + "Hello".getBytes(), "world".getBytes(), new byte[0], new byte[] {0x00, (byte) 0xFF}); + } + + private void executeDMLAndAssert(byte[] data) throws SQLException { + // Convert byte array to hexadecimal string + StringBuilder hexString = new StringBuilder(); + for (byte b : data) { + hexString.append(String.format("%02X", b)); + } + // Execute DML statement + stmt.executeUpdate("INSERT INTO test_bytes (bytes_col) VALUES (X'" + hexString + "');"); + + // Assert the inserted data + ResultSet resultSet = stmt.executeQuery("SELECT bytes_col FROM test_bytes"); + assertTrue(resultSet.next()); + assertArrayEquals(data, resultSet.getBytes(1)); + } + + @Test + void test_getBytes_returns_null_on_null() throws Exception { + stmt.executeUpdate("CREATE TABLE test_null (bytes_col BLOB);"); + stmt.executeUpdate("INSERT INTO test_null (bytes_col) VALUES (NULL);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_null"); + assertTrue(resultSet.next()); + assertNull(resultSet.getBytes(1)); + } + + @Test + void test_getXXX_methods_on_multiple_columns() throws Exception { + stmt.executeUpdate( + "CREATE TABLE test_integration (" + + "string_col TEXT, " + + "boolean_col INTEGER, " + + "byte_col INTEGER, " + + "short_col INTEGER, " + + "int_col INTEGER, " + + "long_col BIGINT, " + + "float_col REAL, " + + "double_col REAL, " + + "bigdecimal_col REAL, " + + "bytes_col BLOB);"); + + stmt.executeUpdate( + "INSERT INTO test_integration VALUES (" + + "'test', " + + "1, " + + "1, " + + "123, " + + "12345, " + + "1234567890, " + + "1.23, " + + "1.234567, " + + "12345.67, " + + "X'48656C6C6F');"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_integration"); + assertTrue(resultSet.next()); + + // Verify each column + assertEquals("test", resultSet.getString(1)); + assertTrue(resultSet.getBoolean(2)); + assertEquals(1, resultSet.getByte(3)); + assertEquals(123, resultSet.getShort(4)); + assertEquals(12345, resultSet.getInt(5)); + assertEquals(1234567890L, resultSet.getLong(6)); + assertEquals(1.23f, resultSet.getFloat(7), 0.0001); + assertEquals(1.234567, resultSet.getDouble(8), 0.0001); + assertEquals( + new BigDecimal("12345.67").setScale(2, RoundingMode.HALF_UP), + resultSet.getBigDecimal(9, 2)); + assertArrayEquals("Hello".getBytes(), resultSet.getBytes(10)); + } + + @Test + void test_invalidColumnIndex_outOfBounds() throws Exception { + stmt.executeUpdate("CREATE TABLE test_invalid (col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_invalid (col) VALUES (1);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_invalid"); + assertTrue(resultSet.next()); + + // Test out-of-bounds column index + assertThrows(SQLException.class, () -> resultSet.getInt(2)); + } + + @Test + void test_invalidColumnIndex_negative() throws Exception { + stmt.executeUpdate("CREATE TABLE test_invalid (col INTEGER);"); + stmt.executeUpdate("INSERT INTO test_invalid (col) VALUES (1);"); + + ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_invalid"); + assertTrue(resultSet.next()); + + // Test negative column index + assertThrows(SQLException.class, () -> resultSet.getInt(-1)); + } } diff --git a/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4StatementTest.java b/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4StatementTest.java index 2a837629d..a48cedea9 100644 --- a/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4StatementTest.java +++ b/bindings/java/src/test/java/org/github/tursodatabase/jdbc4/JDBC4StatementTest.java @@ -3,6 +3,7 @@ package org.github.tursodatabase.jdbc4; import static org.junit.jupiter.api.Assertions.*; import java.sql.ResultSet; +import java.sql.SQLException; import java.sql.Statement; import java.util.Properties; import org.github.tursodatabase.TestUtils; @@ -51,4 +52,22 @@ class JDBC4StatementTest { stmt.execute("INSERT INTO users VALUES (1, 'limbo');"); assertTrue(stmt.execute("SELECT * FROM users;")); } + + @Test + void close_statement_test() throws Exception { + stmt.close(); + assertTrue(stmt.isClosed()); + } + + @Test + void double_close_is_no_op() throws SQLException { + stmt.close(); + assertDoesNotThrow(() -> stmt.close()); + } + + @Test + void operations_on_closed_statement_should_throw_exception() throws Exception { + stmt.close(); + assertThrows(SQLException.class, () -> stmt.execute("SELECT 1;")); + } } diff --git a/bindings/java/src/main/resources/logback.xml b/bindings/java/src/test/resources/logback.xml similarity index 88% rename from bindings/java/src/main/resources/logback.xml rename to bindings/java/src/test/resources/logback.xml index 5143dd837..1496a4b64 100644 --- a/bindings/java/src/main/resources/logback.xml +++ b/bindings/java/src/test/resources/logback.xml @@ -5,7 +5,7 @@ - + diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 085d71b3b..16afc5494 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -101,16 +101,17 @@ impl Cursor { // For DDL and DML statements, // we need to execute the statement immediately if stmt_is_ddl || stmt_is_dml { - while stmt - .borrow_mut() - .step() - .map_err(|e| PyErr::new::(format!("Step error: {:?}", e)))? - .eq(&limbo_core::StepResult::IO) - { - self.conn - .io - .run_once() - .map_err(|e| PyErr::new::(format!("IO error: {:?}", e)))?; + loop { + match stmt.borrow_mut().step().map_err(|e| { + PyErr::new::(format!("Step error: {:?}", e)) + })? { + limbo_core::StepResult::IO => { + self.conn.io.run_once().map_err(|e| { + PyErr::new::(format!("IO error: {:?}", e)) + })?; + } + _ => break, + } } } diff --git a/bindings/wasm/integration-tests/package-lock.json b/bindings/wasm/integration-tests/package-lock.json index f2886d47d..ea731abaf 100644 --- a/bindings/wasm/integration-tests/package-lock.json +++ b/bindings/wasm/integration-tests/package-lock.json @@ -15,7 +15,7 @@ }, "..": { "name": "limbo-wasm", - "version": "0.0.13", + "version": "0.0.14", "license": "MIT", "devDependencies": { "@playwright/test": "^1.49.1", diff --git a/bindings/wasm/lib.rs b/bindings/wasm/lib.rs index 1558f09d8..9b34fcb57 100644 --- a/bindings/wasm/lib.rs +++ b/bindings/wasm/lib.rs @@ -192,7 +192,7 @@ fn to_js_value(value: limbo_core::Value) -> JsValue { } limbo_core::Value::Float(f) => JsValue::from(f), limbo_core::Value::Text(t) => JsValue::from_str(t), - limbo_core::Value::Blob(b) => js_sys::Uint8Array::from(b.as_slice()).into(), + limbo_core::Value::Blob(b) => js_sys::Uint8Array::from(b).into(), } } diff --git a/bindings/wasm/package-lock.json b/bindings/wasm/package-lock.json index 83309fcb9..930d9255e 100644 --- a/bindings/wasm/package-lock.json +++ b/bindings/wasm/package-lock.json @@ -1,12 +1,12 @@ { "name": "limbo-wasm", - "version": "0.0.13", + "version": "0.0.14", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "limbo-wasm", - "version": "0.0.13", + "version": "0.0.14", "license": "MIT", "devDependencies": { "@playwright/test": "^1.49.1", diff --git a/bindings/wasm/package.json b/bindings/wasm/package.json index 7dfb91516..2265a7799 100644 --- a/bindings/wasm/package.json +++ b/bindings/wasm/package.json @@ -3,7 +3,7 @@ "collaborators": [ "the Limbo authors" ], - "version": "0.0.13", + "version": "0.0.14", "license": "MIT", "repository": { "type": "git", diff --git a/bindings/wasm/test-limbo-pkg/package.json b/bindings/wasm/test-limbo-pkg/package.json index 424bcede1..7c10b8c28 100644 --- a/bindings/wasm/test-limbo-pkg/package.json +++ b/bindings/wasm/test-limbo-pkg/package.json @@ -3,7 +3,7 @@ "private": true, "type": "module", "dependencies": { - "limbo-wasm": "limbo-wasm@0.0.13" + "limbo-wasm": "limbo-wasm@0.0.14" }, "scripts": { "dev": "vite" diff --git a/core/Cargo.toml b/core/Cargo.toml index c378d8bcc..f2bba5f59 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -14,15 +14,19 @@ name = "limbo_core" path = "lib.rs" [features] -default = ["fs", "json", "uuid", "io_uring"] +default = ["fs", "json", "uuid", "vector", "io_uring", "time"] fs = [] json = [ "dep:jsonb", "dep:pest", "dep:pest_derive", ] -uuid = ["dep:uuid"] +uuid = ["limbo_uuid/static"] +vector = ["limbo_vector/static"] io_uring = ["dep:io-uring", "rustix/io_uring"] +percentile = ["limbo_percentile/static"] +regexp = ["limbo_regexp/static"] +time = ["limbo_time/static"] [target.'cfg(target_os = "linux")'.dependencies] io-uring = { version = "0.6.1", optional = true } @@ -33,6 +37,7 @@ rustix = "0.38.34" [target.'cfg(not(target_family = "wasm"))'.dependencies] mimalloc = { version = "*", default-features = false } +libloading = "0.8.6" [dependencies] limbo_ext = { path = "../extensions/core" } @@ -57,9 +62,20 @@ pest_derive = { version = "2.0", optional = true } rand = "0.8.5" bumpalo = { version = "3.16.0", features = ["collections", "boxed"] } limbo_macros = { path = "../macros" } -uuid = { version = "1.11.0", features = ["v4", "v7"], optional = true } +limbo_uuid = { path = "../extensions/uuid", optional = true, features = ["static"] } +limbo_vector = { path = "../extensions/vector", optional = true, features = ["static"] } +limbo_regexp = { path = "../extensions/regexp", optional = true, features = ["static"] } +limbo_percentile = { path = "../extensions/percentile", optional = true, features = ["static"] } +limbo_time = { path = "../extensions/time", optional = true, features = ["static"] } miette = "7.4.0" -libloading = "0.8.6" +strum = "0.26" +parking_lot = "0.12.3" +tracing = "0.1.41" +crossbeam-skiplist = "0.1.3" + +[build-dependencies] +chrono = "0.4.38" +built = { version = "0.7.5", features = ["git2", "chrono"] } [target.'cfg(not(target_family = "windows"))'.dev-dependencies] pprof = { version = "0.14.0", features = ["criterion", "flamegraph"] } @@ -78,3 +94,7 @@ tempfile = "3.8.0" [[bench]] name = "benchmark" harness = false + +[[bench]] +name = "mvcc_benchmark" +harness = false diff --git a/core/benches/benchmark.rs b/core/benches/benchmark.rs index 9858a0c56..f25d48144 100644 --- a/core/benches/benchmark.rs +++ b/core/benches/benchmark.rs @@ -1,186 +1,145 @@ -use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use limbo_core::{Database, PlatformIO, IO}; use pprof::criterion::{Output, PProfProfiler}; use std::sync::Arc; -fn bench(c: &mut Criterion) { - limbo_bench(c); - - // https://github.com/penberg/limbo/issues/174 - // The rusqlite benchmark crashes on Mac M1 when using the flamegraph features - if std::env::var("DISABLE_RUSQLITE_BENCHMARK").is_ok() { - return; - } - - rusqlite_bench(c) +fn rusqlite_open() -> rusqlite::Connection { + let sqlite_conn = rusqlite::Connection::open("../testing/testing.db").unwrap(); + sqlite_conn + .pragma_update(None, "locking_mode", "EXCLUSIVE") + .unwrap(); + sqlite_conn } -fn limbo_bench(criterion: &mut Criterion) { - let mut group = criterion.benchmark_group("limbo"); - group.throughput(Throughput::Elements(1)); +fn bench(criterion: &mut Criterion) { + // https://github.com/penberg/limbo/issues/174 + // The rusqlite benchmark crashes on Mac M1 when using the flamegraph features + let enable_rusqlite = std::env::var("DISABLE_RUSQLITE_BENCHMARK").is_err(); + #[allow(clippy::arc_with_non_send_sync)] let io = Arc::new(PlatformIO::new().unwrap()); let db = Database::open_file(io.clone(), "../testing/testing.db").unwrap(); - let conn = db.connect(); + let limbo_conn = db.connect(); - group.bench_function("Prepare statement: 'SELECT 1'", |b| { - b.iter(|| { - conn.prepare("SELECT 1").unwrap(); + let queries = [ + "SELECT 1", + "SELECT * FROM users LIMIT 1", + "SELECT first_name, count(1) FROM users GROUP BY first_name HAVING count(1) > 1 ORDER BY count(1) LIMIT 1", + ]; + + for query in queries.iter() { + let mut group = criterion.benchmark_group(format!("Prepare `{}`", query)); + + group.bench_with_input(BenchmarkId::new("Limbo", query), query, |b, query| { + b.iter(|| { + limbo_conn.prepare(query).unwrap(); + }); }); - }); - group.bench_function("Prepare statement: 'SELECT * FROM users LIMIT 1'", |b| { - b.iter(|| { - conn.prepare("SELECT * FROM users LIMIT 1").unwrap(); + if enable_rusqlite { + let sqlite_conn = rusqlite_open(); + + group.bench_with_input(BenchmarkId::new("Sqlite3", query), query, |b, query| { + b.iter(|| { + sqlite_conn.prepare(query).unwrap(); + }); + }); + } + + group.finish(); + } + + let mut group = criterion.benchmark_group("Execute `SELECT * FROM users LIMIT ?`"); + + for i in [1, 10, 50, 100] { + group.bench_with_input(BenchmarkId::new("Limbo", i), &i, |b, i| { + // TODO: LIMIT doesn't support query parameters. + let mut stmt = limbo_conn + .prepare(format!("SELECT * FROM users LIMIT {}", *i)) + .unwrap(); + let io = io.clone(); + b.iter(|| { + loop { + match stmt.step().unwrap() { + limbo_core::StepResult::Row(row) => { + black_box(row); + } + limbo_core::StepResult::IO => { + let _ = io.run_once(); + } + limbo_core::StepResult::Done => { + break; + } + limbo_core::StepResult::Interrupt | limbo_core::StepResult::Busy => { + unreachable!(); + } + } + } + stmt.reset(); + }); }); - }); - group.bench_function("Prepare statement: 'SELECT first_name, count(1) FROM users GROUP BY first_name HAVING count(1) > 1 ORDER BY count(1) LIMIT 1'", |b| { - b.iter(|| { - conn.prepare("SELECT first_name, count(1) FROM users GROUP BY first_name HAVING count(1) > 1 ORDER BY count(1) LIMIT 1").unwrap(); - }); - }); + if enable_rusqlite { + let sqlite_conn = rusqlite_open(); - let mut stmt = conn.prepare("SELECT 1").unwrap(); - group.bench_function("Execute prepared statement: 'SELECT 1'", |b| { + group.bench_with_input(BenchmarkId::new("Sqlite3", i), &i, |b, i| { + // TODO: Use parameters once we fix the above. + let mut stmt = sqlite_conn + .prepare(&format!("SELECT * FROM users LIMIT {}", *i)) + .unwrap(); + b.iter(|| { + let mut rows = stmt.raw_query(); + while let Some(row) = rows.next().unwrap() { + black_box(row); + } + }); + }); + } + } + + group.finish(); + + let mut group = criterion.benchmark_group("Execute `SELECT 1`"); + + group.bench_function("Limbo", |b| { + let mut stmt = limbo_conn.prepare("SELECT 1").unwrap(); let io = io.clone(); b.iter(|| { - let mut rows = stmt.query().unwrap(); - match rows.step().unwrap() { - limbo_core::StepResult::Row(row) => { - assert_eq!(row.get::(0).unwrap(), 1); - } - limbo_core::StepResult::IO => { - io.run_once().unwrap(); - } - limbo_core::StepResult::Interrupt => { - unreachable!(); - } - limbo_core::StepResult::Done => { - unreachable!(); - } - limbo_core::StepResult::Busy => { - unreachable!(); + loop { + match stmt.step().unwrap() { + limbo_core::StepResult::Row(row) => { + black_box(row); + } + limbo_core::StepResult::IO => { + let _ = io.run_once(); + } + limbo_core::StepResult::Done => { + break; + } + limbo_core::StepResult::Interrupt | limbo_core::StepResult::Busy => { + unreachable!(); + } } } stmt.reset(); }); }); - let mut stmt = conn.prepare("SELECT * FROM users LIMIT 1").unwrap(); - group.bench_function( - "Execute prepared statement: 'SELECT * FROM users LIMIT 1'", - |b| { - let io = io.clone(); + if enable_rusqlite { + let sqlite_conn = rusqlite_open(); + + group.bench_function("Sqlite3", |b| { + let mut stmt = sqlite_conn.prepare("SELECT 1").unwrap(); b.iter(|| { - let mut rows = stmt.query().unwrap(); - match rows.step().unwrap() { - limbo_core::StepResult::Row(row) => { - assert_eq!(row.get::(0).unwrap(), 1); - } - limbo_core::StepResult::IO => { - io.run_once().unwrap(); - } - limbo_core::StepResult::Interrupt => { - unreachable!(); - } - limbo_core::StepResult::Done => { - unreachable!(); - } - limbo_core::StepResult::Busy => { - unreachable!() - } + let mut rows = stmt.raw_query(); + while let Some(row) = rows.next().unwrap() { + black_box(row); } - stmt.reset(); }); - }, - ); - - let mut stmt = conn.prepare("SELECT * FROM users LIMIT 100").unwrap(); - group.bench_function( - "Execute prepared statement: 'SELECT * FROM users LIMIT 100'", - |b| { - let io = io.clone(); - b.iter(|| { - let mut rows = stmt.query().unwrap(); - match rows.step().unwrap() { - limbo_core::StepResult::Row(row) => { - assert_eq!(row.get::(0).unwrap(), 1); - } - limbo_core::StepResult::IO => { - io.run_once().unwrap(); - } - limbo_core::StepResult::Interrupt => { - unreachable!(); - } - limbo_core::StepResult::Done => { - unreachable!(); - } - limbo_core::StepResult::Busy => { - unreachable!() - } - } - stmt.reset(); - }); - }, - ); -} - -fn rusqlite_bench(criterion: &mut Criterion) { - let mut group = criterion.benchmark_group("rusqlite"); - group.throughput(Throughput::Elements(1)); - - let conn = rusqlite::Connection::open("../testing/testing.db").unwrap(); - - conn.pragma_update(None, "locking_mode", "EXCLUSIVE") - .unwrap(); - group.bench_function("Prepare statement: 'SELECT 1'", |b| { - b.iter(|| { - conn.prepare("SELECT 1").unwrap(); }); - }); + } - group.bench_function("Prepare statement: 'SELECT * FROM users LIMIT 1'", |b| { - b.iter(|| { - conn.prepare("SELECT * FROM users LIMIT 1").unwrap(); - }); - }); - - let mut stmt = conn.prepare("SELECT 1").unwrap(); - group.bench_function("Execute prepared statement: 'SELECT 1'", |b| { - b.iter(|| { - let mut rows = stmt.query(()).unwrap(); - let row = rows.next().unwrap().unwrap(); - let val: i64 = row.get(0).unwrap(); - assert_eq!(val, 1); - }); - }); - - let mut stmt = conn.prepare("SELECT * FROM users LIMIT 1").unwrap(); - group.bench_function( - "Execute prepared statement: 'SELECT * FROM users LIMIT 1'", - |b| { - b.iter(|| { - let mut rows = stmt.query(()).unwrap(); - let row = rows.next().unwrap().unwrap(); - let id: i64 = row.get(0).unwrap(); - assert_eq!(id, 1); - }); - }, - ); - - let mut stmt = conn.prepare("SELECT * FROM users LIMIT 100").unwrap(); - group.bench_function( - "Execute prepared statement: 'SELECT * FROM users LIMIT 100'", - |b| { - b.iter(|| { - let mut rows = stmt.query(()).unwrap(); - let row = rows.next().unwrap().unwrap(); - let id: i64 = row.get(0).unwrap(); - assert_eq!(id, 1); - }); - }, - ); + group.finish(); } criterion_group! { diff --git a/core/benches/mvcc_benchmark.rs b/core/benches/mvcc_benchmark.rs new file mode 100644 index 000000000..899c8b82d --- /dev/null +++ b/core/benches/mvcc_benchmark.rs @@ -0,0 +1,129 @@ +use criterion::async_executor::FuturesExecutor; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use limbo_core::mvcc::clock::LocalClock; +use limbo_core::mvcc::database::{Database, Row, RowID}; +use pprof::criterion::{Output, PProfProfiler}; + +fn bench_db() -> Database { + let clock = LocalClock::default(); + let storage = limbo_core::mvcc::persistent_storage::Storage::new_noop(); + Database::new(clock, storage) +} + +fn bench(c: &mut Criterion) { + let mut group = c.benchmark_group("mvcc-ops-throughput"); + group.throughput(Throughput::Elements(1)); + + let db = bench_db(); + group.bench_function("begin_tx + rollback_tx", |b| { + b.to_async(FuturesExecutor).iter(|| async { + let tx_id = db.begin_tx(); + db.rollback_tx(tx_id) + }) + }); + + let db = bench_db(); + group.bench_function("begin_tx + commit_tx", |b| { + b.to_async(FuturesExecutor).iter(|| async { + let tx_id = db.begin_tx(); + db.commit_tx(tx_id) + }) + }); + + let db = bench_db(); + group.bench_function("begin_tx-read-commit_tx", |b| { + b.to_async(FuturesExecutor).iter(|| async { + let tx_id = db.begin_tx(); + db.read( + tx_id, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap(); + db.commit_tx(tx_id) + }) + }); + + let db = bench_db(); + group.bench_function("begin_tx-update-commit_tx", |b| { + b.to_async(FuturesExecutor).iter(|| async { + let tx_id = db.begin_tx(); + db.update( + tx_id, + Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "World".to_string(), + }, + ) + .unwrap(); + db.commit_tx(tx_id) + }) + }); + + let db = bench_db(); + let tx = db.begin_tx(); + db.insert( + tx, + Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }, + ) + .unwrap(); + group.bench_function("read", |b| { + b.to_async(FuturesExecutor).iter(|| async { + db.read( + tx, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap(); + }) + }); + + let db = bench_db(); + let tx = db.begin_tx(); + db.insert( + tx, + Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }, + ) + .unwrap(); + group.bench_function("update", |b| { + b.to_async(FuturesExecutor).iter(|| async { + db.update( + tx, + Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "World".to_string(), + }, + ) + .unwrap(); + }) + }); +} + +criterion_group! { + name = benches; + config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); + targets = bench +} +criterion_main!(benches); diff --git a/core/build.rs b/core/build.rs new file mode 100644 index 000000000..50afee6bf --- /dev/null +++ b/core/build.rs @@ -0,0 +1,21 @@ +use std::fs; +use std::path::PathBuf; + +fn main() { + let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); + let built_file = out_dir.join("built.rs"); + + built::write_built_file().expect("Failed to acquire build-time information"); + + // So that we don't have to transform at runtime + let sqlite_date = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(); + fs::write( + &built_file, + format!( + "{}\npub const BUILT_TIME_SQLITE: &str = \"{}\";\n", + fs::read_to_string(&built_file).unwrap(), + sqlite_date + ), + ) + .expect("Failed to append to built file"); +} diff --git a/core/error.rs b/core/error.rs index ca495eb99..cfe1a827d 100644 --- a/core/error.rs +++ b/core/error.rs @@ -39,12 +39,18 @@ pub enum LimboError { InvalidTime(String), #[error("Modifier parsing error: {0}")] InvalidModifier(String), + #[error("Invalid argument supplied: {0}")] + InvalidArgument(String), + #[error("Invalid formatter supplied: {0}")] + InvalidFormatter(String), #[error("Runtime error: {0}")] Constraint(String), #[error("Extension error: {0}")] ExtensionError(String), #[error("Unbound parameter at index {0}")] Unbound(NonZero), + #[error("Runtime error: integer overflow")] + IntegerOverflow, } #[macro_export] diff --git a/core/ext/mod.rs b/core/ext/mod.rs index cf3fa6109..8a9212556 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -73,4 +73,29 @@ impl Database { register_aggregate_function, } } + + pub fn register_builtins(&self) -> Result<(), String> { + let ext_api = self.build_limbo_ext(); + #[cfg(feature = "uuid")] + if unsafe { !limbo_uuid::register_extension_static(&ext_api).is_ok() } { + return Err("Failed to register uuid extension".to_string()); + } + #[cfg(feature = "vector")] + if unsafe { !limbo_vector::register_extension_static(&ext_api).is_ok() } { + return Err("Failed to register vector extension".to_string()); + } + #[cfg(feature = "percentile")] + if unsafe { !limbo_percentile::register_extension_static(&ext_api).is_ok() } { + return Err("Failed to register percentile extension".to_string()); + } + #[cfg(feature = "regexp")] + if unsafe { !limbo_regexp::register_extension_static(&ext_api).is_ok() } { + return Err("Failed to register regexp extension".to_string()); + } + #[cfg(feature = "time")] + if unsafe { !limbo_time::register_extension_static(&ext_api).is_ok() } { + return Err("Failed to register time extension".to_string()); + } + Ok(()) + } } diff --git a/core/function.rs b/core/function.rs index 1e5386696..8fbe9cb7e 100644 --- a/core/function.rs +++ b/core/function.rs @@ -80,6 +80,10 @@ pub enum JsonFunc { JsonType, JsonErrorPosition, JsonValid, + JsonPatch, + JsonRemove, + JsonPretty, + JsonSet, } #[cfg(feature = "json")] @@ -99,6 +103,10 @@ impl Display for JsonFunc { Self::JsonType => "json_type".to_string(), Self::JsonErrorPosition => "json_error_position".to_string(), Self::JsonValid => "json_valid".to_string(), + Self::JsonPatch => "json_patch".to_string(), + Self::JsonRemove => "json_remove".to_string(), + Self::JsonPretty => "json_pretty".to_string(), + Self::JsonSet => "json_set".to_string(), } ) } @@ -206,6 +214,7 @@ pub enum ScalarFunc { Unicode, Quote, SqliteVersion, + SqliteSourceId, UnixEpoch, JulianDay, Hex, @@ -216,6 +225,7 @@ pub enum ScalarFunc { #[cfg(not(target_family = "wasm"))] LoadExtension, StrfTime, + Printf, } impl Display for ScalarFunc { @@ -257,6 +267,7 @@ impl Display for ScalarFunc { Self::Unicode => "unicode".to_string(), Self::Quote => "quote".to_string(), Self::SqliteVersion => "sqlite_version".to_string(), + Self::SqliteSourceId => "sqlite_source_id".to_string(), Self::JulianDay => "julianday".to_string(), Self::UnixEpoch => "unixepoch".to_string(), Self::Hex => "hex".to_string(), @@ -268,6 +279,7 @@ impl Display for ScalarFunc { #[cfg(not(target_family = "wasm"))] Self::LoadExtension => "load_extension".to_string(), Self::StrfTime => "strftime".to_string(), + Self::Printf => "printf".to_string(), }; write!(f, "{}", str) } @@ -506,6 +518,7 @@ impl Func { "unicode" => Ok(Self::Scalar(ScalarFunc::Unicode)), "quote" => Ok(Self::Scalar(ScalarFunc::Quote)), "sqlite_version" => Ok(Self::Scalar(ScalarFunc::SqliteVersion)), + "sqlite_source_id" => Ok(Self::Scalar(ScalarFunc::SqliteSourceId)), "replace" => Ok(Self::Scalar(ScalarFunc::Replace)), #[cfg(feature = "json")] "json" => Ok(Self::Json(JsonFunc::Json)), @@ -523,6 +536,14 @@ impl Func { "json_error_position" => Ok(Self::Json(JsonFunc::JsonErrorPosition)), #[cfg(feature = "json")] "json_valid" => Ok(Self::Json(JsonFunc::JsonValid)), + #[cfg(feature = "json")] + "json_patch" => Ok(Self::Json(JsonFunc::JsonPatch)), + #[cfg(feature = "json")] + "json_remove" => Ok(Self::Json(JsonFunc::JsonRemove)), + #[cfg(feature = "json")] + "json_pretty" => Ok(Self::Json(JsonFunc::JsonPretty)), + #[cfg(feature = "json")] + "json_set" => Ok(Self::Json(JsonFunc::JsonSet)), "unixepoch" => Ok(Self::Scalar(ScalarFunc::UnixEpoch)), "julianday" => Ok(Self::Scalar(ScalarFunc::JulianDay)), "hex" => Ok(Self::Scalar(ScalarFunc::Hex)), @@ -561,6 +582,7 @@ impl Func { #[cfg(not(target_family = "wasm"))] "load_extension" => Ok(Self::Scalar(ScalarFunc::LoadExtension)), "strftime" => Ok(Self::Scalar(ScalarFunc::StrfTime)), + "printf" => Ok(Self::Scalar(ScalarFunc::Printf)), _ => crate::bail_parse_error!("no such function: {}", name), } } diff --git a/core/info.rs b/core/info.rs new file mode 100644 index 000000000..908ddca7b --- /dev/null +++ b/core/info.rs @@ -0,0 +1,3 @@ +pub mod build { + include!(concat!(env!("OUT_DIR"), "/built.rs")); +} diff --git a/core/io/memory.rs b/core/io/memory.rs index 6fc960e13..18decf78a 100644 --- a/core/io/memory.rs +++ b/core/io/memory.rs @@ -3,15 +3,15 @@ use crate::Result; use log::debug; use std::{ - cell::{RefCell, RefMut}, + cell::{Cell, RefCell, UnsafeCell}, collections::BTreeMap, rc::Rc, sync::Arc, }; pub struct MemoryIO { - pages: RefCell>, - size: RefCell, + pages: UnsafeCell>, + size: Cell, } // TODO: page size flag @@ -23,23 +23,23 @@ impl MemoryIO { pub fn new() -> Result> { debug!("Using IO backend 'memory'"); Ok(Arc::new(Self { - pages: RefCell::new(BTreeMap::new()), - size: RefCell::new(0), + pages: BTreeMap::new().into(), + size: 0.into(), })) } - fn get_or_allocate_page(&self, page_no: usize) -> RefMut { - let pages = self.pages.borrow_mut(); - RefMut::map(pages, |p| { - p.entry(page_no).or_insert_with(|| Box::new([0; PAGE_SIZE])) - }) + #[allow(clippy::mut_from_ref)] + fn get_or_allocate_page(&self, page_no: usize) -> &mut MemPage { + unsafe { + let pages = &mut *self.pages.get(); + pages + .entry(page_no) + .or_insert_with(|| Box::new([0; PAGE_SIZE])) + } } - fn get_page(&self, page_no: usize) -> Option> { - match RefMut::filter_map(self.pages.borrow_mut(), |pages| pages.get_mut(&page_no)) { - Ok(page) => Some(page), - Err(_) => None, - } + fn get_page(&self, page_no: usize) -> Option<&MemPage> { + unsafe { (*self.pages.get()).get(&page_no) } } } @@ -71,7 +71,6 @@ pub struct MemoryFile { } impl File for MemoryFile { - // no-ops fn lock_file(&self, _exclusive: bool) -> Result<()> { Ok(()) } @@ -90,7 +89,7 @@ impl File for MemoryFile { return Ok(()); } - let file_size = *self.io.size.borrow(); + let file_size = self.io.size.get(); if pos >= file_size { c.complete(0); return Ok(()); @@ -108,15 +107,10 @@ impl File for MemoryFile { let page_offset = offset % PAGE_SIZE; let bytes_to_read = remaining.min(PAGE_SIZE - page_offset); if let Some(page) = self.io.get_page(page_no) { - { - let page_data = &*page; - read_buf.as_mut_slice()[buf_offset..buf_offset + bytes_to_read] - .copy_from_slice(&page_data[page_offset..page_offset + bytes_to_read]); - } + read_buf.as_mut_slice()[buf_offset..buf_offset + bytes_to_read] + .copy_from_slice(&page[page_offset..page_offset + bytes_to_read]); } else { - for b in &mut read_buf.as_mut_slice()[buf_offset..buf_offset + bytes_to_read] { - *b = 0; - } + read_buf.as_mut_slice()[buf_offset..buf_offset + bytes_to_read].fill(0); } offset += bytes_to_read; @@ -147,7 +141,7 @@ impl File for MemoryFile { let bytes_to_write = remaining.min(PAGE_SIZE - page_offset); { - let mut page = self.io.get_or_allocate_page(page_no); + let page = self.io.get_or_allocate_page(page_no); page[page_offset..page_offset + bytes_to_write] .copy_from_slice(&data[buf_offset..buf_offset + bytes_to_write]); } @@ -157,10 +151,9 @@ impl File for MemoryFile { remaining -= bytes_to_write; } - { - let mut size = self.io.size.borrow_mut(); - *size = (*size).max(pos + buf_len); - } + self.io + .size + .set(core::cmp::max(pos + buf_len, self.io.size.get())); c.complete(buf_len as i32); Ok(()) @@ -173,7 +166,7 @@ impl File for MemoryFile { } fn size(&self) -> Result { - Ok(*self.io.size.borrow() as u64) + Ok(self.io.size.get() as u64) } } diff --git a/core/json/de.rs b/core/json/de.rs index 14927bb18..6cc3fb333 100644 --- a/core/json/de.rs +++ b/core/json/de.rs @@ -504,3 +504,59 @@ impl<'de> de::VariantAccess<'de> for Variant<'de> { } } } + +pub mod ordered_object { + + use crate::json::Val; + use serde::de::{MapAccess, Visitor}; + use serde::{Deserializer, Serializer}; + use std::fmt; + + pub fn serialize(pairs: &Vec<(String, Val)>, serializer: S) -> Result + where + S: Serializer, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(pairs.len()))?; + for (k, v) in pairs { + if let Val::Removed = v { + continue; + } + map.serialize_entry(k, v)?; + } + map.end() + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + struct OrderedMapVisitor; + + impl<'de> Visitor<'de> for OrderedMapVisitor { + type Value = Vec<(String, Val)>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map(self, mut access: M) -> Result + where + M: MapAccess<'de>, + { + let mut pairs = match access.size_hint() { + Some(size) => Vec::with_capacity(size), + None => Vec::new(), + }; + + while let Some((key, value)) = access.next_entry()? { + pairs.push((key, value)); + } + + Ok(pairs) + } + } + + deserializer.deserialize_map(OrderedMapVisitor) + } +} diff --git a/core/json/error.rs b/core/json/error.rs index 870601738..08d17e309 100644 --- a/core/json/error.rs +++ b/core/json/error.rs @@ -106,3 +106,11 @@ pub fn set_location(res: &mut Result, span: &Span<'_>) { } } } + +impl From for crate::LimboError { + fn from(err: Error) -> Self { + match err { + Error::Message { msg, .. } => crate::LimboError::ParseError(msg), + } + } +} diff --git a/core/json/json_operations.rs b/core/json/json_operations.rs new file mode 100644 index 000000000..086090fea --- /dev/null +++ b/core/json/json_operations.rs @@ -0,0 +1,594 @@ +use std::collections::VecDeque; + +use crate::{ + json::{mutate_json_by_path, Target}, + types::OwnedValue, +}; + +use super::{convert_json_to_db_type, get_json_value, json_path::json_path, Val}; + +/// Represents a single patch operation in the merge queue. +/// +/// Used internally by the `merge_patch` function to track the path and value +/// for each pending merge operation. +#[derive(Debug, Clone)] +struct PatchOperation { + path_start: usize, + path_len: usize, + patch: Val, +} + +/// The function follows RFC 7386 JSON Merge Patch semantics: +/// * If the patch is null, the target is replaced with null +/// * If the patch contains a scalar value, the target is replaced with that value +/// * If both target and patch are objects, the patch is recursively applied +/// * null values in the patch result in property removal from the target +pub fn json_patch(target: &OwnedValue, patch: &OwnedValue) -> crate::Result { + match (target, patch) { + (OwnedValue::Blob(_), _) | (_, OwnedValue::Blob(_)) => { + crate::bail_constraint_error!("blob is not supported!"); + } + _ => (), + } + + let mut parsed_target = get_json_value(target)?; + let parsed_patch = get_json_value(patch)?; + let mut patcher = JsonPatcher::new(16); + patcher.apply_patch(&mut parsed_target, parsed_patch); + + convert_json_to_db_type(&parsed_target, false) +} + +#[derive(Debug)] +struct JsonPatcher { + queue: VecDeque, + path_storage: Vec, +} + +impl JsonPatcher { + fn new(queue_capacity: usize) -> Self { + Self { + queue: VecDeque::with_capacity(queue_capacity), + path_storage: Vec::with_capacity(256), + } + } + + fn apply_patch(&mut self, target: &mut Val, patch: Val) { + self.queue.push_back(PatchOperation { + path_start: 0, + path_len: 0, + patch, + }); + + while let Some(op) = self.queue.pop_front() { + if let Some(current) = self.navigate_to_target(target, op.path_start, op.path_len) { + self.apply_operation(current, op); + } else { + continue; + } + } + } + + fn navigate_to_target<'a>( + &self, + target: &'a mut Val, + path_start: usize, + path_len: usize, + ) -> Option<&'a mut Val> { + let mut current = target; + for i in 0..path_len { + let key = self.path_storage[path_start + i]; + if let Val::Object(ref mut obj) = current { + current = &mut obj + .get_mut(key) + .unwrap_or_else(|| { + panic!("Invalid path at depth {}: key '{}' not found", i, key) + }) + .1; + } else { + return None; + } + } + Some(current) + } + + fn apply_operation(&mut self, current: &mut Val, operation: PatchOperation) { + let path_start = operation.path_start; + let path_len = operation.path_len; + match (current, operation.patch) { + (current_val, Val::Null) => *current_val = Val::Removed, + (Val::Object(target_map), Val::Object(patch_map)) => { + self.merge_objects(target_map, patch_map, path_start, path_len); + } + (current_val, patch_val) => *current_val = patch_val, + } + } + + fn merge_objects( + &mut self, + target_map: &mut Vec<(String, Val)>, + patch_map: Vec<(String, Val)>, + path_start: usize, + path_len: usize, + ) { + for (key, patch_val) in patch_map { + self.process_key_value(target_map, key, patch_val, path_start, path_len); + } + } + + fn process_key_value( + &mut self, + target_map: &mut Vec<(String, Val)>, + key: String, + patch_val: Val, + path_start: usize, + path_len: usize, + ) { + if let Some(pos) = target_map + .iter() + .position(|(target_key, _)| target_key == &key) + { + self.queue_nested_patch(pos, patch_val, path_start, path_len); + } else if !matches!(patch_val, Val::Null) { + target_map.push((key, Val::Object(vec![]))); + self.queue_nested_patch(target_map.len() - 1, patch_val, path_start, path_len) + } + } + + fn queue_nested_patch(&mut self, pos: usize, val: Val, path_start: usize, path_len: usize) { + let new_path_start = self.path_storage.len(); + let new_path_len = path_len + 1; + for i in 0..path_len { + self.path_storage.push(self.path_storage[path_start + i]); + } + self.path_storage.push(pos); + self.queue.push_back(PatchOperation { + path_start: new_path_start, + path_len: new_path_len, + patch: val, + }); + } +} + +pub fn json_remove(args: &[OwnedValue]) -> crate::Result { + if args.is_empty() { + return Ok(OwnedValue::Null); + } + + let mut parsed_target = get_json_value(&args[0])?; + if args.len() == 1 { + return Ok(args[0].clone()); + } + + let paths: Result, _> = args[1..] + .iter() + .map(|path| { + if let OwnedValue::Text(path) = path { + json_path(&path.value) + } else { + crate::bail_constraint_error!("bad JSON path: {:?}", path.to_string()) + } + }) + .collect(); + let paths = paths?; + + for path in paths { + mutate_json_by_path(&mut parsed_target, path, |val| match val { + Target::Array(arr, index) => { + arr.remove(index); + } + Target::Value(val) => *val = Val::Removed, + }); + } + + convert_json_to_db_type(&parsed_target, false) +} + +#[cfg(test)] +mod tests { + use std::rc::Rc; + + use crate::types::LimboText; + + use super::*; + + fn create_text(s: &str) -> OwnedValue { + OwnedValue::Text(LimboText::new(Rc::new(s.to_string()))) + } + + fn create_json(s: &str) -> OwnedValue { + OwnedValue::Text(LimboText::json(Rc::new(s.to_string()))) + } + + #[test] + fn test_new_patcher() { + let patcher = JsonPatcher::new(10); + assert_eq!(patcher.queue.capacity(), 10); + assert_eq!(patcher.path_storage.capacity(), 256); + assert!(patcher.queue.is_empty()); + assert!(patcher.path_storage.is_empty()); + } + + #[test] + fn test_simple_value_replacement() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![("key1".to_string(), Val::String("old".to_string()))]); + let patch = Val::Object(vec![("key1".to_string(), Val::String("new".to_string()))]); + + patcher.apply_patch(&mut target, patch); + + if let Val::Object(map) = target { + assert_eq!(map[0].1, Val::String("new".to_string())); + } else { + panic!("Expected object"); + } + } + + #[test] + fn test_nested_object_patch() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![( + "a".to_string(), + Val::Object(vec![("b".to_string(), Val::String("old".to_string()))]), + )]); + let patch = Val::Object(vec![( + "a".to_string(), + Val::Object(vec![("b".to_string(), Val::String("new".to_string()))]), + )]); + + patcher.apply_patch(&mut target, patch); + + if let Val::Object(map) = target { + if let Val::Object(nested) = &map[0].1 { + assert_eq!(nested[0].1, Val::String("new".to_string())); + } else { + panic!("Expected nested object"); + } + } + } + + #[test] + fn test_null_removal() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![ + ("keep".to_string(), Val::String("value".to_string())), + ("remove".to_string(), Val::String("value".to_string())), + ]); + let patch = Val::Object(vec![("remove".to_string(), Val::Null)]); + + patcher.apply_patch(&mut target, patch); + + if let Val::Object(map) = target { + assert_eq!(map[0].1, Val::String("value".to_string())); + assert_eq!(map[1].1, Val::Removed); + } + } + + #[test] + fn test_duplicate_keys_first_occurrence() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![ + ("key".to_string(), Val::String("first".to_string())), + ("key".to_string(), Val::String("second".to_string())), + ("key".to_string(), Val::String("third".to_string())), + ]); + let patch = Val::Object(vec![( + "key".to_string(), + Val::String("modified".to_string()), + )]); + + patcher.apply_patch(&mut target, patch); + + if let Val::Object(map) = target { + assert_eq!(map[0].1, Val::String("modified".to_string())); + assert_eq!(map[1].1, Val::String("second".to_string())); + assert_eq!(map[2].1, Val::String("third".to_string())); + } + } + + #[test] + fn test_add_new_key() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![( + "existing".to_string(), + Val::String("value".to_string()), + )]); + let patch = Val::Object(vec![("new".to_string(), Val::String("value".to_string()))]); + + patcher.apply_patch(&mut target, patch); + + if let Val::Object(map) = target { + assert_eq!(map.len(), 2); + assert_eq!(map[1].0, "new"); + assert_eq!(map[1].1, Val::String("value".to_string())); + } + } + + #[test] + fn test_deep_nested_patch() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![( + "level1".to_string(), + Val::Object(vec![( + "level2".to_string(), + Val::Object(vec![("level3".to_string(), Val::String("old".to_string()))]), + )]), + )]); + let patch = Val::Object(vec![( + "level1".to_string(), + Val::Object(vec![( + "level2".to_string(), + Val::Object(vec![("level3".to_string(), Val::String("new".to_string()))]), + )]), + )]); + + patcher.apply_patch(&mut target, patch); + + if let Val::Object(l1) = target { + if let Val::Object(l2) = &l1[0].1 { + if let Val::Object(l3) = &l2[0].1 { + assert_eq!(l3[0].1, Val::String("new".to_string())); + } + } + } + } + + #[test] + fn test_null_patch_on_nonexistent_key() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![( + "existing".to_string(), + Val::String("value".to_string()), + )]); + let patch = Val::Object(vec![("nonexistent".to_string(), Val::Null)]); + + patcher.apply_patch(&mut target, patch); + + if let Val::Object(map) = target { + assert_eq!(map.len(), 1); // Should not add new key for null patch + assert_eq!(map[0].0, "existing"); + } + } + + #[test] + fn test_nested_duplicate_keys() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![( + "outer".to_string(), + Val::Object(vec![ + ("inner".to_string(), Val::String("first".to_string())), + ("inner".to_string(), Val::String("second".to_string())), + ]), + )]); + let patch = Val::Object(vec![( + "outer".to_string(), + Val::Object(vec![( + "inner".to_string(), + Val::String("modified".to_string()), + )]), + )]); + + patcher.apply_patch(&mut target, patch); + + if let Val::Object(outer) = target { + if let Val::Object(inner) = &outer[0].1 { + assert_eq!(inner[0].1, Val::String("modified".to_string())); + assert_eq!(inner[1].1, Val::String("second".to_string())); + } + } + } + + #[test] + #[should_panic(expected = "Invalid path")] + fn test_invalid_path_navigation() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![("a".to_string(), Val::Object(vec![]))]); + patcher.path_storage.push(0); + patcher.path_storage.push(999); // Invalid index + + patcher.navigate_to_target(&mut target, 0, 2); + } + + #[test] + fn test_merge_empty_objects() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![]); + let patch = Val::Object(vec![]); + + patcher.apply_patch(&mut target, patch); + + if let Val::Object(map) = target { + assert!(map.is_empty()); + } + } + + #[test] + fn test_path_storage_growth() { + let mut patcher = JsonPatcher::new(10); + let mut target = Val::Object(vec![( + "a".to_string(), + Val::Object(vec![("b".to_string(), Val::Object(vec![]))]), + )]); + let patch = Val::Object(vec![( + "a".to_string(), + Val::Object(vec![("b".to_string(), Val::String("value".to_string()))]), + )]); + + patcher.apply_patch(&mut target, patch); + + // Path storage should contain [0, 0] for accessing a.b + assert_eq!(patcher.path_storage.len(), 3); + assert_eq!(patcher.path_storage[0], 0); + assert_eq!(patcher.path_storage[1], 0); + } + + #[test] + fn test_basic_text_replacement() { + let target = create_text(r#"{"name":"John","age":"30"}"#); + let patch = create_text(r#"{"age":"31"}"#); + + let result = json_patch(&target, &patch).unwrap(); + assert_eq!(result, create_json(r#"{"name":"John","age":"31"}"#)); + } + + #[test] + fn test_null_field_removal() { + let target = create_text(r#"{"name":"John","email":"john@example.com"}"#); + let patch = create_text(r#"{"email":null}"#); + + let result = json_patch(&target, &patch).unwrap(); + assert_eq!(result, create_json(r#"{"name":"John"}"#)); + } + + #[test] + fn test_nested_object_merge() { + let target = + create_text(r#"{"user":{"name":"John","details":{"age":"30","score":"95.5"}}}"#); + + let patch = create_text(r#"{"user":{"details":{"score":"97.5"}}}"#); + + let result = json_patch(&target, &patch).unwrap(); + assert_eq!( + result, + create_json(r#"{"user":{"name":"John","details":{"age":"30","score":"97.5"}}}"#) + ); + } + + #[test] + #[should_panic(expected = "blob is not supported!")] + fn test_blob_not_supported() { + let target = OwnedValue::Blob(Rc::new(vec![1, 2, 3])); + let patch = create_text("{}"); + json_patch(&target, &patch).unwrap(); + } + + #[test] + fn test_deep_null_replacement() { + let target = create_text(r#"{"level1":{"level2":{"keep":"value","remove":"value"}}}"#); + + let patch = create_text(r#"{"level1":{"level2":{"remove":null}}}"#); + + let result = json_patch(&target, &patch).unwrap(); + assert_eq!( + result, + create_json(r#"{"level1":{"level2":{"keep":"value"}}}"#) + ); + } + + #[test] + fn test_empty_patch() { + let target = create_json(r#"{"name":"John","age":"30"}"#); + let patch = create_text("{}"); + + let result = json_patch(&target, &patch).unwrap(); + assert_eq!(result, target); + } + + #[test] + fn test_add_new_field() { + let target = create_text(r#"{"existing":"value"}"#); + let patch = create_text(r#"{"new":"field"}"#); + + let result = json_patch(&target, &patch).unwrap(); + assert_eq!(result, create_json(r#"{"existing":"value","new":"field"}"#)); + } + + #[test] + fn test_complete_object_replacement() { + let target = create_text(r#"{"old":{"nested":"value"}}"#); + let patch = create_text(r#"{"old":"new_value"}"#); + + let result = json_patch(&target, &patch).unwrap(); + assert_eq!(result, create_json(r#"{"old":"new_value"}"#)); + } + + #[test] + fn test_json_remove_empty_args() { + let args = vec![]; + assert_eq!(json_remove(&args).unwrap(), OwnedValue::Null); + } + + #[test] + fn test_json_remove_array_element() { + let args = vec![create_json(r#"[1,2,3,4,5]"#), create_text("$[2]")]; + + let result = json_remove(&args).unwrap(); + match result { + OwnedValue::Text(t) => assert_eq!(t.value.as_str(), "[1,2,4,5]"), + _ => panic!("Expected Text value"), + } + } + + #[test] + fn test_json_remove_multiple_paths() { + let args = vec![ + create_json(r#"{"a": 1, "b": 2, "c": 3}"#), + create_text("$.a"), + create_text("$.c"), + ]; + + let result = json_remove(&args).unwrap(); + match result { + OwnedValue::Text(t) => assert_eq!(t.value.as_str(), r#"{"b":2}"#), + _ => panic!("Expected Text value"), + } + } + + #[test] + fn test_json_remove_nested_paths() { + let args = vec![ + create_json(r#"{"a": {"b": {"c": 1, "d": 2}}}"#), + create_text("$.a.b.c"), + ]; + + let result = json_remove(&args).unwrap(); + match result { + OwnedValue::Text(t) => assert_eq!(t.value.as_str(), r#"{"a":{"b":{"d":2}}}"#), + _ => panic!("Expected Text value"), + } + } + + #[test] + fn test_json_remove_duplicate_keys() { + let args = vec![ + create_json(r#"{"a": 1, "a": 2, "a": 3}"#), + create_text("$.a"), + ]; + + let result = json_remove(&args).unwrap(); + match result { + OwnedValue::Text(t) => assert_eq!(t.value.as_str(), r#"{"a":2,"a":3}"#), + _ => panic!("Expected Text value"), + } + } + + #[test] + fn test_json_remove_invalid_path() { + let args = vec![ + create_json(r#"{"a": 1}"#), + OwnedValue::Integer(42), // Invalid path type + ]; + + assert!(json_remove(&args).is_err()); + } + + #[test] + fn test_json_remove_complex_case() { + let args = vec![ + create_json(r#"{"a":[1,2,3],"b":{"x":1,"x":2},"c":[{"y":1},{"y":2}]}"#), + create_text("$.a[1]"), + create_text("$.b.x"), + create_text("$.c[0].y"), + ]; + + let result = json_remove(&args).unwrap(); + match result { + OwnedValue::Text(t) => { + let value = t.value.as_str(); + assert!(value.contains(r#"[1,3]"#)); + assert!(value.contains(r#"{"x":2}"#)); + } + _ => panic!("Expected Text value"), + } + } +} diff --git a/core/json/json_path.pest b/core/json/json_path.pest index 71a462edc..590a3df23 100644 --- a/core/json/json_path.pest +++ b/core/json/json_path.pest @@ -4,5 +4,5 @@ array_locator = ${ "[" ~ negative_index_indicator? ~ array_offset ~ "]" } relaxed_array_locator = ${ negative_index_indicator? ~ array_offset } root = ${ "$" } -json_path_key = ${ identifier | string } +json_path_key = ${ identifier | string | ASCII_DIGIT+ } path = ${ SOI ~ root ~ (array_locator | "." ~ json_path_key)* ~ EOI } diff --git a/core/json/mod.rs b/core/json/mod.rs index 10e682148..13c2bf98d 100644 --- a/core/json/mod.rs +++ b/core/json/mod.rs @@ -1,17 +1,21 @@ mod de; mod error; +mod json_operations; mod json_path; mod ser; use std::rc::Rc; pub use crate::json::de::from_str; +use crate::json::de::ordered_object; use crate::json::error::Error as JsonError; +pub use crate::json::json_operations::{json_patch, json_remove}; use crate::json::json_path::{json_path, JsonPath, PathElement}; pub use crate::json::ser::to_string; use crate::types::{LimboText, OwnedValue, TextSubtype}; use indexmap::IndexMap; use jsonb::Error as JsonbError; +use ser::to_string_pretty; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] @@ -23,10 +27,12 @@ pub enum Val { Float(f64), String(String), Array(Vec), - Object(IndexMap), + Removed, + #[serde(with = "ordered_object")] + Object(Vec<(String, Val)>), } -pub fn get_json(json_value: &OwnedValue) -> crate::Result { +pub fn get_json(json_value: &OwnedValue, indent: Option<&str>) -> crate::Result { match json_value { OwnedValue::Text(ref t) => { // optimization: once we know the subtype is a valid JSON, we do not have @@ -36,7 +42,10 @@ pub fn get_json(json_value: &OwnedValue) -> crate::Result { } let json_val = get_json_value(json_value)?; - let json = to_string(&json_val).unwrap(); + let json = match indent { + Some(indent) => to_string_pretty(&json_val, indent)?, + None => to_string(&json_val)?, + }; Ok(OwnedValue::Text(LimboText::json(Rc::new(json)))) } @@ -52,7 +61,10 @@ pub fn get_json(json_value: &OwnedValue) -> crate::Result { OwnedValue::Null => Ok(OwnedValue::Null), _ => { let json_val = get_json_value(json_value)?; - let json = to_string(&json_val).unwrap(); + let json = match indent { + Some(indent) => to_string_pretty(&json_val, indent)?, + None => to_string(&json_val)?, + }; Ok(OwnedValue::Text(LimboText::json(Rc::new(json)))) } @@ -143,6 +155,45 @@ pub fn json_array_length( } } +pub fn json_set(json: &OwnedValue, values: &[OwnedValue]) -> crate::Result { + let mut json_value = get_json_value(json)?; + + values + .chunks(2) + .map(|chunk| match chunk { + [path, value] => { + let path = json_path_from_owned_value(path, true)?; + + if let Some(path) = path { + let new_value = match value { + OwnedValue::Text(LimboText { + value, + subtype: TextSubtype::Text, + }) => Val::String(value.to_string()), + _ => get_json_value(value)?, + }; + + let mut new_json_value = json_value.clone(); + + match create_and_mutate_json_by_path(&mut new_json_value, path, |val| match val + { + Target::Array(arr, index) => arr[index] = new_value.clone(), + Target::Value(val) => *val = new_value.clone(), + }) { + Some(_) => json_value = new_json_value, + _ => {} + } + } + + Ok(()) + } + _ => crate::bail_constraint_error!("json_set needs an odd number of arguments"), + }) + .collect::>()?; + + convert_json_to_db_type(&json_value, true) +} + /// Implements the -> operator. Always returns a proper JSON value. /// https://sqlite.org/json1.html#the_and_operators pub fn json_arrow_extract(value: &OwnedValue, path: &OwnedValue) -> crate::Result { @@ -154,7 +205,7 @@ pub fn json_arrow_extract(value: &OwnedValue, path: &OwnedValue) -> crate::Resul let extracted = json_extract_single(&json, path, false)?; if let Some(val) = extracted { - let json = to_string(val).unwrap(); + let json = to_string(val)?; Ok(OwnedValue::Text(LimboText::json(Rc::new(json)))) } else { @@ -211,7 +262,7 @@ pub fn json_extract(value: &OwnedValue, paths: &[OwnedValue]) -> crate::Result crate::Result crate::Result { match extracted { + Val::Removed => Ok(OwnedValue::Null), Val::Null => Ok(OwnedValue::Null), Val::Float(f) => Ok(OwnedValue::Float(*f)), Val::Integer(i) => Ok(OwnedValue::Integer(*i)), @@ -248,7 +300,7 @@ fn convert_json_to_db_type(extracted: &Val, all_as_db: bool) -> crate::Result Ok(OwnedValue::Text(LimboText::new(Rc::new(s.clone())))), _ => { - let json = to_string(&extracted).unwrap(); + let json = to_string(&extracted)?; if all_as_db { Ok(OwnedValue::Text(LimboText::new(Rc::new(json)))) } else { @@ -311,6 +363,7 @@ pub fn json_type(value: &OwnedValue, path: Option<&OwnedValue>) -> crate::Result Val::String(_) => "text", Val::Array(_) => "array", Val::Object(_) => "object", + Val::Removed => unreachable!(), }; Ok(OwnedValue::Text(LimboText::json(Rc::new(val.to_string())))) @@ -327,32 +380,9 @@ fn json_extract_single<'a>( path: &OwnedValue, strict: bool, ) -> crate::Result> { - let json_path = if strict { - match path { - OwnedValue::Text(t) => json_path(t.value.as_str())?, - OwnedValue::Null => return Ok(None), - _ => crate::bail_constraint_error!("JSON path error near: {:?}", path.to_string()), - } - } else { - match path { - OwnedValue::Text(t) => { - if t.value.starts_with("$") { - json_path(t.value.as_str())? - } else { - JsonPath { - elements: vec![PathElement::Root(), PathElement::Key(t.value.to_string())], - } - } - } - OwnedValue::Null => return Ok(None), - OwnedValue::Integer(i) => JsonPath { - elements: vec![PathElement::Root(), PathElement::ArrayLocator(*i as i32)], - }, - OwnedValue::Float(f) => JsonPath { - elements: vec![PathElement::Root(), PathElement::Key(f.to_string())], - }, - _ => crate::bail_constraint_error!("JSON path error near: {:?}", path.to_string()), - } + let json_path = match json_path_from_owned_value(path, strict)? { + Some(path) => path, + None => return Ok(None), }; let mut current_element = &Val::Null; @@ -367,7 +397,7 @@ fn json_extract_single<'a>( match current_element { Val::Object(map) => { - if let Some(value) = map.get(key) { + if let Some((_, value)) = map.iter().find(|(k, _)| k == key) { current_element = value; } else { return Ok(None); @@ -398,6 +428,182 @@ fn json_extract_single<'a>( Ok(Some(current_element)) } +fn json_path_from_owned_value(path: &OwnedValue, strict: bool) -> crate::Result> { + let json_path = if strict { + match path { + OwnedValue::Text(t) => json_path(t.value.as_str())?, + OwnedValue::Null => return Ok(None), + _ => crate::bail_constraint_error!("JSON path error near: {:?}", path.to_string()), + } + } else { + match path { + OwnedValue::Text(t) => { + if t.value.starts_with("$") { + json_path(t.value.as_str())? + } else { + JsonPath { + elements: vec![PathElement::Root(), PathElement::Key(t.value.to_string())], + } + } + } + OwnedValue::Null => return Ok(None), + OwnedValue::Integer(i) => JsonPath { + elements: vec![PathElement::Root(), PathElement::ArrayLocator(*i as i32)], + }, + OwnedValue::Float(f) => JsonPath { + elements: vec![PathElement::Root(), PathElement::Key(f.to_string())], + }, + _ => crate::bail_constraint_error!("JSON path error near: {:?}", path.to_string()), + } + }; + + Ok(Some(json_path)) +} + +enum Target<'a> { + Array(&'a mut Vec, usize), + Value(&'a mut Val), +} + +fn mutate_json_by_path(json: &mut Val, path: JsonPath, closure: F) -> Option +where + F: FnMut(Target) -> R, +{ + find_target(json, &path).map(closure) +} + +fn find_target<'a>(json: &'a mut Val, path: &JsonPath) -> Option> { + let mut current = json; + for (i, key) in path.elements.iter().enumerate() { + let is_last = i == path.elements.len() - 1; + match key { + PathElement::Root() => continue, + PathElement::ArrayLocator(index) => match current { + Val::Array(arr) => { + if let Some(index) = match index { + i if *i < 0 => arr.len().checked_sub(i.unsigned_abs() as usize), + i => ((*i as usize) < arr.len()).then_some(*i as usize), + } { + if is_last { + return Some(Target::Array(arr, index)); + } else { + current = &mut arr[index]; + } + } else { + return None; + } + } + _ => { + return None; + } + }, + PathElement::Key(key) => match current { + Val::Object(obj) => { + if let Some(pos) = &obj + .iter() + .position(|(k, v)| k == key && !matches!(v, Val::Removed)) + { + let val = &mut obj[*pos].1; + current = val; + } else { + return None; + } + } + _ => { + return None; + } + }, + } + } + Some(Target::Value(current)) +} + +fn create_and_mutate_json_by_path(json: &mut Val, path: JsonPath, closure: F) -> Option +where + F: FnOnce(Target) -> R, +{ + find_or_create_target(json, &path).map(closure) +} + +fn find_or_create_target<'a>(json: &'a mut Val, path: &JsonPath) -> Option> { + let mut current = json; + for (i, key) in path.elements.iter().enumerate() { + let is_last = i == path.elements.len() - 1; + match key { + PathElement::Root() => continue, + PathElement::ArrayLocator(index) => match current { + Val::Array(arr) => { + if let Some(index) = match index { + i if *i < 0 => arr.len().checked_sub(i.unsigned_abs() as usize), + i => Some(*i as usize), + } { + if is_last { + if index == arr.len() { + arr.push(Val::Null); + } + + if index >= arr.len() { + return None; + } + + return Some(Target::Array(arr, index)); + } else { + if index == arr.len() { + arr.push( + if matches!(path.elements[i + 1], PathElement::ArrayLocator(_)) + { + Val::Array(vec![]) + } else { + Val::Object(vec![]) + }, + ); + } + + if index >= arr.len() { + return None; + } + + current = &mut arr[index]; + } + } else { + return None; + } + } + _ => { + *current = Val::Array(vec![]); + } + }, + PathElement::Key(key) => match current { + Val::Object(obj) => { + if let Some(pos) = &obj + .iter() + .position(|(k, v)| k == key && !matches!(v, Val::Removed)) + { + let val = &mut obj[*pos].1; + current = val; + } else { + let element = if !is_last + && matches!(path.elements[i + 1], PathElement::ArrayLocator(_)) + { + Val::Array(vec![]) + } else { + Val::Object(vec![]) + }; + + obj.push((key.clone(), element)); + let index = obj.len() - 1; + current = &mut obj[index].1; + } + } + _ => { + return None; + } + }, + } + } + Some(Target::Value(current)) +} + pub fn json_error_position(json: &OwnedValue) -> crate::Result { match json { OwnedValue::Text(t) => match from_str::(&t.value) { @@ -444,10 +650,25 @@ pub fn json_object(values: &[OwnedValue]) -> crate::Result { }) .collect::, _>>()?; - let result = crate::json::to_string(&value_map).unwrap(); + let result = crate::json::to_string(&value_map)?; Ok(OwnedValue::Text(LimboText::json(Rc::new(result)))) } +pub fn is_json_valid(json_value: &OwnedValue) -> crate::Result { + match json_value { + OwnedValue::Text(ref t) => match from_str::(&t.value) { + Ok(_) => Ok(OwnedValue::Integer(1)), + Err(_) => Ok(OwnedValue::Integer(0)), + }, + OwnedValue::Blob(b) => match jsonb::from_slice(b) { + Ok(_) => Ok(OwnedValue::Integer(1)), + Err(_) => Ok(OwnedValue::Integer(0)), + }, + OwnedValue::Null => Ok(OwnedValue::Null), + _ => Ok(OwnedValue::Integer(1)), + } +} + #[cfg(test)] mod tests { use super::*; @@ -456,7 +677,7 @@ mod tests { #[test] fn test_get_json_valid_json5() { let input = OwnedValue::build_text(Rc::new("{ key: 'value' }".to_string())); - let result = get_json(&input).unwrap(); + let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { assert!(result_str.value.contains("\"key\":\"value\"")); assert_eq!(result_str.subtype, TextSubtype::Json); @@ -468,7 +689,7 @@ mod tests { #[test] fn test_get_json_valid_json5_double_single_quotes() { let input = OwnedValue::build_text(Rc::new("{ key: ''value'' }".to_string())); - let result = get_json(&input).unwrap(); + let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { assert!(result_str.value.contains("\"key\":\"value\"")); assert_eq!(result_str.subtype, TextSubtype::Json); @@ -480,7 +701,7 @@ mod tests { #[test] fn test_get_json_valid_json5_infinity() { let input = OwnedValue::build_text(Rc::new("{ \"key\": Infinity }".to_string())); - let result = get_json(&input).unwrap(); + let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { assert!(result_str.value.contains("{\"key\":9e999}")); assert_eq!(result_str.subtype, TextSubtype::Json); @@ -492,7 +713,7 @@ mod tests { #[test] fn test_get_json_valid_json5_negative_infinity() { let input = OwnedValue::build_text(Rc::new("{ \"key\": -Infinity }".to_string())); - let result = get_json(&input).unwrap(); + let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { assert!(result_str.value.contains("{\"key\":-9e999}")); assert_eq!(result_str.subtype, TextSubtype::Json); @@ -504,7 +725,7 @@ mod tests { #[test] fn test_get_json_valid_json5_nan() { let input = OwnedValue::build_text(Rc::new("{ \"key\": NaN }".to_string())); - let result = get_json(&input).unwrap(); + let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { assert!(result_str.value.contains("{\"key\":null}")); assert_eq!(result_str.subtype, TextSubtype::Json); @@ -516,7 +737,7 @@ mod tests { #[test] fn test_get_json_invalid_json5() { let input = OwnedValue::build_text(Rc::new("{ key: value }".to_string())); - let result = get_json(&input); + let result = get_json(&input, None); match result { Ok(_) => panic!("Expected error for malformed JSON"), Err(e) => assert!(e.to_string().contains("malformed JSON")), @@ -526,7 +747,7 @@ mod tests { #[test] fn test_get_json_valid_jsonb() { let input = OwnedValue::build_text(Rc::new("{\"key\":\"value\"}".to_string())); - let result = get_json(&input).unwrap(); + let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { assert!(result_str.value.contains("\"key\":\"value\"")); assert_eq!(result_str.subtype, TextSubtype::Json); @@ -538,7 +759,7 @@ mod tests { #[test] fn test_get_json_invalid_jsonb() { let input = OwnedValue::build_text(Rc::new("{key:\"value\"".to_string())); - let result = get_json(&input); + let result = get_json(&input, None); match result { Ok(_) => panic!("Expected error for malformed JSON"), Err(e) => assert!(e.to_string().contains("malformed JSON")), @@ -549,7 +770,7 @@ mod tests { fn test_get_json_blob_valid_jsonb() { let binary_json = b"\x40\0\0\x01\x10\0\0\x03\x10\0\0\x03\x61\x73\x64\x61\x64\x66".to_vec(); let input = OwnedValue::Blob(Rc::new(binary_json)); - let result = get_json(&input).unwrap(); + let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { assert!(result_str.value.contains("\"asd\":\"adf\"")); assert_eq!(result_str.subtype, TextSubtype::Json); @@ -562,7 +783,7 @@ mod tests { fn test_get_json_blob_invalid_jsonb() { let binary_json: Vec = vec![0xA2, 0x62, 0x6B, 0x31, 0x62, 0x76]; // Incomplete binary JSON let input = OwnedValue::Blob(Rc::new(binary_json)); - let result = get_json(&input); + let result = get_json(&input, None); match result { Ok(_) => panic!("Expected error for malformed JSON"), Err(e) => assert!(e.to_string().contains("malformed JSON")), @@ -572,7 +793,7 @@ mod tests { #[test] fn test_get_json_non_text() { let input = OwnedValue::Null; - let result = get_json(&input).unwrap(); + let result = get_json(&input, None).unwrap(); if let OwnedValue::Null = result { // Test passed } else { @@ -729,7 +950,7 @@ mod tests { #[test] fn test_json_array_length_simple_json_subtype() { let input = OwnedValue::build_text(Rc::new("[1,2,3]".to_string())); - let wrapped = get_json(&input).unwrap(); + let wrapped = get_json(&input, None).unwrap(); let result = json_array_length(&wrapped, None).unwrap(); if let OwnedValue::Integer(res) = result { @@ -983,18 +1204,484 @@ mod tests { .contains("json_object requires an even number of values")), } } -} -pub fn is_json_valid(json_value: &OwnedValue) -> crate::Result { - match json_value { - OwnedValue::Text(ref t) => match from_str::(&t.value) { - Ok(_) => Ok(OwnedValue::Integer(1)), - Err(_) => Ok(OwnedValue::Integer(0)), - }, - OwnedValue::Blob(b) => match jsonb::from_slice(b) { - Ok(_) => Ok(OwnedValue::Integer(1)), - Err(_) => Ok(OwnedValue::Integer(0)), - }, - OwnedValue::Null => Ok(OwnedValue::Null), - _ => Ok(OwnedValue::Integer(1)), + + #[test] + fn test_find_target_array() { + let mut val = Val::Array(vec![ + Val::String("first".to_string()), + Val::String("second".to_string()), + ]); + let path = JsonPath { + elements: vec![PathElement::ArrayLocator(0)], + }; + + match find_target(&mut val, &path) { + Some(Target::Array(_, idx)) => assert_eq!(idx, 0), + _ => panic!("Expected Array target"), + } + } + + #[test] + fn test_find_target_negative_index() { + let mut val = Val::Array(vec![ + Val::String("first".to_string()), + Val::String("second".to_string()), + ]); + let path = JsonPath { + elements: vec![PathElement::ArrayLocator(-1)], + }; + + match find_target(&mut val, &path) { + Some(Target::Array(_, idx)) => assert_eq!(idx, 1), + _ => panic!("Expected Array target"), + } + } + + #[test] + fn test_find_target_object() { + let mut val = Val::Object(vec![("key".to_string(), Val::String("value".to_string()))]); + let path = JsonPath { + elements: vec![PathElement::Key("key".to_string())], + }; + + match find_target(&mut val, &path) { + Some(Target::Value(_)) => {} + _ => panic!("Expected Value target"), + } + } + + #[test] + fn test_find_target_removed() { + let mut val = Val::Object(vec![ + ("key".to_string(), Val::Removed), + ("key".to_string(), Val::String("value".to_string())), + ]); + let path = JsonPath { + elements: vec![PathElement::Key("key".to_string())], + }; + + match find_target(&mut val, &path) { + Some(Target::Value(val)) => assert!(matches!(val, Val::String(_))), + _ => panic!("Expected second value, not removed"), + } + } + + #[test] + fn test_mutate_json() { + let mut val = Val::Array(vec![Val::String("test".to_string())]); + let path = JsonPath { + elements: vec![PathElement::ArrayLocator(0)], + }; + + let result = mutate_json_by_path(&mut val, path, |target| match target { + Target::Array(arr, idx) => { + arr.remove(idx); + "removed" + } + _ => panic!("Expected Array target"), + }); + + assert_eq!(result, Some("removed")); + assert!(matches!(val, Val::Array(arr) if arr.is_empty())); + } + + #[test] + fn test_mutate_json_none() { + let mut val = Val::Array(vec![]); + let path = JsonPath { + elements: vec![PathElement::ArrayLocator(0)], + }; + + let result: Option<()> = mutate_json_by_path(&mut val, path, |_| { + panic!("Should not be called"); + }); + + assert_eq!(result, None); + } + + #[test] + fn test_json_path_from_owned_value_root_strict() { + let path = OwnedValue::Text(LimboText { + value: Rc::new("$".to_string()), + subtype: TextSubtype::Text, + }); + + let result = json_path_from_owned_value(&path, true); + assert!(result.is_ok()); + + let result = result.unwrap(); + assert!(result.is_some()); + + let result = result.unwrap(); + match result.elements[..] { + [PathElement::Root()] => {} + _ => panic!("Expected root"), + } + } + + #[test] + fn test_json_path_from_owned_value_root_non_strict() { + let path = OwnedValue::Text(LimboText { + value: Rc::new("$".to_string()), + subtype: TextSubtype::Text, + }); + + let result = json_path_from_owned_value(&path, false); + assert!(result.is_ok()); + + let result = result.unwrap(); + assert!(result.is_some()); + + let result = result.unwrap(); + match result.elements[..] { + [PathElement::Root()] => {} + _ => panic!("Expected root"), + } + } + + #[test] + fn test_json_path_from_owned_value_named_strict() { + let path = OwnedValue::Text(LimboText { + value: Rc::new("field".to_string()), + subtype: TextSubtype::Text, + }); + + assert!(json_path_from_owned_value(&path, true).is_err()); + } + + #[test] + fn test_json_path_from_owned_value_named_non_strict() { + let path = OwnedValue::Text(LimboText { + value: Rc::new("field".to_string()), + subtype: TextSubtype::Text, + }); + + let result = json_path_from_owned_value(&path, false); + assert!(result.is_ok()); + + let result = result.unwrap(); + assert!(result.is_some()); + + let result = result.unwrap(); + match &result.elements[..] { + [PathElement::Root(), PathElement::Key(field)] if *field == "field" => {} + _ => panic!("Expected root and field"), + } + } + + #[test] + fn test_json_path_from_owned_value_integer_strict() { + let path = OwnedValue::Integer(3); + assert!(json_path_from_owned_value(&path, true).is_err()); + } + + #[test] + fn test_json_path_from_owned_value_integer_non_strict() { + let path = OwnedValue::Integer(3); + + let result = json_path_from_owned_value(&path, false); + assert!(result.is_ok()); + + let result = result.unwrap(); + assert!(result.is_some()); + + let result = result.unwrap(); + match &result.elements[..] { + [PathElement::Root(), PathElement::ArrayLocator(index)] if *index == 3 => {} + _ => panic!("Expected root and array locator"), + } + } + + #[test] + fn test_json_path_from_owned_value_null_strict() { + let path = OwnedValue::Null; + + let result = json_path_from_owned_value(&path, true); + assert!(result.is_ok()); + + let result = result.unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_json_path_from_owned_value_null_non_strict() { + let path = OwnedValue::Null; + + let result = json_path_from_owned_value(&path, false); + assert!(result.is_ok()); + + let result = result.unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_json_path_from_owned_value_float_strict() { + let path = OwnedValue::Float(1.23); + + assert!(json_path_from_owned_value(&path, true).is_err()); + } + + #[test] + fn test_json_path_from_owned_value_float_non_strict() { + let path = OwnedValue::Float(1.23); + + let result = json_path_from_owned_value(&path, false); + assert!(result.is_ok()); + + let result = result.unwrap(); + assert!(result.is_some()); + + let result = result.unwrap(); + match &result.elements[..] { + [PathElement::Root(), PathElement::Key(field)] if *field == "1.23" => {} + _ => panic!("Expected root and field"), + } + } + + #[test] + fn test_json_set_field_empty_object() { + let result = json_set( + &OwnedValue::build_text(Rc::new("{}".to_string())), + &[ + OwnedValue::build_text(Rc::new("$.field".to_string())), + OwnedValue::build_text(Rc::new("value".to_string())), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new(r#"{"field":"value"}"#.to_string())) + ); + } + + #[test] + fn test_json_set_replace_field() { + let result = json_set( + &OwnedValue::build_text(Rc::new(r#"{"field":"old_value"}"#.to_string())), + &[ + OwnedValue::build_text(Rc::new("$.field".to_string())), + OwnedValue::build_text(Rc::new("new_value".to_string())), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new(r#"{"field":"new_value"}"#.to_string())) + ); + } + + #[test] + fn test_json_set_set_deeply_nested_key() { + let result = json_set( + &OwnedValue::build_text(Rc::new("{}".to_string())), + &[ + OwnedValue::build_text(Rc::new("$.object.doesnt.exist".to_string())), + OwnedValue::build_text(Rc::new("value".to_string())), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new( + r#"{"object":{"doesnt":{"exist":"value"}}}"#.to_string() + )) + ); + } + + #[test] + fn test_json_set_add_value_to_empty_array() { + let result = json_set( + &OwnedValue::build_text(Rc::new("[]".to_string())), + &[ + OwnedValue::build_text(Rc::new("$[0]".to_string())), + OwnedValue::build_text(Rc::new("value".to_string())), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new(r#"["value"]"#.to_string())) + ); + } + + #[test] + fn test_json_set_add_value_to_nonexistent_array() { + let result = json_set( + &OwnedValue::build_text(Rc::new("{}".to_string())), + &[ + OwnedValue::build_text(Rc::new("$.some_array[0]".to_string())), + OwnedValue::Integer(123), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new(r#"{"some_array":[123]}"#.to_string())) + ); + } + + #[test] + fn test_json_set_add_value_to_array() { + let result = json_set( + &OwnedValue::build_text(Rc::new("[123]".to_string())), + &[ + OwnedValue::build_text(Rc::new("$[1]".to_string())), + OwnedValue::Integer(456), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new("[123,456]".to_string())) + ); + } + + #[test] + fn test_json_set_add_value_to_array_out_of_bounds() { + let result = json_set( + &OwnedValue::build_text(Rc::new("[123]".to_string())), + &[ + OwnedValue::build_text(Rc::new("$[200]".to_string())), + OwnedValue::Integer(456), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new("[123]".to_string())) + ); + } + + #[test] + fn test_json_set_replace_value_in_array() { + let result = json_set( + &OwnedValue::build_text(Rc::new("[123]".to_string())), + &[ + OwnedValue::build_text(Rc::new("$[0]".to_string())), + OwnedValue::Integer(456), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new("[456]".to_string())) + ); + } + + #[test] + fn test_json_set_null_path() { + let result = json_set( + &OwnedValue::build_text(Rc::new("{}".to_string())), + &[OwnedValue::Null, OwnedValue::Integer(456)], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new("{}".to_string())) + ); + } + + #[test] + fn test_json_set_multiple_keys() { + let result = json_set( + &OwnedValue::build_text(Rc::new("[123]".to_string())), + &[ + OwnedValue::build_text(Rc::new("$[0]".to_string())), + OwnedValue::Integer(456), + OwnedValue::build_text(Rc::new("$[1]".to_string())), + OwnedValue::Integer(789), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new("[456,789]".to_string())) + ); + } + + #[test] + fn test_json_set_missing_value() { + let result = json_set( + &OwnedValue::build_text(Rc::new("[123]".to_string())), + &[OwnedValue::build_text(Rc::new("$[0]".to_string()))], + ); + + assert!(result.is_err()); + } + + #[test] + fn test_json_set_add_array_in_nested_object() { + let result = json_set( + &OwnedValue::build_text(Rc::new("{}".to_string())), + &[ + OwnedValue::build_text(Rc::new("$.object[0].field".to_string())), + OwnedValue::Integer(123), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new(r#"{"object":[{"field":123}]}"#.to_string())) + ); + } + + #[test] + fn test_json_set_add_array_in_array_in_nested_object() { + let result = json_set( + &OwnedValue::build_text(Rc::new("{}".to_string())), + &[ + OwnedValue::build_text(Rc::new("$.object[0][0]".to_string())), + OwnedValue::Integer(123), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new(r#"{"object":[[123]]}"#.to_string())) + ); + } + + #[test] + fn test_json_set_add_array_in_array_in_nested_object_out_of_bounds() { + let result = json_set( + &OwnedValue::build_text(Rc::new("{}".to_string())), + &[ + OwnedValue::build_text(Rc::new("$.object[123].another".to_string())), + OwnedValue::build_text(Rc::new("value".to_string())), + OwnedValue::build_text(Rc::new("$.field".to_string())), + OwnedValue::build_text(Rc::new("value".to_string())), + ], + ); + + assert!(result.is_ok()); + + assert_eq!( + result.unwrap(), + OwnedValue::build_text(Rc::new(r#"{"field":"value"}"#.to_string())) + ); } } diff --git a/core/json/ser.rs b/core/json/ser.rs index 3d5646584..680296c55 100644 --- a/core/json/ser.rs +++ b/core/json/ser.rs @@ -1,116 +1,157 @@ use serde::ser::{self, Serialize}; -use std::{f32, f64, num::FpCategory}; +use std::{f32, f64, io, num::FpCategory}; use crate::json::error::{Error, Result}; +#[derive(Eq, PartialEq)] +pub enum State { + Empty, + First, + Rest, +} + +struct Map<'a, W: 'a, F: 'a> { + ser: &'a mut Serializer, + state: State, +} + /// Attempts to serialize the input as a JSON5 string (actually a JSON string). pub fn to_string(value: &T) -> Result where T: Serialize, { - let mut serializer = Serializer { - output: String::new(), - }; - value.serialize(&mut serializer)?; - Ok(serializer.output) + let vec = to_vec(value)?; + let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; + Ok(string) } -struct Serializer { - output: String, - // TODO settings for formatting (single vs double quotes, whitespace etc) +/// Attempts to serialize the input as a JSON5 string (actually a JSON string). +pub fn to_string_pretty(value: &T, indent: &str) -> Result +where + T: Serialize, +{ + let vec = to_vec_pretty(value, indent)?; + let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; + Ok(string) } -impl Serializer { - fn call_to_string(&mut self, v: &T) -> Result<()> - where - T: ToString, - { - self.output += &v.to_string(); - Ok(()) +struct Serializer { + writer: W, + formatter: F, +} + +impl Serializer +where + W: io::Write, +{ + pub fn new(writer: W) -> Self { + Serializer::with_formatter(writer, CompactFormatter) } } -impl ser::Serializer for &mut Serializer { +impl<'a, W> Serializer> +where + W: io::Write, +{ + /// Creates a new JSON pretty print serializer. + #[inline] + pub fn pretty(writer: W, indent: &'a str) -> Self { + Serializer::with_formatter(writer, PrettyFormatter::with_indent(indent.as_bytes())) + } +} + +impl Serializer +where + W: io::Write, + F: Formatter, +{ + /// Creates a new JSON visitor whose output will be written to the writer + /// specified. + pub fn with_formatter(writer: W, formatter: F) -> Self { + Serializer { writer, formatter } + } +} + +impl<'a, W, F> ser::Serializer for &'a mut Serializer +where + W: io::Write, + F: Formatter, +{ type Ok = (); type Error = Error; - type SerializeSeq = Self; - type SerializeTuple = Self; - type SerializeTupleStruct = Self; - type SerializeTupleVariant = Self; - type SerializeMap = Self; - type SerializeStruct = Self; - type SerializeStructVariant = Self; + type SerializeSeq = Map<'a, W, F>; + type SerializeTuple = Map<'a, W, F>; + type SerializeTupleStruct = Map<'a, W, F>; + type SerializeTupleVariant = Map<'a, W, F>; + type SerializeMap = Map<'a, W, F>; + type SerializeStruct = Map<'a, W, F>; + type SerializeStructVariant = Map<'a, W, F>; fn serialize_bool(self, v: bool) -> Result<()> { - self.call_to_string(&v) + self.formatter + .write_bool(&mut self.writer, v) + .map_err(Error::from) } fn serialize_i8(self, v: i8) -> Result<()> { - self.call_to_string(&v) + self.formatter + .write_i8(&mut self.writer, v) + .map_err(Error::from) } fn serialize_i16(self, v: i16) -> Result<()> { - self.call_to_string(&v) + self.formatter + .write_i16(&mut self.writer, v) + .map_err(Error::from) } fn serialize_i32(self, v: i32) -> Result<()> { - self.call_to_string(&v) + self.formatter + .write_i32(&mut self.writer, v) + .map_err(Error::from) } fn serialize_i64(self, v: i64) -> Result<()> { - self.call_to_string(&v) + self.formatter + .write_i64(&mut self.writer, v) + .map_err(Error::from) } fn serialize_u8(self, v: u8) -> Result<()> { - self.call_to_string(&v) + self.formatter + .write_u8(&mut self.writer, v) + .map_err(Error::from) } fn serialize_u16(self, v: u16) -> Result<()> { - self.call_to_string(&v) + self.formatter + .write_u16(&mut self.writer, v) + .map_err(Error::from) } fn serialize_u32(self, v: u32) -> Result<()> { - self.call_to_string(&v) + self.formatter + .write_u32(&mut self.writer, v) + .map_err(Error::from) } fn serialize_u64(self, v: u64) -> Result<()> { - self.call_to_string(&v) + self.formatter + .write_u64(&mut self.writer, v) + .map_err(Error::from) } fn serialize_f32(self, v: f32) -> Result<()> { - match v.classify() { - FpCategory::Nan => self.output += "null", - FpCategory::Infinite => { - let infinity = if v.is_sign_negative() { - "-9e999" - } else { - "9e999" - }; - self.output += infinity - } - _ => self.output += &v.to_string(), - } - Ok(()) + self.formatter + .write_f32(&mut self.writer, v) + .map_err(Error::from) } fn serialize_f64(self, v: f64) -> Result<()> { - match v.classify() { - FpCategory::Nan => self.output += "null", - FpCategory::Infinite => { - let infinity = if v.is_sign_negative() { - "-9e999" - } else { - "9e999" - }; - self.output += infinity - } - _ => { - let str = &format!("{:.1}", v); - self.output += str - } - } - Ok(()) + self.formatter + .write_f64(&mut self.writer, v) + .map_err(Error::from) } fn serialize_char(self, v: char) -> Result<()> { @@ -120,10 +161,7 @@ impl ser::Serializer for &mut Serializer { } fn serialize_str(self, v: &str) -> Result<()> { - self.output += "\""; - self.output += &escape(v); - self.output += "\""; - Ok(()) + format_escaped_str(&mut self.writer, &mut self.formatter, v).map_err(Error::from) } fn serialize_bytes(self, _v: &[u8]) -> Result<()> { @@ -142,8 +180,9 @@ impl ser::Serializer for &mut Serializer { } fn serialize_unit(self) -> Result<()> { - self.output += "null"; - Ok(()) + self.formatter + .write_null(&mut self.writer) + .map_err(Error::from) } fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { @@ -176,17 +215,47 @@ impl ser::Serializer for &mut Serializer { where T: ?Sized + Serialize, { - self.output += "{"; - variant.serialize(&mut *self)?; // TODO drop the quotes where possible - self.output += ":"; + self.formatter + .begin_object(&mut self.writer) + .map_err(Error::from)?; + self.formatter + .begin_object_key(&mut self.writer, true) + .map_err(Error::from)?; + self.serialize_str(variant)?; + self.formatter + .end_object_key(&mut self.writer) + .map_err(Error::from)?; + self.formatter + .begin_object_value(&mut self.writer) + .map_err(Error::from)?; value.serialize(&mut *self)?; - self.output += "}"; - Ok(()) + self.formatter + .end_object_value(&mut self.writer) + .map_err(Error::from)?; + self.formatter + .end_object(&mut self.writer) + .map_err(Error::from) } - fn serialize_seq(self, _len: Option) -> Result { - self.output += "["; - Ok(self) + fn serialize_seq(self, len: Option) -> Result { + self.formatter + .begin_array(&mut self.writer) + .map_err(Error::from)?; + + if len == Some(0) { + self.formatter + .end_array(&mut self.writer) + .map_err(Error::from)?; + Ok(Map { + ser: self, + state: State::Empty, + }) + } else { + Ok(Map { + ser: self, + state: State::First, + }) + } } fn serialize_tuple(self, len: usize) -> Result { @@ -206,17 +275,34 @@ impl ser::Serializer for &mut Serializer { _name: &'static str, _variant_index: u32, variant: &'static str, - _len: usize, + len: usize, ) -> Result { - self.output += "{"; - variant.serialize(&mut *self)?; - self.output += ":["; - Ok(self) + self.formatter.begin_object(&mut self.writer)?; + self.formatter.begin_object_key(&mut self.writer, true)?; + self.serialize_str(variant)?; + self.formatter.end_object_key(&mut self.writer)?; + self.formatter.begin_object_value(&mut self.writer)?; + self.serialize_seq(Some(len)) } - fn serialize_map(self, _len: Option) -> Result { - self.output += "{"; - Ok(self) + fn serialize_map(self, len: Option) -> Result { + self.formatter + .begin_object(&mut self.writer) + .map_err(Error::from)?; + if len == Some(0) { + self.formatter + .end_object(&mut self.writer) + .map_err(Error::from)?; + Ok(Map { + ser: self, + state: State::Empty, + }) + } else { + Ok(Map { + ser: self, + state: State::First, + }) + } } fn serialize_struct(self, _name: &'static str, len: usize) -> Result { @@ -228,16 +314,30 @@ impl ser::Serializer for &mut Serializer { _name: &'static str, _variant_index: u32, variant: &'static str, - _len: usize, + len: usize, ) -> Result { - self.output += "{"; - variant.serialize(&mut *self)?; - self.output += ":{"; - Ok(self) + self.formatter + .begin_object(&mut self.writer) + .map_err(Error::from)?; + self.formatter + .begin_object_key(&mut self.writer, true) + .map_err(Error::from)?; + self.serialize_str(variant).map_err(Error::from)?; + self.formatter + .end_object_key(&mut self.writer) + .map_err(Error::from)?; + self.formatter + .begin_object_value(&mut self.writer) + .map_err(Error::from)?; + self.serialize_map(Some(len)) } } -impl ser::SerializeSeq for &mut Serializer { +impl ser::SerializeSeq for Map<'_, W, F> +where + W: io::Write, + F: Formatter, +{ type Ok = (); type Error = Error; @@ -245,19 +345,36 @@ impl ser::SerializeSeq for &mut Serializer { where T: ?Sized + Serialize, { - if !self.output.ends_with('[') { - self.output += ","; + self.ser + .formatter + .begin_array_value(&mut self.ser.writer, self.state == State::First) + .map_err(Error::from)?; + + self.state = State::Rest; + value.serialize(&mut *self.ser).map_err(Error::from)?; + self.ser + .formatter + .end_array_value(&mut self.ser.writer) + .map_err(Error::from) + } + + fn end(self) -> Result<()> { + match self.state { + State::Empty => Ok(()), + _ => self + .ser + .formatter + .end_array(&mut self.ser.writer) + .map_err(Error::from), } - value.serialize(&mut **self) - } - - fn end(self) -> Result<()> { - self.output += "]"; - Ok(()) } } -impl ser::SerializeTuple for &mut Serializer { +impl ser::SerializeTuple for Map<'_, W, F> +where + W: io::Write, + F: Formatter, +{ type Ok = (); type Error = Error; @@ -273,7 +390,11 @@ impl ser::SerializeTuple for &mut Serializer { } } -impl ser::SerializeTupleStruct for &mut Serializer { +impl ser::SerializeTupleStruct for Map<'_, W, F> +where + W: io::Write, + F: Formatter, +{ type Ok = (); type Error = Error; @@ -289,7 +410,11 @@ impl ser::SerializeTupleStruct for &mut Serializer { } } -impl ser::SerializeTupleVariant for &mut Serializer { +impl ser::SerializeTupleVariant for Map<'_, W, F> +where + W: io::Write, + F: Formatter, +{ type Ok = (); type Error = Error; @@ -301,12 +426,30 @@ impl ser::SerializeTupleVariant for &mut Serializer { } fn end(self) -> Result<()> { - self.output += "]}"; - Ok(()) + match self.state { + State::Empty => {} + _ => self + .ser + .formatter + .end_array(&mut self.ser.writer) + .map_err(Error::from)?, + }; + self.ser + .formatter + .end_object_value(&mut self.ser.writer) + .map_err(Error::from)?; + self.ser + .formatter + .end_object(&mut self.ser.writer) + .map_err(Error::from) } } -impl ser::SerializeMap for &mut Serializer { +impl ser::SerializeMap for Map<'_, W, F> +where + W: io::Write, + F: Formatter, +{ type Ok = (); type Error = Error; @@ -314,27 +457,54 @@ impl ser::SerializeMap for &mut Serializer { where T: ?Sized + Serialize, { - if !self.output.ends_with('{') { - self.output += ","; - } - key.serialize(&mut **self) + self.ser + .formatter + .begin_object_key(&mut self.ser.writer, self.state == State::First) + .map_err(Error::from)?; + self.state = State::Rest; + + key.serialize(&mut *self.ser)?; + + self.ser + .formatter + .end_object_key(&mut self.ser.writer) + .map_err(Error::from) } fn serialize_value(&mut self, value: &T) -> Result<()> where T: ?Sized + Serialize, { - self.output += ":"; - value.serialize(&mut **self) + self.ser + .formatter + .begin_object_value(&mut self.ser.writer) + .map_err(Error::from)?; + + value.serialize(&mut *self.ser)?; + + self.ser + .formatter + .end_object_value(&mut self.ser.writer) + .map_err(Error::from) } fn end(self) -> Result<()> { - self.output += "}"; - Ok(()) + match self.state { + State::Empty => Ok(()), + _ => self + .ser + .formatter + .end_object(&mut self.ser.writer) + .map_err(Error::from), + } } } -impl ser::SerializeStruct for &mut Serializer { +impl ser::SerializeStruct for Map<'_, W, F> +where + W: io::Write, + F: Formatter, +{ type Ok = (); type Error = Error; @@ -351,7 +521,11 @@ impl ser::SerializeStruct for &mut Serializer { } } -impl ser::SerializeStructVariant for &mut Serializer { +impl ser::SerializeStructVariant for Map<'_, W, F> +where + W: io::Write, + F: Formatter, +{ type Ok = (); type Error = Error; @@ -363,22 +537,667 @@ impl ser::SerializeStructVariant for &mut Serializer { } fn end(self) -> Result<()> { - self.output += "}}"; + match self.state { + State::Empty => {} + _ => self + .ser + .formatter + .end_object(&mut self.ser.writer) + .map_err(Error::from)?, + }; + self.ser + .formatter + .end_object_value(&mut self.ser.writer) + .map_err(Error::from)?; + self.ser + .formatter + .end_object(&mut self.ser.writer) + .map_err(Error::from) + } +} + +pub fn to_writer(writer: W, value: &T) -> Result<()> +where + W: io::Write, + T: ?Sized + Serialize, +{ + let mut ser = Serializer::new(writer); + value.serialize(&mut ser) +} + +pub fn to_vec(value: &T) -> Result> +where + T: ?Sized + Serialize, +{ + let mut writer = Vec::with_capacity(128); + to_writer(&mut writer, value)?; + Ok(writer) +} + +pub fn to_writer_pretty(writer: W, value: &T, indent: &str) -> Result<()> +where + W: io::Write, + T: ?Sized + Serialize, +{ + let mut ser = Serializer::pretty(writer, indent); + value.serialize(&mut ser) +} + +pub fn to_vec_pretty(value: &T, indent: &str) -> Result> +where + T: ?Sized + Serialize, +{ + let mut writer = Vec::with_capacity(128); + to_writer_pretty(&mut writer, value, indent)?; + Ok(writer) +} + +/// Represents a character escape code in a type-safe manner. +pub enum CharEscape { + /// An escaped quote `"` + Quote, + /// An escaped reverse solidus `\` + ReverseSolidus, + /// An escaped solidus `/` + Solidus, + /// An escaped backspace character (usually escaped as `\b`) + Backspace, + /// An escaped form feed character (usually escaped as `\f`) + FormFeed, + /// An escaped line feed character (usually escaped as `\n`) + LineFeed, + /// An escaped carriage return character (usually escaped as `\r`) + CarriageReturn, + /// An escaped tab character (usually escaped as `\t`) + Tab, + /// An escaped ASCII plane control character (usually escaped as + /// `\u00XX` where `XX` are two hex characters) + AsciiControl(u8), +} + +impl CharEscape { + fn from_escape_table(escape: u8, byte: u8) -> CharEscape { + match escape { + self::BB => CharEscape::Backspace, + self::TT => CharEscape::Tab, + self::NN => CharEscape::LineFeed, + self::FF => CharEscape::FormFeed, + self::RR => CharEscape::CarriageReturn, + self::QU => CharEscape::Quote, + self::BS => CharEscape::ReverseSolidus, + self::UU => CharEscape::AsciiControl(byte), + _ => unreachable!(), + } + } +} + +const BB: u8 = b'b'; // \x08 +const TT: u8 = b't'; // \x09 +const NN: u8 = b'n'; // \x0A +const FF: u8 = b'f'; // \x0C +const RR: u8 = b'r'; // \x0D +const QU: u8 = b'"'; // \x22 +const BS: u8 = b'\\'; // \x5C +const UU: u8 = b'u'; // \x00...\x1F except the ones above +const __: u8 = 0; + +pub trait Formatter { + /// Writes a `null` value to the specified writer. + fn write_null(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b"null") + } + + /// Writes a `true` or `false` value to the specified writer. + fn write_bool(&mut self, writer: &mut W, value: bool) -> io::Result<()> + where + W: ?Sized + io::Write, + { + let s = if value { + b"true" as &[u8] + } else { + b"false" as &[u8] + }; + writer.write_all(s) + } + + /// Writes an integer value like `-123` to the specified writer. + fn write_i8(&mut self, writer: &mut W, value: i8) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes an integer value like `-123` to the specified writer. + fn write_i16(&mut self, writer: &mut W, value: i16) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes an integer value like `-123` to the specified writer. + fn write_i32(&mut self, writer: &mut W, value: i32) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes an integer value like `-123` to the specified writer. + fn write_i64(&mut self, writer: &mut W, value: i64) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes an integer value like `-123` to the specified writer. + fn write_i128(&mut self, writer: &mut W, value: i128) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes an integer value like `123` to the specified writer. + fn write_u8(&mut self, writer: &mut W, value: u8) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes an integer value like `123` to the specified writer. + fn write_u16(&mut self, writer: &mut W, value: u16) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes an integer value like `123` to the specified writer. + fn write_u32(&mut self, writer: &mut W, value: u32) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes an integer value like `123` to the specified writer. + fn write_u64(&mut self, writer: &mut W, value: u64) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes an integer value like `123` to the specified writer. + fn write_u128(&mut self, writer: &mut W, value: u128) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.to_string().as_bytes()) + } + + /// Writes a floating point value like `-31.26e+12` to the specified writer. + fn write_f32(&mut self, writer: &mut W, value: f32) -> io::Result<()> + where + W: ?Sized + io::Write, + { + match value.classify() { + FpCategory::Nan => { + self.write_null(writer)?; + } + FpCategory::Infinite => { + let infinity = if value.is_sign_negative() { + "-9e999" + } else { + "9e999" + }; + writer.write_all(infinity.as_bytes())?; + } + _ => { + writer.write_all(value.to_string().as_bytes())?; + } + } + Ok(()) + } + + /// Writes a floating point value like `-31.26e+12` to the specified writer. + fn write_f64(&mut self, writer: &mut W, value: f64) -> io::Result<()> + where + W: ?Sized + io::Write, + { + match value.classify() { + FpCategory::Nan => { + self.write_null(writer)?; + } + FpCategory::Infinite => { + let infinity = if value.is_sign_negative() { + "-9e999" + } else { + "9e999" + }; + writer.write_all(infinity.as_bytes())?; + } + _ => { + // let mut buffer = ryu::Buffer::new(); + // let s = buffer.format_finite(value); + + // This the previous implementation present in the package + // However, serde_json does it differently above. + // Not sure if there if its done like this because of the precision + let s = &format!("{:.1}", value); + writer.write_all(s.as_bytes())?; + } + } + Ok(()) + } + + /// Writes a number that has already been rendered to a string. + fn write_number_str(&mut self, writer: &mut W, value: &str) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(value.as_bytes()) + } + + /// Called before each series of `write_string_fragment` and + /// `write_char_escape`. Writes a `"` to the specified writer. + fn begin_string(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b"\"") + } + + /// Called after each series of `write_string_fragment` and + /// `write_char_escape`. Writes a `"` to the specified writer. + fn end_string(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b"\"") + } + + /// Writes a string fragment that doesn't need any escaping to the + /// specified writer. + fn write_string_fragment(&mut self, writer: &mut W, fragment: &str) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(fragment.as_bytes()) + } + + /// Writes a character escape code to the specified writer. + fn write_char_escape(&mut self, writer: &mut W, char_escape: CharEscape) -> io::Result<()> + where + W: ?Sized + io::Write, + { + use self::CharEscape::*; + + let s = match char_escape { + Quote => b"\\\"", + ReverseSolidus => b"\\\\", + Solidus => b"\\/", + Backspace => b"\\b", + FormFeed => b"\\f", + LineFeed => b"\\n", + CarriageReturn => b"\\r", + Tab => b"\\t", + AsciiControl(byte) => { + static HEX_DIGITS: [u8; 16] = *b"0123456789abcdef"; + let bytes = &[ + b'\\', + b'u', + b'0', + b'0', + HEX_DIGITS[(byte >> 4) as usize], + HEX_DIGITS[(byte & 0xF) as usize], + ]; + return writer.write_all(bytes); + } + }; + + writer.write_all(s) + } + + /// Writes the representation of a byte array. Formatters can choose whether + /// to represent bytes as a JSON array of integers (the default), or some + /// JSON string encoding like hex or base64. + fn write_byte_array(&mut self, writer: &mut W, value: &[u8]) -> io::Result<()> + where + W: ?Sized + io::Write, + { + self.begin_array(writer)?; + let mut first = true; + for byte in value { + self.begin_array_value(writer, first)?; + self.write_u8(writer, *byte)?; + self.end_array_value(writer)?; + first = false; + } + self.end_array(writer) + } + + /// Called before every array. Writes a `[` to the specified + /// writer. + fn begin_array(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b"[") + } + + /// Called after every array. Writes a `]` to the specified + /// writer. + fn end_array(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b"]") + } + + /// Called before every array value. Writes a `,` if needed to + /// the specified writer. + fn begin_array_value(&mut self, writer: &mut W, first: bool) -> io::Result<()> + where + W: ?Sized + io::Write, + { + if first { + Ok(()) + } else { + writer.write_all(b",") + } + } + + /// Called after every array value. + fn end_array_value(&mut self, _writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + Ok(()) + } + + /// Called before every object. Writes a `{` to the specified + /// writer. + fn begin_object(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b"{") + } + + /// Called after every object. Writes a `}` to the specified + /// writer. + fn end_object(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b"}") + } + + /// Called before every object key. + fn begin_object_key(&mut self, writer: &mut W, first: bool) -> io::Result<()> + where + W: ?Sized + io::Write, + { + if first { + Ok(()) + } else { + writer.write_all(b",") + } + } + + /// Called after every object key. A `:` should be written to the + /// specified writer by either this method or + /// `begin_object_value`. + fn end_object_key(&mut self, _writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + Ok(()) + } + + /// Called before every object value. A `:` should be written to + /// the specified writer by either this method or + /// `end_object_key`. + fn begin_object_value(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b":") + } + + /// Called after every object value. + fn end_object_value(&mut self, _writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + Ok(()) + } + + /// Writes a raw JSON fragment that doesn't need any escaping to the + /// specified writer. + fn write_raw_fragment(&mut self, writer: &mut W, fragment: &str) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(fragment.as_bytes()) + } +} + +fn format_escaped_str(writer: &mut W, formatter: &mut F, value: &str) -> io::Result<()> +where + W: ?Sized + io::Write, + F: ?Sized + Formatter, +{ + formatter.begin_string(writer)?; + format_escaped_str_contents(writer, formatter, value)?; + formatter.end_string(writer) +} + +fn format_escaped_str_contents( + writer: &mut W, + formatter: &mut F, + value: &str, +) -> io::Result<()> +where + W: ?Sized + io::Write, + F: ?Sized + Formatter, +{ + let bytes = value.as_bytes(); + + let mut start = 0; + + for (i, &byte) in bytes.iter().enumerate() { + let escape = self::ESCAPE[byte as usize]; + if escape == 0 { + continue; + } + + if start < i { + formatter.write_string_fragment(writer, &value[start..i])?; + } + + let char_escape = CharEscape::from_escape_table(escape, byte); + formatter.write_char_escape(writer, char_escape)?; + + start = i + 1; + } + + if start == bytes.len() { + return Ok(()); + } + + formatter.write_string_fragment(writer, &value[start..]) +} + +// Lookup table of escape sequences. A value of b'x' at index i means that byte +// i is escaped as "\x" in JSON. A value of 0 means that byte i is not escaped. +static ESCAPE: [u8; 256] = [ + // 1 2 3 4 5 6 7 8 9 A B C D E F + UU, UU, UU, UU, UU, UU, UU, UU, BB, TT, NN, UU, FF, RR, UU, UU, // 0 + UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, // 1 + __, __, QU, __, __, __, __, __, __, __, __, __, __, __, __, __, // 2 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 3 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 + __, __, __, __, __, __, __, __, __, __, __, __, BS, __, __, __, // 5 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F +]; + +/// This structure compacts a JSON value with no extra whitespace. +#[derive(Clone, Debug)] +pub struct CompactFormatter; + +impl Formatter for CompactFormatter {} + +/// This structure pretty prints a JSON value to make it human readable. +#[derive(Clone, Debug)] +pub struct PrettyFormatter<'a> { + current_indent: usize, + has_value: bool, + indent: &'a [u8], +} + +impl<'a> PrettyFormatter<'a> { + /// Construct a pretty printer formatter that defaults to using two spaces for indentation. + pub fn new() -> Self { + PrettyFormatter::with_indent(b" ") + } + + /// Construct a pretty printer formatter that uses the `indent` string for indentation. + pub fn with_indent(indent: &'a [u8]) -> Self { + PrettyFormatter { + current_indent: 0, + has_value: false, + indent, + } + } +} + +impl Default for PrettyFormatter<'_> { + fn default() -> Self { + PrettyFormatter::new() + } +} + +impl Formatter for PrettyFormatter<'_> { + #[inline] + fn begin_array(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + self.current_indent += 1; + self.has_value = false; + writer.write_all(b"[") + } + + #[inline] + fn end_array(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + self.current_indent -= 1; + + if self.has_value { + writer.write_all(b"\n")?; + indent(writer, self.current_indent, self.indent)?; + } + + writer.write_all(b"]") + } + + #[inline] + fn begin_array_value(&mut self, writer: &mut W, first: bool) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(if first { b"\n" } else { b",\n" })?; + indent(writer, self.current_indent, self.indent) + } + + #[inline] + fn end_array_value(&mut self, _writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + self.has_value = true; + Ok(()) + } + + #[inline] + fn begin_object(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + self.current_indent += 1; + self.has_value = false; + writer.write_all(b"{") + } + + #[inline] + fn end_object(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + self.current_indent -= 1; + + if self.has_value { + writer.write_all(b"\n")?; + indent(writer, self.current_indent, self.indent)?; + } + + writer.write_all(b"}") + } + + #[inline] + fn begin_object_key(&mut self, writer: &mut W, first: bool) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(if first { b"\n" } else { b",\n" })?; + indent(writer, self.current_indent, self.indent) + } + + #[inline] + fn begin_object_value(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b": ") + } + + #[inline] + fn end_object_value(&mut self, _writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + self.has_value = true; Ok(()) } } -fn escape(v: &str) -> String { - v.chars() - .flat_map(|c| match c { - '"' => vec!['\\', c], - '\n' => vec!['\\', 'n'], - '\r' => vec!['\\', 'r'], - '\t' => vec!['\\', 't'], - '\\' => vec!['\\', '\\'], - '\u{0008}' => vec!['\\', 'b'], - '\u{000c}' => vec!['\\', 'f'], - c => vec![c], - }) - .collect() +fn indent(wr: &mut W, n: usize, s: &[u8]) -> io::Result<()> +where + W: ?Sized + io::Write, +{ + for _ in 0..n { + wr.write_all(s)?; + } + + Ok(()) } diff --git a/core/lib.rs b/core/lib.rs index e6c812110..c2e0f22c6 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1,9 +1,11 @@ mod error; mod ext; mod function; +mod info; mod io; #[cfg(feature = "json")] mod json; +pub mod mvcc; mod parameters; mod pseudo; mod result; @@ -21,16 +23,16 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; use fallible_iterator::FallibleIterator; #[cfg(not(target_family = "wasm"))] use libloading::{Library, Symbol}; -#[cfg(not(target_family = "wasm"))] use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; use log::trace; +use parking_lot::RwLock; use schema::Schema; use sqlite3_parser::ast; use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; use std::cell::Cell; use std::collections::HashMap; use std::num::NonZero; -use std::sync::{Arc, OnceLock, RwLock}; +use std::sync::{Arc, OnceLock}; use std::{cell::RefCell, rc::Rc}; use storage::btree::btree_init_page; #[cfg(feature = "fs")] @@ -42,11 +44,13 @@ pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; pub use types::Value; use util::parse_schema_rows; +use vdbe::builder::QueryMode; pub use error::LimboError; use translate::select::prepare_select_plan; pub type Result = std::result::Result; +use crate::storage::wal::CheckpointResult; use crate::translate::optimizer::optimize_plan; pub use io::OpenFlags; pub use io::PlatformIO; @@ -61,9 +65,10 @@ pub use storage::pager::Page; pub use storage::pager::Pager; pub use storage::wal::CheckpointStatus; pub use storage::wal::Wal; + pub static DATABASE_VERSION: OnceLock = OnceLock::new(); -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] enum TransactionState { Write, Read, @@ -138,6 +143,9 @@ impl Database { _shared_wal: shared_wal.clone(), syms, }; + if let Err(e) = db.register_builtins() { + return Err(LimboError::ExtensionError(e)); + } let db = Arc::new(db); let conn = Rc::new(Connection { db: db.clone(), @@ -253,10 +261,10 @@ pub struct Connection { } impl Connection { - pub fn prepare(self: &Rc, sql: impl Into) -> Result { - let sql = sql.into(); + pub fn prepare(self: &Rc, sql: impl AsRef) -> Result { + let sql = sql.as_ref(); trace!("Preparing: {}", sql); - let db = self.db.clone(); + let db = &self.db; let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; @@ -270,6 +278,7 @@ impl Connection { self.pager.clone(), Rc::downgrade(self), syms, + QueryMode::Normal, )?); Ok(Statement::new(program, self.pager.clone())) } @@ -281,8 +290,8 @@ impl Connection { } } - pub fn query(self: &Rc, sql: impl Into) -> Result> { - let sql = sql.into(); + pub fn query(self: &Rc, sql: impl AsRef) -> Result> { + let sql = sql.as_ref(); trace!("Querying: {}", sql); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; @@ -304,6 +313,7 @@ impl Connection { self.pager.clone(), Rc::downgrade(self), syms, + QueryMode::Normal, )?); let stmt = Statement::new(program, self.pager.clone()); Ok(Some(stmt)) @@ -316,6 +326,7 @@ impl Connection { self.pager.clone(), Rc::downgrade(self), syms, + QueryMode::Explain, )?; program.explain(); Ok(None) @@ -328,7 +339,7 @@ impl Connection { *select, &self.db.syms.borrow(), )?; - optimize_plan(&mut plan)?; + optimize_plan(&mut plan, &self.schema.borrow())?; println!("{}", plan); } _ => todo!(), @@ -342,9 +353,9 @@ impl Connection { QueryRunner::new(self, sql) } - pub fn execute(self: &Rc, sql: impl Into) -> Result<()> { - let sql = sql.into(); - let db = self.db.clone(); + pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { + let sql = sql.as_ref(); + let db = &self.db; let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; @@ -358,6 +369,7 @@ impl Connection { self.pager.clone(), Rc::downgrade(self), syms, + QueryMode::Explain, )?; program.explain(); } @@ -370,6 +382,7 @@ impl Connection { self.pager.clone(), Rc::downgrade(self), syms, + QueryMode::Normal, )?; let mut state = @@ -390,9 +403,9 @@ impl Connection { Ok(()) } - pub fn checkpoint(&self) -> Result<()> { - self.pager.clear_page_cache(); - Ok(()) + pub fn checkpoint(&self) -> Result { + let checkpoint_result = self.pager.clear_page_cache(); + Ok(checkpoint_result) } #[cfg(not(target_family = "wasm"))] @@ -405,7 +418,7 @@ impl Connection { loop { // TODO: make this async? match self.pager.checkpoint()? { - CheckpointStatus::Done => { + CheckpointStatus::Done(_) => { return Ok(()); } CheckpointStatus::IO => { @@ -455,19 +468,7 @@ impl Statement { } pub fn step(&mut self) -> Result> { - let result = self.program.step(&mut self.state, self.pager.clone())?; - match result { - vdbe::StepResult::Row(row) => Ok(StepResult::Row(Row { values: row.values })), - vdbe::StepResult::IO => Ok(StepResult::IO), - vdbe::StepResult::Done => Ok(StepResult::Done), - vdbe::StepResult::Interrupt => Ok(StepResult::Interrupt), - vdbe::StepResult::Busy => Ok(StepResult::Busy), - } - } - - pub fn query(&mut self) -> Result { - let stmt = Statement::new(self.program.clone(), self.pager.clone()); - Ok(stmt) + self.program.step(&mut self.state, self.pager.clone()) } pub fn columns(&self) -> &[String] { @@ -491,19 +492,9 @@ impl Statement { } } -#[derive(PartialEq)] -pub enum StepResult<'a> { - Row(Row<'a>), - IO, - Done, - Interrupt, - Busy, -} +pub type StepResult<'a> = vdbe::StepResult<'a>; -#[derive(PartialEq)] -pub struct Row<'a> { - pub values: Vec>, -} +pub type Row<'a> = types::Record<'a>; impl<'a> Row<'a> { pub fn get + 'a>(&self, idx: usize) -> Result { @@ -557,7 +548,6 @@ impl SymbolTable { pub fn new() -> Self { Self { functions: HashMap::new(), - // TODO: wasm libs will be very different #[cfg(not(target_family = "wasm"))] extensions: Vec::new(), } diff --git a/core/mvcc/clock.rs b/core/mvcc/clock.rs new file mode 100644 index 000000000..7bab1fe5d --- /dev/null +++ b/core/mvcc/clock.rs @@ -0,0 +1,31 @@ +use std::sync::atomic::{AtomicU64, Ordering}; + +/// Logical clock. +pub trait LogicalClock { + fn get_timestamp(&self) -> u64; + fn reset(&self, ts: u64); +} + +/// A node-local clock backed by an atomic counter. +#[derive(Debug, Default)] +pub struct LocalClock { + ts_sequence: AtomicU64, +} + +impl LocalClock { + pub fn new() -> Self { + Self { + ts_sequence: AtomicU64::new(0), + } + } +} + +impl LogicalClock for LocalClock { + fn get_timestamp(&self) -> u64 { + self.ts_sequence.fetch_add(1, Ordering::SeqCst) + } + + fn reset(&self, ts: u64) { + self.ts_sequence.store(ts, Ordering::SeqCst); + } +} diff --git a/core/mvcc/cursor.rs b/core/mvcc/cursor.rs new file mode 100644 index 000000000..4d120e214 --- /dev/null +++ b/core/mvcc/cursor.rs @@ -0,0 +1,67 @@ +use serde::de::DeserializeOwned; +use serde::Serialize; + +use crate::mvcc::clock::LogicalClock; +use crate::mvcc::database::{Database, Result, Row, RowID}; +use std::fmt::Debug; + +#[derive(Debug)] +pub struct ScanCursor< + 'a, + Clock: LogicalClock, + T: Sync + Send + Clone + Serialize + DeserializeOwned + Debug, +> { + pub db: &'a Database, + pub row_ids: Vec, + pub index: usize, + tx_id: u64, +} + +impl< + 'a, + Clock: LogicalClock, + T: Sync + Send + Clone + Serialize + DeserializeOwned + Debug + 'static, + > ScanCursor<'a, Clock, T> +{ + pub fn new( + db: &'a Database, + tx_id: u64, + table_id: u64, + ) -> Result> { + let row_ids = db.scan_row_ids_for_table(table_id)?; + Ok(Self { + db, + tx_id, + row_ids, + index: 0, + }) + } + + pub fn current_row_id(&self) -> Option { + if self.index >= self.row_ids.len() { + return None; + } + Some(self.row_ids[self.index]) + } + + pub fn current_row(&self) -> Result>> { + if self.index >= self.row_ids.len() { + return Ok(None); + } + let id = self.row_ids[self.index]; + self.db.read(self.tx_id, id) + } + + pub fn close(self) -> Result<()> { + Ok(()) + } + + pub fn forward(&mut self) -> bool { + self.index += 1; + self.index < self.row_ids.len() + } + + pub fn is_empty(&self) -> bool { + self.index >= self.row_ids.len() + } +} diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs new file mode 100644 index 000000000..6fa8420f3 --- /dev/null +++ b/core/mvcc/database/mod.rs @@ -0,0 +1,810 @@ +use crate::mvcc::clock::LogicalClock; +use crate::mvcc::errors::DatabaseError; +use crate::mvcc::persistent_storage::Storage; +use crossbeam_skiplist::{SkipMap, SkipSet}; +use std::fmt::Debug; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::RwLock; + +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct RowID { + pub table_id: u64, + pub row_id: u64, +} + +#[derive(Clone, Debug, PartialEq, PartialOrd)] + +pub struct Row { + pub id: RowID, + pub data: T, +} + +/// A row version. +#[derive(Clone, Debug, PartialEq)] +pub struct RowVersion { + begin: TxTimestampOrID, + end: Option, + row: Row, +} + +pub type TxID = u64; + +/// A log record contains all the versions inserted and deleted by a transaction. +#[derive(Clone, Debug)] +pub struct LogRecord { + pub(crate) tx_timestamp: TxID, + row_versions: Vec>, +} + +impl LogRecord { + fn new(tx_timestamp: TxID) -> Self { + Self { + tx_timestamp, + row_versions: Vec::new(), + } + } +} + +/// A transaction timestamp or ID. +/// +/// Versions either track a timestamp or a transaction ID, depending on the +/// phase of the transaction. During the active phase, new versions track the +/// transaction ID in the `begin` and `end` fields. After a transaction commits, +/// versions switch to tracking timestamps. +#[derive(Clone, Debug, PartialEq, PartialOrd)] +enum TxTimestampOrID { + Timestamp(u64), + TxID(TxID), +} + +/// Transaction +#[derive(Debug)] +pub struct Transaction { + /// The state of the transaction. + state: AtomicTransactionState, + /// The transaction ID. + tx_id: u64, + /// The transaction begin timestamp. + begin_ts: u64, + /// The transaction write set. + write_set: SkipSet, + /// The transaction read set. + read_set: SkipSet, +} + +impl Transaction { + fn new(tx_id: u64, begin_ts: u64) -> Transaction { + Transaction { + state: TransactionState::Active.into(), + tx_id, + begin_ts, + write_set: SkipSet::new(), + read_set: SkipSet::new(), + } + } + + fn insert_to_read_set(&self, id: RowID) { + self.read_set.insert(id); + } + + fn insert_to_write_set(&mut self, id: RowID) { + self.write_set.insert(id); + } +} + +impl std::fmt::Display for Transaction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!( + f, + "{{ state: {}, id: {}, begin_ts: {}, write_set: {:?}, read_set: {:?}", + self.state.load(), + self.tx_id, + self.begin_ts, + // FIXME: I'm sorry, we obviously shouldn't be cloning here. + self.write_set + .iter() + .map(|v| *v.value()) + .collect::>(), + self.read_set + .iter() + .map(|v| *v.value()) + .collect::>() + ) + } +} + +/// Transaction state. +#[derive(Debug, Clone, PartialEq)] +enum TransactionState { + Active, + Preparing, + Aborted, + Terminated, + Committed(u64), +} + +impl TransactionState { + pub fn encode(&self) -> u64 { + match self { + TransactionState::Active => 0, + TransactionState::Preparing => 1, + TransactionState::Aborted => 2, + TransactionState::Terminated => 3, + TransactionState::Committed(ts) => { + // We only support 2*62 - 1 timestamps, because the extra bit + // is used to encode the type. + assert!(ts & 0x8000_0000_0000_0000 == 0); + 0x8000_0000_0000_0000 | ts + } + } + } + + pub fn decode(v: u64) -> Self { + match v { + 0 => TransactionState::Active, + 1 => TransactionState::Preparing, + 2 => TransactionState::Aborted, + 3 => TransactionState::Terminated, + v if v & 0x8000_0000_0000_0000 != 0 => { + TransactionState::Committed(v & 0x7fff_ffff_ffff_ffff) + } + _ => panic!("Invalid transaction state"), + } + } +} + +// Transaction state encoded into a single 64-bit atomic. +#[derive(Debug)] +pub(crate) struct AtomicTransactionState { + pub(crate) state: AtomicU64, +} + +impl From for AtomicTransactionState { + fn from(state: TransactionState) -> Self { + Self { + state: AtomicU64::new(state.encode()), + } + } +} + +impl From for TransactionState { + fn from(state: AtomicTransactionState) -> Self { + let encoded = state.state.load(Ordering::Acquire); + TransactionState::decode(encoded) + } +} + +impl std::cmp::PartialEq for AtomicTransactionState { + fn eq(&self, other: &TransactionState) -> bool { + &self.load() == other + } +} + +impl std::fmt::Display for TransactionState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + match self { + TransactionState::Active => write!(f, "Active"), + TransactionState::Preparing => write!(f, "Preparing"), + TransactionState::Committed(ts) => write!(f, "Committed({ts})"), + TransactionState::Aborted => write!(f, "Aborted"), + TransactionState::Terminated => write!(f, "Terminated"), + } + } +} + +impl AtomicTransactionState { + fn store(&self, state: TransactionState) { + self.state.store(state.encode(), Ordering::Release); + } + + fn load(&self) -> TransactionState { + TransactionState::decode(self.state.load(Ordering::Acquire)) + } +} + +#[derive(Debug)] +pub struct Database { + rows: SkipMap>>>, + txs: SkipMap>, + tx_ids: AtomicU64, + clock: Clock, + storage: Storage, +} + +impl Database { + /// Creates a new database. + pub fn new(clock: Clock, storage: Storage) -> Self { + Self { + rows: SkipMap::new(), + txs: SkipMap::new(), + tx_ids: AtomicU64::new(1), // let's reserve transaction 0 for special purposes + clock, + storage, + } + } + + // Extracts the begin timestamp from a transaction + fn get_begin_timestamp(&self, ts_or_id: &TxTimestampOrID) -> u64 { + match ts_or_id { + TxTimestampOrID::Timestamp(ts) => *ts, + TxTimestampOrID::TxID(tx_id) => { + self.txs + .get(tx_id) + .unwrap() + .value() + .read() + .unwrap() + .begin_ts + } + } + } + + /// Inserts a new row version into the database, while making sure that + /// the row version is inserted in the correct order. + fn insert_version(&self, id: RowID, row_version: RowVersion) { + let versions = self.rows.get_or_insert_with(id, || RwLock::new(Vec::new())); + let mut versions = versions.value().write().unwrap(); + self.insert_version_raw(&mut versions, row_version) + } + + /// Inserts a new row version into the internal data structure for versions, + /// while making sure that the row version is inserted in the correct order. + fn insert_version_raw(&self, versions: &mut Vec>, row_version: RowVersion) { + // NOTICE: this is an insert a'la insertion sort, with pessimistic linear complexity. + // However, we expect the number of versions to be nearly sorted, so we deem it worthy + // to search linearly for the insertion point instead of paying the price of using + // another data structure, e.g. a BTreeSet. If it proves to be too quadratic empirically, + // we can either switch to a tree-like structure, or at least use partition_point() + // which performs a binary search for the insertion point. + let position = versions + .iter() + .rposition(|v| { + self.get_begin_timestamp(&v.begin) < self.get_begin_timestamp(&row_version.begin) + }) + .map(|p| p + 1) + .unwrap_or(0); + if versions.len() - position > 3 { + tracing::debug!( + "Inserting a row version {} positions from the end", + versions.len() - position + ); + } + versions.insert(position, row_version); + } + + /// Inserts a new row into the database. + /// + /// This function inserts a new `row` into the database within the context + /// of the transaction `tx_id`. + /// + /// # Arguments + /// + /// * `tx_id` - the ID of the transaction in which to insert the new row. + /// * `row` - the row object containing the values to be inserted. + /// + pub fn insert(&self, tx_id: TxID, row: Row) -> Result<()> { + let tx = self + .txs + .get(&tx_id) + .ok_or(DatabaseError::NoSuchTransactionID(tx_id))?; + let mut tx = tx.value().write().unwrap(); + assert_eq!(tx.state, TransactionState::Active); + let id = row.id; + let row_version = RowVersion { + begin: TxTimestampOrID::TxID(tx.tx_id), + end: None, + row, + }; + tx.insert_to_write_set(id); + drop(tx); + self.insert_version(id, row_version); + Ok(()) + } + + /// Updates a row in the database with new values. + /// + /// This function updates an existing row in the database within the + /// context of the transaction `tx_id`. The `row` argument identifies the + /// row to be updated as `id` and contains the new values to be inserted. + /// + /// If the row identified by the `id` does not exist, this function does + /// nothing and returns `false`. Otherwise, the function updates the row + /// with the new values and returns `true`. + /// + /// # Arguments + /// + /// * `tx_id` - the ID of the transaction in which to update the new row. + /// * `row` - the row object containing the values to be updated. + /// + /// # Returns + /// + /// Returns `true` if the row was successfully updated, and `false` otherwise. + pub fn update(&self, tx_id: TxID, row: Row) -> Result { + if !self.delete(tx_id, row.id)? { + return Ok(false); + } + self.insert(tx_id, row)?; + Ok(true) + } + + /// Inserts a row in the database with new values, previously deleting + /// any old data if it existed. Bails on a delete error, e.g. write-write conflict. + pub fn upsert(&self, tx_id: TxID, row: Row) -> Result<()> { + self.delete(tx_id, row.id)?; + self.insert(tx_id, row) + } + + /// Deletes a row from the table with the given `id`. + /// + /// This function deletes an existing row `id` in the database within the + /// context of the transaction `tx_id`. + /// + /// # Arguments + /// + /// * `tx_id` - the ID of the transaction in which to delete the new row. + /// * `id` - the ID of the row to delete. + /// + /// # Returns + /// + /// Returns `true` if the row was successfully deleted, and `false` otherwise. + /// + pub fn delete(&self, tx_id: TxID, id: RowID) -> Result { + let row_versions_opt = self.rows.get(&id); + if let Some(ref row_versions) = row_versions_opt { + let mut row_versions = row_versions.value().write().unwrap(); + for rv in row_versions.iter_mut().rev() { + let tx = self + .txs + .get(&tx_id) + .ok_or(DatabaseError::NoSuchTransactionID(tx_id))?; + let tx = tx.value().read().unwrap(); + assert_eq!(tx.state, TransactionState::Active); + if is_write_write_conflict(&self.txs, &tx, rv) { + drop(row_versions); + drop(row_versions_opt); + drop(tx); + self.rollback_tx(tx_id); + return Err(DatabaseError::WriteWriteConflict); + } + if is_version_visible(&self.txs, &tx, rv) { + rv.end = Some(TxTimestampOrID::TxID(tx.tx_id)); + drop(row_versions); + drop(row_versions_opt); + drop(tx); + let tx = self + .txs + .get(&tx_id) + .ok_or(DatabaseError::NoSuchTransactionID(tx_id))?; + let mut tx = tx.value().write().unwrap(); + tx.insert_to_write_set(id); + return Ok(true); + } + } + } + Ok(false) + } + + /// Retrieves a row from the table with the given `id`. + /// + /// This operation is performed within the scope of the transaction identified + /// by `tx_id`. + /// + /// # Arguments + /// + /// * `tx_id` - The ID of the transaction to perform the read operation in. + /// * `id` - The ID of the row to retrieve. + /// + /// # Returns + /// + /// Returns `Some(row)` with the row data if the row with the given `id` exists, + /// and `None` otherwise. + pub fn read(&self, tx_id: TxID, id: RowID) -> Result>> { + let tx = self.txs.get(&tx_id).unwrap(); + let tx = tx.value().read().unwrap(); + assert_eq!(tx.state, TransactionState::Active); + if let Some(row_versions) = self.rows.get(&id) { + let row_versions = row_versions.value().read().unwrap(); + for rv in row_versions.iter().rev() { + if is_version_visible(&self.txs, &tx, rv) { + tx.insert_to_read_set(id); + return Ok(Some(rv.row.clone())); + } + } + } + Ok(None) + } + + /// Gets all row ids in the database. + pub fn scan_row_ids(&self) -> Result> { + let keys = self.rows.iter().map(|entry| *entry.key()); + Ok(keys.collect()) + } + + /// Gets all row ids in the database for a given table. + pub fn scan_row_ids_for_table(&self, table_id: u64) -> Result> { + Ok(self + .rows + .range( + RowID { + table_id, + row_id: 0, + }..RowID { + table_id, + row_id: u64::MAX, + }, + ) + .map(|entry| *entry.key()) + .collect()) + } + + /// Begins a new transaction in the database. + /// + /// This function starts a new transaction in the database and returns a `TxID` value + /// that you can use to perform operations within the transaction. All changes made within the + /// transaction are isolated from other transactions until you commit the transaction. + pub fn begin_tx(&self) -> TxID { + let tx_id = self.get_tx_id(); + let begin_ts = self.get_timestamp(); + let tx = Transaction::new(tx_id, begin_ts); + tracing::trace!("BEGIN {tx}"); + self.txs.insert(tx_id, RwLock::new(tx)); + tx_id + } + + /// Commits a transaction with the specified transaction ID. + /// + /// This function commits the changes made within the specified transaction and finalizes the + /// transaction. Once a transaction has been committed, all changes made within the transaction + /// are visible to other transactions that access the same data. + /// + /// # Arguments + /// + /// * `tx_id` - The ID of the transaction to commit. + pub fn commit_tx(&self, tx_id: TxID) -> Result<()> { + let end_ts = self.get_timestamp(); + // NOTICE: the first shadowed tx keeps the entry alive in the map + // for the duration of this whole function, which is important for correctness! + let tx = self.txs.get(&tx_id).ok_or(DatabaseError::TxTerminated)?; + let tx = tx.value().write().unwrap(); + match tx.state.load() { + TransactionState::Terminated => return Err(DatabaseError::TxTerminated), + _ => { + assert_eq!(tx.state, TransactionState::Active); + } + } + tx.state.store(TransactionState::Preparing); + tracing::trace!("PREPARE {tx}"); + + /* TODO: The code we have here is sufficient for snapshot isolation. + ** In order to implement serializability, we need the following steps: + ** + ** 1. Validate if all read versions are still visible by inspecting the read_set + ** 2. Validate if there are no phantoms by walking the scans from scan_set (which we don't even have yet) + ** - a phantom is a version that became visible in the middle of our transaction, + ** but wasn't taken into account during one of the scans from the scan_set + ** 3. Wait for commit dependencies, which we don't even track yet... + ** Excerpt from what's a commit dependency and how it's tracked in the original paper: + ** """ + A transaction T1 has a commit dependency on another transaction + T2, if T1 is allowed to commit only if T2 commits. If T2 aborts, + T1 must also abort, so cascading aborts are possible. T1 acquires a + commit dependency either by speculatively reading or speculatively ignoring a version, + instead of waiting for T2 to commit. + We implement commit dependencies by a register-and-report + approach: T1 registers its dependency with T2 and T2 informs T1 + when it has committed or aborted. Each transaction T contains a + counter, CommitDepCounter, that counts how many unresolved + commit dependencies it still has. A transaction cannot commit + until this counter is zero. In addition, T has a Boolean variable + AbortNow that other transactions can set to tell T to abort. Each + transaction T also has a set, CommitDepSet, that stores transaction IDs + of the transactions that depend on T. + To take a commit dependency on a transaction T2, T1 increments + its CommitDepCounter and adds its transaction ID to T2’s CommitDepSet. + When T2 has committed, it locates each transaction in + its CommitDepSet and decrements their CommitDepCounter. If + T2 aborted, it tells the dependent transactions to also abort by + setting their AbortNow flags. If a dependent transaction is not + found, this means that it has already aborted. + Note that a transaction with commit dependencies may not have to + wait at all - the dependencies may have been resolved before it is + ready to commit. Commit dependencies consolidate all waits into + a single wait and postpone the wait to just before commit. + Some transactions may have to wait before commit. + Waiting raises a concern of deadlocks. + However, deadlocks cannot occur because an older transaction never + waits on a younger transaction. In + a wait-for graph the direction of edges would always be from a + younger transaction (higher end timestamp) to an older transaction + (lower end timestamp) so cycles are impossible. + """ + ** If you're wondering when a speculative read happens, here you go: + ** Case 1: speculative read of TB: + """ + If transaction TB is in the Preparing state, it has acquired an end + timestamp TS which will be V’s begin timestamp if TB commits. + A safe approach in this situation would be to have transaction T + wait until transaction TB commits. However, we want to avoid all + blocking during normal processing so instead we continue with + the visibility test and, if the test returns true, allow T to + speculatively read V. Transaction T acquires a commit dependency on + TB, restricting the serialization order of the two transactions. That + is, T is allowed to commit only if TB commits. + """ + ** Case 2: speculative ignore of TE: + """ + If TE’s state is Preparing, it has an end timestamp TS that will become + the end timestamp of V if TE does commit. If TS is greater than the read + time RT, it is obvious that V will be visible if TE commits. If TE + aborts, V will still be visible, because any transaction that updates + V after TE has aborted will obtain an end timestamp greater than + TS. If TS is less than RT, we have a more complicated situation: + if TE commits, V will not be visible to T but if TE aborts, it will + be visible. We could handle this by forcing T to wait until TE + commits or aborts but we want to avoid all blocking during normal processing. + Instead we allow T to speculatively ignore V and + proceed with its processing. Transaction T acquires a commit + dependency (see Section 2.7) on TE, that is, T is allowed to commit + only if TE commits. + """ + */ + tx.state.store(TransactionState::Committed(end_ts)); + tracing::trace!("COMMIT {tx}"); + let tx_begin_ts = tx.begin_ts; + let write_set: Vec = tx.write_set.iter().map(|v| *v.value()).collect(); + drop(tx); + // Postprocessing: inserting row versions and logging the transaction to persistent storage. + // TODO: we should probably save to persistent storage first, and only then update the in-memory structures. + let mut log_record: LogRecord = LogRecord::new(end_ts); + for ref id in write_set { + if let Some(row_versions) = self.rows.get(id) { + let mut row_versions = row_versions.value().write().unwrap(); + for row_version in row_versions.iter_mut() { + if let TxTimestampOrID::TxID(id) = row_version.begin { + if id == tx_id { + row_version.begin = TxTimestampOrID::Timestamp(tx_begin_ts); + self.insert_version_raw( + &mut log_record.row_versions, + row_version.clone(), + ); // FIXME: optimize cloning out + } + } + if let Some(TxTimestampOrID::TxID(id)) = row_version.end { + if id == tx_id { + row_version.end = Some(TxTimestampOrID::Timestamp(end_ts)); + self.insert_version_raw( + &mut log_record.row_versions, + row_version.clone(), + ); // FIXME: optimize cloning out + } + } + } + } + } + tracing::trace!("UPDATED TX{tx_id}"); + // We have now updated all the versions with a reference to the + // transaction ID to a timestamp and can, therefore, remove the + // transaction. Please note that when we move to lockless, the + // invariant doesn't necessarily hold anymore because another thread + // might have speculatively read a version that we want to remove. + // But that's a problem for another day. + // FIXME: it actually just become a problem for today!!! + // TODO: test that reproduces this failure, and then a fix + self.txs.remove(&tx_id); + if !log_record.row_versions.is_empty() { + self.storage.log_tx(log_record)?; + } + tracing::trace!("LOGGED {tx_id}"); + Ok(()) + } + + /// Rolls back a transaction with the specified ID. + /// + /// This function rolls back a transaction with the specified `tx_id` by + /// discarding any changes made by the transaction. + /// + /// # Arguments + /// + /// * `tx_id` - The ID of the transaction to abort. + pub fn rollback_tx(&self, tx_id: TxID) { + let tx_unlocked = self.txs.get(&tx_id).unwrap(); + let tx = tx_unlocked.value().write().unwrap(); + assert_eq!(tx.state, TransactionState::Active); + tx.state.store(TransactionState::Aborted); + tracing::trace!("ABORT {tx}"); + let write_set: Vec = tx.write_set.iter().map(|v| *v.value()).collect(); + drop(tx); + + for ref id in write_set { + if let Some(row_versions) = self.rows.get(id) { + let mut row_versions = row_versions.value().write().unwrap(); + row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id)); + if row_versions.is_empty() { + self.rows.remove(id); + } + } + } + + let tx = tx_unlocked.value().read().unwrap(); + tx.state.store(TransactionState::Terminated); + tracing::trace!("TERMINATE {tx}"); + // FIXME: verify that we can already remove the transaction here! + // Maybe it's fine for snapshot isolation, but too early for serializable? + self.txs.remove(&tx_id); + } + + /// Generates next unique transaction id + pub fn get_tx_id(&self) -> u64 { + self.tx_ids.fetch_add(1, Ordering::SeqCst) + } + + /// Gets current timestamp + pub fn get_timestamp(&self) -> u64 { + self.clock.get_timestamp() + } + + /// Removes unused row versions with very loose heuristics, + /// which sometimes leaves versions intact for too long. + /// Returns the number of removed versions. + pub fn drop_unused_row_versions(&self) -> usize { + tracing::trace!( + "Dropping unused row versions. Database stats: transactions: {}; rows: {}", + self.txs.len(), + self.rows.len() + ); + let mut dropped = 0; + let mut to_remove = Vec::new(); + for entry in self.rows.iter() { + let mut row_versions = entry.value().write().unwrap(); + row_versions.retain(|rv| { + // FIXME: should take rv.begin into account as well + let should_stay = match rv.end { + Some(TxTimestampOrID::Timestamp(version_end_ts)) => { + // a transaction started before this row version ended, ergo row version is needed + // NOTICE: O(row_versions x transactions), but also lock-free, so sounds acceptable + self.txs.iter().any(|tx| { + let tx = tx.value().read().unwrap(); + // FIXME: verify! + match tx.state.load() { + TransactionState::Active | TransactionState::Preparing => { + version_end_ts > tx.begin_ts + } + _ => false, + } + }) + } + // Let's skip potentially complex logic if the transafction is still + // active/tracked. We will drop the row version when the transaction + // gets garbage-collected itself, it will always happen eventually. + Some(TxTimestampOrID::TxID(tx_id)) => !self.txs.contains_key(&tx_id), + // this row version is current, ergo visible + None => true, + }; + if !should_stay { + dropped += 1; + tracing::trace!( + "Dropping row version {:?} {:?}-{:?}", + entry.key(), + rv.begin, + rv.end + ); + } + should_stay + }); + if row_versions.is_empty() { + to_remove.push(*entry.key()); + } + } + for id in to_remove { + self.rows.remove(&id); + } + dropped + } + + pub fn recover(&self) -> Result<()> { + let tx_log = self.storage.read_tx_log()?; + for record in tx_log { + tracing::debug!("RECOVERING {:?}", record); + for version in record.row_versions { + self.insert_version(version.row.id, version); + } + self.clock.reset(record.tx_timestamp); + } + Ok(()) + } +} + +/// A write-write conflict happens when transaction T_m attempts to update a +/// row version that is currently being updated by an active transaction T_n. +pub(crate) fn is_write_write_conflict( + txs: &SkipMap>, + tx: &Transaction, + rv: &RowVersion, +) -> bool { + match rv.end { + Some(TxTimestampOrID::TxID(rv_end)) => { + let te = txs.get(&rv_end).unwrap(); + let te = te.value().read().unwrap(); + match te.state.load() { + TransactionState::Active | TransactionState::Preparing => tx.tx_id != te.tx_id, + _ => false, + } + } + Some(TxTimestampOrID::Timestamp(_)) => false, + None => false, + } +} + +pub(crate) fn is_version_visible( + txs: &SkipMap>, + tx: &Transaction, + rv: &RowVersion, +) -> bool { + is_begin_visible(txs, tx, rv) && is_end_visible(txs, tx, rv) +} + +fn is_begin_visible( + txs: &SkipMap>, + tx: &Transaction, + rv: &RowVersion, +) -> bool { + match rv.begin { + TxTimestampOrID::Timestamp(rv_begin_ts) => tx.begin_ts >= rv_begin_ts, + TxTimestampOrID::TxID(rv_begin) => { + let tb = txs.get(&rv_begin).unwrap(); + let tb = tb.value().read().unwrap(); + let visible = match tb.state.load() { + TransactionState::Active => tx.tx_id == tb.tx_id && rv.end.is_none(), + TransactionState::Preparing => false, // NOTICE: makes sense for snapshot isolation, not so much for serializable! + TransactionState::Committed(committed_ts) => tx.begin_ts >= committed_ts, + TransactionState::Aborted => false, + TransactionState::Terminated => { + tracing::debug!("TODO: should reread rv's end field - it should have updated the timestamp in the row version by now"); + false + } + }; + tracing::trace!( + "is_begin_visible: tx={tx}, tb={tb} rv = {:?}-{:?} visible = {visible}", + rv.begin, + rv.end + ); + visible + } + } +} + +fn is_end_visible( + txs: &SkipMap>, + tx: &Transaction, + rv: &RowVersion, +) -> bool { + match rv.end { + Some(TxTimestampOrID::Timestamp(rv_end_ts)) => tx.begin_ts < rv_end_ts, + Some(TxTimestampOrID::TxID(rv_end)) => { + let te = txs.get(&rv_end).unwrap(); + let te = te.value().read().unwrap(); + let visible = match te.state.load() { + TransactionState::Active => tx.tx_id != te.tx_id, + TransactionState::Preparing => false, // NOTICE: makes sense for snapshot isolation, not so much for serializable! + TransactionState::Committed(committed_ts) => tx.begin_ts < committed_ts, + TransactionState::Aborted => false, + TransactionState::Terminated => { + tracing::debug!("TODO: should reread rv's end field - it should have updated the timestamp in the row version by now"); + false + } + }; + tracing::trace!( + "is_end_visible: tx={tx}, te={te} rv = {:?}-{:?} visible = {visible}", + rv.begin, + rv.end + ); + visible + } + None => true, + } +} diff --git a/core/mvcc/database/tests.rs b/core/mvcc/database/tests.rs new file mode 100644 index 000000000..741ada4cb --- /dev/null +++ b/core/mvcc/database/tests.rs @@ -0,0 +1,760 @@ +use super::*; +use crate::mvcc::clock::LocalClock; + +fn test_db() -> Database { + let clock = LocalClock::new(); + let storage = crate::mvcc::persistent_storage::Storage::new_noop(); + Database::new(clock, storage) +} + +#[test] +fn test_insert_read() { + let db = test_db(); + + let tx1 = db.begin_tx(); + let tx1_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }; + db.insert(tx1, tx1_row.clone()).unwrap(); + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); + db.commit_tx(tx1).unwrap(); + + let tx2 = db.begin_tx(); + let row = db + .read( + tx2, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); +} + +#[test] +fn test_read_nonexistent() { + let db = test_db(); + let tx = db.begin_tx(); + let row = db.read( + tx, + RowID { + table_id: 1, + row_id: 1, + }, + ); + assert!(row.unwrap().is_none()); +} + +#[test] +fn test_delete() { + let db = test_db(); + + let tx1 = db.begin_tx(); + let tx1_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }; + db.insert(tx1, tx1_row.clone()).unwrap(); + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); + db.delete( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap(); + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap(); + assert!(row.is_none()); + db.commit_tx(tx1).unwrap(); + + let tx2 = db.begin_tx(); + let row = db + .read( + tx2, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap(); + assert!(row.is_none()); +} + +#[test] +fn test_delete_nonexistent() { + let db = test_db(); + let tx = db.begin_tx(); + assert!(!db + .delete( + tx, + RowID { + table_id: 1, + row_id: 1 + } + ) + .unwrap()); +} + +#[test] +fn test_commit() { + let db = test_db(); + let tx1 = db.begin_tx(); + let tx1_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }; + db.insert(tx1, tx1_row.clone()).unwrap(); + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); + let tx1_updated_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "World".to_string(), + }; + db.update(tx1, tx1_updated_row.clone()).unwrap(); + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_updated_row, row); + db.commit_tx(tx1).unwrap(); + + let tx2 = db.begin_tx(); + let row = db + .read( + tx2, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + db.commit_tx(tx2).unwrap(); + assert_eq!(tx1_updated_row, row); + db.drop_unused_row_versions(); +} + +#[test] +fn test_rollback() { + let db = test_db(); + let tx1 = db.begin_tx(); + let row1 = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }; + db.insert(tx1, row1.clone()).unwrap(); + let row2 = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(row1, row2); + let row3 = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "World".to_string(), + }; + db.update(tx1, row3.clone()).unwrap(); + let row4 = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(row3, row4); + db.rollback_tx(tx1); + let tx2 = db.begin_tx(); + let row5 = db + .read( + tx2, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap(); + assert_eq!(row5, None); +} + +#[test] +fn test_dirty_write() { + let db = test_db(); + + // T1 inserts a row with ID 1, but does not commit. + let tx1 = db.begin_tx(); + let tx1_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }; + db.insert(tx1, tx1_row.clone()).unwrap(); + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); + + // T2 attempts to delete row with ID 1, but fails because T1 has not committed. + let tx2 = db.begin_tx(); + let tx2_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "World".to_string(), + }; + assert!(!db.update(tx2, tx2_row).unwrap()); + + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); +} + +#[test] +fn test_dirty_read() { + let db = test_db(); + + // T1 inserts a row with ID 1, but does not commit. + let tx1 = db.begin_tx(); + let row1 = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }; + db.insert(tx1, row1).unwrap(); + + // T2 attempts to read row with ID 1, but doesn't see one because T1 has not committed. + let tx2 = db.begin_tx(); + let row2 = db + .read( + tx2, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap(); + assert_eq!(row2, None); +} + +#[test] +fn test_dirty_read_deleted() { + let db = test_db(); + + // T1 inserts a row with ID 1 and commits. + let tx1 = db.begin_tx(); + let tx1_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }; + db.insert(tx1, tx1_row.clone()).unwrap(); + db.commit_tx(tx1).unwrap(); + + // T2 deletes row with ID 1, but does not commit. + let tx2 = db.begin_tx(); + assert!(db + .delete( + tx2, + RowID { + table_id: 1, + row_id: 1 + } + ) + .unwrap()); + + // T3 reads row with ID 1, but doesn't see the delete because T2 hasn't committed. + let tx3 = db.begin_tx(); + let row = db + .read( + tx3, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); +} + +#[test] +fn test_fuzzy_read() { + let db = test_db(); + + // T1 inserts a row with ID 1 and commits. + let tx1 = db.begin_tx(); + let tx1_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }; + db.insert(tx1, tx1_row.clone()).unwrap(); + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); + db.commit_tx(tx1).unwrap(); + + // T2 reads the row with ID 1 within an active transaction. + let tx2 = db.begin_tx(); + let row = db + .read( + tx2, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); + + // T3 updates the row and commits. + let tx3 = db.begin_tx(); + let tx3_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "World".to_string(), + }; + db.update(tx3, tx3_row).unwrap(); + db.commit_tx(tx3).unwrap(); + + // T2 still reads the same version of the row as before. + let row = db + .read( + tx2, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); +} + +#[test] +fn test_lost_update() { + let db = test_db(); + + // T1 inserts a row with ID 1 and commits. + let tx1 = db.begin_tx(); + let tx1_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello".to_string(), + }; + db.insert(tx1, tx1_row.clone()).unwrap(); + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); + db.commit_tx(tx1).unwrap(); + + // T2 attempts to update row ID 1 within an active transaction. + let tx2 = db.begin_tx(); + let tx2_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "World".to_string(), + }; + assert!(db.update(tx2, tx2_row.clone()).unwrap()); + + // T3 also attempts to update row ID 1 within an active transaction. + let tx3 = db.begin_tx(); + let tx3_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "Hello, world!".to_string(), + }; + assert_eq!( + Err(DatabaseError::WriteWriteConflict), + db.update(tx3, tx3_row) + ); + + db.commit_tx(tx2).unwrap(); + assert_eq!(Err(DatabaseError::TxTerminated), db.commit_tx(tx3)); + + let tx4 = db.begin_tx(); + let row = db + .read( + tx4, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx2_row, row); +} + +// Test for the visibility to check if a new transaction can see old committed values. +// This test checks for the typo present in the paper, explained in https://github.com/penberg/mvcc-rs/issues/15 +#[test] +fn test_committed_visibility() { + let db = test_db(); + + // let's add $10 to my account since I like money + let tx1 = db.begin_tx(); + let tx1_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "10".to_string(), + }; + db.insert(tx1, tx1_row.clone()).unwrap(); + db.commit_tx(tx1).unwrap(); + + // but I like more money, so let me try adding $10 more + let tx2 = db.begin_tx(); + let tx2_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "20".to_string(), + }; + assert!(db.update(tx2, tx2_row.clone()).unwrap()); + let row = db + .read( + tx2, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(row, tx2_row); + + // can I check how much money I have? + let tx3 = db.begin_tx(); + let row = db + .read( + tx3, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(tx1_row, row); +} + +// Test to check if a older transaction can see (un)committed future rows +#[test] +fn test_future_row() { + let db = test_db(); + + let tx1 = db.begin_tx(); + + let tx2 = db.begin_tx(); + let tx2_row = Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "10".to_string(), + }; + db.insert(tx2, tx2_row).unwrap(); + + // transaction in progress, so tx1 shouldn't be able to see the value + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap(); + assert_eq!(row, None); + + // lets commit the transaction and check if tx1 can see it + db.commit_tx(tx2).unwrap(); + let row = db + .read( + tx1, + RowID { + table_id: 1, + row_id: 1, + }, + ) + .unwrap(); + assert_eq!(row, None); +} + +/* States described in the Hekaton paper *for serializability*: + +Table 1: Case analysis of action to take when version V’s +Begin field contains the ID of transaction TB +------------------------------------------------------------------------------------------------------ +TB’s state | TB’s end timestamp | Action to take when transaction T checks visibility of version V. +------------------------------------------------------------------------------------------------------ +Active | Not set | V is visible only if TB=T and V’s end timestamp equals infinity. +------------------------------------------------------------------------------------------------------ +Preparing | TS | V’s begin timestamp will be TS ut V is not yet committed. Use TS + | as V’s begin time when testing visibility. If the test is true, + | allow T to speculatively read V. Committed TS V’s begin timestamp + | will be TS and V is committed. Use TS as V’s begin time to test + | visibility. +------------------------------------------------------------------------------------------------------ +Committed | TS | V’s begin timestamp will be TS and V is committed. Use TS as V’s + | begin time to test visibility. +------------------------------------------------------------------------------------------------------ +Aborted | Irrelevant | Ignore V; it’s a garbage version. +------------------------------------------------------------------------------------------------------ +Terminated | Irrelevant | Reread V’s Begin field. TB has terminated so it must have finalized +or not found | | the timestamp. +------------------------------------------------------------------------------------------------------ + +Table 2: Case analysis of action to take when V's End field +contains a transaction ID TE. +------------------------------------------------------------------------------------------------------ +TE’s state | TE’s end timestamp | Action to take when transaction T checks visibility of a version V + | | as of read time RT. +------------------------------------------------------------------------------------------------------ +Active | Not set | V is visible only if TE is not T. +------------------------------------------------------------------------------------------------------ +Preparing | TS | V’s end timestamp will be TS provided that TE commits. If TS > RT, + | V is visible to T. If TS < RT, T speculatively ignores V. +------------------------------------------------------------------------------------------------------ +Committed | TS | V’s end timestamp will be TS and V is committed. Use TS as V’s end + | timestamp when testing visibility. +------------------------------------------------------------------------------------------------------ +Aborted | Irrelevant | V is visible. +------------------------------------------------------------------------------------------------------ +Terminated | Irrelevant | Reread V’s End field. TE has terminated so it must have finalized +or not found | | the timestamp. +*/ + +fn new_tx(tx_id: TxID, begin_ts: u64, state: TransactionState) -> RwLock { + let state = state.into(); + RwLock::new(Transaction { + state, + tx_id, + begin_ts, + write_set: SkipSet::new(), + read_set: SkipSet::new(), + }) +} + +#[test] +fn test_snapshot_isolation_tx_visible1() { + let txs: SkipMap> = SkipMap::from_iter([ + (1, new_tx(1, 1, TransactionState::Committed(2))), + (2, new_tx(2, 2, TransactionState::Committed(5))), + (3, new_tx(3, 3, TransactionState::Aborted)), + (5, new_tx(5, 5, TransactionState::Preparing)), + (6, new_tx(6, 6, TransactionState::Committed(10))), + (7, new_tx(7, 7, TransactionState::Active)), + ]); + + let current_tx = new_tx(4, 4, TransactionState::Preparing); + let current_tx = current_tx.read().unwrap(); + + let rv_visible = |begin: TxTimestampOrID, end: Option| { + let row_version = RowVersion { + begin, + end, + row: Row { + id: RowID { + table_id: 1, + row_id: 1, + }, + data: "testme".to_string(), + }, + }; + tracing::debug!("Testing visibility of {row_version:?}"); + is_version_visible(&txs, ¤t_tx, &row_version) + }; + + // begin visible: transaction committed with ts < current_tx.begin_ts + // end visible: inf + assert!(rv_visible(TxTimestampOrID::TxID(1), None)); + + // begin invisible: transaction committed with ts > current_tx.begin_ts + assert!(!rv_visible(TxTimestampOrID::TxID(2), None)); + + // begin invisible: transaction aborted + assert!(!rv_visible(TxTimestampOrID::TxID(3), None)); + + // begin visible: timestamp < current_tx.begin_ts + // end invisible: transaction committed with ts > current_tx.begin_ts + assert!(!rv_visible( + TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::TxID(1)) + )); + + // begin visible: timestamp < current_tx.begin_ts + // end visible: transaction committed with ts < current_tx.begin_ts + assert!(rv_visible( + TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::TxID(2)) + )); + + // begin visible: timestamp < current_tx.begin_ts + // end invisible: transaction aborted + assert!(!rv_visible( + TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::TxID(3)) + )); + + // begin invisible: transaction preparing + assert!(!rv_visible(TxTimestampOrID::TxID(5), None)); + + // begin invisible: transaction committed with ts > current_tx.begin_ts + assert!(!rv_visible(TxTimestampOrID::TxID(6), None)); + + // begin invisible: transaction active + assert!(!rv_visible(TxTimestampOrID::TxID(7), None)); + + // begin invisible: transaction committed with ts > current_tx.begin_ts + assert!(!rv_visible(TxTimestampOrID::TxID(6), None)); + + // begin invisible: transaction active + assert!(!rv_visible(TxTimestampOrID::TxID(7), None)); + + // begin visible: timestamp < current_tx.begin_ts + // end invisible: transaction preparing + assert!(!rv_visible( + TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::TxID(5)) + )); + + // begin invisible: timestamp > current_tx.begin_ts + assert!(!rv_visible( + TxTimestampOrID::Timestamp(6), + Some(TxTimestampOrID::TxID(6)) + )); + + // begin visible: timestamp < current_tx.begin_ts + // end visible: some active transaction will eventually overwrite this version, + // but that hasn't happened + // (this is the https://avi.im/blag/2023/hekaton-paper-typo/ case, I believe!) + assert!(rv_visible( + TxTimestampOrID::Timestamp(0), + Some(TxTimestampOrID::TxID(7)) + )); +} diff --git a/core/mvcc/errors.rs b/core/mvcc/errors.rs new file mode 100644 index 000000000..6cdad8ca3 --- /dev/null +++ b/core/mvcc/errors.rs @@ -0,0 +1,13 @@ +use thiserror::Error; + +#[derive(Error, Debug, PartialEq)] +pub enum DatabaseError { + #[error("no such transaction ID: `{0}`")] + NoSuchTransactionID(u64), + #[error("transaction aborted because of a write-write conflict")] + WriteWriteConflict, + #[error("transaction is terminated")] + TxTerminated, + #[error("I/O error: {0}")] + Io(String), +} diff --git a/core/mvcc/mod.rs b/core/mvcc/mod.rs new file mode 100644 index 000000000..53648ec2b --- /dev/null +++ b/core/mvcc/mod.rs @@ -0,0 +1,160 @@ +//! Multiversion concurrency control (MVCC) for Rust. +//! +//! This module implements the main memory MVCC method outlined in the paper +//! "High-Performance Concurrency Control Mechanisms for Main-Memory Databases" +//! by Per-Åke Larson et al (VLDB, 2011). +//! +//! ## Data anomalies +//! +//! * A *dirty write* occurs when transaction T_m updates a value that is written by +//! transaction T_n but not yet committed. The MVCC algorithm prevents dirty +//! writes by validating that a row version is visible to transaction T_m before +//! allowing update to it. +//! +//! * A *dirty read* occurs when transaction T_m reads a value that was written by +//! transaction T_n but not yet committed. The MVCC algorithm prevents dirty +//! reads by validating that a row version is visible to transaction T_m. +//! +//! * A *fuzzy read* (non-repeatable read) occurs when transaction T_m reads a +//! different value in the course of the transaction because another +//! transaction T_n has updated the value. +//! +//! * A *lost update* occurs when transactions T_m and T_n both attempt to update +//! the same value, resulting in one of the updates being lost. The MVCC algorithm +//! prevents lost updates by detecting the write-write conflict and letting the +//! first-writer win by aborting the later transaction. +//! +//! TODO: phantom reads, cursor lost updates, read skew, write skew. +//! +//! ## TODO +//! +//! * Optimistic reads and writes +//! * Garbage collection + +pub mod clock; +pub mod cursor; +pub mod database; +pub mod errors; +pub mod persistent_storage; + +#[cfg(test)] +mod tests { + use crate::mvcc::clock::LocalClock; + use crate::mvcc::database::{Database, Row, RowID}; + use std::sync::atomic::AtomicU64; + use std::sync::atomic::Ordering; + use std::sync::Arc; + + static IDS: AtomicU64 = AtomicU64::new(1); + + #[test] + fn test_non_overlapping_concurrent_inserts() { + // Two threads insert to the database concurrently using non-overlapping + // row IDs. + let clock = LocalClock::default(); + let storage = crate::mvcc::persistent_storage::Storage::new_noop(); + let db = Arc::new(Database::new(clock, storage)); + let iterations = 100000; + + let th1 = { + let db = db.clone(); + std::thread::spawn(move || { + for _ in 0..iterations { + let tx = db.begin_tx(); + let id = IDS.fetch_add(1, Ordering::SeqCst); + let id = RowID { + table_id: 1, + row_id: id, + }; + let row = Row { + id, + data: "Hello".to_string(), + }; + db.insert(tx, row.clone()).unwrap(); + db.commit_tx(tx).unwrap(); + let tx = db.begin_tx(); + let committed_row = db.read(tx, id).unwrap(); + db.commit_tx(tx).unwrap(); + assert_eq!(committed_row, Some(row)); + } + }) + }; + let th2 = { + std::thread::spawn(move || { + for _ in 0..iterations { + let tx = db.begin_tx(); + let id = IDS.fetch_add(1, Ordering::SeqCst); + let id = RowID { + table_id: 1, + row_id: id, + }; + let row = Row { + id, + data: "World".to_string(), + }; + db.insert(tx, row.clone()).unwrap(); + db.commit_tx(tx).unwrap(); + let tx = db.begin_tx(); + let committed_row = db.read(tx, id).unwrap(); + db.commit_tx(tx).unwrap(); + assert_eq!(committed_row, Some(row)); + } + }) + }; + th1.join().unwrap(); + th2.join().unwrap(); + } + + // FIXME: This test fails sporadically. + #[test] + #[ignore] + fn test_overlapping_concurrent_inserts_read_your_writes() { + let clock = LocalClock::default(); + let storage = crate::mvcc::persistent_storage::Storage::new_noop(); + let db = Arc::new(Database::new(clock, storage)); + let iterations = 100000; + + let work = |prefix: &'static str| { + let db = db.clone(); + std::thread::spawn(move || { + let mut failed_upserts = 0; + for i in 0..iterations { + if i % 1000 == 0 { + tracing::debug!("{prefix}: {i}"); + } + if i % 10000 == 0 { + let dropped = db.drop_unused_row_versions(); + tracing::debug!("garbage collected {dropped} versions"); + } + let tx = db.begin_tx(); + let id = i % 16; + let id = RowID { + table_id: 1, + row_id: id, + }; + let row = Row { + id, + data: format!("{prefix} @{tx}"), + }; + if let Err(e) = db.upsert(tx, row.clone()) { + tracing::trace!("upsert failed: {e}"); + failed_upserts += 1; + continue; + } + let committed_row = db.read(tx, id).unwrap(); + db.commit_tx(tx).unwrap(); + assert_eq!(committed_row, Some(row)); + } + tracing::info!( + "{prefix}'s failed upserts: {failed_upserts}/{iterations} {:.2}%", + (failed_upserts * 100) as f64 / iterations as f64 + ); + }) + }; + + let threads = vec![work("A"), work("B"), work("C"), work("D")]; + for th in threads { + th.join().unwrap(); + } + } +} diff --git a/core/mvcc/persistent_storage/mod.rs b/core/mvcc/persistent_storage/mod.rs new file mode 100644 index 000000000..3f9ff2171 --- /dev/null +++ b/core/mvcc/persistent_storage/mod.rs @@ -0,0 +1,32 @@ +use std::fmt::Debug; + +use crate::mvcc::database::{LogRecord, Result}; +use crate::mvcc::errors::DatabaseError; + +#[derive(Debug)] +pub enum Storage { + Noop, +} + +impl Storage { + pub fn new_noop() -> Self { + Self::Noop + } +} + +impl Storage { + pub fn log_tx(&self, _m: LogRecord) -> Result<()> { + match self { + Self::Noop => (), + } + Ok(()) + } + + pub fn read_tx_log(&self) -> Result>> { + match self { + Self::Noop => Err(DatabaseError::Io( + "cannot read from Noop storage".to_string(), + )), + } + } +} diff --git a/core/schema.rs b/core/schema.rs index fda6c12ba..76ee11544 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -66,8 +66,14 @@ impl Table { pub fn get_column_at(&self, index: usize) -> &Column { match self { - Self::BTree(table) => table.columns.get(index).unwrap(), - Self::Pseudo(table) => table.columns.get(index).unwrap(), + Self::BTree(table) => table + .columns + .get(index) + .expect("column index out of bounds"), + Self::Pseudo(table) => table + .columns + .get(index) + .expect("column index out of bounds"), } } @@ -176,8 +182,11 @@ impl PseudoTable { self.columns.push(Column { name: normalize_ident(name), ty, + ty_str: ty.to_string(), primary_key, is_rowid_alias: false, + notnull: false, + default: None, }); } pub fn get_column(&self, name: &str) -> Option<(usize, &Column)> { @@ -243,47 +252,76 @@ fn create_table( // and the value of this column are the same. // https://www.sqlite.org/lang_createtable.html#rowids_and_the_integer_primary_key let mut typename_exactly_integer = false; - let ty = match col_def.col_type { + let (ty, ty_str) = match col_def.col_type { Some(data_type) => { + let s = data_type.name.as_str(); + let ty_str = if matches!( + s.to_uppercase().as_str(), + "TEXT" | "INT" | "INTEGER" | "BLOB" | "REAL" + ) { + s.to_uppercase().to_string() + } else { + s.to_string() + }; + // https://www.sqlite.org/datatype3.html - let type_name = data_type.name.as_str().to_uppercase(); + let type_name = ty_str.to_uppercase(); if type_name.contains("INT") { typename_exactly_integer = type_name == "INTEGER"; - Type::Integer + (Type::Integer, ty_str) } else if type_name.contains("CHAR") || type_name.contains("CLOB") || type_name.contains("TEXT") { - Type::Text - } else if type_name.contains("BLOB") || type_name.is_empty() { - Type::Blob + (Type::Text, ty_str) + } else if type_name.contains("BLOB") { + (Type::Blob, ty_str) + } else if type_name.is_empty() { + (Type::Blob, "".to_string()) } else if type_name.contains("REAL") || type_name.contains("FLOA") || type_name.contains("DOUB") { - Type::Real + (Type::Real, ty_str) } else { - Type::Numeric + (Type::Numeric, ty_str) } } - None => Type::Null, + None => (Type::Null, "".to_string()), }; - let mut primary_key = col_def.constraints.iter().any(|c| { - matches!( - c.constraint, - sqlite3_parser::ast::ColumnConstraint::PrimaryKey { .. } - ) - }); + + let mut default = None; + let mut primary_key = false; + let mut notnull = false; + for c_def in &col_def.constraints { + match &c_def.constraint { + sqlite3_parser::ast::ColumnConstraint::PrimaryKey { .. } => { + primary_key = true; + } + sqlite3_parser::ast::ColumnConstraint::NotNull { .. } => { + notnull = true; + } + sqlite3_parser::ast::ColumnConstraint::Default(expr) => { + default = Some(expr.clone()) + } + _ => {} + } + } + if primary_key { primary_key_column_names.push(name.clone()); } else if primary_key_column_names.contains(&name) { primary_key = true; } + cols.push(Column { name: normalize_ident(&name), ty, + ty_str, primary_key, is_rowid_alias: typename_exactly_integer && primary_key, + notnull, + default, }); } if options.contains(TableOptions::WITHOUT_ROWID) { @@ -330,8 +368,12 @@ pub fn _build_pseudo_table(columns: &[ResultColumn]) -> PseudoTable { pub struct Column { pub name: String, pub ty: Type, + // many sqlite operations like table_info retain the original string + pub ty_str: String, pub primary_key: bool, pub is_rowid_alias: bool, + pub notnull: bool, + pub default: Option, } #[derive(Debug, Clone, Copy, PartialEq)] @@ -347,7 +389,7 @@ pub enum Type { impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match self { - Self::Null => "NULL", + Self::Null => "", Self::Text => "TEXT", Self::Numeric => "NUMERIC", Self::Integer => "INTEGER", @@ -368,32 +410,47 @@ pub fn sqlite_schema_table() -> BTreeTable { Column { name: "type".to_string(), ty: Type::Text, + ty_str: "TEXT".to_string(), primary_key: false, is_rowid_alias: false, + notnull: false, + default: None, }, Column { name: "name".to_string(), ty: Type::Text, + ty_str: "TEXT".to_string(), primary_key: false, is_rowid_alias: false, + notnull: false, + default: None, }, Column { name: "tbl_name".to_string(), ty: Type::Text, + ty_str: "TEXT".to_string(), primary_key: false, is_rowid_alias: false, + notnull: false, + default: None, }, Column { name: "rootpage".to_string(), ty: Type::Integer, + ty_str: "INT".to_string(), primary_key: false, is_rowid_alias: false, + notnull: false, + default: None, }, Column { name: "sql".to_string(), ty: Type::Text, + ty_str: "TEXT".to_string(), primary_key: false, is_rowid_alias: false, + notnull: false, + default: None, }, ], } @@ -711,6 +768,79 @@ mod tests { Ok(()) } + #[test] + pub fn test_default_value() -> Result<()> { + let sql = r#"CREATE TABLE t1 (a INTEGER DEFAULT 23);"#; + let table = BTreeTable::from_sql(sql, 0)?; + let column = table.get_column("a").unwrap().1; + let default = column.default.clone().unwrap(); + assert_eq!(default.to_string(), "23"); + Ok(()) + } + + #[test] + pub fn test_col_notnull() -> Result<()> { + let sql = r#"CREATE TABLE t1 (a INTEGER NOT NULL);"#; + let table = BTreeTable::from_sql(sql, 0)?; + let column = table.get_column("a").unwrap().1; + assert_eq!(column.notnull, true); + Ok(()) + } + + #[test] + pub fn test_col_notnull_negative() -> Result<()> { + let sql = r#"CREATE TABLE t1 (a INTEGER);"#; + let table = BTreeTable::from_sql(sql, 0)?; + let column = table.get_column("a").unwrap().1; + assert_eq!(column.notnull, false); + Ok(()) + } + + #[test] + pub fn test_col_type_string_integer() -> Result<()> { + let sql = r#"CREATE TABLE t1 (a InTeGeR);"#; + let table = BTreeTable::from_sql(sql, 0)?; + let column = table.get_column("a").unwrap().1; + assert_eq!(column.ty_str, "INTEGER"); + Ok(()) + } + + #[test] + pub fn test_col_type_string_int() -> Result<()> { + let sql = r#"CREATE TABLE t1 (a InT);"#; + let table = BTreeTable::from_sql(sql, 0)?; + let column = table.get_column("a").unwrap().1; + assert_eq!(column.ty_str, "INT"); + Ok(()) + } + + #[test] + pub fn test_col_type_string_blob() -> Result<()> { + let sql = r#"CREATE TABLE t1 (a bLoB);"#; + let table = BTreeTable::from_sql(sql, 0)?; + let column = table.get_column("a").unwrap().1; + assert_eq!(column.ty_str, "BLOB"); + Ok(()) + } + + #[test] + pub fn test_col_type_string_empty() -> Result<()> { + let sql = r#"CREATE TABLE t1 (a);"#; + let table = BTreeTable::from_sql(sql, 0)?; + let column = table.get_column("a").unwrap().1; + assert_eq!(column.ty_str, ""); + Ok(()) + } + + #[test] + pub fn test_col_type_string_some_nonsense() -> Result<()> { + let sql = r#"CREATE TABLE t1 (a someNonsenseName);"#; + let table = BTreeTable::from_sql(sql, 0)?; + let column = table.get_column("a").unwrap().1; + assert_eq!(column.ty_str, "someNonsenseName"); + Ok(()) + } + #[test] pub fn test_sqlite_schema() { let expected = r#"CREATE TABLE sqlite_schema ( @@ -783,8 +913,11 @@ mod tests { columns: vec![Column { name: "a".to_string(), ty: Type::Integer, + ty_str: "INT".to_string(), primary_key: false, is_rowid_alias: false, + notnull: false, + default: None, }], }; diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 7e7ee4289..5f53112e1 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -74,7 +74,7 @@ macro_rules! return_if_locked { /// State machine of a write operation. /// May involve balancing due to overflow. -#[derive(Debug)] +#[derive(Debug, Clone)] enum WriteState { Start, BalanceStart, @@ -97,6 +97,40 @@ struct WriteInfo { page_copy: RefCell>, } +impl WriteInfo { + fn new() -> WriteInfo { + WriteInfo { + state: WriteState::Start, + new_pages: RefCell::new(Vec::with_capacity(4)), + scratch_cells: RefCell::new(Vec::new()), + rightmost_pointer: RefCell::new(None), + page_copy: RefCell::new(None), + } + } +} + +/// Holds the state machine for the operation that was in flight when the cursor +/// was suspended due to IO. +enum CursorState { + None, + Write(WriteInfo), +} + +impl CursorState { + fn write_info(&self) -> Option<&WriteInfo> { + match self { + CursorState::Write(x) => Some(x), + _ => None, + } + } + fn mut_write_info(&mut self) -> Option<&mut WriteInfo> { + match self { + CursorState::Write(x) => Some(x), + _ => None, + } + } +} + pub struct BTreeCursor { pager: Rc, /// Page id of the root page used to go back up fast. @@ -109,9 +143,8 @@ pub struct BTreeCursor { /// we just moved to a parent page and the parent page is an internal index page which requires /// to be consumed. going_upwards: bool, - /// Write information kept in case of write yields due to I/O. Needs to be stored somewhere - /// right :). - write_info: WriteInfo, + /// Information maintained across execution attempts when an operation yields due to I/O. + state: CursorState, /// Page stack used to traverse the btree. /// Each cursor has a stack because each cursor traverses the btree independently. stack: PageStack, @@ -144,13 +177,7 @@ impl BTreeCursor { record: RefCell::new(None), null_flag: false, going_upwards: false, - write_info: WriteInfo { - state: WriteState::Start, - new_pages: RefCell::new(Vec::with_capacity(4)), - scratch_cells: RefCell::new(Vec::new()), - rightmost_pointer: RefCell::new(None), - page_copy: RefCell::new(None), - }, + state: CursorState::None, stack: PageStack { current_page: RefCell::new(-1), cell_indices: RefCell::new([0; BTCURSOR_MAX_DEPTH + 1]), @@ -676,9 +703,18 @@ impl BTreeCursor { key: &OwnedValue, record: &OwnedRecord, ) -> Result> { - loop { - let state = &self.write_info.state; - match state { + if let CursorState::None = &self.state { + self.state = CursorState::Write(WriteInfo::new()); + } + let ret = loop { + let write_state = { + let write_info = self + .state + .mut_write_info() + .expect("can't insert while counting"); + write_info.state.clone() + }; + match write_state { WriteState::Start => { let page = self.stack.top(); let int_key = match key { @@ -718,10 +754,14 @@ impl BTreeCursor { self.insert_into_cell(contents, cell_payload.as_slice(), cell_idx); contents.overflow_cells.len() }; + let write_info = self + .state + .mut_write_info() + .expect("can't count while inserting"); if overflow > 0 { - self.write_info.state = WriteState::BalanceStart; + write_info.state = WriteState::BalanceStart; } else { - self.write_info.state = WriteState::Finish; + write_info.state = WriteState::Finish; } } WriteState::BalanceStart @@ -731,11 +771,12 @@ impl BTreeCursor { return_if_io!(self.balance()); } WriteState::Finish => { - self.write_info.state = WriteState::Start; - return Ok(CursorResult::Ok(())); + break Ok(CursorResult::Ok(())); } }; - } + }; + self.state = CursorState::None; + return ret; } /// Insert a record into a cell. @@ -879,7 +920,16 @@ impl BTreeCursor { /// It will try to split the page in half by keys not by content. /// Sqlite tries to have a page at least 40% full. fn balance(&mut self) -> Result> { - let state = &self.write_info.state; + assert!( + matches!(self.state, CursorState::Write(_)), + "Cursor must be in balancing state" + ); + let state = self + .state + .write_info() + .expect("must be balancing") + .state + .clone(); match state { WriteState::BalanceStart => { // drop divider cells and find right pointer @@ -893,7 +943,8 @@ impl BTreeCursor { // don't continue if there are no overflow cells let page = current_page.get().contents.as_mut().unwrap(); if page.overflow_cells.is_empty() { - self.write_info.state = WriteState::Finish; + let write_info = self.state.mut_write_info().unwrap(); + write_info.state = WriteState::Finish; return Ok(CursorResult::Ok(())); } } @@ -903,7 +954,8 @@ impl BTreeCursor { return Ok(CursorResult::Ok(())); } - self.write_info.state = WriteState::BalanceNonRoot; + let write_info = self.state.mut_write_info().unwrap(); + write_info.state = WriteState::BalanceNonRoot; self.balance_non_root() } WriteState::BalanceNonRoot @@ -915,8 +967,17 @@ impl BTreeCursor { } fn balance_non_root(&mut self) -> Result> { - let state = &self.write_info.state; - match state { + assert!( + matches!(self.state, CursorState::Write(_)), + "Cursor must be in balancing state" + ); + let state = self + .state + .write_info() + .expect("must be balancing") + .state + .clone(); + let (next_write_state, result) = match state { WriteState::Start => todo!(), WriteState::BalanceStart => todo!(), WriteState::BalanceNonRoot => { @@ -935,7 +996,8 @@ impl BTreeCursor { // In memory in order copy of all cells in pages we want to balance. For now let's do a 2 page split. // Right pointer in interior cells should be converted to regular cells if more than 2 pages are used for balancing. - let mut scratch_cells = self.write_info.scratch_cells.borrow_mut(); + let write_info = self.state.write_info().unwrap(); + let mut scratch_cells = write_info.scratch_cells.borrow_mut(); scratch_cells.clear(); for cell_idx in 0..page_copy.cell_count() { @@ -952,9 +1014,9 @@ impl BTreeCursor { scratch_cells .insert(overflow_cell.index, to_static_buf(&overflow_cell.payload)); } - *self.write_info.rightmost_pointer.borrow_mut() = page_copy.rightmost_pointer(); - self.write_info.page_copy.replace(Some(page_copy)); + *write_info.rightmost_pointer.borrow_mut() = page_copy.rightmost_pointer(); + write_info.page_copy.replace(Some(page_copy)); // allocate new pages and move cells to those new pages // split procedure @@ -970,15 +1032,9 @@ impl BTreeCursor { let right_page = self.allocate_page(page.page_type(), 0); let right_page_id = right_page.get().id; - self.write_info.new_pages.borrow_mut().clear(); - self.write_info - .new_pages - .borrow_mut() - .push(current_page.clone()); - self.write_info - .new_pages - .borrow_mut() - .push(right_page.clone()); + write_info.new_pages.borrow_mut().clear(); + write_info.new_pages.borrow_mut().push(current_page.clone()); + write_info.new_pages.borrow_mut().push(right_page.clone()); debug!( "splitting left={} right={}", @@ -986,8 +1042,7 @@ impl BTreeCursor { right_page_id ); - self.write_info.state = WriteState::BalanceGetParentPage; - Ok(CursorResult::Ok(())) + (WriteState::BalanceGetParentPage, Ok(CursorResult::Ok(()))) } WriteState::BalanceGetParentPage => { let parent = self.stack.parent(); @@ -1000,8 +1055,7 @@ impl BTreeCursor { return Ok(CursorResult::IO); } parent.set_dirty(); - self.write_info.state = WriteState::BalanceMoveUp; - Ok(CursorResult::Ok(())) + (WriteState::BalanceMoveUp, Ok(CursorResult::Ok(()))) } WriteState::BalanceMoveUp => { let parent = self.stack.parent(); @@ -1046,8 +1100,9 @@ impl BTreeCursor { } } - let mut new_pages = self.write_info.new_pages.borrow_mut(); - let scratch_cells = self.write_info.scratch_cells.borrow(); + let write_info = self.state.write_info().unwrap(); + let mut new_pages = write_info.new_pages.borrow_mut(); + let scratch_cells = write_info.scratch_cells.borrow(); // reset pages for page in new_pages.iter() { @@ -1140,7 +1195,7 @@ impl BTreeCursor { let last_page_contents = last_page.get().contents.as_mut().unwrap(); last_page_contents.write_u32( PAGE_HEADER_OFFSET_RIGHTMOST_PTR, - self.write_info.rightmost_pointer.borrow().unwrap(), + write_info.rightmost_pointer.borrow().unwrap(), ); } @@ -1197,12 +1252,14 @@ impl BTreeCursor { parent_contents.write_u32(right_pointer, last_pointer); } self.stack.pop(); - self.write_info.state = WriteState::BalanceStart; - let _ = self.write_info.page_copy.take(); - Ok(CursorResult::Ok(())) + let _ = write_info.page_copy.take(); + (WriteState::BalanceStart, Ok(CursorResult::Ok(()))) } WriteState::Finish => todo!(), - } + }; + let write_info = self.state.mut_write_info().unwrap(); + write_info.state = next_write_state; + result } /// Balance the root page. diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 559e872ae..1be78cdb6 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -2,14 +2,15 @@ use crate::result::LimboResult; use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; use crate::storage::sqlite3_ondisk::{self, DatabaseHeader, PageContent}; -use crate::storage::wal::Wal; +use crate::storage::wal::{CheckpointResult, Wal}; use crate::{Buffer, Result}; use log::trace; +use parking_lot::RwLock; use std::cell::{RefCell, UnsafeCell}; use std::collections::HashSet; use std::rc::Rc; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use super::page_cache::{DumbLruPageCache, PageCacheKey}; use super::wal::{CheckpointMode, CheckpointStatus}; @@ -207,18 +208,25 @@ impl Pager { } pub fn end_tx(&self) -> Result { - match self.cacheflush()? { - CheckpointStatus::Done => {} - CheckpointStatus::IO => return Ok(CheckpointStatus::IO), - }; + let checkpoint_status = self.cacheflush()?; + match checkpoint_status { + CheckpointStatus::IO => Ok(checkpoint_status), + CheckpointStatus::Done(_) => { + self.wal.borrow().end_read_tx()?; + Ok(checkpoint_status) + } + } + } + + pub fn end_read_tx(&self) -> Result<()> { self.wal.borrow().end_read_tx()?; - Ok(CheckpointStatus::Done) + Ok(()) } /// Reads a page from the database. pub fn read_page(&self, page_idx: usize) -> Result { trace!("read_page(page_idx = {})", page_idx); - let mut page_cache = self.page_cache.write().unwrap(); + let mut page_cache = self.page_cache.write(); let page_key = PageCacheKey::new(page_idx, Some(self.wal.borrow().get_max_frame())); if let Some(page) = page_cache.get(&page_key) { trace!("read_page(page_idx = {}) = cached", page_idx); @@ -254,7 +262,7 @@ impl Pager { pub fn load_page(&self, page: PageRef) -> Result<()> { let id = page.get().id; trace!("load_page(page_idx = {})", id); - let mut page_cache = self.page_cache.write().unwrap(); + let mut page_cache = self.page_cache.write(); page.set_locked(); let page_key = PageCacheKey::new(id, Some(self.wal.borrow().get_max_frame())); if let Some(frame_id) = self.wal.borrow().find_frame(id as u64)? { @@ -290,7 +298,7 @@ impl Pager { /// Changes the size of the page cache. pub fn change_page_cache_size(&self, capacity: usize) { - let mut page_cache = self.page_cache.write().unwrap(); + let mut page_cache = self.page_cache.write(); page_cache.resize(capacity); } @@ -301,13 +309,14 @@ impl Pager { } pub fn cacheflush(&self) -> Result { + let mut checkpoint_result = CheckpointResult::new(); loop { let state = self.flush_info.borrow().state.clone(); match state { FlushState::Start => { let db_size = self.db_header.borrow().database_size; for page_id in self.dirty_pages.borrow().iter() { - let mut cache = self.page_cache.write().unwrap(); + let mut cache = self.page_cache.write(); let page_key = PageCacheKey::new(*page_id, Some(self.wal.borrow().get_max_frame())); let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it."); @@ -334,7 +343,7 @@ impl Pager { FlushState::SyncWal => { match self.wal.borrow_mut().sync() { Ok(CheckpointStatus::IO) => return Ok(CheckpointStatus::IO), - Ok(CheckpointStatus::Done) => {} + Ok(CheckpointStatus::Done(res)) => checkpoint_result = res, Err(e) => return Err(e), } @@ -348,7 +357,8 @@ impl Pager { } FlushState::Checkpoint => { match self.checkpoint()? { - CheckpointStatus::Done => { + CheckpointStatus::Done(res) => { + checkpoint_result = res; self.flush_info.borrow_mut().state = FlushState::SyncDbFile; } CheckpointStatus::IO => return Ok(CheckpointStatus::IO), @@ -368,10 +378,11 @@ impl Pager { } } } - Ok(CheckpointStatus::Done) + Ok(CheckpointStatus::Done(checkpoint_result)) } pub fn checkpoint(&self) -> Result { + let mut checkpoint_result = CheckpointResult::new(); loop { let state = self.checkpoint_state.borrow().clone(); trace!("pager_checkpoint(state={:?})", state); @@ -384,7 +395,8 @@ impl Pager { CheckpointMode::Passive, )? { CheckpointStatus::IO => return Ok(CheckpointStatus::IO), - CheckpointStatus::Done => { + CheckpointStatus::Done(res) => { + checkpoint_result = res; self.checkpoint_state.replace(CheckpointState::SyncDbFile); } }; @@ -408,7 +420,7 @@ impl Pager { Ok(CheckpointStatus::IO) } else { self.checkpoint_state.replace(CheckpointState::Checkpoint); - Ok(CheckpointStatus::Done) + Ok(CheckpointStatus::Done(checkpoint_result)) }; } } @@ -416,7 +428,8 @@ impl Pager { } // WARN: used for testing purposes - pub fn clear_page_cache(&self) { + pub fn clear_page_cache(&self) -> CheckpointResult { + let checkpoint_result: CheckpointResult; loop { match self.wal.borrow_mut().checkpoint( self, @@ -426,14 +439,16 @@ impl Pager { Ok(CheckpointStatus::IO) => { let _ = self.io.run_once(); } - Ok(CheckpointStatus::Done) => { + Ok(CheckpointStatus::Done(res)) => { + checkpoint_result = res; break; } Err(err) => panic!("error while clearing cache {}", err), } } // TODO: only clear cache of things that are really invalidated - self.page_cache.write().unwrap().clear(); + self.page_cache.write().clear(); + checkpoint_result } /* @@ -468,7 +483,7 @@ impl Pager { // setup page and add to cache page.set_dirty(); self.add_dirty(page.get().id); - let mut cache = self.page_cache.write().unwrap(); + let mut cache = self.page_cache.write(); let page_key = PageCacheKey::new(page.get().id, Some(self.wal.borrow().get_max_frame())); cache.insert(page_key, page.clone()); @@ -477,7 +492,7 @@ impl Pager { } pub fn put_loaded_page(&self, id: usize, page: PageRef) { - let mut cache = self.page_cache.write().unwrap(); + let mut cache = self.page_cache.write(); // cache insert invalidates previous page let page_key = PageCacheKey::new(id, Some(self.wal.borrow().get_max_frame())); cache.insert(page_key, page.clone()); @@ -511,7 +526,9 @@ pub fn allocate_page(page_id: usize, buffer_pool: &Rc, offset: usize #[cfg(test)] mod tests { - use std::sync::{Arc, RwLock}; + use std::sync::Arc; + + use parking_lot::RwLock; use crate::storage::page_cache::{DumbLruPageCache, PageCacheKey}; @@ -525,13 +542,13 @@ mod tests { let thread = { let cache = cache.clone(); std::thread::spawn(move || { - let mut cache = cache.write().unwrap(); + let mut cache = cache.write(); let page_key = PageCacheKey::new(1, None); cache.insert(page_key, Arc::new(Page::new(1))); }) }; let _ = thread.join(); - let mut cache = cache.write().unwrap(); + let mut cache = cache.write(); let page_key = PageCacheKey::new(1, None); let page = cache.get(&page_key); assert_eq!(page.unwrap().get().id, 1); diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 5543fa7db..e86d89520 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -49,10 +49,11 @@ use crate::storage::pager::Pager; use crate::types::{OwnedRecord, OwnedValue}; use crate::{File, Result}; use log::trace; +use parking_lot::RwLock; use std::cell::RefCell; use std::pin::Pin; use std::rc::Rc; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use super::pager::PageRef; @@ -1147,7 +1148,7 @@ pub fn begin_read_wal_header(io: &Rc) -> Result> fn finish_read_wal_header(buf: Rc>, header: Arc>) -> Result<()> { let buf = buf.borrow(); let buf = buf.as_slice(); - let mut header = header.write().unwrap(); + let mut header = header.write(); header.magic = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]); header.file_format = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]); header.page_size = u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]]); diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 0de3b7590..e42fde2b6 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -1,9 +1,10 @@ -use std::collections::HashMap; -use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::RwLock; -use std::{cell::RefCell, rc::Rc, sync::Arc}; - use log::{debug, trace}; +use std::collections::HashMap; + +use parking_lot::RwLock; +use std::fmt::Formatter; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::{cell::RefCell, fmt, rc::Rc, sync::Arc}; use crate::io::{File, SyncCompletion, IO}; use crate::result::LimboResult; @@ -25,6 +26,23 @@ pub const NO_LOCK: u32 = 0; pub const SHARED_LOCK: u32 = 1; pub const WRITE_LOCK: u32 = 2; +#[derive(Debug)] +pub struct CheckpointResult { + /// number of frames in WAL + pub num_wal_frames: u64, + /// number of frames moved successfully from WAL to db file after checkpoint + pub num_checkpointed_frames: u64, +} + +impl CheckpointResult { + pub fn new() -> Self { + Self { + num_wal_frames: 0, + num_checkpointed_frames: 0, + } + } +} + #[derive(Debug)] pub enum CheckpointMode { Passive, @@ -159,7 +177,7 @@ pub trait Wal { // Syncing requires a state machine because we need to schedule a sync and then wait until it is // finished. If we don't wait there will be undefined behaviour that no one wants to debug. -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] enum SyncState { NotSyncing, Syncing, @@ -176,7 +194,7 @@ pub enum CheckpointState { } pub enum CheckpointStatus { - Done, + Done(CheckpointResult), IO, } @@ -196,6 +214,17 @@ struct OngoingCheckpoint { current_page: u64, } +impl fmt::Debug for OngoingCheckpoint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("OngoingCheckpoint") + .field("state", &self.state) + .field("min_frame", &self.min_frame) + .field("max_frame", &self.max_frame) + .field("current_page", &self.current_page) + .finish() + } +} + #[allow(dead_code)] pub struct WalFile { io: Arc, @@ -218,6 +247,23 @@ pub struct WalFile { min_frame: u64, } +impl fmt::Debug for WalFile { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WalFile") + .field("sync_state", &self.sync_state) + .field("syncing", &self.syncing) + .field("page_size", &self.page_size) + .field("shared", &self.shared) + .field("ongoing_checkpoint", &self.ongoing_checkpoint) + .field("checkpoint_threshold", &self.checkpoint_threshold) + .field("max_frame_read_lock_index", &self.max_frame_read_lock_index) + .field("max_frame", &self.max_frame) + .field("min_frame", &self.min_frame) + // Excluding other fields + .finish() + } +} + // TODO(pere): lock only important parts + pin WalFileShared /// WalFileShared is the part of a WAL that will be shared between threads. A wal has information /// that needs to be communicated between threads so this struct does the job. @@ -248,10 +294,25 @@ pub struct WalFileShared { write_lock: LimboRwLock, } +impl fmt::Debug for WalFileShared { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WalFileShared") + .field("wal_header", &self.wal_header) + .field("min_frame", &self.min_frame) + .field("max_frame", &self.max_frame) + .field("nbackfills", &self.nbackfills) + .field("frame_cache", &self.frame_cache) + .field("pages_in_frames", &self.pages_in_frames) + .field("last_checksum", &self.last_checksum) + // Excluding `file`, `read_locks`, and `write_lock` + .finish() + } +} + impl Wal for WalFile { /// Begin a read transaction. fn begin_read_tx(&mut self) -> Result { - let mut shared = self.shared.write().unwrap(); + let mut shared = self.shared.write(); let max_frame_in_wal = shared.max_frame; self.min_frame = shared.nbackfills + 1; @@ -305,7 +366,7 @@ impl Wal for WalFile { /// End a read transaction. fn end_read_tx(&self) -> Result { - let mut shared = self.shared.write().unwrap(); + let mut shared = self.shared.write(); let read_lock = &mut shared.read_locks[self.max_frame_read_lock_index]; read_lock.unlock(); Ok(LimboResult::Ok) @@ -313,7 +374,7 @@ impl Wal for WalFile { /// Begin a write transaction fn begin_write_tx(&mut self) -> Result { - let mut shared = self.shared.write().unwrap(); + let mut shared = self.shared.write(); let busy = !shared.write_lock.write(); if busy { return Ok(LimboResult::Busy); @@ -323,14 +384,14 @@ impl Wal for WalFile { /// End a write transaction fn end_write_tx(&self) -> Result { - let mut shared = self.shared.write().unwrap(); + let mut shared = self.shared.write(); shared.write_lock.unlock(); Ok(LimboResult::Ok) } /// Find the latest frame containing a page. fn find_frame(&self, page_id: u64) -> Result> { - let shared = self.shared.read().unwrap(); + let shared = self.shared.read(); let frames = shared.frame_cache.get(&page_id); if frames.is_none() { return Ok(None); @@ -348,7 +409,7 @@ impl Wal for WalFile { fn read_frame(&self, frame_id: u64, page: PageRef, buffer_pool: Rc) -> Result<()> { debug!("read_frame({})", frame_id); let offset = self.frame_offset(frame_id); - let shared = self.shared.read().unwrap(); + let shared = self.shared.read(); page.set_locked(); begin_read_wal_frame( &shared.file, @@ -367,7 +428,7 @@ impl Wal for WalFile { write_counter: Rc>, ) -> Result<()> { let page_id = page.get().id; - let mut shared = self.shared.write().unwrap(); + let mut shared = self.shared.write(); let frame_id = if shared.max_frame == 0 { 1 } else { @@ -381,7 +442,7 @@ impl Wal for WalFile { page_id ); let header = shared.wal_header.clone(); - let header = header.read().unwrap(); + let header = header.read(); let checksums = shared.last_checksum; let checksums = begin_write_wal_frame( &shared.file, @@ -408,7 +469,7 @@ impl Wal for WalFile { } fn should_checkpoint(&self) -> bool { - let shared = self.shared.read().unwrap(); + let shared = self.shared.read(); let frame_id = shared.max_frame as usize; frame_id >= self.checkpoint_threshold } @@ -430,7 +491,7 @@ impl Wal for WalFile { CheckpointState::Start => { // TODO(pere): check what frames are safe to checkpoint between many readers! self.ongoing_checkpoint.min_frame = self.min_frame; - let mut shared = self.shared.write().unwrap(); + let mut shared = self.shared.write(); let max_frame_in_wal = shared.max_frame as u32; let mut max_safe_frame = shared.max_frame; for read_lock in shared.read_locks.iter_mut() { @@ -455,7 +516,7 @@ impl Wal for WalFile { ); } CheckpointState::ReadFrame => { - let shared = self.shared.read().unwrap(); + let shared = self.shared.read(); assert!( self.ongoing_checkpoint.current_page as usize <= shared.pages_in_frames.len() @@ -514,7 +575,7 @@ impl Wal for WalFile { if *write_counter.borrow() > 0 { return Ok(CheckpointStatus::IO); } - let shared = self.shared.read().unwrap(); + let shared = self.shared.read(); if (self.ongoing_checkpoint.current_page as usize) < shared.pages_in_frames.len() { @@ -527,7 +588,14 @@ impl Wal for WalFile { if *write_counter.borrow() > 0 { return Ok(CheckpointStatus::IO); } - let mut shared = self.shared.write().unwrap(); + let mut shared = self.shared.write(); + + // Record two num pages fields to return as checkpoint result to caller. + // Ref: pnLog, pnCkpt on https://www.sqlite.org/c3ref/wal_checkpoint_v2.html + let checkpoint_result = CheckpointResult { + num_wal_frames: shared.max_frame, + num_checkpointed_frames: self.ongoing_checkpoint.max_frame, + }; let everything_backfilled = shared.max_frame == self.ongoing_checkpoint.max_frame; if everything_backfilled { @@ -541,7 +609,7 @@ impl Wal for WalFile { shared.nbackfills = self.ongoing_checkpoint.max_frame; } self.ongoing_checkpoint.state = CheckpointState::Start; - return Ok(CheckpointStatus::Done); + return Ok(CheckpointStatus::Done(checkpoint_result)); } } } @@ -551,7 +619,7 @@ impl Wal for WalFile { let state = *self.sync_state.borrow(); match state { SyncState::NotSyncing => { - let shared = self.shared.write().unwrap(); + let shared = self.shared.write(); debug!("wal_sync"); { let syncing = self.syncing.clone(); @@ -572,7 +640,11 @@ impl Wal for WalFile { Ok(CheckpointStatus::IO) } else { self.sync_state.replace(SyncState::NotSyncing); - Ok(CheckpointStatus::Done) + let checkpoint_result = CheckpointResult { + num_wal_frames: self.max_frame, + num_checkpointed_frames: self.ongoing_checkpoint.max_frame, + }; + Ok(CheckpointStatus::Done(checkpoint_result)) } } } @@ -685,7 +757,7 @@ impl WalFileShared { Arc::new(RwLock::new(wal_header)) }; let checksum = { - let checksum = header.read().unwrap(); + let checksum = header.read(); (checksum.checksum_1, checksum.checksum_2) }; let shared = WalFileShared { diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index d23caf3ec..7eec4531f 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -41,7 +41,7 @@ pub fn emit_ungrouped_aggregation<'a>( // This always emits a ResultRow because currently it can only be used for a single row result // Limit is None because we early exit on limit 0 and the max rows here is 1 - emit_select_result(program, t_ctx, plan, None)?; + emit_select_result(program, t_ctx, plan, None, None)?; Ok(()) } diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 373a74024..6af78cdb7 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -1,13 +1,13 @@ use crate::schema::Table; use crate::translate::emitter::emit_program; use crate::translate::optimizer::optimize_plan; -use crate::translate::plan::{DeletePlan, Plan, SourceOperator}; +use crate::translate::plan::{DeletePlan, Operation, Plan}; use crate::translate::planner::{parse_limit, parse_where}; use crate::vdbe::builder::ProgramBuilder; use crate::{schema::Schema, Result, SymbolTable}; use sqlite3_parser::ast::{Expr, Limit, QualifiedName}; -use super::plan::{TableReference, TableReferenceType}; +use super::plan::TableReference; pub fn translate_delete( program: &mut ProgramBuilder, @@ -18,7 +18,7 @@ pub fn translate_delete( syms: &SymbolTable, ) -> Result<()> { let mut delete_plan = prepare_delete_plan(schema, tbl_name, where_clause, limit)?; - optimize_plan(&mut delete_plan)?; + optimize_plan(&mut delete_plan, schema)?; emit_program(program, delete_plan, syms) } @@ -33,33 +33,28 @@ pub fn prepare_delete_plan( None => crate::bail_corrupt_error!("Parse error: no such table: {}", tbl_name), }; - let btree_table_ref = TableReference { + let table_references = vec![TableReference { table: Table::BTree(table.clone()), - table_identifier: table.name.clone(), - table_index: 0, - reference_type: TableReferenceType::BTreeTable, - }; - let referenced_tables = vec![btree_table_ref.clone()]; + identifier: table.name.clone(), + op: Operation::Scan { iter_dir: None }, + join_info: None, + }]; + + let mut where_predicates = vec![]; // Parse the WHERE clause - let resolved_where_clauses = parse_where(where_clause, &referenced_tables)?; + parse_where(where_clause, &table_references, None, &mut where_predicates)?; - // Parse the LIMIT clause - let resolved_limit = limit.and_then(|l| parse_limit(*l)); + // Parse the LIMIT/OFFSET clause + let (resolved_limit, resolved_offset) = limit.map_or(Ok((None, None)), |l| parse_limit(*l))?; let plan = DeletePlan { - source: SourceOperator::Scan { - id: 0, - table_reference: btree_table_ref, - predicates: resolved_where_clauses.clone(), - iter_dir: None, - }, + table_references, result_columns: vec![], - where_clause: resolved_where_clauses, + where_clause: where_predicates, order_by: None, limit: resolved_limit, - referenced_tables, - available_indexes: vec![], + offset: resolved_offset, contains_constant_false_condition: false, }; diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 939a287f0..ef559a16c 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -1,8 +1,6 @@ // This module contains code for emitting bytecode instructions for SQL query execution. // It handles translating high-level SQL operations into low-level bytecode that can be executed by the virtual machine. -use std::collections::HashMap; - use sqlite3_parser::ast::{self}; use crate::function::Func; @@ -16,8 +14,8 @@ use super::aggregation::emit_ungrouped_aggregation; use super::group_by::{emit_group_by, init_group_by, GroupByMetadata}; use super::main_loop::{close_loop, emit_loop, init_loop, open_loop, LeftJoinMetadata, LoopLabels}; use super::order_by::{emit_order_by, init_order_by, SortMetadata}; -use super::plan::SelectPlan; -use super::plan::SourceOperator; +use super::plan::Operation; +use super::plan::{SelectPlan, TableReference}; use super::subquery::emit_subqueries; #[derive(Debug)] @@ -58,7 +56,7 @@ impl<'a> Resolver<'a> { #[derive(Debug)] pub struct TranslateCtx<'a> { // A typical query plan is a nested loop. Each loop has its own LoopLabels (see the definition of LoopLabels for more details) - pub labels_main_loop: HashMap, + pub labels_main_loop: Vec, // label for the instruction that jumps to the next phase of the query after the main loop // we don't know ahead of time what that is (GROUP BY, ORDER BY, etc.) pub label_main_loop_end: Option, @@ -68,15 +66,20 @@ pub struct TranslateCtx<'a> { pub reg_result_cols_start: Option, // The register holding the limit value, if any. pub reg_limit: Option, + // The register holding the offset value, if any. + pub reg_offset: Option, + // The register holding the limit+offset value, if any. + pub reg_limit_offset_sum: Option, // metadata for the group by operator pub meta_group_by: Option, // metadata for the order by operator pub meta_sort: Option, - // mapping between Join operator id and associated metadata (for left joins only) - pub meta_left_joins: HashMap, + /// mapping between table loop index and associated metadata (for left joins only) + /// this metadata exists for the right table in a given left join + pub meta_left_joins: Vec>, // We need to emit result columns in the order they are present in the SELECT, but they may not be in the same order in the ORDER BY sorter. // This vector holds the indexes of the result columns in the ORDER BY sorter. - pub result_column_indexes_in_orderby_sorter: HashMap, + pub result_column_indexes_in_orderby_sorter: Vec, // We might skip adding a SELECT result column into the ORDER BY sorter if it is an exact match in the ORDER BY keys. // This vector holds the indexes of the result columns that we need to skip. pub result_columns_to_skip_in_orderby_sorter: Option>, @@ -97,6 +100,8 @@ pub enum OperationMode { fn prologue<'a>( program: &mut ProgramBuilder, syms: &'a SymbolTable, + table_count: usize, + result_column_count: usize, ) -> Result<(TranslateCtx<'a>, BranchOffset, BranchOffset)> { let init_label = program.allocate_label(); @@ -107,15 +112,17 @@ fn prologue<'a>( let start_offset = program.offset(); let t_ctx = TranslateCtx { - labels_main_loop: HashMap::new(), + labels_main_loop: (0..table_count).map(|_| LoopLabels::new(program)).collect(), label_main_loop_end: None, reg_agg_start: None, reg_limit: None, + reg_offset: None, + reg_limit_offset_sum: None, reg_result_cols_start: None, meta_group_by: None, - meta_left_joins: HashMap::new(), + meta_left_joins: (0..table_count).map(|_| None).collect(), meta_sort: None, - result_column_indexes_in_orderby_sorter: HashMap::new(), + result_column_indexes_in_orderby_sorter: (0..result_column_count).collect(), result_columns_to_skip_in_orderby_sorter: None, resolver: Resolver::new(syms), }; @@ -161,7 +168,12 @@ fn emit_program_for_select( mut plan: SelectPlan, syms: &SymbolTable, ) -> Result<()> { - let (mut t_ctx, init_label, start_offset) = prologue(program, syms)?; + let (mut t_ctx, init_label, start_offset) = prologue( + program, + syms, + plan.table_references.len(), + plan.result_columns.len(), + )?; // Trivial exit on LIMIT 0 if let Some(limit) = plan.limit { @@ -189,17 +201,20 @@ pub fn emit_query<'a>( t_ctx: &'a mut TranslateCtx<'a>, ) -> Result { // Emit subqueries first so the results can be read in the main query loop. - emit_subqueries( - program, - t_ctx, - &mut plan.referenced_tables, - &mut plan.source, - )?; + emit_subqueries(program, t_ctx, &mut plan.table_references)?; if t_ctx.reg_limit.is_none() { t_ctx.reg_limit = plan.limit.map(|_| program.alloc_register()); } + if t_ctx.reg_offset.is_none() { + t_ctx.reg_offset = plan.offset.map(|_| program.alloc_register()); + } + + if t_ctx.reg_limit_offset_sum.is_none() { + t_ctx.reg_limit_offset_sum = plan.offset.map(|_| program.alloc_register()); + } + // No rows will be read from source table loops if there is a constant false condition eg. WHERE 0 // however an aggregation might still happen, // e.g. SELECT COUNT(*) WHERE 0 returns a row with 0, not an empty result set @@ -222,16 +237,21 @@ pub fn emit_query<'a>( if let Some(ref mut group_by) = plan.group_by { init_group_by(program, t_ctx, group_by, &plan.aggregates)?; } - init_loop(program, t_ctx, &plan.source, &OperationMode::SELECT)?; + init_loop( + program, + t_ctx, + &plan.table_references, + &OperationMode::SELECT, + )?; // Set up main query execution loop - open_loop(program, t_ctx, &mut plan.source, &plan.referenced_tables)?; + open_loop(program, t_ctx, &plan.table_references, &plan.where_clause)?; // Process result columns and expressions in the inner loop emit_loop(program, t_ctx, plan)?; // Clean up and close the main execution loop - close_loop(program, t_ctx, &plan.source)?; + close_loop(program, t_ctx, &plan.table_references)?; program.resolve_label(after_main_loop_label, program.offset()); @@ -260,7 +280,12 @@ fn emit_program_for_delete( mut plan: DeletePlan, syms: &SymbolTable, ) -> Result<()> { - let (mut t_ctx, init_label, start_offset) = prologue(program, syms)?; + let (mut t_ctx, init_label, start_offset) = prologue( + program, + syms, + plan.table_references.len(), + plan.result_columns.len(), + )?; // No rows will be read from source table loops if there is a constant false condition eg. WHERE 0 let after_main_loop_label = program.allocate_label(); @@ -271,20 +296,25 @@ fn emit_program_for_delete( } // Initialize cursors and other resources needed for query execution - init_loop(program, &mut t_ctx, &plan.source, &OperationMode::DELETE)?; + init_loop( + program, + &mut t_ctx, + &plan.table_references, + &OperationMode::DELETE, + )?; // Set up main query execution loop open_loop( program, &mut t_ctx, - &mut plan.source, - &plan.referenced_tables, + &mut plan.table_references, + &plan.where_clause, )?; - emit_delete_insns(program, &mut t_ctx, &plan.source, &plan.limit)?; + emit_delete_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?; // Clean up and close the main execution loop - close_loop(program, &mut t_ctx, &plan.source)?; + close_loop(program, &mut t_ctx, &plan.table_references)?; program.resolve_label(after_main_loop_label, program.offset()); @@ -301,20 +331,15 @@ fn emit_program_for_delete( fn emit_delete_insns( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - source: &SourceOperator, - limit: &Option, + table_references: &[TableReference], + limit: &Option, ) -> Result<()> { - let cursor_id = match source { - SourceOperator::Scan { - table_reference, .. - } => program.resolve_cursor_id(&table_reference.table_identifier), - SourceOperator::Search { - table_reference, - search, - .. - } => match search { + let table_reference = table_references.first().unwrap(); + let cursor_id = match &table_reference.op { + Operation::Scan { .. } => program.resolve_cursor_id(&table_reference.identifier), + Operation::Search(search) => match search { Search::RowidEq { .. } | Search::RowidSearch { .. } => { - program.resolve_cursor_id(&table_reference.table_identifier) + program.resolve_cursor_id(&table_reference.identifier) } Search::IndexSearch { index, .. } => program.resolve_cursor_id(&index.name), }, diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 9fdf5f9c2..c765aa6b7 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -5,11 +5,15 @@ use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc}; use crate::schema::Type; use crate::util::normalize_ident; -use crate::vdbe::{builder::ProgramBuilder, insn::Insn, BranchOffset}; +use crate::vdbe::{ + builder::ProgramBuilder, + insn::{CmpInsFlags, Insn}, + BranchOffset, +}; use crate::Result; use super::emitter::Resolver; -use super::plan::{TableReference, TableReferenceType}; +use super::plan::{Operation, TableReference}; #[derive(Debug, Clone, Copy)] pub struct ConditionMetadata { @@ -47,14 +51,41 @@ macro_rules! emit_cmp_insn { lhs: $lhs, rhs: $rhs, target_pc: $cond.jump_target_when_true, - jump_if_null: false, + flags: CmpInsFlags::default(), }); } else { $program.emit_insn(Insn::$op_false { lhs: $lhs, rhs: $rhs, target_pc: $cond.jump_target_when_false, - jump_if_null: true, + flags: CmpInsFlags::default().jump_if_null(), + }); + } + }}; +} + +macro_rules! emit_cmp_null_insn { + ( + $program:expr, + $cond:expr, + $op_true:ident, + $op_false:ident, + $lhs:expr, + $rhs:expr + ) => {{ + if $cond.jump_if_condition_is_true { + $program.emit_insn(Insn::$op_true { + lhs: $lhs, + rhs: $rhs, + target_pc: $cond.jump_target_when_true, + flags: CmpInsFlags::default().null_eq(), + }); + } else { + $program.emit_insn(Insn::$op_false { + lhs: $lhs, + rhs: $rhs, + target_pc: $cond.jump_target_when_false, + flags: CmpInsFlags::default().null_eq(), }); } }}; @@ -226,8 +257,12 @@ pub fn translate_condition_expr( ast::Operator::NotEquals => { emit_cmp_insn!(program, condition_metadata, Ne, Eq, lhs_reg, rhs_reg) } - ast::Operator::Is => todo!(), - ast::Operator::IsNot => todo!(), + ast::Operator::Is => { + emit_cmp_null_insn!(program, condition_metadata, Eq, Ne, lhs_reg, rhs_reg) + } + ast::Operator::IsNot => { + emit_cmp_null_insn!(program, condition_metadata, Ne, Eq, lhs_reg, rhs_reg) + } _ => { todo!("op {:?} not implemented", op); } @@ -326,7 +361,7 @@ pub fn translate_condition_expr( lhs: lhs_reg, rhs: rhs_reg, target_pc: jump_target_when_true, - jump_if_null: false, + flags: CmpInsFlags::default(), }); } else { // If this is the last condition, we need to jump to the 'jump_target_when_false' label if there is no match. @@ -334,7 +369,7 @@ pub fn translate_condition_expr( lhs: lhs_reg, rhs: rhs_reg, target_pc: condition_metadata.jump_target_when_false, - jump_if_null: true, + flags: CmpInsFlags::default().jump_if_null(), }); } } @@ -355,7 +390,7 @@ pub fn translate_condition_expr( lhs: lhs_reg, rhs: rhs_reg, target_pc: condition_metadata.jump_target_when_false, - jump_if_null: true, + flags: CmpInsFlags::default().jump_if_null(), }); } // If we got here, then none of the conditions were a match, so we jump to the 'jump_target_when_true' label if 'jump_if_condition_is_true'. @@ -444,6 +479,22 @@ pub fn translate_condition_expr( ); } } + ast::Expr::NotNull(expr) => { + let cur_reg = program.alloc_register(); + translate_expr(program, Some(referenced_tables), expr, cur_reg, resolver)?; + program.emit_insn(Insn::IsNull { + reg: cur_reg, + target_pc: condition_metadata.jump_target_when_false, + }); + } + ast::Expr::IsNull(expr) => { + let cur_reg = program.alloc_register(); + translate_expr(program, Some(referenced_tables), expr, cur_reg, resolver)?; + program.emit_insn(Insn::NotNull { + reg: cur_reg, + target_pc: condition_metadata.jump_target_when_false, + }); + } _ => todo!("op {:?} not implemented", expr), } Ok(()) @@ -482,7 +533,7 @@ pub fn translate_expr( lhs: e1_reg, rhs: e2_reg, target_pc: if_true_label, - jump_if_null: false, + flags: CmpInsFlags::default(), }, target_register, if_true_label, @@ -498,7 +549,7 @@ pub fn translate_expr( lhs: e1_reg, rhs: e2_reg, target_pc: if_true_label, - jump_if_null: false, + flags: CmpInsFlags::default(), }, target_register, if_true_label, @@ -514,7 +565,7 @@ pub fn translate_expr( lhs: e1_reg, rhs: e2_reg, target_pc: if_true_label, - jump_if_null: false, + flags: CmpInsFlags::default(), }, target_register, if_true_label, @@ -530,7 +581,7 @@ pub fn translate_expr( lhs: e1_reg, rhs: e2_reg, target_pc: if_true_label, - jump_if_null: false, + flags: CmpInsFlags::default(), }, target_register, if_true_label, @@ -546,7 +597,7 @@ pub fn translate_expr( lhs: e1_reg, rhs: e2_reg, target_pc: if_true_label, - jump_if_null: false, + flags: CmpInsFlags::default(), }, target_register, if_true_label, @@ -562,7 +613,7 @@ pub fn translate_expr( lhs: e1_reg, rhs: e2_reg, target_pc: if_true_label, - jump_if_null: false, + flags: CmpInsFlags::default(), }, target_register, if_true_label, @@ -655,7 +706,7 @@ pub fn translate_expr( lhs: e1_reg, rhs: e2_reg, target_pc: if_true_label, - jump_if_null: false, + flags: CmpInsFlags::default().null_eq(), }, target_register, if_true_label, @@ -669,7 +720,7 @@ pub fn translate_expr( lhs: e1_reg, rhs: e2_reg, target_pc: if_true_label, - jump_if_null: false, + flags: CmpInsFlags::default().null_eq(), }, target_register, if_true_label, @@ -740,7 +791,7 @@ pub fn translate_expr( lhs: base_reg, rhs: expr_reg, target_pc: next_case_label, - jump_if_null: false, + flags: CmpInsFlags::default(), }), // CASE WHEN 0 THEN 0 ELSE 1 becomes ifnot 0 branch to next clause None => program.emit_insn(Insn::IfNot { @@ -867,14 +918,16 @@ pub fn translate_expr( func_ctx, ) } - JsonFunc::JsonArray | JsonFunc::JsonExtract => translate_function( - program, - args.as_deref().unwrap_or_default(), - referenced_tables, - resolver, - target_register, - func_ctx, - ), + JsonFunc::JsonArray | JsonFunc::JsonExtract | JsonFunc::JsonSet => { + translate_function( + program, + args.as_deref().unwrap_or_default(), + referenced_tables, + resolver, + target_register, + func_ctx, + ) + } JsonFunc::JsonArrowExtract | JsonFunc::JsonArrowShiftExtract => { unreachable!( "These two functions are only reachable via the -> and ->> operators" @@ -937,6 +990,45 @@ pub fn translate_expr( target_register, func_ctx, ), + JsonFunc::JsonPatch => { + let args = expect_arguments_exact!(args, 2, j); + translate_function( + program, + args, + referenced_tables, + resolver, + target_register, + func_ctx, + ) + } + JsonFunc::JsonRemove => { + if let Some(args) = args { + for arg in args.iter() { + // register containing result of each argument expression + let _ = + translate_and_mark(program, referenced_tables, arg, resolver)?; + } + } + program.emit_insn(Insn::Function { + constant_mask: 0, + start_reg: target_register + 1, + dest: target_register, + func: func_ctx, + }); + Ok(target_register) + } + JsonFunc::JsonPretty => { + let args = expect_arguments_max!(args, 2, j); + + translate_function( + program, + args, + referenced_tables, + resolver, + target_register, + func_ctx, + ) + } }, Func::Scalar(srf) => { match srf { @@ -1068,9 +1160,10 @@ pub fn translate_expr( temp_reg, resolver, )?; + let before_copy_label = program.allocate_label(); program.emit_insn(Insn::NotNull { reg: temp_reg, - target_pc: program.offset().add(2u32), + target_pc: before_copy_label, }); translate_expr( @@ -1080,6 +1173,7 @@ pub fn translate_expr( temp_reg, resolver, )?; + program.resolve_label(before_copy_label, program.offset()); program.emit_insn(Insn::Copy { src_reg: temp_reg, dst_reg: target_register, @@ -1147,15 +1241,21 @@ pub fn translate_expr( srf.to_string() ); }; - for arg in args { - let _ = - translate_and_mark(program, referenced_tables, arg, resolver); + let func_registers = program.alloc_registers(args.len()); + for (i, arg) in args.iter().enumerate() { + let _ = translate_expr( + program, + referenced_tables, + arg, + func_registers + i, + resolver, + )?; } program.emit_insn(Insn::Function { // Only constant patterns for LIKE are supported currently, so this // is always 1 constant_mask: 1, - start_reg: target_register + 1, + start_reg: func_registers, dest: target_register, func: func_ctx, }); @@ -1492,6 +1592,28 @@ pub fn translate_expr( }); Ok(target_register) } + ScalarFunc::SqliteSourceId => { + if args.is_some() { + crate::bail_parse_error!( + "sqlite_source_id function with arguments" + ); + } + + let output_register = program.alloc_register(); + program.emit_insn(Insn::Function { + constant_mask: 0, + start_reg: output_register, + dest: output_register, + func: func_ctx, + }); + + program.emit_insn(Insn::Copy { + src_reg: output_register, + dst_reg: target_register, + amount: 0, + }); + Ok(target_register) + } ScalarFunc::Replace => { let args = if let Some(args) = args { if !args.len() == 3 { @@ -1559,6 +1681,14 @@ pub fn translate_expr( }); Ok(target_register) } + ScalarFunc::Printf => translate_function( + program, + args.as_deref().unwrap_or(&[]), + referenced_tables, + resolver, + target_register, + func_ctx, + ), } } Func::Math(math_func) => match math_func.arity() { @@ -1626,19 +1756,24 @@ pub fn translate_expr( } } ast::Expr::FunctionCallStar { .. } => todo!(), - ast::Expr::Id(_) => unreachable!("Id should be resolved to a Column before translation"), + ast::Expr::Id(id) => { + crate::bail_parse_error!( + "no such column: {} - should this be a string literal in single-quotes?", + id.0 + ) + } ast::Expr::Column { database: _, table, column, is_rowid_alias, } => { - let tbl_ref = referenced_tables.as_ref().unwrap().get(*table).unwrap(); - match tbl_ref.reference_type { + let table_reference = referenced_tables.as_ref().unwrap().get(*table).unwrap(); + match table_reference.op { // If we are reading a column from a table, we find the cursor that corresponds to // the table and read the column from the cursor. - TableReferenceType::BTreeTable => { - let cursor_id = program.resolve_cursor_id(&tbl_ref.table_identifier); + Operation::Scan { .. } | Operation::Search(_) => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); if *is_rowid_alias { program.emit_insn(Insn::RowId { cursor_id, @@ -1651,13 +1786,13 @@ pub fn translate_expr( dest: target_register, }); } - let column = tbl_ref.table.get_column_at(*column); + let column = table_reference.table.get_column_at(*column); maybe_apply_affinity(column.ty, target_register, program); Ok(target_register) } // If we are reading a column from a subquery, we instead copy the column from the // subquery's result registers. - TableReferenceType::Subquery { + Operation::Subquery { result_columns_start_reg, .. } => { @@ -1671,8 +1806,8 @@ pub fn translate_expr( } } ast::Expr::RowId { database: _, table } => { - let tbl_ref = referenced_tables.as_ref().unwrap().get(*table).unwrap(); - let cursor_id = program.resolve_cursor_id(&tbl_ref.table_identifier); + let table_reference = referenced_tables.as_ref().unwrap().get(*table).unwrap(); + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); program.emit_insn(Insn::RowId { cursor_id, dest: target_register, @@ -1768,22 +1903,33 @@ pub fn translate_expr( UnaryOperator::Negative | UnaryOperator::Positive, ast::Expr::Literal(ast::Literal::Numeric(numeric_value)), ) => { - let maybe_int = numeric_value.parse::(); let multiplier = if let UnaryOperator::Negative = op { -1 } else { 1 }; - if let Ok(value) = maybe_int { + + // Special case: if we're negating "9223372036854775808", this is exactly MIN_INT64 + // If we don't do this -1 * 9223372036854775808 will overflow and parse will fail and trigger conversion to Real. + if multiplier == -1 && numeric_value == "9223372036854775808" { program.emit_insn(Insn::Integer { - value: value * multiplier, + value: i64::MIN, dest: target_register, }); } else { - program.emit_insn(Insn::Real { - value: multiplier as f64 * numeric_value.parse::()?, - dest: target_register, - }); + let maybe_int = numeric_value.parse::(); + if let Ok(value) = maybe_int { + program.emit_insn(Insn::Integer { + value: value * multiplier, + dest: target_register, + }); + } else { + let value = numeric_value.parse::()?; + program.emit_insn(Insn::Real { + value: value * multiplier as f64, + dest: target_register, + }); + } } Ok(target_register) } @@ -1971,8 +2117,8 @@ pub fn get_name( } match expr { ast::Expr::Column { table, column, .. } => { - let table_ref = referenced_tables.get(*table).unwrap(); - table_ref.table.get_column_at(*column).name.clone() + let table_reference = referenced_tables.get(*table).unwrap(); + table_reference.table.get_column_at(*column).name.clone() } _ => fallback(), } diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index dfb92f20a..b537257a0 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -165,12 +165,16 @@ pub fn emit_group_by<'a>( .map(|agg| agg.args.len()) .sum::(); // sorter column names do not matter + let ty = crate::schema::Type::Null; let pseudo_columns = (0..sorter_column_count) .map(|i| Column { name: i.to_string(), primary_key: false, - ty: crate::schema::Type::Null, + ty, + ty_str: ty.to_string(), is_rowid_alias: false, + notnull: false, + default: None, }) .collect::>(); @@ -270,7 +274,7 @@ pub fn emit_group_by<'a>( let agg_result_reg = start_reg + i; translate_aggregation_step_groupby( program, - &plan.referenced_tables, + &plan.table_references, pseudo_cursor, cursor_index, agg, @@ -380,7 +384,7 @@ pub fn emit_group_by<'a>( for expr in having.iter() { translate_condition_expr( program, - &plan.referenced_tables, + &plan.table_references, expr, ConditionMetadata { jump_if_condition_is_true: false, @@ -394,7 +398,13 @@ pub fn emit_group_by<'a>( match &plan.order_by { None => { - emit_select_result(program, t_ctx, plan, Some(label_group_by_end))?; + emit_select_result( + program, + t_ctx, + plan, + Some(label_group_by_end), + Some(group_by_end_without_emitting_row_label), + )?; } Some(_) => { order_by_sorter_insert(program, t_ctx, plan)?; diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 9ce449b9e..b4fff9c7a 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -4,7 +4,7 @@ use crate::{ translate::result_row::emit_select_result, vdbe::{ builder::{CursorType, ProgramBuilder}, - insn::Insn, + insn::{CmpInsFlags, Insn}, BranchOffset, }, Result, @@ -16,7 +16,8 @@ use super::{ expr::{translate_condition_expr, translate_expr, ConditionMetadata}, order_by::{order_by_sorter_insert, sorter_insert}, plan::{ - IterationDirection, Search, SelectPlan, SelectQueryType, SourceOperator, TableReference, + IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference, + WhereTerm, }, }; @@ -42,122 +43,59 @@ pub struct LoopLabels { loop_end: BranchOffset, } +impl LoopLabels { + pub fn new(program: &mut ProgramBuilder) -> Self { + Self { + loop_start: program.allocate_label(), + next: program.allocate_label(), + loop_end: program.allocate_label(), + } + } +} + /// Initialize resources needed for the source operators (tables, joins, etc) pub fn init_loop( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - source: &SourceOperator, + tables: &[TableReference], mode: &OperationMode, ) -> Result<()> { - let operator_id = source.id(); - let loop_labels = LoopLabels { - next: program.allocate_label(), - loop_start: program.allocate_label(), - loop_end: program.allocate_label(), - }; - t_ctx.labels_main_loop.insert(operator_id, loop_labels); - - match source { - SourceOperator::Subquery { .. } => Ok(()), - SourceOperator::Join { - id, - left, - right, - outer, - .. - } => { - if *outer { + assert!( + t_ctx.meta_left_joins.len() == tables.len(), + "meta_left_joins length does not match tables length" + ); + for (table_index, table) in tables.iter().enumerate() { + // Initialize bookkeeping for OUTER JOIN + if let Some(join_info) = table.join_info.as_ref() { + if join_info.outer { let lj_metadata = LeftJoinMetadata { reg_match_flag: program.alloc_register(), label_match_flag_set_true: program.allocate_label(), label_match_flag_check_value: program.allocate_label(), }; - t_ctx.meta_left_joins.insert(*id, lj_metadata); + t_ctx.meta_left_joins[table_index] = Some(lj_metadata); } - init_loop(program, t_ctx, left, mode)?; - init_loop(program, t_ctx, right, mode)?; - - Ok(()) } - SourceOperator::Scan { - table_reference, .. - } => { - let cursor_id = program.alloc_cursor_id( - Some(table_reference.table_identifier.clone()), - CursorType::BTreeTable(table_reference.btree().unwrap().clone()), - ); - let root_page = table_reference.table.get_root_page(); - - match mode { - OperationMode::SELECT => { - program.emit_insn(Insn::OpenReadAsync { - cursor_id, - root_page, - }); - program.emit_insn(Insn::OpenReadAwait {}); - } - OperationMode::DELETE => { - program.emit_insn(Insn::OpenWriteAsync { - cursor_id, - root_page, - }); - program.emit_insn(Insn::OpenWriteAwait {}); - } - _ => { - unimplemented!() - } - } - - Ok(()) - } - SourceOperator::Search { - table_reference, - search, - .. - } => { - let table_cursor_id = program.alloc_cursor_id( - Some(table_reference.table_identifier.clone()), - CursorType::BTreeTable(table_reference.btree().unwrap().clone()), - ); - - match mode { - OperationMode::SELECT => { - program.emit_insn(Insn::OpenReadAsync { - cursor_id: table_cursor_id, - root_page: table_reference.table.get_root_page(), - }); - program.emit_insn(Insn::OpenReadAwait {}); - } - OperationMode::DELETE => { - program.emit_insn(Insn::OpenWriteAsync { - cursor_id: table_cursor_id, - root_page: table_reference.table.get_root_page(), - }); - program.emit_insn(Insn::OpenWriteAwait {}); - } - _ => { - unimplemented!() - } - } - - if let Search::IndexSearch { index, .. } = search { - let index_cursor_id = program.alloc_cursor_id( - Some(index.name.clone()), - CursorType::BTreeIndex(index.clone()), + match &table.op { + Operation::Scan { .. } => { + let cursor_id = program.alloc_cursor_id( + Some(table.identifier.clone()), + CursorType::BTreeTable(table.btree().unwrap().clone()), ); + let root_page = table.table.get_root_page(); match mode { OperationMode::SELECT => { program.emit_insn(Insn::OpenReadAsync { - cursor_id: index_cursor_id, - root_page: index.root_page, + cursor_id, + root_page, }); - program.emit_insn(Insn::OpenReadAwait); + program.emit_insn(Insn::OpenReadAwait {}); } OperationMode::DELETE => { program.emit_insn(Insn::OpenWriteAsync { - cursor_id: index_cursor_id, - root_page: index.root_page, + cursor_id, + root_page, }); program.emit_insn(Insn::OpenWriteAwait {}); } @@ -166,11 +104,64 @@ pub fn init_loop( } } } + Operation::Search(search) => { + let table_cursor_id = program.alloc_cursor_id( + Some(table.identifier.clone()), + CursorType::BTreeTable(table.btree().unwrap().clone()), + ); - Ok(()) + match mode { + OperationMode::SELECT => { + program.emit_insn(Insn::OpenReadAsync { + cursor_id: table_cursor_id, + root_page: table.table.get_root_page(), + }); + program.emit_insn(Insn::OpenReadAwait {}); + } + OperationMode::DELETE => { + program.emit_insn(Insn::OpenWriteAsync { + cursor_id: table_cursor_id, + root_page: table.table.get_root_page(), + }); + program.emit_insn(Insn::OpenWriteAwait {}); + } + _ => { + unimplemented!() + } + } + + if let Search::IndexSearch { index, .. } = search { + let index_cursor_id = program.alloc_cursor_id( + Some(index.name.clone()), + CursorType::BTreeIndex(index.clone()), + ); + + match mode { + OperationMode::SELECT => { + program.emit_insn(Insn::OpenReadAsync { + cursor_id: index_cursor_id, + root_page: index.root_page, + }); + program.emit_insn(Insn::OpenReadAwait); + } + OperationMode::DELETE => { + program.emit_insn(Insn::OpenWriteAsync { + cursor_id: index_cursor_id, + root_page: index.root_page, + }); + program.emit_insn(Insn::OpenWriteAwait {}); + } + _ => { + unimplemented!() + } + } + } + } + _ => {} } - SourceOperator::Nothing { .. } => Ok(()), } + + Ok(()) } /// Set up the main query execution loop @@ -179,52 +170,64 @@ pub fn init_loop( pub fn open_loop( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - source: &mut SourceOperator, - referenced_tables: &[TableReference], + tables: &[TableReference], + predicates: &[WhereTerm], ) -> Result<()> { - match source { - SourceOperator::Subquery { - id, - predicates, - plan, - .. - } => { - let (yield_reg, coroutine_implementation_start) = match &plan.query_type { - SelectQueryType::Subquery { - yield_reg, - coroutine_implementation_start, - } => (*yield_reg, *coroutine_implementation_start), - _ => unreachable!("Subquery operator with non-subquery query type"), - }; - // In case the subquery is an inner loop, it needs to be reinitialized on each iteration of the outer loop. - program.emit_insn(Insn::InitCoroutine { - yield_reg, - jump_on_definition: BranchOffset::Offset(0), - start_offset: coroutine_implementation_start, - }); - let LoopLabels { - loop_start, - loop_end, - next, - } = *t_ctx - .labels_main_loop - .get(id) - .expect("subquery has no loop labels"); - program.resolve_label(loop_start, program.offset()); - // A subquery within the main loop of a parent query has no cursor, so instead of advancing the cursor, - // it emits a Yield which jumps back to the main loop of the subquery itself to retrieve the next row. - // When the subquery coroutine completes, this instruction jumps to the label at the top of the termination_label_stack, - // which in this case is the end of the Yield-Goto loop in the parent query. - program.emit_insn(Insn::Yield { - yield_reg, - end_offset: loop_end, - }); + for (table_index, table) in tables.iter().enumerate() { + let LoopLabels { + loop_start, + loop_end, + next, + } = *t_ctx + .labels_main_loop + .get(table_index) + .expect("table has no loop labels"); - // These are predicates evaluated outside of the subquery, - // so they are translated here. - // E.g. SELECT foo FROM (SELECT bar as foo FROM t1) sub WHERE sub.foo > 10 - if let Some(preds) = predicates { - for expr in preds { + // Each OUTER JOIN has a "match flag" that is initially set to false, + // and is set to true when a match is found for the OUTER JOIN. + // This is used to determine whether to emit actual columns or NULLs for the columns of the right table. + if let Some(join_info) = table.join_info.as_ref() { + if join_info.outer { + let lj_meta = t_ctx.meta_left_joins[table_index].as_ref().unwrap(); + program.emit_insn(Insn::Integer { + value: 0, + dest: lj_meta.reg_match_flag, + }); + } + } + + match &table.op { + Operation::Subquery { plan, .. } => { + let (yield_reg, coroutine_implementation_start) = match &plan.query_type { + SelectQueryType::Subquery { + yield_reg, + coroutine_implementation_start, + } => (*yield_reg, *coroutine_implementation_start), + _ => unreachable!("Subquery operator with non-subquery query type"), + }; + // In case the subquery is an inner loop, it needs to be reinitialized on each iteration of the outer loop. + program.emit_insn(Insn::InitCoroutine { + yield_reg, + jump_on_definition: BranchOffset::Offset(0), + start_offset: coroutine_implementation_start, + }); + program.resolve_label(loop_start, program.offset()); + // A subquery within the main loop of a parent query has no cursor, so instead of advancing the cursor, + // it emits a Yield which jumps back to the main loop of the subquery itself to retrieve the next row. + // When the subquery coroutine completes, this instruction jumps to the label at the top of the termination_label_stack, + // which in this case is the end of the Yield-Goto loop in the parent query. + program.emit_insn(Insn::Yield { + yield_reg, + end_offset: loop_end, + }); + + // These are predicates evaluated outside of the subquery, + // so they are translated here. + // E.g. SELECT foo FROM (SELECT bar as foo FROM t1) sub WHERE sub.foo > 10 + for cond in predicates + .iter() + .filter(|cond| cond.eval_at_loop == table_index) + { let jump_target_when_true = program.allocate_label(); let condition_metadata = ConditionMetadata { jump_if_condition_is_true: false, @@ -233,325 +236,253 @@ pub fn open_loop( }; translate_condition_expr( program, - referenced_tables, - expr, + tables, + &cond.expr, condition_metadata, &t_ctx.resolver, )?; program.resolve_label(jump_target_when_true, program.offset()); } } + Operation::Scan { iter_dir } => { + let cursor_id = program.resolve_cursor_id(&table.identifier); + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + program.emit_insn(Insn::LastAsync { cursor_id }); + } else { + program.emit_insn(Insn::RewindAsync { cursor_id }); + } + program.emit_insn( + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + Insn::LastAwait { + cursor_id, + pc_if_empty: loop_end, + } + } else { + Insn::RewindAwait { + cursor_id, + pc_if_empty: loop_end, + } + }, + ); + program.resolve_label(loop_start, program.offset()); - Ok(()) - } - SourceOperator::Join { - id, - left, - right, - predicates, - outer, - .. - } => { - open_loop(program, t_ctx, left, referenced_tables)?; - - let LoopLabels { next, .. } = *t_ctx - .labels_main_loop - .get(&right.id()) - .expect("right side of join has no loop labels"); - - let mut jump_target_when_false = next; - - if *outer { - let lj_meta = t_ctx.meta_left_joins.get(id).unwrap(); - program.emit_insn(Insn::Integer { - value: 0, - dest: lj_meta.reg_match_flag, - }); - jump_target_when_false = lj_meta.label_match_flag_check_value; - } - - open_loop(program, t_ctx, right, referenced_tables)?; - - if let Some(predicates) = predicates { - let jump_target_when_true = program.allocate_label(); - let condition_metadata = ConditionMetadata { - jump_if_condition_is_true: false, - jump_target_when_true, - jump_target_when_false, - }; - for predicate in predicates.iter() { + for cond in predicates + .iter() + .filter(|cond| cond.eval_at_loop == table_index) + { + let jump_target_when_true = program.allocate_label(); + let condition_metadata = ConditionMetadata { + jump_if_condition_is_true: false, + jump_target_when_true, + jump_target_when_false: next, + }; translate_condition_expr( program, - referenced_tables, - predicate, + tables, + &cond.expr, condition_metadata, &t_ctx.resolver, )?; + program.resolve_label(jump_target_when_true, program.offset()); } - program.resolve_label(jump_target_when_true, program.offset()); } + Operation::Search(search) => { + let table_cursor_id = program.resolve_cursor_id(&table.identifier); + // Open the loop for the index search. + // Rowid equality point lookups are handled with a SeekRowid instruction which does not loop, since it is a single row lookup. + if !matches!(search, Search::RowidEq { .. }) { + let index_cursor_id = if let Search::IndexSearch { index, .. } = search { + Some(program.resolve_cursor_id(&index.name)) + } else { + None + }; + let cmp_reg = program.alloc_register(); + let (cmp_expr, cmp_op) = match search { + Search::IndexSearch { + cmp_expr, cmp_op, .. + } => (cmp_expr, cmp_op), + Search::RowidSearch { cmp_expr, cmp_op } => (cmp_expr, cmp_op), + Search::RowidEq { .. } => unreachable!(), + }; - if *outer { - let lj_meta = t_ctx.meta_left_joins.get(id).unwrap(); + // TODO this only handles ascending indexes + match cmp_op { + ast::Operator::Equals + | ast::Operator::Greater + | ast::Operator::GreaterEquals => { + translate_expr( + program, + Some(tables), + &cmp_expr.expr, + cmp_reg, + &t_ctx.resolver, + )?; + } + ast::Operator::Less | ast::Operator::LessEquals => { + program.emit_insn(Insn::Null { + dest: cmp_reg, + dest_end: None, + }); + } + _ => unreachable!(), + } + // If we try to seek to a key that is not present in the table/index, we exit the loop entirely. + program.emit_insn(match cmp_op { + ast::Operator::Equals | ast::Operator::GreaterEquals => Insn::SeekGE { + is_index: index_cursor_id.is_some(), + cursor_id: index_cursor_id.unwrap_or(table_cursor_id), + start_reg: cmp_reg, + num_regs: 1, + target_pc: loop_end, + }, + ast::Operator::Greater + | ast::Operator::Less + | ast::Operator::LessEquals => Insn::SeekGT { + is_index: index_cursor_id.is_some(), + cursor_id: index_cursor_id.unwrap_or(table_cursor_id), + start_reg: cmp_reg, + num_regs: 1, + target_pc: loop_end, + }, + _ => unreachable!(), + }); + if *cmp_op == ast::Operator::Less || *cmp_op == ast::Operator::LessEquals { + translate_expr( + program, + Some(tables), + &cmp_expr.expr, + cmp_reg, + &t_ctx.resolver, + )?; + } + + program.resolve_label(loop_start, program.offset()); + // TODO: We are currently only handling ascending indexes. + // For conditions like index_key > 10, we have already seeked to the first key greater than 10, and can just scan forward. + // For conditions like index_key < 10, we are at the beginning of the index, and will scan forward and emit IdxGE(10) with a conditional jump to the end. + // For conditions like index_key = 10, we have already seeked to the first key greater than or equal to 10, and can just scan forward and emit IdxGT(10) with a conditional jump to the end. + // For conditions like index_key >= 10, we have already seeked to the first key greater than or equal to 10, and can just scan forward. + // For conditions like index_key <= 10, we are at the beginning of the index, and will scan forward and emit IdxGT(10) with a conditional jump to the end. + // For conditions like index_key != 10, TODO. probably the optimal way is not to use an index at all. + // + // For primary key searches we emit RowId and then compare it to the seek value. + + match cmp_op { + ast::Operator::Equals | ast::Operator::LessEquals => { + if let Some(index_cursor_id) = index_cursor_id { + program.emit_insn(Insn::IdxGT { + cursor_id: index_cursor_id, + start_reg: cmp_reg, + num_regs: 1, + target_pc: loop_end, + }); + } else { + let rowid_reg = program.alloc_register(); + program.emit_insn(Insn::RowId { + cursor_id: table_cursor_id, + dest: rowid_reg, + }); + program.emit_insn(Insn::Gt { + lhs: rowid_reg, + rhs: cmp_reg, + target_pc: loop_end, + flags: CmpInsFlags::default(), + }); + } + } + ast::Operator::Less => { + if let Some(index_cursor_id) = index_cursor_id { + program.emit_insn(Insn::IdxGE { + cursor_id: index_cursor_id, + start_reg: cmp_reg, + num_regs: 1, + target_pc: loop_end, + }); + } else { + let rowid_reg = program.alloc_register(); + program.emit_insn(Insn::RowId { + cursor_id: table_cursor_id, + dest: rowid_reg, + }); + program.emit_insn(Insn::Ge { + lhs: rowid_reg, + rhs: cmp_reg, + target_pc: loop_end, + flags: CmpInsFlags::default(), + }); + } + } + _ => {} + } + + if let Some(index_cursor_id) = index_cursor_id { + program.emit_insn(Insn::DeferredSeek { + index_cursor_id, + table_cursor_id, + }); + } + } + + if let Search::RowidEq { cmp_expr } = search { + let src_reg = program.alloc_register(); + translate_expr( + program, + Some(tables), + &cmp_expr.expr, + src_reg, + &t_ctx.resolver, + )?; + program.emit_insn(Insn::SeekRowid { + cursor_id: table_cursor_id, + src_reg, + target_pc: next, + }); + } + for cond in predicates + .iter() + .filter(|cond| cond.eval_at_loop == table_index) + { + let jump_target_when_true = program.allocate_label(); + let condition_metadata = ConditionMetadata { + jump_if_condition_is_true: false, + jump_target_when_true, + jump_target_when_false: next, + }; + translate_condition_expr( + program, + tables, + &cond.expr, + condition_metadata, + &t_ctx.resolver, + )?; + program.resolve_label(jump_target_when_true, program.offset()); + } + } + } + + // Set the match flag to true if this is a LEFT JOIN. + // At this point of execution we are going to emit columns for the left table, + // and either emit columns or NULLs for the right table, depending on whether the null_flag is set + // for the right table's cursor. + if let Some(join_info) = table.join_info.as_ref() { + if join_info.outer { + let lj_meta = t_ctx.meta_left_joins[table_index].as_ref().unwrap(); program.resolve_label(lj_meta.label_match_flag_set_true, program.offset()); program.emit_insn(Insn::Integer { value: 1, dest: lj_meta.reg_match_flag, }); } - - Ok(()) } - SourceOperator::Scan { - id, - table_reference, - predicates, - iter_dir, - } => { - let cursor_id = program.resolve_cursor_id(&table_reference.table_identifier); - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::LastAsync { cursor_id }); - } else { - program.emit_insn(Insn::RewindAsync { cursor_id }); - } - let LoopLabels { - loop_start, - loop_end, - next, - } = *t_ctx - .labels_main_loop - .get(id) - .expect("scan has no loop labels"); - program.emit_insn( - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - Insn::LastAwait { - cursor_id, - pc_if_empty: loop_end, - } - } else { - Insn::RewindAwait { - cursor_id, - pc_if_empty: loop_end, - } - }, - ); - program.resolve_label(loop_start, program.offset()); - - if let Some(preds) = predicates { - for expr in preds { - let jump_target_when_true = program.allocate_label(); - let condition_metadata = ConditionMetadata { - jump_if_condition_is_true: false, - jump_target_when_true, - jump_target_when_false: next, - }; - translate_condition_expr( - program, - referenced_tables, - expr, - condition_metadata, - &t_ctx.resolver, - )?; - program.resolve_label(jump_target_when_true, program.offset()); - } - } - - Ok(()) - } - SourceOperator::Search { - id, - table_reference, - search, - predicates, - .. - } => { - let table_cursor_id = program.resolve_cursor_id(&table_reference.table_identifier); - let LoopLabels { - loop_start, - loop_end, - next, - } = *t_ctx - .labels_main_loop - .get(id) - .expect("search has no loop labels"); - // Open the loop for the index search. - // Rowid equality point lookups are handled with a SeekRowid instruction which does not loop, since it is a single row lookup. - if !matches!(search, Search::RowidEq { .. }) { - let index_cursor_id = if let Search::IndexSearch { index, .. } = search { - Some(program.resolve_cursor_id(&index.name)) - } else { - None - }; - let cmp_reg = program.alloc_register(); - let (cmp_expr, cmp_op) = match search { - Search::IndexSearch { - cmp_expr, cmp_op, .. - } => (cmp_expr, cmp_op), - Search::RowidSearch { cmp_expr, cmp_op } => (cmp_expr, cmp_op), - Search::RowidEq { .. } => unreachable!(), - }; - // TODO this only handles ascending indexes - match cmp_op { - ast::Operator::Equals - | ast::Operator::Greater - | ast::Operator::GreaterEquals => { - translate_expr( - program, - Some(referenced_tables), - cmp_expr, - cmp_reg, - &t_ctx.resolver, - )?; - } - ast::Operator::Less | ast::Operator::LessEquals => { - program.emit_insn(Insn::Null { - dest: cmp_reg, - dest_end: None, - }); - } - _ => unreachable!(), - } - // If we try to seek to a key that is not present in the table/index, we exit the loop entirely. - program.emit_insn(match cmp_op { - ast::Operator::Equals | ast::Operator::GreaterEquals => Insn::SeekGE { - is_index: index_cursor_id.is_some(), - cursor_id: index_cursor_id.unwrap_or(table_cursor_id), - start_reg: cmp_reg, - num_regs: 1, - target_pc: loop_end, - }, - ast::Operator::Greater | ast::Operator::Less | ast::Operator::LessEquals => { - Insn::SeekGT { - is_index: index_cursor_id.is_some(), - cursor_id: index_cursor_id.unwrap_or(table_cursor_id), - start_reg: cmp_reg, - num_regs: 1, - target_pc: loop_end, - } - } - _ => unreachable!(), - }); - if *cmp_op == ast::Operator::Less || *cmp_op == ast::Operator::LessEquals { - translate_expr( - program, - Some(referenced_tables), - cmp_expr, - cmp_reg, - &t_ctx.resolver, - )?; - } - - program.resolve_label(loop_start, program.offset()); - // TODO: We are currently only handling ascending indexes. - // For conditions like index_key > 10, we have already seeked to the first key greater than 10, and can just scan forward. - // For conditions like index_key < 10, we are at the beginning of the index, and will scan forward and emit IdxGE(10) with a conditional jump to the end. - // For conditions like index_key = 10, we have already seeked to the first key greater than or equal to 10, and can just scan forward and emit IdxGT(10) with a conditional jump to the end. - // For conditions like index_key >= 10, we have already seeked to the first key greater than or equal to 10, and can just scan forward. - // For conditions like index_key <= 10, we are at the beginning of the index, and will scan forward and emit IdxGT(10) with a conditional jump to the end. - // For conditions like index_key != 10, TODO. probably the optimal way is not to use an index at all. - // - // For primary key searches we emit RowId and then compare it to the seek value. - - match cmp_op { - ast::Operator::Equals | ast::Operator::LessEquals => { - if let Some(index_cursor_id) = index_cursor_id { - program.emit_insn(Insn::IdxGT { - cursor_id: index_cursor_id, - start_reg: cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } else { - let rowid_reg = program.alloc_register(); - program.emit_insn(Insn::RowId { - cursor_id: table_cursor_id, - dest: rowid_reg, - }); - program.emit_insn(Insn::Gt { - lhs: rowid_reg, - rhs: cmp_reg, - target_pc: loop_end, - jump_if_null: false, - }); - } - } - ast::Operator::Less => { - if let Some(index_cursor_id) = index_cursor_id { - program.emit_insn(Insn::IdxGE { - cursor_id: index_cursor_id, - start_reg: cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } else { - let rowid_reg = program.alloc_register(); - program.emit_insn(Insn::RowId { - cursor_id: table_cursor_id, - dest: rowid_reg, - }); - program.emit_insn(Insn::Ge { - lhs: rowid_reg, - rhs: cmp_reg, - target_pc: loop_end, - jump_if_null: false, - }); - } - } - _ => {} - } - - if let Some(index_cursor_id) = index_cursor_id { - program.emit_insn(Insn::DeferredSeek { - index_cursor_id, - table_cursor_id, - }); - } - } - - if let Search::RowidEq { cmp_expr } = search { - let src_reg = program.alloc_register(); - translate_expr( - program, - Some(referenced_tables), - cmp_expr, - src_reg, - &t_ctx.resolver, - )?; - program.emit_insn(Insn::SeekRowid { - cursor_id: table_cursor_id, - src_reg, - target_pc: next, - }); - } - if let Some(predicates) = predicates { - for predicate in predicates.iter() { - let jump_target_when_true = program.allocate_label(); - let condition_metadata = ConditionMetadata { - jump_if_condition_is_true: false, - jump_target_when_true, - jump_target_when_false: next, - }; - translate_condition_expr( - program, - referenced_tables, - predicate, - condition_metadata, - &t_ctx.resolver, - )?; - program.resolve_label(jump_target_when_true, program.offset()); - } - } - - Ok(()) - } - SourceOperator::Nothing { .. } => Ok(()), } + + Ok(()) } /// SQLite (and so Limbo) processes joins as a nested loop. @@ -620,7 +551,7 @@ fn emit_loop_source( cur_reg += 1; translate_expr( program, - Some(&plan.referenced_tables), + Some(&plan.table_references), expr, key_reg, &t_ctx.resolver, @@ -639,7 +570,7 @@ fn emit_loop_source( cur_reg += 1; translate_expr( program, - Some(&plan.referenced_tables), + Some(&plan.table_references), expr, agg_reg, &t_ctx.resolver, @@ -676,7 +607,7 @@ fn emit_loop_source( let reg = start_reg + i; translate_aggregation_step( program, - &plan.referenced_tables, + &plan.table_references, agg, reg, &t_ctx.resolver, @@ -692,7 +623,7 @@ fn emit_loop_source( let reg = start_reg + num_aggs + i; translate_expr( program, - Some(&plan.referenced_tables), + Some(&plan.table_references), &rc.expr, reg, &t_ctx.resolver, @@ -705,7 +636,18 @@ fn emit_loop_source( plan.aggregates.is_empty(), "We should not get here with aggregates" ); - emit_select_result(program, t_ctx, plan, t_ctx.label_main_loop_end)?; + let offset_jump_to = t_ctx + .labels_main_loop + .get(0) + .map(|l| l.next) + .or_else(|| t_ctx.label_main_loop_end); + emit_select_result( + program, + t_ctx, + plan, + t_ctx.label_main_loop_end, + offset_jump_to, + )?; Ok(()) } @@ -718,33 +660,85 @@ fn emit_loop_source( pub fn close_loop( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - source: &SourceOperator, + tables: &[TableReference], ) -> Result<()> { - let loop_labels = *t_ctx - .labels_main_loop - .get(&source.id()) - .expect("source has no loop labels"); - match source { - SourceOperator::Subquery { .. } => { - program.resolve_label(loop_labels.next, program.offset()); - // A subquery has no cursor to call NextAsync on, so it just emits a Goto - // to the Yield instruction, which in turn jumps back to the main loop of the subquery, - // so that the next row from the subquery can be read. - program.emit_insn(Insn::Goto { - target_pc: loop_labels.loop_start, - }); - } - SourceOperator::Join { - id, - left, - right, - outer, - .. - } => { - close_loop(program, t_ctx, right)?; + // We close the loops for all tables in reverse order, i.e. innermost first. + // OPEN t1 + // OPEN t2 + // OPEN t3 + // + // CLOSE t3 + // CLOSE t2 + // CLOSE t1 + for (idx, table) in tables.iter().rev().enumerate() { + let table_index = tables.len() - idx - 1; + let loop_labels = *t_ctx + .labels_main_loop + .get(table_index) + .expect("source has no loop labels"); - if *outer { - let lj_meta = t_ctx.meta_left_joins.get(id).unwrap(); + match &table.op { + Operation::Subquery { .. } => { + program.resolve_label(loop_labels.next, program.offset()); + // A subquery has no cursor to call NextAsync on, so it just emits a Goto + // to the Yield instruction, which in turn jumps back to the main loop of the subquery, + // so that the next row from the subquery can be read. + program.emit_insn(Insn::Goto { + target_pc: loop_labels.loop_start, + }); + } + Operation::Scan { iter_dir, .. } => { + program.resolve_label(loop_labels.next, program.offset()); + let cursor_id = program.resolve_cursor_id(&table.identifier); + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + program.emit_insn(Insn::PrevAsync { cursor_id }); + } else { + program.emit_insn(Insn::NextAsync { cursor_id }); + } + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + program.emit_insn(Insn::PrevAwait { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } else { + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } + } + Operation::Search(search) => { + program.resolve_label(loop_labels.next, program.offset()); + // Rowid equality point lookups are handled with a SeekRowid instruction which does not loop, so there is no need to emit a NextAsync instruction. + if !matches!(search, Search::RowidEq { .. }) { + let cursor_id = match search { + Search::IndexSearch { index, .. } => program.resolve_cursor_id(&index.name), + Search::RowidSearch { .. } => program.resolve_cursor_id(&table.identifier), + Search::RowidEq { .. } => unreachable!(), + }; + + program.emit_insn(Insn::NextAsync { cursor_id }); + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } + } + } + + program.resolve_label(loop_labels.loop_end, program.offset()); + + // Handle OUTER JOIN logic. The reason this comes after the "loop end" mark is that we may need to still jump back + // and emit a row with NULLs for the right table, and then jump back to the next row of the left table. + if let Some(join_info) = table.join_info.as_ref() { + if join_info.outer { + let lj_meta = t_ctx.meta_left_joins[table_index].as_ref().unwrap(); // The left join match flag is set to 1 when there is any match on the right table // (e.g. SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a). // If the left join match flag has been set to 1, we jump to the next row on the outer table, @@ -760,13 +754,9 @@ pub fn close_loop( // but since it's a LEFT JOIN, we still need to emit a row with NULLs for the right table. // In that case, we now enter the routine that does exactly that. // First we set the right table cursor's "pseudo null bit" on, which means any Insn::Column will return NULL - let right_cursor_id = match right.as_ref() { - SourceOperator::Scan { - table_reference, .. - } => program.resolve_cursor_id(&table_reference.table_identifier), - SourceOperator::Search { - table_reference, .. - } => program.resolve_cursor_id(&table_reference.table_identifier), + let right_cursor_id = match &table.op { + Operation::Scan { .. } => program.resolve_cursor_id(&table.identifier), + Operation::Search { .. } => program.resolve_cursor_id(&table.identifier), _ => unreachable!(), }; program.emit_insn(Insn::NullRow { @@ -784,66 +774,7 @@ pub fn close_loop( assert_eq!(program.offset(), jump_offset); } - - close_loop(program, t_ctx, left)?; } - SourceOperator::Scan { - table_reference, - iter_dir, - .. - } => { - program.resolve_label(loop_labels.next, program.offset()); - let cursor_id = program.resolve_cursor_id(&table_reference.table_identifier); - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::PrevAsync { cursor_id }); - } else { - program.emit_insn(Insn::NextAsync { cursor_id }); - } - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::PrevAwait { - cursor_id, - pc_if_next: loop_labels.loop_start, - }); - } else { - program.emit_insn(Insn::NextAwait { - cursor_id, - pc_if_next: loop_labels.loop_start, - }); - } - } - SourceOperator::Search { - table_reference, - search, - .. - } => { - program.resolve_label(loop_labels.next, program.offset()); - if matches!(search, Search::RowidEq { .. }) { - // Rowid equality point lookups are handled with a SeekRowid instruction which does not loop, so there is no need to emit a NextAsync instruction. - return Ok(()); - } - let cursor_id = match search { - Search::IndexSearch { index, .. } => program.resolve_cursor_id(&index.name), - Search::RowidSearch { .. } => { - program.resolve_cursor_id(&table_reference.table_identifier) - } - Search::RowidEq { .. } => unreachable!(), - }; - - program.emit_insn(Insn::NextAsync { cursor_id }); - program.emit_insn(Insn::NextAwait { - cursor_id, - pc_if_next: loop_labels.loop_start, - }); - } - SourceOperator::Nothing { .. } => {} - }; - - program.resolve_label(loop_labels.loop_end, program.offset()); + } Ok(()) } diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 75c1f9758..82b115c74 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -18,26 +18,25 @@ pub(crate) mod optimizer; pub(crate) mod order_by; pub(crate) mod plan; pub(crate) mod planner; +pub(crate) mod pragma; pub(crate) mod result_row; pub(crate) mod select; pub(crate) mod subquery; use crate::schema::Schema; use crate::storage::pager::Pager; -use crate::storage::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; -use crate::storage::wal::CheckpointMode; +use crate::storage::sqlite3_ondisk::DatabaseHeader; use crate::translate::delete::translate_delete; use crate::util::PRIMARY_KEY_AUTOMATIC_INDEX_NAME_PREFIX; -use crate::vdbe::builder::CursorType; +use crate::vdbe::builder::{CursorType, QueryMode}; use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable}; use insert::translate_insert; use select::translate_select; -use sqlite3_parser::ast::{self, fmt::ToTokens, PragmaName}; +use sqlite3_parser::ast::{self, fmt::ToTokens}; use std::cell::RefCell; use std::fmt::Display; use std::rc::{Rc, Weak}; -use std::str::FromStr; /// Translate SQL statement into bytecode program. pub fn translate( @@ -47,8 +46,9 @@ pub fn translate( pager: Rc, connection: Weak, syms: &SymbolTable, + query_mode: QueryMode, ) -> Result { - let mut program = ProgramBuilder::new(); + let mut program = ProgramBuilder::new(query_mode); let mut change_cnt_on = false; match stmt { @@ -90,7 +90,14 @@ pub fn translate( ast::Stmt::DropTrigger { .. } => bail_parse_error!("DROP TRIGGER not supported yet"), ast::Stmt::DropView { .. } => bail_parse_error!("DROP VIEW not supported yet"), ast::Stmt::Pragma(name, body) => { - translate_pragma(&mut program, &name, body, database_header.clone(), pager)?; + pragma::translate_pragma( + &mut program, + &schema, + &name, + body, + database_header.clone(), + pager, + )?; } ast::Stmt::Reindex { .. } => bail_parse_error!("REINDEX not supported yet"), ast::Stmt::Release(_) => bail_parse_error!("RELEASE not supported yet"), @@ -197,23 +204,9 @@ fn emit_schema_entry( prev_largest_reg: 0, }); - let type_reg = program.alloc_register(); - program.emit_insn(Insn::String8 { - value: entry_type.as_str().to_string(), - dest: type_reg, - }); - - let name_reg = program.alloc_register(); - program.emit_insn(Insn::String8 { - value: name.to_string(), - dest: name_reg, - }); - - let tbl_name_reg = program.alloc_register(); - program.emit_insn(Insn::String8 { - value: tbl_name.to_string(), - dest: tbl_name_reg, - }); + let type_reg = program.emit_string8_new_reg(entry_type.as_str().to_string()); + program.emit_string8_new_reg(name.to_string()); + program.emit_string8_new_reg(tbl_name.to_string()); let rootpage_reg = program.alloc_register(); program.emit_insn(Insn::Copy { @@ -224,15 +217,9 @@ fn emit_schema_entry( let sql_reg = program.alloc_register(); if let Some(sql) = sql { - program.emit_insn(Insn::String8 { - value: sql, - dest: sql_reg, - }); + program.emit_string8(sql, sql_reg); } else { - program.emit_insn(Insn::Null { - dest: sql_reg, - dest_end: None, - }); + program.emit_null(sql_reg); } let record_reg = program.alloc_register(); @@ -253,6 +240,11 @@ fn emit_schema_entry( }); } +struct PrimaryKeyColumnInfo<'a> { + name: &'a String, + is_descending: bool, +} + /// Check if an automatic PRIMARY KEY index is required for the table. /// If so, create a register for the index root page and return it. /// @@ -282,10 +274,13 @@ fn check_automatic_pk_index_required( columns: pk_cols, .. } = &constraint.constraint { - let primary_key_column_results: Vec> = pk_cols + let primary_key_column_results: Vec> = pk_cols .iter() .map(|col| match &col.expr { - ast::Expr::Id(name) => Ok(&name.0), + ast::Expr::Id(name) => Ok(PrimaryKeyColumnInfo { + name: &name.0, + is_descending: matches!(col.order, Some(ast::SortOrder::Desc)), + }), _ => Err(LimboError::ParseError( "expressions prohibited in PRIMARY KEY and UNIQUE constraints" .to_string(), @@ -297,7 +292,9 @@ fn check_automatic_pk_index_required( if let Err(e) = result { bail_parse_error!("{}", e); } - let column_name = result?; + let pk_info = result?; + + let column_name = pk_info.name; let column_def = columns.get(&ast::Name(column_name.clone())); if column_def.is_none() { bail_parse_error!("No such column: {}", column_name); @@ -314,8 +311,11 @@ fn check_automatic_pk_index_required( let column_def = column_def.unwrap(); let typename = column_def.col_type.as_ref().map(|t| t.name.as_str()); - primary_key_definition = - Some(PrimaryKeyDefinitionType::Simple { typename }); + let is_descending = pk_info.is_descending; + primary_key_definition = Some(PrimaryKeyDefinitionType::Simple { + typename, + is_descending, + }); } } } @@ -333,8 +333,10 @@ fn check_automatic_pk_index_required( bail_parse_error!("table {} has more than one primary key", tbl_name); } let typename = col_def.col_type.as_ref().map(|t| t.name.as_str()); - primary_key_definition = - Some(PrimaryKeyDefinitionType::Simple { typename }); + primary_key_definition = Some(PrimaryKeyDefinitionType::Simple { + typename, + is_descending: false, + }); } } } @@ -347,9 +349,13 @@ fn check_automatic_pk_index_required( // Check if we need an automatic index let needs_auto_index = if let Some(primary_key_definition) = &primary_key_definition { match primary_key_definition { - PrimaryKeyDefinitionType::Simple { typename } => { - let is_integer = typename.is_some() && typename.unwrap() == "INTEGER"; - !is_integer + PrimaryKeyDefinitionType::Simple { + typename, + is_descending, + } => { + let is_integer = + typename.is_some() && typename.unwrap().to_uppercase() == "INTEGER"; + !is_integer || *is_descending } PrimaryKeyDefinitionType::Composite => true, } @@ -379,21 +385,13 @@ fn translate_create_table( ) -> Result<()> { if schema.get_table(tbl_name.name.0.as_str()).is_some() { if if_not_exists { - let init_label = program.allocate_label(); - program.emit_insn(Insn::Init { - target_pc: init_label, - }); + let init_label = program.emit_init(); let start_offset = program.offset(); - program.emit_insn(Insn::Halt { - err_code: 0, - description: String::new(), - }); + program.emit_halt(); program.resolve_label(init_label, program.offset()); - program.emit_insn(Insn::Transaction { write: true }); + program.emit_transaction(true); program.emit_constant_insns(); - program.emit_insn(Insn::Goto { - target_pc: start_offset, - }); + program.emit_goto(start_offset); return Ok(()); } @@ -403,10 +401,7 @@ fn translate_create_table( let sql = create_table_body_to_str(&tbl_name, &body); let parse_schema_label = program.allocate_label(); - let init_label = program.allocate_label(); - program.emit_insn(Insn::Init { - target_pc: init_label, - }); + let init_label = program.emit_init(); let start_offset = program.offset(); // TODO: ReadCookie // TODO: If @@ -505,183 +500,23 @@ fn translate_create_table( }); // TODO: SqlExec - program.emit_insn(Insn::Halt { - err_code: 0, - description: String::new(), - }); + program.emit_halt(); program.resolve_label(init_label, program.offset()); - program.emit_insn(Insn::Transaction { write: true }); + program.emit_transaction(true); program.emit_constant_insns(); - program.emit_insn(Insn::Goto { - target_pc: start_offset, - }); + program.emit_goto(start_offset); Ok(()) } enum PrimaryKeyDefinitionType<'a> { - Simple { typename: Option<&'a str> }, + Simple { + typename: Option<&'a str>, + is_descending: bool, + }, Composite, } -fn translate_pragma( - program: &mut ProgramBuilder, - name: &ast::QualifiedName, - body: Option, - database_header: Rc>, - pager: Rc, -) -> Result<()> { - let init_label = program.allocate_label(); - program.emit_insn(Insn::Init { - target_pc: init_label, - }); - let start_offset = program.offset(); - let mut write = false; - match body { - None => { - let pragma_name = &name.name.0; - query_pragma(pragma_name, database_header.clone(), program)?; - } - Some(ast::PragmaBody::Equals(value)) => { - write = true; - update_pragma(&name.name.0, value, database_header.clone(), pager, program)?; - } - Some(ast::PragmaBody::Call(_)) => { - todo!() - } - }; - program.emit_insn(Insn::Halt { - err_code: 0, - description: String::new(), - }); - program.resolve_label(init_label, program.offset()); - program.emit_insn(Insn::Transaction { write }); - program.emit_constant_insns(); - program.emit_insn(Insn::Goto { - target_pc: start_offset, - }); - - Ok(()) -} - -fn update_pragma( - name: &str, - value: ast::Expr, - header: Rc>, - pager: Rc, - program: &mut ProgramBuilder, -) -> Result<()> { - let pragma = match PragmaName::from_str(name) { - Ok(pragma) => pragma, - Err(()) => bail_parse_error!("Not a valid pragma name"), - }; - match pragma { - PragmaName::CacheSize => { - let cache_size = match value { - ast::Expr::Literal(ast::Literal::Numeric(numeric_value)) => { - numeric_value.parse::()? - } - ast::Expr::Unary(ast::UnaryOperator::Negative, expr) => match *expr { - ast::Expr::Literal(ast::Literal::Numeric(numeric_value)) => { - -numeric_value.parse::()? - } - _ => bail_parse_error!("Not a valid value"), - }, - _ => bail_parse_error!("Not a valid value"), - }; - update_cache_size(cache_size, header, pager); - Ok(()) - } - PragmaName::JournalMode => { - query_pragma("journal_mode", header, program)?; - Ok(()) - } - PragmaName::WalCheckpoint => { - query_pragma("wal_checkpoint", header, program)?; - Ok(()) - } - } -} - -fn query_pragma( - name: &str, - database_header: Rc>, - program: &mut ProgramBuilder, -) -> Result<()> { - let pragma = match PragmaName::from_str(name) { - Ok(pragma) => pragma, - Err(()) => bail_parse_error!("Not a valid pragma name"), - }; - let register = program.alloc_register(); - match pragma { - PragmaName::CacheSize => { - program.emit_insn(Insn::Integer { - value: database_header.borrow().default_page_cache_size.into(), - dest: register, - }); - program.emit_insn(Insn::ResultRow { - start_reg: register, - count: 1, - }); - } - PragmaName::JournalMode => { - program.emit_insn(Insn::String8 { - value: "wal".into(), - dest: register, - }); - program.emit_insn(Insn::ResultRow { - start_reg: register, - count: 1, - }); - } - PragmaName::WalCheckpoint => { - // Checkpoint uses 3 registers: P1, P2, P3. Ref Insn::Checkpoint for more info. - // Allocate two more here as one was allocated at the top. - program.alloc_register(); - program.alloc_register(); - program.emit_insn(Insn::Checkpoint { - database: 0, - checkpoint_mode: CheckpointMode::Passive, - dest: register, - }); - program.emit_insn(Insn::ResultRow { - start_reg: register, - count: 3, - }); - } - } - - Ok(()) -} - -fn update_cache_size(value: i64, header: Rc>, pager: Rc) { - let mut cache_size_unformatted: i64 = value; - let mut cache_size = if cache_size_unformatted < 0 { - let kb = cache_size_unformatted.abs() * 1024; - kb / 512 // assume 512 page size for now - } else { - value - } as usize; - - if cache_size < MIN_PAGE_CACHE_SIZE { - // update both in memory and stored disk value - cache_size = MIN_PAGE_CACHE_SIZE; - cache_size_unformatted = MIN_PAGE_CACHE_SIZE as i64; - } - - // update in-memory header - header.borrow_mut().default_page_cache_size = cache_size_unformatted - .try_into() - .unwrap_or_else(|_| panic!("invalid value, too big for a i32 {}", value)); - - // update in disk - let header_copy = header.borrow().clone(); - pager.write_database_header(&header_copy); - - // update cache size - pager.change_page_cache_size(cache_size); -} - struct TableFormatter<'a> { body: &'a ast::CreateTableBody, } diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index d6caba85e..63f8cc573 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -1,19 +1,21 @@ -use std::rc::Rc; +use std::{collections::HashMap, rc::Rc}; use sqlite3_parser::ast; -use crate::{schema::Index, Result}; - -use super::plan::{ - get_table_ref_bitmask_for_ast_expr, get_table_ref_bitmask_for_operator, DeletePlan, Direction, - IterationDirection, Plan, Search, SelectPlan, SourceOperator, TableReference, - TableReferenceType, +use crate::{ + schema::{Index, Schema}, + Result, }; -pub fn optimize_plan(plan: &mut Plan) -> Result<()> { +use super::plan::{ + DeletePlan, Direction, IterationDirection, Operation, Plan, Search, SelectPlan, TableReference, + WhereTerm, +}; + +pub fn optimize_plan(plan: &mut Plan, schema: &Schema) -> Result<()> { match plan { - Plan::Select(plan) => optimize_select_plan(plan), - Plan::Delete(plan) => optimize_delete_plan(plan), + Plan::Select(plan) => optimize_select_plan(plan, schema), + Plan::Delete(plan) => optimize_delete_plan(plan, schema), } } @@ -22,118 +24,90 @@ pub fn optimize_plan(plan: &mut Plan) -> Result<()> { * TODO: these could probably be done in less passes, * but having them separate makes them easier to understand */ -fn optimize_select_plan(plan: &mut SelectPlan) -> Result<()> { - optimize_subqueries(&mut plan.source)?; +fn optimize_select_plan(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { + optimize_subqueries(plan, schema)?; rewrite_exprs_select(plan)?; if let ConstantConditionEliminationResult::ImpossibleCondition = - eliminate_constants(&mut plan.source, &mut plan.where_clause)? + eliminate_constant_conditions(&mut plan.where_clause)? { plan.contains_constant_false_condition = true; return Ok(()); } - push_predicates( - &mut plan.source, - &mut plan.where_clause, - &plan.referenced_tables, - )?; - use_indexes( - &mut plan.source, - &plan.referenced_tables, - &plan.available_indexes, + &mut plan.table_references, + &schema.indexes, + &mut plan.where_clause, )?; - eliminate_unnecessary_orderby( - &mut plan.source, - &mut plan.order_by, - &plan.referenced_tables, - &plan.available_indexes, - )?; + eliminate_unnecessary_orderby(plan, schema)?; Ok(()) } -fn optimize_delete_plan(plan: &mut DeletePlan) -> Result<()> { +fn optimize_delete_plan(plan: &mut DeletePlan, schema: &Schema) -> Result<()> { rewrite_exprs_delete(plan)?; if let ConstantConditionEliminationResult::ImpossibleCondition = - eliminate_constants(&mut plan.source, &mut plan.where_clause)? + eliminate_constant_conditions(&mut plan.where_clause)? { plan.contains_constant_false_condition = true; return Ok(()); } use_indexes( - &mut plan.source, - &plan.referenced_tables, - &plan.available_indexes, + &mut plan.table_references, + &schema.indexes, + &mut plan.where_clause, )?; Ok(()) } -fn optimize_subqueries(operator: &mut SourceOperator) -> Result<()> { - match operator { - SourceOperator::Subquery { plan, .. } => { - optimize_select_plan(&mut *plan)?; - Ok(()) +fn optimize_subqueries(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { + for table in plan.table_references.iter_mut() { + if let Operation::Subquery { plan, .. } = &mut table.op { + optimize_select_plan(&mut *plan, schema)?; } - SourceOperator::Join { left, right, .. } => { - optimize_subqueries(left)?; - optimize_subqueries(right)?; - Ok(()) - } - _ => Ok(()), } + + Ok(()) } -fn _operator_is_already_ordered_by( - operator: &mut SourceOperator, +fn query_is_already_ordered_by( + table_references: &[TableReference], key: &mut ast::Expr, - referenced_tables: &[TableReference], - available_indexes: &Vec>, + available_indexes: &HashMap>>, ) -> Result { - match operator { - SourceOperator::Scan { - table_reference, .. - } => Ok(key.is_rowid_alias_of(table_reference.table_index)), - SourceOperator::Search { - table_reference, - search, - .. - } => match search { - Search::RowidEq { .. } => Ok(key.is_rowid_alias_of(table_reference.table_index)), - Search::RowidSearch { .. } => Ok(key.is_rowid_alias_of(table_reference.table_index)), + let first_table = table_references.first(); + if first_table.is_none() { + return Ok(false); + } + let table_reference = first_table.unwrap(); + match &table_reference.op { + Operation::Scan { .. } => Ok(key.is_rowid_alias_of(0)), + Operation::Search(search) => match search { + Search::RowidEq { .. } => Ok(key.is_rowid_alias_of(0)), + Search::RowidSearch { .. } => Ok(key.is_rowid_alias_of(0)), Search::IndexSearch { index, .. } => { - let index_idx = key.check_index_scan( - table_reference.table_index, - referenced_tables, - available_indexes, - )?; - let index_is_the_same = index_idx - .map(|i| Rc::ptr_eq(&available_indexes[i], index)) - .unwrap_or(false); + let index_rc = key.check_index_scan(0, &table_reference, available_indexes)?; + let index_is_the_same = + index_rc.map(|irc| Rc::ptr_eq(index, &irc)).unwrap_or(false); Ok(index_is_the_same) } }, - SourceOperator::Join { left, .. } => { - _operator_is_already_ordered_by(left, key, referenced_tables, available_indexes) - } _ => Ok(false), } } -fn eliminate_unnecessary_orderby( - operator: &mut SourceOperator, - order_by: &mut Option>, - referenced_tables: &[TableReference], - available_indexes: &Vec>, -) -> Result<()> { - if order_by.is_none() { +fn eliminate_unnecessary_orderby(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { + if plan.order_by.is_none() { + return Ok(()); + } + if plan.table_references.len() == 0 { return Ok(()); } - let o = order_by.as_mut().unwrap(); + let o = plan.order_by.as_mut().unwrap(); if o.len() != 1 { // TODO: handle multiple order by keys @@ -143,76 +117,55 @@ fn eliminate_unnecessary_orderby( let (key, direction) = o.first_mut().unwrap(); let already_ordered = - _operator_is_already_ordered_by(operator, key, referenced_tables, available_indexes)?; + query_is_already_ordered_by(&plan.table_references, key, &schema.indexes)?; if already_ordered { - push_scan_direction(operator, direction); - *order_by = None; + push_scan_direction(&mut plan.table_references[0], direction); + plan.order_by = None; } Ok(()) } /** - * Use indexes where possible + * Use indexes where possible. + * Right now we make decisions about using indexes ONLY based on condition expressions, not e.g. ORDER BY or others. + * This is just because we are WIP. + * + * When this function is called, condition expressions from both the actual WHERE clause and the JOIN clauses are in the where_clause vector. + * If we find a condition that can be used to index scan, we pop it off from the where_clause vector and put it into a Search operation. + * We put it there simply because it makes it a bit easier to track during translation. */ fn use_indexes( - operator: &mut SourceOperator, - referenced_tables: &[TableReference], - available_indexes: &[Rc], + table_references: &mut [TableReference], + available_indexes: &HashMap>>, + where_clause: &mut Vec, ) -> Result<()> { - match operator { - SourceOperator::Subquery { .. } => Ok(()), - SourceOperator::Search { .. } => Ok(()), - SourceOperator::Scan { - table_reference, - predicates: filter, - id, - .. - } => { - if filter.is_none() { - return Ok(()); - } + if where_clause.is_empty() { + return Ok(()); + } - let fs = filter.as_mut().unwrap(); - for i in 0..fs.len() { - let f = fs[i].take_ownership(); - let table_index = referenced_tables - .iter() - .position(|t| t.table_identifier == table_reference.table_identifier) - .unwrap(); - match try_extract_index_search_expression( - f, + 'outer: for (table_index, table_reference) in table_references.iter_mut().enumerate() { + if let Operation::Scan { .. } = &mut table_reference.op { + let mut i = 0; + while i < where_clause.len() { + let cond = where_clause.get_mut(i).unwrap(); + if let Some(index_search) = try_extract_index_search_expression( + cond, table_index, - referenced_tables, + &table_reference, available_indexes, )? { - Either::Left(non_index_using_expr) => { - fs[i] = non_index_using_expr; - } - Either::Right(index_search) => { - fs.remove(i); - *operator = SourceOperator::Search { - id: *id, - table_reference: table_reference.clone(), - predicates: Some(fs.clone()), - search: index_search, - }; - - return Ok(()); - } + where_clause.remove(i); + table_reference.op = Operation::Search(index_search); + continue 'outer; } + i += 1; } - - Ok(()) } - SourceOperator::Join { left, right, .. } => { - use_indexes(left, referenced_tables, available_indexes)?; - use_indexes(right, referenced_tables, available_indexes)?; - Ok(()) - } - SourceOperator::Nothing { .. } => Ok(()), } + + Ok(()) } #[derive(Debug, PartialEq, Clone)] @@ -221,377 +174,38 @@ enum ConstantConditionEliminationResult { ImpossibleCondition, } -// removes predicates that are always true -// returns a ConstantEliminationResult indicating whether any predicates are always false -fn eliminate_constants( - operator: &mut SourceOperator, - where_clause: &mut Option>, +/// Removes predicates that are always true. +/// Returns a ConstantEliminationResult indicating whether any predicates are always false. +/// This is used to determine whether the query can be aborted early. +fn eliminate_constant_conditions( + where_clause: &mut Vec, ) -> Result { - if let Some(predicates) = where_clause { - let mut i = 0; - while i < predicates.len() { - let predicate = &predicates[i]; - if predicate.is_always_true()? { - // true predicates can be removed since they don't affect the result - predicates.remove(i); - } else if predicate.is_always_false()? { - // any false predicate in a list of conjuncts (AND-ed predicates) will make the whole list false - predicates.truncate(0); - return Ok(ConstantConditionEliminationResult::ImpossibleCondition); - } else { + let mut i = 0; + while i < where_clause.len() { + let predicate = &where_clause[i]; + if predicate.expr.is_always_true()? { + // true predicates can be removed since they don't affect the result + where_clause.remove(i); + } else if predicate.expr.is_always_false()? { + // any false predicate in a list of conjuncts (AND-ed predicates) will make the whole list false, + // except an outer join condition, because that just results in NULLs, not skipping the whole loop + if predicate.from_outer_join { i += 1; - } - } - } - match operator { - SourceOperator::Subquery { .. } => Ok(ConstantConditionEliminationResult::Continue), - SourceOperator::Join { - left, - right, - predicates, - outer, - .. - } => { - if eliminate_constants(left, where_clause)? - == ConstantConditionEliminationResult::ImpossibleCondition - { - return Ok(ConstantConditionEliminationResult::ImpossibleCondition); - } - if eliminate_constants(right, where_clause)? - == ConstantConditionEliminationResult::ImpossibleCondition - && !*outer - { - return Ok(ConstantConditionEliminationResult::ImpossibleCondition); - } - - if predicates.is_none() { - return Ok(ConstantConditionEliminationResult::Continue); - } - - let predicates = predicates.as_mut().unwrap(); - - let mut i = 0; - while i < predicates.len() { - let predicate = &mut predicates[i]; - if predicate.is_always_true()? { - predicates.remove(i); - } else if predicate.is_always_false()? { - if !*outer { - predicates.truncate(0); - return Ok(ConstantConditionEliminationResult::ImpossibleCondition); - } - // in an outer join, we can't skip rows, so just replace all constant false predicates with 0 - // so we don't later have to evaluate anything more complex or special-case the identifiers true and false - // which are just aliases for 1 and 0 - *predicate = ast::Expr::Literal(ast::Literal::Numeric("0".to_string())); - i += 1; - } else { - i += 1; - } - } - - Ok(ConstantConditionEliminationResult::Continue) - } - SourceOperator::Scan { predicates, .. } => { - if let Some(ps) = predicates { - let mut i = 0; - while i < ps.len() { - let predicate = &ps[i]; - if predicate.is_always_true()? { - // true predicates can be removed since they don't affect the result - ps.remove(i); - } else if predicate.is_always_false()? { - // any false predicate in a list of conjuncts (AND-ed predicates) will make the whole list false - ps.truncate(0); - return Ok(ConstantConditionEliminationResult::ImpossibleCondition); - } else { - i += 1; - } - } - - if ps.is_empty() { - *predicates = None; - } - } - Ok(ConstantConditionEliminationResult::Continue) - } - SourceOperator::Search { predicates, .. } => { - if let Some(predicates) = predicates { - let mut i = 0; - while i < predicates.len() { - let predicate = &predicates[i]; - if predicate.is_always_true()? { - // true predicates can be removed since they don't affect the result - predicates.remove(i); - } else if predicate.is_always_false()? { - // any false predicate in a list of conjuncts (AND-ed predicates) will make the whole list false - predicates.truncate(0); - return Ok(ConstantConditionEliminationResult::ImpossibleCondition); - } else { - i += 1; - } - } - } - - Ok(ConstantConditionEliminationResult::Continue) - } - SourceOperator::Nothing { .. } => Ok(ConstantConditionEliminationResult::Continue), - } -} - -/** - Recursively pushes predicates down the tree, as far as possible. - Where a predicate is pushed determines at which loop level it will be evaluated. - For example, in SELECT * FROM t1 JOIN t2 JOIN t3 WHERE t1.a = t2.a AND t2.b = t3.b AND t1.c = 1 - the predicate t1.c = 1 can be pushed to t1 and will be evaluated in the first (outermost) loop, - the predicate t1.a = t2.a can be pushed to t2 and will be evaluated in the second loop - while t2.b = t3.b will be evaluated in the third loop. -*/ -fn push_predicates( - operator: &mut SourceOperator, - where_clause: &mut Option>, - referenced_tables: &Vec, -) -> Result<()> { - // First try to push down any predicates from the WHERE clause - if let Some(predicates) = where_clause { - let mut i = 0; - while i < predicates.len() { - // Take ownership of predicate to try pushing it down - let predicate = predicates[i].take_ownership(); - // If predicate was successfully pushed (None returned), remove it from WHERE - let Some(predicate) = push_predicate(operator, predicate, referenced_tables)? else { - predicates.remove(i); continue; - }; - predicates[i] = predicate; + } + where_clause.truncate(0); + return Ok(ConstantConditionEliminationResult::ImpossibleCondition); + } else { i += 1; } - // Clean up empty WHERE clause - if predicates.is_empty() { - *where_clause = None; - } } - match operator { - SourceOperator::Subquery { .. } => Ok(()), - SourceOperator::Join { - left, - right, - predicates, - outer, - .. - } => { - // Recursively push predicates down both sides of join - push_predicates(left, where_clause, referenced_tables)?; - push_predicates(right, where_clause, referenced_tables)?; - - if predicates.is_none() { - return Ok(()); - } - - let predicates = predicates.as_mut().unwrap(); - - let mut i = 0; - while i < predicates.len() { - let predicate_owned = predicates[i].take_ownership(); - - // For a join like SELECT * FROM left INNER JOIN right ON left.id = right.id AND left.name = 'foo' - // the predicate 'left.name = 'foo' can already be evaluated in the outer loop (left side of join) - // because the row can immediately be skipped if left.name != 'foo'. - // But for a LEFT JOIN, we can't do this since we need to ensure that all rows from the left table are included, - // even if there are no matching rows from the right table. This is why we can't push LEFT JOIN predicates to the left side. - let push_result = if *outer { - Some(predicate_owned) - } else { - push_predicate(left, predicate_owned, referenced_tables)? - }; - - // Try pushing to left side first (see comment above for reasoning) - let Some(predicate) = push_result else { - predicates.remove(i); - continue; - }; - - // Then try right side - let Some(predicate) = push_predicate(right, predicate, referenced_tables)? else { - predicates.remove(i); - continue; - }; - - // If neither side could take it, keep in join predicates (not sure if this actually happens in practice) - // this is effectively the same as pushing to the right side, so maybe it could be removed and assert here - // that we don't reach this code - predicates[i] = predicate; - i += 1; - } - - Ok(()) - } - // Base cases - nowhere else to push to - SourceOperator::Scan { .. } => Ok(()), - SourceOperator::Search { .. } => Ok(()), - SourceOperator::Nothing { .. } => Ok(()), - } + Ok(ConstantConditionEliminationResult::Continue) } -/** - Push a single predicate down the tree, as far as possible. - Returns Ok(None) if the predicate was pushed, otherwise returns itself as Ok(Some(predicate)) -*/ -fn push_predicate( - operator: &mut SourceOperator, - predicate: ast::Expr, - referenced_tables: &Vec, -) -> Result> { - match operator { - SourceOperator::Subquery { - predicates, - table_reference, - .. - } => { - // **TODO**: we are currently just evaluating the predicate after the subquery yields, - // and not trying to do anythign more sophisticated. - // E.g. literally: SELECT * FROM (SELECT * FROM t1) sub WHERE sub.col = 'foo' - // - // It is possible, and not overly difficult, to determine that we can also push the - // predicate into the subquery coroutine itself before it yields. The above query would - // effectively become: SELECT * FROM (SELECT * FROM t1 WHERE col = 'foo') sub - // - // This matters more in cases where the subquery builds some kind of sorter/index in memory - // (or on disk) and in those cases pushing the predicate down to the coroutine will make the - // subquery produce less intermediate data. In cases where no intermediate data structures are - // built, it doesn't matter. - // - // Moreover, in many cases the subquery can even be completely eliminated, e.g. the above original - // query would become: SELECT * FROM t1 WHERE col = 'foo' without the subquery. - // **END TODO** - - // Find position of this subquery in referenced_tables array - let subquery_index = referenced_tables - .iter() - .position(|t| { - t.table_identifier == table_reference.table_identifier - && matches!(t.reference_type, TableReferenceType::Subquery { .. }) - }) - .unwrap(); - - // Get bitmask showing which tables this predicate references - let predicate_bitmask = - get_table_ref_bitmask_for_ast_expr(referenced_tables, &predicate)?; - - // Each table has a bit position based on join order from left to right - // e.g. in SELECT * FROM t1 JOIN t2 JOIN t3 - // t1 is position 0 (001), t2 is position 1 (010), t3 is position 2 (100) - // To push a predicate to a given table, it can only reference that table and tables to its left - // Example: For table t2 at position 1 (bit 010): - // - Can push: 011 (t2 + t1), 001 (just t1), 010 (just t2) - // - Can't push: 110 (t2 + t3) - let next_table_on_the_right_in_join_bitmask = 1 << (subquery_index + 1); - if predicate_bitmask >= next_table_on_the_right_in_join_bitmask { - return Ok(Some(predicate)); - } - - if predicates.is_none() { - predicates.replace(vec![predicate]); - } else { - predicates.as_mut().unwrap().push(predicate); - } - - Ok(None) - } - SourceOperator::Scan { - predicates, - table_reference, - .. - } => { - // Find position of this table in referenced_tables array - let table_index = referenced_tables - .iter() - .position(|t| { - t.table_identifier == table_reference.table_identifier - && t.reference_type == TableReferenceType::BTreeTable - }) - .unwrap(); - - // Get bitmask showing which tables this predicate references - let predicate_bitmask = - get_table_ref_bitmask_for_ast_expr(referenced_tables, &predicate)?; - - // Each table has a bit position based on join order from left to right - // e.g. in SELECT * FROM t1 JOIN t2 JOIN t3 - // t1 is position 0 (001), t2 is position 1 (010), t3 is position 2 (100) - // To push a predicate to a given table, it can only reference that table and tables to its left - // Example: For table t2 at position 1 (bit 010): - // - Can push: 011 (t2 + t1), 001 (just t1), 010 (just t2) - // - Can't push: 110 (t2 + t3) - let next_table_on_the_right_in_join_bitmask = 1 << (table_index + 1); - if predicate_bitmask >= next_table_on_the_right_in_join_bitmask { - return Ok(Some(predicate)); - } - - // Add predicate to this table's filters - if predicates.is_none() { - predicates.replace(vec![predicate]); - } else { - predicates.as_mut().unwrap().push(predicate); - } - - Ok(None) - } - // Search nodes don't exist yet at this point; Scans are transformed to Search in use_indexes() - SourceOperator::Search { .. } => unreachable!(), - SourceOperator::Join { - left, - right, - predicates: join_on_preds, - outer, - .. - } => { - // Try pushing to left side first - let push_result_left = push_predicate(left, predicate, referenced_tables)?; - if push_result_left.is_none() { - return Ok(None); - } - // Then try right side - let push_result_right = - push_predicate(right, push_result_left.unwrap(), referenced_tables)?; - if push_result_right.is_none() { - return Ok(None); - } - - // For LEFT JOIN, predicates must stay at join level - if *outer { - return Ok(Some(push_result_right.unwrap())); - } - - let pred = push_result_right.unwrap(); - - // Get bitmasks for tables referenced in predicate and both sides of join - let table_refs_bitmask = get_table_ref_bitmask_for_ast_expr(referenced_tables, &pred)?; - let left_bitmask = get_table_ref_bitmask_for_operator(referenced_tables, left)?; - let right_bitmask = get_table_ref_bitmask_for_operator(referenced_tables, right)?; - - // If predicate doesn't reference tables from both sides, it can't be a join condition - if table_refs_bitmask & left_bitmask == 0 || table_refs_bitmask & right_bitmask == 0 { - return Ok(Some(pred)); - } - - // Add as join predicate since it references both sides - if join_on_preds.is_none() { - join_on_preds.replace(vec![pred]); - } else { - join_on_preds.as_mut().unwrap().push(pred); - } - - Ok(None) - } - SourceOperator::Nothing { .. } => Ok(Some(predicate)), - } -} - -fn push_scan_direction(operator: &mut SourceOperator, direction: &Direction) { - match operator { - SourceOperator::Scan { iter_dir, .. } => { +fn push_scan_direction(table: &mut TableReference, direction: &Direction) { + match &mut table.op { + Operation::Scan { iter_dir, .. } => { if iter_dir.is_none() { match direction { Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards), @@ -599,22 +213,19 @@ fn push_scan_direction(operator: &mut SourceOperator, direction: &Direction) { } } } - _ => todo!(), + _ => {} } } fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> { - rewrite_source_operator_exprs(&mut plan.source)?; for rc in plan.result_columns.iter_mut() { rewrite_expr(&mut rc.expr)?; } for agg in plan.aggregates.iter_mut() { rewrite_expr(&mut agg.original_expr)?; } - if let Some(predicates) = &mut plan.where_clause { - for expr in predicates { - rewrite_expr(expr)?; - } + for cond in plan.where_clause.iter_mut() { + rewrite_expr(&mut cond.expr)?; } if let Some(group_by) = &mut plan.group_by { for expr in group_by.exprs.iter_mut() { @@ -631,57 +242,12 @@ fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> { } fn rewrite_exprs_delete(plan: &mut DeletePlan) -> Result<()> { - rewrite_source_operator_exprs(&mut plan.source)?; - if let Some(predicates) = &mut plan.where_clause { - for expr in predicates { - rewrite_expr(expr)?; - } + for cond in plan.where_clause.iter_mut() { + rewrite_expr(&mut cond.expr)?; } - Ok(()) } -fn rewrite_source_operator_exprs(operator: &mut SourceOperator) -> Result<()> { - match operator { - SourceOperator::Join { - left, - right, - predicates, - .. - } => { - rewrite_source_operator_exprs(left)?; - rewrite_source_operator_exprs(right)?; - - if let Some(predicates) = predicates { - for expr in predicates.iter_mut() { - rewrite_expr(expr)?; - } - } - - Ok(()) - } - SourceOperator::Scan { predicates, .. } | SourceOperator::Search { predicates, .. } => { - if let Some(predicates) = predicates { - for expr in predicates.iter_mut() { - rewrite_expr(expr)?; - } - } - - Ok(()) - } - SourceOperator::Subquery { predicates, .. } => { - if let Some(predicates) = predicates { - for expr in predicates.iter_mut() { - rewrite_expr(expr)?; - } - } - - Ok(()) - } - SourceOperator::Nothing { .. } => Ok(()), - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ConstantPredicate { AlwaysTrue, @@ -709,9 +275,9 @@ pub trait Optimizable { fn check_index_scan( &mut self, table_index: usize, - referenced_tables: &[TableReference], - available_indexes: &[Rc], - ) -> Result>; + table_reference: &TableReference, + available_indexes: &HashMap>>, + ) -> Result>>; } impl Optimizable for ast::Expr { @@ -728,21 +294,23 @@ impl Optimizable for ast::Expr { fn check_index_scan( &mut self, table_index: usize, - referenced_tables: &[TableReference], - available_indexes: &[Rc], - ) -> Result> { + table_reference: &TableReference, + available_indexes: &HashMap>>, + ) -> Result>> { match self { Self::Column { table, column, .. } => { if *table != table_index { return Ok(None); } - for (idx, index) in available_indexes.iter().enumerate() { - let table_ref = &referenced_tables[*table]; - if index.table_name == table_ref.table.get_name() { - let column = table_ref.table.get_column_at(*column); - if index.columns.first().unwrap().name == column.name { - return Ok(Some(idx)); - } + let Some(available_indexes_for_table) = + available_indexes.get(table_reference.table.get_name()) + else { + return Ok(None); + }; + let column = table_reference.table.get_column_at(*column); + for index in available_indexes_for_table.iter() { + if index.columns.first().unwrap().name == column.name { + return Ok(Some(index.clone())); } } Ok(None) @@ -766,12 +334,12 @@ impl Optimizable for ast::Expr { return Ok(None); } let lhs_index = - lhs.check_index_scan(table_index, referenced_tables, available_indexes)?; + lhs.check_index_scan(table_index, &table_reference, available_indexes)?; if lhs_index.is_some() { return Ok(lhs_index); } let rhs_index = - rhs.check_index_scan(table_index, referenced_tables, available_indexes)?; + rhs.check_index_scan(table_index, &table_reference, available_indexes)?; if rhs_index.is_some() { // swap lhs and rhs let swapped_operator = match *op { @@ -911,31 +479,52 @@ impl Optimizable for ast::Expr { } } -pub enum Either { - Left(T), - Right(U), +fn opposite_cmp_op(op: ast::Operator) -> ast::Operator { + match op { + ast::Operator::Equals => ast::Operator::Equals, + ast::Operator::Greater => ast::Operator::Less, + ast::Operator::GreaterEquals => ast::Operator::LessEquals, + ast::Operator::Less => ast::Operator::Greater, + ast::Operator::LessEquals => ast::Operator::GreaterEquals, + _ => panic!("unexpected operator: {:?}", op), + } } pub fn try_extract_index_search_expression( - expr: ast::Expr, + cond: &mut WhereTerm, table_index: usize, - referenced_tables: &[TableReference], - available_indexes: &[Rc], -) -> Result> { - match expr { - ast::Expr::Binary(mut lhs, operator, mut rhs) => { + table_reference: &TableReference, + available_indexes: &HashMap>>, +) -> Result> { + if cond.eval_at_loop != table_index { + return Ok(None); + } + match &mut cond.expr { + ast::Expr::Binary(lhs, operator, rhs) => { if lhs.is_rowid_alias_of(table_index) { match operator { ast::Operator::Equals => { - return Ok(Either::Right(Search::RowidEq { cmp_expr: *rhs })); + let rhs_owned = rhs.take_ownership(); + return Ok(Some(Search::RowidEq { + cmp_expr: WhereTerm { + expr: rhs_owned, + from_outer_join: cond.from_outer_join, + eval_at_loop: cond.eval_at_loop, + }, + })); } ast::Operator::Greater | ast::Operator::GreaterEquals | ast::Operator::Less | ast::Operator::LessEquals => { - return Ok(Either::Right(Search::RowidSearch { - cmp_op: operator, - cmp_expr: *rhs, + let rhs_owned = rhs.take_ownership(); + return Ok(Some(Search::RowidSearch { + cmp_op: *operator, + cmp_expr: WhereTerm { + expr: rhs_owned, + from_outer_join: cond.from_outer_join, + eval_at_loop: cond.eval_at_loop, + }, })); } _ => {} @@ -945,23 +534,35 @@ pub fn try_extract_index_search_expression( if rhs.is_rowid_alias_of(table_index) { match operator { ast::Operator::Equals => { - return Ok(Either::Right(Search::RowidEq { cmp_expr: *lhs })); + let lhs_owned = lhs.take_ownership(); + return Ok(Some(Search::RowidEq { + cmp_expr: WhereTerm { + expr: lhs_owned, + from_outer_join: cond.from_outer_join, + eval_at_loop: cond.eval_at_loop, + }, + })); } ast::Operator::Greater | ast::Operator::GreaterEquals | ast::Operator::Less | ast::Operator::LessEquals => { - return Ok(Either::Right(Search::RowidSearch { - cmp_op: operator, - cmp_expr: *lhs, + let lhs_owned = lhs.take_ownership(); + return Ok(Some(Search::RowidSearch { + cmp_op: opposite_cmp_op(*operator), + cmp_expr: WhereTerm { + expr: lhs_owned, + from_outer_join: cond.from_outer_join, + eval_at_loop: cond.eval_at_loop, + }, })); } _ => {} } } - if let Some(index_index) = - lhs.check_index_scan(table_index, referenced_tables, available_indexes)? + if let Some(index_rc) = + lhs.check_index_scan(table_index, &table_reference, available_indexes)? { match operator { ast::Operator::Equals @@ -969,18 +570,23 @@ pub fn try_extract_index_search_expression( | ast::Operator::GreaterEquals | ast::Operator::Less | ast::Operator::LessEquals => { - return Ok(Either::Right(Search::IndexSearch { - index: available_indexes[index_index].clone(), - cmp_op: operator, - cmp_expr: *rhs, + let rhs_owned = rhs.take_ownership(); + return Ok(Some(Search::IndexSearch { + index: index_rc, + cmp_op: *operator, + cmp_expr: WhereTerm { + expr: rhs_owned, + from_outer_join: cond.from_outer_join, + eval_at_loop: cond.eval_at_loop, + }, })); } _ => {} } } - if let Some(index_index) = - rhs.check_index_scan(table_index, referenced_tables, available_indexes)? + if let Some(index_rc) = + rhs.check_index_scan(table_index, &table_reference, available_indexes)? { match operator { ast::Operator::Equals @@ -988,19 +594,24 @@ pub fn try_extract_index_search_expression( | ast::Operator::GreaterEquals | ast::Operator::Less | ast::Operator::LessEquals => { - return Ok(Either::Right(Search::IndexSearch { - index: available_indexes[index_index].clone(), - cmp_op: operator, - cmp_expr: *lhs, + let lhs_owned = lhs.take_ownership(); + return Ok(Some(Search::IndexSearch { + index: index_rc, + cmp_op: opposite_cmp_op(*operator), + cmp_expr: WhereTerm { + expr: lhs_owned, + from_outer_join: cond.from_outer_join, + eval_at_loop: cond.eval_at_loop, + }, })); } _ => {} } } - Ok(Either::Left(ast::Expr::Binary(lhs, operator, rhs))) + Ok(None) } - _ => Ok(Either::Left(expr)), + _ => Ok(None), } } @@ -1080,6 +691,10 @@ fn rewrite_expr(expr: &mut ast::Expr) -> Result<()> { } Ok(()) } + ast::Expr::Unary(_, arg) => { + rewrite_expr(arg)?; + Ok(()) + } _ => Ok(()), } } diff --git a/core/translate/order_by.rs b/core/translate/order_by.rs index 1d02639d2..06e411175 100644 --- a/core/translate/order_by.rs +++ b/core/translate/order_by.rs @@ -17,7 +17,7 @@ use super::{ emitter::TranslateCtx, expr::translate_expr, plan::{Direction, ResultSetColumn, SelectPlan}, - result_row::emit_result_row_and_limit, + result_row::{emit_offset, emit_result_row_and_limit}, }; // Metadata for handling ORDER BY operations @@ -63,15 +63,20 @@ pub fn emit_order_by( let order_by = plan.order_by.as_ref().unwrap(); let result_columns = &plan.result_columns; let sort_loop_start_label = program.allocate_label(); + let sort_loop_next_label = program.allocate_label(); let sort_loop_end_label = program.allocate_label(); let mut pseudo_columns = vec![]; for (i, _) in order_by.iter().enumerate() { + let ty = crate::schema::Type::Null; pseudo_columns.push(Column { // Names don't matter. We are tracking which result column is in which position in the ORDER BY clause in m.result_column_indexes_in_orderby_sorter. name: format!("sort_key_{}", i), primary_key: false, - ty: crate::schema::Type::Null, + ty, + ty_str: ty.to_string(), is_rowid_alias: false, + notnull: false, + default: None, }); } for (i, rc) in result_columns.iter().enumerate() { @@ -81,11 +86,15 @@ pub fn emit_order_by( continue; } } + let ty = crate::schema::Type::Null; pseudo_columns.push(Column { name: rc.expr.to_string(), primary_key: false, - ty: crate::schema::Type::Null, + ty, + ty_str: ty.to_string(), is_rowid_alias: false, + notnull: false, + default: None, }); } @@ -117,6 +126,8 @@ pub fn emit_order_by( }); program.resolve_label(sort_loop_start_label, program.offset()); + emit_offset(program, t_ctx, plan, sort_loop_next_label)?; + program.emit_insn(Insn::SorterData { cursor_id: sort_cursor, dest_reg: reg_sorter_data, @@ -131,13 +142,14 @@ pub fn emit_order_by( let reg = start_reg + i; program.emit_insn(Insn::Column { cursor_id, - column: t_ctx.result_column_indexes_in_orderby_sorter[&i], + column: t_ctx.result_column_indexes_in_orderby_sorter[i], dest: reg, }); } emit_result_row_and_limit(program, t_ctx, plan, start_reg, Some(sort_loop_end_label))?; + program.resolve_label(sort_loop_next_label, program.offset()); program.emit_insn(Insn::SorterNext { cursor_id: sort_cursor, pc_if_next: sort_loop_start_label, @@ -172,7 +184,7 @@ pub fn order_by_sorter_insert( let key_reg = start_reg + i; translate_expr( program, - Some(&plan.referenced_tables), + Some(&plan.table_references), expr, key_reg, &t_ctx.resolver, @@ -193,7 +205,7 @@ pub fn order_by_sorter_insert( } translate_expr( program, - Some(&plan.referenced_tables), + Some(&plan.table_references), &rc.expr, cur_reg, &t_ctx.resolver, diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 6164428fb..6209d37a8 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -9,7 +9,6 @@ use crate::{ function::AggFunc, schema::{BTreeTable, Column, Index, Table}, vdbe::BranchOffset, - Result, }; use crate::{ schema::{PseudoTable, Type}, @@ -31,6 +30,28 @@ pub struct GroupBy { pub having: Option>, } +/// In a query plan, WHERE clause conditions and JOIN conditions are all folded into a vector of WhereTerm. +/// This is done so that we can evaluate the conditions at the correct loop depth. +/// We also need to keep track of whether the condition came from an OUTER JOIN. Take this example: +/// SELECT * FROM users u LEFT JOIN products p ON u.id = 5. +/// Even though the condition only refers to 'u', we CANNOT evaluate it at the users loop, because we need to emit NULL +/// values for the columns of 'p', for EVERY row in 'u', instead of completely skipping any rows in 'u' where the condition is false. +#[derive(Debug, Clone)] +pub struct WhereTerm { + /// The original condition expression. + pub expr: ast::Expr, + /// Is this condition originally from an OUTER JOIN? + /// If so, we need to evaluate it at the loop of the right table in that JOIN, + /// regardless of which tables it references. + /// We also cannot e.g. short circuit the entire query in the optimizer if the condition is statically false. + pub from_outer_join: bool, + /// The loop index where to evaluate the condition. + /// For example, in `SELECT * FROM u JOIN p WHERE u.id = 5`, the condition can already be evaluated at the first loop (idx 0), + /// because that is the rightmost table that it references. + pub eval_at_loop: usize, +} + +/// A query plan is either a SELECT or a DELETE (for now) #[derive(Debug, Clone)] pub enum Plan { Select(SelectPlan), @@ -51,12 +72,13 @@ pub enum SelectQueryType { #[derive(Debug, Clone)] pub struct SelectPlan { - /// A tree of sources (tables). - pub source: SourceOperator, + /// List of table references in loop order, outermost first. + pub table_references: Vec, /// the columns inside SELECT ... FROM pub result_columns: Vec, - /// where clause split into a vec at 'AND' boundaries. - pub where_clause: Option>, + /// where clause split into a vec at 'AND' boundaries. all join conditions also get shoved in here, + /// and we keep track of which join they came from (mainly for OUTER JOIN processing) + pub where_clause: Vec, /// group by clause pub group_by: Option, /// order by clause @@ -64,11 +86,9 @@ pub struct SelectPlan { /// all the aggregates collected from the result columns, order by, and (TODO) having clauses pub aggregates: Vec, /// limit clause - pub limit: Option, - /// all the tables referenced in the query - pub referenced_tables: Vec, - /// all the indexes available - pub available_indexes: Vec>, + pub limit: Option, + /// offset clause + pub offset: Option, /// query contains a constant condition that is always false pub contains_constant_false_condition: bool, /// query type (top level or subquery) @@ -78,207 +98,148 @@ pub struct SelectPlan { #[allow(dead_code)] #[derive(Debug, Clone)] pub struct DeletePlan { - /// A tree of sources (tables). - pub source: SourceOperator, + /// List of table references. Delete is always a single table. + pub table_references: Vec, /// the columns inside SELECT ... FROM pub result_columns: Vec, /// where clause split into a vec at 'AND' boundaries. - pub where_clause: Option>, + pub where_clause: Vec, /// order by clause pub order_by: Option>, /// limit clause - pub limit: Option, - /// all the tables referenced in the query - pub referenced_tables: Vec, - /// all the indexes available - pub available_indexes: Vec>, + pub limit: Option, + /// offset clause + pub offset: Option, /// query contains a constant condition that is always false pub contains_constant_false_condition: bool, } -impl Display for Plan { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Select(select_plan) => write!(f, "{}", select_plan.source), - Delete(delete_plan) => write!(f, "{}", delete_plan.source), - } - } -} - #[derive(Clone, Debug, PartialEq, Eq)] pub enum IterationDirection { Forwards, Backwards, } -impl SourceOperator { - pub fn select_star(&self, out_columns: &mut Vec) { - for (table_index, col, idx) in self.select_star_helper() { - out_columns.push(ResultSetColumn { - name: col.name.clone(), - expr: ast::Expr::Column { - database: None, - table: table_index, - column: idx, - is_rowid_alias: col.is_rowid_alias, - }, - contains_aggregates: false, - }); - } - } - - /// All this ceremony is required to deduplicate columns when joining with USING - fn select_star_helper(&self) -> Vec<(usize, &Column, usize)> { - match self { - SourceOperator::Join { - left, right, using, .. - } => { - let mut columns = left.select_star_helper(); - - // Join columns are filtered out from the right side - // in the case of a USING join. - if let Some(using_cols) = using { - let right_columns = right.select_star_helper(); - - for (table_index, col, idx) in right_columns { - if !using_cols - .iter() - .any(|using_col| col.name.eq_ignore_ascii_case(&using_col.0)) - { - columns.push((table_index, col, idx)); - } - } - } else { - columns.extend(right.select_star_helper()); - } - columns - } - SourceOperator::Scan { - table_reference, .. - } - | SourceOperator::Search { - table_reference, .. - } - | SourceOperator::Subquery { - table_reference, .. - } => table_reference +pub fn select_star(tables: &[TableReference], out_columns: &mut Vec) { + for (current_table_index, table) in tables.iter().enumerate() { + let maybe_using_cols = table + .join_info + .as_ref() + .and_then(|join_info| join_info.using.as_ref()); + out_columns.extend( + table .columns() .iter() .enumerate() - .map(|(i, col)| (table_reference.table_index, col, i)) - .collect(), - SourceOperator::Nothing { .. } => Vec::new(), - } + .filter(|(_, col)| { + // If we are joining with USING, we need to deduplicate the columns from the right table + // that are also present in the USING clause. + if let Some(using_cols) = maybe_using_cols { + !using_cols + .iter() + .any(|using_col| col.name.eq_ignore_ascii_case(&using_col.0)) + } else { + true + } + }) + .map(|(i, col)| ResultSetColumn { + name: col.name.clone(), + expr: ast::Expr::Column { + database: None, + table: current_table_index, + column: i, + is_rowid_alias: col.is_rowid_alias, + }, + contains_aggregates: false, + }), + ); } } +/// Join information for a table reference. +#[derive(Debug, Clone)] +pub struct JoinInfo { + /// Whether this is an OUTER JOIN. + pub outer: bool, + /// The USING clause for the join, if any. NATURAL JOIN is transformed into USING (col1, col2, ...). + pub using: Option, +} + +/// A table reference in the query plan. +/// For example, SELECT * FROM users u JOIN products p JOIN (SELECT * FROM users) sub +/// has three table references: +/// 1. operation=Scan, table=users, table_identifier=u, reference_type=BTreeTable, join_info=None +/// 2. operation=Scan, table=products, table_identifier=p, reference_type=BTreeTable, join_info=Some(JoinInfo { outer: false, using: None }), +/// 3. operation=Subquery, table=users, table_identifier=sub, reference_type=Subquery, join_info=None +#[derive(Debug, Clone)] +pub struct TableReference { + /// The operation that this table reference performs. + pub op: Operation, + /// Table object, which contains metadata about the table, e.g. columns. + pub table: Table, + /// The name of the table as referred to in the query, either the literal name or an alias e.g. "users" or "u" + pub identifier: String, + /// The join info for this table reference, if it is the right side of a join (which all except the first table reference have) + pub join_info: Option, +} + /** - A SourceOperator is a Node in the query plan that reads data from a table. + A SourceOperator is a reference in the query plan that reads data from a table. */ #[derive(Clone, Debug)] -pub enum SourceOperator { - // Join operator - // This operator is used to join two source operators. - // It takes a left and right source operator, a list of predicates to evaluate, - // and a boolean indicating whether it is an outer join. - Join { - id: usize, - left: Box, - right: Box, - predicates: Option>, - outer: bool, - using: Option, - }, - // Scan operator - // This operator is used to scan a table. - // It takes a table to scan and an optional list of predicates to evaluate. - // The predicates are used to filter rows from the table. - // e.g. SELECT * FROM t1 WHERE t1.foo = 5 +pub enum Operation { + // Scan operation + // This operation is used to scan a table. // The iter_dir are uset to indicate the direction of the iterator. // The use of Option for iter_dir is aimed at implementing a conservative optimization strategy: it only pushes // iter_dir down to Scan when iter_dir is None, to prevent potential result set errors caused by multiple - // assignments. for more detailed discussions, please refer to https://github.com/penberg/limbo/pull/376 + // assignments. for more detailed discussions, please refer to https://github.com/tursodatabase/limbo/pull/376 Scan { - id: usize, - table_reference: TableReference, - predicates: Option>, iter_dir: Option, }, - // Search operator - // This operator is used to search for a row in a table using an index + // Search operation + // This operation is used to search for a row in a table using an index // (i.e. a primary key or a secondary index) - Search { - id: usize, - table_reference: TableReference, - search: Search, - predicates: Option>, - }, + Search(Search), + /// Subquery operation + /// This operation is used to represent a subquery in the query plan. + /// The subquery itself (recursively) contains an arbitrary SelectPlan. Subquery { - id: usize, - table_reference: TableReference, plan: Box, - predicates: Option>, - }, - // Nothing operator - // This operator is used to represent an empty query. - // e.g. SELECT * from foo WHERE 0 will eventually be optimized to Nothing. - Nothing { - id: usize, - }, -} - -/// The type of the table reference, either BTreeTable or Subquery -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum TableReferenceType { - /// A BTreeTable is a table that is stored on disk in a B-tree index. - BTreeTable, - /// A subquery. - Subquery { - /// The index of the first register in the query plan that contains the result columns of the subquery. result_columns_start_reg: usize, }, } -/// A query plan has a list of TableReference objects, each of which represents a table or subquery. -#[derive(Clone, Debug)] -pub struct TableReference { - /// Table object, which contains metadata about the table, e.g. columns. - pub table: Table, - /// The name of the table as referred to in the query, either the literal name or an alias e.g. "users" or "u" - pub table_identifier: String, - /// The index of this reference in the list of TableReference objects in the query plan - /// The reference at index 0 is the first table in the FROM clause, the reference at index 1 is the second table in the FROM clause, etc. - /// So, the index is relevant for determining when predicates (WHERE, ON filters etc.) should be evaluated. - pub table_index: usize, - /// The type of the table reference, either BTreeTable or Subquery - pub reference_type: TableReferenceType, -} - impl TableReference { + /// Returns the btree table for this table reference, if it is a BTreeTable. pub fn btree(&self) -> Option> { - match self.reference_type { - TableReferenceType::BTreeTable => self.table.btree(), - TableReferenceType::Subquery { .. } => None, - } + self.table.btree() } - pub fn new_subquery(identifier: String, table_index: usize, plan: &SelectPlan) -> Self { + + /// Creates a new TableReference for a subquery. + pub fn new_subquery(identifier: String, plan: SelectPlan, join_info: Option) -> Self { + let table = Table::Pseudo(Rc::new(PseudoTable::new_with_columns( + plan.result_columns + .iter() + .map(|rc| Column { + name: rc.name.clone(), + ty: Type::Text, // FIXME: infer proper type + ty_str: "TEXT".to_string(), + is_rowid_alias: false, + primary_key: false, + notnull: false, + default: None, + }) + .collect(), + ))); Self { - table: Table::Pseudo(Rc::new(PseudoTable::new_with_columns( - plan.result_columns - .iter() - .map(|rc| Column { - name: rc.name.clone(), - ty: Type::Text, // FIXME: infer proper type - is_rowid_alias: false, - primary_key: false, - }) - .collect(), - ))), - table_identifier: identifier.clone(), - table_index, - reference_type: TableReferenceType::Subquery { + op: Operation::Subquery { + plan: Box::new(plan), result_columns_start_reg: 0, // Will be set in the bytecode emission phase }, + table, + identifier: identifier.clone(), + join_info, } } @@ -293,32 +254,20 @@ impl TableReference { #[derive(Clone, Debug)] pub enum Search { /// A rowid equality point lookup. This is a special case that uses the SeekRowid bytecode instruction and does not loop. - RowidEq { cmp_expr: ast::Expr }, + RowidEq { cmp_expr: WhereTerm }, /// A rowid search. Uses bytecode instructions like SeekGT, SeekGE etc. RowidSearch { cmp_op: ast::Operator, - cmp_expr: ast::Expr, + cmp_expr: WhereTerm, }, /// A secondary index search. Uses bytecode instructions like SeekGE, SeekGT etc. IndexSearch { index: Rc, cmp_op: ast::Operator, - cmp_expr: ast::Expr, + cmp_expr: WhereTerm, }, } -impl SourceOperator { - pub fn id(&self) -> usize { - match self { - SourceOperator::Join { id, .. } => *id, - SourceOperator::Scan { id, .. } => *id, - SourceOperator::Search { id, .. } => *id, - SourceOperator::Subquery { id, .. } => *id, - SourceOperator::Nothing { id } => *id, - } - } -} - #[derive(Clone, Copy, Debug, PartialEq)] pub enum Direction { Ascending, @@ -353,204 +302,98 @@ impl Display for Aggregate { } } -// For EXPLAIN QUERY PLAN -impl Display for SourceOperator { +/// For EXPLAIN QUERY PLAN +impl Display for Plan { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Select(select_plan) => select_plan.fmt(f), + Delete(delete_plan) => delete_plan.fmt(f), + } + } +} + +impl Display for SelectPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - fn fmt_operator( - operator: &SourceOperator, - f: &mut Formatter, - level: usize, - last: bool, - ) -> fmt::Result { - let indent = if level == 0 { - if last { "`--" } else { "|--" }.to_string() + writeln!(f, "QUERY PLAN")?; + + // Print each table reference with appropriate indentation based on join depth + for (i, reference) in self.table_references.iter().enumerate() { + let is_last = i == self.table_references.len() - 1; + let indent = if i == 0 { + if is_last { "`--" } else { "|--" }.to_string() } else { format!( " {}{}", - "| ".repeat(level - 1), - if last { "`--" } else { "|--" } + "| ".repeat(i - 1), + if is_last { "`--" } else { "|--" } ) }; - match operator { - SourceOperator::Join { - left, - right, - predicates, - outer, - .. - } => { - let join_name = if *outer { "OUTER JOIN" } else { "JOIN" }; - match predicates - .as_ref() - .and_then(|ps| if ps.is_empty() { None } else { Some(ps) }) - { - Some(ps) => { - let predicates_string = ps - .iter() - .map(|p| p.to_string()) - .collect::>() - .join(" AND "); - writeln!(f, "{}{} ON {}", indent, join_name, predicates_string)?; - } - None => writeln!(f, "{}{}", indent, join_name)?, + match &reference.op { + Operation::Scan { .. } => { + let table_name = if reference.table.get_name() == reference.identifier { + reference.identifier.clone() + } else { + format!("{} AS {}", reference.table.get_name(), reference.identifier) + }; + + writeln!(f, "{}SCAN {}", indent, table_name)?; + } + Operation::Search(search) => match search { + Search::RowidEq { .. } | Search::RowidSearch { .. } => { + writeln!( + f, + "{}SEARCH {} USING INTEGER PRIMARY KEY (rowid=?)", + indent, reference.identifier + )?; } - fmt_operator(left, f, level + 1, false)?; - fmt_operator(right, f, level + 1, true) - } - SourceOperator::Scan { - table_reference, - predicates: filter, - .. - } => { - let table_name = - if table_reference.table.get_name() == table_reference.table_identifier { - table_reference.table_identifier.clone() - } else { - format!( - "{} AS {}", - &table_reference.table.get_name(), - &table_reference.table_identifier - ) - }; - let filter_string = filter.as_ref().map(|f| { - let filters_string = f - .iter() - .map(|p| p.to_string()) - .collect::>() - .join(" AND "); - format!("FILTER {}", filters_string) - }); - match filter_string { - Some(fs) => writeln!(f, "{}SCAN {} {}", indent, table_name, fs), - None => writeln!(f, "{}SCAN {}", indent, table_name), - }?; - Ok(()) - } - SourceOperator::Search { - table_reference, - search, - .. - } => { - match search { - Search::RowidEq { .. } | Search::RowidSearch { .. } => { - writeln!( - f, - "{}SEARCH {} USING INTEGER PRIMARY KEY (rowid=?)", - indent, table_reference.table_identifier - )?; - } - Search::IndexSearch { index, .. } => { - writeln!( - f, - "{}SEARCH {} USING INDEX {}", - indent, table_reference.table_identifier, index.name - )?; - } + Search::IndexSearch { index, .. } => { + writeln!( + f, + "{}SEARCH {} USING INDEX {}", + indent, reference.identifier, index.name + )?; + } + }, + Operation::Subquery { plan, .. } => { + writeln!(f, "{}SUBQUERY {}", indent, reference.identifier)?; + // Indent and format the subquery plan + for line in format!("{}", plan).lines() { + writeln!(f, "{} {}", indent, line)?; } - Ok(()) } - SourceOperator::Subquery { plan, .. } => { - fmt_operator(&plan.source, f, level + 1, last) - } - SourceOperator::Nothing { .. } => Ok(()), } } + Ok(()) + } +} + +impl Display for DeletePlan { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { writeln!(f, "QUERY PLAN")?; - fmt_operator(self, f, 0, true) - } -} -/** - Returns a bitmask where each bit corresponds to a table in the `tables` vector. - If a table is referenced in the given Operator, the corresponding bit is set to 1. - Example: - if tables = [(table1, "t1"), (table2, "t2"), (table3, "t3")], - and the Operator is a join between table2 and table3, - then the return value will be (in bits): 110 -*/ -pub fn get_table_ref_bitmask_for_operator<'a>( - tables: &'a Vec, - operator: &'a SourceOperator, -) -> Result { - let mut table_refs_mask = 0; - match operator { - SourceOperator::Join { left, right, .. } => { - table_refs_mask |= get_table_ref_bitmask_for_operator(tables, left)?; - table_refs_mask |= get_table_ref_bitmask_for_operator(tables, right)?; - } - SourceOperator::Scan { - table_reference, .. - } => { - table_refs_mask |= 1 - << tables - .iter() - .position(|t| t.table_identifier == table_reference.table_identifier) - .unwrap(); - } - SourceOperator::Search { - table_reference, .. - } => { - table_refs_mask |= 1 - << tables - .iter() - .position(|t| t.table_identifier == table_reference.table_identifier) - .unwrap(); - } - SourceOperator::Subquery { .. } => {} - SourceOperator::Nothing { .. } => {} - } - Ok(table_refs_mask) -} + // Delete plan should only have one table reference + if let Some(reference) = self.table_references.first() { + let indent = "`--"; -/** - Returns a bitmask where each bit corresponds to a table in the `tables` vector. - If a table is referenced in the given AST expression, the corresponding bit is set to 1. - Example: - if tables = [(table1, "t1"), (table2, "t2"), (table3, "t3")], - and predicate = "t1.a = t2.b" - then the return value will be (in bits): 011 -*/ -#[allow(clippy::only_used_in_recursion)] -pub fn get_table_ref_bitmask_for_ast_expr<'a>( - tables: &'a Vec, - predicate: &'a ast::Expr, -) -> Result { - let mut table_refs_mask = 0; - match predicate { - ast::Expr::Binary(e1, _, e2) => { - table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, e1)?; - table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, e2)?; - } - ast::Expr::Column { table, .. } => { - table_refs_mask |= 1 << table; - } - ast::Expr::Id(_) => unreachable!("Id should be resolved to a Column before optimizer"), - ast::Expr::Qualified(_, _) => { - unreachable!("Qualified should be resolved to a Column before optimizer") - } - ast::Expr::Literal(_) => {} - ast::Expr::Like { lhs, rhs, .. } => { - table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, lhs)?; - table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, rhs)?; - } - ast::Expr::FunctionCall { - args: Some(args), .. - } => { - for arg in args { - table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, arg)?; - } - } - ast::Expr::InList { lhs, rhs, .. } => { - table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, lhs)?; - if let Some(rhs_list) = rhs { - for rhs_expr in rhs_list { - table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, rhs_expr)?; + match &reference.op { + Operation::Scan { .. } => { + let table_name = if reference.table.get_name() == reference.identifier { + reference.identifier.clone() + } else { + format!("{} AS {}", reference.table.get_name(), reference.identifier) + }; + + writeln!(f, "{}DELETE FROM {}", indent, table_name)?; + } + Operation::Search { .. } => { + panic!("DELETE plans should not contain search operations"); + } + Operation::Subquery { .. } => { + panic!("DELETE plans should not contain subqueries"); } } } - _ => {} + Ok(()) } - - Ok(table_refs_mask) } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 86132fb60..1dc44b8a8 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,5 +1,8 @@ use super::{ - plan::{Aggregate, Plan, SelectQueryType, SourceOperator, TableReference, TableReferenceType}, + plan::{ + Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, + WhereTerm, + }, select::prepare_select_plan, SymbolTable, }; @@ -10,25 +13,10 @@ use crate::{ vdbe::BranchOffset, Result, }; -use sqlite3_parser::ast::{self, Expr, FromClause, JoinType, Limit}; +use sqlite3_parser::ast::{self, Expr, FromClause, JoinType, Limit, UnaryOperator}; pub const ROWID: &str = "rowid"; -pub struct OperatorIdCounter { - id: usize, -} - -impl OperatorIdCounter { - pub fn new() -> Self { - Self { id: 1 } - } - pub fn get_next_id(&mut self) -> usize { - let id = self.id; - self.id += 1; - id - } -} - pub fn resolve_aggregates(expr: &Expr, aggs: &mut Vec) -> bool { if aggs .iter() @@ -93,7 +81,11 @@ pub fn resolve_aggregates(expr: &Expr, aggs: &mut Vec) -> bool { } } -pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReference]) -> Result<()> { +pub fn bind_column_references( + expr: &mut Expr, + referenced_tables: &[TableReference], + result_columns: Option<&[ResultSetColumn]>, +) -> Result<()> { match expr { Expr::Id(id) => { // true and false are special constants that are effectively aliases for 1 and 0 @@ -126,24 +118,31 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen match_result = Some((tbl_idx, col_idx.unwrap(), col.is_rowid_alias)); } } - if match_result.is_none() { - crate::bail_parse_error!("Column {} not found", id.0); + if let Some((tbl_idx, col_idx, is_rowid_alias)) = match_result { + *expr = Expr::Column { + database: None, // TODO: support different databases + table: tbl_idx, + column: col_idx, + is_rowid_alias, + }; + return Ok(()); } - let (tbl_idx, col_idx, is_rowid_alias) = match_result.unwrap(); - *expr = Expr::Column { - database: None, // TODO: support different databases - table: tbl_idx, - column: col_idx, - is_rowid_alias, - }; - Ok(()) + + if let Some(result_columns) = result_columns { + for result_column in result_columns.iter() { + if result_column.name == normalized_id { + *expr = result_column.expr.clone(); + return Ok(()); + } + } + } + crate::bail_parse_error!("Column {} not found", id.0); } Expr::Qualified(tbl, id) => { let normalized_table_name = normalize_ident(tbl.0.as_str()); - let matching_tbl_idx = referenced_tables.iter().position(|t| { - t.table_identifier - .eq_ignore_ascii_case(&normalized_table_name) - }); + let matching_tbl_idx = referenced_tables + .iter() + .position(|t| t.identifier.eq_ignore_ascii_case(&normalized_table_name)); if matching_tbl_idx.is_none() { crate::bail_parse_error!("Table {} not found", normalized_table_name); } @@ -180,14 +179,14 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen start, end, } => { - bind_column_references(lhs, referenced_tables)?; - bind_column_references(start, referenced_tables)?; - bind_column_references(end, referenced_tables)?; + bind_column_references(lhs, referenced_tables, result_columns)?; + bind_column_references(start, referenced_tables, result_columns)?; + bind_column_references(end, referenced_tables, result_columns)?; Ok(()) } Expr::Binary(expr, _operator, expr1) => { - bind_column_references(expr, referenced_tables)?; - bind_column_references(expr1, referenced_tables)?; + bind_column_references(expr, referenced_tables, result_columns)?; + bind_column_references(expr1, referenced_tables, result_columns)?; Ok(()) } Expr::Case { @@ -196,19 +195,23 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen else_expr, } => { if let Some(base) = base { - bind_column_references(base, referenced_tables)?; + bind_column_references(base, referenced_tables, result_columns)?; } for (when, then) in when_then_pairs { - bind_column_references(when, referenced_tables)?; - bind_column_references(then, referenced_tables)?; + bind_column_references(when, referenced_tables, result_columns)?; + bind_column_references(then, referenced_tables, result_columns)?; } if let Some(else_expr) = else_expr { - bind_column_references(else_expr, referenced_tables)?; + bind_column_references(else_expr, referenced_tables, result_columns)?; } Ok(()) } - Expr::Cast { expr, type_name: _ } => bind_column_references(expr, referenced_tables), - Expr::Collate(expr, _string) => bind_column_references(expr, referenced_tables), + Expr::Cast { expr, type_name: _ } => { + bind_column_references(expr, referenced_tables, result_columns) + } + Expr::Collate(expr, _string) => { + bind_column_references(expr, referenced_tables, result_columns) + } Expr::FunctionCall { name: _, distinctness: _, @@ -218,7 +221,7 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen } => { if let Some(args) = args { for arg in args { - bind_column_references(arg, referenced_tables)?; + bind_column_references(arg, referenced_tables, result_columns)?; } } Ok(()) @@ -229,10 +232,10 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen Expr::Exists(_) => todo!(), Expr::FunctionCallStar { .. } => Ok(()), Expr::InList { lhs, not: _, rhs } => { - bind_column_references(lhs, referenced_tables)?; + bind_column_references(lhs, referenced_tables, result_columns)?; if let Some(rhs) = rhs { for arg in rhs { - bind_column_references(arg, referenced_tables)?; + bind_column_references(arg, referenced_tables, result_columns)?; } } Ok(()) @@ -240,30 +243,30 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen Expr::InSelect { .. } => todo!(), Expr::InTable { .. } => todo!(), Expr::IsNull(expr) => { - bind_column_references(expr, referenced_tables)?; + bind_column_references(expr, referenced_tables, result_columns)?; Ok(()) } Expr::Like { lhs, rhs, .. } => { - bind_column_references(lhs, referenced_tables)?; - bind_column_references(rhs, referenced_tables)?; + bind_column_references(lhs, referenced_tables, result_columns)?; + bind_column_references(rhs, referenced_tables, result_columns)?; Ok(()) } Expr::Literal(_) => Ok(()), Expr::Name(_) => todo!(), Expr::NotNull(expr) => { - bind_column_references(expr, referenced_tables)?; + bind_column_references(expr, referenced_tables, result_columns)?; Ok(()) } Expr::Parenthesized(expr) => { for e in expr.iter_mut() { - bind_column_references(e, referenced_tables)?; + bind_column_references(e, referenced_tables, result_columns)?; } Ok(()) } Expr::Raise(_, _) => todo!(), Expr::Subquery(_) => todo!(), Expr::Unary(_, expr) => { - bind_column_references(expr, referenced_tables)?; + bind_column_references(expr, referenced_tables, result_columns)?; Ok(()) } Expr::Variable(_) => Ok(()), @@ -273,10 +276,9 @@ pub fn bind_column_references(expr: &mut Expr, referenced_tables: &[TableReferen fn parse_from_clause_table( schema: &Schema, table: ast::SelectTable, - operator_id_counter: &mut OperatorIdCounter, cur_table_index: usize, syms: &SymbolTable, -) -> Result<(TableReference, SourceOperator)> { +) -> Result { match table { ast::SelectTable::Table(qualified_name, maybe_alias, _) => { let normalized_qualified_name = normalize_ident(qualified_name.name.0.as_str()); @@ -289,21 +291,12 @@ fn parse_from_clause_table( ast::As::Elided(id) => id, }) .map(|a| a.0); - let table_reference = TableReference { + Ok(TableReference { + op: Operation::Scan { iter_dir: None }, table: Table::BTree(table.clone()), - table_identifier: alias.unwrap_or(normalized_qualified_name), - table_index: cur_table_index, - reference_type: TableReferenceType::BTreeTable, - }; - Ok(( - table_reference.clone(), - SourceOperator::Scan { - table_reference, - predicates: None, - id: operator_id_counter.get_next_id(), - iter_dir: None, - }, - )) + identifier: alias.unwrap_or(normalized_qualified_name), + join_info: None, + }) } ast::SelectTable::Select(subselect, maybe_alias) => { let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect, syms)? else { @@ -319,17 +312,8 @@ fn parse_from_clause_table( ast::As::Elided(id) => id.0.clone(), }) .unwrap_or(format!("subquery_{}", cur_table_index)); - let table_reference = - TableReference::new_subquery(identifier.clone(), cur_table_index, &subplan); - Ok(( - table_reference.clone(), - SourceOperator::Subquery { - id: operator_id_counter.get_next_id(), - table_reference, - plan: Box::new(subplan), - predicates: None, - }, - )) + let table_reference = TableReference::new_subquery(identifier, subplan, None); + Ok(table_reference) } _ => todo!(), } @@ -338,99 +322,125 @@ fn parse_from_clause_table( pub fn parse_from( schema: &Schema, mut from: Option, - operator_id_counter: &mut OperatorIdCounter, syms: &SymbolTable, -) -> Result<(SourceOperator, Vec)> { + out_where_clause: &mut Vec, +) -> Result> { if from.as_ref().and_then(|f| f.select.as_ref()).is_none() { - return Ok(( - SourceOperator::Nothing { - id: operator_id_counter.get_next_id(), - }, - vec![], - )); + return Ok(vec![]); } - let mut table_index = 0; let mut tables = vec![]; let mut from_owned = std::mem::take(&mut from).unwrap(); let select_owned = *std::mem::take(&mut from_owned.select).unwrap(); let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default(); - let (table_reference, mut operator) = - parse_from_clause_table(schema, select_owned, operator_id_counter, table_index, syms)?; - + let table_reference = parse_from_clause_table(schema, select_owned, 0, syms)?; tables.push(table_reference); - table_index += 1; for join in joins_owned.into_iter() { - let JoinParseResult { - source_operator: right, - is_outer_join: outer, - using, - predicates, - } = parse_join( - schema, - join, - operator_id_counter, - &mut tables, - table_index, - syms, - )?; - operator = SourceOperator::Join { - left: Box::new(operator), - right: Box::new(right), - predicates, - outer, - using, - id: operator_id_counter.get_next_id(), - }; - table_index += 1; + parse_join(schema, join, syms, &mut tables, out_where_clause)?; } - Ok((operator, tables)) + Ok(tables) } pub fn parse_where( where_clause: Option, - referenced_tables: &[TableReference], -) -> Result>> { + table_references: &[TableReference], + result_columns: Option<&[ResultSetColumn]>, + out_where_clause: &mut Vec, +) -> Result<()> { if let Some(where_expr) = where_clause { let mut predicates = vec![]; break_predicate_at_and_boundaries(where_expr, &mut predicates); for expr in predicates.iter_mut() { - bind_column_references(expr, referenced_tables)?; + bind_column_references(expr, table_references, result_columns)?; } - Ok(Some(predicates)) + for expr in predicates { + let eval_at_loop = get_rightmost_table_referenced_in_expr(&expr)?; + out_where_clause.push(WhereTerm { + expr, + from_outer_join: false, + eval_at_loop, + }); + } + Ok(()) } else { - Ok(None) + Ok(()) } } -struct JoinParseResult { - source_operator: SourceOperator, - is_outer_join: bool, - using: Option, - predicates: Option>, +/** + Returns the rightmost table index that is referenced in the given AST expression. + Rightmost = innermost loop. + This is used to determine where we should evaluate a given condition expression, + and it needs to be the rightmost table referenced in the expression, because otherwise + the condition would be evaluated before a row is read from that table. +*/ +fn get_rightmost_table_referenced_in_expr<'a>(predicate: &'a ast::Expr) -> Result { + let mut max_table_idx = 0; + match predicate { + ast::Expr::Binary(e1, _, e2) => { + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(e1)?); + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(e2)?); + } + ast::Expr::Column { table, .. } => { + max_table_idx = max_table_idx.max(*table); + } + ast::Expr::Id(_) => { + /* Id referring to column will already have been rewritten as an Expr::Column */ + /* we only get here with literal 'true' or 'false' etc */ + } + ast::Expr::Qualified(_, _) => { + unreachable!("Qualified should be resolved to a Column before optimizer") + } + ast::Expr::Literal(_) => {} + ast::Expr::Like { lhs, rhs, .. } => { + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(lhs)?); + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(rhs)?); + } + ast::Expr::FunctionCall { + args: Some(args), .. + } => { + for arg in args { + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(arg)?); + } + } + ast::Expr::InList { lhs, rhs, .. } => { + max_table_idx = max_table_idx.max(get_rightmost_table_referenced_in_expr(lhs)?); + if let Some(rhs_list) = rhs { + for rhs_expr in rhs_list { + max_table_idx = + max_table_idx.max(get_rightmost_table_referenced_in_expr(rhs_expr)?); + } + } + } + _ => {} + } + + Ok(max_table_idx) } fn parse_join( schema: &Schema, join: ast::JoinedSelectTable, - operator_id_counter: &mut OperatorIdCounter, - tables: &mut Vec, - table_index: usize, syms: &SymbolTable, -) -> Result { + tables: &mut Vec, + out_where_clause: &mut Vec, +) -> Result<()> { let ast::JoinedSelectTable { operator: join_operator, table, constraint, } = join; - let (table_reference, source_operator) = - parse_from_clause_table(schema, table, operator_id_counter, table_index, syms)?; - - tables.push(table_reference); + let cur_table_index = tables.len(); + tables.push(parse_from_clause_table( + schema, + table, + cur_table_index, + syms, + )?); let (outer, natural) = match join_operator { ast::JoinOperator::TypedJoin(Some(join_type)) => { @@ -442,23 +452,21 @@ fn parse_join( }; let mut using = None; - let mut predicates = None; if natural && constraint.is_some() { crate::bail_parse_error!("NATURAL JOIN cannot be combined with ON or USING clause"); } let constraint = if natural { + assert!(tables.len() >= 2); + let rightmost_table = tables.last().unwrap(); // NATURAL JOIN is first transformed into a USING join with the common columns - let left_tables = &tables[..table_index]; - assert!(!left_tables.is_empty()); - let right_table = &tables[table_index]; - let right_cols = &right_table.columns(); + let right_cols = rightmost_table.columns(); let mut distinct_names: Option = None; // TODO: O(n^2) maybe not great for large tables or big multiway joins for right_col in right_cols.iter() { let mut found_match = false; - for left_table in left_tables.iter() { + for left_table in tables.iter().take(tables.len() - 1) { for left_col in left_table.columns().iter() { if left_col.name == right_col.name { if let Some(distinct_names) = distinct_names.as_mut() { @@ -493,18 +501,30 @@ fn parse_join( let mut preds = vec![]; break_predicate_at_and_boundaries(expr, &mut preds); for predicate in preds.iter_mut() { - bind_column_references(predicate, tables)?; + bind_column_references(predicate, tables, None)?; + } + for pred in preds { + let cur_table_idx = tables.len() - 1; + let eval_at_loop = if outer { + cur_table_idx + } else { + get_rightmost_table_referenced_in_expr(&pred)? + }; + out_where_clause.push(WhereTerm { + expr: pred, + from_outer_join: outer, + eval_at_loop, + }); } - predicates = Some(preds); } ast::JoinConstraint::Using(distinct_names) => { // USING join is replaced with a list of equality predicates - let mut using_predicates = vec![]; for distinct_name in distinct_names.iter() { let name_normalized = normalize_ident(distinct_name.0.as_str()); - let left_tables = &tables[..table_index]; + let cur_table_idx = tables.len() - 1; + let left_tables = &tables[..cur_table_idx]; assert!(!left_tables.is_empty()); - let right_table = &tables[table_index]; + let right_table = tables.last().unwrap(); let mut left_col = None; for (left_table_idx, left_table) in left_tables.iter().enumerate() { left_col = left_table @@ -536,7 +556,7 @@ fn parse_join( } let (left_table_idx, left_col_idx, left_col) = left_col.unwrap(); let (right_col_idx, right_col) = right_col.unwrap(); - using_predicates.push(Expr::Binary( + let expr = Expr::Binary( Box::new(Expr::Column { database: None, table: left_table_idx, @@ -546,39 +566,68 @@ fn parse_join( ast::Operator::Equals, Box::new(Expr::Column { database: None, - table: right_table.table_index, + table: cur_table_idx, column: right_col_idx, is_rowid_alias: right_col.is_rowid_alias, }), - )); + ); + let eval_at_loop = if outer { + cur_table_idx + } else { + get_rightmost_table_referenced_in_expr(&expr)? + }; + out_where_clause.push(WhereTerm { + expr, + from_outer_join: outer, + eval_at_loop, + }); } - predicates = Some(using_predicates); using = Some(distinct_names); } } } - Ok(JoinParseResult { - source_operator, - is_outer_join: outer, - using, - predicates, - }) + assert!(tables.len() >= 2); + let last_idx = tables.len() - 1; + let rightmost_table = tables.get_mut(last_idx).unwrap(); + rightmost_table.join_info = Some(JoinInfo { outer, using }); + + Ok(()) } -pub fn parse_limit(limit: Limit) -> Option { +pub fn parse_limit(limit: Limit) -> Result<(Option, Option)> { + let offset_val = match limit.offset { + Some(offset_expr) => match offset_expr { + Expr::Literal(ast::Literal::Numeric(n)) => n.parse().ok(), + // If OFFSET is negative, the result is as if OFFSET is zero + Expr::Unary(UnaryOperator::Negative, expr) => match *expr { + Expr::Literal(ast::Literal::Numeric(n)) => n.parse::().ok().map(|num| -num), + _ => crate::bail_parse_error!("Invalid OFFSET clause"), + }, + _ => crate::bail_parse_error!("Invalid OFFSET clause"), + }, + None => Some(0), + }; + if let Expr::Literal(ast::Literal::Numeric(n)) = limit.expr { - n.parse().ok() + Ok((n.parse().ok(), offset_val)) + } else if let Expr::Unary(UnaryOperator::Negative, expr) = limit.expr { + if let Expr::Literal(ast::Literal::Numeric(n)) = *expr { + let limit_val = n.parse::().ok().map(|num| -num); + Ok((limit_val, offset_val)) + } else { + crate::bail_parse_error!("Invalid LIMIT clause"); + } } else if let Expr::Id(id) = limit.expr { if id.0.eq_ignore_ascii_case("true") { - Some(1) + Ok((Some(1), offset_val)) } else if id.0.eq_ignore_ascii_case("false") { - Some(0) + Ok((Some(0), offset_val)) } else { - None + crate::bail_parse_error!("Invalid LIMIT clause"); } } else { - None + crate::bail_parse_error!("Invalid LIMIT clause"); } } diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs new file mode 100644 index 000000000..df0041331 --- /dev/null +++ b/core/translate/pragma.rs @@ -0,0 +1,269 @@ +//! VDBE bytecode generation for pragma statements. +//! More info: https://www.sqlite.org/pragma.html. + +use sqlite3_parser::ast; +use sqlite3_parser::ast::PragmaName; +use std::cell::RefCell; +use std::rc::Rc; + +use crate::schema::Schema; +use crate::storage::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; +use crate::storage::wal::CheckpointMode; +use crate::util::normalize_ident; +use crate::vdbe::builder::ProgramBuilder; +use crate::vdbe::insn::Insn; +use crate::vdbe::BranchOffset; +use crate::{bail_parse_error, Pager}; +use std::str::FromStr; +use strum::IntoEnumIterator; + +fn list_pragmas( + program: &mut ProgramBuilder, + init_label: BranchOffset, + start_offset: BranchOffset, +) { + for x in PragmaName::iter() { + let register = program.emit_string8_new_reg(x.to_string()); + program.emit_result_row(register, 1); + } + + program.emit_halt(); + program.resolve_label(init_label, program.offset()); + program.emit_constant_insns(); + program.emit_goto(start_offset); +} + +pub fn translate_pragma( + program: &mut ProgramBuilder, + schema: &Schema, + name: &ast::QualifiedName, + body: Option, + database_header: Rc>, + pager: Rc, +) -> crate::Result<()> { + let init_label = program.emit_init(); + let start_offset = program.offset(); + let mut write = false; + + if name.name.0.to_lowercase() == "pragma_list" { + list_pragmas(program, init_label, start_offset); + return Ok(()); + } + + let pragma = match PragmaName::from_str(&name.name.0) { + Ok(pragma) => pragma, + Err(_) => bail_parse_error!("Not a valid pragma name"), + }; + + match body { + None => { + query_pragma(pragma, schema, None, database_header.clone(), program)?; + } + Some(ast::PragmaBody::Equals(value)) => match pragma { + PragmaName::TableInfo => { + query_pragma( + pragma, + schema, + Some(value), + database_header.clone(), + program, + )?; + } + _ => { + write = true; + update_pragma( + pragma, + schema, + value, + database_header.clone(), + pager, + program, + )?; + } + }, + Some(ast::PragmaBody::Call(value)) => match pragma { + PragmaName::TableInfo => { + query_pragma( + pragma, + schema, + Some(value), + database_header.clone(), + program, + )?; + } + _ => { + todo!() + } + }, + }; + program.emit_halt(); + program.resolve_label(init_label, program.offset()); + program.emit_transaction(write); + program.emit_constant_insns(); + program.emit_goto(start_offset); + + Ok(()) +} + +fn update_pragma( + pragma: PragmaName, + schema: &Schema, + value: ast::Expr, + header: Rc>, + pager: Rc, + program: &mut ProgramBuilder, +) -> crate::Result<()> { + match pragma { + PragmaName::CacheSize => { + let cache_size = match value { + ast::Expr::Literal(ast::Literal::Numeric(numeric_value)) => { + numeric_value.parse::()? + } + ast::Expr::Unary(ast::UnaryOperator::Negative, expr) => match *expr { + ast::Expr::Literal(ast::Literal::Numeric(numeric_value)) => { + -numeric_value.parse::()? + } + _ => bail_parse_error!("Not a valid value"), + }, + _ => bail_parse_error!("Not a valid value"), + }; + update_cache_size(cache_size, header, pager); + Ok(()) + } + PragmaName::JournalMode => { + query_pragma(PragmaName::JournalMode, schema, None, header, program)?; + Ok(()) + } + PragmaName::WalCheckpoint => { + query_pragma(PragmaName::WalCheckpoint, schema, None, header, program)?; + Ok(()) + } + PragmaName::PageCount => { + query_pragma(PragmaName::PageCount, schema, None, header, program)?; + Ok(()) + } + PragmaName::TableInfo => { + // because we need control over the write parameter for the transaction, + // this should be unreachable. We have to force-call query_pragma before + // getting here + unreachable!(); + } + } +} + +fn query_pragma( + pragma: PragmaName, + schema: &Schema, + value: Option, + database_header: Rc>, + program: &mut ProgramBuilder, +) -> crate::Result<()> { + let register = program.alloc_register(); + match pragma { + PragmaName::CacheSize => { + program.emit_int( + database_header.borrow().default_page_cache_size.into(), + register, + ); + program.emit_result_row(register, 1); + } + PragmaName::JournalMode => { + program.emit_string8("wal".into(), register); + program.emit_result_row(register, 1); + } + PragmaName::WalCheckpoint => { + // Checkpoint uses 3 registers: P1, P2, P3. Ref Insn::Checkpoint for more info. + // Allocate two more here as one was allocated at the top. + program.alloc_register(); + program.alloc_register(); + program.emit_insn(Insn::Checkpoint { + database: 0, + checkpoint_mode: CheckpointMode::Passive, + dest: register, + }); + program.emit_result_row(register, 3); + } + PragmaName::PageCount => { + program.emit_insn(Insn::PageCount { + db: 0, + dest: register, + }); + program.emit_result_row(register, 1); + } + PragmaName::TableInfo => { + let table = match value { + Some(ast::Expr::Name(name)) => { + let tbl = normalize_ident(&name.0); + schema.get_table(&tbl) + } + _ => None, + }; + + let base_reg = register; + program.alloc_register(); + program.alloc_register(); + program.alloc_register(); + program.alloc_register(); + program.alloc_register(); + if let Some(table) = table { + for (i, column) in table.columns.iter().enumerate() { + // cid + program.emit_int(i as i64, base_reg); + // name + program.emit_string8(column.name.clone(), base_reg + 1); + + // type + program.emit_string8(column.ty_str.clone(), base_reg + 2); + + // notnull + program.emit_bool(column.notnull, base_reg + 3); + + // dflt_value + match &column.default { + None => { + program.emit_null(base_reg + 4); + } + Some(expr) => { + program.emit_string8(expr.to_string(), base_reg + 4); + } + } + + // pk + program.emit_bool(column.primary_key, base_reg + 5); + + program.emit_result_row(base_reg, 6); + } + } + } + } + + Ok(()) +} + +fn update_cache_size(value: i64, header: Rc>, pager: Rc) { + let mut cache_size_unformatted: i64 = value; + let mut cache_size = if cache_size_unformatted < 0 { + let kb = cache_size_unformatted.abs() * 1024; + kb / 512 // assume 512 page size for now + } else { + value + } as usize; + + if cache_size < MIN_PAGE_CACHE_SIZE { + // update both in memory and stored disk value + cache_size = MIN_PAGE_CACHE_SIZE; + cache_size_unformatted = MIN_PAGE_CACHE_SIZE as i64; + } + + // update in-memory header + header.borrow_mut().default_page_cache_size = cache_size_unformatted + .try_into() + .unwrap_or_else(|_| panic!("invalid value, too big for a i32 {}", value)); + + // update in disk + let header_copy = header.borrow().clone(); + pager.write_database_header(&header_copy); + + // update cache size + pager.change_page_cache_size(cache_size); +} diff --git a/core/translate/result_row.rs b/core/translate/result_row.rs index 5c76f3008..ad8454c25 100644 --- a/core/translate/result_row.rs +++ b/core/translate/result_row.rs @@ -18,13 +18,18 @@ pub fn emit_select_result( t_ctx: &mut TranslateCtx, plan: &SelectPlan, label_on_limit_reached: Option, + offset_jump_to: Option, ) -> Result<()> { + if let (Some(jump_to), Some(_)) = (offset_jump_to, label_on_limit_reached) { + emit_offset(program, t_ctx, plan, jump_to)?; + } + let start_reg = t_ctx.reg_result_cols_start.unwrap(); for (i, rc) in plan.result_columns.iter().enumerate() { let reg = start_reg + i; translate_expr( program, - Some(&plan.referenced_tables), + Some(&plan.table_references), &rc.expr, reg, &t_ctx.resolver, @@ -71,6 +76,22 @@ pub fn emit_result_row_and_limit( dest: t_ctx.reg_limit.unwrap(), }); program.mark_last_insn_constant(); + + if let Some(offset) = plan.offset { + program.emit_insn(Insn::Integer { + value: offset as i64, + dest: t_ctx.reg_offset.unwrap(), + }); + program.mark_last_insn_constant(); + + program.emit_insn(Insn::OffsetLimit { + limit_reg: t_ctx.reg_limit.unwrap(), + combined_reg: t_ctx.reg_limit_offset_sum.unwrap(), + offset_reg: t_ctx.reg_offset.unwrap(), + }); + program.mark_last_insn_constant(); + } + program.emit_insn(Insn::DecrJumpZero { reg: t_ctx.reg_limit.unwrap(), target_pc: label_on_limit_reached.unwrap(), @@ -78,3 +99,23 @@ pub fn emit_result_row_and_limit( } Ok(()) } + +pub fn emit_offset( + program: &mut ProgramBuilder, + t_ctx: &mut TranslateCtx, + plan: &SelectPlan, + jump_to: BranchOffset, +) -> Result<()> { + match plan.offset { + Some(offset) if offset > 0 => { + program.add_comment(program.offset(), "OFFSET"); + program.emit_insn(Insn::IfPos { + reg: t_ctx.reg_offset.unwrap(), + target_pc: jump_to, + decrement_by: 1, + }); + } + _ => {} + } + Ok(()) +} diff --git a/core/translate/select.rs b/core/translate/select.rs index 196a1d284..b02179620 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -1,12 +1,12 @@ use super::emitter::emit_program; use super::expr::get_name; -use super::plan::SelectQueryType; +use super::plan::{select_star, SelectQueryType}; use crate::function::{AggFunc, ExtFunc, Func}; use crate::translate::optimizer::optimize_plan; use crate::translate::plan::{Aggregate, Direction, GroupBy, Plan, ResultSetColumn, SelectPlan}; use crate::translate::planner::{ bind_column_references, break_predicate_at_and_boundaries, parse_from, parse_limit, - parse_where, resolve_aggregates, OperatorIdCounter, + parse_where, resolve_aggregates, }; use crate::util::normalize_ident; use crate::SymbolTable; @@ -21,7 +21,7 @@ pub fn translate_select( syms: &SymbolTable, ) -> Result<()> { let mut select_plan = prepare_select_plan(schema, select, syms)?; - optimize_plan(&mut select_plan)?; + optimize_plan(&mut select_plan, schema)?; emit_program(program, select_plan, syms) } @@ -43,51 +43,68 @@ pub fn prepare_select_plan( crate::bail_parse_error!("SELECT without columns is not allowed"); } - let mut operator_id_counter = OperatorIdCounter::new(); + let mut where_predicates = vec![]; - // Parse the FROM clause - let (source, referenced_tables) = - parse_from(schema, from, &mut operator_id_counter, syms)?; + // Parse the FROM clause into a vec of TableReferences. Fold all the join conditions expressions into the WHERE clause. + let table_references = parse_from(schema, from, syms, &mut where_predicates)?; + + // Preallocate space for the result columns + let result_columns = Vec::with_capacity( + columns + .iter() + .map(|c| match c { + // Allocate space for all columns in all tables + ResultColumn::Star => { + table_references.iter().map(|t| t.columns().len()).sum() + } + // Guess 5 columns if we can't find the table using the identifier (maybe it's in [brackets] or `tick_quotes`, or miXeDcAse) + ResultColumn::TableStar(n) => table_references + .iter() + .find(|t| t.identifier == n.0) + .map(|t| t.columns().len()) + .unwrap_or(5), + // Otherwise allocate space for 1 column + ResultColumn::Expr(_, _) => 1, + }) + .sum(), + ); let mut plan = SelectPlan { - source, - result_columns: vec![], - where_clause: None, + table_references, + result_columns, + where_clause: where_predicates, group_by: None, order_by: None, aggregates: vec![], limit: None, - referenced_tables, - available_indexes: schema.indexes.clone().into_values().flatten().collect(), + offset: None, contains_constant_false_condition: false, query_type: SelectQueryType::TopLevel, }; - // Parse the WHERE clause - plan.where_clause = parse_where(where_clause, &plan.referenced_tables)?; - let mut aggregate_expressions = Vec::new(); for (result_column_idx, column) in columns.iter_mut().enumerate() { match column { ResultColumn::Star => { - plan.source.select_star(&mut plan.result_columns); + select_star(&plan.table_references, &mut plan.result_columns); } ResultColumn::TableStar(name) => { let name_normalized = normalize_ident(name.0.as_str()); let referenced_table = plan - .referenced_tables + .table_references .iter() - .find(|t| t.table_identifier == name_normalized); + .enumerate() + .find(|(_, t)| t.identifier == name_normalized); if referenced_table.is_none() { crate::bail_parse_error!("Table {} not found", name.0); } - let table_reference = referenced_table.unwrap(); - for (idx, col) in table_reference.columns().iter().enumerate() { + let (table_index, table) = referenced_table.unwrap(); + for (idx, col) in table.columns().iter().enumerate() { plan.result_columns.push(ResultSetColumn { expr: ast::Expr::Column { database: None, // TODO: support different databases - table: table_reference.table_index, + table: table_index, column: idx, is_rowid_alias: col.is_rowid_alias, }, @@ -97,7 +114,11 @@ pub fn prepare_select_plan( } } ResultColumn::Expr(ref mut expr, maybe_alias) => { - bind_column_references(expr, &plan.referenced_tables)?; + bind_column_references( + expr, + &plan.table_references, + Some(&plan.result_columns), + )?; match expr { ast::Expr::FunctionCall { name, @@ -140,7 +161,7 @@ pub fn prepare_select_plan( name: get_name( maybe_alias.as_ref(), expr, - &plan.referenced_tables, + &plan.table_references, || format!("expr_{}", result_column_idx), ), expr: expr.clone(), @@ -154,7 +175,7 @@ pub fn prepare_select_plan( name: get_name( maybe_alias.as_ref(), expr, - &plan.referenced_tables, + &plan.table_references, || format!("expr_{}", result_column_idx), ), expr: expr.clone(), @@ -173,7 +194,7 @@ pub fn prepare_select_plan( name: get_name( maybe_alias.as_ref(), expr, - &plan.referenced_tables, + &plan.table_references, || format!("expr_{}", result_column_idx), ), expr: expr.clone(), @@ -190,7 +211,7 @@ pub fn prepare_select_plan( name: get_name( maybe_alias.as_ref(), expr, - &plan.referenced_tables, + &plan.table_references, || format!("expr_{}", result_column_idx), ), expr: expr.clone(), @@ -224,7 +245,7 @@ pub fn prepare_select_plan( name: get_name( maybe_alias.as_ref(), expr, - &plan.referenced_tables, + &plan.table_references, || format!("expr_{}", result_column_idx), ), expr: expr.clone(), @@ -244,7 +265,7 @@ pub fn prepare_select_plan( name: get_name( maybe_alias.as_ref(), expr, - &plan.referenced_tables, + &plan.table_references, || format!("expr_{}", result_column_idx), ), expr: expr.clone(), @@ -255,9 +276,22 @@ pub fn prepare_select_plan( } } } + + // Parse the actual WHERE clause and add its conditions to the plan WHERE clause that already contains the join conditions. + parse_where( + where_clause, + &plan.table_references, + Some(&plan.result_columns), + &mut plan.where_clause, + )?; + if let Some(mut group_by) = group_by { for expr in group_by.exprs.iter_mut() { - bind_column_references(expr, &plan.referenced_tables)?; + bind_column_references( + expr, + &plan.table_references, + Some(&plan.result_columns), + )?; } plan.group_by = Some(GroupBy { @@ -266,7 +300,11 @@ pub fn prepare_select_plan( let mut predicates = vec![]; break_predicate_at_and_boundaries(having, &mut predicates); for expr in predicates.iter_mut() { - bind_column_references(expr, &plan.referenced_tables)?; + bind_column_references( + expr, + &plan.table_references, + Some(&plan.result_columns), + )?; let contains_aggregates = resolve_aggregates(expr, &mut aggregate_expressions); if !contains_aggregates { @@ -312,7 +350,11 @@ pub fn prepare_select_plan( o.expr }; - bind_column_references(&mut expr, &plan.referenced_tables)?; + bind_column_references( + &mut expr, + &plan.table_references, + Some(&plan.result_columns), + )?; resolve_aggregates(&expr, &mut plan.aggregates); key.push(( @@ -326,8 +368,9 @@ pub fn prepare_select_plan( plan.order_by = Some(key); } - // Parse the LIMIT clause - plan.limit = select.limit.and_then(|l| parse_limit(*l)); + // Parse the LIMIT/OFFSET clause + (plan.limit, plan.offset) = + select.limit.map_or(Ok((None, None)), |l| parse_limit(*l))?; // Return the unoptimized query plan Ok(Plan::Select(plan)) diff --git a/core/translate/subquery.rs b/core/translate/subquery.rs index fd0e6667b..1730312be 100644 --- a/core/translate/subquery.rs +++ b/core/translate/subquery.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use crate::{ vdbe::{builder::ProgramBuilder, insn::Insn}, Result, @@ -7,7 +5,8 @@ use crate::{ use super::{ emitter::{emit_query, Resolver, TranslateCtx}, - plan::{SelectPlan, SelectQueryType, SourceOperator, TableReference, TableReferenceType}, + main_loop::LoopLabels, + plan::{Operation, SelectPlan, SelectQueryType, TableReference}, }; /// Emit the subqueries contained in the FROM clause. @@ -15,42 +14,23 @@ use super::{ pub fn emit_subqueries( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - referenced_tables: &mut [TableReference], - source: &mut SourceOperator, + tables: &mut [TableReference], ) -> Result<()> { - match source { - SourceOperator::Subquery { - table_reference, + for table in tables.iter_mut() { + if let Operation::Subquery { plan, - .. - } => { + result_columns_start_reg, + } = &mut table.op + { // Emit the subquery and get the start register of the result columns. let result_columns_start = emit_subquery(program, plan, t_ctx)?; - // Set the result_columns_start_reg in the TableReference object. + // Set the start register of the subquery's result columns. // This is done so that translate_expr() can read the result columns of the subquery, // as if it were reading from a regular table. - let table_ref = referenced_tables - .iter_mut() - .find(|t| t.table_identifier == table_reference.table_identifier) - .unwrap(); - if let TableReferenceType::Subquery { - result_columns_start_reg, - .. - } = &mut table_ref.reference_type - { - *result_columns_start_reg = result_columns_start; - } else { - unreachable!("emit_subqueries called on non-subquery"); - } - Ok(()) + *result_columns_start_reg = result_columns_start; } - SourceOperator::Join { left, right, .. } => { - emit_subqueries(program, t_ctx, referenced_tables, left)?; - emit_subqueries(program, t_ctx, referenced_tables, right)?; - Ok(()) - } - _ => Ok(()), } + Ok(()) } /// Emit a subquery and return the start register of the result columns. @@ -87,16 +67,20 @@ pub fn emit_subquery<'a>( } let end_coroutine_label = program.allocate_label(); let mut metadata = TranslateCtx { - labels_main_loop: HashMap::new(), + labels_main_loop: (0..plan.table_references.len()) + .map(|_| LoopLabels::new(program)) + .collect(), label_main_loop_end: None, meta_group_by: None, - meta_left_joins: HashMap::new(), + meta_left_joins: (0..plan.table_references.len()).map(|_| None).collect(), meta_sort: None, reg_agg_start: None, reg_result_cols_start: None, - result_column_indexes_in_orderby_sorter: HashMap::new(), + result_column_indexes_in_orderby_sorter: (0..plan.result_columns.len()).collect(), result_columns_to_skip_in_orderby_sorter: None, reg_limit: plan.limit.map(|_| program.alloc_register()), + reg_offset: plan.offset.map(|_| program.alloc_register()), + reg_limit_offset_sum: plan.offset.map(|_| program.alloc_register()), resolver: Resolver::new(t_ctx.resolver.symbol_table), }; let subquery_body_end_label = program.allocate_label(); diff --git a/core/types.rs b/core/types.rs index 92e4714c7..1d478050e 100644 --- a/core/types.rs +++ b/core/types.rs @@ -15,8 +15,8 @@ pub enum Value<'a> { Null, Integer(i64), Float(f64), - Text(&'a String), - Blob(&'a Vec), + Text(&'a str), + Blob(&'a [u8]), } impl Display for Value<'_> { @@ -130,38 +130,41 @@ impl OwnedValue { } } - pub fn from_ffi(v: &ExtValue) -> Self { + pub fn from_ffi(v: &ExtValue) -> Result { match v.value_type() { - ExtValueType::Null => OwnedValue::Null, + ExtValueType::Null => Ok(OwnedValue::Null), ExtValueType::Integer => { let Some(int) = v.to_integer() else { - return OwnedValue::Null; + return Ok(OwnedValue::Null); }; - OwnedValue::Integer(int) + Ok(OwnedValue::Integer(int)) } ExtValueType::Float => { let Some(float) = v.to_float() else { - return OwnedValue::Null; + return Ok(OwnedValue::Null); }; - OwnedValue::Float(float) + Ok(OwnedValue::Float(float)) } ExtValueType::Text => { let Some(text) = v.to_text() else { - return OwnedValue::Null; + return Ok(OwnedValue::Null); }; - OwnedValue::build_text(Rc::new(text)) + Ok(OwnedValue::build_text(Rc::new(text.to_string()))) } ExtValueType::Blob => { let Some(blob) = v.to_blob() else { - return OwnedValue::Null; + return Ok(OwnedValue::Null); }; - OwnedValue::Blob(Rc::new(blob)) + Ok(OwnedValue::Blob(Rc::new(blob))) } ExtValueType::Error => { - let Some(err) = v.to_error() else { - return OwnedValue::Null; + let Some(err) = v.to_error_details() else { + return Ok(OwnedValue::Null); }; - OwnedValue::Text(LimboText::new(Rc::new(err))) + match err { + (_, Some(msg)) => Err(LimboError::ExtensionError(msg)), + (code, None) => Err(LimboError::ExtensionError(code.to_string())), + } } } } @@ -181,13 +184,15 @@ pub enum AggContext { const NULL: OwnedValue = OwnedValue::Null; impl AggContext { - pub fn compute_external(&mut self) { + pub fn compute_external(&mut self) -> Result<()> { if let Self::External(ext_state) = self { if ext_state.finalized_value.is_none() { let final_value = unsafe { (ext_state.finalize_fn)(ext_state.state) }; - ext_state.cache_final_value(OwnedValue::from_ffi(&final_value)); + ext_state.cache_final_value(OwnedValue::from_ffi(&final_value)?); + unsafe { final_value.free() }; } } + Ok(()) } pub fn final_value(&self) -> &OwnedValue { diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 0af4d1182..36e6de7ec 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -15,7 +15,6 @@ use super::{BranchOffset, CursorID, Insn, InsnReference, Program}; #[allow(dead_code)] pub struct ProgramBuilder { next_free_register: usize, - next_free_label: i32, next_free_cursor_id: usize, insns: Vec, // for temporarily storing instructions that will be put after Transaction opcode @@ -23,12 +22,12 @@ pub struct ProgramBuilder { next_insn_label: Option, // Cursors that are referenced by the program. Indexed by CursorID. pub cursor_ref: Vec<(Option, CursorType)>, - // Hashmap of label to insn reference. Resolved in build(). - label_to_resolved_offset: HashMap, + /// A vector where index=label number, value=resolved offset. Resolved in build(). + label_to_resolved_offset: Vec>, // Bitmask of cursors that have emitted a SeekRowid instruction. seekrowid_emitted_bitmask: u64, - // map of instruction index to manual comment (used in EXPLAIN) - comments: HashMap, + // map of instruction index to manual comment (used in EXPLAIN only) + comments: Option>, pub parameters: Parameters, pub columns: Vec, } @@ -47,19 +46,28 @@ impl CursorType { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QueryMode { + Normal, + Explain, +} + impl ProgramBuilder { - pub fn new() -> Self { + pub fn new(query_mode: QueryMode) -> Self { Self { next_free_register: 1, - next_free_label: 0, next_free_cursor_id: 0, insns: Vec::new(), next_insn_label: None, cursor_ref: Vec::new(), constant_insns: Vec::new(), - label_to_resolved_offset: HashMap::new(), + label_to_resolved_offset: Vec::with_capacity(4), // 4 is arbitrary, we guess to assign at least this much seekrowid_emitted_bitmask: 0, - comments: HashMap::new(), + comments: if query_mode == QueryMode::Explain { + Some(HashMap::new()) + } else { + None + }, parameters: Parameters::new(), columns: Vec::new(), } @@ -91,15 +99,83 @@ impl ProgramBuilder { pub fn emit_insn(&mut self, insn: Insn) { if let Some(label) = self.next_insn_label { - self.label_to_resolved_offset - .insert(label.to_label_value(), self.insns.len() as InsnReference); + self.label_to_resolved_offset.insert( + label.to_label_value() as usize, + Some(self.insns.len() as InsnReference), + ); self.next_insn_label = None; } self.insns.push(insn); } + pub fn emit_string8(&mut self, value: String, dest: usize) { + self.emit_insn(Insn::String8 { value, dest }); + } + + pub fn emit_string8_new_reg(&mut self, value: String) -> usize { + let dest = self.alloc_register(); + self.emit_insn(Insn::String8 { value, dest }); + dest + } + + pub fn emit_int(&mut self, value: i64, dest: usize) { + self.emit_insn(Insn::Integer { value, dest }); + } + + pub fn emit_bool(&mut self, value: bool, dest: usize) { + self.emit_insn(Insn::Integer { + value: if value { 1 } else { 0 }, + dest, + }); + } + + pub fn emit_null(&mut self, dest: usize) { + self.emit_insn(Insn::Null { + dest, + dest_end: None, + }); + } + + pub fn emit_result_row(&mut self, start_reg: usize, count: usize) { + self.emit_insn(Insn::ResultRow { start_reg, count }); + } + + pub fn emit_halt(&mut self) { + self.emit_insn(Insn::Halt { + err_code: 0, + description: String::new(), + }); + } + + // no users yet, but I want to avoid someone else in the future + // just adding parameters to emit_halt! If you use this, remove the + // clippy warning please. + #[allow(dead_code)] + pub fn emit_halt_err(&mut self, err_code: usize, description: String) { + self.emit_insn(Insn::Halt { + err_code, + description, + }); + } + + pub fn emit_init(&mut self) -> BranchOffset { + let target_pc = self.allocate_label(); + self.emit_insn(Insn::Init { target_pc }); + target_pc + } + + pub fn emit_transaction(&mut self, write: bool) { + self.emit_insn(Insn::Transaction { write }); + } + + pub fn emit_goto(&mut self, target_pc: BranchOffset) { + self.emit_insn(Insn::Goto { target_pc }); + } + pub fn add_comment(&mut self, insn_index: BranchOffset, comment: &'static str) { - self.comments.insert(insn_index.to_offset_int(), comment); + if let Some(comments) = &mut self.comments { + comments.insert(insn_index.to_offset_int(), comment); + } } // Emit an instruction that will be put at the end of the program (after Transaction statement). @@ -119,8 +195,9 @@ impl ProgramBuilder { } pub fn allocate_label(&mut self) -> BranchOffset { - self.next_free_label -= 1; - BranchOffset::Label(self.next_free_label) + let label_n = self.label_to_resolved_offset.len(); + self.label_to_resolved_offset.push(None); + BranchOffset::Label(label_n as u32) } // Effectively a GOTO without the need to emit an explicit GOTO instruction. @@ -133,8 +210,8 @@ impl ProgramBuilder { pub fn resolve_label(&mut self, label: BranchOffset, to_offset: BranchOffset) { assert!(matches!(label, BranchOffset::Label(_))); assert!(matches!(to_offset, BranchOffset::Offset(_))); - self.label_to_resolved_offset - .insert(label.to_label_value(), to_offset.to_offset_int()); + self.label_to_resolved_offset[label.to_label_value() as usize] = + Some(to_offset.to_offset_int()); } /// Resolve unresolved labels to a specific offset in the instruction list. @@ -145,10 +222,16 @@ impl ProgramBuilder { pub fn resolve_labels(&mut self) { let resolve = |pc: &mut BranchOffset, insn_name: &str| { if let BranchOffset::Label(label) = pc { - let to_offset = *self.label_to_resolved_offset.get(label).unwrap_or_else(|| { - panic!("Reference to undefined label in {}: {}", insn_name, label) - }); - *pc = BranchOffset::Offset(to_offset); + let to_offset = self + .label_to_resolved_offset + .get(*label as usize) + .unwrap_or_else(|| { + panic!("Reference to undefined label in {}: {}", insn_name, label) + }); + *pc = BranchOffset::Offset( + to_offset + .unwrap_or_else(|| panic!("Unresolved label in {}: {}", insn_name, label)), + ); } }; for insn in self.insns.iter_mut() { @@ -310,7 +393,7 @@ impl ProgramBuilder { Insn::IdxGT { target_pc, .. } => { resolve(target_pc, "IdxGT"); } - Insn::IsNull { src: _, target_pc } => { + Insn::IsNull { reg: _, target_pc } => { resolve(target_pc, "IsNull"); } _ => continue, diff --git a/core/vdbe/datetime.rs b/core/vdbe/datetime.rs index 4cb89a65d..a4fe2a680 100644 --- a/core/vdbe/datetime.rs +++ b/core/vdbe/datetime.rs @@ -122,6 +122,7 @@ fn format_dt(dt: NaiveDateTime, output_type: DateTimeOutput, subsec: bool) -> St // Not as fast as if the formatting was native to chrono, but a good enough // for now, just to have the feature implemented fn strftime_format(dt: &NaiveDateTime, format_str: &str) -> String { + use super::strftime::CustomStrftimeItems; use std::fmt::Write; // Necessary to remove %f and %J that are exclusive formatters to sqlite // Chrono does not support them, so it is necessary to replace the modifiers manually @@ -130,13 +131,13 @@ fn strftime_format(dt: &NaiveDateTime, format_str: &str) -> String { let copy_format = format_str .to_string() .replace("%J", &format!("{:.9}", to_julian_day_exact(dt))); - // Just change the formatting here to have fractional seconds using chrono builtin modifier - let copy_format = copy_format.replace("%f", "%S.%3f"); + + let items = CustomStrftimeItems::new(©_format); // The write! macro is used here as chrono's format can panic if the formatting string contains // unknown specifiers. By using a writer, we can catch the panic and handle the error let mut formatted = String::new(); - match write!(formatted, "{}", dt.format(©_format)) { + match write!(formatted, "{}", dt.format_with_items(items)) { Ok(_) => formatted, // On sqlite when the formatting fails nothing is printed Err(_) => "".to_string(), diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 89967608a..309745b89 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -957,6 +957,22 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::OffsetLimit { + limit_reg, + combined_reg, + offset_reg, + } => ( + "OffsetLimit", + *limit_reg as i32, + *combined_reg as i32, + *offset_reg as i32, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + format!( + "if r[{}]>0 then r[{}]=r[{}]+max(0,r[{}]) else r[{}]=(-1)", + limit_reg, combined_reg, limit_reg, offset_reg, combined_reg + ), + ), Insn::OpenWriteAsync { cursor_id, root_page, @@ -1018,14 +1034,14 @@ pub fn insn_to_str( 0, "".to_string(), ), - Insn::IsNull { src, target_pc } => ( + Insn::IsNull { reg, target_pc } => ( "IsNull", - *src as i32, + *reg as i32, target_pc.to_debug_int(), 0, OwnedValue::build_text(Rc::new("".to_string())), 0, - format!("if (r[{}]==NULL) goto {}", src, target_pc.to_debug_int()), + format!("if (r[{}]==NULL) goto {}", reg, target_pc.to_debug_int()), ), Insn::ParseSchema { db, where_clause } => ( "ParseSchema", @@ -1138,6 +1154,24 @@ pub fn insn_to_str( 0, format!("r[{}]=(r[{}] || r[{}])", dest, lhs, rhs), ), + Insn::Noop => ( + "Noop", + 0, + 0, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + String::new(), + ), + Insn::PageCount { db, dest } => ( + "Pagecount", + *db as i32, + *dest as i32, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), }; format!( "{:<4} {:<17} {:<4} {:<4} {:<4} {:<13} {:<2} {}", diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index f17a1a354..fd033b11a 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -6,6 +6,37 @@ use crate::storage::wal::CheckpointMode; use crate::types::{OwnedRecord, OwnedValue}; use limbo_macros::Description; +/// Flags provided to comparison instructions (e.g. Eq, Ne) which determine behavior related to NULL values. +#[derive(Clone, Copy, Debug, Default)] +pub struct CmpInsFlags(usize); + +impl CmpInsFlags { + const NULL_EQ: usize = 0x80; + const JUMP_IF_NULL: usize = 0x10; + + fn has(&self, flag: usize) -> bool { + (self.0 & flag) != 0 + } + + pub fn null_eq(mut self) -> Self { + self.0 |= CmpInsFlags::NULL_EQ; + self + } + + pub fn jump_if_null(mut self) -> Self { + self.0 |= CmpInsFlags::JUMP_IF_NULL; + self + } + + pub fn has_jump_if_null(&self) -> bool { + self.has(CmpInsFlags::JUMP_IF_NULL) + } + + pub fn has_nulleq(&self) -> bool { + self.has(CmpInsFlags::NULL_EQ) + } +} + #[derive(Description, Debug)] pub enum Insn { // Initialize the program state and jump to the given PC. @@ -108,52 +139,56 @@ pub enum Insn { lhs: usize, rhs: usize, target_pc: BranchOffset, - /// Jump if either of the operands is null. Used for "jump when false" logic. + /// CmpInsFlags are nulleq (null = null) or jump_if_null. + /// + /// jump_if_null jumps if either of the operands is null. Used for "jump when false" logic. /// Eg. "SELECT * FROM users WHERE id = NULL" becomes: /// /// Without the jump_if_null flag it would not jump because the logical comparison "id != NULL" is never true. /// This flag indicates that if either is null we should still jump. - jump_if_null: bool, + flags: CmpInsFlags, }, // Compare two registers and jump to the given PC if they are not equal. Ne { lhs: usize, rhs: usize, target_pc: BranchOffset, - /// Jump if either of the operands is null. Used for "jump when false" logic. - jump_if_null: bool, + /// CmpInsFlags are nulleq (null = null) or jump_if_null. + /// + /// jump_if_null jumps if either of the operands is null. Used for "jump when false" logic. + flags: CmpInsFlags, }, // Compare two registers and jump to the given PC if the left-hand side is less than the right-hand side. Lt { lhs: usize, rhs: usize, target_pc: BranchOffset, - /// Jump if either of the operands is null. Used for "jump when false" logic. - jump_if_null: bool, + /// jump_if_null: Jump if either of the operands is null. Used for "jump when false" logic. + flags: CmpInsFlags, }, // Compare two registers and jump to the given PC if the left-hand side is less than or equal to the right-hand side. Le { lhs: usize, rhs: usize, target_pc: BranchOffset, - /// Jump if either of the operands is null. Used for "jump when false" logic. - jump_if_null: bool, + /// jump_if_null: Jump if either of the operands is null. Used for "jump when false" logic. + flags: CmpInsFlags, }, // Compare two registers and jump to the given PC if the left-hand side is greater than the right-hand side. Gt { lhs: usize, rhs: usize, target_pc: BranchOffset, - /// Jump if either of the operands is null. Used for "jump when false" logic. - jump_if_null: bool, + /// jump_if_null: Jump if either of the operands is null. Used for "jump when false" logic. + flags: CmpInsFlags, }, // Compare two registers and jump to the given PC if the left-hand side is greater than or equal to the right-hand side. Ge { lhs: usize, rhs: usize, target_pc: BranchOffset, - /// Jump if either of the operands is null. Used for "jump when false" logic. - jump_if_null: bool, + /// jump_if_null: Jump if either of the operands is null. Used for "jump when false" logic. + flags: CmpInsFlags, }, /// Jump to target_pc if r\[reg\] != 0 or (r\[reg\] == NULL && r\[jump_if_null\] != 0) If { @@ -473,6 +508,12 @@ pub enum Insn { target_pc: BranchOffset, }, + OffsetLimit { + limit_reg: usize, + combined_reg: usize, + offset_reg: usize, + }, + OpenWriteAsync { cursor_id: CursorID, root_page: PageIdx, @@ -504,7 +545,7 @@ pub enum Insn { /// Check if the register is null. IsNull { /// Source register (P1). - src: usize, + reg: usize, /// Jump to this PC if the register is null (P2). target_pc: BranchOffset, @@ -563,6 +604,12 @@ pub enum Insn { rhs: usize, dest: usize, }, + Noop, + /// Write the current number of pages in database P1 to memory cell P2. + PageCount { + db: usize, + dest: usize, + }, } fn cast_text_to_numerical(value: &str) -> OwnedValue { @@ -833,16 +880,7 @@ pub fn exec_shift_left(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue } fn compute_shl(lhs: i64, rhs: i64) -> i64 { - if rhs == 0 { - lhs - } else if rhs >= 64 || rhs <= -64 { - 0 - } else if rhs < 0 { - // if negative do right shift - lhs >> (-rhs) - } else { - lhs << rhs - } + compute_shr(lhs, -rhs) } pub fn exec_shift_right(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { @@ -880,11 +918,15 @@ pub fn exec_shift_right(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValu } } +// compute binary shift to the right if rhs >= 0 and binary shift to the left - if rhs < 0 +// note, that binary shift to the right is sign-extended fn compute_shr(lhs: i64, rhs: i64) -> i64 { if rhs == 0 { lhs - } else if rhs >= 64 || rhs <= -64 { + } else if rhs >= 64 && lhs >= 0 || rhs <= -64 { 0 + } else if rhs >= 64 && lhs < 0 { + -1 } else if rhs < 0 { // if negative do left shift lhs << (-rhs) diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 36cc8e4cb..f1945e865 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -22,14 +22,18 @@ mod datetime; pub mod explain; pub mod insn; pub mod likeop; +mod printf; pub mod sorter; +mod strftime; use crate::error::{LimboError, SQLITE_CONSTRAINT_PRIMARYKEY}; use crate::ext::ExtValue; use crate::function::{AggFunc, ExtFunc, FuncCtx, MathFunc, MathFuncArity, ScalarFunc}; +use crate::info; use crate::pseudo::PseudoCursor; use crate::result::LimboResult; use crate::storage::sqlite3_ondisk::DatabaseHeader; +use crate::storage::wal::CheckpointResult; use crate::storage::{btree::BTreeCursor, pager::Pager}; use crate::types::{ AggContext, Cursor, CursorResult, ExternalAggState, OwnedRecord, OwnedValue, Record, SeekKey, @@ -42,7 +46,8 @@ use crate::vdbe::insn::Insn; use crate::{ function::JsonFunc, json::get_json, json::is_json_valid, json::json_array, json::json_array_length, json::json_arrow_extract, json::json_arrow_shift_extract, - json::json_error_position, json::json_extract, json::json_object, json::json_type, + json::json_error_position, json::json_extract, json::json_object, json::json_patch, + json::json_remove, json::json_set, json::json_type, }; use crate::{resolve_ext_path, Connection, Result, TransactionState, DATABASE_VERSION}; use datetime::{ @@ -54,6 +59,7 @@ use insn::{ exec_subtract, }; use likeop::{construct_like_escape_arg, exec_glob, exec_like_with_escape}; +use printf::exec_printf; use rand::distributions::{Distribution, Uniform}; use rand::{thread_rng, Rng}; use regex::{Regex, RegexBuilder}; @@ -71,7 +77,7 @@ pub enum BranchOffset { /// A label is a named location in the program. /// If there are references to it, it must always be resolved to an Offset /// via program.resolve_label(). - Label(i32), + Label(u32), /// An offset is a direct index into the instruction list. Offset(InsnReference), /// A placeholder is a temporary value to satisfy the compiler. @@ -100,7 +106,7 @@ impl BranchOffset { } /// Returns the label value. Panics if the branch offset is an offset or placeholder. - pub fn to_label_value(&self) -> i32 { + pub fn to_label_value(&self) -> u32 { match self { BranchOffset::Label(v) => *v, BranchOffset::Offset(_) => unreachable!("Offset cannot be converted to label value"), @@ -113,7 +119,7 @@ impl BranchOffset { /// label or placeholder. pub fn to_debug_int(&self) -> i32 { match self { - BranchOffset::Label(v) => *v, + BranchOffset::Label(v) => *v as i32, BranchOffset::Offset(v) => *v as i32, BranchOffset::Placeholder => i32::MAX, } @@ -134,6 +140,7 @@ pub type PageIdx = usize; // Index of insn in list of insns type InsnReference = u32; +#[derive(Debug)] pub enum StepResult<'a> { Done, IO, @@ -163,8 +170,16 @@ macro_rules! call_external_function { ) => {{ if $arg_count == 0 { let result_c_value: ExtValue = unsafe { ($func_ptr)(0, std::ptr::null()) }; - let result_ov = OwnedValue::from_ffi(&result_c_value); - $state.registers[$dest_register] = result_ov; + match OwnedValue::from_ffi(&result_c_value) { + Ok(result_ov) => { + $state.registers[$dest_register] = result_ov; + unsafe { result_c_value.free() }; + } + Err(e) => { + unsafe { result_c_value.free() }; + return Err(e); + } + } } else { let register_slice = &$state.registers[$start_reg..$start_reg + $arg_count]; let mut ext_values: Vec = Vec::with_capacity($arg_count); @@ -174,8 +189,16 @@ macro_rules! call_external_function { } let argv_ptr = ext_values.as_ptr(); let result_c_value: ExtValue = unsafe { ($func_ptr)($arg_count as i32, argv_ptr) }; - let result_ov = OwnedValue::from_ffi(&result_c_value); - $state.registers[$dest_register] = result_ov; + match OwnedValue::from_ffi(&result_c_value) { + Ok(result_ov) => { + $state.registers[$dest_register] = result_ov; + unsafe { result_c_value.free() }; + } + Err(e) => { + unsafe { result_c_value.free() }; + return Err(e); + } + } } }}; } @@ -358,7 +381,7 @@ pub struct Program { pub insns: Vec, pub cursor_ref: Vec<(Option, CursorType)>, pub database_header: Rc>, - pub comments: HashMap, + pub comments: Option>, pub parameters: crate::parameters::Parameters, pub connection: Weak, pub auto_commit: bool, @@ -397,7 +420,6 @@ impl Program { } let insn = &self.insns[state.pc as usize]; trace_insn(self, state.pc as InsnReference, insn); - let mut cursors = state.cursors.borrow_mut(); match insn { Insn::Init { target_pc } => { assert!(target_pc.is_offset()); @@ -449,15 +471,18 @@ impl Program { } => { let result = self.connection.upgrade().unwrap().checkpoint(); match result { - Ok(()) => { + Ok(CheckpointResult { + num_wal_frames: num_wal_pages, + num_checkpointed_frames: num_checkpointed_pages, + }) => { // https://sqlite.org/pragma.html#pragma_wal_checkpoint - // TODO make 2nd and 3rd cols available through checkpoint method // 1st col: 1 (checkpoint SQLITE_BUSY) or 0 (not busy). state.registers[*dest] = OwnedValue::Integer(0); // 2nd col: # modified pages written to wal file - state.registers[*dest + 1] = OwnedValue::Integer(0); + state.registers[*dest + 1] = OwnedValue::Integer(num_wal_pages as i64); // 3rd col: # pages moved to db after checkpoint - state.registers[*dest + 2] = OwnedValue::Integer(0); + state.registers[*dest + 2] = + OwnedValue::Integer(num_checkpointed_pages as i64); } Err(_err) => state.registers[*dest] = OwnedValue::Integer(1), } @@ -475,6 +500,7 @@ impl Program { state.pc += 1; } Insn::NullRow { cursor_id } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = must_be_btree_cursor!(*cursor_id, self.cursor_ref, cursors, "NullRow"); cursor.set_null_flag(true); @@ -585,15 +611,18 @@ impl Program { lhs, rhs, target_pc, - jump_if_null, + flags, } => { assert!(target_pc.is_offset()); let lhs = *lhs; let rhs = *rhs; let target_pc = *target_pc; + let cond = state.registers[lhs] == state.registers[rhs]; + let nulleq = flags.has_nulleq(); + let jump_if_null = flags.has_jump_if_null(); match (&state.registers[lhs], &state.registers[rhs]) { (_, OwnedValue::Null) | (OwnedValue::Null, _) => { - if *jump_if_null { + if (nulleq && cond) || (!nulleq && jump_if_null) { state.pc = target_pc.to_offset_int(); } else { state.pc += 1; @@ -612,15 +641,18 @@ impl Program { lhs, rhs, target_pc, - jump_if_null, + flags, } => { assert!(target_pc.is_offset()); let lhs = *lhs; let rhs = *rhs; let target_pc = *target_pc; + let cond = state.registers[lhs] != state.registers[rhs]; + let nulleq = flags.has_nulleq(); + let jump_if_null = flags.has_jump_if_null(); match (&state.registers[lhs], &state.registers[rhs]) { (_, OwnedValue::Null) | (OwnedValue::Null, _) => { - if *jump_if_null { + if (nulleq && cond) || (!nulleq && jump_if_null) { state.pc = target_pc.to_offset_int(); } else { state.pc += 1; @@ -639,15 +671,16 @@ impl Program { lhs, rhs, target_pc, - jump_if_null, + flags, } => { assert!(target_pc.is_offset()); let lhs = *lhs; let rhs = *rhs; let target_pc = *target_pc; + let jump_if_null = flags.has_jump_if_null(); match (&state.registers[lhs], &state.registers[rhs]) { (_, OwnedValue::Null) | (OwnedValue::Null, _) => { - if *jump_if_null { + if jump_if_null { state.pc = target_pc.to_offset_int(); } else { state.pc += 1; @@ -666,15 +699,16 @@ impl Program { lhs, rhs, target_pc, - jump_if_null, + flags, } => { assert!(target_pc.is_offset()); let lhs = *lhs; let rhs = *rhs; let target_pc = *target_pc; + let jump_if_null = flags.has_jump_if_null(); match (&state.registers[lhs], &state.registers[rhs]) { (_, OwnedValue::Null) | (OwnedValue::Null, _) => { - if *jump_if_null { + if jump_if_null { state.pc = target_pc.to_offset_int(); } else { state.pc += 1; @@ -693,15 +727,16 @@ impl Program { lhs, rhs, target_pc, - jump_if_null, + flags, } => { assert!(target_pc.is_offset()); let lhs = *lhs; let rhs = *rhs; let target_pc = *target_pc; + let jump_if_null = flags.has_jump_if_null(); match (&state.registers[lhs], &state.registers[rhs]) { (_, OwnedValue::Null) | (OwnedValue::Null, _) => { - if *jump_if_null { + if jump_if_null { state.pc = target_pc.to_offset_int(); } else { state.pc += 1; @@ -720,15 +755,16 @@ impl Program { lhs, rhs, target_pc, - jump_if_null, + flags, } => { assert!(target_pc.is_offset()); let lhs = *lhs; let rhs = *rhs; let target_pc = *target_pc; + let jump_if_null = flags.has_jump_if_null(); match (&state.registers[lhs], &state.registers[rhs]) { (_, OwnedValue::Null) | (OwnedValue::Null, _) => { - if *jump_if_null { + if jump_if_null { state.pc = target_pc.to_offset_int(); } else { state.pc += 1; @@ -773,6 +809,7 @@ impl Program { } => { let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); let cursor = BTreeCursor::new(pager.clone(), *root_page); + let mut cursors = state.cursors.borrow_mut(); match cursor_type { CursorType::BTreeTable(_) => { cursors @@ -803,6 +840,7 @@ impl Program { content_reg: _, num_fields: _, } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = PseudoCursor::new(); cursors .get_mut(*cursor_id) @@ -811,12 +849,14 @@ impl Program { state.pc += 1; } Insn::RewindAsync { cursor_id } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = must_be_btree_cursor!(*cursor_id, self.cursor_ref, cursors, "RewindAsync"); return_if_io!(cursor.rewind()); state.pc += 1; } Insn::LastAsync { cursor_id } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = must_be_btree_cursor!(*cursor_id, self.cursor_ref, cursors, "LastAsync"); return_if_io!(cursor.last()); @@ -827,6 +867,7 @@ impl Program { pc_if_empty, } => { assert!(pc_if_empty.is_offset()); + let mut cursors = state.cursors.borrow_mut(); let cursor = must_be_btree_cursor!(*cursor_id, self.cursor_ref, cursors, "LastAwait"); cursor.wait_for_completion()?; @@ -841,6 +882,7 @@ impl Program { pc_if_empty, } => { assert!(pc_if_empty.is_offset()); + let mut cursors = state.cursors.borrow_mut(); let cursor = must_be_btree_cursor!(*cursor_id, self.cursor_ref, cursors, "RewindAwait"); cursor.wait_for_completion()?; @@ -855,6 +897,7 @@ impl Program { column, dest, } => { + let mut cursors = state.cursors.borrow_mut(); if let Some((index_cursor_id, table_cursor_id)) = state.deferred_seek.take() { let index_cursor = get_cursor_as_index_mut(&mut cursors, index_cursor_id); let rowid = index_cursor.rowid()?; @@ -922,6 +965,7 @@ impl Program { return Ok(StepResult::Row(record)); } Insn::NextAsync { cursor_id } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = must_be_btree_cursor!(*cursor_id, self.cursor_ref, cursors, "NextAsync"); cursor.set_null_flag(false); @@ -929,6 +973,7 @@ impl Program { state.pc += 1; } Insn::PrevAsync { cursor_id } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = must_be_btree_cursor!(*cursor_id, self.cursor_ref, cursors, "PrevAsync"); cursor.set_null_flag(false); @@ -939,6 +984,7 @@ impl Program { cursor_id, pc_if_next, } => { + let mut cursors = state.cursors.borrow_mut(); assert!(pc_if_next.is_offset()); let cursor = must_be_btree_cursor!(*cursor_id, self.cursor_ref, cursors, "PrevAwait"); @@ -954,6 +1000,7 @@ impl Program { pc_if_next, } => { assert!(pc_if_next.is_offset()); + let mut cursors = state.cursors.borrow_mut(); let cursor = must_be_btree_cursor!(*cursor_id, self.cursor_ref, cursors, "NextAwait"); cursor.wait_for_completion()?; @@ -983,10 +1030,19 @@ impl Program { } } log::trace!("Halt auto_commit {}", self.auto_commit); + let connection = self + .connection + .upgrade() + .expect("only weak ref to connection?"); + let current_state = connection.transaction_state.borrow().clone(); + if current_state == TransactionState::Read { + pager.end_read_tx()?; + return Ok(StepResult::Done); + } return if self.auto_commit { match pager.end_tx() { Ok(crate::storage::wal::CheckpointStatus::IO) => Ok(StepResult::IO), - Ok(crate::storage::wal::CheckpointStatus::Done) => { + Ok(crate::storage::wal::CheckpointStatus::Done(_)) => { if self.change_cnt_on { if let Some(conn) = self.connection.upgrade() { conn.set_changes(self.n_change.get()); @@ -1084,6 +1140,7 @@ impl Program { state.pc += 1; } Insn::RowId { cursor_id, dest } => { + let mut cursors = state.cursors.borrow_mut(); if let Some((index_cursor_id, table_cursor_id)) = state.deferred_seek.take() { let index_cursor = get_cursor_as_index_mut(&mut cursors, index_cursor_id); let rowid = index_cursor.rowid()?; @@ -1111,6 +1168,7 @@ impl Program { target_pc, } => { assert!(target_pc.is_offset()); + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_table_mut(&mut cursors, *cursor_id); let rowid = match &state.registers[*src_reg] { OwnedValue::Integer(rowid) => *rowid as u64, @@ -1146,6 +1204,7 @@ impl Program { is_index, } => { assert!(target_pc.is_offset()); + let mut cursors = state.cursors.borrow_mut(); if *is_index { let cursor = get_cursor_as_index_mut(&mut cursors, *cursor_id); let record_from_regs: OwnedRecord = @@ -1191,6 +1250,7 @@ impl Program { is_index, } => { assert!(target_pc.is_offset()); + let mut cursors = state.cursors.borrow_mut(); if *is_index { let cursor = get_cursor_as_index_mut(&mut cursors, *cursor_id); let record_from_regs: OwnedRecord = @@ -1235,6 +1295,7 @@ impl Program { target_pc, } => { assert!(target_pc.is_offset()); + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_index_mut(&mut cursors, *cursor_id); let record_from_regs: OwnedRecord = make_owned_record(&state.registers, start_reg, num_regs); @@ -1249,7 +1310,7 @@ impl Program { } } else { state.pc = target_pc.to_offset_int(); - } + }; } Insn::IdxGT { cursor_id, @@ -1258,6 +1319,7 @@ impl Program { target_pc, } => { assert!(target_pc.is_offset()); + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_index_mut(&mut cursors, *cursor_id); let record_from_regs: OwnedRecord = make_owned_record(&state.registers, start_reg, num_regs); @@ -1272,7 +1334,7 @@ impl Program { } } else { state.pc = target_pc.to_offset_int(); - } + }; } Insn::DecrJumpZero { reg, target_pc } => { assert!(target_pc.is_offset()); @@ -1553,7 +1615,7 @@ impl Program { AggFunc::Min => {} AggFunc::GroupConcat | AggFunc::StringAgg => {} AggFunc::External(_) => { - agg.compute_external(); + agg.compute_external()?; } }, OwnedValue::Null => { @@ -1588,6 +1650,7 @@ impl Program { }) .collect(); let cursor = Sorter::new(order); + let mut cursors = state.cursors.borrow_mut(); cursors .get_mut(*cursor_id) .unwrap() @@ -1599,6 +1662,7 @@ impl Program { dest_reg, pseudo_cursor, } => { + let mut cursors = state.cursors.borrow_mut(); let sorter_cursor = get_cursor_as_sorter_mut(&mut cursors, *cursor_id); let record = match sorter_cursor.record() { Some(record) => record.clone(), @@ -1616,6 +1680,7 @@ impl Program { cursor_id, record_reg, } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_sorter_mut(&mut cursors, *cursor_id); let record = match &state.registers[*record_reg] { OwnedValue::Record(record) => record, @@ -1628,6 +1693,7 @@ impl Program { cursor_id, pc_if_empty, } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_sorter_mut(&mut cursors, *cursor_id); if cursor.is_empty() { state.pc = pc_if_empty.to_offset_int(); @@ -1641,6 +1707,7 @@ impl Program { pc_if_next, } => { assert!(pc_if_next.is_offset()); + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_sorter_mut(&mut cursors, *cursor_id); cursor.next(); if cursor.has_more() { @@ -1661,7 +1728,7 @@ impl Program { crate::function::Func::Json(json_func) => match json_func { JsonFunc::Json => { let json_value = &state.registers[*start_reg]; - let json_str = get_json(json_value); + let json_str = get_json(json_value, None); match json_str { Ok(json) => state.registers[*dest] = json, Err(e) => return Err(e), @@ -1746,6 +1813,64 @@ impl Program { let json_value = &state.registers[*start_reg]; state.registers[*dest] = is_json_valid(json_value)?; } + JsonFunc::JsonPatch => { + assert_eq!(arg_count, 2); + assert!(*start_reg + 1 < state.registers.len()); + let target = &state.registers[*start_reg]; + let patch = &state.registers[*start_reg + 1]; + state.registers[*dest] = json_patch(target, patch)?; + } + JsonFunc::JsonRemove => { + state.registers[*dest] = json_remove( + &state.registers[*start_reg..*start_reg + arg_count], + )?; + } + JsonFunc::JsonPretty => { + let json_value = &state.registers[*start_reg]; + let indent = if arg_count > 1 { + Some(&state.registers[*start_reg + 1]) + } else { + None + }; + + // Blob should be converted to Ascii in a lossy way + // However, Rust strings uses utf-8 + // so the behavior at the moment is slightly different + // To the way blobs are parsed here in SQLite. + let indent = match indent { + Some(value) => match value { + OwnedValue::Text(text) => text.value.as_str(), + OwnedValue::Integer(val) => &val.to_string(), + OwnedValue::Float(val) => &val.to_string(), + OwnedValue::Blob(val) => &String::from_utf8_lossy(val), + OwnedValue::Agg(ctx) => match ctx.final_value() { + OwnedValue::Text(text) => text.value.as_str(), + OwnedValue::Integer(val) => &val.to_string(), + OwnedValue::Float(val) => &val.to_string(), + OwnedValue::Blob(val) => &String::from_utf8_lossy(val), + _ => " ", + }, + _ => " ", + }, + // If the second argument is omitted or is NULL, then indentation is four spaces per level + None => " ", + }; + + let json_str = get_json(json_value, Some(indent))?; + state.registers[*dest] = json_str; + } + JsonFunc::JsonSet => { + let reg_values = + &state.registers[*start_reg + 1..*start_reg + arg_count]; + + let json_result = + json_set(&state.registers[*start_reg], reg_values); + + match json_result { + Ok(json) => state.registers[*dest] = json, + Err(e) => return Err(e), + } + } }, crate::function::Func::Scalar(scalar_func) => match scalar_func { ScalarFunc::Cast => { @@ -1879,7 +2004,7 @@ impl Program { let reg_value = state.registers[*start_reg].borrow_mut(); let result = match scalar_func { ScalarFunc::Sign => exec_sign(reg_value), - ScalarFunc::Abs => exec_abs(reg_value), + ScalarFunc::Abs => Some(exec_abs(reg_value)?), ScalarFunc::Lower => exec_lower(reg_value), ScalarFunc::Upper => exec_upper(reg_value), ScalarFunc::Length => Some(exec_length(reg_value)), @@ -2038,6 +2163,14 @@ impl Program { let version = execute_sqlite_version(version_integer); state.registers[*dest] = OwnedValue::build_text(Rc::new(version)); } + ScalarFunc::SqliteSourceId => { + let src_id = format!( + "{} {}", + info::build::BUILT_TIME_SQLITE, + info::build::GIT_COMMIT_HASH.unwrap_or("unknown") + ); + state.registers[*dest] = OwnedValue::build_text(Rc::new(src_id)); + } ScalarFunc::Replace => { assert_eq!(arg_count, 3); let source = &state.registers[*start_reg]; @@ -2059,6 +2192,12 @@ impl Program { ); state.registers[*dest] = result; } + ScalarFunc::Printf => { + let result = exec_printf( + &state.registers[*start_reg..*start_reg + arg_count], + )?; + state.registers[*dest] = result; + } }, crate::function::Func::External(f) => match f.func { ExtFunc::Scalar(f) => { @@ -2180,6 +2319,7 @@ impl Program { record_reg, flag: _, } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_table_mut(&mut cursors, *cursor); let record = match &state.registers[*record_reg] { OwnedValue::Record(r) => r, @@ -2190,6 +2330,7 @@ impl Program { state.pc += 1; } Insn::InsertAwait { cursor_id } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_table_mut(&mut cursors, *cursor_id); cursor.wait_for_completion()?; // Only update last_insert_rowid for regular table inserts, not schema modifications @@ -2205,11 +2346,13 @@ impl Program { state.pc += 1; } Insn::DeleteAsync { cursor_id } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_table_mut(&mut cursors, *cursor_id); return_if_io!(cursor.delete()); state.pc += 1; } Insn::DeleteAwait { cursor_id } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_table_mut(&mut cursors, *cursor_id); cursor.wait_for_completion()?; let prev_changes = self.n_change.get(); @@ -2219,6 +2362,7 @@ impl Program { Insn::NewRowid { cursor, rowid_reg, .. } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_table_mut(&mut cursors, *cursor); // TODO: make io handle rng let rowid = return_if_io!(get_new_rowid(cursor, thread_rng())); @@ -2259,6 +2403,7 @@ impl Program { rowid_reg, target_pc, } => { + let mut cursors = state.cursors.borrow_mut(); let cursor = must_be_btree_cursor!(*cursor, self.cursor_ref, cursors, "NotExists"); let exists = return_if_io!(cursor.exists(&state.registers[*rowid_reg])); @@ -2268,6 +2413,37 @@ impl Program { state.pc = target_pc.to_offset_int(); } } + Insn::OffsetLimit { + limit_reg, + combined_reg, + offset_reg, + } => { + let limit_val = match state.registers[*limit_reg] { + OwnedValue::Integer(val) => val, + _ => { + return Err(LimboError::InternalError( + "OffsetLimit: the value in limit_reg is not an integer".into(), + )); + } + }; + let offset_val = match state.registers[*offset_reg] { + OwnedValue::Integer(val) if val < 0 => 0, + OwnedValue::Integer(val) if val >= 0 => val, + _ => { + return Err(LimboError::InternalError( + "OffsetLimit: the value in offset_reg is not an integer".into(), + )); + } + }; + + let offset_limit_sum = limit_val.overflowing_add(offset_val); + if limit_val <= 0 || offset_limit_sum.1 { + state.registers[*combined_reg] = OwnedValue::Integer(-1); + } else { + state.registers[*combined_reg] = OwnedValue::Integer(offset_limit_sum.0); + } + state.pc += 1; + } // this cursor may be reused for next insert // Update: tablemoveto is used to travers on not exists, on insert depending on flags if nonseek it traverses again. // If not there might be some optimizations obviously. @@ -2276,6 +2452,7 @@ impl Program { root_page, } => { let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let mut cursors = state.cursors.borrow_mut(); let is_index = cursor_type.is_index(); let cursor = BTreeCursor::new(pager.clone(), *root_page); if is_index { @@ -2316,16 +2493,31 @@ impl Program { state.pc += 1; } Insn::Close { cursor_id } => { + let mut cursors = state.cursors.borrow_mut(); cursors.get_mut(*cursor_id).unwrap().take(); state.pc += 1; } - Insn::IsNull { src, target_pc } => { - if matches!(state.registers[*src], OwnedValue::Null) { + Insn::IsNull { reg, target_pc } => { + if matches!(state.registers[*reg], OwnedValue::Null) { state.pc = target_pc.to_offset_int(); } else { state.pc += 1; } } + Insn::PageCount { db, dest } => { + if *db > 0 { + // TODO: implement temp databases + todo!("temp databases not implemented yet"); + } + // SQLite returns "0" on an empty database, and 2 on the first insertion, + // so we'll mimick that behavior. + let mut pages = pager.db_header.borrow().database_size.into(); + if pages == 1 { + pages = 0; + } + state.registers[*dest] = OwnedValue::Integer(pages); + state.pc += 1; + } Insn::ParseSchema { db: _, where_clause, @@ -2387,6 +2579,11 @@ impl Program { exec_or(&state.registers[*lhs], &state.registers[*rhs]); state.pc += 1; } + Insn::Noop => { + // Do nothing + // Advance the program counter for the next opcode + state.pc += 1 + } } } } @@ -2397,7 +2594,11 @@ fn get_new_rowid(cursor: &mut BTreeCursor, mut rng: R) -> Result {} CursorResult::IO => return Ok(CursorResult::IO), } - let mut rowid = cursor.rowid()?.unwrap_or(0) + 1; + let mut rowid = cursor + .rowid()? + .unwrap_or(0) // if BTree is empty - use 0 as initial value for rowid + .checked_add(1) // add 1 but be careful with overflows + .unwrap_or(u64::MAX); // in case of overflow - use u64::MAX if rowid > i64::MAX.try_into().unwrap() { let distribution = Uniform::from(1..=i64::MAX); let max_attempts = 100; @@ -2448,7 +2649,10 @@ fn trace_insn(program: &Program, addr: InsnReference, insn: &Insn) { addr, insn, String::new(), - program.comments.get(&{ addr }).copied() + program + .comments + .as_ref() + .and_then(|comments| comments.get(&{ addr }).copied()) ) ); } @@ -2459,7 +2663,10 @@ fn print_insn(program: &Program, addr: InsnReference, insn: &Insn, indent: Strin addr, insn, indent, - program.comments.get(&{ addr }).copied(), + program + .comments + .as_ref() + .and_then(|comments| comments.get(&{ addr }).copied()), ); println!("{}", s); } @@ -2696,24 +2903,25 @@ pub fn exec_soundex(reg: &OwnedValue) -> OwnedValue { OwnedValue::build_text(Rc::new(result.to_uppercase())) } -fn exec_abs(reg: &OwnedValue) -> Option { +fn exec_abs(reg: &OwnedValue) -> Result { match reg { OwnedValue::Integer(x) => { - if x < &0 { - Some(OwnedValue::Integer(-x)) - } else { - Some(OwnedValue::Integer(*x)) + match i64::checked_abs(*x) { + Some(y) => Ok(OwnedValue::Integer(y)), + // Special case: if we do the abs of "-9223372036854775808", it causes overflow. + // return IntegerOverflow error + None => Err(LimboError::IntegerOverflow), } } OwnedValue::Float(x) => { if x < &0.0 { - Some(OwnedValue::Float(-x)) + Ok(OwnedValue::Float(-x)) } else { - Some(OwnedValue::Float(*x)) + Ok(OwnedValue::Float(*x)) } } - OwnedValue::Null => Some(OwnedValue::Null), - _ => Some(OwnedValue::Float(0.0)), + OwnedValue::Null => Ok(OwnedValue::Null), + _ => Ok(OwnedValue::Float(0.0)), } } @@ -3746,6 +3954,9 @@ mod tests { OwnedValue::Float(0.0) ); assert_eq!(exec_abs(&OwnedValue::Null).unwrap(), OwnedValue::Null); + + // ABS(i64::MIN) should return RuntimeError + assert!(exec_abs(&OwnedValue::Integer(i64::MIN)).is_err()); } #[test] diff --git a/core/vdbe/printf.rs b/core/vdbe/printf.rs new file mode 100644 index 000000000..c4fb6a153 --- /dev/null +++ b/core/vdbe/printf.rs @@ -0,0 +1,265 @@ +use std::rc::Rc; + +use crate::types::OwnedValue; +use crate::LimboError; + +#[inline(always)] +pub fn exec_printf(values: &[OwnedValue]) -> crate::Result { + if values.is_empty() { + return Ok(OwnedValue::Null); + } + let format_str = match &values[0] { + OwnedValue::Text(t) => &t.value, + _ => return Ok(OwnedValue::Null), + }; + + let mut result = String::new(); + let mut args_index = 1; + let mut chars = format_str.chars().peekable(); + + while let Some(c) = chars.next() { + if c != '%' { + result.push(c); + continue; + } + + match chars.next() { + Some('%') => { + result.push('%'); + continue; + } + Some('d') => { + if args_index >= values.len() { + return Err(LimboError::InvalidArgument("not enough arguments".into())); + } + match &values[args_index] { + OwnedValue::Integer(i) => result.push_str(&i.to_string()), + OwnedValue::Float(f) => result.push_str(&f.to_string()), + _ => result.push_str("0".into()), + } + args_index += 1; + } + Some('s') => { + if args_index >= values.len() { + return Err(LimboError::InvalidArgument("not enough arguments".into())); + } + match &values[args_index] { + OwnedValue::Text(t) => result.push_str(&t.value), + OwnedValue::Null => result.push_str("(null)"), + v => result.push_str(&v.to_string()), + } + args_index += 1; + } + Some('f') => { + if args_index >= values.len() { + return Err(LimboError::InvalidArgument("not enough arguments".into())); + } + match &values[args_index] { + OwnedValue::Float(f) => result.push_str(&f.to_string()), + OwnedValue::Integer(i) => result.push_str(&(*i as f64).to_string()), + _ => result.push_str("0.0".into()), + } + args_index += 1; + } + None => { + return Err(LimboError::InvalidArgument( + "incomplete format specifier".into(), + )) + } + _ => { + return Err(LimboError::InvalidFormatter( + "this formatter is not supported".into(), + )); + } + } + } + Ok(OwnedValue::build_text(Rc::new(result))) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::rc::Rc; + + fn text(value: &str) -> OwnedValue { + OwnedValue::build_text(Rc::new(value.to_string())) + } + + fn integer(value: i64) -> OwnedValue { + OwnedValue::Integer(value) + } + + fn float(value: f64) -> OwnedValue { + OwnedValue::Float(value) + } + + #[test] + fn test_printf_no_args() { + assert_eq!(exec_printf(&[]).unwrap(), OwnedValue::Null); + } + + #[test] + fn test_printf_basic_string() { + assert_eq!( + exec_printf(&[text("Hello World")]).unwrap(), + text("Hello World") + ); + } + + #[test] + fn test_printf_string_formatting() { + let test_cases = vec![ + // Simple string substitution + ( + vec![text("Hello, %s!"), text("World")], + text("Hello, World!"), + ), + // Multiple string substitutions + ( + vec![text("%s %s!"), text("Hello"), text("World")], + text("Hello World!"), + ), + // String with null value + ( + vec![text("Hello, %s!"), OwnedValue::Null], + text("Hello, (null)!"), + ), + // String with number conversion + (vec![text("Value: %s"), integer(42)], text("Value: 42")), + // Escaping percent sign + (vec![text("100%% complete")], text("100% complete")), + ]; + for (input, output) in test_cases { + assert_eq!(exec_printf(&input).unwrap(), output); + } + } + + #[test] + fn test_printf_integer_formatting() { + let test_cases = vec![ + // Basic integer formatting + (vec![text("Number: %d"), integer(42)], text("Number: 42")), + // Negative integer + (vec![text("Number: %d"), integer(-42)], text("Number: -42")), + // Multiple integers + ( + vec![text("%d + %d = %d"), integer(2), integer(3), integer(5)], + text("2 + 3 = 5"), + ), + // Non-numeric value defaults to 0 + ( + vec![text("Number: %d"), text("not a number")], + text("Number: 0"), + ), + ]; + for (input, output) in test_cases { + assert_eq!(exec_printf(&input).unwrap(), output) + } + } + + #[test] + fn test_printf_float_formatting() { + let test_cases = vec![ + // Basic float formatting + (vec![text("Number: %f"), float(42.5)], text("Number: 42.5")), + // Negative float + ( + vec![text("Number: %f"), float(-42.5)], + text("Number: -42.5"), + ), + // Integer as float + (vec![text("Number: %f"), integer(42)], text("Number: 42")), + // Multiple floats + ( + vec![text("%f + %f = %f"), float(2.5), float(3.5), float(6.0)], + text("2.5 + 3.5 = 6"), + ), + // Non-numeric value defaults to 0.0 + ( + vec![text("Number: %f"), text("not a number")], + text("Number: 0.0"), + ), + ]; + + for (input, expected) in test_cases { + assert_eq!(exec_printf(&input).unwrap(), expected); + } + } + + #[test] + fn test_printf_mixed_formatting() { + let test_cases = vec![ + // Mix of string and integer + ( + vec![text("%s: %d"), text("Count"), integer(42)], + text("Count: 42"), + ), + // Mix of all types + ( + vec![ + text("%s: %d (%f%%)"), + text("Progress"), + integer(75), + float(75.5), + ], + text("Progress: 75 (75.5%)"), + ), + // Complex format + ( + vec![ + text("Name: %s, ID: %d, Score: %f"), + text("John"), + integer(123), + float(95.5), + ], + text("Name: John, ID: 123, Score: 95.5"), + ), + ]; + + for (input, expected) in test_cases { + assert_eq!(exec_printf(&input).unwrap(), expected); + } + } + + #[test] + fn test_printf_error_cases() { + let error_cases = vec![ + // Not enough arguments + vec![text("%d %d"), integer(42)], + // Invalid format string + vec![text("%z"), integer(42)], + // Incomplete format specifier + vec![text("incomplete %")], + ]; + + for case in error_cases { + assert!(exec_printf(&case).is_err()); + } + } + + #[test] + fn test_printf_edge_cases() { + let test_cases = vec![ + // Empty format string + (vec![text("")], text("")), + // Only percent signs + (vec![text("%%%%")], text("%%")), + // String with no format specifiers + (vec![text("No substitutions")], text("No substitutions")), + // Multiple consecutive format specifiers + ( + vec![text("%d%d%d"), integer(1), integer(2), integer(3)], + text("123"), + ), + // Format string with special characters + ( + vec![text("Special chars: %s"), text("\n\t\r")], + text("Special chars: \n\t\r"), + ), + ]; + + for (input, expected) in test_cases { + assert_eq!(exec_printf(&input).unwrap(), expected); + } + } +} diff --git a/core/vdbe/strftime.rs b/core/vdbe/strftime.rs new file mode 100644 index 000000000..f54914e78 --- /dev/null +++ b/core/vdbe/strftime.rs @@ -0,0 +1,228 @@ +//! Code adapted from Chrono StrftimeItems but for sqlite strftime compatibility +//! Sqlite reference https://www.sqlite.org/lang_datefunc.html + +use chrono::format::{Fixed, Item, Numeric, Pad}; + +const fn num(numeric: Numeric) -> Item<'static> { + Item::Numeric(numeric, Pad::None) +} + +const fn num0(numeric: Numeric) -> Item<'static> { + Item::Numeric(numeric, Pad::Zero) +} + +const fn nums(numeric: Numeric) -> Item<'static> { + Item::Numeric(numeric, Pad::Space) +} + +const fn fixed(fixed: Fixed) -> Item<'static> { + Item::Fixed(fixed) +} + +#[derive(Clone, Debug)] +pub struct CustomStrftimeItems<'a> { + // Remaining portion of the string. + remainder: &'a str, + /// If the current specifier is composed of multiple formatting items (e.g. `%+`), + /// `queue` stores a slice of `Item`s that have to be returned one by one. + queue: &'static [Item<'static>], +} + +impl<'a> CustomStrftimeItems<'a> { + pub const fn new(s: &'a str) -> CustomStrftimeItems<'a> { + CustomStrftimeItems { + remainder: s, + queue: &[], + } + } +} + +// const HAVE_ALTERNATES: &str = "z"; + +impl<'a> Iterator for CustomStrftimeItems<'a> { + type Item = Item<'a>; + + fn next(&mut self) -> Option> { + // We have items queued to return from a specifier composed of multiple formatting items. + if let Some((item, remainder)) = self.queue.split_first() { + self.queue = remainder; + return Some(item.clone()); + } + + // Normal: we are parsing the formatting string. + let (remainder, item) = self.parse_next_item(self.remainder)?; + self.remainder = remainder; + Some(item) + } +} + +impl<'a> CustomStrftimeItems<'a> { + fn parse_next_item(&mut self, mut remainder: &'a str) -> Option<(&'a str, Item<'a>)> { + // use InternalInternal::*; + use Item::{Literal, Space}; + use Numeric::*; + + match remainder.chars().next() { + // we are done + None => None, + + // the next item is a specifier + Some('%') => { + remainder = &remainder[1..]; + + macro_rules! next { + () => { + match remainder.chars().next() { + Some(x) => { + remainder = &remainder[x.len_utf8()..]; + x + } + None => return Some((remainder, Item::Error)), // premature end of string + } + }; + } + + let spec = next!(); + let pad_override = match spec { + '-' => Some(Pad::None), + '0' => Some(Pad::Zero), + '_' => Some(Pad::Space), + _ => None, + }; + + // let is_alternate = spec == '#'; + // let spec = if pad_override.is_some() || is_alternate { next!() } else { spec }; + // if is_alternate && !HAVE_ALTERNATES.contains(spec) { + // return Some((remainder, Item::Error)); + // } + + macro_rules! queue { + [$head:expr, $($tail:expr),+ $(,)*] => ({ + const QUEUE: &'static [Item<'static>] = &[$($tail),+]; + self.queue = QUEUE; + $head + }) + } + + // macro_rules! queue_from_slice { + // ($slice:expr) => {{ + // self.queue = &$slice[1..]; + // $slice[0].clone() + // }}; + // } + + let item = match spec { + // day of month: 01-31 + 'd' => num0(Day), + // day of month without leading zero: 1-31 + 'e' => nums(Day), + // fractional seconds: SS.SSS + 'f' => { + queue![num0(Second), fixed(Fixed::Nanosecond3)] + } + // ISO 8601 date: YYYY-MM-DD + 'F' => queue![ + num0(Year), + Literal("-"), + num0(Month), + Literal("-"), + num0(Day) + ], + // ISO 8601 year corresponding to %V + 'G' => num0(IsoYear), + // 2-digit ISO 8601 year corresponding to %V + 'g' => num0(IsoYearMod100), + // hour: 00-24 + 'H' => num0(Hour), + // hour for 12-hour clock: 01-12 + 'I' => num0(Hour12), + // day of year: 001-366 + 'j' => num0(Ordinal), + // hour without leading zero: 0-24 + 'k' => nums(Hour), + // %I without leading zero: 1-12 + 'l' => nums(Hour12), + // month: 01-12 + 'm' => num0(Month), + // minute: 00-59 + 'M' => num0(Minute), + // "AM" or "PM" depending on the hour + 'p' => fixed(Fixed::UpperAmPm), + // "am" or "pm" depending on the hour + 'P' => fixed(Fixed::LowerAmPm), + // ISO 8601 time: HH:MM + 'R' => queue![num0(Hour), Literal(":"), num0(Minute)], + // seconds since 1970-01-01 + 's' => num(Timestamp), + // seconds: 00-59 + 'S' => num0(Second), + // ISO 8601 time: HH:MM:SS + 'T' => { + queue![ + num0(Hour), + Literal(":"), + num0(Minute), + Literal(":"), + num0(Second) + ] + } + // week of year (00-53) - week 01 starts on the first Sunday + 'U' => num0(WeekFromSun), + // day of week 1-7 with Monday==1 + 'u' => num(WeekdayFromMon), + // ISO 8601 week of year + 'V' => num0(IsoWeek), + // day of week 0-6 with Sunday==0 + 'w' => num(NumDaysFromSun), + // week of year (00-53) - week 01 starts on the first Monday + 'W' => num0(WeekFromMon), + // year: 0000-9999 + 'Y' => num0(Year), + // % + '%' => Literal("%"), + // TODO instead of doing a preprocessing of the %J specifier, it could be done post as postprocessing + // step by just emitting the formatter again to the string + // 'J' => Literal("%J"), + _ => Item::Error, // no such specifier + }; + + // Adjust `item` if we have any padding modifier. + // Not allowed on non-numeric items or on specifiers composed out of multiple + // formatting items. + if let Some(new_pad) = pad_override { + match item { + Item::Numeric(ref kind, _pad) if self.queue.is_empty() => { + Some((remainder, Item::Numeric(kind.clone(), new_pad))) + } + _ => Some((remainder, Item::Error)), + } + } else { + Some((remainder, item)) + } + } + + // the next item is space + Some(c) if c.is_whitespace() => { + // `%` is not a whitespace, so `c != '%'` is redundant + let nextspec = remainder + .find(|c: char| !c.is_whitespace()) + .unwrap_or(remainder.len()); + assert!(nextspec > 0); + let item = Space(&remainder[..nextspec]); + remainder = &remainder[nextspec..]; + Some((remainder, item)) + } + + // the next item is literal + _ => { + let nextspec = remainder + .find(|c: char| c.is_whitespace() || c == '%') + .unwrap_or(remainder.len()); + assert!(nextspec > 0); + let item = Literal(&remainder[..nextspec]); + remainder = &remainder[nextspec..]; + Some((remainder, item)) + } + } + } +} diff --git a/docs/internals/functions.md b/docs/contributing/contributing_functions.md similarity index 87% rename from docs/internals/functions.md rename to docs/contributing/contributing_functions.md index ff71b6864..a4386cab5 100644 --- a/docs/internals/functions.md +++ b/docs/contributing/contributing_functions.md @@ -6,6 +6,11 @@ Steps 3. Implement the function in a feature branch. 4. Push it as a Merge Request, get it review. +Sample Pull Requests of function contributing +- [partial support for datetime() and julianday()](https://github.com/tursodatabase/limbo/pull/600) +- [support for changes() and total_changes()](https://github.com/tursodatabase/limbo/pull/589) +- [support for unhex(X)](https://github.com/tursodatabase/limbo/pull/353) + ## An example with function `date(..)` > Note that the files, code location, steps might be not exactly the same because of refactor but the idea of the changes needed in each layer stays. @@ -13,13 +18,24 @@ Steps [Issue #158](https://github.com/tursodatabase/limbo/issues/158) was created for it. Refer to commit [4ff7058](https://github.com/tursodatabase/limbo/commit/4ff705868a054643f6113cbe009655c32bc5f235). +![limbo_architecture.png](limbo_architecture.png) + +To add a function we generally need to touch at least the following modules +- SQL Command Processor + - The `SQL Command Processor` module is responsible for turning sql function string into a sequence of instructions to be executed by the `Virtual Machine` module. + - we need the following things: function definition, how the `bytecode generator` in `core/translate` generates bytecode program for this function to be executed. +- Virtual Machine `core/vdbe` + - we need to add logic of how the `vdbe` should execute the logic of this function in Rust and write result to destination register of the vm. + - [more info](https://www.sqlite.org/opcode.html) +- Tests + ``` -sql function: string ---Parser--> -enum Func ---translate--> -Instruction ---VDBE--> +SQL function string +--Tokenizer and Parser--> +AST (enum Func) +--Bytecode Generator (core/translate)--> +Bytecode Instructions +--Virtual Machine--> Result ``` diff --git a/docs/contributing/limbo_architecture.png b/docs/contributing/limbo_architecture.png new file mode 100644 index 000000000..630bfb633 Binary files /dev/null and b/docs/contributing/limbo_architecture.png differ diff --git a/docs/internals.md b/docs/internals.md new file mode 100644 index 000000000..0063abf5d --- /dev/null +++ b/docs/internals.md @@ -0,0 +1,82 @@ +# Limbo Database System Design and Implementation + +This is a work-in-progress book on the design and implementation of Limbo. + +## Limbo Overview + +Limbo is an in-process OLTP database system with SQLite compatibility. Unlike +client-server database systems such as PostgreSQL or MySQL, which require +applications to communicate over network protocols for SQL execution, an +in-process database is in your application memory space. This embedded +architecture eliminates network communication overhead, allowing for the best +case of low read and write latencies in the order of sub-microseconds. + +Limbo's architecture resembles SQLite's but differs primarily in its +asynchronous I/O model. This asynchronous design enables applications to +leverage modern I/O interfaces like `io_uring,` maximizing storage device +performance. While an in-process database offers significant performance +advantages, integration with cloud services remains crucial for operations +like backups. Limbo's asynchronous I/O model facilitates this by supporting +networked storage capabilities. + +The high-level interface to Limbo is the same as in SQLite: + +* SQLite query language +* The `sqlite3_prepare()` function for translating SQL statements to programs + ("prepared statements") +* The `sqlite3_step()` function for executing programs + +If we start with the SQLite query language, you can use the `limbo` +command, for example, to evaluate SQL statements in the shell: + +``` +limbo> SELECT 'hello, world'; +hello, world +``` + +To execute this SQL statement, the shell uses the `sqlite3_prepare()` +interface to parse the statement and generate a bytecode program, a step +called preparing a statement. When a statement is prepared, it can be executed +using the `sqlite3_step()` function. + +To inspect the bytecode program for a SQL statement, you can use the +`EXPLAIN` command in the shell. For our example SQL statement, the bytecode +looks as follows: + +``` +limbo> EXPLAIN SELECT 'hello, world'; +addr opcode p1 p2 p3 p4 p5 comment +---- ----------------- ---- ---- ---- ------------- -- ------- +0 Init 0 4 0 0 Start at 4 +1 String8 0 1 0 hello, world 0 r[1]='hello, world' +2 ResultRow 1 1 0 0 output=r[1] +3 Halt 0 0 0 0 +4 Transaction 0 0 0 0 +5 Goto 0 1 0 0 +``` + +The instruction set of the virtual machine consists of domain specific +instructions for a database system. Every instruction consists of an +opcode that describes the operation and up to 5 operands. In the example +above, execution starts at offset zero with the `Init` instruction. The +instruction sets up the program and branches to a instruction at address +specified in operand `p2`. In our example, address 4 has the +`Transaction` instruction, which begins a transaction. After that, the +`Goto` instruction then branches to address 1 where we load a string +constant `'hello, world'` to register `r[1]`. The `ResultRow` instruction +produces a SQL query result using contents of `r[1]`. Finally, the +program terminates with the `Halt` instruction. + +## Frontend + +### Parser + +### Code generator + +### Query optimizer + +## Virtual Machine + +## Pager + +## I/O diff --git a/docs/internals/mvcc/DESIGN.md b/docs/internals/mvcc/DESIGN.md new file mode 100644 index 000000000..37943d992 --- /dev/null +++ b/docs/internals/mvcc/DESIGN.md @@ -0,0 +1,19 @@ +# Design + +## Persistent storage + +Persistent storage must implement the `Storage` trait that the MVCC module uses for transaction logging. + +Figure 1 shows an example of write-ahead log across three transactions. +The first transaction T0 executes a `INSERT (id) VALUES (1)` statement, which results in a log record with `id` set to `1`, begin timestamp to 0 (which is the transaction ID) and end timestamp as infinity (meaning the row version is still visible). +The second transaction T1 executes another `INSERT` statement, which adds another log record to the transaction log with `id` set to `2`, begin timesstamp to 1 and end timestamp as infinity, similar to what T0 did. +Finally, a third transaction T2 executes two statements: `DELETE WHERE id = 1` and `INSERT (id) VALUES (3)`. The first one results in a log record with `id` set to `1` and begin timestamp set to 0 (which is the transaction that created the entry). However, the end timestamp is now set to 2 (the current transaction), which means the entry is now deleted. +The second statement results in an entry in the transaction log similar to the `INSERT` statements in T0 and T1. + +![Transactions](figures/transactions.png) +

+Figure 1. Transaction log of three transactions. +

+ +When MVCC bootstraps or recovers, it simply redos the transaction log. +If the transaction log grows big, we can checkpoint it it by dropping all entries that are no longer visible after the the latest transaction and create a snapshot. diff --git a/docs/internals/mvcc/figures/transactions.excalidraw b/docs/internals/mvcc/figures/transactions.excalidraw new file mode 100644 index 000000000..cee1947f9 --- /dev/null +++ b/docs/internals/mvcc/figures/transactions.excalidraw @@ -0,0 +1,656 @@ +{ + "type": "excalidraw", + "version": 2, + "source": "https://excalidraw.com", + "elements": [ + { + "id": "tFvpBUMWe3qPFUTQVV14X", + "type": "text", + "x": 233.14035848761839, + "y": 205.73272444200816, + "width": 278.57781982421875, + "height": 25, + "angle": 0, + "strokeColor": "#087f5b", + "backgroundColor": "#82c91e", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "roundness": null, + "seed": 94988319, + "version": 510, + "versionNonce": 1210831775, + "isDeleted": false, + "boundElements": null, + "updated": 1683370319070, + "link": null, + "locked": false, + "text": "", + "fontSize": 20, + "fontFamily": 1, + "textAlign": "left", + "verticalAlign": "top", + "baseline": 18, + "containerId": null, + "originalText": "", + "lineHeight": 1.25 + }, + { + "type": "text", + "version": 515, + "versionNonce": 1881893969, + "isDeleted": false, + "id": "7i88n1PIb89NxUbVQmTTi", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 938.4614491858606, + "y": 311.23272444200813, + "strokeColor": "#0b7285", + "backgroundColor": "#82c91e", + "width": 279.0400085449219, + "height": 25, + "seed": 1123646321, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "", + "lineHeight": 1.25, + "baseline": 18 + }, + { + "type": "text", + "version": 556, + "versionNonce": 153125934, + "isDeleted": false, + "id": "Yh8XLtKqXUUYmcmG4SEXn", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 581.1603475012903, + "y": 256.23272444200813, + "strokeColor": "#e67700", + "backgroundColor": "#82c91e", + "width": 270.71783447265625, + "height": 25, + "seed": 1685524017, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683371076075, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "", + "lineHeight": 1.25, + "baseline": 18 + }, + { + "id": "8l0CCJzCAtOLt_2GRcNpa", + "type": "text", + "x": 256.1403584876185, + "y": 409.73272444200813, + "width": 234.41998291015625, + "height": 75, + "angle": 0, + "strokeColor": "#087f5b", + "backgroundColor": "#82c91e", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "roundness": null, + "seed": 583129809, + "version": 570, + "versionNonce": 561756721, + "isDeleted": false, + "boundElements": null, + "updated": 1683370316909, + "link": null, + "locked": false, + "text": "BEGIN\nINSERT (id) VALUEs (1)\nCOMMIT", + "fontSize": 20, + "fontFamily": 1, + "textAlign": "left", + "verticalAlign": "top", + "baseline": 68, + "containerId": null, + "originalText": "BEGIN\nINSERT (id) VALUEs (1)\nCOMMIT", + "lineHeight": 1.25 + }, + { + "type": "text", + "version": 628, + "versionNonce": 282656095, + "isDeleted": false, + "id": "3m7VluAP5tair6-60b_sp", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 962.0903554358606, + "y": 416.23272444200813, + "strokeColor": "#0b7285", + "backgroundColor": "#82c91e", + "width": 243.91998291015625, + "height": 100, + "seed": 479705617, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "BEGIN\nDELETE WHERE id =1\nINSERT (id) VALUES (3)\nCOMMIT", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "BEGIN\nDELETE WHERE id =1\nINSERT (id) VALUES (3)\nCOMMIT", + "lineHeight": 1.25, + "baseline": 93 + }, + { + "type": "text", + "version": 574, + "versionNonce": 1128746001, + "isDeleted": false, + "id": "Z-Mh1kti2oC6sIMnuGluo", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 613.0903554358607, + "y": 417.23272444200813, + "strokeColor": "#e67700", + "backgroundColor": "#82c91e", + "width": 243.239990234375, + "height": 75, + "seed": 580440625, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "BEGIN\nINSERT (id) VALUEs (2)\nCOMMIT", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "BEGIN\nINSERT (id) VALUEs (2)\nCOMMIT", + "lineHeight": 1.25, + "baseline": 68 + }, + { + "type": "line", + "version": 1502, + "versionNonce": 1835608607, + "isDeleted": false, + "id": "VuJNZCgz1Y0WEWwug7pGk", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 0, + "opacity": 100, + "angle": 0, + "x": 226.3083636621349, + "y": 173.11701218356845, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 20.336010349032712, + "height": 203.23377930246647, + "seed": 1879839231, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": null, + "points": [ + [ + 0, + 0 + ], + [ + -20.264781987976257, + -0.0011773927935071482 + ], + [ + -20.336010349032712, + 203.23260190967298 + ], + [ + -0.07239358683375485, + 203.135377672515 + ] + ] + }, + { + "type": "line", + "version": 1755, + "versionNonce": 1487752017, + "isDeleted": false, + "id": "GpZg3Rw4Hszxzxf38Q4Hn", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 0, + "opacity": 100, + "angle": 3.141592653589793, + "x": 539.3083636621348, + "y": 178.11701218356845, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 20.336010349032712, + "height": 203.23377930246647, + "seed": 470135121, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": null, + "points": [ + [ + 0, + 0 + ], + [ + -20.264781987976257, + -0.0011773927935071482 + ], + [ + -20.336010349032712, + 203.23260190967298 + ], + [ + -0.07239358683375485, + 203.135377672515 + ] + ] + }, + { + "type": "text", + "version": 528, + "versionNonce": 1276939839, + "isDeleted": false, + "id": "AGEyNvBxBm2cwm1WRW8n8", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 576.6403584876185, + "y": 210.23272444200816, + "strokeColor": "#087f5b", + "backgroundColor": "#82c91e", + "width": 278.57781982421875, + "height": 25, + "seed": 877528401, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "", + "lineHeight": 1.25, + "baseline": 18 + }, + { + "type": "line", + "version": 1557, + "versionNonce": 773679889, + "isDeleted": false, + "id": "Q8E0gAcLvq6VXqMDZhLdA", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 0, + "opacity": 100, + "angle": 0, + "x": 581.8083636621351, + "y": 177.61701218356845, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 20.336010349032712, + "height": 203.23377930246647, + "seed": 153279217, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": null, + "points": [ + [ + 0, + 0 + ], + [ + -20.264781987976257, + -0.0011773927935071482 + ], + [ + -20.336010349032712, + 203.23260190967298 + ], + [ + -0.07239358683375485, + 203.135377672515 + ] + ] + }, + { + "type": "line", + "version": 1810, + "versionNonce": 1561283199, + "isDeleted": false, + "id": "uhh3ZkPO6bwwf0-AI8syI", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 0, + "opacity": 100, + "angle": 3.141592653589793, + "x": 894.8083636621349, + "y": 182.61701218356845, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 20.336010349032712, + "height": 203.23377930246647, + "seed": 315380945, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": null, + "points": [ + [ + 0, + 0 + ], + [ + -20.264781987976257, + -0.0011773927935071482 + ], + [ + -20.336010349032712, + 203.23260190967298 + ], + [ + -0.07239358683375485, + 203.135377672515 + ] + ] + }, + { + "type": "text", + "version": 575, + "versionNonce": 910156017, + "isDeleted": false, + "id": "jI5YKyaOdGYYKiBWZmCMs", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 929.6403584876182, + "y": 215.23272444200813, + "strokeColor": "#087f5b", + "backgroundColor": "#82c91e", + "width": 278.57781982421875, + "height": 25, + "seed": 121503167, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "", + "lineHeight": 1.25, + "baseline": 18 + }, + { + "type": "line", + "version": 1604, + "versionNonce": 19920575, + "isDeleted": false, + "id": "QqIk7VTnRWYq499wkttvv", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 0, + "opacity": 100, + "angle": 0, + "x": 934.8083636621348, + "y": 182.61701218356842, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 20.336010349032712, + "height": 203.23377930246647, + "seed": 2012037663, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": null, + "points": [ + [ + 0, + 0 + ], + [ + -20.264781987976257, + -0.0011773927935071482 + ], + [ + -20.336010349032712, + 203.23260190967298 + ], + [ + -0.07239358683375485, + 203.135377672515 + ] + ] + }, + { + "type": "line", + "version": 1857, + "versionNonce": 1660885169, + "isDeleted": false, + "id": "gk89VsYpnf9Jby9KEUBd3", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 0, + "opacity": 100, + "angle": 3.141592653589793, + "x": 1247.808363662135, + "y": 187.61701218356842, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 20.336010349032712, + "height": 203.23377930246647, + "seed": 509453887, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370316909, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": null, + "points": [ + [ + 0, + 0 + ], + [ + -20.264781987976257, + -0.0011773927935071482 + ], + [ + -20.336010349032712, + 203.23260190967298 + ], + [ + -0.07239358683375485, + 203.135377672515 + ] + ] + }, + { + "type": "text", + "version": 620, + "versionNonce": 1588681010, + "isDeleted": false, + "id": "a1c-iZI0SafCiy0u4xieZ", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 934.3714375891809, + "y": 261.23272444200813, + "strokeColor": "#e67700", + "backgroundColor": "#82c91e", + "width": 270.71783447265625, + "height": 25, + "seed": 1742829553, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683371080181, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "", + "lineHeight": 1.25, + "baseline": 18 + }, + { + "type": "text", + "version": 564, + "versionNonce": 1968863633, + "isDeleted": false, + "id": "hdhhgp5nA06o5EcSgHQE8", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 937.6203542151575, + "y": 354.23272444200813, + "strokeColor": "#0b7285", + "backgroundColor": "#82c91e", + "width": 287.73785400390625, + "height": 25, + "seed": 309558367, + "groupIds": [], + "roundness": null, + "boundElements": [], + "updated": 1683370363648, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "", + "lineHeight": 1.25, + "baseline": 18 + } + ], + "appState": { + "gridSize": null, + "viewBackgroundColor": "#ffffff" + }, + "files": {} +} \ No newline at end of file diff --git a/docs/internals/mvcc/figures/transactions.png b/docs/internals/mvcc/figures/transactions.png new file mode 100644 index 000000000..3b8fe59bc Binary files /dev/null and b/docs/internals/mvcc/figures/transactions.png differ diff --git a/extensions/core/Cargo.toml b/extensions/core/Cargo.toml index f56436f5c..3194bcadb 100644 --- a/extensions/core/Cargo.toml +++ b/extensions/core/Cargo.toml @@ -6,6 +6,11 @@ edition.workspace = true license.workspace = true repository.workspace = true + +[features] +default = [] +static = [] + [dependencies] log = "0.4.20" limbo_macros = { path = "../../macros" } diff --git a/extensions/core/README.md b/extensions/core/README.md index 6e87743e0..bcb7ff86f 100644 --- a/extensions/core/README.md +++ b/extensions/core/README.md @@ -19,13 +19,18 @@ Add the crate to your `Cargo.toml`: ```toml [dependencies] limbo_ext = { path = "path/to/limbo/extensions/core" } # temporary until crate is published + +# mimalloc is required if you intend on linking dynamically. It is imported for you by the register_extension +# macro, so no configuration is needed. But it must be added to your Cargo.toml +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } ``` -**NOTE** Crate must be of type `cdylib` +**NOTE** Crate must be of type `cdylib` if you wish to link dynamically ``` [lib] -crate-type = ["cdylib"] +crate-type = ["cdylib", "lib"] ``` `cargo build` will output a shared library that can be loaded with `.load target/debug/libyour_crate_name` diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 9d69aa942..74fa670ad 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -1,7 +1,8 @@ -use std::{fmt::Display, os::raw::c_void}; +use std::fmt::Display; /// Error type is of type ExtError which can be /// either a user defined error or an error code +#[derive(Clone, Copy)] #[repr(C)] pub enum ResultCode { OK = 0, @@ -18,12 +19,17 @@ pub enum ResultCode { Unimplemented = 11, Internal = 12, Unavailable = 13, + CustomError = 14, } impl ResultCode { pub fn is_ok(&self) -> bool { matches!(self, ResultCode::OK) } + + pub fn has_error_set(&self) -> bool { + matches!(self, ResultCode::CustomError) + } } impl Display for ResultCode { @@ -31,7 +37,7 @@ impl Display for ResultCode { match self { ResultCode::OK => write!(f, "OK"), ResultCode::Error => write!(f, "Error"), - ResultCode::InvalidArgs => write!(f, "InvalidArgs"), + ResultCode::InvalidArgs => write!(f, "Invalid Argument"), ResultCode::Unknown => write!(f, "Unknown"), ResultCode::OoM => write!(f, "Out of Memory"), ResultCode::Corrupt => write!(f, "Corrupt"), @@ -43,6 +49,7 @@ impl Display for ResultCode { ResultCode::Unimplemented => write!(f, "Unimplemented"), ResultCode::Internal => write!(f, "Internal Error"), ResultCode::Unavailable => write!(f, "Unavailable"), + ResultCode::CustomError => write!(f, "Error "), } } } @@ -61,31 +68,35 @@ pub enum ValueType { #[repr(C)] pub struct Value { value_type: ValueType, - value: *mut c_void, + value: ValueData, +} + +#[repr(C)] +union ValueData { + int: i64, + float: f64, + text: *const TextValue, + blob: *const Blob, + error: *const ErrValue, } impl std::fmt::Debug for Value { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.value.is_null() { - return write!(f, "{:?}: Null", self.value_type); - } match self.value_type { ValueType::Null => write!(f, "Value {{ Null }}"), - ValueType::Integer => write!(f, "Value {{ Integer: {} }}", unsafe { - *(self.value as *const i64) - }), - ValueType::Float => write!(f, "Value {{ Float: {} }}", unsafe { - *(self.value as *const f64) - }), - ValueType::Text => write!(f, "Value {{ Text: {:?} }}", unsafe { - &*(self.value as *const TextValue) - }), - ValueType::Blob => write!(f, "Value {{ Blob: {:?} }}", unsafe { - &*(self.value as *const Blob) - }), - ValueType::Error => write!(f, "Value {{ Error: {:?} }}", unsafe { - &*(self.value as *const TextValue) - }), + ValueType::Integer => write!( + f, + "Value {{ Integer: {} }}", + self.to_integer().unwrap_or_default() + ), + ValueType::Float => write!( + f, + "Value {{ Float: {} }}", + self.to_float().unwrap_or_default() + ), + ValueType::Text => write!(f, "Value {{ Text: {:?} }}", self.to_text()), + ValueType::Blob => write!(f, "Value {{ Blob: {:?} }}", self.to_blob()), + ValueType::Error => write!(f, "Value {{ Error }}"), } } } @@ -123,6 +134,17 @@ impl TextValue { } } + pub(crate) fn new_boxed(s: String) -> Box { + let buffer = s.into_boxed_str(); + let ptr = buffer.as_ptr(); + let len = buffer.len(); + std::mem::forget(buffer); + Box::new(Self { + text: ptr, + len: len as u32, + }) + } + fn as_str(&self) -> &str { if self.text.is_null() { return ""; @@ -133,6 +155,39 @@ impl TextValue { } } +#[repr(C)] +pub struct ErrValue { + code: ResultCode, + message: *mut TextValue, +} + +impl ErrValue { + fn new(code: ResultCode) -> Self { + Self { + code, + message: std::ptr::null_mut(), + } + } + + fn new_with_message(code: ResultCode, message: String) -> Self { + let buffer = message.into_boxed_str(); + let ptr = buffer.as_ptr(); + let len = buffer.len(); + std::mem::forget(buffer); + let text_value = TextValue::new(ptr, len); + Self { + code, + message: Box::into_raw(Box::new(text_value)), + } + } + + unsafe fn free(self) { + if !self.message.is_null() { + let _ = Box::from_raw(self.message); // Freed by the same library + } + } +} + #[repr(C)] pub struct Blob { data: *const u8, @@ -156,40 +211,42 @@ impl Value { pub fn null() -> Self { Self { value_type: ValueType::Null, - value: std::ptr::null_mut(), + value: ValueData { int: 0 }, } } /// Returns the value type of the Value + /// # Safety + /// This function accesses the value_type field of the union. + /// it is safe to call this function as long as the value was properly + /// constructed with one of the provided methods pub fn value_type(&self) -> ValueType { self.value_type } - /// Returns the float value if the Value is the proper type + /// Returns the float value or casts the relevant value to a float pub fn to_float(&self) -> Option { - if self.value.is_null() { - return None; - } match self.value_type { - ValueType::Float => Some(unsafe { *(self.value as *const f64) }), - ValueType::Integer => Some(unsafe { *(self.value as *const i64) as f64 }), + ValueType::Float => Some(unsafe { self.value.float }), + ValueType::Integer => Some(unsafe { self.value.int } as f64), ValueType::Text => { - let txt = unsafe { &*(self.value as *const TextValue) }; - txt.as_str().parse().ok() + let txt = self.to_text().unwrap_or_default(); + txt.parse().ok() } _ => None, } } + /// Returns the text value if the Value is the proper type - pub fn to_text(&self) -> Option { - if self.value_type != ValueType::Text { - return None; + pub fn to_text(&self) -> Option<&str> { + unsafe { + if self.value_type == ValueType::Text && !self.value.text.is_null() { + let txt = &*self.value.text; + Some(txt.as_str()) + } else { + None + } } - if self.value.is_null() { - return None; - } - let txt = unsafe { &*(self.value as *const TextValue) }; - Some(String::from(txt.as_str())) } /// Returns the blob value if the Value is the proper type @@ -197,119 +254,120 @@ impl Value { if self.value_type != ValueType::Blob { return None; } - if self.value.is_null() { + if unsafe { self.value.blob.is_null() } { return None; } - let blob = unsafe { &*(self.value as *const Blob) }; + let blob = unsafe { &*(self.value.blob) }; let slice = unsafe { std::slice::from_raw_parts(blob.data, blob.size as usize) }; Some(slice.to_vec()) } /// Returns the integer value if the Value is the proper type pub fn to_integer(&self) -> Option { - if self.value.is_null() { - return None; - } match self.value_type() { - ValueType::Integer => Some(unsafe { *(self.value as *const i64) }), - ValueType::Float => Some(unsafe { *(self.value as *const f64) } as i64), - ValueType::Text => { - let txt = unsafe { &*(self.value as *const TextValue) }; - txt.as_str().parse().ok() - } + ValueType::Integer => Some(unsafe { self.value.int }), + ValueType::Float => Some(unsafe { self.value.float } as i64), + ValueType::Text => self + .to_text() + .map(|txt| txt.parse::().unwrap_or_default()), _ => None, } } - /// Returns the error message if the value is an error - pub fn to_error(&self) -> Option { + /// Returns the error code if the value is an error + pub fn to_error(&self) -> Option { if self.value_type != ValueType::Error { return None; } - if self.value.is_null() { + if unsafe { self.value.error.is_null() } { return None; } - let err = unsafe { &*(self.value as *const ExtError) }; - match &err.error_type { - ErrorType::User => { - if err.message.is_null() { - return None; - } - let txt = unsafe { &*(err.message as *const TextValue) }; - Some(txt.as_str().to_string()) - } - ErrorType::ErrCode { code } => Some(format!("{}", code)), + let err = unsafe { &*self.value.error }; + Some(err.code) + } + + /// Returns the error code and optional message if the value is an error + pub fn to_error_details(&self) -> Option<(ResultCode, Option)> { + if self.value_type != ValueType::Error || unsafe { self.value.error.is_null() } { + return None; + } + let err_val = unsafe { &*(self.value.error) }; + let code = err_val.code; + + if err_val.message.is_null() { + Some((code, None)) + } else { + let txt = unsafe { &*(err_val.message as *const TextValue) }; + let msg = txt.as_str().to_owned(); + Some((code, Some(msg))) } } /// Creates a new integer Value from an i64 - pub fn from_integer(value: i64) -> Self { - let boxed = Box::new(value); + pub fn from_integer(i: i64) -> Self { Self { value_type: ValueType::Integer, - value: Box::into_raw(boxed) as *mut c_void, + value: ValueData { int: i }, } } /// Creates a new float Value from an f64 pub fn from_float(value: f64) -> Self { - let boxed = Box::new(value); Self { value_type: ValueType::Float, - value: Box::into_raw(boxed) as *mut c_void, + value: ValueData { float: value }, } } + /// Creates a new text Value from a String + /// This function allocates/leaks the string + /// and must be free'd manually pub fn from_text(s: String) -> Self { - let buffer = s.into_boxed_str(); - let ptr = buffer.as_ptr(); - let len = buffer.len(); - std::mem::forget(buffer); - let text_value = TextValue::new(ptr, len); - let text_box = Box::new(text_value); + let txt_value = TextValue::new_boxed(s); + let ptr = Box::into_raw(txt_value); Self { value_type: ValueType::Text, - value: Box::into_raw(text_box) as *mut c_void, + value: ValueData { text: ptr }, } } /// Creates a new error Value from a ResultCode - pub fn error(err: ResultCode) -> Self { - let error = ExtError { - error_type: ErrorType::ErrCode { code: err }, - message: std::ptr::null_mut(), - }; + /// This function allocates/leaks the error + /// and must be free'd manually + pub fn error(code: ResultCode) -> Self { + let err_val = ErrValue::new(code); Self { value_type: ValueType::Error, - value: Box::into_raw(Box::new(error)) as *mut c_void, + value: ValueData { + error: Box::into_raw(Box::new(err_val)) as *const ErrValue, + }, } } - /// Create a new user defined error Value with a message - pub fn custom_error(s: String) -> Self { - let buffer = s.into_boxed_str(); - let ptr = buffer.as_ptr(); - let len = buffer.len(); - std::mem::forget(buffer); - let text_value = TextValue::new(ptr, len); - let text_box = Box::new(text_value); - let error = ExtError { - error_type: ErrorType::User, - message: Box::into_raw(text_box) as *mut c_void, - }; + /// Creates a new error Value from a ResultCode and a message + /// This function allocates/leaks the error, must be free'd manually + pub fn error_with_message(message: String) -> Self { + let err_value = ErrValue::new_with_message(ResultCode::CustomError, message); + let err_box = Box::new(err_value); Self { value_type: ValueType::Error, - value: Box::into_raw(Box::new(error)) as *mut c_void, + value: ValueData { + error: Box::into_raw(err_box) as *const ErrValue, + }, } } /// Creates a new blob Value from a Vec + /// This function allocates/leaks the blob + /// and must be free'd manually pub fn from_blob(value: Vec) -> Self { let boxed = Box::new(Blob::new(value.as_ptr(), value.len() as u64)); std::mem::forget(value); Self { value_type: ValueType::Blob, - value: Box::into_raw(boxed) as *mut c_void, + value: ValueData { + blob: Box::into_raw(boxed) as *const Blob, + }, } } @@ -318,41 +376,18 @@ impl Value { /// however this does assume that the type was properly constructed with /// the appropriate value_type and value. pub unsafe fn free(self) { - if self.value.is_null() { - return; - } match self.value_type { - ValueType::Integer => { - let _ = Box::from_raw(self.value as *mut i64); - } - ValueType::Float => { - let _ = Box::from_raw(self.value as *mut f64); - } ValueType::Text => { - let _ = Box::from_raw(self.value as *mut TextValue); + let _ = Box::from_raw(self.value.text as *mut TextValue); } ValueType::Blob => { - let _ = Box::from_raw(self.value as *mut Blob); + let _ = Box::from_raw(self.value.blob as *mut Blob); } ValueType::Error => { - let _ = Box::from_raw(self.value as *mut ExtError); + let err_val = Box::from_raw(self.value.error as *mut ErrValue); + err_val.free(); } - ValueType::Null => {} + _ => {} } } } - -#[repr(C)] -pub struct ExtError { - pub error_type: ErrorType, - pub message: *mut std::ffi::c_void, -} - -#[repr(C)] -pub enum ErrorType { - User, - /// User type has a user provided message - ErrCode { - code: ResultCode, - }, -} diff --git a/extensions/percentile/Cargo.toml b/extensions/percentile/Cargo.toml index 91340b813..614578944 100644 --- a/extensions/percentile/Cargo.toml +++ b/extensions/percentile/Cargo.toml @@ -7,7 +7,13 @@ license.workspace = true repository.workspace = true [lib] -crate-type = ["cdylib"] +crate-type = ["cdylib", "lib"] + +[features] +static = ["limbo_ext/static"] [dependencies] -limbo_ext = { path = "../core" } +limbo_ext = { path = "../core", features = ["static"] } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } diff --git a/extensions/percentile/src/lib.rs b/extensions/percentile/src/lib.rs index ceecc0e3b..9f81a6674 100644 --- a/extensions/percentile/src/lib.rs +++ b/extensions/percentile/src/lib.rs @@ -1,4 +1,4 @@ -use limbo_ext::{register_extension, AggFunc, AggregateDerive, ResultCode, Value}; +use limbo_ext::{register_extension, AggFunc, AggregateDerive, Value}; register_extension! { aggregates: { Median, Percentile, PercentileCont, PercentileDisc } @@ -41,7 +41,7 @@ impl AggFunc for Median { struct Percentile; impl AggFunc for Percentile { - type State = (Vec, Option, Option<()>); + type State = (Vec, Option, Option<&'static str>); const NAME: &'static str = "percentile"; const ARGS: i32 = 2; @@ -53,13 +53,13 @@ impl AggFunc for Percentile { args.get(1).and_then(Value::to_float), ) { if !(0.0..=100.0).contains(&p) { - err_value.get_or_insert(()); + err_value.get_or_insert("Invalid percentile value"); return; } if let Some(existing_p) = *p_value { if (existing_p - p).abs() >= 0.001 { - err_value.get_or_insert(()); + err_value.get_or_insert("Inconsistent percentile values across rows"); return; } } else { @@ -74,8 +74,8 @@ impl AggFunc for Percentile { if values.is_empty() { return Value::null(); } - if err_value.is_some() { - return Value::error(ResultCode::Error); + if let Some(err) = err_value { + return Value::error_with_message(err.into()); } if values.len() == 1 { return Value::from_float(values[0]); @@ -101,7 +101,7 @@ impl AggFunc for Percentile { struct PercentileCont; impl AggFunc for PercentileCont { - type State = (Vec, Option, Option<()>); + type State = (Vec, Option, Option<&'static str>); const NAME: &'static str = "percentile_cont"; const ARGS: i32 = 2; @@ -113,13 +113,13 @@ impl AggFunc for PercentileCont { args.get(1).and_then(Value::to_float), ) { if !(0.0..=1.0).contains(&p) { - err_state.get_or_insert(()); + err_state.get_or_insert("Percentile value must be between 0.0 and 1.0 inclusive"); return; } if let Some(existing_p) = *p_value { if (existing_p - p).abs() >= 0.001 { - err_state.get_or_insert(()); + err_state.get_or_insert("Inconsistent percentile values across rows"); return; } } else { @@ -134,8 +134,8 @@ impl AggFunc for PercentileCont { if values.is_empty() { return Value::null(); } - if err_state.is_some() { - return Value::error(ResultCode::Error); + if let Some(err) = err_state { + return Value::error_with_message(err.into()); } if values.len() == 1 { return Value::from_float(values[0]); @@ -161,7 +161,7 @@ impl AggFunc for PercentileCont { struct PercentileDisc; impl AggFunc for PercentileDisc { - type State = (Vec, Option, Option<()>); + type State = (Vec, Option, Option<&'static str>); const NAME: &'static str = "percentile_disc"; const ARGS: i32 = 2; @@ -175,8 +175,8 @@ impl AggFunc for PercentileDisc { if values.is_empty() { return Value::null(); } - if err_value.is_some() { - return Value::error(ResultCode::Error); + if let Some(err) = err_value { + return Value::error_with_message(err.into()); } let p = p_value.unwrap(); diff --git a/extensions/regexp/Cargo.toml b/extensions/regexp/Cargo.toml index dc2f87c3b..c8288e601 100644 --- a/extensions/regexp/Cargo.toml +++ b/extensions/regexp/Cargo.toml @@ -6,11 +6,18 @@ edition.workspace = true license.workspace = true repository.workspace = true +[features] +static = ["limbo_ext/static"] +defaults = [] + [lib] crate-type = ["cdylib", "lib"] [dependencies] -limbo_ext = { path = "../core"} +limbo_ext = { path = "../core", features = ["static"] } regex = "1.11.1" log = "0.4.20" + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } diff --git a/extensions/regexp/src/lib.rs b/extensions/regexp/src/lib.rs index a4531acbc..6f037e4d4 100644 --- a/extensions/regexp/src/lib.rs +++ b/extensions/regexp/src/lib.rs @@ -19,11 +19,11 @@ fn regex(pattern: &Value, haystack: &Value) -> Value { let Some(haystack) = haystack.to_text() else { return Value::null(); }; - let re = match Regex::new(&pattern) { + let re = match Regex::new(pattern) { Ok(re) => re, Err(_) => return Value::null(), }; - Value::from_integer(re.is_match(&haystack) as i64) + Value::from_integer(re.is_match(haystack) as i64) } _ => Value::null(), } diff --git a/extensions/time/Cargo.toml b/extensions/time/Cargo.toml new file mode 100644 index 000000000..6220ddffd --- /dev/null +++ b/extensions/time/Cargo.toml @@ -0,0 +1,23 @@ +[package] +authors.workspace = true +edition.workspace = true +license.workspace = true +name = "limbo_time" +repository.workspace = true +version.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +static = ["limbo_ext/static"] + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } + +[dependencies] +chrono = "0.4.39" +limbo_ext = { path = "../core", features = ["static"] } +strum = "0.26.3" +strum_macros = "0.26.3" +thiserror = "2.0.11" diff --git a/extensions/time/src/lib.rs b/extensions/time/src/lib.rs new file mode 100644 index 000000000..5c4d5a383 --- /dev/null +++ b/extensions/time/src/lib.rs @@ -0,0 +1,1019 @@ +use std::str::FromStr as _; + +use chrono::prelude::*; +use core::cmp::Ordering; +use limbo_ext::ValueType; +use thiserror::Error; + +use limbo_ext::{register_extension, scalar, ResultCode, Value}; + +mod time; + +use time::*; + +register_extension! { + scalars: { + time_now, + time_date, + make_date, + make_timestamp, + time_get, + time_get_year, + time_get_month, + time_get_day, + time_get_hour, + time_get_minute, + time_get_second, + time_get_nano, + time_get_weekday, + time_get_yearday, + time_get_isoyear, + time_get_isoweek, + time_unix, + to_timestamp, + time_milli, + time_micro, + time_nano, + time_to_unix, + time_to_milli, + time_to_micro, + time_to_nano, + time_after, + time_before, + time_compare, + time_equal, + dur_ns, + dur_us, + dur_ms, + dur_s, + dur_m, + dur_h, + time_add, + time_add_date, + time_sub, + time_since, + time_until, + time_trunc, + time_round, + time_fmt_iso, + time_fmt_datetime, + time_fmt_date, + time_fmt_time, + time_parse, + }, +} + +macro_rules! ok_tri { + ($e:expr) => { + match $e { + Some(val) => val, + None => return Value::error(ResultCode::Error), + } + }; + ($e:expr, $msg:expr) => { + match $e { + Some(val) => val, + None => return Value::error_with_message($msg.to_string()), + } + }; +} + +macro_rules! tri { + ($e:expr) => { + match $e { + Ok(val) => val, + Err(err) => return Value::error_with_message(err.to_string()), + } + }; + ($e:expr, $msg:expr) => { + match $e { + Ok(val) => val, + Err(_) => return Value::error_with_message($msg.to_string()), + } + }; +} + +/// Checks to see if e's enum is of type val +macro_rules! value_tri { + ($e:expr, $val:pat) => { + match $e { + $val => (), + _ => return Value::error(ResultCode::InvalidArgs), + } + }; + ($e:expr, $val:pat, $msg:expr) => { + match $e { + $val => (), + _ => return Value::error_with_message($msg.to_string()), + } + }; +} + +#[derive(Error, Debug)] +pub enum TimeError { + /// Timezone offset is invalid + #[error("invalid timezone offset")] + InvalidOffset, + #[error("invalid datetime format")] + InvalidFormat, + /// Blob is not size of `TIME_BLOB_SIZE` + #[error("invalid time blob size")] + InvalidSize, + /// Blob time version not matching + #[error("mismatch time blob version")] + MismatchVersion, + #[error("unknown field")] + UnknownField(#[from] ::Err), + #[error("rounding error")] + RoundingError(#[from] chrono::RoundingError), + #[error("time creation error")] + CreationError, +} + +type Result = core::result::Result; + +#[scalar(name = "time_now", alias = "now")] +fn time_now(args: &[Value]) -> Value { + if !args.is_empty() { + return Value::error(ResultCode::InvalidArgs); + } + let t = Time::new(); + + t.into_blob() +} + +/// ```text +/// time_date(year, month, day[, hour, min, sec[, nsec[, offset_sec]]]) +/// ``` +/// +/// Returns the Time corresponding to a given date/time. The time part (hour+minute+second), the nanosecond part, and the timezone offset part are all optional. +/// +/// The `month`, `day`, `hour`, `min`, `sec`, and `nsec` values may be outside their usual ranges and will be normalized during the conversion. For example, October 32 converts to November 1. +/// +/// If `offset_sec` is not 0, the source time is treated as being in a given timezone (with an offset in seconds east of UTC) and converted back to UTC. +fn time_date_internal(args: &[Value]) -> Value { + if args.len() != 3 && args.len() != 6 && args.len() != 7 && args.len() != 8 { + return Value::error(ResultCode::InvalidArgs); + } + + for arg in args { + value_tri!( + arg.value_type(), + ValueType::Integer, + "all parameters should be integers" + ); + } + + let year = ok_tri!(args[0].to_integer()); + let month = ok_tri!(args[1].to_integer()); + let day = ok_tri!(args[2].to_integer()); + let mut hour = 0; + let mut minutes = 0; + let mut seconds = 0; + let mut nano_secs = 0; + let mut offset = FixedOffset::east_opt(0).unwrap(); + + if args.len() >= 6 { + hour = ok_tri!(args[3].to_integer()); + minutes = ok_tri!(args[4].to_integer()); + seconds = ok_tri!(args[5].to_integer()); + } + + if args.len() >= 7 { + nano_secs = ok_tri!(args[6].to_integer()); + } + + if args.len() == 8 { + let offset_sec = ok_tri!(args[7].to_integer()) as i32; + // TODO offset is not normalized. Maybe could just increase/decrease the number of seconds + // instead of relying in this offset + offset = ok_tri!(FixedOffset::east_opt(offset_sec)); + } + + let t = Time::time_date( + year as i32, + month as i32, + day, + hour, + minutes, + seconds, + nano_secs, + offset, + ); + + let t = tri!(t); + + t.into_blob() +} + +#[scalar(name = "time_date")] +fn time_date(args: &[Value]) { + time_date_internal(args) +} + +#[scalar(name = "make_date")] +fn make_date(args: &[Value]) -> Value { + if args.len() != 3 { + return Value::error(ResultCode::InvalidArgs); + } + + time_date_internal(args) +} + +#[scalar(name = "make_timestamp")] +fn make_timestamp(args: &[Value]) -> Value { + if args.len() != 6 { + return Value::error(ResultCode::InvalidArgs); + } + + time_date_internal(args) +} + +#[scalar(name = "time_get", alias = "date_part")] +fn time_get(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + let field = ok_tri!(args[1].to_text(), "2nd parameter: should be a field name"); + + let field = tri!(TimeField::from_str(field)); + + t.time_get(field) +} + +#[scalar(name = "time_get_year")] +fn time_get_year(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + t.time_get(TimeField::Year) +} + +#[scalar(name = "time_get_month")] +fn time_get_month(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + t.time_get(TimeField::Month) +} + +#[scalar(name = "time_get_day")] +fn time_get_day(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + t.time_get(TimeField::Day) +} + +#[scalar(name = "time_get_hour")] +fn time_get_hour(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + t.time_get(TimeField::Hour) +} + +#[scalar(name = "time_get_minute")] +fn time_get_minute(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + t.time_get(TimeField::Minute) +} + +#[scalar(name = "time_get_second")] +fn time_get_second(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + Value::from_integer(t.get_second()) +} + +#[scalar(name = "time_get_nano")] +fn time_get_nano(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + Value::from_integer(t.get_nanosecond()) +} + +#[scalar(name = "time_get_weekday")] +fn time_get_weekday(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + t.time_get(TimeField::WeekDay) +} + +#[scalar(name = "time_get_yearday")] +fn time_get_yearday(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + t.time_get(TimeField::YearDay) +} + +#[scalar(name = "time_get_isoyear")] +fn time_get_isoyear(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + t.time_get(TimeField::IsoYear) +} + +#[scalar(name = "time_get_isoweek")] +fn time_get_isoweek(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + t.time_get(TimeField::IsoWeek) +} + +fn time_unix_internal(args: &[Value]) -> Value { + if args.len() != 1 && args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + for arg in args { + value_tri!( + arg.value_type(), + ValueType::Integer, + "all parameters should be integers" + ); + } + + let seconds = ok_tri!(args[0].to_integer()); + + let mut nano_sec = 0; + + if args.len() == 2 { + nano_sec = ok_tri!(args[1].to_integer()); + } + + let dt = ok_tri!(DateTime::from_timestamp(seconds, nano_sec as u32)); + + let t = Time::from_datetime(dt); + + t.into_blob() +} + +#[scalar(name = "time_unix")] +fn time_unix(args: &[Value]) -> Value { + time_unix_internal(args) +} + +#[scalar(name = "to_timestamp")] +fn to_timestamp(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + time_unix_internal(args) +} + +#[scalar(name = "time_milli")] +fn time_milli(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + value_tri!( + &args[0].value_type(), + ValueType::Integer, + "parameter should be an integer" + ); + + let millis = ok_tri!(args[0].to_integer()); + + let dt = ok_tri!(DateTime::from_timestamp_millis(millis)); + + let t = Time::from_datetime(dt); + + t.into_blob() +} + +#[scalar(name = "time_micro")] +fn time_micro(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + value_tri!( + &args[0].value_type(), + ValueType::Integer, + "parameter should be an integer" + ); + + let micros = ok_tri!(args[0].to_integer()); + + let dt = ok_tri!(DateTime::from_timestamp_micros(micros)); + + let t = Time::from_datetime(dt); + + t.into_blob() +} + +#[scalar(name = "time_nano")] +fn time_nano(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + value_tri!( + &args[0].value_type(), + ValueType::Integer, + "parameter should be an integer" + ); + + let nanos = ok_tri!(args[0].to_integer()); + + let dt = DateTime::from_timestamp_nanos(nanos); + + let t = Time::from_datetime(dt); + + t.into_blob() +} + +#[scalar(name = "time_to_unix")] +fn time_to_unix(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + Value::from_integer(t.to_unix()) +} + +#[scalar(name = "time_to_milli")] +fn time_to_milli(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + Value::from_integer(t.to_unix_milli()) +} + +#[scalar(name = "time_to_micro")] +fn time_to_micro(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + Value::from_integer(t.to_unix_micro()) +} + +#[scalar(name = "time_to_nano")] +fn time_to_nano(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + Value::from_integer(ok_tri!(t.to_unix_nano())) +} + +// Comparisons + +#[scalar(name = "time_after")] +fn time_after(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + let blob = ok_tri!(args[1].to_blob(), "2nd parameter: should be a time blob"); + + let u = tri!(Time::try_from(blob)); + + Value::from_integer((t > u).into()) +} + +#[scalar(name = "time_before")] +fn time_before(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + let blob = ok_tri!(args[1].to_blob(), "2nd parameter: should be a time blob"); + + let u = tri!(Time::try_from(blob)); + + Value::from_integer((t < u).into()) +} + +#[scalar(name = "time_compare")] +fn time_compare(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + let blob = ok_tri!(args[1].to_blob(), "2nd parameter: should be a time blob"); + + let u = tri!(Time::try_from(blob)); + + let cmp = match ok_tri!(t.partial_cmp(&u)) { + Ordering::Less => -1, + Ordering::Greater => 1, + Ordering::Equal => 0, + }; + + Value::from_integer(cmp) +} + +#[scalar(name = "time_equal")] +fn time_equal(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + let blob = ok_tri!(args[1].to_blob(), "2nd parameter: should be a time blob"); + + let u = tri!(Time::try_from(blob)); + + Value::from_integer(t.eq(&u).into()) +} + +// Duration Constants + +/// 1 nanosecond +#[scalar(name = "dur_ns")] +fn dur_ns(args: &[Value]) -> Value { + if !args.is_empty() { + return Value::error(ResultCode::InvalidArgs); + } + + Value::from_integer(chrono::Duration::nanoseconds(1).num_nanoseconds().unwrap()) +} + +/// 1 microsecond +#[scalar(name = "dur_us")] +fn dur_us(args: &[Value]) -> Value { + if !args.is_empty() { + return Value::error(ResultCode::InvalidArgs); + } + + Value::from_integer(chrono::Duration::microseconds(1).num_nanoseconds().unwrap()) +} + +/// 1 millisecond +#[scalar(name = "dur_ms")] +fn dur_ms(args: &[Value]) -> Value { + if !args.is_empty() { + return Value::error(ResultCode::InvalidArgs); + } + + Value::from_integer(chrono::Duration::milliseconds(1).num_nanoseconds().unwrap()) +} + +/// 1 second +#[scalar(name = "dur_s")] +fn dur_s(args: &[Value]) -> Value { + if !args.is_empty() { + return Value::error(ResultCode::InvalidArgs); + } + + Value::from_integer(chrono::Duration::seconds(1).num_nanoseconds().unwrap()) +} + +/// 1 minute +#[scalar(name = "dur_m")] +fn dur_m(args: &[Value]) -> Value { + if !args.is_empty() { + return Value::error(ResultCode::InvalidArgs); + } + + Value::from_integer(chrono::Duration::minutes(1).num_nanoseconds().unwrap()) +} + +/// 1 hour +#[scalar(name = "dur_h")] +fn dur_h(args: &[Value]) -> Value { + if !args.is_empty() { + return Value::error(ResultCode::InvalidArgs); + } + + Value::from_integer(chrono::Duration::hours(1).num_nanoseconds().unwrap()) +} + +// Time Arithmetic + +/// Do not use `time_add` to add days, months or years. Use `time_add_date` instead. +#[scalar(name = "time_add", alias = "date_add")] +fn time_add(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + let t = tri!(Time::try_from(blob)); + + value_tri!( + args[1].value_type(), + ValueType::Integer, + "2nd parameter: should be an integer" + ); + + let d = ok_tri!(args[1].to_integer()); + + let d = Duration::from(d); + + t.add_duration(d).into_blob() +} + +#[scalar(name = "time_add_date")] +fn time_add_date(args: &[Value]) -> Value { + if args.len() != 2 && args.len() != 3 && args.len() != 4 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + let t = tri!(Time::try_from(blob)); + + value_tri!( + args[1].value_type(), + ValueType::Integer, + "2nd parameter: should be an integer" + ); + + let years = ok_tri!(args[1].to_integer()); + let mut months = 0; + let mut days = 0; + + if args.len() >= 3 { + value_tri!( + args[2].value_type(), + ValueType::Integer, + "3rd parameter: should be an integer" + ); + + months = ok_tri!(args[2].to_integer()); + } + + if args.len() == 4 { + value_tri!( + args[3].value_type(), + ValueType::Integer, + "4th parameter: should be an integer" + ); + + days = ok_tri!(args[3].to_integer()); + } + + let t: Time = tri!(t.time_add_date(years as i32, months as i32, days)); + + t.into_blob() +} + +/// Returns the duration between two time values t and u (in nanoseconds). +/// If the result exceeds the maximum (or minimum) value that can be stored in a Duration, +/// the maximum (or minimum) duration will be returned. +fn time_sub_internal(t: Time, u: Time) -> Value { + let cmp = ok_tri!(t.partial_cmp(&u)); + + let diff = t - u; + + let nano_secs = match diff.num_nanoseconds() { + Some(nano) => nano, + None => match cmp { + Ordering::Equal => ok_tri!(diff.num_nanoseconds()), + Ordering::Less => i64::MIN, + Ordering::Greater => i64::MAX, + }, + }; + + Value::from_integer(nano_secs) +} + +#[scalar(name = "time_sub", alias = "age")] +fn time_sub(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + let t = tri!(Time::try_from(blob)); + + let blob = ok_tri!(args[1].to_blob(), "2nd parameter: should be a time blob"); + let u = tri!(Time::try_from(blob)); + + time_sub_internal(t, u) +} + +#[scalar(name = "time_since")] +fn time_since(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let now = Time::new(); + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + let t = tri!(Time::try_from(blob)); + + time_sub_internal(now, t) +} + +#[scalar(name = "time_until")] +fn time_until(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let now = Time::new(); + + let blob = ok_tri!(args[0].to_blob(), "parameter should be a time blob"); + let t = tri!(Time::try_from(blob)); + + time_sub_internal(t, now) +} + +// Rouding + +#[scalar(name = "time_trunc", alias = "date_trunc")] +fn time_trunc(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + match args[1].value_type() { + ValueType::Text => { + let field = ok_tri!(args[1].to_text()); + + let field = tri!(TimeRoundField::from_str(field)); + + tri!(t.trunc_field(field)).into_blob() + } + ValueType::Integer => { + let duration = ok_tri!(args[1].to_integer()); + let duration = Duration::from(duration); + + tri!(t.trunc_duration(duration)).into_blob() + } + _ => Value::error_with_message("2nd parameter: should be a field name".to_string()), + } +} + +#[scalar(name = "time_round")] +fn time_round(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + value_tri!( + args[1].value_type(), + ValueType::Integer, + "2nd parameter: should be an integer" + ); + + let duration = ok_tri!(args[1].to_integer()); + let duration = Duration::from(duration); + + tri!(t.round_duration(duration)).into_blob() +} + +// Formatting + +#[scalar(name = "time_fmt_iso")] +fn time_fmt_iso(args: &[Value]) -> Value { + if args.len() != 1 && args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + let offset_sec = { + if args.len() == 2 { + value_tri!( + args[1].value_type(), + ValueType::Integer, + "2nd parameter: should be an integer" + ); + ok_tri!(args[1].to_integer()) as i32 + } else { + 0 + } + }; + + let fmt_str = tri!(t.fmt_iso(offset_sec)); + + Value::from_text(fmt_str) +} + +#[scalar(name = "time_fmt_datetime")] +fn time_fmt_datetime(args: &[Value]) -> Value { + if args.len() != 1 && args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + let offset_sec = { + if args.len() == 2 { + value_tri!( + args[1].value_type(), + ValueType::Integer, + "2nd parameter: should be an integer" + ); + ok_tri!(args[1].to_integer()) as i32 + } else { + 0 + } + }; + + let fmt_str = tri!(t.fmt_datetime(offset_sec)); + + Value::from_text(fmt_str) +} + +#[scalar(name = "time_fmt_date")] +fn time_fmt_date(args: &[Value]) -> Value { + if args.len() != 1 && args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + let offset_sec = { + if args.len() == 2 { + value_tri!( + args[1].value_type(), + ValueType::Integer, + "2nd parameter: should be an integer" + ); + ok_tri!(args[1].to_integer()) as i32 + } else { + 0 + } + }; + + let fmt_str = tri!(t.fmt_date(offset_sec)); + + Value::from_text(fmt_str) +} + +#[scalar(name = "time_fmt_time")] +fn time_fmt_time(args: &[Value]) -> Value { + if args.len() != 1 && args.len() != 2 { + return Value::error(ResultCode::InvalidArgs); + } + let blob = ok_tri!(args[0].to_blob(), "1st parameter: should be a time blob"); + + let t = tri!(Time::try_from(blob)); + + let offset_sec = { + if args.len() == 2 { + value_tri!( + args[1].value_type(), + ValueType::Integer, + "2nd parameter: should be an integer" + ); + ok_tri!(args[1].to_integer()) as i32 + } else { + 0 + } + }; + + let fmt_str = tri!(t.fmt_time(offset_sec)); + + Value::from_text(fmt_str) +} + +#[scalar(name = "time_parse")] +fn time_parse(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::InvalidArgs); + } + + let dt_str = ok_tri!(args[0].to_text()); + + if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(dt_str) { + return Time::from_datetime(dt.to_utc()).into_blob(); + } + + if let Ok(mut dt) = chrono::NaiveDateTime::parse_from_str(dt_str, "%Y-%m-%d %H:%M:%S") { + // Unwrap is safe here + dt = dt.with_nanosecond(0).unwrap(); + return Time::from_datetime(dt.and_utc()).into_blob(); + } + + if let Ok(date) = chrono::NaiveDate::parse_from_str(dt_str, "%Y-%m-%d") { + // Unwrap is safe here + + let dt = date + .and_hms_opt(0, 0, 0) + .unwrap() + .with_nanosecond(0) + .unwrap(); + return Time::from_datetime(dt.and_utc()).into_blob(); + } + + let time = tri!( + chrono::NaiveTime::parse_from_str(dt_str, "%H:%M:%S"), + "error parsing datetime string" + ); + let dt = NaiveDateTime::new(NaiveDate::from_ymd_opt(1, 1, 1).unwrap(), time) + .with_nanosecond(0) + .unwrap(); + + Time::from_datetime(dt.and_utc()).into_blob() +} diff --git a/extensions/time/src/time.rs b/extensions/time/src/time.rs new file mode 100644 index 000000000..238d599d5 --- /dev/null +++ b/extensions/time/src/time.rs @@ -0,0 +1,560 @@ +use std::ops::{Deref, Sub}; + +use chrono::{self, DateTime, Timelike, Utc}; +use chrono::{prelude::*, DurationRound}; + +use limbo_ext::Value; + +use crate::{Result, TimeError}; + +const DAYS_BEFORE_EPOCH: i64 = 719162; +const TIME_BLOB_SIZE: usize = 13; +const VERSION: u8 = 1; + +#[derive(Debug, PartialEq, PartialOrd, Eq)] +pub struct Time { + inner: DateTime, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd)] +pub struct Duration { + inner: chrono::Duration, +} + +#[derive(strum_macros::Display, strum_macros::EnumString)] +pub enum TimeField { + #[strum(to_string = "millennium")] + Millennium, + #[strum(to_string = "century")] + Century, + #[strum(to_string = "decade")] + Decade, + #[strum(to_string = "year")] + Year, + #[strum(to_string = "quarter")] + Quarter, + #[strum(to_string = "month")] + Month, + #[strum(to_string = "day")] + Day, + #[strum(to_string = "hour")] + Hour, + #[strum(to_string = "minute")] + Minute, + #[strum(to_string = "second")] + Second, + #[strum(to_string = "millisecond")] + MilliSecond, + #[strum(to_string = "milli")] + Milli, + #[strum(to_string = "microsecond")] + MicroSecond, + #[strum(to_string = "micro")] + Micro, + #[strum(to_string = "nanosecond")] + NanoSecond, + #[strum(to_string = "nano")] + Nano, + #[strum(to_string = "isoyear")] + IsoYear, + #[strum(to_string = "isoweek")] + IsoWeek, + #[strum(to_string = "isodow")] + IsoDow, + #[strum(to_string = "yearday")] + YearDay, + #[strum(to_string = "weekday")] + WeekDay, + #[strum(to_string = "epoch")] + Epoch, +} + +#[derive(strum_macros::Display, strum_macros::EnumString)] +pub enum TimeRoundField { + #[strum(to_string = "millennium")] + Millennium, + #[strum(to_string = "century")] + Century, + #[strum(to_string = "decade")] + Decade, + #[strum(to_string = "year")] + Year, + #[strum(to_string = "quarter")] + Quarter, + #[strum(to_string = "month")] + Month, + #[strum(to_string = "week")] + Week, + #[strum(to_string = "day")] + Day, + #[strum(to_string = "hour")] + Hour, + #[strum(to_string = "minute")] + Minute, + #[strum(to_string = "second")] + Second, + #[strum(to_string = "millisecond")] + MilliSecond, + #[strum(to_string = "milli")] + Milli, + #[strum(to_string = "microsecond")] + MicroSecond, + #[strum(to_string = "micro")] + Micro, +} + +impl Time { + /// Returns a new instance of Time with tracking UTC::now + pub fn new() -> Self { + Self { inner: Utc::now() } + } + + pub fn into_blob(self) -> Value { + let blob: [u8; 13] = self.into(); + Value::from_blob(blob.to_vec()) + } + + pub fn fmt_iso(&self, offset_sec: i32) -> Result { + if offset_sec == 0 { + if self.inner.nanosecond() == 0 { + return Ok(self.inner.format("%FT%TZ").to_string()); + } else { + return Ok(self.inner.format("%FT%T%.9fZ").to_string()); + } + } + // I do not see how this can error + let offset = &FixedOffset::east_opt(offset_sec).ok_or(TimeError::InvalidFormat)?; + + let timezone_date = self.inner.with_timezone(offset); + + if timezone_date.nanosecond() == 0 { + Ok(timezone_date.format("%FT%T%:z").to_string()) + } else { + Ok(timezone_date.format("%FT%T%.9f%:z").to_string()) + } + } + + pub fn fmt_datetime(&self, offset_sec: i32) -> Result { + let fmt = "%F %T"; + + if offset_sec == 0 { + return Ok(self.inner.format(fmt).to_string()); + } + // I do not see how this can error + let offset = &FixedOffset::east_opt(offset_sec).ok_or(TimeError::InvalidFormat)?; + + let timezone_date = self.inner.with_timezone(offset); + + Ok(timezone_date.format(fmt).to_string()) + } + + pub fn fmt_date(&self, offset_sec: i32) -> Result { + let fmt = "%F"; + + if offset_sec == 0 { + return Ok(self.inner.format(fmt).to_string()); + } + // I do not see how this can error + let offset = &FixedOffset::east_opt(offset_sec).ok_or(TimeError::InvalidFormat)?; + + let timezone_date = self.inner.with_timezone(offset); + + Ok(timezone_date.format(fmt).to_string()) + } + + pub fn fmt_time(&self, offset_sec: i32) -> Result { + let fmt = "%T"; + + if offset_sec == 0 { + return Ok(self.inner.format(fmt).to_string()); + } + // I do not see how this can error + let offset = &FixedOffset::east_opt(offset_sec).ok_or(TimeError::InvalidFormat)?; + + let timezone_date = self.inner.with_timezone(offset); + + Ok(timezone_date.format(fmt).to_string()) + } + + /// Adjust the datetime to the offset + pub fn from_datetime(dt: DateTime) -> Self { + Self { inner: dt } + } + + // + #[allow(clippy::too_many_arguments)] + pub fn time_date( + year: i32, + month: i32, + day: i64, + hour: i64, + minutes: i64, + seconds: i64, + nano_secs: i64, + offset: FixedOffset, + ) -> Result { + let mut dt: NaiveDateTime = NaiveDate::from_ymd_opt(1, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + + match year.cmp(&0) { + std::cmp::Ordering::Greater => { + dt = dt + .checked_add_months(chrono::Months::new((year - 1).unsigned_abs() * 12)) + .ok_or(TimeError::CreationError)? + } + std::cmp::Ordering::Less => { + dt = dt + .checked_sub_months(chrono::Months::new((year - 1).unsigned_abs() * 12)) + .ok_or(TimeError::CreationError)? + } + std::cmp::Ordering::Equal => (), + }; + + match month.cmp(&0) { + std::cmp::Ordering::Greater => { + dt = dt + .checked_add_months(chrono::Months::new((month - 1).unsigned_abs())) + .ok_or(TimeError::CreationError)? + } + std::cmp::Ordering::Less => { + dt = dt + .checked_sub_months(chrono::Months::new((month - 1).unsigned_abs())) + .ok_or(TimeError::CreationError)? + } + std::cmp::Ordering::Equal => (), + }; + + dt += chrono::Duration::try_days(day - 1).ok_or(TimeError::CreationError)?; + + dt += chrono::Duration::try_hours(hour).ok_or(TimeError::CreationError)?; + dt += chrono::Duration::try_minutes(minutes).ok_or(TimeError::CreationError)?; + dt += chrono::Duration::try_seconds(seconds).ok_or(TimeError::CreationError)?; + + dt += chrono::Duration::nanoseconds(nano_secs); + + dt = dt + .and_local_timezone(offset) + .single() + .ok_or(TimeError::CreationError)? + .naive_utc(); + + Ok(dt.into()) + } + + pub fn time_add_date(self, years: i32, months: i32, days: i64) -> Result { + let mut dt: NaiveDateTime = self.into(); + + match years.cmp(&0) { + std::cmp::Ordering::Greater => { + dt = dt + .checked_add_months(chrono::Months::new(years.unsigned_abs() * 12)) + .ok_or(TimeError::CreationError)?; + } + std::cmp::Ordering::Less => { + dt = dt + .checked_sub_months(chrono::Months::new(years.unsigned_abs() * 12)) + .ok_or(TimeError::CreationError)?; + } + std::cmp::Ordering::Equal => (), + }; + + match months.cmp(&0) { + std::cmp::Ordering::Greater => { + dt = dt + .checked_add_months(chrono::Months::new(months.unsigned_abs())) + .ok_or(TimeError::CreationError)? + } + std::cmp::Ordering::Less => { + dt = dt + .checked_sub_months(chrono::Months::new(months.unsigned_abs())) + .ok_or(TimeError::CreationError)? + } + std::cmp::Ordering::Equal => (), + }; + + dt += chrono::Duration::try_days(days).ok_or(TimeError::CreationError)?; + + Ok(dt.into()) + } + + pub fn get_second(&self) -> i64 { + self.inner.second() as i64 + } + + pub fn get_nanosecond(&self) -> i64 { + self.inner.timestamp_subsec_nanos() as i64 + } + + pub fn to_unix(&self) -> i64 { + self.inner.timestamp() + } + + pub fn to_unix_milli(&self) -> i64 { + self.inner.timestamp_millis() + } + + pub fn to_unix_micro(&self) -> i64 { + self.inner.timestamp_micros() + } + + pub fn to_unix_nano(&self) -> Option { + self.inner.timestamp_nanos_opt() + } + + pub fn add_duration(&self, d: Duration) -> Self { + Self { + inner: self.inner + d.inner, + } + } + + pub fn sub_duration(&self, d: Duration) -> Self { + Self { + inner: self.inner - d.inner, + } + } + + pub fn trunc_duration(&self, d: Duration) -> Result { + Ok(Self { + inner: self.inner.duration_trunc(d.inner)?, + }) + } + + pub fn trunc_field(&self, field: TimeRoundField) -> Result { + use TimeRoundField::*; + + let year: i32; + let mut month: i32 = 1; + let mut week: i32 = 0; + let mut day: i64 = 1; + let mut hour: i64 = 0; + let mut minutes: i64 = 0; + let mut seconds: i64 = 0; + let mut nano_secs: i64 = 0; + let offset = FixedOffset::east_opt(0).unwrap(); // UTC + + match field { + Millennium => { + let millennium = (self.inner.year() / 1000) * 1000; + year = millennium; + } + Century => { + let century = (self.inner.year() / 100) * 100; + year = century; + } + Decade => { + let decade = (self.inner.year() / 10) * 10; + year = decade; + } + Year => { + year = self.inner.year(); + } + Quarter => { + let quarter = ((self.inner.month() - 1) / 3) as i32; + year = self.inner.year(); + month = (quarter * 3) + 1; + } + Month => { + year = self.inner.year(); + month = self.inner.month() as i32; + } + Week => { + let isoweek = self.inner.iso_week(); + year = isoweek.year(); + week = isoweek.week() as i32; + } + Day => { + year = self.inner.year(); + month = self.inner.month() as i32; + day = self.inner.day() as i64; + } + Hour => { + year = self.inner.year(); + month = self.inner.month() as i32; + day = self.inner.day() as i64; + hour = self.inner.hour() as i64; + } + Minute => { + year = self.inner.year(); + month = self.inner.month() as i32; + day = self.inner.day() as i64; + hour = self.inner.hour() as i64; + minutes = self.inner.minute() as i64; + } + Second => { + year = self.inner.year(); + month = self.inner.month() as i32; + day = self.inner.day() as i64; + hour = self.inner.hour() as i64; + minutes = self.inner.minute() as i64; + seconds = self.inner.second() as i64; + } + MilliSecond | Milli => { + year = self.inner.year(); + month = self.inner.month() as i32; + day = self.inner.day() as i64; + hour = self.inner.hour() as i64; + minutes = self.inner.minute() as i64; + seconds = self.inner.second() as i64; + nano_secs = (self.inner.nanosecond() / 1_000_000 * 1_000_000) as i64; + } + MicroSecond | Micro => { + year = self.inner.year(); + month = self.inner.month() as i32; + day = self.inner.day() as i64; + hour = self.inner.hour() as i64; + minutes = self.inner.minute() as i64; + seconds = self.inner.second() as i64; + nano_secs = (self.inner.nanosecond() / 1_000 * 1_000) as i64; + } + }; + + let mut ret = Self::time_date(year, month, day, hour, minutes, seconds, nano_secs, offset)?; + + // Means we have to adjust for the week + if week != 0 { + ret = ret.time_add_date(0, 0, ((week - 1) * 7) as i64)?; + } + + Ok(ret) + } + + pub fn round_duration(&self, d: Duration) -> Result { + Ok(Self { + inner: self.inner.duration_round(d.inner)?, + }) + } + + pub fn time_get(&self, field: TimeField) -> Value { + use TimeField::*; + + match field { + Millennium => Value::from_integer((self.inner.year() / 1000) as i64), + Century => Value::from_integer((self.inner.year() / 100) as i64), + Decade => Value::from_integer((self.inner.year() / 10) as i64), + Year => Value::from_integer(self.inner.year() as i64), + Quarter => Value::from_integer(self.inner.month().div_ceil(3) as i64), + Month => Value::from_integer(self.inner.month() as i64), + Day => Value::from_integer(self.inner.day() as i64), + Hour => Value::from_integer(self.inner.hour() as i64), + Minute => Value::from_integer(self.inner.minute() as i64), + Second => Value::from_float( + self.inner.second() as f64 + (self.inner.nanosecond() as f64) / (1_000_000_000_f64), + ), + MilliSecond | Milli => { + Value::from_integer((self.inner.nanosecond() / 1_000_000 % 1_000) as i64) + } + MicroSecond | Micro => { + Value::from_integer((self.inner.nanosecond() / 1_000 % 1_000_000) as i64) + } + NanoSecond | Nano => { + Value::from_integer((self.inner.nanosecond() % 1_000_000_000) as i64) + } + IsoYear => Value::from_integer(self.inner.iso_week().year() as i64), + IsoWeek => Value::from_integer(self.inner.iso_week().week() as i64), + IsoDow => Value::from_integer(self.inner.weekday().days_since(Weekday::Sun) as i64), + YearDay => Value::from_integer(self.inner.ordinal() as i64), + WeekDay => Value::from_integer(self.inner.weekday().num_days_from_sunday() as i64), + Epoch => Value::from_float( + self.inner.timestamp() as f64 + self.inner.nanosecond() as f64 / 1_000_000_000_f64, + ), + } + } +} + +impl From