diff --git a/pkg/pagerank/cache.go b/pkg/pagerank/cache.go index d355911..d8e2397 100644 --- a/pkg/pagerank/cache.go +++ b/pkg/pagerank/cache.go @@ -11,11 +11,17 @@ type cachedWalker struct { fallback walks.Walker } -func newCachedWalker(followsMap map[graph.ID][]graph.ID, fallback walks.Walker) *cachedWalker { - return &cachedWalker{ - follows: followsMap, +func newCachedWalker(nodes []graph.ID, follows [][]graph.ID, fallback walks.Walker) *cachedWalker { + w := cachedWalker{ + follows: make(map[graph.ID][]graph.ID, len(nodes)), fallback: fallback, } + + for i, node := range nodes { + w.follows[node] = follows[i] + } + + return &w } func (w *cachedWalker) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) { diff --git a/pkg/pagerank/pagerank.go b/pkg/pagerank/pagerank.go index 28383c2..36f17fb 100644 --- a/pkg/pagerank/pagerank.go +++ b/pkg/pagerank/pagerank.go @@ -57,7 +57,7 @@ type PersonalizedLoader interface { Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) // BulkFollows returns the follow-lists of the specified nodes - BulkFollows(ctx context.Context, nodes []graph.ID) (map[graph.ID][]graph.ID, error) + BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) // WalksVisitingAny returns up to limit walks that visit the specified nodes. // The walks are distributed evenly among the nodes: @@ -116,7 +116,7 @@ func Personalized( return map[graph.ID]float64{source: 1.0}, nil } - followMap, err := loader.BulkFollows(ctx, follows) + followByNode, err := loader.BulkFollows(ctx, follows) if err != nil { return nil, fmt.Errorf("Personalized: failed to fetch the two-hop network of source: %w", err) } @@ -127,7 +127,7 @@ func Personalized( return nil, fmt.Errorf("Personalized: failed to fetch the walk: %w", err) } - walker := newCachedWalker(followMap, loader) + walker := newCachedWalker(follows, followByNode, loader) pool := newWalkPool(walks) walk, err := personalizedWalk(ctx, walker, pool, source, targetLenght) diff --git a/pkg/redb/graph.go b/pkg/redb/graph.go index 5bd1c73..7c41d19 100644 --- a/pkg/redb/graph.go +++ b/pkg/redb/graph.go @@ -185,6 +185,72 @@ func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node gr return toNodes(members), nil } +// BulkFollows returns the follow-lists of all the provided nodes. +// Do not call on too many nodes (e.g. +100k) to avoid too many recursions. +func (r RedisDB) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) { + return r.bulkMembers(ctx, follows, nodes) +} + +func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nodes []graph.ID) ([][]graph.ID, error) { + switch { + case len(nodes) == 0: + return nil, nil + + case len(nodes) < 10000: + pipe := r.client.Pipeline() + cmds := make([]*redis.StringSliceCmd, len(nodes)) + + for i, node := range nodes { + cmds[i] = pipe.SMembers(ctx, key(node)) + } + + if _, err := pipe.Exec(ctx); err != nil { + return nil, fmt.Errorf("failed to fetch the %s of %d nodes: %w", key(""), len(nodes), err) + } + + var empty []string + members := make([][]graph.ID, len(nodes)) + + for i, cmd := range cmds { + m := cmd.Val() + if len(m) == 0 { + // empty slice might mean node not found. + empty = append(empty, node(nodes[i])) + } + + members[i] = toNodes(m) + } + + if len(empty) > 0 { + exists, err := r.client.Exists(ctx, empty...).Result() + if err != nil { + return nil, err + } + + if int(exists) < len(empty) { + return nil, fmt.Errorf("failed to fetch the %s of these nodes %v: %w", key(""), empty, ErrNodeNotFound) + } + } + + return members, nil + + default: + // too many nodes, split them in two batches + mid := len(nodes) / 2 + batch1, err := r.bulkMembers(ctx, key, nodes[:mid]) + if err != nil { + return nil, err + } + + batch2, err := r.bulkMembers(ctx, key, nodes[mid:]) + if err != nil { + return nil, err + } + + return append(batch1, batch2...), nil + } +} + // FollowCounts returns the number of follows each node has. If a node is not found, it returns 0. func (r RedisDB) FollowCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) { return r.counts(ctx, follows, nodes...) @@ -202,15 +268,13 @@ func (r RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes .. pipe := r.client.Pipeline() cmds := make([]*redis.IntCmd, len(nodes)) - keys := make([]string, len(nodes)) for i, node := range nodes { - keys[i] = key(node) cmds[i] = pipe.SCard(ctx, key(node)) } if _, err := pipe.Exec(ctx); err != nil { - return nil, fmt.Errorf("failed to count the elements of %v: %w", keys, err) + return nil, fmt.Errorf("failed to count the elements of %d nodes: %w", len(nodes), err) } counts := make([]int, len(nodes)) diff --git a/pkg/redb/graph_test.go b/pkg/redb/graph_test.go index 3d29c77..ee6109c 100644 --- a/pkg/redb/graph_test.go +++ b/pkg/redb/graph_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github/pippellia-btc/crawler/pkg/graph" + "github/pippellia-btc/crawler/pkg/pagerank" "reflect" "testing" "time" @@ -175,6 +176,60 @@ func TestMembers(t *testing.T) { } } +func TestBulkMembers(t *testing.T) { + tests := []struct { + name string + setup func() (RedisDB, error) + nodes []graph.ID + expected [][]graph.ID + err error + }{ + { + name: "empty database", + setup: Empty, + nodes: []graph.ID{"0"}, + err: ErrNodeNotFound, + }, + { + name: "node not found", + setup: OneNode, + nodes: []graph.ID{"0", "1"}, + err: ErrNodeNotFound, + }, + { + name: "dandling node", + setup: OneNode, + nodes: []graph.ID{"0"}, + expected: [][]graph.ID{{}}, + }, + { + name: "valid", + setup: Simple, + nodes: []graph.ID{"0", "1"}, + expected: [][]graph.ID{{"1"}, {}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + db, err := test.setup() + if err != nil { + t.Fatalf("setup failed: %v", err) + } + defer db.flushAll() + + follows, err := db.bulkMembers(ctx, follows, test.nodes) + if !errors.Is(err, test.err) { + t.Fatalf("expected error %v, got %v", test.err, err) + } + + if !reflect.DeepEqual(follows, test.expected) { + t.Errorf("expected follows %v, got %v", test.expected, follows) + } + }) + } +} + func TestUpdateFollows(t *testing.T) { db, err := Simple() if err != nil { @@ -315,6 +370,10 @@ func TestPubkeys(t *testing.T) { } } +func TestInterfaces(t *testing.T) { + var _ pagerank.PersonalizedLoader = RedisDB{} +} + // ------------------------------------- HELPERS ------------------------------- func Empty() (RedisDB, error) { diff --git a/pkg/redb/walks.go b/pkg/redb/walks.go index 5ea991a..1b5776a 100644 --- a/pkg/redb/walks.go +++ b/pkg/redb/walks.go @@ -1,13 +1,17 @@ package redb import ( + "cmp" "context" "errors" "fmt" "github/pippellia-btc/crawler/pkg/graph" "github/pippellia-btc/crawler/pkg/walks" "math" + "slices" "strconv" + + "github.com/redis/go-redis/v9" ) const ( @@ -23,7 +27,7 @@ const ( var ( ErrWalkNotFound = errors.New("walk not found") ErrInvalidReplacement = errors.New("invalid walk replacement") - ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks visiting node") + ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks") ) // Walks returns the walks associated with the IDs. @@ -94,6 +98,65 @@ func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([ } } +// WalksVisitingAny returns up to limit walks that visit the specified nodes. +// The walks are distributed evenly among the nodes: +// - if limit == -1, all walks are returned (use with few nodes) +// - if limit < len(nodes), no walks are returned +func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit int) ([]walks.Walk, error) { + switch { + case limit == -1: + // return all walks visiting all nodes + pipe := r.client.Pipeline() + cmds := make([]*redis.StringSliceCmd, len(nodes)) + + for i, node := range nodes { + cmds[i] = pipe.SMembers(ctx, walksVisiting(node)) + } + + if _, err := pipe.Exec(ctx); err != nil { + return nil, fmt.Errorf("failed to fetch all walks visiting %d nodes: %w", len(nodes), err) + } + + IDs := make([]string, 0, walks.N*len(nodes)) + for _, cmd := range cmds { + IDs = append(IDs, cmd.Val()...) + } + + unique := unique(IDs) + return r.Walks(ctx, toWalks(unique)...) + + case limit > 0: + // return limit walks uniformely distributed across all nodes + nodeLimit := int64(limit / len(nodes)) + if nodeLimit == 0 { + return nil, nil + } + + pipe := r.client.Pipeline() + cmds := make([]*redis.StringSliceCmd, len(nodes)) + + for i, node := range nodes { + cmds[i] = pipe.SRandMemberN(ctx, walksVisiting(node), nodeLimit) + } + + if _, err := pipe.Exec(ctx); err != nil { + return nil, fmt.Errorf("failed to fetch %d walks visiting %d nodes: %w", limit, len(nodes), err) + } + + IDs := make([]string, 0, limit) + for _, cmd := range cmds { + IDs = append(IDs, cmd.Val()...) + } + + unique := unique(IDs) + return r.Walks(ctx, toWalks(unique)...) + + default: + // invalid limit + return nil, fmt.Errorf("failed to fetch walks visiting any: %w", ErrInvalidLimit) + } +} + // AddWalks adds all the walks to the database assigning them progressive IDs. func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error { if len(walks) == 0 { @@ -278,3 +341,22 @@ func (r RedisDB) validateWalks() error { return nil } + +// unique returns a slice of unique elements of the input slice. +func unique[S ~[]E, E cmp.Ordered](slice S) S { + if len(slice) == 0 { + return nil + } + + slices.Sort(slice) + unique := make(S, 0, len(slice)) + unique = append(unique, slice[0]) + + for i := 1; i < len(slice); i++ { + if slice[i] != slice[i-1] { + unique = append(unique, slice[i]) + } + } + + return unique +} diff --git a/pkg/redb/walks_test.go b/pkg/redb/walks_test.go index 5dbca56..06ebe02 100644 --- a/pkg/redb/walks_test.go +++ b/pkg/redb/walks_test.go @@ -90,6 +90,61 @@ func TestWalksVisiting(t *testing.T) { } } +func TestWalksVisitingAny(t *testing.T) { + tests := []struct { + name string + setup func() (RedisDB, error) + limit int + expectedWalks int // the number of [defaultWalk] returned + }{ + { + name: "empty", + setup: SomeWalks(0), + limit: 1, + expectedWalks: 0, + }, + { + name: "all walks", + setup: SomeWalks(10), + limit: -1, + expectedWalks: 10, + }, + { + name: "some walks", + setup: SomeWalks(100), + limit: 20, + expectedWalks: 20, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + db, err := test.setup() + if err != nil { + t.Fatalf("setup failed: %v", err) + } + defer db.flushAll() + + nodes := []graph.ID{"0", "1"} + visiting, err := db.WalksVisitingAny(ctx, nodes, test.limit) + if err != nil { + t.Fatalf("expected error nil, got %v", err) + } + + if len(visiting) > test.expectedWalks { + t.Fatalf("expected %d walks, got %d", test.expectedWalks, len(visiting)) + } + + for _, walk := range visiting { + if !reflect.DeepEqual(walk.Path, defaultWalk.Path) { + // compare only the paths, not the IDs + t.Fatalf("expected walk %v, got %v", defaultWalk, walk) + } + } + }) + } +} + func TestAddWalks(t *testing.T) { db, err := SomeWalks(1)() if err != nil { @@ -297,6 +352,38 @@ func TestValidateReplacement(t *testing.T) { } } +func TestUnique(t *testing.T) { + tests := []struct { + slice []walks.ID + expected []walks.ID + }{ + {slice: nil, expected: nil}, + {slice: []walks.ID{}, expected: nil}, + {slice: []walks.ID{"1", "2", "0"}, expected: []walks.ID{"0", "1", "2"}}, + {slice: []walks.ID{"1", "2", "0", "3", "1", "0"}, expected: []walks.ID{"0", "1", "2", "3"}}, + } + + for _, test := range tests { + unique := unique(test.slice) + if !reflect.DeepEqual(unique, test.expected) { + t.Errorf("expected %v, got %v", test.expected, unique) + } + } +} + +func BenchmarkUnique(b *testing.B) { + size := 1000000 + IDs := make([]walks.ID, size) + for i := 0; i < size; i++ { + IDs[i] = walks.ID(strconv.Itoa(i)) + } + + b.ResetTimer() + for range b.N { + unique(IDs) + } +} + var defaultWalk = walks.Walk{Path: []graph.ID{"0", "1"}} func SomeWalks(n int) func() (RedisDB, error) { diff --git a/tests/random/store_test.go b/tests/random/store_test.go index b894561..ebb6029 100644 --- a/tests/random/store_test.go +++ b/tests/random/store_test.go @@ -99,18 +99,18 @@ func (l *mockLoader) Follows(ctx context.Context, node graph.ID) ([]graph.ID, er return l.walker.Follows(ctx, node) } -func (l *mockLoader) BulkFollows(ctx context.Context, nodes []graph.ID) (map[graph.ID][]graph.ID, error) { - followsMap := make(map[graph.ID][]graph.ID, len(nodes)) - for _, node := range nodes { - follows, err := l.walker.Follows(ctx, node) +func (l *mockLoader) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) { + var err error + follows := make([][]graph.ID, len(nodes)) + + for i, node := range nodes { + follows[i], err = l.walker.Follows(ctx, node) if err != nil { return nil, err } - - followsMap[node] = follows } - return followsMap, nil + return follows, nil } func (l *mockLoader) AddWalks(w []walks.Walk) {