diff --git a/bindings/go/connection.go b/bindings/go/connection.go index 2d7a27e8b..b72ad7e35 100644 --- a/bindings/go/connection.go +++ b/bindings/go/connection.go @@ -1,6 +1,7 @@ package limbo import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -136,7 +137,94 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) { return newStmt(stmtPtr, query), nil } -// begin is needed to implement driver.Conn.. for now not implemented -func (c *limboConn) Begin() (driver.Tx, error) { - return nil, errors.New("transactions not implemented") +// limboTx implements driver.Tx +type limboTx struct { + conn *limboConn +} + +// Begin starts a new transaction with default isolation level +func (c *limboConn) Begin() (driver.Tx, error) { + c.Lock() + defer c.Unlock() + + if c.ctx == 0 { + return nil, errors.New("connection closed") + } + + // Execute BEGIN statement + stmtPtr := connPrepare(c.ctx, "BEGIN") + if stmtPtr == 0 { + return nil, c.getError() + } + + stmt := newStmt(stmtPtr, "BEGIN") + defer stmt.Close() + + _, err := stmt.Exec(nil) + if err != nil { + return nil, err + } + + return &limboTx{conn: c}, nil +} + +// BeginTx starts a transaction with the specified options. +// Currently only supports default isolation level and non-read-only transactions. +func (c *limboConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + // Skip handling non-default isolation levels and read-only mode + // for now, letting database/sql package handle these cases + if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) || opts.ReadOnly { + return nil, driver.ErrSkip + } + + // Check for context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + return c.Begin() + } +} + +// Commit commits the transaction +func (tx *limboTx) Commit() error { + tx.conn.Lock() + defer tx.conn.Unlock() + + if tx.conn.ctx == 0 { + return errors.New("connection closed") + } + + stmtPtr := connPrepare(tx.conn.ctx, "COMMIT") + if stmtPtr == 0 { + return tx.conn.getError() + } + + stmt := newStmt(stmtPtr, "COMMIT") + defer stmt.Close() + + _, err := stmt.Exec(nil) + return err +} + +// Rollback aborts the transaction. +// Note: This operation is not currently fully supported by Limbo and will return an error. +func (tx *limboTx) Rollback() error { + tx.conn.Lock() + defer tx.conn.Unlock() + + if tx.conn.ctx == 0 { + return errors.New("connection closed") + } + + stmtPtr := connPrepare(tx.conn.ctx, "ROLLBACK") + if stmtPtr == 0 { + return tx.conn.getError() + } + + stmt := newStmt(stmtPtr, "ROLLBACK") + defer stmt.Close() + + _, err := stmt.Exec(nil) + return err } diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go index 9527faa5f..d22a2e650 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/limbo_test.go @@ -9,8 +9,10 @@ import ( _ "github.com/tursodatabase/limbo" ) -var conn *sql.DB -var connErr error +var ( + conn *sql.DB + connErr error +) func TestMain(m *testing.M) { conn, connErr = sql.Open("sqlite3", ":memory:") @@ -59,7 +61,7 @@ func TestQuery(t *testing.T) { t.Errorf("Expected column %d to be %s, got %s", i, expectedCols[i], col) } } - var i = 1 + i := 1 for rows.Next() { var a int var b string @@ -78,7 +80,6 @@ func TestQuery(t *testing.T) { if err = rows.Err(); err != nil { t.Fatalf("Row iteration error: %v", err) } - } func TestFunctions(t *testing.T) { @@ -280,6 +281,81 @@ func TestDriverRowsErrorMessages(t *testing.T) { t.Log("Rows error behavior test passed") } +func TestTransaction(t *testing.T) { + // Open database connection + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening database: %v", err) + } + defer db.Close() + + // Create a test table + _, err = db.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + if err != nil { + t.Fatalf("Error creating table: %v", err) + } + + // Insert initial data + _, err = db.Exec("INSERT INTO test (id, name) VALUES (1, 'Initial')") + if err != nil { + t.Fatalf("Error inserting initial data: %v", err) + } + + // Begin a transaction + tx, err := db.Begin() + if err != nil { + t.Fatalf("Error starting transaction: %v", err) + } + + // Insert data within the transaction + _, err = tx.Exec("INSERT INTO test (id, name) VALUES (2, 'Transaction')") + if err != nil { + t.Fatalf("Error inserting data in transaction: %v", err) + } + + // Commit the transaction + err = tx.Commit() + if err != nil { + t.Fatalf("Error committing transaction: %v", err) + } + + // Verify both rows are visible after commit + rows, err := db.Query("SELECT id, name FROM test ORDER BY id") + if err != nil { + t.Fatalf("Error querying data after commit: %v", err) + } + defer rows.Close() + + expected := []struct { + id int + name string + }{ + {1, "Initial"}, + {2, "Transaction"}, + } + + i := 0 + for rows.Next() { + var id int + var name string + if err := rows.Scan(&id, &name); err != nil { + t.Fatalf("Error scanning row: %v", err) + } + + if id != expected[i].id || name != expected[i].name { + t.Errorf("Row %d: expected (%d, %s), got (%d, %s)", + i, expected[i].id, expected[i].name, id, name) + } + i++ + } + + if i != 2 { + t.Fatalf("Expected 2 rows, got %d", i) + } + + t.Log("Transaction test passed") +} + func slicesAreEq(a, b []byte) bool { if len(a) != len(b) { fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))