mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-02 23:04:23 +01:00
Add tests for vector operations and date/time functions in Go adapter
This commit is contained in:
@@ -4,13 +4,16 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
_ "github.com/tursodatabase/limbo"
|
||||
)
|
||||
|
||||
var conn *sql.DB
|
||||
var connErr error
|
||||
var (
|
||||
conn *sql.DB
|
||||
connErr error
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
conn, connErr = sql.Open("sqlite3", ":memory:")
|
||||
@@ -59,7 +62,7 @@ func TestQuery(t *testing.T) {
|
||||
t.Errorf("Expected column %d to be %s, got %s", i, expectedCols[i], col)
|
||||
}
|
||||
}
|
||||
var i = 1
|
||||
i := 1
|
||||
for rows.Next() {
|
||||
var a int
|
||||
var b string
|
||||
@@ -78,7 +81,6 @@ func TestQuery(t *testing.T) {
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("Row iteration error: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFunctions(t *testing.T) {
|
||||
@@ -280,6 +282,321 @@ func TestDriverRowsErrorMessages(t *testing.T) {
|
||||
t.Log("Rows error behavior test passed")
|
||||
}
|
||||
|
||||
func TestVectorOperations(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening connection: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Test creating table with vector columns
|
||||
_, err = db.Exec(`CREATE TABLE vector_test (id INTEGER PRIMARY KEY, embedding F32_BLOB(64))`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating vector table: %v", err)
|
||||
}
|
||||
|
||||
// Test vector insertion
|
||||
_, err = db.Exec(`INSERT INTO vector_test VALUES (1, vector('[0.1, 0.2, 0.3, 0.4, 0.5]'))`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting vector: %v", err)
|
||||
}
|
||||
|
||||
// Test vector similarity calculation
|
||||
var similarity float64
|
||||
err = db.QueryRow(`SELECT vector_distance_cos(embedding, vector('[0.2, 0.3, 0.4, 0.5, 0.6]')) FROM vector_test WHERE id = 1`).Scan(&similarity)
|
||||
if err != nil {
|
||||
t.Fatalf("Error calculating vector similarity: %v", err)
|
||||
}
|
||||
if similarity <= 0 || similarity > 1 {
|
||||
t.Fatalf("Expected similarity between 0 and 1, got %f", similarity)
|
||||
}
|
||||
|
||||
// Test vector extraction
|
||||
var extracted string
|
||||
err = db.QueryRow(`SELECT vector_extract(embedding) FROM vector_test WHERE id = 1`).Scan(&extracted)
|
||||
if err != nil {
|
||||
t.Fatalf("Error extracting vector: %v", err)
|
||||
}
|
||||
fmt.Printf("Extracted vector: %s\n", extracted)
|
||||
}
|
||||
|
||||
func TestTransactions(t *testing.T) {
|
||||
t.Skip("Skipping transaction tests - transactions not yet implemented in the limbo driver")
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening connection: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create test table
|
||||
_, err = db.Exec(`CREATE TABLE tx_test (id INTEGER PRIMARY KEY, value TEXT)`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating table: %v", err)
|
||||
}
|
||||
|
||||
// Test successful transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("Error beginning transaction: %v", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`INSERT INTO tx_test VALUES (1, 'before commit')`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing in transaction: %v", err)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Fatalf("Error committing transaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify commit worked
|
||||
var value string
|
||||
err = db.QueryRow(`SELECT value FROM tx_test WHERE id = 1`).Scan(&value)
|
||||
if err != nil {
|
||||
t.Fatalf("Error querying after commit: %v", err)
|
||||
}
|
||||
if value != "before commit" {
|
||||
t.Fatalf("Expected 'before commit', got '%s'", value)
|
||||
}
|
||||
|
||||
// Test rollback
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("Error beginning transaction: %v", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`INSERT INTO tx_test VALUES (2, 'should rollback')`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing in transaction: %v", err)
|
||||
}
|
||||
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Fatalf("Error rolling back transaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify rollback worked
|
||||
err = db.QueryRow(`SELECT value FROM tx_test WHERE id = 2`).Scan(&value)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error after rollback, got value: %s", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLFeatures(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening connection: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create test tables
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE customers (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
age INTEGER
|
||||
)`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating customers table: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE orders (
|
||||
id INTEGER PRIMARY KEY,
|
||||
customer_id INTEGER,
|
||||
amount REAL,
|
||||
date TEXT
|
||||
)`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating orders table: %v", err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO customers VALUES
|
||||
(1, 'Alice', 30),
|
||||
(2, 'Bob', 25),
|
||||
(3, 'Charlie', 40)`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting customers: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO orders VALUES
|
||||
(1, 1, 100.50, '2024-01-01'),
|
||||
(2, 1, 200.75, '2024-02-01'),
|
||||
(3, 2, 50.25, '2024-01-15'),
|
||||
(4, 3, 300.00, '2024-02-10')`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error inserting orders: %v", err)
|
||||
}
|
||||
|
||||
// Test JOIN
|
||||
rows, err := db.Query(`
|
||||
SELECT c.name, o.amount
|
||||
FROM customers c
|
||||
INNER JOIN orders o ON c.id = o.customer_id
|
||||
ORDER BY o.amount DESC`)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing JOIN: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Check JOIN results
|
||||
expectedResults := []struct {
|
||||
name string
|
||||
amount float64
|
||||
}{
|
||||
{"Charlie", 300.00},
|
||||
{"Alice", 200.75},
|
||||
{"Alice", 100.50},
|
||||
{"Bob", 50.25},
|
||||
}
|
||||
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
var name string
|
||||
var amount float64
|
||||
if err := rows.Scan(&name, &amount); err != nil {
|
||||
t.Fatalf("Error scanning JOIN result: %v", err)
|
||||
}
|
||||
if i >= len(expectedResults) {
|
||||
t.Fatalf("Too many rows returned from JOIN")
|
||||
}
|
||||
if name != expectedResults[i].name || amount != expectedResults[i].amount {
|
||||
t.Fatalf("Row %d: expected (%s, %.2f), got (%s, %.2f)",
|
||||
i, expectedResults[i].name, expectedResults[i].amount, name, amount)
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
// Test GROUP BY with aggregation
|
||||
var count int
|
||||
var total float64
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(*), SUM(amount)
|
||||
FROM orders
|
||||
WHERE customer_id = 1
|
||||
GROUP BY customer_id`).Scan(&count, &total)
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing GROUP BY: %v", err)
|
||||
}
|
||||
if count != 2 || total != 301.25 {
|
||||
t.Fatalf("GROUP BY gave wrong results: count=%d, total=%.2f", count, total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateTimeFunctions(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening connection: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Test date()
|
||||
var dateStr string
|
||||
err = db.QueryRow(`SELECT date('now')`).Scan(&dateStr)
|
||||
if err != nil {
|
||||
t.Fatalf("Error with date() function: %v", err)
|
||||
}
|
||||
fmt.Printf("Current date: %s\n", dateStr)
|
||||
|
||||
// Test date arithmetic
|
||||
err = db.QueryRow(`SELECT date('2024-01-01', '+1 month')`).Scan(&dateStr)
|
||||
if err != nil {
|
||||
t.Fatalf("Error with date arithmetic: %v", err)
|
||||
}
|
||||
if dateStr != "2024-02-01" {
|
||||
t.Fatalf("Expected '2024-02-01', got '%s'", dateStr)
|
||||
}
|
||||
|
||||
// Test strftime
|
||||
var formatted string
|
||||
err = db.QueryRow(`SELECT strftime('%Y-%m-%d', '2024-01-01')`).Scan(&formatted)
|
||||
if err != nil {
|
||||
t.Fatalf("Error with strftime function: %v", err)
|
||||
}
|
||||
if formatted != "2024-01-01" {
|
||||
t.Fatalf("Expected '2024-01-01', got '%s'", formatted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMathFunctions(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening connection: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Test basic math functions
|
||||
var result float64
|
||||
err = db.QueryRow(`SELECT abs(-15.5)`).Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("Error with abs function: %v", err)
|
||||
}
|
||||
if result != 15.5 {
|
||||
t.Fatalf("abs(-15.5) should be 15.5, got %f", result)
|
||||
}
|
||||
|
||||
// Test trigonometric functions
|
||||
err = db.QueryRow(`SELECT round(sin(radians(30)), 4)`).Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("Error with sin function: %v", err)
|
||||
}
|
||||
if math.Abs(result-0.5) > 0.0001 {
|
||||
t.Fatalf("sin(30 degrees) should be about 0.5, got %f", result)
|
||||
}
|
||||
|
||||
// Test power functions
|
||||
err = db.QueryRow(`SELECT pow(2, 3)`).Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("Error with pow function: %v", err)
|
||||
}
|
||||
if result != 8 {
|
||||
t.Fatalf("2^3 should be 8, got %f", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFunctions(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening connection: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Test json function
|
||||
var valid int
|
||||
err = db.QueryRow(`SELECT json_valid('{"name":"John","age":30}')`).Scan(&valid)
|
||||
if err != nil {
|
||||
t.Fatalf("Error with json_valid function: %v", err)
|
||||
}
|
||||
if valid != 1 {
|
||||
t.Fatalf("Expected valid JSON to return 1, got %d", valid)
|
||||
}
|
||||
|
||||
// Test json_extract
|
||||
var name string
|
||||
err = db.QueryRow(`SELECT json_extract('{"name":"John","age":30}', '$.name')`).Scan(&name)
|
||||
if err != nil {
|
||||
t.Fatalf("Error with json_extract function: %v", err)
|
||||
}
|
||||
if name != "John" {
|
||||
t.Fatalf("Expected 'John', got '%s'", name)
|
||||
}
|
||||
|
||||
// Test JSON shorthand
|
||||
var age int
|
||||
err = db.QueryRow(`SELECT '{"name":"John","age":30}' -> '$.age'`).Scan(&age)
|
||||
if err != nil {
|
||||
t.Fatalf("Error with JSON shorthand: %v", err)
|
||||
}
|
||||
if age != 30 {
|
||||
t.Fatalf("Expected 30, got %d", age)
|
||||
}
|
||||
}
|
||||
|
||||
func slicesAreEq(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
fmt.Printf("LENGTHS NOT EQUAL: %d != %d\n", len(a), len(b))
|
||||
|
||||
Reference in New Issue
Block a user