From a9b5fc7f6336e762da96b7716cdabc72dcfb5903 Mon Sep 17 00:00:00 2001 From: jnesss Date: Fri, 2 May 2025 11:30:31 -0700 Subject: [PATCH] Add tests for vector operations and date/time functions in Go adapter --- bindings/go/limbo_test.go | 325 +++++++++++++++++++++++++++++++++++++- 1 file changed, 321 insertions(+), 4 deletions(-) diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go index 9527faa5f..dd0676292 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/limbo_test.go @@ -4,13 +4,16 @@ import ( "database/sql" "fmt" "log" + "math" "testing" _ "github.com/tursodatabase/limbo" ) -var conn *sql.DB -var connErr error +var ( + conn *sql.DB + connErr error +) func TestMain(m *testing.M) { conn, connErr = sql.Open("sqlite3", ":memory:") @@ -59,7 +62,7 @@ func TestQuery(t *testing.T) { t.Errorf("Expected column %d to be %s, got %s", i, expectedCols[i], col) } } - var i = 1 + i := 1 for rows.Next() { var a int var b string @@ -78,7 +81,6 @@ func TestQuery(t *testing.T) { if err = rows.Err(); err != nil { t.Fatalf("Row iteration error: %v", err) } - } func TestFunctions(t *testing.T) { @@ -280,6 +282,321 @@ func TestDriverRowsErrorMessages(t *testing.T) { t.Log("Rows error behavior test passed") } +func TestVectorOperations(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening connection: %v", err) + } + defer db.Close() + + // Test creating table with vector columns + _, err = db.Exec(`CREATE TABLE vector_test (id INTEGER PRIMARY KEY, embedding F32_BLOB(64))`) + if err != nil { + t.Fatalf("Error creating vector table: %v", err) + } + + // Test vector insertion + _, err = db.Exec(`INSERT INTO vector_test VALUES (1, vector('[0.1, 0.2, 0.3, 0.4, 0.5]'))`) + if err != nil { + t.Fatalf("Error inserting vector: %v", err) + } + + // Test vector similarity calculation + var similarity float64 + err = db.QueryRow(`SELECT vector_distance_cos(embedding, vector('[0.2, 0.3, 0.4, 0.5, 0.6]')) FROM vector_test WHERE id = 1`).Scan(&similarity) + if err != nil { + t.Fatalf("Error calculating vector similarity: %v", err) + } + if similarity <= 0 || similarity > 1 { + t.Fatalf("Expected similarity between 0 and 1, got %f", similarity) + } + + // Test vector extraction + var extracted string + err = db.QueryRow(`SELECT vector_extract(embedding) FROM vector_test WHERE id = 1`).Scan(&extracted) + if err != nil { + t.Fatalf("Error extracting vector: %v", err) + } + fmt.Printf("Extracted vector: %s\n", extracted) +} + +func TestTransactions(t *testing.T) { + t.Skip("Skipping transaction tests - transactions not yet implemented in the limbo driver") + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening connection: %v", err) + } + defer db.Close() + + // Create test table + _, err = db.Exec(`CREATE TABLE tx_test (id INTEGER PRIMARY KEY, value TEXT)`) + if err != nil { + t.Fatalf("Error creating table: %v", err) + } + + // Test successful transaction + tx, err := db.Begin() + if err != nil { + t.Fatalf("Error beginning transaction: %v", err) + } + + _, err = tx.Exec(`INSERT INTO tx_test VALUES (1, 'before commit')`) + if err != nil { + t.Fatalf("Error executing in transaction: %v", err) + } + + err = tx.Commit() + if err != nil { + t.Fatalf("Error committing transaction: %v", err) + } + + // Verify commit worked + var value string + err = db.QueryRow(`SELECT value FROM tx_test WHERE id = 1`).Scan(&value) + if err != nil { + t.Fatalf("Error querying after commit: %v", err) + } + if value != "before commit" { + t.Fatalf("Expected 'before commit', got '%s'", value) + } + + // Test rollback + tx, err = db.Begin() + if err != nil { + t.Fatalf("Error beginning transaction: %v", err) + } + + _, err = tx.Exec(`INSERT INTO tx_test VALUES (2, 'should rollback')`) + if err != nil { + t.Fatalf("Error executing in transaction: %v", err) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("Error rolling back transaction: %v", err) + } + + // Verify rollback worked + err = db.QueryRow(`SELECT value FROM tx_test WHERE id = 2`).Scan(&value) + if err == nil { + t.Fatalf("Expected error after rollback, got value: %s", value) + } +} + +func TestSQLFeatures(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening connection: %v", err) + } + defer db.Close() + + // Create test tables + _, err = db.Exec(` + CREATE TABLE customers ( + id INTEGER PRIMARY KEY, + name TEXT, + age INTEGER + )`) + if err != nil { + t.Fatalf("Error creating customers table: %v", err) + } + + _, err = db.Exec(` + CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER, + amount REAL, + date TEXT + )`) + if err != nil { + t.Fatalf("Error creating orders table: %v", err) + } + + // Insert test data + _, err = db.Exec(` + INSERT INTO customers VALUES + (1, 'Alice', 30), + (2, 'Bob', 25), + (3, 'Charlie', 40)`) + if err != nil { + t.Fatalf("Error inserting customers: %v", err) + } + + _, err = db.Exec(` + INSERT INTO orders VALUES + (1, 1, 100.50, '2024-01-01'), + (2, 1, 200.75, '2024-02-01'), + (3, 2, 50.25, '2024-01-15'), + (4, 3, 300.00, '2024-02-10')`) + if err != nil { + t.Fatalf("Error inserting orders: %v", err) + } + + // Test JOIN + rows, err := db.Query(` + SELECT c.name, o.amount + FROM customers c + INNER JOIN orders o ON c.id = o.customer_id + ORDER BY o.amount DESC`) + if err != nil { + t.Fatalf("Error executing JOIN: %v", err) + } + defer rows.Close() + + // Check JOIN results + expectedResults := []struct { + name string + amount float64 + }{ + {"Charlie", 300.00}, + {"Alice", 200.75}, + {"Alice", 100.50}, + {"Bob", 50.25}, + } + + i := 0 + for rows.Next() { + var name string + var amount float64 + if err := rows.Scan(&name, &amount); err != nil { + t.Fatalf("Error scanning JOIN result: %v", err) + } + if i >= len(expectedResults) { + t.Fatalf("Too many rows returned from JOIN") + } + if name != expectedResults[i].name || amount != expectedResults[i].amount { + t.Fatalf("Row %d: expected (%s, %.2f), got (%s, %.2f)", + i, expectedResults[i].name, expectedResults[i].amount, name, amount) + } + i++ + } + + // Test GROUP BY with aggregation + var count int + var total float64 + err = db.QueryRow(` + SELECT COUNT(*), SUM(amount) + FROM orders + WHERE customer_id = 1 + GROUP BY customer_id`).Scan(&count, &total) + if err != nil { + t.Fatalf("Error executing GROUP BY: %v", err) + } + if count != 2 || total != 301.25 { + t.Fatalf("GROUP BY gave wrong results: count=%d, total=%.2f", count, total) + } +} + +func TestDateTimeFunctions(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening connection: %v", err) + } + defer db.Close() + + // Test date() + var dateStr string + err = db.QueryRow(`SELECT date('now')`).Scan(&dateStr) + if err != nil { + t.Fatalf("Error with date() function: %v", err) + } + fmt.Printf("Current date: %s\n", dateStr) + + // Test date arithmetic + err = db.QueryRow(`SELECT date('2024-01-01', '+1 month')`).Scan(&dateStr) + if err != nil { + t.Fatalf("Error with date arithmetic: %v", err) + } + if dateStr != "2024-02-01" { + t.Fatalf("Expected '2024-02-01', got '%s'", dateStr) + } + + // Test strftime + var formatted string + err = db.QueryRow(`SELECT strftime('%Y-%m-%d', '2024-01-01')`).Scan(&formatted) + if err != nil { + t.Fatalf("Error with strftime function: %v", err) + } + if formatted != "2024-01-01" { + t.Fatalf("Expected '2024-01-01', got '%s'", formatted) + } +} + +func TestMathFunctions(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening connection: %v", err) + } + defer db.Close() + + // Test basic math functions + var result float64 + err = db.QueryRow(`SELECT abs(-15.5)`).Scan(&result) + if err != nil { + t.Fatalf("Error with abs function: %v", err) + } + if result != 15.5 { + t.Fatalf("abs(-15.5) should be 15.5, got %f", result) + } + + // Test trigonometric functions + err = db.QueryRow(`SELECT round(sin(radians(30)), 4)`).Scan(&result) + if err != nil { + t.Fatalf("Error with sin function: %v", err) + } + if math.Abs(result-0.5) > 0.0001 { + t.Fatalf("sin(30 degrees) should be about 0.5, got %f", result) + } + + // Test power functions + err = db.QueryRow(`SELECT pow(2, 3)`).Scan(&result) + if err != nil { + t.Fatalf("Error with pow function: %v", err) + } + if result != 8 { + t.Fatalf("2^3 should be 8, got %f", result) + } +} + +func TestJSONFunctions(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening connection: %v", err) + } + defer db.Close() + + // Test json function + var valid int + err = db.QueryRow(`SELECT json_valid('{"name":"John","age":30}')`).Scan(&valid) + if err != nil { + t.Fatalf("Error with json_valid function: %v", err) + } + if valid != 1 { + t.Fatalf("Expected valid JSON to return 1, got %d", valid) + } + + // Test json_extract + var name string + err = db.QueryRow(`SELECT json_extract('{"name":"John","age":30}', '$.name')`).Scan(&name) + if err != nil { + t.Fatalf("Error with json_extract function: %v", err) + } + if name != "John" { + t.Fatalf("Expected 'John', got '%s'", name) + } + + // Test JSON shorthand + var age int + err = db.QueryRow(`SELECT '{"name":"John","age":30}' -> '$.age'`).Scan(&age) + if err != nil { + t.Fatalf("Error with JSON shorthand: %v", err) + } + if age != 30 { + t.Fatalf("Expected 30, got %d", age) + } +} + func slicesAreEq(a, b []byte) bool { if len(a) != len(b) { fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))