diff --git a/pkg/redb/graph.go b/pkg/redb/graph.go index 0c2fc72..5bd1c73 100644 --- a/pkg/redb/graph.go +++ b/pkg/redb/graph.go @@ -15,20 +15,20 @@ import ( const ( // redis variable names - KeyDatabase string = "database" // TODO: this can be removed - KeyLastNodeID string = "lastNodeID" // TODO: change it to "next" inside "node" hash - KeyKeyIndex string = "keyIndex" // TODO: change to key_index - KeyNodePrefix string = "node:" - KeyFollowsPrefix string = "follows:" - KeyFollowersPrefix string = "followers:" + KeyDatabase = "database" // TODO: this can be removed + KeyLastNodeID = "lastNodeID" // TODO: change it to "next" inside "node" hash + KeyKeyIndex = "keyIndex" // TODO: change to key_index + KeyNodePrefix = "node:" + KeyFollowsPrefix = "follows:" + KeyFollowersPrefix = "followers:" // redis node HASH fields - NodeID string = "id" - NodePubkey string = "pubkey" - NodeStatus string = "status" - NodePromotionTS string = "promotion_TS" // TODO: change to promotion - NodeDemotionTS string = "demotion_TS" // TODO: change to demotion - NodeAddedTS string = "added_TS" // TODO: change to addition + NodeID = "id" + NodePubkey = "pubkey" + NodeStatus = "status" + NodePromotionTS = "promotion_TS" // TODO: change to promotion + NodeDemotionTS = "demotion_TS" // TODO: change to demotion + NodeAddedTS = "added_TS" // TODO: change to addition ) var ( @@ -41,7 +41,11 @@ type RedisDB struct { } func New(opt *redis.Options) RedisDB { - return RedisDB{client: redis.NewClient(opt)} + r := RedisDB{client: redis.NewClient(opt)} + if err := r.validateWalks(); err != nil { + panic(err) + } + return r } // Size returns the DBSize of redis, which is the total number of keys @@ -178,7 +182,7 @@ func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node gr } } - return toIDs(members), nil + return toNodes(members), nil } // FollowCounts returns the number of follows each node has. If a node is not found, it returns 0. diff --git a/pkg/redb/graph_test.go b/pkg/redb/graph_test.go index fc99475..3d29c77 100644 --- a/pkg/redb/graph_test.go +++ b/pkg/redb/graph_test.go @@ -318,12 +318,12 @@ func TestPubkeys(t *testing.T) { // ------------------------------------- HELPERS ------------------------------- func Empty() (RedisDB, error) { - return New(&redis.Options{Addr: testAddress}), nil + return RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})}, nil } func OneNode() (RedisDB, error) { - db := New(&redis.Options{Addr: testAddress}) - if _, err := db.AddNode(context.Background(), "0"); err != nil { + db := RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})} + if _, err := db.AddNode(ctx, "0"); err != nil { db.flushAll() return RedisDB{}, err } @@ -332,9 +332,7 @@ func OneNode() (RedisDB, error) { } func Simple() (RedisDB, error) { - ctx := context.Background() - db := New(&redis.Options{Addr: testAddress}) - + db := RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})} for _, pk := range []string{"0", "1", "2"} { if _, err := db.AddNode(ctx, pk); err != nil { db.flushAll() diff --git a/pkg/redb/utils.go b/pkg/redb/utils.go index 74f4bc5..395cd0a 100644 --- a/pkg/redb/utils.go +++ b/pkg/redb/utils.go @@ -2,13 +2,19 @@ package redb import ( "context" + "errors" "github/pippellia-btc/crawler/pkg/graph" + "github/pippellia-btc/crawler/pkg/walks" "strconv" + "strings" "time" ) var ( testAddress = "localhost:6380" + + ErrValueIsNil = errors.New("value is nil") + ErrValueIsNotString = errors.New("failed to convert to string") ) // flushAll deletes all the keys of all existing databases. This command never fails. @@ -28,8 +34,12 @@ func followers[ID string | graph.ID](id ID) string { return KeyFollowersPrefix + string(id) } -// ids converts a slice of strings to IDs -func toIDs(s []string) []graph.ID { +func walksVisiting[ID string | graph.ID](id ID) string { + return KeyWalksVisitingPrefix + string(id) +} + +// toNodes converts a slice of strings to node IDs +func toNodes(s []string) []graph.ID { IDs := make([]graph.ID, len(s)) for i, e := range s { IDs[i] = graph.ID(e) @@ -37,8 +47,17 @@ func toIDs(s []string) []graph.ID { return IDs } +// toWalks converts a slice of strings to walk IDs +func toWalks(s []string) []walks.ID { + IDs := make([]walks.ID, len(s)) + for i, e := range s { + IDs[i] = walks.ID(e) + } + return IDs +} + // strings converts graph IDs to a slice of strings -func toStrings(ids []graph.ID) []string { +func toStrings[ID graph.ID | walks.ID](ids []ID) []string { s := make([]string, len(ids)) for i, id := range ids { s[i] = string(id) @@ -98,3 +117,49 @@ func parseTimestamp(unix string) (time.Time, error) { } return time.Unix(ts, 0), nil } + +func formatWalk(walk walks.Walk) string { + nodes := make([]string, walk.Len()) + for i, node := range walk.Path { + nodes[i] = string(node) + } + return strings.Join(nodes, ",") +} + +func parseWalk(s string) walks.Walk { + nodes := strings.Split(s, ",") + walk := walks.Walk{Path: make([]graph.ID, len(nodes))} + for i, node := range nodes { + walk.Path[i] = graph.ID(node) + } + return walk +} + +func parseString(v any) (string, error) { + if v == nil { + return "", ErrValueIsNil + } + + str, ok := v.(string) + if !ok { + return "", ErrValueIsNotString + } + + return str, nil +} + +func parseFloat(v any) (float64, error) { + str, err := parseString(v) + if err != nil { + return 0, err + } + return strconv.ParseFloat(str, 64) +} + +func parseInt(v any) (int, error) { + str, err := parseString(v) + if err != nil { + return 0, err + } + return strconv.Atoi(str) +} diff --git a/pkg/redb/walks.go b/pkg/redb/walks.go index 9195bd4..5ea991a 100644 --- a/pkg/redb/walks.go +++ b/pkg/redb/walks.go @@ -1 +1,280 @@ package redb + +import ( + "context" + "errors" + "fmt" + "github/pippellia-btc/crawler/pkg/graph" + "github/pippellia-btc/crawler/pkg/walks" + "math" + "strconv" +) + +const ( + KeyRWS = "RWS" // TODO: this can be removed + KeyAlpha = "alpha" // TODO: walks:alpha + KeyWalksPerNode = "walksPerNode" // TODO: walks:N or another + KeyLastWalkID = "lastWalkID" // TODO: walks:next + KeyTotalVisits = "totalVisits" // TODO: walks:total_visits + KeyWalks = "walks" + KeyWalksVisitingPrefix = "walksVisiting:" // TODO: walks_visiting: +) + +var ( + ErrWalkNotFound = errors.New("walk not found") + ErrInvalidReplacement = errors.New("invalid walk replacement") + ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks visiting node") +) + +// Walks returns the walks associated with the IDs. +func (r RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, error) { + switch { + case len(IDs) == 0: + return nil, nil + + case len(IDs) <= 100000: + vals, err := r.client.HMGet(ctx, KeyWalks, toStrings(IDs)...).Result() + if err != nil { + return nil, fmt.Errorf("failed to fetch walks: %w", err) + } + + walks := make([]walks.Walk, len(vals)) + for i, val := range vals { + if val == nil { + // walk was not found, so return an error + return nil, fmt.Errorf("failed to fetch walk with ID %s: %w", IDs[i], ErrWalkNotFound) + } + + walks[i] = parseWalk(val.(string)) + walks[i].ID = IDs[i] + } + + return walks, nil + + default: + // too many walks for a single call, so we split them in two batches + mid := len(IDs) / 2 + batch1, err := r.Walks(ctx, IDs[:mid]...) + if err != nil { + return nil, err + } + + batch2, err := r.Walks(ctx, IDs[mid:]...) + if err != nil { + return nil, err + } + + return append(batch1, batch2...), nil + } +} + +// WalksVisiting returns up-to limit walks that visit node. +// Use limit = -1 to fetch all the walks visiting node. +func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([]walks.Walk, error) { + switch { + case limit == -1: + // return all walks visiting node + IDs, err := r.client.SMembers(ctx, walksVisiting(node)).Result() + if err != nil { + return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err) + } + + return r.Walks(ctx, toWalks(IDs)...) + + case limit > 0: + IDs, err := r.client.SRandMemberN(ctx, walksVisiting(node), int64(limit)).Result() + if err != nil { + return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err) + } + + return r.Walks(ctx, toWalks(IDs)...) + + default: + return nil, ErrInvalidLimit + } +} + +// AddWalks adds all the walks to the database assigning them progressive IDs. +func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error { + if len(walks) == 0 { + return nil + } + + // get the IDs outside the transaction, which implies there might be "holes", + // meaning IDs not associated with any walk + next, err := r.client.HIncrBy(ctx, KeyRWS, KeyLastWalkID, int64(len(walks))).Result() + if err != nil { + return fmt.Errorf("failed to add walks: failed to increment ID: %w", err) + } + + var visits, ID int + pipe := r.client.TxPipeline() + + for i, walk := range walks { + visits += walk.Len() + ID = int(next) - len(walks) + i // assigning IDs in the same order + + pipe.HSet(ctx, KeyWalks, ID, formatWalk(walk)) + for _, node := range walk.Path { + pipe.SAdd(ctx, walksVisiting(node), ID) + } + } + + pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, int64(visits)) + + if _, err = pipe.Exec(ctx); err != nil { + return fmt.Errorf("failed to add walks: pipeline failed %w", err) + } + + return nil +} + +// RemoveWalks removes all the walks from the database. +func (r RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error { + if len(walks) == 0 { + return nil + } + + var visits int + pipe := r.client.TxPipeline() + + for _, walk := range walks { + pipe.HDel(ctx, KeyWalks, string(walk.ID)) + for _, node := range walk.Path { + pipe.SRem(ctx, walksVisiting(node), string(walk.ID)) + } + + visits += walk.Len() + } + + pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, -int64(visits)) + + if _, err := pipe.Exec(ctx); err != nil { + return fmt.Errorf("failed to remove walks: pipeline failed %w", err) + } + + return nil +} + +func (r RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) error { + if err := validateReplacement(before, after); err != nil { + return err + } + + var visits int64 + pipe := r.client.TxPipeline() + + for i := range before { + div := walks.Divergence(before[i], after[i]) + if div == -1 { + // the two walks are equal, skip + continue + } + + prev := before[i] + next := after[i] + ID := string(after[i].ID) + + pipe.HSet(ctx, KeyWalks, ID, formatWalk(next)) + + for _, node := range prev.Path[div:] { + pipe.SRem(ctx, walksVisiting(node), ID) + visits-- + } + + for _, node := range next.Path[div:] { + pipe.SAdd(ctx, walksVisiting(node), ID) + visits++ + } + + if pipe.Len() > 5000 { + // execute a partial update when it's too big + pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, visits) + + if _, err := pipe.Exec(ctx); err != nil { + return fmt.Errorf("failed to replace walks: pipeline failed %w", err) + } + + pipe = r.client.TxPipeline() + visits = 0 + } + } + + pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, visits) + + if _, err := pipe.Exec(ctx); err != nil { + return fmt.Errorf("failed to replace walks: pipeline failed %w", err) + } + + return nil +} + +func validateReplacement(old, new []walks.Walk) error { + if len(old) != len(new) { + return fmt.Errorf("%w: old and new walks must have the same lenght", ErrInvalidReplacement) + } + + seen := make(map[walks.ID]struct{}) + for i := range old { + if old[i].ID != new[i].ID { + return fmt.Errorf("%w: IDs don't match at index %d: old=%s, new=%s", ErrInvalidReplacement, i, old[i].ID, new[i].ID) + } + + if _, ok := seen[old[i].ID]; ok { + return fmt.Errorf("%w: repeated walk ID %s", ErrInvalidReplacement, old[i].ID) + } + + seen[old[i].ID] = struct{}{} + } + + return nil +} + +// TotalVisits returns the total number of visits, which is the sum of the lengths of all walks. +func (r RedisDB) TotalVisits(ctx context.Context) (int, error) { + total, err := r.client.HGet(ctx, KeyRWS, KeyTotalVisits).Result() + if err != nil { + return -1, fmt.Errorf("failed to get the total number of visits: %w", err) + } + + tot, err := strconv.Atoi(total) + if err != nil { + return -1, fmt.Errorf("failed to parse the total number of visits: %w", err) + } + + return tot, nil +} + +// Visits returns the number of times each specified node was visited during the walks. +// The returned slice contains counts in the same order as the input nodes. +// If a node is not found, it returns 0 visits. +func (r RedisDB) Visits(ctx context.Context, nodes ...graph.ID) ([]int, error) { + return r.counts(ctx, walksVisiting, nodes...) +} + +func (r RedisDB) validateWalks() error { + vals, err := r.client.HMGet(context.Background(), KeyRWS, KeyAlpha, KeyWalksPerNode).Result() + if err != nil { + return fmt.Errorf("failed to fetch alpha and walksPerNode %w", err) + } + + alpha, err := parseFloat(vals[0]) + if err != nil { + return fmt.Errorf("failed to parse alpha: %w", err) + } + + N, err := parseInt(vals[1]) + if err != nil { + return fmt.Errorf("failed to parse walksPerNode: %w", err) + } + + if math.Abs(alpha-walks.Alpha) > 1e-10 { + return errors.New("alpha and walks.Alpha are different") + } + + if N != walks.N { + return errors.New("N and walks.N are different") + } + + return nil +} diff --git a/pkg/redb/walks_test.go b/pkg/redb/walks_test.go new file mode 100644 index 0000000..5dbca56 --- /dev/null +++ b/pkg/redb/walks_test.go @@ -0,0 +1,317 @@ +package redb + +import ( + "errors" + "github/pippellia-btc/crawler/pkg/graph" + "github/pippellia-btc/crawler/pkg/walks" + "reflect" + "strconv" + "testing" + + "github.com/redis/go-redis/v9" +) + +func TestValidate(t *testing.T) { + tests := []struct { + name string + setup func() (RedisDB, error) + err error + }{ + {name: "empty", setup: Empty, err: ErrValueIsNil}, + {name: "valid", setup: SomeWalks(0)}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + db, err := test.setup() + if err != nil { + t.Fatalf("setup failed: %v", err) + } + defer db.flushAll() + + if err = db.validateWalks(); !errors.Is(err, test.err) { + t.Fatalf("expected error %v, got %v", test.err, err) + } + }) + } +} + +func TestWalksVisiting(t *testing.T) { + tests := []struct { + name string + setup func() (RedisDB, error) + limit int + expectedWalks int // the number of [defaultWalk] returned + }{ + { + name: "empty", + setup: SomeWalks(0), + limit: 1, + expectedWalks: 0, + }, + { + name: "all walks", + setup: SomeWalks(10), + limit: -1, + expectedWalks: 10, + }, + { + name: "some walks", + setup: SomeWalks(100), + limit: 33, + expectedWalks: 33, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + db, err := test.setup() + if err != nil { + t.Fatalf("setup failed: %v", err) + } + defer db.flushAll() + + visiting, err := db.WalksVisiting(ctx, "0", test.limit) + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + if len(visiting) != test.expectedWalks { + t.Fatalf("expected %d walks, got %d", test.expectedWalks, len(visiting)) + } + + for _, walk := range visiting { + if !reflect.DeepEqual(walk.Path, defaultWalk.Path) { + // compare only the paths, not the IDs + t.Fatalf("expected walk %v, got %v", defaultWalk, walk) + } + } + }) + } +} + +func TestAddWalks(t *testing.T) { + db, err := SomeWalks(1)() + if err != nil { + t.Fatalf("setup failed: %v", err) + } + defer db.flushAll() + + walks := []walks.Walk{ + {ID: "1", Path: []graph.ID{"1", "2", "3"}}, + {ID: "2", Path: []graph.ID{"4", "5"}}, + {ID: "3", Path: []graph.ID{"a", "b", "c"}}, + } + + if err := db.AddWalks(ctx, walks...); err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + stored, err := db.Walks(ctx, "1", "2", "3") + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + if !reflect.DeepEqual(stored, walks) { + t.Fatalf("expected walks %v, got %v", walks, stored) + } + + total, err := db.TotalVisits(ctx) + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + if total != 10 { + t.Fatalf("expected total visits %d, got %d", 10, total) + } +} + +func TestRemoveWalks(t *testing.T) { + db, err := SomeWalks(10)() + if err != nil { + t.Fatalf("setup failed: %v", err) + } + defer db.flushAll() + + walks := []walks.Walk{ + {ID: "0", Path: defaultWalk.Path}, + {ID: "1", Path: defaultWalk.Path}, + } + + if err := db.RemoveWalks(ctx, walks...); err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + total, err := db.TotalVisits(ctx) + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + expected := (10 - 2) * defaultWalk.Len() + if total != expected { + t.Fatalf("expected total %d, got %d", expected, total) + } + + visits, err := db.Visits(ctx, "0") + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + expected = (10 - 2) + if visits[0] != expected { + t.Fatalf("expected visits %d, got %d", expected, visits[0]) + } +} + +func TestReplaceWalks(t *testing.T) { + t.Run("simple", func(t *testing.T) { + db, err := SomeWalks(2)() + if err != nil { + t.Fatalf("setup failed: %v", err) + } + defer db.flushAll() + + before := []walks.Walk{ + {ID: "0", Path: []graph.ID{"0", "1"}}, + {ID: "1", Path: []graph.ID{"0", "1"}}, + } + + after := []walks.Walk{ + {ID: "0", Path: []graph.ID{"0", "2", "3"}}, // changed + {ID: "1", Path: []graph.ID{"0", "1"}}, + } + + if err := db.ReplaceWalks(ctx, before, after); err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + walks, err := db.Walks(ctx, "0", "1") + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + if !reflect.DeepEqual(walks, after) { + t.Fatalf("expected walks %v, got %v", after, walks) + } + + expected := []int{2, 1, 1, 1} + visits, err := db.Visits(ctx, "0", "1", "2", "3") + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + if !reflect.DeepEqual(visits, expected) { + t.Fatalf("expected visits %v, got %v", expected, visits) + } + + total, err := db.TotalVisits(ctx) + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + if total != 5 { + t.Fatalf("expected total %d, got %d", 8, total) + } + }) + + t.Run("mass removal", func(t *testing.T) { + num := 10000 + db, err := SomeWalks(num)() + if err != nil { + t.Fatalf("setup failed: %v", err) + } + defer db.flushAll() + + before := make([]walks.Walk, num) + after := make([]walks.Walk, num) + for i := range num { + ID := walks.ID(strconv.Itoa(i)) + before[i] = walks.Walk{ID: ID, Path: defaultWalk.Path} // the walk in the DB + after[i] = walks.Walk{ID: ID} // empty walk + } + + if err := db.ReplaceWalks(ctx, before, after); err != nil { + t.Fatalf("expected error nil, got %v", err) + } + visits, err := db.Visits(ctx, "0", "1") + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + if !reflect.DeepEqual(visits, []int{0, 0}) { + t.Fatalf("expected visits %v, got %v", []int{0, 0}, visits) + } + + total, err := db.TotalVisits(ctx) + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + if total != 0 { + t.Fatalf("expected total %d, got %d", 0, total) + } + }) +} + +func TestValidateReplacement(t *testing.T) { + tests := []struct { + name string + old []walks.Walk + new []walks.Walk + err error + }{ + { + name: "no walks", + }, + { + name: "different lenght", + old: []walks.Walk{{}}, + err: ErrInvalidReplacement, + }, + { + name: "different IDs", + old: []walks.Walk{{ID: "0"}, {ID: "1"}}, + new: []walks.Walk{{ID: "1"}, {ID: "0"}}, + err: ErrInvalidReplacement, + }, + { + name: "repeated IDs", + old: []walks.Walk{{ID: "0"}, {ID: "0"}}, + new: []walks.Walk{{ID: "0"}, {ID: "0"}}, + err: ErrInvalidReplacement, + }, + { + name: "valid IDs", + old: []walks.Walk{{ID: "0"}, {ID: "1"}}, + new: []walks.Walk{{ID: "0"}, {ID: "1"}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := validateReplacement(test.old, test.new) + if !errors.Is(err, test.err) { + t.Fatalf("expected error %v, got %v", test.err, err) + } + }) + } +} + +var defaultWalk = walks.Walk{Path: []graph.ID{"0", "1"}} + +func SomeWalks(n int) func() (RedisDB, error) { + return func() (RedisDB, error) { + db := RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})} + if err := db.client.HSet(ctx, KeyRWS, KeyAlpha, walks.Alpha, KeyWalksPerNode, walks.N).Err(); err != nil { + return RedisDB{}, err + } + + for range n { + if err := db.AddWalks(ctx, defaultWalk); err != nil { + return RedisDB{}, err + } + } + + return db, nil + } +} diff --git a/pkg/walks/walks.go b/pkg/walks/walks.go index 66f302d..0aea710 100644 --- a/pkg/walks/walks.go +++ b/pkg/walks/walks.go @@ -2,6 +2,7 @@ package walks import ( "context" + "errors" "fmt" "github/pippellia-btc/crawler/pkg/graph" "math/rand/v2" @@ -11,6 +12,8 @@ import ( var ( Alpha = 0.85 // the dampening factor N = 100 // the walks per node + + ErrInvalidRemoval = errors.New(fmt.Sprintf("the walks to be removed are different than the expected number %d", N)) ) // ID represent how walks are identified in the storage layer @@ -28,11 +31,6 @@ type Walker interface { Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) } -// New returns a new walk with a preallocated empty path -func New(n int) Walk { - return Walk{Path: make([]graph.ID, 0, n)} -} - // Len returns the lenght of the walk func (w Walk) Len() int { return len(w.Path) @@ -80,6 +78,23 @@ func (w *Walk) Graft(path []graph.ID) { w.Path = w.Path[:pos] } +// Divergence returns the first index where w1 and w2 are different, -1 if equal. +func Divergence(w1, w2 Walk) int { + min := min(w1.Len(), w2.Len()) + for i := range min { + if w1.Path[i] != w2.Path[i] { + return i + } + } + + if w1.Len() == w2.Len() { + // they are all equal, so no divergence + return -1 + } + + return min +} + // Generate [N] random walks for the specified node, using dampening factor [Alpha]. // A walk stops early if a cycle is encountered. // Walk IDs are not set, because it's the responsibility of the storage layer. @@ -146,19 +161,18 @@ func generate(ctx context.Context, walker Walker, start ...graph.ID) ([]graph.ID return path, nil } -// ToRemove returns the IDs of walks that needs to be removed. +// ToRemove returns the walks that need to be removed. // It returns an error if the number of walks to remove differs from the expected [N]. -func ToRemove(node graph.ID, walks []Walk) ([]ID, error) { - toRemove := make([]ID, 0, N) - +func ToRemove(node graph.ID, walks []Walk) ([]Walk, error) { + toRemove := make([]Walk, 0, N) for _, walk := range walks { - if walk.Index(node) != -1 { - toRemove = append(toRemove, walk.ID) + if walk.Index(node) == 0 { + toRemove = append(toRemove, walk) } } if len(toRemove) != N { - return toRemove, fmt.Errorf("walks to be removed (%d) are different than expected (%d)", len(toRemove), N) + return nil, fmt.Errorf("ToRemove: %w: %d", ErrInvalidRemoval, len(toRemove)) } return toRemove, nil diff --git a/pkg/walks/walks_test.go b/pkg/walks/walks_test.go index c1a2c61..860240b 100644 --- a/pkg/walks/walks_test.go +++ b/pkg/walks/walks_test.go @@ -2,6 +2,7 @@ package walks import ( "context" + "errors" "fmt" "github/pippellia-btc/crawler/pkg/graph" "math" @@ -53,6 +54,53 @@ func TestGenerate(t *testing.T) { }) } +func TestToRemove(t *testing.T) { + N = 3 + tests := []struct { + name string + walks []Walk + toRemove []Walk + err error + }{ + { + name: "no walks", + err: ErrInvalidRemoval, + }, + { + name: "too few walks to remove", + walks: []Walk{{Path: []graph.ID{"0", "1"}}}, + err: ErrInvalidRemoval, + }, + { + name: "valid", + walks: []Walk{ + {Path: []graph.ID{"0", "1"}}, + {Path: []graph.ID{"0", "2"}}, + {Path: []graph.ID{"0", "3"}}, + {Path: []graph.ID{"1", "0"}}, + }, + toRemove: []Walk{ + {Path: []graph.ID{"0", "1"}}, + {Path: []graph.ID{"0", "2"}}, + {Path: []graph.ID{"0", "3"}}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + toRemove, err := ToRemove("0", test.walks) + if !errors.Is(err, test.err) { + t.Fatalf("expected error %v, got %v", test.err, err) + } + + if !reflect.DeepEqual(toRemove, test.toRemove) { + t.Fatalf("expected walks to remove %v, got %v", test.toRemove, toRemove) + } + }) + } +} + func TestUpdateRemove(t *testing.T) { walker := NewWalker(map[graph.ID][]graph.ID{ "0": {"3"}, @@ -85,6 +133,26 @@ func TestUpdateRemove(t *testing.T) { } } +func TestDivergence(t *testing.T) { + tests := []struct { + w1 Walk + w2 Walk + expected int + }{ + {w1: Walk{Path: []graph.ID{"0"}}, w2: Walk{Path: []graph.ID{"0", "1"}}, expected: 1}, + {w1: Walk{Path: []graph.ID{"0", "1", "69"}}, w2: Walk{Path: []graph.ID{"0", "1"}}, expected: 2}, + {w1: Walk{Path: []graph.ID{"0", "1", "69"}}, w2: Walk{Path: []graph.ID{"0", "1", "420"}}, expected: 2}, + {w1: Walk{Path: []graph.ID{"a", "b", "c"}}, w2: Walk{Path: []graph.ID{"a", "b", "c"}}, expected: -1}, + {expected: -1}, + } + + for i, test := range tests { + if div := Divergence(test.w1, test.w2); div != test.expected { + t.Fatalf("test %d: expected %d, got %v", i, test.expected, div) + } + } +} + func TestFindCycle(t *testing.T) { tests := []struct { list []graph.ID