code refactor for rws

This commit is contained in:
pippellia-btc
2025-05-28 16:55:38 +02:00
parent 85a2eebc95
commit 364432a161
7 changed files with 780 additions and 35 deletions

View File

@@ -15,20 +15,20 @@ import (
const (
// redis variable names
KeyDatabase string = "database" // TODO: this can be removed
KeyLastNodeID string = "lastNodeID" // TODO: change it to "next" inside "node" hash
KeyKeyIndex string = "keyIndex" // TODO: change to key_index
KeyNodePrefix string = "node:"
KeyFollowsPrefix string = "follows:"
KeyFollowersPrefix string = "followers:"
KeyDatabase = "database" // TODO: this can be removed
KeyLastNodeID = "lastNodeID" // TODO: change it to "next" inside "node" hash
KeyKeyIndex = "keyIndex" // TODO: change to key_index
KeyNodePrefix = "node:"
KeyFollowsPrefix = "follows:"
KeyFollowersPrefix = "followers:"
// redis node HASH fields
NodeID string = "id"
NodePubkey string = "pubkey"
NodeStatus string = "status"
NodePromotionTS string = "promotion_TS" // TODO: change to promotion
NodeDemotionTS string = "demotion_TS" // TODO: change to demotion
NodeAddedTS string = "added_TS" // TODO: change to addition
NodeID = "id"
NodePubkey = "pubkey"
NodeStatus = "status"
NodePromotionTS = "promotion_TS" // TODO: change to promotion
NodeDemotionTS = "demotion_TS" // TODO: change to demotion
NodeAddedTS = "added_TS" // TODO: change to addition
)
var (
@@ -41,7 +41,11 @@ type RedisDB struct {
}
func New(opt *redis.Options) RedisDB {
return RedisDB{client: redis.NewClient(opt)}
r := RedisDB{client: redis.NewClient(opt)}
if err := r.validateWalks(); err != nil {
panic(err)
}
return r
}
// Size returns the DBSize of redis, which is the total number of keys
@@ -178,7 +182,7 @@ func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node gr
}
}
return toIDs(members), nil
return toNodes(members), nil
}
// FollowCounts returns the number of follows each node has. If a node is not found, it returns 0.

View File

