diff --git a/bindings/python/example.py b/bindings/python/example.py index 53bfb9b34..870766e8f 100644 --- a/bindings/python/example.py +++ b/bindings/python/example.py @@ -1,6 +1,38 @@ import limbo -con = limbo.connect("sqlite.db") -cur = con.cursor() -res = cur.execute("SELECT * FROM users") -print(res.fetchone()) +# Use the context manager to automatically close the connection +with limbo.connect("sqlite.db") as con: + cur = con.cursor() + cur.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL, + email TEXT NOT NULL, + role TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT (datetime('now')) + ) + """) + + # Insert some sample data + sample_users = [ + ("alice", "alice@example.com", "admin"), + ("bob", "bob@example.com", "user"), + ("charlie", "charlie@example.com", "moderator"), + ("diana", "diana@example.com", "user"), + ] + for username, email, role in sample_users: + cur.execute( + """ + INSERT INTO users (username, email, role) + VALUES (?, ?, ?) + """, + (username, email, role), + ) + + # Use commit to ensure the data is saved + con.commit() + + # Query the table + res = cur.execute("SELECT * FROM users") + record = res.fetchone() + print(record) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index f88666c6d..6a09e851b 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -193,9 +193,9 @@ impl Cursor { } pub fn close(&self) -> PyResult<()> { - Err(PyErr::new::( - "close() is not supported in this version", - )) + self.conn.close()?; + + Ok(()) } #[pyo3(signature = (sql, parameters=None))] @@ -244,8 +244,12 @@ impl Connection { }) } - pub fn close(&self) { - drop(self.conn.clone()); + pub fn close(&self) -> PyResult<()> { + self.conn.close().map_err(|e| { + PyErr::new::(format!("Failed to close connection: {:?}", e)) + })?; + + Ok(()) } pub fn commit(&self) -> PyResult<()> { @@ -266,6 +270,27 @@ impl Connection { "Transactions are not supported in this version", )) } + + fn __enter__(&self) -> PyResult { + Ok(self.clone()) + } + + fn __exit__( + &self, + _exc_type: Option<&Bound<'_, PyAny>>, + _exc_val: Option<&Bound<'_, PyAny>>, + _exc_tb: Option<&Bound<'_, PyAny>>, + ) -> PyResult<()> { + self.close() + } +} + +impl Drop for Connection { + fn drop(&mut self) { + self.conn + .close() + .expect("Failed to drop (close) connection"); + } } #[allow(clippy::arc_with_non_send_sync)] diff --git a/bindings/python/tests/test_database.py b/bindings/python/tests/test_database.py index a8276b21c..ec565a898 100644 --- a/bindings/python/tests/test_database.py +++ b/bindings/python/tests/test_database.py @@ -12,27 +12,41 @@ def setup_database(): db_wal_path = "tests/database.db-wal" # Ensure the database file is created fresh for each test - if os.path.exists(db_path): - os.remove(db_path) - if os.path.exists(db_wal_path): - os.remove(db_wal_path) + try: + if os.path.exists(db_path): + os.remove(db_path) + if os.path.exists(db_wal_path): + os.remove(db_wal_path) + except PermissionError as e: + print(f"Failed to clean up: {e}") # Create a new database file conn = sqlite3.connect(db_path) cursor = conn.cursor() - cursor.execute("CREATE TABLE users (id INT PRIMARY KEY, username TEXT)") - cursor.execute("INSERT INTO users VALUES (1, 'alice')") - cursor.execute("INSERT INTO users VALUES (2, 'bob')") + cursor.execute("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY, username TEXT)") + cursor.execute(""" + INSERT INTO users (id, username) + SELECT 1, 'alice' + WHERE NOT EXISTS (SELECT 1 FROM users WHERE id = 1) + """) + cursor.execute(""" + INSERT INTO users (id, username) + SELECT 2, 'bob' + WHERE NOT EXISTS (SELECT 1 FROM users WHERE id = 2) + """) conn.commit() conn.close() yield db_path # Cleanup after the test - if os.path.exists(db_path): - os.remove(db_path) - if os.path.exists(db_wal_path): - os.remove(db_wal_path) + try: + if os.path.exists(db_path): + os.remove(db_path) + if os.path.exists(db_wal_path): + os.remove(db_wal_path) + except PermissionError as e: + print(f"Failed to clean up: {e}") @pytest.mark.parametrize("provider", ["sqlite3", "limbo"]) @@ -145,6 +159,18 @@ def test_commit(provider): assert record +@pytest.mark.parametrize("provider", ["sqlite3", "limbo"]) +def test_with_statement(provider): + with connect(provider, "tests/database.db") as conn: + cursor = conn.cursor() + cursor.execute("SELECT MAX(id) FROM users") + + max_id = cursor.fetchone() + + assert max_id + assert max_id == (2,) + + def connect(provider, database): if provider == "limbo": return limbo.connect(database)