diff --git a/bindings/javascript/src/lib.rs b/bindings/javascript/src/lib.rs index 1616d6b87..5363ab9ff 100644 --- a/bindings/javascript/src/lib.rs +++ b/bindings/javascript/src/lib.rs @@ -65,7 +65,7 @@ pub struct Database { pub open: bool, #[napi(writable = false)] pub name: String, - _db: Arc, + db: Option>, conn: Arc, _io: Arc, } @@ -108,7 +108,7 @@ impl Database { Ok(Self { readonly: opts.readonly(), memory, - _db: db, + db: Some(db), conn, open: true, name: path, @@ -237,6 +237,7 @@ impl Database { pub fn close(&mut self) -> napi::Result<()> { if self.open { self.conn.close().map_err(into_napi_error)?; + self.db.take(); self.open = false; } Ok(()) diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index ee0e0d838..a4f758308 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -606,37 +606,39 @@ mod tests { } } // db and conn are dropped here, simulating closing - // Now, re-open the database and check if the data is still there - let db = Builder::new_local(db_path).build().await?; - let conn = db.connect()?; + { + // Now, re-open the database and check if the data is still there + let db = Builder::new_local(db_path).build().await?; + let conn = db.connect()?; - let mut rows = conn - .query("SELECT data FROM test_large_persistence ORDER BY id;", ()) - .await?; + let mut rows = conn + .query("SELECT data FROM test_large_persistence ORDER BY id;", ()) + .await?; - for (i, value) in original_data.iter().enumerate().take(NUM_INSERTS) { - let row = rows - .next() - .await? - .unwrap_or_else(|| panic!("Expected row {i} but found None")); - assert_eq!( - row.get_value(0)?, - Value::Text(value.clone()), - "Mismatch in retrieved data for row {i}" + for (i, value) in original_data.iter().enumerate().take(NUM_INSERTS) { + let row = rows + .next() + .await? + .unwrap_or_else(|| panic!("Expected row {i} but found None")); + assert_eq!( + row.get_value(0)?, + Value::Text(value.clone()), + "Mismatch in retrieved data for row {i}" + ); + } + + assert!( + rows.next().await?.is_none(), + "Expected no more rows after retrieving all inserted data" ); + + // Delete the WAL file only and try to re-open and query + let wal_path = format!("{db_path}-wal"); + std::fs::remove_file(&wal_path) + .map_err(|e| eprintln!("Warning: Failed to delete WAL file for test: {e}")) + .unwrap(); } - assert!( - rows.next().await?.is_none(), - "Expected no more rows after retrieving all inserted data" - ); - - // Delete the WAL file only and try to re-open and query - let wal_path = format!("{db_path}-wal"); - std::fs::remove_file(&wal_path) - .map_err(|e| eprintln!("Warning: Failed to delete WAL file for test: {e}")) - .unwrap(); - // Attempt to re-open the database after deleting WAL and assert that table is missing. let db_after_wal_delete = Builder::new_local(db_path).build().await?; let conn_after_wal_delete = db_after_wal_delete.connect()?; diff --git a/core/lib.rs b/core/lib.rs index 37bdf257f..e3cdea746 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -74,7 +74,7 @@ use std::{ num::NonZero, ops::Deref, rc::Rc, - sync::{Arc, Mutex}, + sync::{Arc, LazyLock, Mutex, Weak}, }; #[cfg(feature = "fs")] use storage::database::DatabaseFile; @@ -110,6 +110,15 @@ pub(crate) type MvStore = mvcc::MvStore; pub(crate) type MvCursor = mvcc::cursor::ScanCursor; +/// The database manager ensures that there is a single, shared +/// `Database` object per a database file. We need because it is not safe +/// to have multiple independent WAL files open because coordination +/// happens at process-level POSIX file advisory locks. +static DATABASE_MANAGER: LazyLock>>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +/// The `Database` object contains per database file state that is shared +/// between multiple connections. pub struct Database { mv_store: Option>, schema: Mutex>, @@ -226,6 +235,34 @@ impl Database { flags: OpenFlags, enable_mvcc: bool, enable_indexes: bool, + ) -> Result> { + if path == ":memory:" { + return Self::do_open_with_flags(io, path, db_file, flags, enable_mvcc, enable_indexes); + } + + let mut registry = DATABASE_MANAGER.lock().unwrap(); + + let canonical_path = std::fs::canonicalize(path) + .ok() + .and_then(|p| p.to_str().map(|s| s.to_string())) + .unwrap_or_else(|| path.to_string()); + + if let Some(db) = registry.get(&canonical_path).and_then(Weak::upgrade) { + return Ok(db); + } + let db = Self::do_open_with_flags(io, path, db_file, flags, enable_mvcc, enable_indexes)?; + registry.insert(canonical_path, Arc::downgrade(&db)); + Ok(db) + } + + #[allow(clippy::arc_with_non_send_sync)] + fn do_open_with_flags( + io: Arc, + path: &str, + db_file: Arc, + flags: OpenFlags, + enable_mvcc: bool, + enable_indexes: bool, ) -> Result> { let wal_path = format!("{path}-wal"); let maybe_shared_wal = WalFileShared::open_shared_if_exists(&io, wal_path.as_str())?; @@ -1117,6 +1154,26 @@ impl Connection { return Ok(()); } self.closed.set(true); + + match self.transaction_state.get() { + TransactionState::Write { schema_did_change } => { + let _result = self.pager.borrow().end_tx( + true, // rollback = true for close + schema_did_change, + self, + self.wal_checkpoint_disabled.get(), + ); + self.transaction_state.set(TransactionState::None); + } + TransactionState::Read => { + let _result = self.pager.borrow().end_read_tx(); + self.transaction_state.set(TransactionState::None); + } + TransactionState::None => { + // No active transaction + } + } + self.pager .borrow() .checkpoint_shutdown(self.wal_checkpoint_disabled.get()) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 60a708f38..32f512324 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -7348,7 +7348,7 @@ mod tests { fn empty_btree() -> (Rc, usize, Arc, Arc) { #[allow(clippy::arc_with_non_send_sync)] let io: Arc = Arc::new(MemoryIO::new()); - let db = Database::open_file(io.clone(), "test.db", false, false).unwrap(); + let db = Database::open_file(io.clone(), ":memory:", false, false).unwrap(); let conn = db.connect().unwrap(); let pager = conn.pager.borrow().clone(); @@ -8274,7 +8274,7 @@ mod tests { let io: Arc = Arc::new(MemoryIO::new()); let db_file = Arc::new(DatabaseFile::new( - io.open_file("test.db", OpenFlags::Create, false).unwrap(), + io.open_file(":memory:", OpenFlags::Create, false).unwrap(), )); let wal_file = io.open_file("test.wal", OpenFlags::Create, false).unwrap(); diff --git a/sqlite3/tests/compat/mod.rs b/sqlite3/tests/compat/mod.rs index a6aba67c5..9b1b1b56d 100644 --- a/sqlite3/tests/compat/mod.rs +++ b/sqlite3/tests/compat/mod.rs @@ -355,6 +355,7 @@ mod tests { ), SQLITE_OK ); + assert_eq!(sqlite3_close(db), SQLITE_OK); } let mut wal_path = temp_file.path().to_path_buf(); assert!(wal_path.set_extension("db-wal")); @@ -380,6 +381,7 @@ mod tests { assert_eq!(sqlite3_step(stmt), SQLITE_DONE); assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); } + assert_eq!(sqlite3_close(db), SQLITE_OK); } } @@ -459,6 +461,7 @@ mod tests { ), SQLITE_OK ); + assert_eq!(sqlite3_close(db), SQLITE_OK); } let mut wal_path = temp_file.path().to_path_buf(); assert!(wal_path.set_extension("db-wal")); @@ -483,6 +486,7 @@ mod tests { assert_eq!(sqlite3_step(stmt), SQLITE_DONE); assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); } + assert_eq!(sqlite3_close(db), SQLITE_OK); } }