diff --git a/core/lib.rs b/core/lib.rs index 85321fbe0..c5fed4399 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1865,6 +1865,11 @@ impl Connection { .wal_checkpoint(disabled, CheckpointMode::Truncate)?; self.pager.borrow_mut().db_file.copy_to(&*io, file) } + + /// Creates a HashSet of modules that have been loaded + pub fn get_syms_vtab_mods(&self) -> std::collections::HashSet { + self.syms.borrow().vtab_modules.keys().cloned().collect() + } } pub struct Statement { diff --git a/core/pragma.rs b/core/pragma.rs index 19e6862e4..f4b4b3f44 100644 --- a/core/pragma.rs +++ b/core/pragma.rs @@ -58,6 +58,10 @@ pub fn pragma_for(pragma: &PragmaName) -> Pragma { LegacyFileFormat => { unreachable!("pragma_for() called with LegacyFileFormat, which is unsupported") } + ModuleList => Pragma::new( + PragmaFlags::NeedSchema | PragmaFlags::Result0 | PragmaFlags::SchemaReq, + &["module_list"], + ), PageCount => Pragma::new( PragmaFlags::NeedSchema | PragmaFlags::Result0 | PragmaFlags::SchemaReq, &["page_count"], diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 63ceb18b8..3f6ffa2f3 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -146,6 +146,7 @@ fn update_pragma( connection, program, ), + PragmaName::ModuleList => Ok((program, TransactionMode::None)), PragmaName::PageCount => query_pragma( PragmaName::PageCount, schema, @@ -407,6 +408,16 @@ fn query_pragma( program.emit_result_row(register, 3); Ok((program, TransactionMode::None)) } + PragmaName::ModuleList => { + let modules = connection.get_syms_vtab_mods(); + for module in modules { + program.emit_string8(module.to_string(), register); + program.emit_result_row(register, 1); + } + + program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::None)) + } PragmaName::PageCount => { program.emit_insn(Insn::PageCount { db: 0, diff --git a/testing/cli_tests/extensions.py b/testing/cli_tests/extensions.py index 7116854dc..ebb856130 100755 --- a/testing/cli_tests/extensions.py +++ b/testing/cli_tests/extensions.py @@ -701,7 +701,7 @@ def test_create_virtual_table(): ext_path = "target/debug/libturso_ext_tests" limbo = TestTursoShell() - limbo.execute_dot(f".load {ext_path}") + test_module_list(limbo, ext_path, "kv_store") limbo.run_debug("CREATE VIRTUAL TABLE t1 USING kv_store;") limbo.run_test_fn( @@ -741,8 +741,7 @@ def test_csv(): # open new empty connection explicitly to test whether we can load an extension # with brand new connection/uninitialized database. limbo = TestTursoShell(init_commands="") - ext_path = "./target/debug/liblimbo_csv" - limbo.execute_dot(f".load {ext_path}") + test_module_list(limbo, "target/debug/liblimbo_csv", "csv") limbo.run_test_fn( "CREATE VIRTUAL TABLE temp.csv USING csv(filename=./testing/test_files/test.csv);", @@ -828,6 +827,7 @@ def cleanup(): def test_tablestats(): ext_path = "target/debug/libturso_ext_tests" limbo = TestTursoShell(use_testing_db=True) + test_module_list(limbo, ext_path=ext_path, module_name="tablestats") limbo.execute_dot("CREATE TABLE people(id INTEGER PRIMARY KEY, name TEXT);") limbo.execute_dot("INSERT INTO people(name) VALUES ('Ada'), ('Grace'), ('Linus');") @@ -845,8 +845,6 @@ def test_tablestats(): lambda res: res == "1", "one logs rowverify logs count", ) - # load extension - limbo.execute_dot(f".load {ext_path}") limbo.execute_dot("CREATE VIRTUAL TABLE stats USING tablestats;") def _split(res): @@ -1072,6 +1070,25 @@ def _test_hidden_columns(exec_name, ext_path): limbo.quit() +def test_module_list(turso_shell, ext_path, module_name): + """loads the extension at the provided path and asserts that 'PRAGMA module_list;' displays 'module_name'""" + console.info(f"Running test_module_list for {ext_path}") + + turso_shell.run_test_fn( + "PRAGMA module_list;", + lambda res: "generate_series" in res and module_name not in res, + "lists built in modules but doesn't contain the module name yet", + ) + + turso_shell.run_test_fn("PRAGMA module_list;", lambda res: module_name not in res, "does not include module list") + turso_shell.execute_dot(f".load {ext_path}") + turso_shell.run_test_fn( + "PRAGMA module_list;", + lambda res: module_name in res, + f"includes {module_name} after loading extension", + ) + + def main(): try: test_regexp() diff --git a/testing/pragma.test b/testing/pragma.test index fb1aa7a31..bbe75636d 100755 --- a/testing/pragma.test +++ b/testing/pragma.test @@ -325,3 +325,7 @@ do_execsql_test_in_memory_any_error pragma-max-page-count-enforcement-error { PRAGMA max_page_count = 1; CREATE TABLE test (id INTEGER) } + +do_execsql_test_regex pragma-module-list-nonempty { + SELECT * FROM pragma_module_list; +} {\ngenerate_series\n|^generate_series\n|\ngenerate_series$|^generate_series$} diff --git a/tests/integration/mod.rs b/tests/integration/mod.rs index 9d68aff5d..0f45007c6 100644 --- a/tests/integration/mod.rs +++ b/tests/integration/mod.rs @@ -2,5 +2,6 @@ mod common; mod functions; mod fuzz; mod fuzz_transaction; +mod pragma; mod query_processing; mod wal; diff --git a/tests/integration/pragma.rs b/tests/integration/pragma.rs new file mode 100644 index 000000000..14559d11e --- /dev/null +++ b/tests/integration/pragma.rs @@ -0,0 +1,59 @@ +use crate::common::TempDatabase; +use turso_core::{StepResult, Value}; + +#[test] +fn test_pragma_module_list_returns_list() { + let db = TempDatabase::new_empty(false); + let conn = db.connect_limbo(); + + let mut module_list = conn.query("PRAGMA module_list;").unwrap(); + + let mut counter = 0; + + if let Some(ref mut rows) = module_list { + while let StepResult::Row = rows.step().unwrap() { + counter += 1; + } + } + + assert!(counter > 0) +} + +#[test] +fn test_pragma_module_list_generate_series() { + let db = TempDatabase::new_empty(false); + let conn = db.connect_limbo(); + + let mut rows = conn + .query("SELECT * FROM generate_series(1, 3);") + .expect("generate_series module not available") + .expect("query did not return rows"); + + let mut values = vec![]; + while let StepResult::Row = rows.step().unwrap() { + let row = rows.row().unwrap(); + values.push(row.get_value(0).clone()); + } + + assert_eq!( + values, + vec![Value::Integer(1), Value::Integer(2), Value::Integer(3),] + ); + + let mut module_list = conn.query("PRAGMA module_list;").unwrap(); + let mut found = false; + + if let Some(ref mut rows) = module_list { + while let StepResult::Row = rows.step().unwrap() { + let row = rows.row().unwrap(); + if let Value::Text(name) = row.get_value(0) { + if name.as_str() == "generate_series" { + found = true; + break; + } + } + } + } + + assert!(found, "generate_series should appear in module_list"); +} diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index 1dfd07734..a56022e18 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -1814,6 +1814,9 @@ pub enum PragmaName { LegacyFileFormat, /// Set or get the maximum number of pages in the database file. MaxPageCount, + /// `module_list` praagma + /// `module_list` lists modules used by virtual tables. + ModuleList, /// Return the total number of pages in the database file. PageCount, /// Return the page size of the database in bytes.