implemented cached walker

This commit is contained in:
pippellia-btc
2025-06-03 15:35:59 +02:00
parent 7d82354540
commit c7b0d8ff94
7 changed files with 306 additions and 48 deletions

View File

@@ -3,6 +3,7 @@
package graph
import (
"errors"
"time"
)
@@ -17,6 +18,11 @@ const (
Demotion int = -1
)
var (
ErrNodeNotFound = errors.New("node not found")
ErrNodeAlreadyExists = errors.New("node already exists")
)
type ID string
func (id ID) MarshalBinary() ([]byte, error) { return []byte(id), nil }

View File

@@ -117,19 +117,27 @@ func Personalized(
return map[graph.ID]float64{source: 1.0}, nil
}
followByNode, err := loader.BulkFollows(ctx, follows)
bulk, err := loader.BulkFollows(ctx, follows)
if err != nil {
return nil, fmt.Errorf("Personalized: failed to fetch the two-hop network of source: %w", err)
}
walker := walks.NewCachedWalker(follows, followByNode, loader)
targetWalks := int(float64(targetLenght) * (1 - walks.Alpha))
walker := walks.NewWalker(
walks.WithCapacity(10000),
walks.WithFallback(loader),
)
walks, err := loader.WalksVisitingAny(ctx, append(follows, source), targetWalks)
if err := walker.Load(follows, bulk); err != nil {
return nil, fmt.Errorf("Personalized: failed to load the two-hop network of source: %w", err)
}
targetWalks := int(float64(targetLenght) * (1 - walks.Alpha))
visiting, err := loader.WalksVisitingAny(ctx, append(follows, source), targetWalks)
if err != nil {
return nil, fmt.Errorf("Personalized: failed to fetch the walk: %w", err)
}
pool := newWalkPool(walks)
pool := newWalkPool(visiting)
walk, err := personalizedWalk(ctx, walker, pool, source, targetLenght)
if err != nil {

View File

@@ -33,11 +33,6 @@ const (
NodeAddedTS = "added_TS" // TODO: change to addition
)
var (
ErrNodeNotFound = errors.New("node not found")
ErrNodeAlreadyExists = errors.New("node already exists")
)
type RedisDB struct {
client *redis.Client
}
@@ -89,7 +84,7 @@ func (r RedisDB) Nodes(ctx context.Context, IDs ...graph.ID) ([]*graph.Node, err
for i, cmd := range cmds {
fields := cmd.Val()
if len(fields) == 0 {
return nil, fmt.Errorf("failed to fetch %s: %w", node(IDs[i]), ErrNodeNotFound)
return nil, fmt.Errorf("failed to fetch %s: %w", node(IDs[i]), graph.ErrNodeNotFound)
}
nodes[i], err = parseNode(fields)
@@ -109,7 +104,7 @@ func (r RedisDB) NodeByID(ctx context.Context, ID graph.ID) (*graph.Node, error)
}
if len(fields) == 0 {
return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), ErrNodeNotFound)
return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), graph.ErrNodeNotFound)
}
return parseNode(fields)
@@ -128,7 +123,7 @@ func (r RedisDB) NodeByKey(ctx context.Context, pubkey string) (*graph.Node, err
}
if len(fields) == 0 {
return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, ErrNodeNotFound)
return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, graph.ErrNodeNotFound)
}
return parseNode(fields)
@@ -159,7 +154,7 @@ func (r RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error {
}
if int(exists) < len(IDs) {
return ErrNodeNotFound
return graph.ErrNodeNotFound
}
return nil
@@ -173,7 +168,7 @@ func (r RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) {
}
if exists {
return "", fmt.Errorf("failed to add node with pubkey %s: %w", pubkey, ErrNodeAlreadyExists)
return "", fmt.Errorf("failed to add node with pubkey %s: %w", pubkey, graph.ErrNodeAlreadyExists)
}
// get the ID outside the transaction, which implies there might be "holes",
@@ -212,12 +207,12 @@ func (r RedisDB) Demote(ctx context.Context, ID graph.ID) error {
return nil
}
// Follows returns the follow list of node. If node is not found, it returns [ErrNodeNotFound].
// Follows returns the follow list of node. If node is not found, it returns [graph.ErrNodeNotFound].
func (r RedisDB) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) {
return r.members(ctx, follows, node)
}
// Followers returns the list of followers of node. If node is not found, it returns [ErrNodeNotFound].
// Followers returns the list of followers of node. If node is not found, it returns [graph.ErrNodeNotFound].
func (r RedisDB) Followers(ctx context.Context, node graph.ID) ([]graph.ID, error) {
return r.members(ctx, followers, node)
}

View File