@@ -318,12 +318,12 @@ func TestPubkeys(t *testing.T) {
// ------------------------------------- HELPERS -------------------------------
func Empty() (RedisDB, error) {
return New(&redis.Options{Addr: testAddress}), nil
return RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})}, nil
}
func OneNode() (RedisDB, error) {
db := New(&redis.Options{Addr: testAddress})
if _, err := db.AddNode(context.Background(), "0"); err != nil {
db := RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})}
if _, err := db.AddNode(ctx, "0"); err != nil {
db.flushAll()
return RedisDB{}, err
}
@@ -332,9 +332,7 @@ func OneNode() (RedisDB, error) {
}
func Simple() (RedisDB, error) {
ctx := context.Background()
db := New(&redis.Options{Addr: testAddress})
db := RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})}
for _, pk := range []string{"0", "1", "2"} {
if _, err := db.AddNode(ctx, pk); err != nil {
db.flushAll()

View File

@@ -2,13 +2,19 @@ package redb
import (
"context"
"errors"
"github/pippellia-btc/crawler/pkg/graph"
"github/pippellia-btc/crawler/pkg/walks"
"strconv"
"strings"
"time"
)
var (
testAddress = "localhost:6380"
ErrValueIsNil = errors.New("value is nil")
ErrValueIsNotString = errors.New("failed to convert to string")
)
// flushAll deletes all the keys of all existing databases. This command never fails.
@@ -28,8 +34,12 @@ func followers[ID string | graph.ID](id ID) string {
return KeyFollowersPrefix + string(id)
}
// ids converts a slice of strings to IDs
func toIDs(s []string) []graph.ID {
func walksVisiting[ID string | graph.ID](id ID) string {
return KeyWalksVisitingPrefix + string(id)
}
// toNodes converts a slice of strings to node IDs
func toNodes(s []string) []graph.ID {
IDs := make([]graph.ID, len(s))
for i, e := range s {
IDs[i] = graph.ID(e)
@@ -37,8 +47,17 @@ func toIDs(s []string) []graph.ID {
return IDs
}
// toWalks converts a slice of strings to walk IDs
func toWalks(s []string) []walks.ID {
IDs := make([]walks.ID, len(s))
for i, e := range s {
IDs[i] = walks.ID(e)
}
return IDs
}
// strings converts graph IDs to a slice of strings
func toStrings(ids []graph.ID) []string {
func toStrings[ID graph.ID | walks.ID](ids []ID) []string {
s := make([]string, len(ids))
for i, id := range ids {
s[i] = string(id)
@@ -98,3 +117,49 @@ func parseTimestamp(unix string) (time.Time, error) {
}
return time.Unix(ts, 0), nil
}
func formatWalk(walk walks.Walk) string {
nodes := make([]string, walk.Len())
for i, node := range walk.Path {
nodes[i] = string(node)
}
return strings.Join(nodes, ",")
}
func parseWalk(s string) walks.Walk {
nodes := strings.Split(s, ",")
walk := walks.Walk{Path: make([]graph.ID, len(nodes))}
for i, node := range nodes {
walk.Path[i] = graph.ID(node)
}
return walk
}
func parseString(v any) (string, error) {
if v == nil {
return "", ErrValueIsNil
}
str, ok := v.(string)
if !ok {
return "", ErrValueIsNotString
}
return str, nil
}
func parseFloat(v any) (float64, error) {
str, err := parseString(v)
if err != nil {
return 0, err
}
return strconv.ParseFloat(str, 64)
}
func parseInt(v any) (int, error) {
str, err := parseString(v)
if err != nil {
return 0, err
}
return strconv.Atoi(str)
}

View File

@@ -1 +1,280 @@
package redb
import (
"context"
"errors"
"fmt"
"github/pippellia-btc/crawler/pkg/graph"
"github/pippellia-btc/crawler/pkg/walks"
"math"
"strconv"
)
const (
KeyRWS = "RWS" // TODO: this can be removed
KeyAlpha = "alpha" // TODO: walks:alpha
KeyWalksPerNode = "walksPerNode" // TODO: walks:N or another
KeyLastWalkID = "lastWalkID" // TODO: walks:next
KeyTotalVisits = "totalVisits" // TODO: walks:total_visits
KeyWalks = "walks"
KeyWalksVisitingPrefix = "walksVisiting:" // TODO: walks_visiting:
)
var (
ErrWalkNotFound = errors.New("walk not found")
ErrInvalidReplacement = errors.New("invalid walk replacement")
ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks visiting node")
)
// Walks returns the walks associated with the IDs.
func (r RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, error) {
switch {
case len(IDs) == 0:
return nil, nil
case len(IDs) <= 100000:
vals, err := r.client.HMGet(ctx, KeyWalks, toStrings(IDs)...).Result()
if err != nil {
return nil, fmt.Errorf("failed to fetch walks: %w", err)
}
walks := make([]walks.Walk, len(vals))
for i, val := range vals {
if val == nil {
// walk was not found, so return an error
return nil, fmt.Errorf("failed to fetch walk with ID %s: %w", IDs[i], ErrWalkNotFound)
}
walks[i] = parseWalk(val.(string))
walks[i].ID = IDs[i]
}
return walks, nil
default:
// too many walks for a single call, so we split them in two batches
mid := len(IDs) / 2
batch1, err := r.Walks(ctx, IDs[:mid]...)
if err != nil {
return nil, err
}
batch2, err := r.Walks(ctx, IDs[mid:]...)
if err != nil {
return nil, err
}
return append(batch1, batch2...), nil
}
}
// WalksVisiting returns up-to limit walks that visit node.
// Use limit = -1 to fetch all the walks visiting node.
func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([]walks.Walk, error) {
switch {
case limit == -1:
// return all walks visiting node
IDs, err := r.client.SMembers(ctx, walksVisiting(node)).Result()
if err != nil {
return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err)
}
return r.Walks(ctx, toWalks(IDs)...)
case limit > 0:
IDs, err := r.client.SRandMemberN(ctx, walksVisiting(node), int64(limit)).Result()
if err != nil {
return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err)
}
return r.Walks(ctx, toWalks(IDs)...)
default:
return nil, ErrInvalidLimit
}
}
// AddWalks adds all the walks to the database assigning them progressive IDs.
func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error {
if len(walks) == 0 {
return nil
}
// get the IDs outside the transaction, which implies there might be "holes",
// meaning IDs not associated with any walk
next, err := r.client.HIncrBy(ctx, KeyRWS, KeyLastWalkID, int64(len(walks))).Result()
if err != nil {
return fmt.Errorf("failed to add walks: failed to increment ID: %w", err)
}
var visits, ID int
pipe := r.client.TxPipeline()
for i, walk := range walks {
visits += walk.Len()
ID = int(next) - len(walks) + i // assigning IDs in the same order
pipe.HSet(ctx, KeyWalks, ID, formatWalk(walk))
for _, node := range walk.Path {
pipe.SAdd(ctx, walksVisiting(node), ID)
}
}
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, int64(visits))
if _, err = pipe.Exec(ctx); err != nil {
return fmt.Errorf("failed to add walks: pipeline failed %w", err)
}
return nil
}
// RemoveWalks removes all the walks from the database.
func (r RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error {
if len(walks) == 0 {
return nil
}
var visits int
pipe := r.client.TxPipeline()
for _, walk := range walks {
pipe.HDel(ctx, KeyWalks, string(walk.ID))
for _, node := range walk.Path {
pipe.SRem(ctx, walksVisiting(node), string(walk.ID))
}
visits += walk.Len()
}
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, -int64(visits))
if _, err := pipe.Exec(ctx); err != nil {
return fmt.Errorf("failed to remove walks: pipeline failed %w", err)
}
return nil
}
func (r RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) error {
if err := validateReplacement(before, after); err != nil {
return err
}
var visits int64
pipe := r.client.TxPipeline()
for i := range before {
div := walks.Divergence(before[i], after[i])
if div == -1 {
// the two walks are equal, skip
continue
}
prev := before[i]
next := after[i]
ID := string(after[i].ID)
pipe.HSet(ctx, KeyWalks, ID, formatWalk(next))
for _, node := range prev.Path[div:] {
pipe.SRem(ctx, walksVisiting(node), ID)
visits--
}
for _, node := range next.Path[div:] {
pipe.SAdd(ctx, walksVisiting(node), ID)
visits++
}
if pipe.Len() > 5000 {
// execute a partial update when it's too big
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, visits)
if _, err := pipe.Exec(ctx); err != nil {
return fmt.Errorf("failed to replace walks: pipeline failed %w", err)
}
pipe = r.client.TxPipeline()
visits = 0
}
}
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, visits)
if _, err := pipe.Exec(ctx); err != nil {
return fmt.Errorf("failed to replace walks: pipeline failed %w", err)
}
return nil
}
func validateReplacement(old, new []walks.Walk) error {
if len(old) != len(new) {
return fmt.Errorf("%w: old and new walks must have the same lenght", ErrInvalidReplacement)
}
seen := make(map[walks.ID]struct{})
for i := range old {
if old[i].ID != new[i].ID {
return fmt.Errorf("%w: IDs don't match at index %d: old=%s, new=%s", ErrInvalidReplacement, i, old[i].ID, new[i].ID)
}
if _, ok := seen[old[i].ID]; ok {
return fmt.Errorf("%w: repeated walk ID %s", ErrInvalidReplacement, old[i].ID)
}
seen[old[i].ID] = struct{}{}
}
return nil
}
// TotalVisits returns the total number of visits, which is the sum of the lengths of all walks.
func (r RedisDB) TotalVisits(ctx context.Context) (int, error) {
total, err := r.client.HGet(ctx, KeyRWS, KeyTotalVisits).Result()
if err != nil {
return -1, fmt.Errorf("failed to get the total number of visits: %w", err)
}
tot, err := strconv.Atoi(total)
if err != nil {
return -1, fmt.Errorf("failed to parse the total number of visits: %w", err)
}
return tot, nil
}
// Visits returns the number of times each specified node was visited during the walks.
// The returned slice contains counts in the same order as the input nodes.
// If a node is not found, it returns 0 visits.
func (r RedisDB) Visits(ctx context.Context, nodes ...graph.ID) ([]int, error) {
return r.counts(ctx, walksVisiting, nodes...)
}
func (r RedisDB) validateWalks() error {
vals, err := r.client.HMGet(context.Background(), KeyRWS, KeyAlpha, KeyWalksPerNode).Result()
if err != nil {
return fmt.Errorf("failed to fetch alpha and walksPerNode %w", err)
}
alpha, err := parseFloat(vals[0])
if err != nil {
return fmt.Errorf("failed to parse alpha: %w", err)
}
N, err := parseInt(vals[1])
if err != nil {
return fmt.Errorf("failed to parse walksPerNode: %w", err)
}
if math.Abs(alpha-walks.Alpha) > 1e-10 {
return errors.New("alpha and walks.Alpha are different")
}
if N != walks.N {
return errors.New("N and walks.N are different")
}
return nil
}

317
pkg/redb/walks_test.go Normal file
View File

@@ -0,0 +1,317 @@
package redb
import (
"errors"
"github/pippellia-btc/crawler/pkg/graph"
"github/pippellia-btc/crawler/pkg/walks"
"reflect"
"strconv"
"testing"
"github.com/redis/go-redis/v9"
)
func TestValidate(t *testing.T) {
tests := []struct {
name string
setup func() (RedisDB, error)
err error
}{
{name: "empty", setup: Empty, err: ErrValueIsNil},
{name: "valid", setup: SomeWalks(0)},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db, err := test.setup()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
if err = db.validateWalks(); !errors.Is(err, test.err) {
t.Fatalf("expected error %v, got %v", test.err, err)
}
})
}
}
func TestWalksVisiting(t *testing.T) {
tests := []struct {
name string
setup func() (RedisDB, error)
limit int
expectedWalks int // the number of [defaultWalk] returned
}{
{
name: "empty",
setup: SomeWalks(0),
limit: 1,
expectedWalks: 0,
},
{
name: "all walks",
setup: SomeWalks(10),
limit: -1,
expectedWalks: 10,
},
{
name: "some walks",
setup: SomeWalks(100),
limit: 33,
expectedWalks: 33,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db, err := test.setup()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
visiting, err := db.WalksVisiting(ctx, "0", test.limit)
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if len(visiting) != test.expectedWalks {
t.Fatalf("expected %d walks, got %d", test.expectedWalks, len(visiting))
}
for _, walk := range visiting {
if !reflect.DeepEqual(walk.Path, defaultWalk.Path) {
// compare only the paths, not the IDs
t.Fatalf("expected walk %v, got %v", defaultWalk, walk)
}
}
})
}
}
func TestAddWalks(t *testing.T) {
db, err := SomeWalks(1)()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
walks := []walks.Walk{
{ID: "1", Path: []graph.ID{"1", "2", "3"}},
{ID: "2", Path: []graph.ID{"4", "5"}},
{ID: "3", Path: []graph.ID{"a", "b", "c"}},
}
if err := db.AddWalks(ctx, walks...); err != nil {
t.Fatalf("expected error nil, got %v", err)
}
stored, err := db.Walks(ctx, "1", "2", "3")
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if !reflect.DeepEqual(stored, walks) {
t.Fatalf("expected walks %v, got %v", walks, stored)
}
total, err := db.TotalVisits(ctx)
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if total != 10 {
t.Fatalf("expected total visits %d, got %d", 10, total)
}
}
func TestRemoveWalks(t *testing.T) {
db, err := SomeWalks(10)()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
walks := []walks.Walk{
{ID: "0", Path: defaultWalk.Path},
{ID: "1", Path: defaultWalk.Path},
}
if err := db.RemoveWalks(ctx, walks...); err != nil {
t.Fatalf("expected error nil, got %v", err)
}
total, err := db.TotalVisits(ctx)
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
expected := (10 - 2) * defaultWalk.Len()
if total != expected {
t.Fatalf("expected total %d, got %d", expected, total)
}
visits, err := db.Visits(ctx, "0")
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
expected = (10 - 2)
if visits[0] != expected {
t.Fatalf("expected visits %d, got %d", expected, visits[0])
}
}
func TestReplaceWalks(t *testing.T) {
t.Run("simple", func(t *testing.T) {
db, err := SomeWalks(2)()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
before := []walks.Walk{
{ID: "0", Path: []graph.ID{"0", "1"}},
{ID: "1", Path: []graph.ID{"0", "1"}},
}
after := []walks.Walk{
{ID: "0", Path: []graph.ID{"0", "2", "3"}}, // changed
{ID: "1", Path: []graph.ID{"0", "1"}},
}
if err := db.ReplaceWalks(ctx, before, after); err != nil {
t.Fatalf("expected error nil, got %v", err)
}
walks, err := db.Walks(ctx, "0", "1")
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if !reflect.DeepEqual(walks, after) {
t.Fatalf("expected walks %v, got %v", after, walks)
}
expected := []int{2, 1, 1, 1}
visits, err := db.Visits(ctx, "0", "1", "2", "3")
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if !reflect.DeepEqual(visits, expected) {
t.Fatalf("expected visits %v, got %v", expected, visits)
}
total, err := db.TotalVisits(ctx)
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if total != 5 {
t.Fatalf("expected total %d, got %d", 8, total)
}
})
t.Run("mass removal", func(t *testing.T) {
num := 10000
db, err := SomeWalks(num)()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
before := make([]walks.Walk, num)
after := make([]walks.Walk, num)
for i := range num {
ID := walks.ID(strconv.Itoa(i))
before[i] = walks.Walk{ID: ID, Path: defaultWalk.Path} // the walk in the DB
after[i] = walks.Walk{ID: ID} // empty walk
}
if err := db.ReplaceWalks(ctx, before, after); err != nil {
t.Fatalf("expected error nil, got %v", err)
}
visits, err := db.Visits(ctx, "0", "1")
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if !reflect.DeepEqual(visits, []int{0, 0}) {
t.Fatalf("expected visits %v, got %v", []int{0, 0}, visits)
}
total, err := db.TotalVisits(ctx)
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if total != 0 {
t.Fatalf("expected total %d, got %d", 0, total)
}
})
}
func TestValidateReplacement(t *testing.T) {
tests := []struct {
name string
old []walks.Walk
new []walks.Walk
err error
}{
{
name: "no walks",
},
{
name: "different lenght",
old: []walks.Walk{{}},
err: ErrInvalidReplacement,
},
{
name: "different IDs",
old: []walks.Walk{{ID: "0"}, {ID: "1"}},
new: []walks.Walk{{ID: "1"}, {ID: "0"}},
err: ErrInvalidReplacement,
},
{
name: "repeated IDs",
old: []walks.Walk{{ID: "0"}, {ID: "0"}},
new: []walks.Walk{{ID: "0"}, {ID: "0"}},
err: ErrInvalidReplacement,
},
{
name: "valid IDs",
old: []walks.Walk{{ID: "0"}, {ID: "1"}},
new: []walks.Walk{{ID: "0"}, {ID: "1"}},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := validateReplacement(test.old, test.new)
if !errors.Is(err, test.err) {
t.Fatalf("expected error %v, got %v", test.err, err)
}
})
}
}
var defaultWalk = walks.Walk{Path: []graph.ID{"0", "1"}}
func SomeWalks(n int) func() (RedisDB, error) {
return func() (RedisDB, error) {
db := RedisDB{client: redis.NewClient(&redis.Options{Addr: testAddress})}
if err := db.client.HSet(ctx, KeyRWS, KeyAlpha, walks.Alpha, KeyWalksPerNode, walks.N).Err(); err != nil {
return RedisDB{}, err
}
for range n {
if err := db.AddWalks(ctx, defaultWalk); err != nil {
return RedisDB{}, err
}
}
return db, nil
}
}

