package redb import ( "errors" "reflect" "strconv" "testing" "github.com/vertex-lab/crawler_v2/pkg/graph" "github.com/vertex-lab/crawler_v2/pkg/walks" "github.com/redis/go-redis/v9" ) func TestInit(t *testing.T) { tests := []struct { name string setup func() (RedisDB, error) err error }{ {name: "seed", setup: Empty}, {name: "invalid", setup: Invalid, err: ErrInvalidWalkParameters}, {name: "valid", setup: SomeWalks(0)}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { db, err := test.setup() if err != nil { t.Fatalf("setup failed: %v", err) } defer db.flushAll() if err = db.init(); !errors.Is(err, test.err) { t.Fatalf("expected error %v, got %v", test.err, err) } }) } } func TestWalksVisiting(t *testing.T) { tests := []struct { name string setup func() (RedisDB, error) limit int expectedWalks int // the number of [defaultWalk] returned }{ { name: "empty", setup: SomeWalks(0), limit: 1, expectedWalks: 0, }, { name: "all walks", setup: SomeWalks(10), limit: -1, expectedWalks: 10, }, { name: "some walks", setup: SomeWalks(100), limit: 33, expectedWalks: 33, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { db, err := test.setup() if err != nil { t.Fatalf("setup failed: %v", err) } defer db.flushAll() visiting, err := db.WalksVisiting(ctx, "0", test.limit) if err != nil { t.Fatalf("expected error nil, got %v", err) } if len(visiting) != test.expectedWalks { t.Fatalf("expected %d walks, got %d", test.expectedWalks, len(visiting)) } for _, walk := range visiting { if !reflect.DeepEqual(walk.Path, defaultWalk.Path) { // compare only the paths, not the IDs t.Fatalf("expected walk %v, got %v", defaultWalk, walk) } } }) } } func TestWalksVisitingAny(t *testing.T) { tests := []struct { name string setup func() (RedisDB, error) limit int expectedWalks int // the number of [defaultWalk] returned }{ { name: "empty", setup: SomeWalks(0), limit: 1, expectedWalks: 0, }, { name: "all walks", setup: SomeWalks(10), limit: -1, expectedWalks: 10, }, { name: "some walks", setup: SomeWalks(100), limit: 20, expectedWalks: 20, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { db, err := test.setup() if err != nil { t.Fatalf("setup failed: %v", err) } defer db.flushAll() nodes := []graph.ID{"0", "1"} visiting, err := db.WalksVisitingAny(ctx, nodes, test.limit) if err != nil { t.Fatalf("expected error nil, got %v", err) } if len(visiting) > test.expectedWalks { t.Fatalf("expected %d walks, got %d", test.expectedWalks, len(visiting)) } for _, walk := range visiting { if !reflect.DeepEqual(walk.Path, defaultWalk.Path) { // compare only the paths, not the IDs t.Fatalf("expected walk %v, got %v", defaultWalk, walk) } } }) } } func TestAddWalks(t *testing.T) { db, err := SomeWalks(1)() if err != nil { t.Fatalf("setup failed: %v", err) } defer db.flushAll() walks := []walks.Walk{ {ID: "1", Path: []graph.ID{"1", "2", "3"}}, {ID: "2", Path: []graph.ID{"4", "5"}}, {ID: "3", Path: []graph.ID{"a", "b", "c"}}, } if err := db.AddWalks(ctx, walks...); err != nil { t.Fatalf("expected error nil, got %v", err) } stored, err := db.Walks(ctx, "1", "2", "3") if err != nil { t.Fatalf("expected error nil, got %v", err) } if !reflect.DeepEqual(stored, walks) { t.Fatalf("expected walks %v, got %v", walks, stored) } total, err := db.TotalVisits(ctx) if err != nil { t.Fatalf("expected error nil, got %v", err) } if total != 10 { t.Fatalf("expected total visits %d, got %d", 10, total) } } func TestRemoveWalks(t *testing.T) { db, err := SomeWalks(10)() if err != nil { t.Fatalf("setup failed: %v", err) } defer db.flushAll() walks := []walks.Walk{ {ID: "0", Path: defaultWalk.Path}, {ID: "1", Path: defaultWalk.Path}, } if err := db.RemoveWalks(ctx, walks...); err != nil { t.Fatalf("expected error nil, got %v", err) } total, err := db.TotalVisits(ctx) if err != nil { t.Fatalf("expected error nil, got %v", err) } expected := (10 - 2) * defaultWalk.Len() if total != expected { t.Fatalf("expected total %d, got %d", expected, total) } visits, err := db.Visits(ctx, "0") if err != nil { t.Fatalf("expected error nil, got %v", err) } expected = (10 - 2) if visits[0] != expected { t.Fatalf("expected visits %d, got %d", expected, visits[0]) } } func TestReplaceWalks(t *testing.T) { t.Run("simple", func(t *testing.T) { db, err := SomeWalks(2)() if err != nil { t.Fatalf("setup failed: %v", err) } defer db.flushAll() before := []walks.Walk{ {ID: "0", Path: []graph.ID{"0", "1"}}, {ID: "1", Path: []graph.ID{"0", "1"}}, } after := []walks.Walk{ {ID: "0", Path: []graph.ID{"0", "2", "3"}}, // changed {ID: "1", Path: []graph.ID{"0", "1"}}, } if err := db.ReplaceWalks(ctx, before, after); err != nil { t.Fatalf("expected error nil, got %v", err) } walks, err := db.Walks(ctx, "0", "1") if err != nil { t.Fatalf("expected error nil, got %v", err) } if !reflect.DeepEqual(walks, after) { t.Fatalf("expected walks %v, got %v", after, walks) } expected := []int{2, 1, 1, 1} visits, err := db.Visits(ctx, "0", "1", "2", "3") if err != nil { t.Fatalf("expected error nil, got %v", err) } if !reflect.DeepEqual(visits, expected) { t.Fatalf("expected visits %v, got %v", expected, visits) } total, err := db.TotalVisits(ctx) if err != nil { t.Fatalf("expected error nil, got %v", err) } if total != 5 { t.Fatalf("expected total %d, got %d", 8, total) } }) t.Run("mass removal", func(t *testing.T) { num := 10000 db, err := SomeWalks(num)() if err != nil { t.Fatalf("setup failed: %v", err) } defer db.flushAll() before := make([]walks.Walk, num) after := make([]walks.Walk, num) for i := range num { ID := walks.ID(strconv.Itoa(i)) before[i] = walks.Walk{ID: ID, Path: defaultWalk.Path} // the walk in the DB after[i] = walks.Walk{ID: ID} // empty walk } if err := db.ReplaceWalks(ctx, before, after); err != nil { t.Fatalf("expected error nil, got %v", err) } visits, err := db.Visits(ctx, "0", "1") if err != nil { t.Fatalf("expected error nil, got %v", err) } if !reflect.DeepEqual(visits, []int{0, 0}) { t.Fatalf("expected visits %v, got %v", []int{0, 0}, visits) } total, err := db.TotalVisits(ctx) if err != nil { t.Fatalf("expected error nil, got %v", err) } if total != 0 { t.Fatalf("expected total %d, got %d", 0, total) } }) } func TestValidateReplacement(t *testing.T) { tests := []struct { name string old []walks.Walk new []walks.Walk err error }{ { name: "no walks", }, { name: "different lenght", old: []walks.Walk{{}}, err: ErrInvalidReplacement, }, { name: "different IDs", old: []walks.Walk{{ID: "0"}, {ID: "1"}}, new: []walks.Walk{{ID: "1"}, {ID: "0"}}, err: ErrInvalidReplacement, }, { name: "repeated IDs", old: []walks.Walk{{ID: "0"}, {ID: "0"}}, new: []walks.Walk{{ID: "0"}, {ID: "0"}}, err: ErrInvalidReplacement, }, { name: "valid IDs", old: []walks.Walk{{ID: "0"}, {ID: "1"}}, new: []walks.Walk{{ID: "0"}, {ID: "1"}}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { err := validateReplacement(test.old, test.new) if !errors.Is(err, test.err) { t.Fatalf("expected error %v, got %v", test.err, err) } }) } } var defaultWalk = walks.Walk{Path: []graph.ID{"0", "1"}} func SomeWalks(n int) func() (RedisDB, error) { return func() (RedisDB, error) { db := RedisDB{Client: redis.NewClient(&redis.Options{Addr: testAddress})} if err := db.Client.HSet(ctx, KeyRWS, KeyAlpha, walks.Alpha, KeyWalksPerNode, walks.N).Err(); err != nil { return RedisDB{}, err } for range n { if err := db.AddWalks(ctx, defaultWalk); err != nil { return RedisDB{}, err } } return db, nil } } func Invalid() (RedisDB, error) { db := RedisDB{Client: redis.NewClient(&redis.Options{Addr: testAddress})} if err := db.Client.HSet(ctx, KeyRWS, KeyAlpha, 69, KeyWalksPerNode, 420).Err(); err != nil { return RedisDB{}, err } return db, nil }