From c7b0d8ff94080c38a00b9337981f08d986645b78 Mon Sep 17 00:00:00 2001 From: pippellia-btc Date: Tue, 3 Jun 2025 15:35:59 +0200 Subject: [PATCH] implemented cached walker --- pkg/graph/graph.go | 6 ++ pkg/pagerank/pagerank.go | 18 +++-- pkg/redb/graph.go | 19 ++--- pkg/redb/graph_test.go | 12 +-- pkg/walks/walker.go | 163 +++++++++++++++++++++++++++++++++++---- pkg/walks/walker_test.go | 128 ++++++++++++++++++++++++++++++ pkg/walks/walks_test.go | 8 -- 7 files changed, 306 insertions(+), 48 deletions(-) create mode 100644 pkg/walks/walker_test.go diff --git a/pkg/graph/graph.go b/pkg/graph/graph.go index 7e7dadc..538d82e 100644 --- a/pkg/graph/graph.go +++ b/pkg/graph/graph.go @@ -3,6 +3,7 @@ package graph import ( + "errors" "time" ) @@ -17,6 +18,11 @@ const ( Demotion int = -1 ) +var ( + ErrNodeNotFound = errors.New("node not found") + ErrNodeAlreadyExists = errors.New("node already exists") +) + type ID string func (id ID) MarshalBinary() ([]byte, error) { return []byte(id), nil } diff --git a/pkg/pagerank/pagerank.go b/pkg/pagerank/pagerank.go index 23fa645..8ede934 100644 --- a/pkg/pagerank/pagerank.go +++ b/pkg/pagerank/pagerank.go @@ -117,19 +117,27 @@ func Personalized( return map[graph.ID]float64{source: 1.0}, nil } - followByNode, err := loader.BulkFollows(ctx, follows) + bulk, err := loader.BulkFollows(ctx, follows) if err != nil { return nil, fmt.Errorf("Personalized: failed to fetch the two-hop network of source: %w", err) } - walker := walks.NewCachedWalker(follows, followByNode, loader) - targetWalks := int(float64(targetLenght) * (1 - walks.Alpha)) + walker := walks.NewWalker( + walks.WithCapacity(10000), + walks.WithFallback(loader), + ) - walks, err := loader.WalksVisitingAny(ctx, append(follows, source), targetWalks) + if err := walker.Load(follows, bulk); err != nil { + return nil, fmt.Errorf("Personalized: failed to load the two-hop network of source: %w", err) + } + + targetWalks := int(float64(targetLenght) * (1 - walks.Alpha)) + visiting, err := loader.WalksVisitingAny(ctx, append(follows, source), targetWalks) if err != nil { return nil, fmt.Errorf("Personalized: failed to fetch the walk: %w", err) } - pool := newWalkPool(walks) + + pool := newWalkPool(visiting) walk, err := personalizedWalk(ctx, walker, pool, source, targetLenght) if err != nil { diff --git a/pkg/redb/graph.go b/pkg/redb/graph.go index 01ef805..db6f48e 100644 --- a/pkg/redb/graph.go +++ b/pkg/redb/graph.go @@ -33,11 +33,6 @@ const ( NodeAddedTS = "added_TS" // TODO: change to addition ) -var ( - ErrNodeNotFound = errors.New("node not found") - ErrNodeAlreadyExists = errors.New("node already exists") -) - type RedisDB struct { client *redis.Client } @@ -89,7 +84,7 @@ func (r RedisDB) Nodes(ctx context.Context, IDs ...graph.ID) ([]*graph.Node, err for i, cmd := range cmds { fields := cmd.Val() if len(fields) == 0 { - return nil, fmt.Errorf("failed to fetch %s: %w", node(IDs[i]), ErrNodeNotFound) + return nil, fmt.Errorf("failed to fetch %s: %w", node(IDs[i]), graph.ErrNodeNotFound) } nodes[i], err = parseNode(fields) @@ -109,7 +104,7 @@ func (r RedisDB) NodeByID(ctx context.Context, ID graph.ID) (*graph.Node, error) } if len(fields) == 0 { - return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), ErrNodeNotFound) + return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), graph.ErrNodeNotFound) } return parseNode(fields) @@ -128,7 +123,7 @@ func (r RedisDB) NodeByKey(ctx context.Context, pubkey string) (*graph.Node, err } if len(fields) == 0 { - return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, ErrNodeNotFound) + return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, graph.ErrNodeNotFound) } return parseNode(fields) @@ -159,7 +154,7 @@ func (r RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error { } if int(exists) < len(IDs) { - return ErrNodeNotFound + return graph.ErrNodeNotFound } return nil @@ -173,7 +168,7 @@ func (r RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) { } if exists { - return "", fmt.Errorf("failed to add node with pubkey %s: %w", pubkey, ErrNodeAlreadyExists) + return "", fmt.Errorf("failed to add node with pubkey %s: %w", pubkey, graph.ErrNodeAlreadyExists) } // get the ID outside the transaction, which implies there might be "holes", @@ -212,12 +207,12 @@ func (r RedisDB) Demote(ctx context.Context, ID graph.ID) error { return nil } -// Follows returns the follow list of node. If node is not found, it returns [ErrNodeNotFound]. +// Follows returns the follow list of node. If node is not found, it returns [graph.ErrNodeNotFound]. func (r RedisDB) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) { return r.members(ctx, follows, node) } -// Followers returns the list of followers of node. If node is not found, it returns [ErrNodeNotFound]. +// Followers returns the list of followers of node. If node is not found, it returns [graph.ErrNodeNotFound]. func (r RedisDB) Followers(ctx context.Context, node graph.ID) ([]graph.ID, error) { return r.members(ctx, followers, node) } diff --git a/pkg/redb/graph_test.go b/pkg/redb/graph_test.go index 365989d..464e319 100644 --- a/pkg/redb/graph_test.go +++ b/pkg/redb/graph_test.go @@ -84,8 +84,8 @@ func TestAddNode(t *testing.T) { } defer db.flushAll() - if _, err = db.AddNode(ctx, "0"); !errors.Is(err, ErrNodeAlreadyExists) { - t.Fatalf("expected error %v, got %v", ErrNodeAlreadyExists, err) + if _, err = db.AddNode(ctx, "0"); !errors.Is(err, graph.ErrNodeAlreadyExists) { + t.Fatalf("expected error %v, got %v", graph.ErrNodeAlreadyExists, err) } }) @@ -135,13 +135,13 @@ func TestMembers(t *testing.T) { name: "empty database", setup: Empty, node: "0", - err: ErrNodeNotFound, + err: graph.ErrNodeNotFound, }, { name: "node not found", setup: OneNode, node: "1", - err: ErrNodeNotFound, + err: graph.ErrNodeNotFound, }, { name: "dandling node", @@ -189,13 +189,13 @@ func TestBulkMembers(t *testing.T) { name: "empty database", setup: Empty, nodes: []graph.ID{"0"}, - err: ErrNodeNotFound, + err: graph.ErrNodeNotFound, }, { name: "node not found", setup: OneNode, nodes: []graph.ID{"0", "1"}, - err: ErrNodeNotFound, + err: graph.ErrNodeNotFound, }, { name: "dandling node", diff --git a/pkg/walks/walker.go b/pkg/walks/walker.go index 13c7145..80c05e2 100644 --- a/pkg/walks/walker.go +++ b/pkg/walks/walker.go @@ -1,8 +1,11 @@ package walks import ( + "container/list" "context" + "fmt" "github/pippellia-btc/crawler/pkg/graph" + "log" "strconv" ) @@ -38,37 +41,163 @@ func NewCyclicWalker(n int) *SimpleWalker { return &SimpleWalker{follows: follows} } -// CachedWalker is a walker with optional fallback that stores follow relationships +// CachedWalker is a [Walker] with optional fallback that stores follow relationships // in a compact format (uint32) for reduced memory footprint. +// If its size grows larger than capacity, the least recently used (LRU) key is evicted. +// It is not safe for concurrent use. type CachedWalker struct { - follows map[graph.ID][]graph.ID + lookup map[uint32]*list.Element + + // newest at the front, oldest at the back + edgeList *list.List + capacity int + + // for stats + calls, hits, misses int + fallback Walker } -func NewCachedWalker(nodes []graph.ID, follows [][]graph.ID, fallback Walker) *CachedWalker { - w := CachedWalker{ - follows: make(map[graph.ID][]graph.ID, len(nodes)), - fallback: fallback, +type Option func(*CachedWalker) + +func WithCapacity(cap int) Option { return func(c *CachedWalker) { c.capacity = cap } } +func WithFallback(f Walker) Option { return func(c *CachedWalker) { c.fallback = f } } + +func NewWalker(opts ...Option) *CachedWalker { + c := &CachedWalker{ + lookup: make(map[uint32]*list.Element, 10000), + edgeList: list.New(), + } + + for _, opt := range opts { + opt(c) + } + return c +} + +type edges struct { + node uint32 + follows []uint32 +} + +// Add compresses node and follows and adds them to the cache. +// It evicts the LRU element if the capacity has been exeeded. +func (c *CachedWalker) Add(node graph.ID, follows []graph.ID) error { + ID, err := compactID(node) + if err != nil { + return fmt.Errorf("failed to compress node %s: %w", node, err) + } + + IDs, err := compactIDs(follows) + if err != nil { + return fmt.Errorf("failed to compress follows of node %s: %w", node, err) + } + + c.add(ID, IDs) + return nil +} + +// Add node and follows as edges. It evicts the LRU element if the capacity has been exeeded. +func (c *CachedWalker) add(node uint32, follows []uint32) { + c.lookup[node] = c.edgeList.PushFront( + edges{node: node, follows: follows}, + ) + + if c.Size() > c.capacity { + oldest := c.edgeList.Back() + c.edgeList.Remove(oldest) + delete(c.lookup, oldest.Value.(edges).node) + } +} + +func (c *CachedWalker) Size() int { + return c.edgeList.Len() +} + +func (c *CachedWalker) logStats() { + log.Printf("cache: calls %d, hits %d, misses %d", c.calls, c.hits, c.misses) + c.calls, c.hits, c.misses = 0, 0, 0 +} + +func (c *CachedWalker) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) { + ID, err := compactID(node) + if err != nil { + return nil, fmt.Errorf("failed to fetch follows of %s: %w", node, err) + } + + c.calls++ + if c.calls > 10000 { + defer c.logStats() + } + + element, hit := c.lookup[ID] + if hit { + c.hits++ + c.edgeList.MoveToFront(element) + return nodes(element.Value.(edges).follows), nil + } + + c.misses++ + if c.fallback == nil { + return nil, fmt.Errorf("%w: %s", graph.ErrNodeNotFound, node) + } + + follows, err := c.fallback.Follows(ctx, node) + if err != nil { + return nil, err + } + + IDs, err := compactIDs(follows) + if err != nil { + return nil, fmt.Errorf("failed to fetch follows of %s: %w", node, err) + } + + c.add(ID, IDs) + return follows, nil +} + +func (c *CachedWalker) Load(nodes []graph.ID, follows [][]graph.ID) error { + if len(nodes) != len(follows) { + return fmt.Errorf("failed to load: nodes and follows must have the same lenght") } for i, node := range nodes { - w.follows[node] = follows[i] + if err := c.Add(node, follows[i]); err != nil { + return fmt.Errorf("failed to load: %w", err) + } } - - return &w + return nil } -func (w *CachedWalker) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) { - follows, exists := w.follows[node] - if !exists { - var err error - follows, err = w.fallback.Follows(ctx, node) +func compactID(node graph.ID) (uint32, error) { + ID, err := strconv.ParseUint(string(node), 10, 32) + if err != nil { + return 0, err + } + return uint32(ID), err +} + +func compactIDs(nodes []graph.ID) ([]uint32, error) { + IDs := make([]uint32, len(nodes)) + var err error + for i, node := range nodes { + IDs[i], err = compactID(node) if err != nil { return nil, err } - - w.follows[node] = follows } - return follows, nil + return IDs, nil +} + +func node(ID uint32) graph.ID { + return graph.ID(strconv.FormatUint(uint64(ID), 10)) +} + +func nodes(IDs []uint32) []graph.ID { + nodes := make([]graph.ID, len(IDs)) + for i, ID := range IDs { + nodes[i] = node(ID) + } + return nodes } diff --git a/pkg/walks/walker_test.go b/pkg/walks/walker_test.go new file mode 100644 index 0000000..90bfb0b --- /dev/null +++ b/pkg/walks/walker_test.go @@ -0,0 +1,128 @@ +package walks + +import ( + "context" + "errors" + "fmt" + "github/pippellia-btc/crawler/pkg/graph" + "math/rand/v2" + "reflect" + "strconv" + "testing" +) + +var ctx = context.Background() + +func TestFollows(t *testing.T) { + tests := []struct { + name string + node graph.ID + fallback Walker + expected []graph.ID + err error + }{ + { + name: "node not found, no fallback", + node: "69", + err: graph.ErrNodeNotFound, + }, + { + name: "node not found, fallback", + node: "1", + fallback: NewSimpleWalker(map[graph.ID][]graph.ID{"1": {"2"}}), + expected: []graph.ID{"2"}, + }, + { + name: "node found", + node: "0", + expected: []graph.ID{"1"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + walker := NewWalker( + WithCapacity(1), + WithFallback(test.fallback), + ) + walker.add(0, []uint32{1}) + + follows, err := walker.Follows(ctx, test.node) + if !errors.Is(err, test.err) { + t.Fatalf("expected error %v, got %v", test.err, err) + } + + if !reflect.DeepEqual(follows, test.expected) { + t.Fatalf("expected follows %v, got %v", test.expected, follows) + } + + if walker.Size() != 1 { + t.Fatalf("failed to evict keys %d", walker.Size()) + } + }) + } +} + +func BenchmarkAdd(b *testing.B) { + sizes := []int{1000, 10000, 100000} + for _, size := range sizes { + b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) { + walker := NewWalker( + WithCapacity(size), + ) + + for { + // fill-up the cached walker outside the benchmark + if walker.Size() >= size { + break + } + + node := rand.Uint32N(1000000) + follows := randomCompactIDs(1000) + walker.add(node, follows) + } + + node := rand.Uint32N(1000000) + follows := randomCompactIDs(1000) + + b.ResetTimer() + for range b.N { + walker.add(node, follows) + } + }) + } + + for range b.N { + + } +} + +func BenchmarkCompactIDs(b *testing.B) { + follows := randomFollows(1000) + b.ResetTimer() + + for range b.N { + IDs, err := compactIDs(follows) + if err != nil { + b.Fatalf("failed to convert to compact IDs: %v", err) + } + nodes(IDs) + } +} + +func randomFollows(size int) []graph.ID { + follows := make([]graph.ID, size) + for i := range size { + node := rand.IntN(10000000) + follows[i] = graph.ID(strconv.Itoa(node)) + } + return follows +} + +func randomCompactIDs(size int) []uint32 { + follows := make([]uint32, size) + for i := range size { + follows[i] = rand.Uint32N(1000000) + } + return follows +} diff --git a/pkg/walks/walks_test.go b/pkg/walks/walks_test.go index 9ea7261..2237f88 100644 --- a/pkg/walks/walks_test.go +++ b/pkg/walks/walks_test.go @@ -189,14 +189,6 @@ func TestFindCycle(t *testing.T) { // t.Fatalf("Approx. memory used by map: %.2f MB\n", used) // } -// func randomFollows(size int) []int { -// follows := make([]int, size) -// for i := range size { -// follows[i] = rand.Int() -// } -// return follows -// } - func BenchmarkFindCycle(b *testing.B) { sizes := []int{10, 100, 1000} for _, size := range sizes {