View File

@@ -2,6 +2,7 @@ package walks
import (
"context"
"errors"
"fmt"
"github/pippellia-btc/crawler/pkg/graph"
"math/rand/v2"
@@ -11,6 +12,8 @@ import (
var (
Alpha = 0.85 // the dampening factor
N = 100 // the walks per node
ErrInvalidRemoval = errors.New(fmt.Sprintf("the walks to be removed are different than the expected number %d", N))
)
// ID represent how walks are identified in the storage layer
@@ -28,11 +31,6 @@ type Walker interface {
Follows(ctx context.Context, node graph.ID) ([]graph.ID, error)
}
// New returns a new walk with a preallocated empty path
func New(n int) Walk {
return Walk{Path: make([]graph.ID, 0, n)}
}
// Len returns the lenght of the walk
func (w Walk) Len() int {
return len(w.Path)
@@ -80,6 +78,23 @@ func (w *Walk) Graft(path []graph.ID) {
w.Path = w.Path[:pos]
}
// Divergence returns the first index where w1 and w2 are different, -1 if equal.
func Divergence(w1, w2 Walk) int {
min := min(w1.Len(), w2.Len())
for i := range min {
if w1.Path[i] != w2.Path[i] {
return i
}
}
if w1.Len() == w2.Len() {
// they are all equal, so no divergence
return -1
}
return min
}
// Generate [N] random walks for the specified node, using dampening factor [Alpha].
// A walk stops early if a cycle is encountered.
// Walk IDs are not set, because it's the responsibility of the storage layer.
@@ -146,19 +161,18 @@ func generate(ctx context.Context, walker Walker, start ...graph.ID) ([]graph.ID
return path, nil
}
// ToRemove returns the IDs of walks that needs to be removed.
// ToRemove returns the walks that need to be removed.
// It returns an error if the number of walks to remove differs from the expected [N].
func ToRemove(node graph.ID, walks []Walk) ([]ID, error) {
toRemove := make([]ID, 0, N)
func ToRemove(node graph.ID, walks []Walk) ([]Walk, error) {
toRemove := make([]Walk, 0, N)
for _, walk := range walks {
if walk.Index(node) != -1 {
toRemove = append(toRemove, walk.ID)
if walk.Index(node) == 0 {
toRemove = append(toRemove, walk)
}
}
if len(toRemove) != N {
return toRemove, fmt.Errorf("walks to be removed (%d) are different than expected (%d)", len(toRemove), N)
return nil, fmt.Errorf("ToRemove: %w: %d", ErrInvalidRemoval, len(toRemove))
}
return toRemove, nil

View File

@@ -2,6 +2,7 @@ package walks
import (
"context"
"errors"
"fmt"
"github/pippellia-btc/crawler/pkg/graph"
"math"
@@ -53,6 +54,53 @@ func TestGenerate(t *testing.T) {
})
}
func TestToRemove(t *testing.T) {
N = 3
tests := []struct {
name string
walks []Walk
toRemove []Walk
err error
}{
{
name: "no walks",
err: ErrInvalidRemoval,
},
{
name: "too few walks to remove",
walks: []Walk{{Path: []graph.ID{"0", "1"}}},
err: ErrInvalidRemoval,
},
{
name: "valid",
walks: []Walk{
{Path: []graph.ID{"0", "1"}},
{Path: []graph.ID{"0", "2"}},
{Path: []graph.ID{"0", "3"}},
{Path: []graph.ID{"1", "0"}},
},
toRemove: []Walk{
{Path: []graph.ID{"0", "1"}},
{Path: []graph.ID{"0", "2"}},
{Path: []graph.ID{"0", "3"}},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
toRemove, err := ToRemove("0", test.walks)
if !errors.Is(err, test.err) {
t.Fatalf("expected error %v, got %v", test.err, err)
}
if !reflect.DeepEqual(toRemove, test.toRemove) {
t.Fatalf("expected walks to remove %v, got %v", test.toRemove, toRemove)
}
})
}
}
func TestUpdateRemove(t *testing.T) {
walker := NewWalker(map[graph.ID][]graph.ID{
"0": {"3"},
@@ -85,6 +133,26 @@ func TestUpdateRemove(t *testing.T) {
}
}
func TestDivergence(t *testing.T) {
tests := []struct {
w1 Walk
w2 Walk
expected int
}{
{w1: Walk{Path: []graph.ID{"0"}}, w2: Walk{Path: []graph.ID{"0", "1"}}, expected: 1},
{w1: Walk{Path: []graph.ID{"0", "1", "69"}}, w2: Walk{Path: []graph.ID{"0", "1"}}, expected: 2},
{w1: Walk{Path: []graph.ID{"0", "1", "69"}}, w2: Walk{Path: []graph.ID{"0", "1", "420"}}, expected: 2},
{w1: Walk{Path: []graph.ID{"a", "b", "c"}}, w2: Walk{Path: []graph.ID{"a", "b", "c"}}, expected: -1},
{expected: -1},
}
for i, test := range tests {
if div := Divergence(test.w1, test.w2); div != test.expected {
t.Fatalf("test %d: expected %d, got %v", i, test.expected, div)
}
}
}
func TestFindCycle(t *testing.T) {
tests := []struct {
list []graph.ID