diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 6b38f503c..f88666c6d 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -1,8 +1,11 @@ use anyhow::Result; use errors::*; +use limbo_core::types::Text; +use limbo_core::OwnedValue; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyList, PyTuple}; use std::cell::RefCell; +use std::num::NonZeroUsize; use std::rc::Rc; use std::sync::Arc; @@ -62,13 +65,6 @@ pub struct Cursor { smt: Option>>, } -#[pyclass(unsendable)] -#[derive(Clone)] -pub struct Connection { - conn: Rc, - io: Arc, -} - #[allow(unused_variables, clippy::arc_with_non_send_sync)] #[pymethods] impl Cursor { @@ -83,6 +79,20 @@ impl Cursor { let stmt = Rc::new(RefCell::new(statement)); + Python::with_gil(|py| { + if let Some(params) = parameters { + let obj = params.into_bound(py); + + for (i, elem) in obj.iter().enumerate() { + let value = py_to_owned_value(&elem)?; + stmt.borrow_mut() + .bind_at(NonZeroUsize::new(i + 1).unwrap(), value); + } + } + + Ok::<(), anyhow::Error>(()) + })?; + // For DDL and DML statements, // we need to execute the statement immediately if stmt_is_ddl || stmt_is_dml { @@ -215,6 +225,13 @@ fn stmt_is_ddl(sql: &str) -> bool { sql.starts_with("CREATE") || sql.starts_with("ALTER") || sql.starts_with("DROP") } +#[pyclass(unsendable)] +#[derive(Clone)] +pub struct Connection { + conn: Rc, + io: Arc, +} + #[pymethods] impl Connection { pub fn cursor(&self) -> Result { @@ -232,9 +249,16 @@ impl Connection { } pub fn commit(&self) -> PyResult<()> { - Err(PyErr::new::( - "Transactions are not supported in this version", - )) + if !self.conn.get_auto_commit() { + self.conn.execute("COMMIT").map_err(|e| { + PyErr::new::(format!("Failed to commit: {:?}", e)) + })?; + + self.conn.execute("BEGIN").map_err(|e| { + PyErr::new::(format!("Failed to commit: {:?}", e)) + })?; + } + Ok(()) } pub fn rollback(&self) -> PyResult<()> { @@ -293,6 +317,27 @@ fn row_to_py(py: Python, row: &limbo_core::Row) -> Result { .into()) } +/// Converts a Python object to a Limbo OwnedValue +fn py_to_owned_value(obj: &Bound) -> Result { + if obj.is_none() { + return Ok(OwnedValue::Null); + } else if let Ok(integer) = obj.extract::() { + return Ok(OwnedValue::Integer(integer)); + } else if let Ok(float) = obj.extract::() { + return Ok(OwnedValue::Float(float)); + } else if let Ok(string) = obj.extract::() { + return Ok(OwnedValue::Text(Text::from_str(string))); + } else if let Ok(bytes) = obj.downcast::() { + return Ok(OwnedValue::Blob(Rc::new(bytes.as_bytes().to_vec()))); + } else { + return Err(PyErr::new::(format!( + "Unsupported Python type: {}", + obj.get_type().name()? + )) + .into()); + } +} + #[pymodule] fn _limbo(m: &Bound) -> PyResult<()> { m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/bindings/python/tests/database.db b/bindings/python/tests/database.db index 6138f6dfc..89397b76e 100644 Binary files a/bindings/python/tests/database.db and b/bindings/python/tests/database.db differ diff --git a/bindings/python/tests/test_database.py b/bindings/python/tests/test_database.py index 63241f19a..4a8301395 100644 --- a/bindings/python/tests/test_database.py +++ b/bindings/python/tests/test_database.py @@ -67,6 +67,42 @@ def test_fetchone_select_max_user_id(provider): assert max_id == (2,) +# Test case for: https://github.com/tursodatabase/limbo/issues/494 +@pytest.mark.parametrize("provider", ["sqlite3", "limbo"]) +def test_commit(provider): + con = connect(provider, "tests/database.db") + cur = con.cursor() + + cur.execute(""" + CREATE TABLE IF NOT EXISTS users_b ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + email TEXT NOT NULL, + role TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT (datetime('now')) + ) + """) + + con.commit() + + sample_users = [ + ("alice", "alice@example.com", "admin"), + ("bob", "bob@example.com", "user"), + ("charlie", "charlie@example.com", "moderator"), + ("diana", "diana@example.com", "user"), + ] + + for username, email, role in sample_users: + cur.execute("INSERT INTO users_b (username, email, role) VALUES (?, ?, ?)", (username, email, role)) + + con.commit() + + # Now query the table + res = cur.execute("SELECT * FROM users_b") + record = res.fetchone() + assert record + + def connect(provider, database): if provider == "limbo": return limbo.connect(database) diff --git a/core/lib.rs b/core/lib.rs index 4a5f0a206..03392f8ba 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -14,7 +14,7 @@ pub mod result; mod schema; mod storage; mod translate; -mod types; +pub mod types; #[allow(dead_code)] mod util; mod vdbe; @@ -525,6 +525,10 @@ impl Connection { } all_vfs } + + pub fn get_auto_commit(&self) -> bool { + *self.auto_commit.borrow() + } } pub struct Statement {