completed redb package

This commit is contained in:
pippellia-btc
2025-05-29 12:53:08 +02:00
parent 921b21d6c1
commit 28be0dbfbd
7 changed files with 315 additions and 17 deletions

View File

@@ -11,11 +11,17 @@ type cachedWalker struct {
fallback walks.Walker
}
func newCachedWalker(followsMap map[graph.ID][]graph.ID, fallback walks.Walker) *cachedWalker {
return &cachedWalker{
follows: followsMap,
func newCachedWalker(nodes []graph.ID, follows [][]graph.ID, fallback walks.Walker) *cachedWalker {
w := cachedWalker{
follows: make(map[graph.ID][]graph.ID, len(nodes)),
fallback: fallback,
}
for i, node := range nodes {
w.follows[node] = follows[i]
}
return &w
}
func (w *cachedWalker) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) {

View File

@@ -57,7 +57,7 @@ type PersonalizedLoader interface {
Follows(ctx context.Context, node graph.ID) ([]graph.ID, error)
// BulkFollows returns the follow-lists of the specified nodes
BulkFollows(ctx context.Context, nodes []graph.ID) (map[graph.ID][]graph.ID, error)
BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error)
// WalksVisitingAny returns up to limit walks that visit the specified nodes.
// The walks are distributed evenly among the nodes:
@@ -116,7 +116,7 @@ func Personalized(
return map[graph.ID]float64{source: 1.0}, nil
}
followMap, err := loader.BulkFollows(ctx, follows)
followByNode, err := loader.BulkFollows(ctx, follows)
if err != nil {
return nil, fmt.Errorf("Personalized: failed to fetch the two-hop network of source: %w", err)
}
@@ -127,7 +127,7 @@ func Personalized(
return nil, fmt.Errorf("Personalized: failed to fetch the walk: %w", err)
}
walker := newCachedWalker(followMap, loader)
walker := newCachedWalker(follows, followByNode, loader)
pool := newWalkPool(walks)
walk, err := personalizedWalk(ctx, walker, pool, source, targetLenght)

View File

@@ -185,6 +185,72 @@ func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node gr
return toNodes(members), nil
}
// BulkFollows returns the follow-lists of all the provided nodes.
// Do not call on too many nodes (e.g. +100k) to avoid too many recursions.
func (r RedisDB) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) {
return r.bulkMembers(ctx, follows, nodes)
}
func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nodes []graph.ID) ([][]graph.ID, error) {
switch {
case len(nodes) == 0:
return nil, nil
case len(nodes) < 10000:
pipe := r.client.Pipeline()
cmds := make([]*redis.StringSliceCmd, len(nodes))
for i, node := range nodes {
cmds[i] = pipe.SMembers(ctx, key(node))
}
if _, err := pipe.Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to fetch the %s of %d nodes: %w", key(""), len(nodes), err)
}
var empty []string
members := make([][]graph.ID, len(nodes))
for i, cmd := range cmds {
m := cmd.Val()
if len(m) == 0 {
// empty slice might mean node not found.
empty = append(empty, node(nodes[i]))
}
members[i] = toNodes(m)
}
if len(empty) > 0 {
exists, err := r.client.Exists(ctx, empty...).Result()
if err != nil {
return nil, err
}
if int(exists) < len(empty) {
return nil, fmt.Errorf("failed to fetch the %s of these nodes %v: %w", key(""), empty, ErrNodeNotFound)
}
}
return members, nil
default:
// too many nodes, split them in two batches
mid := len(nodes) / 2
batch1, err := r.bulkMembers(ctx, key, nodes[:mid])
if err != nil {
return nil, err
}
batch2, err := r.bulkMembers(ctx, key, nodes[mid:])
if err != nil {
return nil, err
}
return append(batch1, batch2...), nil
}
}
// FollowCounts returns the number of follows each node has. If a node is not found, it returns 0.
func (r RedisDB) FollowCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
return r.counts(ctx, follows, nodes...)
@@ -202,15 +268,13 @@ func (r RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes ..
pipe := r.client.Pipeline()
cmds := make([]*redis.IntCmd, len(nodes))
keys := make([]string, len(nodes))
for i, node := range nodes {
keys[i] = key(node)
cmds[i] = pipe.SCard(ctx, key(node))
}
if _, err := pipe.Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to count the elements of %v: %w", keys, err)
return nil, fmt.Errorf("failed to count the elements of %d nodes: %w", len(nodes), err)
}
counts := make([]int, len(nodes))

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"github/pippellia-btc/crawler/pkg/graph"
"github/pippellia-btc/crawler/pkg/pagerank"
"reflect"
"testing"
"time"
@@ -175,6 +176,60 @@ func TestMembers(t *testing.T) {
}
}
func TestBulkMembers(t *testing.T) {
tests := []struct {
name string
setup func() (RedisDB, error)
nodes []graph.ID
expected [][]graph.ID
err error
}{
{
name: "empty database",
setup: Empty,
nodes: []graph.ID{"0"},
err: ErrNodeNotFound,
},
{
name: "node not found",
setup: OneNode,
nodes: []graph.ID{"0", "1"},
err: ErrNodeNotFound,
},
{
name: "dandling node",
setup: OneNode,
nodes: []graph.ID{"0"},
expected: [][]graph.ID{{}},
},
{
name: "valid",
setup: Simple,
nodes: []graph.ID{"0", "1"},
expected: [][]graph.ID{{"1"}, {}},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db, err := test.setup()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
follows, err := db.bulkMembers(ctx, follows, test.nodes)
if !errors.Is(err, test.err) {
t.Fatalf("expected error %v, got %v", test.err, err)
}
if !reflect.DeepEqual(follows, test.expected) {
t.Errorf("expected follows %v, got %v", test.expected, follows)
}
})
}
}
func TestUpdateFollows(t *testing.T) {
db, err := Simple()
if err != nil {
@@ -315,6 +370,10 @@ func TestPubkeys(t *testing.T) {
}
}
func TestInterfaces(t *testing.T) {
var _ pagerank.PersonalizedLoader = RedisDB{}
}
// ------------------------------------- HELPERS -------------------------------
func Empty() (RedisDB, error) {

View File

@@ -1,13 +1,17 @@
package redb
import (
"cmp"
"context"
"errors"
"fmt"
"github/pippellia-btc/crawler/pkg/graph"
"github/pippellia-btc/crawler/pkg/walks"
"math"
"slices"
"strconv"
"github.com/redis/go-redis/v9"
)
const (
@@ -23,7 +27,7 @@ const (
var (
ErrWalkNotFound = errors.New("walk not found")
ErrInvalidReplacement = errors.New("invalid walk replacement")
ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks visiting node")
ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks")
)
// Walks returns the walks associated with the IDs.
@@ -94,6 +98,65 @@ func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([
}
}
// WalksVisitingAny returns up to limit walks that visit the specified nodes.
// The walks are distributed evenly among the nodes:
// - if limit == -1, all walks are returned (use with few nodes)
// - if limit < len(nodes), no walks are returned
func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit int) ([]walks.Walk, error) {
switch {
case limit == -1:
// return all walks visiting all nodes
pipe := r.client.Pipeline()
cmds := make([]*redis.StringSliceCmd, len(nodes))
for i, node := range nodes {
cmds[i] = pipe.SMembers(ctx, walksVisiting(node))
}
if _, err := pipe.Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to fetch all walks visiting %d nodes: %w", len(nodes), err)
}
IDs := make([]string, 0, walks.N*len(nodes))
for _, cmd := range cmds {
IDs = append(IDs, cmd.Val()...)
}
unique := unique(IDs)
return r.Walks(ctx, toWalks(unique)...)
case limit > 0:
// return limit walks uniformely distributed across all nodes
nodeLimit := int64(limit / len(nodes))
if nodeLimit == 0 {
return nil, nil
}
pipe := r.client.Pipeline()
cmds := make([]*redis.StringSliceCmd, len(nodes))
for i, node := range nodes {
cmds[i] = pipe.SRandMemberN(ctx, walksVisiting(node), nodeLimit)
}
if _, err := pipe.Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to fetch %d walks visiting %d nodes: %w", limit, len(nodes), err)
}
IDs := make([]string, 0, limit)
for _, cmd := range cmds {
IDs = append(IDs, cmd.Val()...)
}
unique := unique(IDs)
return r.Walks(ctx, toWalks(unique)...)
default:
// invalid limit
return nil, fmt.Errorf("failed to fetch walks visiting any: %w", ErrInvalidLimit)
}
}
// AddWalks adds all the walks to the database assigning them progressive IDs.
func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error {
if len(walks) == 0 {
@@ -278,3 +341,22 @@ func (r RedisDB) validateWalks() error {
return nil
}
// unique returns a slice of unique elements of the input slice.
func unique[S ~[]E, E cmp.Ordered](slice S) S {
if len(slice) == 0 {
return nil
}
slices.Sort(slice)
unique := make(S, 0, len(slice))
unique = append(unique, slice[0])
for i := 1; i < len(slice); i++ {
if slice[i] != slice[i-1] {
unique = append(unique, slice[i])
}
}
return unique
}

View File

@@ -90,6 +90,61 @@ func TestWalksVisiting(t *testing.T) {
}
}
func TestWalksVisitingAny(t *testing.T) {
tests := []struct {
name string
setup func() (RedisDB, error)
limit int
expectedWalks int // the number of [defaultWalk] returned
}{
{
name: "empty",
setup: SomeWalks(0),
limit: 1,
expectedWalks: 0,
},
{
name: "all walks",
setup: SomeWalks(10),
limit: -1,
expectedWalks: 10,
},
{
name: "some walks",
setup: SomeWalks(100),
limit: 20,
expectedWalks: 20,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db, err := test.setup()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
nodes := []graph.ID{"0", "1"}
visiting, err := db.WalksVisitingAny(ctx, nodes, test.limit)
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if len(visiting) > test.expectedWalks {
t.Fatalf("expected %d walks, got %d", test.expectedWalks, len(visiting))
}
for _, walk := range visiting {
if !reflect.DeepEqual(walk.Path, defaultWalk.Path) {
// compare only the paths, not the IDs
t.Fatalf("expected walk %v, got %v", defaultWalk, walk)
}
}
})
}
}
func TestAddWalks(t *testing.T) {
db, err := SomeWalks(1)()
if err != nil {
@@ -297,6 +352,38 @@ func TestValidateReplacement(t *testing.T) {
}
}
func TestUnique(t *testing.T) {
tests := []struct {
slice []walks.ID
expected []walks.ID
}{
{slice: nil, expected: nil},
{slice: []walks.ID{}, expected: nil},
{slice: []walks.ID{"1", "2", "0"}, expected: []walks.ID{"0", "1", "2"}},
{slice: []walks.ID{"1", "2", "0", "3", "1", "0"}, expected: []walks.ID{"0", "1", "2", "3"}},
}
for _, test := range tests {
unique := unique(test.slice)
if !reflect.DeepEqual(unique, test.expected) {
t.Errorf("expected %v, got %v", test.expected, unique)
}
}
}
func BenchmarkUnique(b *testing.B) {
size := 1000000
IDs := make([]walks.ID, size)
for i := 0; i < size; i++ {
IDs[i] = walks.ID(strconv.Itoa(i))
}
b.ResetTimer()
for range b.N {
unique(IDs)
}
}
var defaultWalk = walks.Walk{Path: []graph.ID{"0", "1"}}
func SomeWalks(n int) func() (RedisDB, error) {

View File

@@ -99,18 +99,18 @@ func (l *mockLoader) Follows(ctx context.Context, node graph.ID) ([]graph.ID, er
return l.walker.Follows(ctx, node)
}
func (l *mockLoader) BulkFollows(ctx context.Context, nodes []graph.ID) (map[graph.ID][]graph.ID, error) {
followsMap := make(map[graph.ID][]graph.ID, len(nodes))
for _, node := range nodes {
follows, err := l.walker.Follows(ctx, node)
func (l *mockLoader) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) {
var err error
follows := make([][]graph.ID, len(nodes))
for i, node := range nodes {
follows[i], err = l.walker.Follows(ctx, node)
if err != nil {
return nil, err
}
followsMap[node] = follows
}
return followsMap, nil
return follows, nil
}
func (l *mockLoader) AddWalks(w []walks.Walk) {