mirror of
https://github.com/aljazceru/crawler_v2.git
synced 2025-12-17 07:24:21 +01:00
completed redb package
This commit is contained in:
@@ -11,11 +11,17 @@ type cachedWalker struct {
|
||||
fallback walks.Walker
|
||||
}
|
||||
|
||||
func newCachedWalker(followsMap map[graph.ID][]graph.ID, fallback walks.Walker) *cachedWalker {
|
||||
return &cachedWalker{
|
||||
follows: followsMap,
|
||||
func newCachedWalker(nodes []graph.ID, follows [][]graph.ID, fallback walks.Walker) *cachedWalker {
|
||||
w := cachedWalker{
|
||||
follows: make(map[graph.ID][]graph.ID, len(nodes)),
|
||||
fallback: fallback,
|
||||
}
|
||||
|
||||
for i, node := range nodes {
|
||||
w.follows[node] = follows[i]
|
||||
}
|
||||
|
||||
return &w
|
||||
}
|
||||
|
||||
func (w *cachedWalker) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) {
|
||||
|
||||
@@ -57,7 +57,7 @@ type PersonalizedLoader interface {
|
||||
Follows(ctx context.Context, node graph.ID) ([]graph.ID, error)
|
||||
|
||||
// BulkFollows returns the follow-lists of the specified nodes
|
||||
BulkFollows(ctx context.Context, nodes []graph.ID) (map[graph.ID][]graph.ID, error)
|
||||
BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error)
|
||||
|
||||
// WalksVisitingAny returns up to limit walks that visit the specified nodes.
|
||||
// The walks are distributed evenly among the nodes:
|
||||
@@ -116,7 +116,7 @@ func Personalized(
|
||||
return map[graph.ID]float64{source: 1.0}, nil
|
||||
}
|
||||
|
||||
followMap, err := loader.BulkFollows(ctx, follows)
|
||||
followByNode, err := loader.BulkFollows(ctx, follows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Personalized: failed to fetch the two-hop network of source: %w", err)
|
||||
}
|
||||
@@ -127,7 +127,7 @@ func Personalized(
|
||||
return nil, fmt.Errorf("Personalized: failed to fetch the walk: %w", err)
|
||||
}
|
||||
|
||||
walker := newCachedWalker(followMap, loader)
|
||||
walker := newCachedWalker(follows, followByNode, loader)
|
||||
pool := newWalkPool(walks)
|
||||
|
||||
walk, err := personalizedWalk(ctx, walker, pool, source, targetLenght)
|
||||
|
||||
@@ -185,6 +185,72 @@ func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node gr
|
||||
return toNodes(members), nil
|
||||
}
|
||||
|
||||
// BulkFollows returns the follow-lists of all the provided nodes.
|
||||
// Do not call on too many nodes (e.g. +100k) to avoid too many recursions.
|
||||
func (r RedisDB) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) {
|
||||
return r.bulkMembers(ctx, follows, nodes)
|
||||
}
|
||||
|
||||
func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nodes []graph.ID) ([][]graph.ID, error) {
|
||||
switch {
|
||||
case len(nodes) == 0:
|
||||
return nil, nil
|
||||
|
||||
case len(nodes) < 10000:
|
||||
pipe := r.client.Pipeline()
|
||||
cmds := make([]*redis.StringSliceCmd, len(nodes))
|
||||
|
||||
for i, node := range nodes {
|
||||
cmds[i] = pipe.SMembers(ctx, key(node))
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch the %s of %d nodes: %w", key(""), len(nodes), err)
|
||||
}
|
||||
|
||||
var empty []string
|
||||
members := make([][]graph.ID, len(nodes))
|
||||
|
||||
for i, cmd := range cmds {
|
||||
m := cmd.Val()
|
||||
if len(m) == 0 {
|
||||
// empty slice might mean node not found.
|
||||
empty = append(empty, node(nodes[i]))
|
||||
}
|
||||
|
||||
members[i] = toNodes(m)
|
||||
}
|
||||
|
||||
if len(empty) > 0 {
|
||||
exists, err := r.client.Exists(ctx, empty...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if int(exists) < len(empty) {
|
||||
return nil, fmt.Errorf("failed to fetch the %s of these nodes %v: %w", key(""), empty, ErrNodeNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
return members, nil
|
||||
|
||||
default:
|
||||
// too many nodes, split them in two batches
|
||||
mid := len(nodes) / 2
|
||||
batch1, err := r.bulkMembers(ctx, key, nodes[:mid])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
batch2, err := r.bulkMembers(ctx, key, nodes[mid:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(batch1, batch2...), nil
|
||||
}
|
||||
}
|
||||
|
||||
// FollowCounts returns the number of follows each node has. If a node is not found, it returns 0.
|
||||
func (r RedisDB) FollowCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
|
||||
return r.counts(ctx, follows, nodes...)
|
||||
@@ -202,15 +268,13 @@ func (r RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes ..
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
cmds := make([]*redis.IntCmd, len(nodes))
|
||||
keys := make([]string, len(nodes))
|
||||
|
||||
for i, node := range nodes {
|
||||
keys[i] = key(node)
|
||||
cmds[i] = pipe.SCard(ctx, key(node))
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to count the elements of %v: %w", keys, err)
|
||||
return nil, fmt.Errorf("failed to count the elements of %d nodes: %w", len(nodes), err)
|
||||
}
|
||||
|
||||
counts := make([]int, len(nodes))
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"github/pippellia-btc/crawler/pkg/graph"
|
||||
"github/pippellia-btc/crawler/pkg/pagerank"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -175,6 +176,60 @@ func TestMembers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkMembers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() (RedisDB, error)
|
||||
nodes []graph.ID
|
||||
expected [][]graph.ID
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "empty database",
|
||||
setup: Empty,
|
||||
nodes: []graph.ID{"0"},
|
||||
err: ErrNodeNotFound,
|
||||
},
|
||||
{
|
||||
name: "node not found",
|
||||
setup: OneNode,
|
||||
nodes: []graph.ID{"0", "1"},
|
||||
err: ErrNodeNotFound,
|
||||
},
|
||||
{
|
||||
name: "dandling node",
|
||||
setup: OneNode,
|
||||
nodes: []graph.ID{"0"},
|
||||
expected: [][]graph.ID{{}},
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
setup: Simple,
|
||||
nodes: []graph.ID{"0", "1"},
|
||||
expected: [][]graph.ID{{"1"}, {}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
db, err := test.setup()
|
||||
if err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
defer db.flushAll()
|
||||
|
||||
follows, err := db.bulkMembers(ctx, follows, test.nodes)
|
||||
if !errors.Is(err, test.err) {
|
||||
t.Fatalf("expected error %v, got %v", test.err, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(follows, test.expected) {
|
||||
t.Errorf("expected follows %v, got %v", test.expected, follows)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFollows(t *testing.T) {
|
||||
db, err := Simple()
|
||||
if err != nil {
|
||||
@@ -315,6 +370,10 @@ func TestPubkeys(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaces(t *testing.T) {
|
||||
var _ pagerank.PersonalizedLoader = RedisDB{}
|
||||
}
|
||||
|
||||
// ------------------------------------- HELPERS -------------------------------
|
||||
|
||||
func Empty() (RedisDB, error) {
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package redb
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github/pippellia-btc/crawler/pkg/graph"
|
||||
"github/pippellia-btc/crawler/pkg/walks"
|
||||
"math"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,7 +27,7 @@ const (
|
||||
var (
|
||||
ErrWalkNotFound = errors.New("walk not found")
|
||||
ErrInvalidReplacement = errors.New("invalid walk replacement")
|
||||
ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks visiting node")
|
||||
ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks")
|
||||
)
|
||||
|
||||
// Walks returns the walks associated with the IDs.
|
||||
@@ -94,6 +98,65 @@ func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([
|
||||
}
|
||||
}
|
||||
|
||||
// WalksVisitingAny returns up to limit walks that visit the specified nodes.
|
||||
// The walks are distributed evenly among the nodes:
|
||||
// - if limit == -1, all walks are returned (use with few nodes)
|
||||
// - if limit < len(nodes), no walks are returned
|
||||
func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit int) ([]walks.Walk, error) {
|
||||
switch {
|
||||
case limit == -1:
|
||||
// return all walks visiting all nodes
|
||||
pipe := r.client.Pipeline()
|
||||
cmds := make([]*redis.StringSliceCmd, len(nodes))
|
||||
|
||||
for i, node := range nodes {
|
||||
cmds[i] = pipe.SMembers(ctx, walksVisiting(node))
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch all walks visiting %d nodes: %w", len(nodes), err)
|
||||
}
|
||||
|
||||
IDs := make([]string, 0, walks.N*len(nodes))
|
||||
for _, cmd := range cmds {
|
||||
IDs = append(IDs, cmd.Val()...)
|
||||
}
|
||||
|
||||
unique := unique(IDs)
|
||||
return r.Walks(ctx, toWalks(unique)...)
|
||||
|
||||
case limit > 0:
|
||||
// return limit walks uniformely distributed across all nodes
|
||||
nodeLimit := int64(limit / len(nodes))
|
||||
if nodeLimit == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
cmds := make([]*redis.StringSliceCmd, len(nodes))
|
||||
|
||||
for i, node := range nodes {
|
||||
cmds[i] = pipe.SRandMemberN(ctx, walksVisiting(node), nodeLimit)
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch %d walks visiting %d nodes: %w", limit, len(nodes), err)
|
||||
}
|
||||
|
||||
IDs := make([]string, 0, limit)
|
||||
for _, cmd := range cmds {
|
||||
IDs = append(IDs, cmd.Val()...)
|
||||
}
|
||||
|
||||
unique := unique(IDs)
|
||||
return r.Walks(ctx, toWalks(unique)...)
|
||||
|
||||
default:
|
||||
// invalid limit
|
||||
return nil, fmt.Errorf("failed to fetch walks visiting any: %w", ErrInvalidLimit)
|
||||
}
|
||||
}
|
||||
|
||||
// AddWalks adds all the walks to the database assigning them progressive IDs.
|
||||
func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error {
|
||||
if len(walks) == 0 {
|
||||
@@ -278,3 +341,22 @@ func (r RedisDB) validateWalks() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// unique returns a slice of unique elements of the input slice.
|
||||
func unique[S ~[]E, E cmp.Ordered](slice S) S {
|
||||
if len(slice) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slices.Sort(slice)
|
||||
unique := make(S, 0, len(slice))
|
||||
unique = append(unique, slice[0])
|
||||
|
||||
for i := 1; i < len(slice); i++ {
|
||||
if slice[i] != slice[i-1] {
|
||||
unique = append(unique, slice[i])
|
||||
}
|
||||
}
|
||||
|
||||
return unique
|
||||
}
|
||||
|
||||
@@ -90,6 +90,61 @@ func TestWalksVisiting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalksVisitingAny(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() (RedisDB, error)
|
||||
limit int
|
||||
expectedWalks int // the number of [defaultWalk] returned
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
setup: SomeWalks(0),
|
||||
limit: 1,
|
||||
expectedWalks: 0,
|
||||
},
|
||||
{
|
||||
name: "all walks",
|
||||
setup: SomeWalks(10),
|
||||
limit: -1,
|
||||
expectedWalks: 10,
|
||||
},
|
||||
{
|
||||
name: "some walks",
|
||||
setup: SomeWalks(100),
|
||||
limit: 20,
|
||||
expectedWalks: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
db, err := test.setup()
|
||||
if err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
defer db.flushAll()
|
||||
|
||||
nodes := []graph.ID{"0", "1"}
|
||||
visiting, err := db.WalksVisitingAny(ctx, nodes, test.limit)
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
if len(visiting) > test.expectedWalks {
|
||||
t.Fatalf("expected %d walks, got %d", test.expectedWalks, len(visiting))
|
||||
}
|
||||
|
||||
for _, walk := range visiting {
|
||||
if !reflect.DeepEqual(walk.Path, defaultWalk.Path) {
|
||||
// compare only the paths, not the IDs
|
||||
t.Fatalf("expected walk %v, got %v", defaultWalk, walk)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddWalks(t *testing.T) {
|
||||
db, err := SomeWalks(1)()
|
||||
if err != nil {
|
||||
@@ -297,6 +352,38 @@ func TestValidateReplacement(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnique(t *testing.T) {
|
||||
tests := []struct {
|
||||
slice []walks.ID
|
||||
expected []walks.ID
|
||||
}{
|
||||
{slice: nil, expected: nil},
|
||||
{slice: []walks.ID{}, expected: nil},
|
||||
{slice: []walks.ID{"1", "2", "0"}, expected: []walks.ID{"0", "1", "2"}},
|
||||
{slice: []walks.ID{"1", "2", "0", "3", "1", "0"}, expected: []walks.ID{"0", "1", "2", "3"}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
unique := unique(test.slice)
|
||||
if !reflect.DeepEqual(unique, test.expected) {
|
||||
t.Errorf("expected %v, got %v", test.expected, unique)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnique(b *testing.B) {
|
||||
size := 1000000
|
||||
IDs := make([]walks.ID, size)
|
||||
for i := 0; i < size; i++ {
|
||||
IDs[i] = walks.ID(strconv.Itoa(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
unique(IDs)
|
||||
}
|
||||
}
|
||||
|
||||
var defaultWalk = walks.Walk{Path: []graph.ID{"0", "1"}}
|
||||
|
||||
func SomeWalks(n int) func() (RedisDB, error) {
|
||||
|
||||
@@ -99,18 +99,18 @@ func (l *mockLoader) Follows(ctx context.Context, node graph.ID) ([]graph.ID, er
|
||||
return l.walker.Follows(ctx, node)
|
||||
}
|
||||
|
||||
func (l *mockLoader) BulkFollows(ctx context.Context, nodes []graph.ID) (map[graph.ID][]graph.ID, error) {
|
||||
followsMap := make(map[graph.ID][]graph.ID, len(nodes))
|
||||
for _, node := range nodes {
|
||||
follows, err := l.walker.Follows(ctx, node)
|
||||
func (l *mockLoader) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) {
|
||||
var err error
|
||||
follows := make([][]graph.ID, len(nodes))
|
||||
|
||||
for i, node := range nodes {
|
||||
follows[i], err = l.walker.Follows(ctx, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
followsMap[node] = follows
|
||||
}
|
||||
|
||||
return followsMap, nil
|
||||
return follows, nil
|
||||
}
|
||||
|
||||
func (l *mockLoader) AddWalks(w []walks.Walk) {
|
||||
|
||||
Reference in New Issue
Block a user