Implement transaction support in Go adapter

This commit is contained in:
jnesss
2025-05-02 14:39:23 -07:00
parent 6096cfb3d8
commit 2f0bbf6b22
2 changed files with 171 additions and 7 deletions

View File

@@ -1,6 +1,7 @@
package limbo
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
@@ -136,7 +137,94 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) {
return newStmt(stmtPtr, query), nil
}
// begin is needed to implement driver.Conn.. for now not implemented
func (c *limboConn) Begin() (driver.Tx, error) {
return nil, errors.New("transactions not implemented")
// limboTx implements driver.Tx
type limboTx struct {
conn *limboConn
}
// Begin starts a new transaction with default isolation level
func (c *limboConn) Begin() (driver.Tx, error) {
c.Lock()
defer c.Unlock()
if c.ctx == 0 {
return nil, errors.New("connection closed")
}
// Execute BEGIN statement
stmtPtr := connPrepare(c.ctx, "BEGIN")
if stmtPtr == 0 {
return nil, c.getError()
}
stmt := newStmt(stmtPtr, "BEGIN")
defer stmt.Close()
_, err := stmt.Exec(nil)
if err != nil {
return nil, err
}
return &limboTx{conn: c}, nil
}
// BeginTx starts a transaction with the specified options.
// Currently only supports default isolation level and non-read-only transactions.
func (c *limboConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
// Skip handling non-default isolation levels and read-only mode
// for now, letting database/sql package handle these cases
if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) || opts.ReadOnly {
return nil, driver.ErrSkip
}
// Check for context cancellation
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return c.Begin()
}
}
// Commit commits the transaction
func (tx *limboTx) Commit() error {
tx.conn.Lock()
defer tx.conn.Unlock()
if tx.conn.ctx == 0 {
return errors.New("connection closed")
}
stmtPtr := connPrepare(tx.conn.ctx, "COMMIT")
if stmtPtr == 0 {
return tx.conn.getError()
}
stmt := newStmt(stmtPtr, "COMMIT")
defer stmt.Close()
_, err := stmt.Exec(nil)
return err
}
// Rollback aborts the transaction.
// Note: This operation is not currently fully supported by Limbo and will return an error.
func (tx *limboTx) Rollback() error {
tx.conn.Lock()
defer tx.conn.Unlock()
if tx.conn.ctx == 0 {
return errors.New("connection closed")
}
stmtPtr := connPrepare(tx.conn.ctx, "ROLLBACK")
if stmtPtr == 0 {
return tx.conn.getError()
}
stmt := newStmt(stmtPtr, "ROLLBACK")
defer stmt.Close()
_, err := stmt.Exec(nil)
return err
}

View File

@@ -9,8 +9,10 @@ import (
_ "github.com/tursodatabase/limbo"
)
var conn *sql.DB
var connErr error
var (
conn *sql.DB
connErr error
)
func TestMain(m *testing.M) {
conn, connErr = sql.Open("sqlite3", ":memory:")
@@ -59,7 +61,7 @@ func TestQuery(t *testing.T) {
t.Errorf("Expected column %d to be %s, got %s", i, expectedCols[i], col)
}
}
var i = 1
i := 1
for rows.Next() {
var a int
var b string
@@ -78,7 +80,6 @@ func TestQuery(t *testing.T) {
if err = rows.Err(); err != nil {
t.Fatalf("Row iteration error: %v", err)
}
}
func TestFunctions(t *testing.T) {
@@ -280,6 +281,81 @@ func TestDriverRowsErrorMessages(t *testing.T) {
t.Log("Rows error behavior test passed")
}
func TestTransaction(t *testing.T) {
// Open database connection
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening database: %v", err)
}
defer db.Close()
// Create a test table
_, err = db.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)")
if err != nil {
t.Fatalf("Error creating table: %v", err)
}
// Insert initial data
_, err = db.Exec("INSERT INTO test (id, name) VALUES (1, 'Initial')")
if err != nil {
t.Fatalf("Error inserting initial data: %v", err)
}
// Begin a transaction
tx, err := db.Begin()
if err != nil {
t.Fatalf("Error starting transaction: %v", err)
}
// Insert data within the transaction
_, err = tx.Exec("INSERT INTO test (id, name) VALUES (2, 'Transaction')")
if err != nil {
t.Fatalf("Error inserting data in transaction: %v", err)
}
// Commit the transaction
err = tx.Commit()
if err != nil {
t.Fatalf("Error committing transaction: %v", err)
}
// Verify both rows are visible after commit
rows, err := db.Query("SELECT id, name FROM test ORDER BY id")
if err != nil {
t.Fatalf("Error querying data after commit: %v", err)
}
defer rows.Close()
expected := []struct {
id int
name string
}{
{1, "Initial"},
{2, "Transaction"},
}
i := 0
for rows.Next() {
var id int
var name string
if err := rows.Scan(&id, &name); err != nil {
t.Fatalf("Error scanning row: %v", err)
}
if id != expected[i].id || name != expected[i].name {
t.Errorf("Row %d: expected (%d, %s), got (%d, %s)",
i, expected[i].id, expected[i].name, id, name)
}
i++
}
if i != 2 {
t.Fatalf("Expected 2 rows, got %d", i)
}
t.Log("Transaction test passed")
}
func slicesAreEq(a, b []byte) bool {
if len(a) != len(b) {
fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))