mirror of
https://github.com/aljazceru/crawler_v2.git
synced 2025-12-17 07:24:21 +01:00
processor is here
This commit is contained in:
@@ -4,6 +4,9 @@ package graph
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"math/rand/v2"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -64,6 +67,46 @@ type Delta struct {
|
|||||||
Add []ID
|
Add []ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewDelta returns a delta by computing the relationships to remove, keep and add.
|
||||||
|
// Time complexity O(n * logn + m * logm), where n and m are the lengths of the slices.
|
||||||
|
// This function is much faster than converting to sets for sizes (n, m) smaller than ~10^6.
|
||||||
|
func NewDelta(kind int, node ID, old, new []ID) Delta {
|
||||||
|
delta := Delta{
|
||||||
|
Kind: kind,
|
||||||
|
Node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.Sort(old)
|
||||||
|
slices.Sort(new)
|
||||||
|
i, j := 0, 0
|
||||||
|
oldLen, newLen := len(old), len(new)
|
||||||
|
|
||||||
|
for i < oldLen && j < newLen {
|
||||||
|
switch {
|
||||||
|
case old[i] < new[j]:
|
||||||
|
// ID is in old but not in new => remove
|
||||||
|
delta.Remove = append(delta.Remove, old[i])
|
||||||
|
i++
|
||||||
|
|
||||||
|
case old[i] > new[j]:
|
||||||
|
// ID is in new but not in old => add
|
||||||
|
delta.Add = append(delta.Add, new[j])
|
||||||
|
j++
|
||||||
|
|
||||||
|
default:
|
||||||
|
// ID is in both => keep
|
||||||
|
delta.Keep = append(delta.Keep, old[i])
|
||||||
|
i++
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add all elements not traversed
|
||||||
|
delta.Remove = append(delta.Remove, old[i:]...)
|
||||||
|
delta.Add = append(delta.Add, new[j:]...)
|
||||||
|
return delta
|
||||||
|
}
|
||||||
|
|
||||||
// Size returns the number of relationships changed by delta
|
// Size returns the number of relationships changed by delta
|
||||||
func (d Delta) Size() int {
|
func (d Delta) Size() int {
|
||||||
return len(d.Remove) + len(d.Add)
|
return len(d.Remove) + len(d.Add)
|
||||||
@@ -89,3 +132,13 @@ func (d Delta) Inverse() Delta {
|
|||||||
Add: d.Remove,
|
Add: d.Remove,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RandomIDs of the provided size.
|
||||||
|
func RandomIDs(size int) []ID {
|
||||||
|
IDs := make([]ID, size)
|
||||||
|
for i := range size {
|
||||||
|
node := rand.IntN(10000000)
|
||||||
|
IDs[i] = ID(strconv.Itoa(node))
|
||||||
|
}
|
||||||
|
return IDs
|
||||||
|
}
|
||||||
|
|||||||
118
pkg/graph/graph_test.go
Normal file
118
pkg/graph/graph_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package graph
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDelta(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
old []ID
|
||||||
|
new []ID
|
||||||
|
expected Delta
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil slices",
|
||||||
|
expected: Delta{Kind: 3, Node: "0"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty slices",
|
||||||
|
expected: Delta{Kind: 3, Node: "0"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only removals",
|
||||||
|
old: []ID{"0", "1", "2", "19", "111"},
|
||||||
|
new: []ID{"2", "19"},
|
||||||
|
expected: Delta{Kind: 3, Node: "0", Remove: []ID{"0", "1", "111"}, Keep: []ID{"19", "2"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only additions",
|
||||||
|
old: []ID{"0", "1"},
|
||||||
|
new: []ID{"420", "0", "1", "69"},
|
||||||
|
expected: Delta{Kind: 3, Node: "0", Keep: []ID{"0", "1"}, Add: []ID{"420", "69"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both additions",
|
||||||
|
old: []ID{"0", "1", "111"},
|
||||||
|
new: []ID{"420", "0", "1", "69"},
|
||||||
|
expected: Delta{Kind: 3, Node: "0", Remove: []ID{"111"}, Keep: []ID{"0", "1"}, Add: []ID{"420", "69"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
delta := NewDelta(3, "0", test.old, test.new)
|
||||||
|
if !reflect.DeepEqual(delta, test.expected) {
|
||||||
|
t.Errorf("expected delta %v, got %v", test.expected, delta)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNewDelta(b *testing.B) {
|
||||||
|
sizes := []int{1000, 10000, 100000}
|
||||||
|
for _, size := range sizes {
|
||||||
|
b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) {
|
||||||
|
old := RandomIDs(size)
|
||||||
|
new := RandomIDs(size)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
NewDelta(3, "0", old, new)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNewDeltaSets(b *testing.B) {
|
||||||
|
sizes := []int{1000, 10000, 100000}
|
||||||
|
for _, size := range sizes {
|
||||||
|
b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) {
|
||||||
|
old := RandomIDs(size)
|
||||||
|
new := RandomIDs(size)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
newDeltaSet(3, "0", old, new)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDeltaSet(kind int, node ID, old, new []ID) Delta {
|
||||||
|
delta := Delta{
|
||||||
|
Kind: kind,
|
||||||
|
Node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
oldMap := make(map[ID]struct{}, len(old))
|
||||||
|
newMap := make(map[ID]struct{}, len(new))
|
||||||
|
|
||||||
|
// Fill maps
|
||||||
|
for _, id := range old {
|
||||||
|
oldMap[id] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, id := range new {
|
||||||
|
newMap[id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find removed and kept
|
||||||
|
for _, id := range old {
|
||||||
|
if _, found := newMap[id]; found {
|
||||||
|
delta.Keep = append(delta.Keep, id)
|
||||||
|
} else {
|
||||||
|
delta.Remove = append(delta.Remove, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find added
|
||||||
|
for _, id := range new {
|
||||||
|
if _, found := oldMap[id]; !found {
|
||||||
|
delta.Add = append(delta.Add, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return delta
|
||||||
|
}
|
||||||
187
pkg/pipe/processor.go
Normal file
187
pkg/pipe/processor.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package pipe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github/pippellia-btc/crawler/pkg/graph"
|
||||||
|
"github/pippellia-btc/crawler/pkg/redb"
|
||||||
|
"github/pippellia-btc/crawler/pkg/walks"
|
||||||
|
"log"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/nbd-wtf/go-nostr"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrUnsupportedKind = errors.New("unsupported event kind")
|
||||||
|
|
||||||
|
type ProcessorConfig struct {
|
||||||
|
PrintEvery int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProcessEventsConfig() ProcessorConfig {
|
||||||
|
return ProcessorConfig{PrintEvery: 5000}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ProcessorConfig) Print() {
|
||||||
|
fmt.Printf("Processor\n")
|
||||||
|
fmt.Printf(" PrintEvery: %d\n", c.PrintEvery)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Processor() process one event at the time from the eventChannel, based on their kind.
|
||||||
|
func Processor(
|
||||||
|
ctx context.Context,
|
||||||
|
config ProcessorConfig,
|
||||||
|
db redb.RedisDB,
|
||||||
|
//store *eventstore.Store,
|
||||||
|
events chan *nostr.Event) {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var processed int
|
||||||
|
|
||||||
|
cache := walks.NewWalker(
|
||||||
|
walks.WithCapacity(10000),
|
||||||
|
walks.WithFallback(db),
|
||||||
|
)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Println("Processor: shutting down...")
|
||||||
|
return
|
||||||
|
|
||||||
|
case event := <-events:
|
||||||
|
switch event.Kind {
|
||||||
|
case nostr.KindFollowList:
|
||||||
|
err = processFollowList(cache, db, event)
|
||||||
|
|
||||||
|
case nostr.KindProfileMetadata:
|
||||||
|
err = nil //HandleProfileMetadata(eventStore, event)
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = ErrUnsupportedKind
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Processor: event ID %s, kind %d by %s: %v", event.ID, event.Kind, event.PubKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
processed++
|
||||||
|
if processed%config.PrintEvery == 0 {
|
||||||
|
log.Printf("Processor: processed %d events", processed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func processFollowList(cache *walks.CachedWalker, db redb.RedisDB, event *nostr.Event) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
author, err := db.NodeByKey(ctx, event.PubKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldFollows, err := cache.Follows(ctx, author.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pubkeys := parsePubkeys(event)
|
||||||
|
onMissing := redb.Ignore
|
||||||
|
if author.Status == graph.StatusActive {
|
||||||
|
// active nodes are the only ones that can add new pubkeys to the database
|
||||||
|
onMissing = redb.AddValid
|
||||||
|
}
|
||||||
|
|
||||||
|
newFollows, err := db.Resolve(ctx, pubkeys, onMissing)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := graph.NewDelta(event.Kind, author.ID, oldFollows, newFollows)
|
||||||
|
if delta.Size() == 0 {
|
||||||
|
// old and new follows are the same, stop
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
visiting, err := db.WalksVisiting(ctx, author.ID, -1)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
old, new, err := walks.ToUpdate(ctx, cache, delta, visiting)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.ReplaceWalks(ctx, old, new); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Update(ctx, delta); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cache.Update(ctx, delta)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
followPrefix = "p"
|
||||||
|
maxFollows = 50000
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParsePubkeys returns the slice of pubkeys that are correctly listed in the nostr.Tags.
|
||||||
|
// - Badly formatted tags are ignored.
|
||||||
|
// - Pubkeys will be uniquely added (no repetitions).
|
||||||
|
// - The author of the event will be removed from the followed pubkeys if present.
|
||||||
|
// - NO CHECKING the validity of the pubkeys
|
||||||
|
func parsePubkeys(event *nostr.Event) []string {
|
||||||
|
pubkeys := make([]string, 0, min(len(event.Tags), maxFollows))
|
||||||
|
for _, tag := range event.Tags {
|
||||||
|
if len(pubkeys) > maxFollows {
|
||||||
|
// stop processing, list is too big
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tag) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix, pubkey := tag[0], tag[1]
|
||||||
|
if prefix != followPrefix {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if pubkey == event.PubKey {
|
||||||
|
// remove self-follows
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pubkeys = append(pubkeys, pubkey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return unique(pubkeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unique returns a slice of unique elements of the input slice.
|
||||||
|
func unique[E cmp.Ordered](slice []E) []E {
|
||||||
|
if len(slice) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.Sort(slice)
|
||||||
|
unique := make([]E, 0, len(slice))
|
||||||
|
unique = append(unique, slice[0])
|
||||||
|
|
||||||
|
for i := 1; i < len(slice); i++ {
|
||||||
|
if slice[i] != slice[i-1] {
|
||||||
|
unique = append(unique, slice[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return unique
|
||||||
|
}
|
||||||
@@ -38,16 +38,16 @@ type RedisDB struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(opt *redis.Options) RedisDB {
|
func New(opt *redis.Options) RedisDB {
|
||||||
r := RedisDB{client: redis.NewClient(opt)}
|
db := RedisDB{client: redis.NewClient(opt)}
|
||||||
if err := r.validateWalks(); err != nil {
|
if err := db.validateWalks(); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return r
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size returns the DBSize of redis, which is the total number of keys
|
// Size returns the DBSize of redis, which is the total number of keys
|
||||||
func (r RedisDB) Size(ctx context.Context) (int, error) {
|
func (db RedisDB) Size(ctx context.Context) (int, error) {
|
||||||
size, err := r.client.DBSize(ctx).Result()
|
size, err := db.client.DBSize(ctx).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -55,8 +55,8 @@ func (r RedisDB) Size(ctx context.Context) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NodeCount returns the number of nodes stored in redis (in the keyIndex)
|
// NodeCount returns the number of nodes stored in redis (in the keyIndex)
|
||||||
func (r RedisDB) NodeCount(ctx context.Context) (int, error) {
|
func (db RedisDB) NodeCount(ctx context.Context) (int, error) {
|
||||||
nodes, err := r.client.HLen(ctx, KeyKeyIndex).Result()
|
nodes, err := db.client.HLen(ctx, KeyKeyIndex).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -64,12 +64,12 @@ func (r RedisDB) NodeCount(ctx context.Context) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Nodes fetches a slice of nodes by their IDs.
|
// Nodes fetches a slice of nodes by their IDs.
|
||||||
func (r RedisDB) Nodes(ctx context.Context, IDs ...graph.ID) ([]*graph.Node, error) {
|
func (db RedisDB) Nodes(ctx context.Context, IDs ...graph.ID) ([]*graph.Node, error) {
|
||||||
if len(IDs) == 0 {
|
if len(IDs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pipe := r.client.Pipeline()
|
pipe := db.client.Pipeline()
|
||||||
cmds := make([]*redis.MapStringStringCmd, len(IDs))
|
cmds := make([]*redis.MapStringStringCmd, len(IDs))
|
||||||
for i, ID := range IDs {
|
for i, ID := range IDs {
|
||||||
cmds[i] = pipe.HGetAll(ctx, node(ID))
|
cmds[i] = pipe.HGetAll(ctx, node(ID))
|
||||||
@@ -97,8 +97,8 @@ func (r RedisDB) Nodes(ctx context.Context, IDs ...graph.ID) ([]*graph.Node, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NodeByID fetches a node by its ID
|
// NodeByID fetches a node by its ID
|
||||||
func (r RedisDB) NodeByID(ctx context.Context, ID graph.ID) (*graph.Node, error) {
|
func (db RedisDB) NodeByID(ctx context.Context, ID graph.ID) (*graph.Node, error) {
|
||||||
fields, err := r.client.HGetAll(ctx, node(ID)).Result()
|
fields, err := db.client.HGetAll(ctx, node(ID)).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), err)
|
return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), err)
|
||||||
}
|
}
|
||||||
@@ -111,13 +111,13 @@ func (r RedisDB) NodeByID(ctx context.Context, ID graph.ID) (*graph.Node, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NodeByKey fetches a node by its pubkey
|
// NodeByKey fetches a node by its pubkey
|
||||||
func (r RedisDB) NodeByKey(ctx context.Context, pubkey string) (*graph.Node, error) {
|
func (db RedisDB) NodeByKey(ctx context.Context, pubkey string) (*graph.Node, error) {
|
||||||
ID, err := r.client.HGet(ctx, KeyKeyIndex, pubkey).Result()
|
ID, err := db.client.HGet(ctx, KeyKeyIndex, pubkey).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch ID of node with pubkey %s: %w", pubkey, err)
|
return nil, fmt.Errorf("failed to fetch ID of node with pubkey %s: %w", pubkey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fields, err := r.client.HGetAll(ctx, node(ID)).Result()
|
fields, err := db.client.HGetAll(ctx, node(ID)).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, err)
|
return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, err)
|
||||||
}
|
}
|
||||||
@@ -130,15 +130,15 @@ func (r RedisDB) NodeByKey(ctx context.Context, pubkey string) (*graph.Node, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Exists checks for the existance of the pubkey
|
// Exists checks for the existance of the pubkey
|
||||||
func (r RedisDB) Exists(ctx context.Context, pubkey string) (bool, error) {
|
func (db RedisDB) Exists(ctx context.Context, pubkey string) (bool, error) {
|
||||||
exists, err := r.client.HExists(ctx, KeyKeyIndex, pubkey).Result()
|
exists, err := db.client.HExists(ctx, KeyKeyIndex, pubkey).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to check existance of pubkey %s: %w", pubkey, err)
|
return false, fmt.Errorf("failed to check existance of pubkey %s: %w", pubkey, err)
|
||||||
}
|
}
|
||||||
return exists, nil
|
return exists, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error {
|
func (db RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error {
|
||||||
if len(IDs) == 0 {
|
if len(IDs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -148,7 +148,7 @@ func (r RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error {
|
|||||||
nodes[i] = node(ID)
|
nodes[i] = node(ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := r.client.Exists(ctx, nodes...).Result()
|
exists, err := db.client.Exists(ctx, nodes...).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check for the existence of %d nodes: %w", len(IDs), err)
|
return fmt.Errorf("failed to check for the existence of %d nodes: %w", len(IDs), err)
|
||||||
}
|
}
|
||||||
@@ -161,8 +161,8 @@ func (r RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddNode adds a new inactive node to the database and returns its assigned ID
|
// AddNode adds a new inactive node to the database and returns its assigned ID
|
||||||
func (r RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) {
|
func (db RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) {
|
||||||
exists, err := r.client.HExists(ctx, KeyKeyIndex, pubkey).Result()
|
exists, err := db.client.HExists(ctx, KeyKeyIndex, pubkey).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to check for existence of pubkey %s: %w", pubkey, err)
|
return "", fmt.Errorf("failed to check for existence of pubkey %s: %w", pubkey, err)
|
||||||
}
|
}
|
||||||
@@ -173,13 +173,13 @@ func (r RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) {
|
|||||||
|
|
||||||
// get the ID outside the transaction, which implies there might be "holes",
|
// get the ID outside the transaction, which implies there might be "holes",
|
||||||
// meaning IDs not associated with any node
|
// meaning IDs not associated with any node
|
||||||
next, err := r.client.HIncrBy(ctx, KeyDatabase, KeyLastNodeID, 1).Result()
|
next, err := db.client.HIncrBy(ctx, KeyDatabase, KeyLastNodeID, 1).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to add node with pubkey %s: failed to increment ID", pubkey)
|
return "", fmt.Errorf("failed to add node with pubkey %s: failed to increment ID", pubkey)
|
||||||
}
|
}
|
||||||
ID := strconv.FormatInt(next-1, 10)
|
ID := strconv.FormatInt(next-1, 10)
|
||||||
|
|
||||||
pipe := r.client.TxPipeline()
|
pipe := db.client.TxPipeline()
|
||||||
pipe.HSetNX(ctx, KeyKeyIndex, pubkey, ID)
|
pipe.HSetNX(ctx, KeyKeyIndex, pubkey, ID)
|
||||||
pipe.HSet(ctx, node(ID), NodeID, ID, NodePubkey, pubkey, NodeStatus, graph.StatusInactive, NodeAddedTS, time.Now().Unix())
|
pipe.HSet(ctx, node(ID), NodeID, ID, NodePubkey, pubkey, NodeStatus, graph.StatusInactive, NodeAddedTS, time.Now().Unix())
|
||||||
if _, err := pipe.Exec(ctx); err != nil {
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
@@ -190,8 +190,8 @@ func (r RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Promote changes the node status to active
|
// Promote changes the node status to active
|
||||||
func (r RedisDB) Promote(ctx context.Context, ID graph.ID) error {
|
func (db RedisDB) Promote(ctx context.Context, ID graph.ID) error {
|
||||||
err := r.client.HSet(ctx, node(ID), NodeStatus, graph.StatusActive, NodePromotionTS, time.Now().Unix()).Err()
|
err := db.client.HSet(ctx, node(ID), NodeStatus, graph.StatusActive, NodePromotionTS, time.Now().Unix()).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to promote %s: %w", node(ID), err)
|
return fmt.Errorf("failed to promote %s: %w", node(ID), err)
|
||||||
}
|
}
|
||||||
@@ -199,8 +199,8 @@ func (r RedisDB) Promote(ctx context.Context, ID graph.ID) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Demote changes the node status to inactive
|
// Demote changes the node status to inactive
|
||||||
func (r RedisDB) Demote(ctx context.Context, ID graph.ID) error {
|
func (db RedisDB) Demote(ctx context.Context, ID graph.ID) error {
|
||||||
err := r.client.HSet(ctx, node(ID), NodeStatus, graph.StatusInactive, NodeDemotionTS, time.Now().Unix()).Err()
|
err := db.client.HSet(ctx, node(ID), NodeStatus, graph.StatusInactive, NodeDemotionTS, time.Now().Unix()).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to demote %s: %w", node(ID), err)
|
return fmt.Errorf("failed to demote %s: %w", node(ID), err)
|
||||||
}
|
}
|
||||||
@@ -208,24 +208,24 @@ func (r RedisDB) Demote(ctx context.Context, ID graph.ID) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Follows returns the follow list of node. If node is not found, it returns [graph.ErrNodeNotFound].
|
// Follows returns the follow list of node. If node is not found, it returns [graph.ErrNodeNotFound].
|
||||||
func (r RedisDB) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) {
|
func (db RedisDB) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) {
|
||||||
return r.members(ctx, follows, node)
|
return db.members(ctx, follows, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Followers returns the list of followers of node. If node is not found, it returns [graph.ErrNodeNotFound].
|
// Followers returns the list of followers of node. If node is not found, it returns [graph.ErrNodeNotFound].
|
||||||
func (r RedisDB) Followers(ctx context.Context, node graph.ID) ([]graph.ID, error) {
|
func (db RedisDB) Followers(ctx context.Context, node graph.ID) ([]graph.ID, error) {
|
||||||
return r.members(ctx, followers, node)
|
return db.members(ctx, followers, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node graph.ID) ([]graph.ID, error) {
|
func (db RedisDB) members(ctx context.Context, key func(graph.ID) string, node graph.ID) ([]graph.ID, error) {
|
||||||
members, err := r.client.SMembers(ctx, key(node)).Result()
|
members, err := db.client.SMembers(ctx, key(node)).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch %s: %w", key(node), err)
|
return nil, fmt.Errorf("failed to fetch %s: %w", key(node), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(members) == 0 {
|
if len(members) == 0 {
|
||||||
// check if there are no members because node doesn't exists
|
// check if there are no members because node doesn't exists
|
||||||
if err := r.ensureExists(ctx, node); err != nil {
|
if err := db.ensureExists(ctx, node); err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch %s: %w", key(node), err)
|
return nil, fmt.Errorf("failed to fetch %s: %w", key(node), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -235,17 +235,17 @@ func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node gr
|
|||||||
|
|
||||||
// BulkFollows returns the follow-lists of all the provided nodes.
|
// BulkFollows returns the follow-lists of all the provided nodes.
|
||||||
// Do not call on too many nodes (e.g. +100k) to avoid too many recursions.
|
// Do not call on too many nodes (e.g. +100k) to avoid too many recursions.
|
||||||
func (r RedisDB) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) {
|
func (db RedisDB) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) {
|
||||||
return r.bulkMembers(ctx, follows, nodes)
|
return db.bulkMembers(ctx, follows, nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nodes []graph.ID) ([][]graph.ID, error) {
|
func (db RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nodes []graph.ID) ([][]graph.ID, error) {
|
||||||
switch {
|
switch {
|
||||||
case len(nodes) == 0:
|
case len(nodes) == 0:
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|
||||||
case len(nodes) < 10000:
|
case len(nodes) < 10000:
|
||||||
pipe := r.client.Pipeline()
|
pipe := db.client.Pipeline()
|
||||||
cmds := make([]*redis.StringSliceCmd, len(nodes))
|
cmds := make([]*redis.StringSliceCmd, len(nodes))
|
||||||
|
|
||||||
for i, node := range nodes {
|
for i, node := range nodes {
|
||||||
@@ -270,7 +270,7 @@ func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nod
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(empty) > 0 {
|
if len(empty) > 0 {
|
||||||
err := r.ensureExists(ctx, empty...)
|
err := db.ensureExists(ctx, empty...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch the %s of these nodes %v: %w", key(""), empty, err)
|
return nil, fmt.Errorf("failed to fetch the %s of these nodes %v: %w", key(""), empty, err)
|
||||||
}
|
}
|
||||||
@@ -281,12 +281,12 @@ func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nod
|
|||||||
default:
|
default:
|
||||||
// too many nodes, split them in two batches
|
// too many nodes, split them in two batches
|
||||||
mid := len(nodes) / 2
|
mid := len(nodes) / 2
|
||||||
batch1, err := r.bulkMembers(ctx, key, nodes[:mid])
|
batch1, err := db.bulkMembers(ctx, key, nodes[:mid])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
batch2, err := r.bulkMembers(ctx, key, nodes[mid:])
|
batch2, err := db.bulkMembers(ctx, key, nodes[mid:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -296,21 +296,21 @@ func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nod
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FollowCounts returns the number of follows each node has. If a node is not found, it returns 0.
|
// FollowCounts returns the number of follows each node has. If a node is not found, it returns 0.
|
||||||
func (r RedisDB) FollowCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
|
func (db RedisDB) FollowCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
|
||||||
return r.counts(ctx, follows, nodes...)
|
return db.counts(ctx, follows, nodes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FollowerCounts returns the number of followers each node has. If a node is not found, it returns 0.
|
// FollowerCounts returns the number of followers each node has. If a node is not found, it returns 0.
|
||||||
func (r RedisDB) FollowerCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
|
func (db RedisDB) FollowerCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
|
||||||
return r.counts(ctx, followers, nodes...)
|
return db.counts(ctx, followers, nodes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes ...graph.ID) ([]int, error) {
|
func (db RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes ...graph.ID) ([]int, error) {
|
||||||
if len(nodes) == 0 {
|
if len(nodes) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pipe := r.client.Pipeline()
|
pipe := db.client.Pipeline()
|
||||||
cmds := make([]*redis.IntCmd, len(nodes))
|
cmds := make([]*redis.IntCmd, len(nodes))
|
||||||
|
|
||||||
for i, node := range nodes {
|
for i, node := range nodes {
|
||||||
@@ -330,19 +330,19 @@ func (r RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes ..
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update applies the delta to the graph.
|
// Update applies the delta to the graph.
|
||||||
func (r RedisDB) Update(ctx context.Context, delta *graph.Delta) error {
|
func (db RedisDB) Update(ctx context.Context, delta graph.Delta) error {
|
||||||
if delta.Size() == 0 {
|
if delta.Size() == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := r.ensureExists(ctx, delta.Node)
|
err := db.ensureExists(ctx, delta.Node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update with delta %v: %w", delta, err)
|
return fmt.Errorf("failed to update with delta %v: %w", delta, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch delta.Kind {
|
switch delta.Kind {
|
||||||
case nostr.KindFollowList:
|
case nostr.KindFollowList:
|
||||||
err = r.updateFollows(ctx, delta)
|
err = db.updateFollows(ctx, delta)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unsupported kind %d", delta.Kind)
|
err = fmt.Errorf("unsupported kind %d", delta.Kind)
|
||||||
@@ -355,8 +355,8 @@ func (r RedisDB) Update(ctx context.Context, delta *graph.Delta) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r RedisDB) updateFollows(ctx context.Context, delta *graph.Delta) error {
|
func (db RedisDB) updateFollows(ctx context.Context, delta graph.Delta) error {
|
||||||
pipe := r.client.TxPipeline()
|
pipe := db.client.TxPipeline()
|
||||||
if len(delta.Add) > 0 {
|
if len(delta.Add) > 0 {
|
||||||
// add all node --> added
|
// add all node --> added
|
||||||
pipe.SAdd(ctx, follows(delta.Node), toStrings(delta.Add))
|
pipe.SAdd(ctx, follows(delta.Node), toStrings(delta.Add))
|
||||||
@@ -370,8 +370,8 @@ func (r RedisDB) updateFollows(ctx context.Context, delta *graph.Delta) error {
|
|||||||
// remove all node --> removed
|
// remove all node --> removed
|
||||||
pipe.SRem(ctx, follows(delta.Node), toStrings(delta.Remove))
|
pipe.SRem(ctx, follows(delta.Node), toStrings(delta.Remove))
|
||||||
|
|
||||||
for _, r := range delta.Remove {
|
for _, db := range delta.Remove {
|
||||||
pipe.SRem(ctx, followers(r), delta.Node)
|
pipe.SRem(ctx, followers(db), delta.Node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -384,12 +384,12 @@ func (r RedisDB) updateFollows(ctx context.Context, delta *graph.Delta) error {
|
|||||||
|
|
||||||
// NodeIDs returns a slice of node IDs assosiated with the pubkeys.
|
// NodeIDs returns a slice of node IDs assosiated with the pubkeys.
|
||||||
// If a pubkey is not found, an empty ID "" is returned
|
// If a pubkey is not found, an empty ID "" is returned
|
||||||
func (r RedisDB) NodeIDs(ctx context.Context, pubkeys ...string) ([]graph.ID, error) {
|
func (db RedisDB) NodeIDs(ctx context.Context, pubkeys ...string) ([]graph.ID, error) {
|
||||||
if len(pubkeys) == 0 {
|
if len(pubkeys) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
IDs, err := r.client.HMGet(ctx, KeyKeyIndex, pubkeys...).Result()
|
IDs, err := db.client.HMGet(ctx, KeyKeyIndex, pubkeys...).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch the node IDs of %v: %w", pubkeys, err)
|
return nil, fmt.Errorf("failed to fetch the node IDs of %v: %w", pubkeys, err)
|
||||||
}
|
}
|
||||||
@@ -398,10 +398,12 @@ func (r RedisDB) NodeIDs(ctx context.Context, pubkeys ...string) ([]graph.ID, er
|
|||||||
for i, ID := range IDs {
|
for i, ID := range IDs {
|
||||||
switch ID {
|
switch ID {
|
||||||
case nil:
|
case nil:
|
||||||
nodes[i] = "" // empty ID means missing pubkey
|
// empty ID means missing pubkey
|
||||||
|
nodes[i] = ""
|
||||||
|
|
||||||
default:
|
default:
|
||||||
nodes[i] = graph.ID(ID.(string)) // no need to type-assert because everything in redis is a string
|
// direct type convertion because everything in redis is a string
|
||||||
|
nodes[i] = graph.ID(ID.(string))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -410,12 +412,12 @@ func (r RedisDB) NodeIDs(ctx context.Context, pubkeys ...string) ([]graph.ID, er
|
|||||||
|
|
||||||
// Pubkeys returns a slice of pubkeys assosiated with the node IDs.
|
// Pubkeys returns a slice of pubkeys assosiated with the node IDs.
|
||||||
// If a node ID is not found, an empty pubkey "" is returned
|
// If a node ID is not found, an empty pubkey "" is returned
|
||||||
func (r RedisDB) Pubkeys(ctx context.Context, nodes ...graph.ID) ([]string, error) {
|
func (db RedisDB) Pubkeys(ctx context.Context, nodes ...graph.ID) ([]string, error) {
|
||||||
if len(nodes) == 0 {
|
if len(nodes) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pipe := r.client.Pipeline()
|
pipe := db.client.Pipeline()
|
||||||
cmds := make([]*redis.StringCmd, len(nodes))
|
cmds := make([]*redis.StringCmd, len(nodes))
|
||||||
for i, ID := range nodes {
|
for i, ID := range nodes {
|
||||||
cmds[i] = pipe.HGet(ctx, node(ID), NodePubkey)
|
cmds[i] = pipe.HGet(ctx, node(ID), NodePubkey)
|
||||||
@@ -443,12 +445,56 @@ func (r RedisDB) Pubkeys(ctx context.Context, nodes ...graph.ID) ([]string, erro
|
|||||||
return pubkeys, nil
|
return pubkeys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MissingHandler func(ctx context.Context, db RedisDB, pubkey string) (graph.ID, error)
|
||||||
|
|
||||||
|
func Ignore(context.Context, RedisDB, string) (graph.ID, error) { return "", nil }
|
||||||
|
func Sentinel(context.Context, RedisDB, string) (graph.ID, error) { return "-1", nil }
|
||||||
|
|
||||||
|
func AddValid(ctx context.Context, db RedisDB, pubkey string) (graph.ID, error) {
|
||||||
|
if !nostr.IsValidPublicKey(pubkey) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return db.AddNode(ctx, pubkey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve pubkeys into node IDs. If a pubkey is missing (ID = ""), it applies the onMissing handler.
|
||||||
|
func (db RedisDB) Resolve(ctx context.Context, pubkeys []string, onMissing MissingHandler) ([]graph.ID, error) {
|
||||||
|
IDs, err := db.NodeIDs(ctx, pubkeys...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to resolve pubkeys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
j := 0 // write index
|
||||||
|
for i, ID := range IDs {
|
||||||
|
switch ID {
|
||||||
|
case "":
|
||||||
|
ID, err = onMissing(ctx, db, pubkeys[i])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to resolve pubkey %q: %w", pubkeys[i], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ID != "" {
|
||||||
|
IDs[j] = ID
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
if j != i {
|
||||||
|
// write only if necessary
|
||||||
|
IDs[j] = ID
|
||||||
|
}
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return IDs[:j], nil
|
||||||
|
}
|
||||||
|
|
||||||
// ScanNodes to return a batch of node IDs of size roughly proportional to limit.
|
// ScanNodes to return a batch of node IDs of size roughly proportional to limit.
|
||||||
// Limit controls how much "work" is invested in fetching the batch, hence it's not precise.
|
// Limit controls how much "work" is invested in fetching the batch, hence it's not precise.
|
||||||
// Learn more about scan: https://redis.io/docs/latest/commands/scan/
|
// Learn more about scan: https://redis.io/docs/latest/commands/scan/
|
||||||
func (r RedisDB) ScanNodes(ctx context.Context, cursor uint64, limit int) ([]graph.ID, uint64, error) {
|
func (db RedisDB) ScanNodes(ctx context.Context, cursor uint64, limit int) ([]graph.ID, uint64, error) {
|
||||||
match := KeyNodePrefix + "*"
|
match := KeyNodePrefix + "*"
|
||||||
keys, cursor, err := r.client.Scan(ctx, cursor, match, int64(limit)).Result()
|
keys, cursor, err := db.client.Scan(ctx, cursor, match, int64(limit)).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to scan for keys matching %s: %w", match, err)
|
return nil, 0, fmt.Errorf("failed to scan for keys matching %s: %w", match, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ func TestUpdateFollows(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer db.flushAll()
|
defer db.flushAll()
|
||||||
|
|
||||||
delta := &graph.Delta{
|
delta := graph.Delta{
|
||||||
Kind: nostr.KindFollowList,
|
Kind: nostr.KindFollowList,
|
||||||
Node: "0",
|
Node: "0",
|
||||||
Remove: []graph.ID{"1"},
|
Remove: []graph.ID{"1"},
|
||||||
@@ -324,6 +324,57 @@ func TestNodeIDs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResolve(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func() (RedisDB, error)
|
||||||
|
pubkeys []string
|
||||||
|
onMissing MissingHandler
|
||||||
|
expected []graph.ID
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty database",
|
||||||
|
setup: Empty,
|
||||||
|
pubkeys: []string{"0"},
|
||||||
|
onMissing: Sentinel,
|
||||||
|
expected: []graph.ID{"-1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "node not found, ignore",
|
||||||
|
setup: OneNode,
|
||||||
|
pubkeys: []string{"1"},
|
||||||
|
onMissing: Ignore,
|
||||||
|
expected: []graph.ID{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid",
|
||||||
|
setup: Simple,
|
||||||
|
pubkeys: []string{"0", "69", "1"},
|
||||||
|
onMissing: Ignore,
|
||||||
|
expected: []graph.ID{"0", "1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
db, err := test.setup()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("setup failed: %v", err)
|
||||||
|
}
|
||||||
|
defer db.flushAll()
|
||||||
|
|
||||||
|
nodes, err := db.Resolve(ctx, test.pubkeys, test.onMissing)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected error nil, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(nodes, test.expected) {
|
||||||
|
t.Fatalf("expected nodes %v, got %v", test.expected, nodes)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPubkeys(t *testing.T) {
|
func TestPubkeys(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func (c *CachedWalker) Add(node graph.ID, follows []graph.ID) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add node and follows as edges. It evicts the LRU element if the capacity has been exeeded.
|
// Add node and follows as [edges]. It evicts the LRU element if the capacity has been exeeded.
|
||||||
func (c *CachedWalker) add(node uint32, follows []uint32) {
|
func (c *CachedWalker) add(node uint32, follows []uint32) {
|
||||||
if e, ok := c.lookup[node]; ok {
|
if e, ok := c.lookup[node]; ok {
|
||||||
// node already present, update value
|
// node already present, update value
|
||||||
|
|||||||
@@ -34,9 +34,23 @@ func (w Walk) Len() int {
|
|||||||
return len(w.Path)
|
return len(w.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visits returns whether the walk visited node
|
// Visits returns whether the walk visits any of the nodes
|
||||||
func (w Walk) Visits(node graph.ID) bool {
|
func (w Walk) Visits(nodes ...graph.ID) bool {
|
||||||
return slices.Contains(w.Path, node)
|
for _, node := range nodes {
|
||||||
|
if slices.Contains(w.Path, node) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisitsAt returns whether the walk visits any of the nodes at the specified step.
|
||||||
|
// If the step is outside the bouds of the walk, it returns false.
|
||||||
|
func (w Walk) VisitsAt(step int, nodes ...graph.ID) bool {
|
||||||
|
if step < 0 || step >= w.Len() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return slices.Contains(nodes, w.Path[step])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Index returns the index of node in the walk, or -1 if not present
|
// Index returns the index of node in the walk, or -1 if not present
|
||||||
@@ -89,7 +103,6 @@ func Divergence(w1, w2 Walk) int {
|
|||||||
// they are all equal, so no divergence
|
// they are all equal, so no divergence
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
return min
|
return min
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,9 +189,12 @@ func ToRemove(node graph.ID, walks []Walk) ([]Walk, error) {
|
|||||||
return toRemove, nil
|
return toRemove, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Walk) ([]Walk, error) {
|
// ToUpdate returns how the old walks need to be updated to reflect the changes in the graph.
|
||||||
toUpdate := make([]Walk, 0, expectedUpdates(walks, delta))
|
// In particular, it corrects invalid steps and resamples in order to maintain the correct distribution.
|
||||||
|
func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Walk) (old, new []Walk, err error) {
|
||||||
resampleProbability := resampleProbability(delta)
|
resampleProbability := resampleProbability(delta)
|
||||||
|
old = make([]Walk, 0, expectedUpdates(walks, delta))
|
||||||
|
new = make([]Walk, 0, expectedUpdates(walks, delta))
|
||||||
|
|
||||||
for _, walk := range walks {
|
for _, walk := range walks {
|
||||||
pos := walk.Index(delta.Node)
|
pos := walk.Index(delta.Node)
|
||||||
@@ -187,43 +203,44 @@ func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Wal
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldResample := rand.Float64() < resampleProbability
|
resample := rand.Float64() < resampleProbability
|
||||||
isInvalid := (pos < walk.Len()-1) && slices.Contains(delta.Remove, walk.Path[pos+1])
|
invalid := walk.VisitsAt(pos+1, delta.Remove...)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case shouldResample:
|
case resample:
|
||||||
// prune and graft with the added nodes to avoid oversampling of common nodes
|
// prune and graft with the added nodes to avoid oversampling of common nodes
|
||||||
updated := walk.Copy()
|
updated := walk.Copy()
|
||||||
updated.Prune(pos + 1)
|
updated.Prune(pos + 1)
|
||||||
|
|
||||||
if rand.Float64() < Alpha {
|
if rand.Float64() < Alpha {
|
||||||
new, err := generate(ctx, walker, delta.Add...)
|
path, err := generate(ctx, walker, delta.Add...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err)
|
return nil, nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updated.Graft(new)
|
updated.Graft(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
toUpdate = append(toUpdate, updated)
|
old = append(old, walk)
|
||||||
|
new = append(new, updated)
|
||||||
|
|
||||||
case isInvalid:
|
case invalid:
|
||||||
// prune and graft invalid steps with the common nodes
|
// prune and graft invalid steps with the common nodes
|
||||||
updated := walk.Copy()
|
updated := walk.Copy()
|
||||||
updated.Prune(pos + 1)
|
updated.Prune(pos + 1)
|
||||||
|
|
||||||
new, err := generate(ctx, walker, delta.Keep...)
|
path, err := generate(ctx, walker, delta.Keep...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err)
|
return nil, nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updated.Graft(new)
|
updated.Graft(path)
|
||||||
toUpdate = append(toUpdate, updated)
|
old = append(old, walk)
|
||||||
|
new = append(new, updated)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return toUpdate, nil
|
return old, new, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// The resample probability that a walk needs to be changed to avoid an oversampling of common nodes.
|
// The resample probability that a walk needs to be changed to avoid an oversampling of common nodes.
|
||||||
@@ -236,9 +253,9 @@ func resampleProbability(delta graph.Delta) float64 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
c := float64(len(delta.Keep))
|
k := float64(len(delta.Keep))
|
||||||
a := float64(len(delta.Add))
|
a := float64(len(delta.Add))
|
||||||
return a / (a + c)
|
return a / (a + k)
|
||||||
}
|
}
|
||||||
|
|
||||||
func expectedUpdates(walks []Walk, delta graph.Delta) int {
|
func expectedUpdates(walks []Walk, delta graph.Delta) int {
|
||||||
@@ -248,14 +265,14 @@ func expectedUpdates(walks []Walk, delta graph.Delta) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := float64(len(delta.Remove))
|
r := float64(len(delta.Remove))
|
||||||
c := float64(len(delta.Keep))
|
k := float64(len(delta.Keep))
|
||||||
a := float64(len(delta.Add))
|
a := float64(len(delta.Add))
|
||||||
|
|
||||||
invalidProbability := Alpha * r / (r + c)
|
invalidP := Alpha * r / (r + k)
|
||||||
resampleProbability := a / (a + c)
|
resampleP := a / (a + k)
|
||||||
updateProbability := invalidProbability + resampleProbability - invalidProbability*resampleProbability
|
updateP := invalidP + resampleP - invalidP*resampleP
|
||||||
expectedUpdates := float64(len(walks)) * updateProbability
|
expected := float64(len(walks)) * updateP
|
||||||
return int(expectedUpdates + 0.5)
|
return int(expected + 0.5)
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns a random element of a slice. It panics if the slice is empty or nil.
|
// returns a random element of a slice. It panics if the slice is empty or nil.
|
||||||
|
|||||||
@@ -124,13 +124,17 @@ func TestUpdateRemove(t *testing.T) {
|
|||||||
Alpha = 1 // avoid early stopping, which makes the test deterministic
|
Alpha = 1 // avoid early stopping, which makes the test deterministic
|
||||||
expected := []Walk{{ID: "0", Path: []graph.ID{"0", "3", "2"}}}
|
expected := []Walk{{ID: "0", Path: []graph.ID{"0", "3", "2"}}}
|
||||||
|
|
||||||
toUpdate, err := ToUpdate(context.Background(), walker, delta, walks)
|
old, new, err := ToUpdate(context.Background(), walker, delta, walks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected nil, got %v", err)
|
t.Fatalf("expected nil, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(toUpdate, expected) {
|
if !reflect.DeepEqual(old, walks[:1]) {
|
||||||
t.Errorf("expected %v, got %v", expected, toUpdate)
|
t.Errorf("expected old %v, got %v", walks[:1], old)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(new, expected) {
|
||||||
|
t.Errorf("expected new %v, got %v", expected, new)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -102,11 +102,11 @@ func TestPagerankDynamic(t *testing.T) {
|
|||||||
inv := delta.Inverse()
|
inv := delta.Inverse()
|
||||||
test.walker.Update(ctx, inv)
|
test.walker.Update(ctx, inv)
|
||||||
|
|
||||||
toUpdate, err := walks.ToUpdate(ctx, test.walker, inv, rwalks)
|
_, new, err := walks.ToUpdate(ctx, test.walker, inv, rwalks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to update the walks: %v", err)
|
t.Fatalf("failed to update the walks: %v", err)
|
||||||
}
|
}
|
||||||
store.ReplaceWalks(toUpdate)
|
store.ReplaceWalks(new)
|
||||||
|
|
||||||
global, err := pagerank.Global(ctx, store, test.nodes...)
|
global, err := pagerank.Global(ctx, store, test.nodes...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user