mirror of
https://github.com/aljazceru/crawler_v2.git
synced 2025-12-17 07:24:21 +01:00
code refactor for rws
This commit is contained in:
@@ -15,20 +15,20 @@ import (
|
||||
|
||||
const (
|
||||
// redis variable names
|
||||
KeyDatabase string = "database" // TODO: this can be removed
|
||||
KeyLastNodeID string = "lastNodeID" // TODO: change it to "next" inside "node" hash
|
||||
KeyKeyIndex string = "keyIndex" // TODO: change to key_index
|
||||
KeyNodePrefix string = "node:"
|
||||
KeyFollowsPrefix string = "follows:"
|
||||
KeyFollowersPrefix string = "followers:"
|
||||
KeyDatabase = "database" // TODO: this can be removed
|
||||
KeyLastNodeID = "lastNodeID" // TODO: change it to "next" inside "node" hash
|
||||
KeyKeyIndex = "keyIndex" // TODO: change to key_index
|
||||
KeyNodePrefix = "node:"
|
||||
KeyFollowsPrefix = "follows:"
|
||||
KeyFollowersPrefix = "followers:"
|
||||
|
||||
// redis node HASH fields
|
||||
NodeID string = "id"
|
||||
NodePubkey string = "pubkey"
|
||||
NodeStatus string = "status"
|
||||
NodePromotionTS string = "promotion_TS" // TODO: change to promotion
|
||||
NodeDemotionTS string = "demotion_TS" // TODO: change to demotion
|
||||
NodeAddedTS string = "added_TS" // TODO: change to addition
|
||||
NodeID = "id"
|
||||
NodePubkey = "pubkey"
|
||||
NodeStatus = "status"
|
||||
NodePromotionTS = "promotion_TS" // TODO: change to promotion
|
||||
NodeDemotionTS = "demotion_TS" // TODO: change to demotion
|
||||
NodeAddedTS = "added_TS" // TODO: change to addition
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -41,7 +41,11 @@ type RedisDB struct {
|
||||
}
|
||||
|
||||
func New(opt *redis.Options) RedisDB {
|
||||
return RedisDB{client: redis.NewClient(opt)}
|
||||
r := RedisDB{client: redis.NewClient(opt)}
|
||||
if err := r.validateWalks(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Size returns the DBSize of redis, which is the total number of keys
|
||||
@@ -178,7 +182,7 @@ func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node gr
|
||||
}
|
||||
}
|
||||
|
||||
return toIDs(members), nil
|
||||
return toNodes(members), nil
|
||||
}
|
||||
|
||||
// FollowCounts returns the number of follows each node has. If a node is not found, it returns 0.
|
||||
|
||||
@@ -318,12 +318,12 @@ func TestPubkeys(t *testing.T) {
|
||||
// ------------------------------------- HELPERS -------------------------------
|
||||
|
||||
func Empty() (RedisDB, error) {
|
||||
return New(&redis.Options{Addr: testAddress}), nil
|
||||
return RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})}, nil
|
||||
}
|
||||
|
||||
func OneNode() (RedisDB, error) {
|
||||
db := New(&redis.Options{Addr: testAddress})
|
||||
if _, err := db.AddNode(context.Background(), "0"); err != nil {
|
||||
db := RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})}
|
||||
if _, err := db.AddNode(ctx, "0"); err != nil {
|
||||
db.flushAll()
|
||||
return RedisDB{}, err
|
||||
}
|
||||
@@ -332,9 +332,7 @@ func OneNode() (RedisDB, error) {
|
||||
}
|
||||
|
||||
func Simple() (RedisDB, error) {
|
||||
ctx := context.Background()
|
||||
db := New(&redis.Options{Addr: testAddress})
|
||||
|
||||
db := RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})}
|
||||
for _, pk := range []string{"0", "1", "2"} {
|
||||
if _, err := db.AddNode(ctx, pk); err != nil {
|
||||
db.flushAll()
|
||||
|
||||
@@ -2,13 +2,19 @@ package redb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github/pippellia-btc/crawler/pkg/graph"
|
||||
"github/pippellia-btc/crawler/pkg/walks"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
testAddress = "localhost:6380"
|
||||
|
||||
ErrValueIsNil = errors.New("value is nil")
|
||||
ErrValueIsNotString = errors.New("failed to convert to string")
|
||||
)
|
||||
|
||||
// flushAll deletes all the keys of all existing databases. This command never fails.
|
||||
@@ -28,8 +34,12 @@ func followers[ID string | graph.ID](id ID) string {
|
||||
return KeyFollowersPrefix + string(id)
|
||||
}
|
||||
|
||||
// ids converts a slice of strings to IDs
|
||||
func toIDs(s []string) []graph.ID {
|
||||
func walksVisiting[ID string | graph.ID](id ID) string {
|
||||
return KeyWalksVisitingPrefix + string(id)
|
||||
}
|
||||
|
||||
// toNodes converts a slice of strings to node IDs
|
||||
func toNodes(s []string) []graph.ID {
|
||||
IDs := make([]graph.ID, len(s))
|
||||
for i, e := range s {
|
||||
IDs[i] = graph.ID(e)
|
||||
@@ -37,8 +47,17 @@ func toIDs(s []string) []graph.ID {
|
||||
return IDs
|
||||
}
|
||||
|
||||
// toWalks converts a slice of strings to walk IDs
|
||||
func toWalks(s []string) []walks.ID {
|
||||
IDs := make([]walks.ID, len(s))
|
||||
for i, e := range s {
|
||||
IDs[i] = walks.ID(e)
|
||||
}
|
||||
return IDs
|
||||
}
|
||||
|
||||
// strings converts graph IDs to a slice of strings
|
||||
func toStrings(ids []graph.ID) []string {
|
||||
func toStrings[ID graph.ID | walks.ID](ids []ID) []string {
|
||||
s := make([]string, len(ids))
|
||||
for i, id := range ids {
|
||||
s[i] = string(id)
|
||||
@@ -98,3 +117,49 @@ func parseTimestamp(unix string) (time.Time, error) {
|
||||
}
|
||||
return time.Unix(ts, 0), nil
|
||||
}
|
||||
|
||||
func formatWalk(walk walks.Walk) string {
|
||||
nodes := make([]string, walk.Len())
|
||||
for i, node := range walk.Path {
|
||||
nodes[i] = string(node)
|
||||
}
|
||||
return strings.Join(nodes, ",")
|
||||
}
|
||||
|
||||
func parseWalk(s string) walks.Walk {
|
||||
nodes := strings.Split(s, ",")
|
||||
walk := walks.Walk{Path: make([]graph.ID, len(nodes))}
|
||||
for i, node := range nodes {
|
||||
walk.Path[i] = graph.ID(node)
|
||||
}
|
||||
return walk
|
||||
}
|
||||
|
||||
func parseString(v any) (string, error) {
|
||||
if v == nil {
|
||||
return "", ErrValueIsNil
|
||||
}
|
||||
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return "", ErrValueIsNotString
|
||||
}
|
||||
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func parseFloat(v any) (float64, error) {
|
||||
str, err := parseString(v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(str, 64)
|
||||
}
|
||||
|
||||
func parseInt(v any) (int, error) {
|
||||
str, err := parseString(v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.Atoi(str)
|
||||
}
|
||||
|
||||
@@ -1 +1,280 @@
|
||||
package redb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github/pippellia-btc/crawler/pkg/graph"
|
||||
"github/pippellia-btc/crawler/pkg/walks"
|
||||
"math"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
KeyRWS = "RWS" // TODO: this can be removed
|
||||
KeyAlpha = "alpha" // TODO: walks:alpha
|
||||
KeyWalksPerNode = "walksPerNode" // TODO: walks:N or another
|
||||
KeyLastWalkID = "lastWalkID" // TODO: walks:next
|
||||
KeyTotalVisits = "totalVisits" // TODO: walks:total_visits
|
||||
KeyWalks = "walks"
|
||||
KeyWalksVisitingPrefix = "walksVisiting:" // TODO: walks_visiting:
|
||||
)
|
||||
|
||||
var (
|
||||
ErrWalkNotFound = errors.New("walk not found")
|
||||
ErrInvalidReplacement = errors.New("invalid walk replacement")
|
||||
ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks visiting node")
|
||||
)
|
||||
|
||||
// Walks returns the walks associated with the IDs.
|
||||
func (r RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, error) {
|
||||
switch {
|
||||
case len(IDs) == 0:
|
||||
return nil, nil
|
||||
|
||||
case len(IDs) <= 100000:
|
||||
vals, err := r.client.HMGet(ctx, KeyWalks, toStrings(IDs)...).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch walks: %w", err)
|
||||
}
|
||||
|
||||
walks := make([]walks.Walk, len(vals))
|
||||
for i, val := range vals {
|
||||
if val == nil {
|
||||
// walk was not found, so return an error
|
||||
return nil, fmt.Errorf("failed to fetch walk with ID %s: %w", IDs[i], ErrWalkNotFound)
|
||||
}
|
||||
|
||||
walks[i] = parseWalk(val.(string))
|
||||
walks[i].ID = IDs[i]
|
||||
}
|
||||
|
||||
return walks, nil
|
||||
|
||||
default:
|
||||
// too many walks for a single call, so we split them in two batches
|
||||
mid := len(IDs) / 2
|
||||
batch1, err := r.Walks(ctx, IDs[:mid]...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
batch2, err := r.Walks(ctx, IDs[mid:]...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(batch1, batch2...), nil
|
||||
}
|
||||
}
|
||||
|
||||
// WalksVisiting returns up-to limit walks that visit node.
|
||||
// Use limit = -1 to fetch all the walks visiting node.
|
||||
func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([]walks.Walk, error) {
|
||||
switch {
|
||||
case limit == -1:
|
||||
// return all walks visiting node
|
||||
IDs, err := r.client.SMembers(ctx, walksVisiting(node)).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err)
|
||||
}
|
||||
|
||||
return r.Walks(ctx, toWalks(IDs)...)
|
||||
|
||||
case limit > 0:
|
||||
IDs, err := r.client.SRandMemberN(ctx, walksVisiting(node), int64(limit)).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err)
|
||||
}
|
||||
|
||||
return r.Walks(ctx, toWalks(IDs)...)
|
||||
|
||||
default:
|
||||
return nil, ErrInvalidLimit
|
||||
}
|
||||
}
|
||||
|
||||
// AddWalks adds all the walks to the database assigning them progressive IDs.
|
||||
func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error {
|
||||
if len(walks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get the IDs outside the transaction, which implies there might be "holes",
|
||||
// meaning IDs not associated with any walk
|
||||
next, err := r.client.HIncrBy(ctx, KeyRWS, KeyLastWalkID, int64(len(walks))).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add walks: failed to increment ID: %w", err)
|
||||
}
|
||||
|
||||
var visits, ID int
|
||||
pipe := r.client.TxPipeline()
|
||||
|
||||
for i, walk := range walks {
|
||||
visits += walk.Len()
|
||||
ID = int(next) - len(walks) + i // assigning IDs in the same order
|
||||
|
||||
pipe.HSet(ctx, KeyWalks, ID, formatWalk(walk))
|
||||
for _, node := range walk.Path {
|
||||
pipe.SAdd(ctx, walksVisiting(node), ID)
|
||||
}
|
||||
}
|
||||
|
||||
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, int64(visits))
|
||||
|
||||
if _, err = pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to add walks: pipeline failed %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveWalks removes all the walks from the database.
|
||||
func (r RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error {
|
||||
if len(walks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var visits int
|
||||
pipe := r.client.TxPipeline()
|
||||
|
||||
for _, walk := range walks {
|
||||
pipe.HDel(ctx, KeyWalks, string(walk.ID))
|
||||
for _, node := range walk.Path {
|
||||
pipe.SRem(ctx, walksVisiting(node), string(walk.ID))
|
||||
}
|
||||
|
||||
visits += walk.Len()
|
||||
}
|
||||
|
||||
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, -int64(visits))
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to remove walks: pipeline failed %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) error {
|
||||
if err := validateReplacement(before, after); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var visits int64
|
||||
pipe := r.client.TxPipeline()
|
||||
|
||||
for i := range before {
|
||||
div := walks.Divergence(before[i], after[i])
|
||||
if div == -1 {
|
||||
// the two walks are equal, skip
|
||||
continue
|
||||
}
|
||||
|
||||
prev := before[i]
|
||||
next := after[i]
|
||||
ID := string(after[i].ID)
|
||||
|
||||
pipe.HSet(ctx, KeyWalks, ID, formatWalk(next))
|
||||
|
||||
for _, node := range prev.Path[div:] {
|
||||
pipe.SRem(ctx, walksVisiting(node), ID)
|
||||
visits--
|
||||
}
|
||||
|
||||
for _, node := range next.Path[div:] {
|
||||
pipe.SAdd(ctx, walksVisiting(node), ID)
|
||||
visits++
|
||||
}
|
||||
|
||||
if pipe.Len() > 5000 {
|
||||
// execute a partial update when it's too big
|
||||
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, visits)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to replace walks: pipeline failed %w", err)
|
||||
}
|
||||
|
||||
pipe = r.client.TxPipeline()
|
||||
visits = 0
|
||||
}
|
||||
}
|
||||
|
||||
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, visits)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to replace walks: pipeline failed %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateReplacement(old, new []walks.Walk) error {
|
||||
if len(old) != len(new) {
|
||||
return fmt.Errorf("%w: old and new walks must have the same lenght", ErrInvalidReplacement)
|
||||
}
|
||||
|
||||
seen := make(map[walks.ID]struct{})
|
||||
for i := range old {
|
||||
if old[i].ID != new[i].ID {
|
||||
return fmt.Errorf("%w: IDs don't match at index %d: old=%s, new=%s", ErrInvalidReplacement, i, old[i].ID, new[i].ID)
|
||||
}
|
||||
|
||||
if _, ok := seen[old[i].ID]; ok {
|
||||
return fmt.Errorf("%w: repeated walk ID %s", ErrInvalidReplacement, old[i].ID)
|
||||
}
|
||||
|
||||
seen[old[i].ID] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TotalVisits returns the total number of visits, which is the sum of the lengths of all walks.
|
||||
func (r RedisDB) TotalVisits(ctx context.Context) (int, error) {
|
||||
total, err := r.client.HGet(ctx, KeyRWS, KeyTotalVisits).Result()
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to get the total number of visits: %w", err)
|
||||
}
|
||||
|
||||
tot, err := strconv.Atoi(total)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to parse the total number of visits: %w", err)
|
||||
}
|
||||
|
||||
return tot, nil
|
||||
}
|
||||
|
||||
// Visits returns the number of times each specified node was visited during the walks.
|
||||
// The returned slice contains counts in the same order as the input nodes.
|
||||
// If a node is not found, it returns 0 visits.
|
||||
func (r RedisDB) Visits(ctx context.Context, nodes ...graph.ID) ([]int, error) {
|
||||
return r.counts(ctx, walksVisiting, nodes...)
|
||||
}
|
||||
|
||||
func (r RedisDB) validateWalks() error {
|
||||
vals, err := r.client.HMGet(context.Background(), KeyRWS, KeyAlpha, KeyWalksPerNode).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch alpha and walksPerNode %w", err)
|
||||
}
|
||||
|
||||
alpha, err := parseFloat(vals[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse alpha: %w", err)
|
||||
}
|
||||
|
||||
N, err := parseInt(vals[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse walksPerNode: %w", err)
|
||||
}
|
||||
|
||||
if math.Abs(alpha-walks.Alpha) > 1e-10 {
|
||||
return errors.New("alpha and walks.Alpha are different")
|
||||
}
|
||||
|
||||
if N != walks.N {
|
||||
return errors.New("N and walks.N are different")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
317
pkg/redb/walks_test.go
Normal file
317
pkg/redb/walks_test.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package redb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github/pippellia-btc/crawler/pkg/graph"
|
||||
"github/pippellia-btc/crawler/pkg/walks"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() (RedisDB, error)
|
||||
err error
|
||||
}{
|
||||
{name: "empty", setup: Empty, err: ErrValueIsNil},
|
||||
{name: "valid", setup: SomeWalks(0)},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
db, err := test.setup()
|
||||
if err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
defer db.flushAll()
|
||||
|
||||
if err = db.validateWalks(); !errors.Is(err, test.err) {
|
||||
t.Fatalf("expected error %v, got %v", test.err, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalksVisiting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() (RedisDB, error)
|
||||
limit int
|
||||
expectedWalks int // the number of [defaultWalk] returned
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
setup: SomeWalks(0),
|
||||
limit: 1,
|
||||
expectedWalks: 0,
|
||||
},
|
||||
{
|
||||
name: "all walks",
|
||||
setup: SomeWalks(10),
|
||||
limit: -1,
|
||||
expectedWalks: 10,
|
||||
},
|
||||
{
|
||||
name: "some walks",
|
||||
setup: SomeWalks(100),
|
||||
limit: 33,
|
||||
expectedWalks: 33,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
db, err := test.setup()
|
||||
if err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
defer db.flushAll()
|
||||
|
||||
visiting, err := db.WalksVisiting(ctx, "0", test.limit)
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
if len(visiting) != test.expectedWalks {
|
||||
t.Fatalf("expected %d walks, got %d", test.expectedWalks, len(visiting))
|
||||
}
|
||||
|
||||
for _, walk := range visiting {
|
||||
if !reflect.DeepEqual(walk.Path, defaultWalk.Path) {
|
||||
// compare only the paths, not the IDs
|
||||
t.Fatalf("expected walk %v, got %v", defaultWalk, walk)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddWalks(t *testing.T) {
|
||||
db, err := SomeWalks(1)()
|
||||
if err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
defer db.flushAll()
|
||||
|
||||
walks := []walks.Walk{
|
||||
{ID: "1", Path: []graph.ID{"1", "2", "3"}},
|
||||
{ID: "2", Path: []graph.ID{"4", "5"}},
|
||||
{ID: "3", Path: []graph.ID{"a", "b", "c"}},
|
||||
}
|
||||
|
||||
if err := db.AddWalks(ctx, walks...); err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
stored, err := db.Walks(ctx, "1", "2", "3")
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(stored, walks) {
|
||||
t.Fatalf("expected walks %v, got %v", walks, stored)
|
||||
}
|
||||
|
||||
total, err := db.TotalVisits(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
if total != 10 {
|
||||
t.Fatalf("expected total visits %d, got %d", 10, total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveWalks(t *testing.T) {
|
||||
db, err := SomeWalks(10)()
|
||||
if err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
defer db.flushAll()
|
||||
|
||||
walks := []walks.Walk{
|
||||
{ID: "0", Path: defaultWalk.Path},
|
||||
{ID: "1", Path: defaultWalk.Path},
|
||||
}
|
||||
|
||||
if err := db.RemoveWalks(ctx, walks...); err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
total, err := db.TotalVisits(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
expected := (10 - 2) * defaultWalk.Len()
|
||||
if total != expected {
|
||||
t.Fatalf("expected total %d, got %d", expected, total)
|
||||
}
|
||||
|
||||
visits, err := db.Visits(ctx, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
expected = (10 - 2)
|
||||
if visits[0] != expected {
|
||||
t.Fatalf("expected visits %d, got %d", expected, visits[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceWalks(t *testing.T) {
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
db, err := SomeWalks(2)()
|
||||
if err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
defer db.flushAll()
|
||||
|
||||
before := []walks.Walk{
|
||||
{ID: "0", Path: []graph.ID{"0", "1"}},
|
||||
{ID: "1", Path: []graph.ID{"0", "1"}},
|
||||
}
|
||||
|
||||
after := []walks.Walk{
|
||||
{ID: "0", Path: []graph.ID{"0", "2", "3"}}, // changed
|
||||
{ID: "1", Path: []graph.ID{"0", "1"}},
|
||||
}
|
||||
|
||||
if err := db.ReplaceWalks(ctx, before, after); err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
walks, err := db.Walks(ctx, "0", "1")
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(walks, after) {
|
||||
t.Fatalf("expected walks %v, got %v", after, walks)
|
||||
}
|
||||
|
||||
expected := []int{2, 1, 1, 1}
|
||||
visits, err := db.Visits(ctx, "0", "1", "2", "3")
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(visits, expected) {
|
||||
t.Fatalf("expected visits %v, got %v", expected, visits)
|
||||
}
|
||||
|
||||
total, err := db.TotalVisits(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
if total != 5 {
|
||||
t.Fatalf("expected total %d, got %d", 8, total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mass removal", func(t *testing.T) {
|
||||
num := 10000
|
||||
db, err := SomeWalks(num)()
|
||||
if err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
defer db.flushAll()
|
||||
|
||||
before := make([]walks.Walk, num)
|
||||
after := make([]walks.Walk, num)
|
||||
for i := range num {
|
||||
ID := walks.ID(strconv.Itoa(i))
|
||||
before[i] = walks.Walk{ID: ID, Path: defaultWalk.Path} // the walk in the DB
|
||||
after[i] = walks.Walk{ID: ID} // empty walk
|
||||
}
|
||||
|
||||
if err := db.ReplaceWalks(ctx, before, after); err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
visits, err := db.Visits(ctx, "0", "1")
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(visits, []int{0, 0}) {
|
||||
t.Fatalf("expected visits %v, got %v", []int{0, 0}, visits)
|
||||
}
|
||||
|
||||
total, err := db.TotalVisits(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected error nil, got %v", err)
|
||||
}
|
||||
|
||||
if total != 0 {
|
||||
t.Fatalf("expected total %d, got %d", 0, total)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateReplacement(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
old []walks.Walk
|
||||
new []walks.Walk
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "no walks",
|
||||
},
|
||||
{
|
||||
name: "different lenght",
|
||||
old: []walks.Walk{{}},
|
||||
err: ErrInvalidReplacement,
|
||||
},
|
||||
{
|
||||
name: "different IDs",
|
||||
old: []walks.Walk{{ID: "0"}, {ID: "1"}},
|
||||
new: []walks.Walk{{ID: "1"}, {ID: "0"}},
|
||||
err: ErrInvalidReplacement,
|
||||
},
|
||||
{
|
||||
name: "repeated IDs",
|
||||
old: []walks.Walk{{ID: "0"}, {ID: "0"}},
|
||||
new: []walks.Walk{{ID: "0"}, {ID: "0"}},
|
||||
err: ErrInvalidReplacement,
|
||||
},
|
||||
{
|
||||
name: "valid IDs",
|
||||
old: []walks.Walk{{ID: "0"}, {ID: "1"}},
|
||||
new: []walks.Walk{{ID: "0"}, {ID: "1"}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
err := validateReplacement(test.old, test.new)
|
||||
if !errors.Is(err, test.err) {
|
||||
t.Fatalf("expected error %v, got %v", test.err, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var defaultWalk = walks.Walk{Path: []graph.ID{"0", "1"}}
|
||||
|
||||
func SomeWalks(n int) func() (RedisDB, error) {
|
||||
return func() (RedisDB, error) {
|
||||
db := RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})}
|
||||
if err := db.client.HSet(ctx, KeyRWS, KeyAlpha, walks.Alpha, KeyWalksPerNode, walks.N).Err(); err != nil {
|
||||
return RedisDB{}, err
|
||||
}
|
||||
|
||||
for range n {
|
||||
if err := db.AddWalks(ctx, defaultWalk); err != nil {
|
||||
return RedisDB{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package walks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github/pippellia-btc/crawler/pkg/graph"
|
||||
"math/rand/v2"
|
||||
@@ -11,6 +12,8 @@ import (
|
||||
var (
|
||||
Alpha = 0.85 // the dampening factor
|
||||
N = 100 // the walks per node
|
||||
|
||||
ErrInvalidRemoval = errors.New(fmt.Sprintf("the walks to be removed are different than the expected number %d", N))
|
||||
)
|
||||
|
||||
// ID represent how walks are identified in the storage layer
|
||||
@@ -28,11 +31,6 @@ type Walker interface {
|
||||
Follows(ctx context.Context, node graph.ID) ([]graph.ID, error)
|
||||
}
|
||||
|
||||
// New returns a new walk with a preallocated empty path
|
||||
func New(n int) Walk {
|
||||
return Walk{Path: make([]graph.ID, 0, n)}
|
||||
}
|
||||
|
||||
// Len returns the lenght of the walk
|
||||
func (w Walk) Len() int {
|
||||
return len(w.Path)
|
||||
@@ -80,6 +78,23 @@ func (w *Walk) Graft(path []graph.ID) {
|
||||
w.Path = w.Path[:pos]
|
||||
}
|
||||
|
||||
// Divergence returns the first index where w1 and w2 are different, -1 if equal.
|
||||
func Divergence(w1, w2 Walk) int {
|
||||
min := min(w1.Len(), w2.Len())
|
||||
for i := range min {
|
||||
if w1.Path[i] != w2.Path[i] {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
if w1.Len() == w2.Len() {
|
||||
// they are all equal, so no divergence
|
||||
return -1
|
||||
}
|
||||
|
||||
return min
|
||||
}
|
||||
|
||||
// Generate [N] random walks for the specified node, using dampening factor [Alpha].
|
||||
// A walk stops early if a cycle is encountered.
|
||||
// Walk IDs are not set, because it's the responsibility of the storage layer.
|
||||
@@ -146,19 +161,18 @@ func generate(ctx context.Context, walker Walker, start ...graph.ID) ([]graph.ID
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// ToRemove returns the IDs of walks that needs to be removed.
|
||||
// ToRemove returns the walks that need to be removed.
|
||||
// It returns an error if the number of walks to remove differs from the expected [N].
|
||||
func ToRemove(node graph.ID, walks []Walk) ([]ID, error) {
|
||||
toRemove := make([]ID, 0, N)
|
||||
|
||||
func ToRemove(node graph.ID, walks []Walk) ([]Walk, error) {
|
||||
toRemove := make([]Walk, 0, N)
|
||||
for _, walk := range walks {
|
||||
if walk.Index(node) != -1 {
|
||||
toRemove = append(toRemove, walk.ID)
|
||||
if walk.Index(node) == 0 {
|
||||
toRemove = append(toRemove, walk)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toRemove) != N {
|
||||
return toRemove, fmt.Errorf("walks to be removed (%d) are different than expected (%d)", len(toRemove), N)
|
||||
return nil, fmt.Errorf("ToRemove: %w: %d", ErrInvalidRemoval, len(toRemove))
|
||||
}
|
||||
|
||||
return toRemove, nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package walks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github/pippellia-btc/crawler/pkg/graph"
|
||||
"math"
|
||||
@@ -53,6 +54,53 @@ func TestGenerate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestToRemove(t *testing.T) {
|
||||
N = 3
|
||||
tests := []struct {
|
||||
name string
|
||||
walks []Walk
|
||||
toRemove []Walk
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "no walks",
|
||||
err: ErrInvalidRemoval,
|
||||
},
|
||||
{
|
||||
name: "too few walks to remove",
|
||||
walks: []Walk{{Path: []graph.ID{"0", "1"}}},
|
||||
err: ErrInvalidRemoval,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
walks: []Walk{
|
||||
{Path: []graph.ID{"0", "1"}},
|
||||
{Path: []graph.ID{"0", "2"}},
|
||||
{Path: []graph.ID{"0", "3"}},
|
||||
{Path: []graph.ID{"1", "0"}},
|
||||
},
|
||||
toRemove: []Walk{
|
||||
{Path: []graph.ID{"0", "1"}},
|
||||
{Path: []graph.ID{"0", "2"}},
|
||||
{Path: []graph.ID{"0", "3"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
toRemove, err := ToRemove("0", test.walks)
|
||||
if !errors.Is(err, test.err) {
|
||||
t.Fatalf("expected error %v, got %v", test.err, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(toRemove, test.toRemove) {
|
||||
t.Fatalf("expected walks to remove %v, got %v", test.toRemove, toRemove)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRemove(t *testing.T) {
|
||||
walker := NewWalker(map[graph.ID][]graph.ID{
|
||||
"0": {"3"},
|
||||
@@ -85,6 +133,26 @@ func TestUpdateRemove(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDivergence(t *testing.T) {
|
||||
tests := []struct {
|
||||
w1 Walk
|
||||
w2 Walk
|
||||
expected int
|
||||
}{
|
||||
{w1: Walk{Path: []graph.ID{"0"}}, w2: Walk{Path: []graph.ID{"0", "1"}}, expected: 1},
|
||||
{w1: Walk{Path: []graph.ID{"0", "1", "69"}}, w2: Walk{Path: []graph.ID{"0", "1"}}, expected: 2},
|
||||
{w1: Walk{Path: []graph.ID{"0", "1", "69"}}, w2: Walk{Path: []graph.ID{"0", "1", "420"}}, expected: 2},
|
||||
{w1: Walk{Path: []graph.ID{"a", "b", "c"}}, w2: Walk{Path: []graph.ID{"a", "b", "c"}}, expected: -1},
|
||||
{expected: -1},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
if div := Divergence(test.w1, test.w2); div != test.expected {
|
||||
t.Fatalf("test %d: expected %d, got %v", i, test.expected, div)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCycle(t *testing.T) {
|
||||
tests := []struct {
|
||||
list []graph.ID
|
||||
|
||||
Reference in New Issue
Block a user