mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-10 18:54:22 +01:00
Implement transaction support in Go adapter
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package limbo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
@@ -136,7 +137,94 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) {
|
||||
return newStmt(stmtPtr, query), nil
|
||||
}
|
||||
|
||||
// begin is needed to implement driver.Conn.. for now not implemented
|
||||
func (c *limboConn) Begin() (driver.Tx, error) {
|
||||
return nil, errors.New("transactions not implemented")
|
||||
// limboTx implements driver.Tx
|
||||
type limboTx struct {
|
||||
conn *limboConn
|
||||
}
|
||||
|
||||
// Begin starts a new transaction with default isolation level
|
||||
func (c *limboConn) Begin() (driver.Tx, error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if c.ctx == 0 {
|
||||
return nil, errors.New("connection closed")
|
||||
}
|
||||
|
||||
// Execute BEGIN statement
|
||||
stmtPtr := connPrepare(c.ctx, "BEGIN")
|
||||
if stmtPtr == 0 {
|
||||
return nil, c.getError()
|
||||
}
|
||||
|
||||
stmt := newStmt(stmtPtr, "BEGIN")
|
||||
defer stmt.Close()
|
||||
|
||||
_, err := stmt.Exec(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &limboTx{conn: c}, nil
|
||||
}
|
||||
|
||||
// BeginTx starts a transaction with the specified options.
|
||||
// Currently only supports default isolation level and non-read-only transactions.
|
||||
func (c *limboConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
// Skip handling non-default isolation levels and read-only mode
|
||||
// for now, letting database/sql package handle these cases
|
||||
if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) || opts.ReadOnly {
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
return c.Begin()
|
||||
}
|
||||
}
|
||||
|
||||
// Commit commits the transaction
|
||||
func (tx *limboTx) Commit() error {
|
||||
tx.conn.Lock()
|
||||
defer tx.conn.Unlock()
|
||||
|
||||
if tx.conn.ctx == 0 {
|
||||
return errors.New("connection closed")
|
||||
}
|
||||
|
||||
stmtPtr := connPrepare(tx.conn.ctx, "COMMIT")
|
||||
if stmtPtr == 0 {
|
||||
return tx.conn.getError()
|
||||
}
|
||||
|
||||
stmt := newStmt(stmtPtr, "COMMIT")
|
||||
defer stmt.Close()
|
||||
|
||||
_, err := stmt.Exec(nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Rollback aborts the transaction.
|
||||
// Note: This operation is not currently fully supported by Limbo and will return an error.
|
||||
func (tx *limboTx) Rollback() error {
|
||||
tx.conn.Lock()
|
||||
defer tx.conn.Unlock()
|
||||
|
||||
if tx.conn.ctx == 0 {
|
||||
return errors.New("connection closed")
|
||||
}
|
||||
|
||||
stmtPtr := connPrepare(tx.conn.ctx, "ROLLBACK")
|
||||
if stmtPtr == 0 {
|
||||
return tx.conn.getError()
|
||||
}
|
||||
|
||||
stmt := newStmt(stmtPtr, "ROLLBACK")
|
||||
defer stmt.Close()
|
||||
|
||||
_, err := stmt.Exec(nil)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
_ "github.com/tursodatabase/limbo"
|
||||
)
|
||||
|
||||
var conn *sql.DB
|
||||
var connErr error
|
||||
var (
|
||||
conn *sql.DB
|
||||
connErr error
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
conn, connErr = sql.Open("sqlite3", ":memory:")
|
||||
@@ -59,7 +61,7 @@ func TestQuery(t *testing.T) {
|
||||
t.Errorf("Expected column %d to be %s, got %s", i, expectedCols[i], col)
|
||||
}
|
||||
}
|
||||
var i = 1
|
||||
i := 1
|
||||
for rows.Next() {
|
||||
var a int
|
||||
var b string
|
||||
@@ -78,7 +80,6 @@ func TestQuery(t *testing.T) {
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("Row iteration error: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFunctions(t *testing.T) {
|
||||
@@ -280,6 +281,81 @@ func TestDriverRowsErrorMessages(t *testing.T) {
|
||||
t.Log("Rows error behavior test passed")
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a test table
|
||||
_, err = db.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)")
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating table: %v", err)
|
||||
}
|
||||
|
||||
// Insert initial data
|
||||
_, err = db.Exec("INSERT INTO test (id, name) VALUES (1, 'Initial')")
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting initial data: %v", err)
|
||||
}
|
||||
|
||||
// Begin a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("Error starting transaction: %v", err)
|
||||
}
|
||||
|
||||
// Insert data within the transaction
|
||||
_, err = tx.Exec("INSERT INTO test (id, name) VALUES (2, 'Transaction')")
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting data in transaction: %v", err)
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Fatalf("Error committing transaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify both rows are visible after commit
|
||||
rows, err := db.Query("SELECT id, name FROM test ORDER BY id")
|
||||
if err != nil {
|
||||
t.Fatalf("Error querying data after commit: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
expected := []struct {
|
||||
id int
|
||||
name string
|
||||
}{
|
||||
{1, "Initial"},
|
||||
{2, "Transaction"},
|
||||
}
|
||||
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
var id int
|
||||
var name string
|
||||
if err := rows.Scan(&id, &name); err != nil {
|
||||
t.Fatalf("Error scanning row: %v", err)
|
||||
}
|
||||
|
||||
if id != expected[i].id || name != expected[i].name {
|
||||
t.Errorf("Row %d: expected (%d, %s), got (%d, %s)",
|
||||
i, expected[i].id, expected[i].name, id, name)
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
if i != 2 {
|
||||
t.Fatalf("Expected 2 rows, got %d", i)
|
||||
}
|
||||
|
||||
t.Log("Transaction test passed")
|
||||
}
|
||||
|
||||
func slicesAreEq(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))
|
||||
|
||||
Reference in New Issue
Block a user