@@ -84,8 +84,8 @@ func TestAddNode(t *testing.T) {
}
defer db.flushAll()
if _, err = db.AddNode(ctx, "0"); !errors.Is(err, ErrNodeAlreadyExists) {
t.Fatalf("expected error %v, got %v", ErrNodeAlreadyExists, err)
if _, err = db.AddNode(ctx, "0"); !errors.Is(err, graph.ErrNodeAlreadyExists) {
t.Fatalf("expected error %v, got %v", graph.ErrNodeAlreadyExists, err)
}
})
@@ -135,13 +135,13 @@ func TestMembers(t *testing.T) {
name: "empty database",
setup: Empty,
node: "0",
err: ErrNodeNotFound,
err: graph.ErrNodeNotFound,
},
{
name: "node not found",
setup: OneNode,
node: "1",
err: ErrNodeNotFound,
err: graph.ErrNodeNotFound,
},
{
name: "dandling node",
@@ -189,13 +189,13 @@ func TestBulkMembers(t *testing.T) {
name: "empty database",
setup: Empty,
nodes: []graph.ID{"0"},
err: ErrNodeNotFound,
err: graph.ErrNodeNotFound,
},
{
name: "node not found",
setup: OneNode,
nodes: []graph.ID{"0", "1"},
err: ErrNodeNotFound,
err: graph.ErrNodeNotFound,
},
{
name: "dandling node",

View File

@@ -1,8 +1,11 @@
package walks
import (
"container/list"
"context"
"fmt"
"github/pippellia-btc/crawler/pkg/graph"
"log"
"strconv"
)
@@ -38,37 +41,163 @@ func NewCyclicWalker(n int) *SimpleWalker {
return &SimpleWalker{follows: follows}
}
// CachedWalker is a walker with optional fallback that stores follow relationships
// CachedWalker is a [Walker] with optional fallback that stores follow relationships
// in a compact format (uint32) for reduced memory footprint.
// If its size grows larger than capacity, the least recently used (LRU) key is evicted.
// It is not safe for concurrent use.
type CachedWalker struct {
follows map[graph.ID][]graph.ID
lookup map[uint32]*list.Element
// newest at the front, oldest at the back
edgeList *list.List
capacity int
// for stats
calls, hits, misses int
fallback Walker
}
func NewCachedWalker(nodes []graph.ID, follows [][]graph.ID, fallback Walker) *CachedWalker {
w := CachedWalker{
follows: make(map[graph.ID][]graph.ID, len(nodes)),
fallback: fallback,
type Option func(*CachedWalker)
func WithCapacity(cap int) Option { return func(c *CachedWalker) { c.capacity = cap } }
func WithFallback(f Walker) Option { return func(c *CachedWalker) { c.fallback = f } }
func NewWalker(opts ...Option) *CachedWalker {
c := &CachedWalker{
lookup: make(map[uint32]*list.Element, 10000),
edgeList: list.New(),
}
for i, node := range nodes {
w.follows[node] = follows[i]
for _, opt := range opts {
opt(c)
}
return &w
return c
}
func (w *CachedWalker) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) {
follows, exists := w.follows[node]
if !exists {
var err error
follows, err = w.fallback.Follows(ctx, node)
type edges struct {
node uint32
follows []uint32
}
// Add compresses node and follows and adds them to the cache.
// It evicts the LRU element if the capacity has been exeeded.
func (c *CachedWalker) Add(node graph.ID, follows []graph.ID) error {
ID, err := compactID(node)
if err != nil {
return fmt.Errorf("failed to compress node %s: %w", node, err)
}
IDs, err := compactIDs(follows)
if err != nil {
return fmt.Errorf("failed to compress follows of node %s: %w", node, err)
}
c.add(ID, IDs)
return nil
}
// Add node and follows as edges. It evicts the LRU element if the capacity has been exeeded.
func (c *CachedWalker) add(node uint32, follows []uint32) {
c.lookup[node] = c.edgeList.PushFront(
edges{node: node, follows: follows},
)
if c.Size() > c.capacity {
oldest := c.edgeList.Back()
c.edgeList.Remove(oldest)
delete(c.lookup, oldest.Value.(edges).node)
}
}
func (c *CachedWalker) Size() int {
return c.edgeList.Len()
}
func (c *CachedWalker) logStats() {
log.Printf("cache: calls %d, hits %d, misses %d", c.calls, c.hits, c.misses)
c.calls, c.hits, c.misses = 0, 0, 0
}
func (c *CachedWalker) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) {
ID, err := compactID(node)
if err != nil {
return nil, fmt.Errorf("failed to fetch follows of %s: %w", node, err)
}
c.calls++
if c.calls > 10000 {
defer c.logStats()
}
element, hit := c.lookup[ID]
if hit {
c.hits++
c.edgeList.MoveToFront(element)
return nodes(element.Value.(edges).follows), nil
}
c.misses++
if c.fallback == nil {
return nil, fmt.Errorf("%w: %s", graph.ErrNodeNotFound, node)
}
follows, err := c.fallback.Follows(ctx, node)
if err != nil {
return nil, err
}
w.follows[node] = follows
IDs, err := compactIDs(follows)
if err != nil {
return nil, fmt.Errorf("failed to fetch follows of %s: %w", node, err)
}
c.add(ID, IDs)
return follows, nil
}
func (c *CachedWalker) Load(nodes []graph.ID, follows [][]graph.ID) error {
if len(nodes) != len(follows) {
return fmt.Errorf("failed to load: nodes and follows must have the same lenght")
}
for i, node := range nodes {
if err := c.Add(node, follows[i]); err != nil {
return fmt.Errorf("failed to load: %w", err)
}
}
return nil
}
func compactID(node graph.ID) (uint32, error) {
ID, err := strconv.ParseUint(string(node), 10, 32)
if err != nil {
return 0, err
}
return uint32(ID), err
}
func compactIDs(nodes []graph.ID) ([]uint32, error) {
IDs := make([]uint32, len(nodes))
var err error
for i, node := range nodes {
IDs[i], err = compactID(node)
if err != nil {
return nil, err
}
}
return IDs, nil
}
func node(ID uint32) graph.ID {
return graph.ID(strconv.FormatUint(uint64(ID), 10))
}
func nodes(IDs []uint32) []graph.ID {
nodes := make([]graph.ID, len(IDs))
for i, ID := range IDs {
nodes[i] = node(ID)
}
return nodes
}

128
pkg/walks/walker_test.go Normal file
View File

@@ -0,0 +1,128 @@
package walks
import (
"context"
"errors"
"fmt"
"github/pippellia-btc/crawler/pkg/graph"
"math/rand/v2"
"reflect"
"strconv"
"testing"
)
var ctx = context.Background()
func TestFollows(t *testing.T) {
tests := []struct {
name string
node graph.ID
fallback Walker
expected []graph.ID
err error
}{
{
name: "node not found, no fallback",
node: "69",
err: graph.ErrNodeNotFound,
},
{
name: "node not found, fallback",
node: "1",
fallback: NewSimpleWalker(map[graph.ID][]graph.ID{"1": {"2"}}),
expected: []graph.ID{"2"},
},
{
name: "node found",
node: "0",
expected: []graph.ID{"1"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
walker := NewWalker(
WithCapacity(1),
WithFallback(test.fallback),
)
walker.add(0, []uint32{1})
follows, err := walker.Follows(ctx, test.node)
if !errors.Is(err, test.err) {
t.Fatalf("expected error %v, got %v", test.err, err)
}
if !reflect.DeepEqual(follows, test.expected) {
t.Fatalf("expected follows %v, got %v", test.expected, follows)
}
if walker.Size() != 1 {
t.Fatalf("failed to evict keys %d", walker.Size())
}
})
}
}
func BenchmarkAdd(b *testing.B) {
sizes := []int{1000, 10000, 100000}
for _, size := range sizes {
b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) {
walker := NewWalker(
WithCapacity(size),
)
for {
// fill-up the cached walker outside the benchmark
if walker.Size() >= size {
break
}
node := rand.Uint32N(1000000)
follows := randomCompactIDs(1000)
walker.add(node, follows)
}
node := rand.Uint32N(1000000)
follows := randomCompactIDs(1000)
b.ResetTimer()
for range b.N {
walker.add(node, follows)
}
})
}
for range b.N {
}
}
func BenchmarkCompactIDs(b *testing.B) {
follows := randomFollows(1000)
b.ResetTimer()
for range b.N {
IDs, err := compactIDs(follows)
if err != nil {
b.Fatalf("failed to convert to compact IDs: %v", err)
}
nodes(IDs)
}
}
func randomFollows(size int) []graph.ID {
follows := make([]graph.ID, size)
for i := range size {
node := rand.IntN(10000000)
follows[i] = graph.ID(strconv.Itoa(node))
}
return follows
}
func randomCompactIDs(size int) []uint32 {
follows := make([]uint32, size)
for i := range size {
follows[i] = rand.Uint32N(1000000)
}
return follows
}

View File

@@ -189,14 +189,6 @@ func TestFindCycle(t *testing.T) {
// t.Fatalf("Approx. memory used by map: %.2f MB\n", used)
// }
// func randomFollows(size int) []int {
// follows := make([]int, size)
// for i := range size {
// follows[i] = rand.Int()
// }
// return follows
// }
func BenchmarkFindCycle(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {