Add tests for vector operations and date/time functions in Go adapter

This commit is contained in:
jnesss
2025-05-02 11:30:31 -07:00
parent 6096cfb3d8
commit a9b5fc7f63

View File

@@ -4,13 +4,16 @@ import (
"database/sql"
"fmt"
"log"
"math"
"testing"
_ "github.com/tursodatabase/limbo"
)
var conn *sql.DB
var connErr error
var (
conn *sql.DB
connErr error
)
func TestMain(m *testing.M) {
conn, connErr = sql.Open("sqlite3", ":memory:")
@@ -59,7 +62,7 @@ func TestQuery(t *testing.T) {
t.Errorf("Expected column %d to be %s, got %s", i, expectedCols[i], col)
}
}
var i = 1
i := 1
for rows.Next() {
var a int
var b string
@@ -78,7 +81,6 @@ func TestQuery(t *testing.T) {
if err = rows.Err(); err != nil {
t.Fatalf("Row iteration error: %v", err)
}
}
func TestFunctions(t *testing.T) {
@@ -280,6 +282,321 @@ func TestDriverRowsErrorMessages(t *testing.T) {
t.Log("Rows error behavior test passed")
}
func TestVectorOperations(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening connection: %v", err)
}
defer db.Close()
// Test creating table with vector columns
_, err = db.Exec(`CREATE TABLE vector_test (id INTEGER PRIMARY KEY, embedding F32_BLOB(64))`)
if err != nil {
t.Fatalf("Error creating vector table: %v", err)
}
// Test vector insertion
_, err = db.Exec(`INSERT INTO vector_test VALUES (1, vector('[0.1, 0.2, 0.3, 0.4, 0.5]'))`)
if err != nil {
t.Fatalf("Error inserting vector: %v", err)
}
// Test vector similarity calculation
var similarity float64
err = db.QueryRow(`SELECT vector_distance_cos(embedding, vector('[0.2, 0.3, 0.4, 0.5, 0.6]')) FROM vector_test WHERE id = 1`).Scan(&similarity)
if err != nil {
t.Fatalf("Error calculating vector similarity: %v", err)
}
if similarity <= 0 || similarity > 1 {
t.Fatalf("Expected similarity between 0 and 1, got %f", similarity)
}
// Test vector extraction
var extracted string
err = db.QueryRow(`SELECT vector_extract(embedding) FROM vector_test WHERE id = 1`).Scan(&extracted)
if err != nil {
t.Fatalf("Error extracting vector: %v", err)
}
fmt.Printf("Extracted vector: %s\n", extracted)
}
func TestTransactions(t *testing.T) {
t.Skip("Skipping transaction tests - transactions not yet implemented in the limbo driver")
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening connection: %v", err)
}
defer db.Close()
// Create test table
_, err = db.Exec(`CREATE TABLE tx_test (id INTEGER PRIMARY KEY, value TEXT)`)
if err != nil {
t.Fatalf("Error creating table: %v", err)
}
// Test successful transaction
tx, err := db.Begin()
if err != nil {
t.Fatalf("Error beginning transaction: %v", err)
}
_, err = tx.Exec(`INSERT INTO tx_test VALUES (1, 'before commit')`)
if err != nil {
t.Fatalf("Error executing in transaction: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Error committing transaction: %v", err)
}
// Verify commit worked
var value string
err = db.QueryRow(`SELECT value FROM tx_test WHERE id = 1`).Scan(&value)
if err != nil {
t.Fatalf("Error querying after commit: %v", err)
}
if value != "before commit" {
t.Fatalf("Expected 'before commit', got '%s'", value)
}
// Test rollback
tx, err = db.Begin()
if err != nil {
t.Fatalf("Error beginning transaction: %v", err)
}
_, err = tx.Exec(`INSERT INTO tx_test VALUES (2, 'should rollback')`)
if err != nil {
t.Fatalf("Error executing in transaction: %v", err)
}
err = tx.Rollback()
if err != nil {
t.Fatalf("Error rolling back transaction: %v", err)
}
// Verify rollback worked
err = db.QueryRow(`SELECT value FROM tx_test WHERE id = 2`).Scan(&value)
if err == nil {
t.Fatalf("Expected error after rollback, got value: %s", value)
}
}
func TestSQLFeatures(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening connection: %v", err)
}
defer db.Close()
// Create test tables
_, err = db.Exec(`
CREATE TABLE customers (
id INTEGER PRIMARY KEY,
name TEXT,
age INTEGER
)`)
if err != nil {
t.Fatalf("Error creating customers table: %v", err)
}
_, err = db.Exec(`
CREATE TABLE orders (
id INTEGER PRIMARY KEY,
customer_id INTEGER,
amount REAL,
date TEXT
)`)
if err != nil {
t.Fatalf("Error creating orders table: %v", err)
}
// Insert test data
_, err = db.Exec(`
INSERT INTO customers VALUES
(1, 'Alice', 30),
(2, 'Bob', 25),
(3, 'Charlie', 40)`)
if err != nil {
t.Fatalf("Error inserting customers: %v", err)
}
_, err = db.Exec(`
INSERT INTO orders VALUES
(1, 1, 100.50, '2024-01-01'),
(2, 1, 200.75, '2024-02-01'),
(3, 2, 50.25, '2024-01-15'),
(4, 3, 300.00, '2024-02-10')`)
if err != nil {
t.Fatalf("Error inserting orders: %v", err)
}
// Test JOIN
rows, err := db.Query(`
SELECT c.name, o.amount
FROM customers c
INNER JOIN orders o ON c.id = o.customer_id
ORDER BY o.amount DESC`)
if err != nil {
t.Fatalf("Error executing JOIN: %v", err)
}
defer rows.Close()
// Check JOIN results
expectedResults := []struct {
name string
amount float64
}{
{"Charlie", 300.00},
{"Alice", 200.75},
{"Alice", 100.50},
{"Bob", 50.25},
}
i := 0
for rows.Next() {
var name string
var amount float64
if err := rows.Scan(&name, &amount); err != nil {
t.Fatalf("Error scanning JOIN result: %v", err)
}
if i >= len(expectedResults) {
t.Fatalf("Too many rows returned from JOIN")
}
if name != expectedResults[i].name || amount != expectedResults[i].amount {
t.Fatalf("Row %d: expected (%s, %.2f), got (%s, %.2f)",
i, expectedResults[i].name, expectedResults[i].amount, name, amount)
}
i++
}
// Test GROUP BY with aggregation
var count int
var total float64
err = db.QueryRow(`
SELECT COUNT(*), SUM(amount)
FROM orders
WHERE customer_id = 1
GROUP BY customer_id`).Scan(&count, &total)
if err != nil {
t.Fatalf("Error executing GROUP BY: %v", err)
}
if count != 2 || total != 301.25 {
t.Fatalf("GROUP BY gave wrong results: count=%d, total=%.2f", count, total)
}
}
func TestDateTimeFunctions(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening connection: %v", err)
}
defer db.Close()
// Test date()
var dateStr string
err = db.QueryRow(`SELECT date('now')`).Scan(&dateStr)
if err != nil {
t.Fatalf("Error with date() function: %v", err)
}
fmt.Printf("Current date: %s\n", dateStr)
// Test date arithmetic
err = db.QueryRow(`SELECT date('2024-01-01', '+1 month')`).Scan(&dateStr)
if err != nil {
t.Fatalf("Error with date arithmetic: %v", err)
}
if dateStr != "2024-02-01" {
t.Fatalf("Expected '2024-02-01', got '%s'", dateStr)
}
// Test strftime
var formatted string
err = db.QueryRow(`SELECT strftime('%Y-%m-%d', '2024-01-01')`).Scan(&formatted)
if err != nil {
t.Fatalf("Error with strftime function: %v", err)
}
if formatted != "2024-01-01" {
t.Fatalf("Expected '2024-01-01', got '%s'", formatted)
}
}
func TestMathFunctions(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening connection: %v", err)
}
defer db.Close()
// Test basic math functions
var result float64
err = db.QueryRow(`SELECT abs(-15.5)`).Scan(&result)
if err != nil {
t.Fatalf("Error with abs function: %v", err)
}
if result != 15.5 {
t.Fatalf("abs(-15.5) should be 15.5, got %f", result)
}
// Test trigonometric functions
err = db.QueryRow(`SELECT round(sin(radians(30)), 4)`).Scan(&result)
if err != nil {
t.Fatalf("Error with sin function: %v", err)
}
if math.Abs(result-0.5) > 0.0001 {
t.Fatalf("sin(30 degrees) should be about 0.5, got %f", result)
}
// Test power functions
err = db.QueryRow(`SELECT pow(2, 3)`).Scan(&result)
if err != nil {
t.Fatalf("Error with pow function: %v", err)
}
if result != 8 {
t.Fatalf("2^3 should be 8, got %f", result)
}
}
func TestJSONFunctions(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening connection: %v", err)
}
defer db.Close()
// Test json function
var valid int
err = db.QueryRow(`SELECT json_valid('{"name":"John","age":30}')`).Scan(&valid)
if err != nil {
t.Fatalf("Error with json_valid function: %v", err)
}
if valid != 1 {
t.Fatalf("Expected valid JSON to return 1, got %d", valid)
}
// Test json_extract
var name string
err = db.QueryRow(`SELECT json_extract('{"name":"John","age":30}', '$.name')`).Scan(&name)
if err != nil {
t.Fatalf("Error with json_extract function: %v", err)
}
if name != "John" {
t.Fatalf("Expected 'John', got '%s'", name)
}
// Test JSON shorthand
var age int
err = db.QueryRow(`SELECT '{"name":"John","age":30}' -> '$.age'`).Scan(&age)
if err != nil {
t.Fatalf("Error with JSON shorthand: %v", err)
}
if age != 30 {
t.Fatalf("Expected 30, got %d", age)
}
}
func slicesAreEq(a, b []byte) bool {
if len(a) != len(b) {
fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))