processor is here

This commit is contained in:
pippellia-btc
2025-06-04 15:48:07 +02:00
parent 9ba3e0553f
commit 29ef016392
9 changed files with 572 additions and 96 deletions

View File

@@ -4,6 +4,9 @@ package graph
import ( import (
"errors" "errors"
"math/rand/v2"
"slices"
"strconv"
"time" "time"
) )
@@ -64,6 +67,46 @@ type Delta struct {
Add []ID Add []ID
} }
// NewDelta returns a delta by computing the relationships to remove, keep and add.
// Time complexity O(n * logn + m * logm), where n and m are the lengths of the slices.
// This function is much faster than converting to sets for sizes (n, m) smaller than ~10^6.
func NewDelta(kind int, node ID, old, new []ID) Delta {
delta := Delta{
Kind: kind,
Node: node,
}
slices.Sort(old)
slices.Sort(new)
i, j := 0, 0
oldLen, newLen := len(old), len(new)
for i < oldLen && j < newLen {
switch {
case old[i] < new[j]:
// ID is in old but not in new => remove
delta.Remove = append(delta.Remove, old[i])
i++
case old[i] > new[j]:
// ID is in new but not in old => add
delta.Add = append(delta.Add, new[j])
j++
default:
// ID is in both => keep
delta.Keep = append(delta.Keep, old[i])
i++
j++
}
}
// add all elements not traversed
delta.Remove = append(delta.Remove, old[i:]...)
delta.Add = append(delta.Add, new[j:]...)
return delta
}
// Size returns the number of relationships changed by delta // Size returns the number of relationships changed by delta
func (d Delta) Size() int { func (d Delta) Size() int {
return len(d.Remove) + len(d.Add) return len(d.Remove) + len(d.Add)
@@ -89,3 +132,13 @@ func (d Delta) Inverse() Delta {
Add: d.Remove, Add: d.Remove,
} }
} }
// RandomIDs of the provided size.
func RandomIDs(size int) []ID {
IDs := make([]ID, size)
for i := range size {
node := rand.IntN(10000000)
IDs[i] = ID(strconv.Itoa(node))
}
return IDs
}

118
pkg/graph/graph_test.go Normal file
View File

@@ -0,0 +1,118 @@
package graph
import (
"fmt"
"reflect"
"testing"
)
func TestNewDelta(t *testing.T) {
testCases := []struct {
name string
old []ID
new []ID
expected Delta
}{
{
name: "nil slices",
expected: Delta{Kind: 3, Node: "0"},
},
{
name: "empty slices",
expected: Delta{Kind: 3, Node: "0"},
},
{
name: "only removals",
old: []ID{"0", "1", "2", "19", "111"},
new: []ID{"2", "19"},
expected: Delta{Kind: 3, Node: "0", Remove: []ID{"0", "1", "111"}, Keep: []ID{"19", "2"}},
},
{
name: "only additions",
old: []ID{"0", "1"},
new: []ID{"420", "0", "1", "69"},
expected: Delta{Kind: 3, Node: "0", Keep: []ID{"0", "1"}, Add: []ID{"420", "69"}},
},
{
name: "both additions",
old: []ID{"0", "1", "111"},
new: []ID{"420", "0", "1", "69"},
expected: Delta{Kind: 3, Node: "0", Remove: []ID{"111"}, Keep: []ID{"0", "1"}, Add: []ID{"420", "69"}},
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
delta := NewDelta(3, "0", test.old, test.new)
if !reflect.DeepEqual(delta, test.expected) {
t.Errorf("expected delta %v, got %v", test.expected, delta)
}
})
}
}
func BenchmarkNewDelta(b *testing.B) {
sizes := []int{1000, 10000, 100000}
for _, size := range sizes {
b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) {
old := RandomIDs(size)
new := RandomIDs(size)
b.ResetTimer()
for range b.N {
NewDelta(3, "0", old, new)
}
})
}
}
func BenchmarkNewDeltaSets(b *testing.B) {
sizes := []int{1000, 10000, 100000}
for _, size := range sizes {
b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) {
old := RandomIDs(size)
new := RandomIDs(size)
b.ResetTimer()
for range b.N {
newDeltaSet(3, "0", old, new)
}
})
}
}
func newDeltaSet(kind int, node ID, old, new []ID) Delta {
delta := Delta{
Kind: kind,
Node: node,
}
oldMap := make(map[ID]struct{}, len(old))
newMap := make(map[ID]struct{}, len(new))
// Fill maps
for _, id := range old {
oldMap[id] = struct{}{}
}
for _, id := range new {
newMap[id] = struct{}{}
}
// Find removed and kept
for _, id := range old {
if _, found := newMap[id]; found {
delta.Keep = append(delta.Keep, id)
} else {
delta.Remove = append(delta.Remove, id)
}
}
// Find added
for _, id := range new {
if _, found := oldMap[id]; !found {
delta.Add = append(delta.Add, id)
}
}
return delta
}

187
pkg/pipe/processor.go Normal file
View File

@@ -0,0 +1,187 @@
package pipe
import (
"cmp"
"context"
"errors"
"fmt"
"github/pippellia-btc/crawler/pkg/graph"
"github/pippellia-btc/crawler/pkg/redb"
"github/pippellia-btc/crawler/pkg/walks"
"log"
"slices"
"time"
"github.com/nbd-wtf/go-nostr"
)
var ErrUnsupportedKind = errors.New("unsupported event kind")
type ProcessorConfig struct {
PrintEvery int
}
func NewProcessEventsConfig() ProcessorConfig {
return ProcessorConfig{PrintEvery: 5000}
}
func (c ProcessorConfig) Print() {
fmt.Printf("Processor\n")
fmt.Printf(" PrintEvery: %d\n", c.PrintEvery)
}
// Processor() process one event at the time from the eventChannel, based on their kind.
func Processor(
ctx context.Context,
config ProcessorConfig,
db redb.RedisDB,
//store *eventstore.Store,
events chan *nostr.Event) {
var err error
var processed int
cache := walks.NewWalker(
walks.WithCapacity(10000),
walks.WithFallback(db),
)
for {
select {
case <-ctx.Done():
log.Println("Processor: shutting down...")
return
case event := <-events:
switch event.Kind {
case nostr.KindFollowList:
err = processFollowList(cache, db, event)
case nostr.KindProfileMetadata:
err = nil //HandleProfileMetadata(eventStore, event)
default:
err = ErrUnsupportedKind
}
if err != nil {
log.Printf("Processor: event ID %s, kind %d by %s: %v", event.ID, event.Kind, event.PubKey, err)
}
processed++
if processed%config.PrintEvery == 0 {
log.Printf("Processor: processed %d events", processed)
}
}
}
}
func processFollowList(cache *walks.CachedWalker, db redb.RedisDB, event *nostr.Event) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
author, err := db.NodeByKey(ctx, event.PubKey)
if err != nil {
return err
}
oldFollows, err := cache.Follows(ctx, author.ID)
if err != nil {
return err
}
pubkeys := parsePubkeys(event)
onMissing := redb.Ignore
if author.Status == graph.StatusActive {
// active nodes are the only ones that can add new pubkeys to the database
onMissing = redb.AddValid
}
newFollows, err := db.Resolve(ctx, pubkeys, onMissing)
if err != nil {
return err
}
delta := graph.NewDelta(event.Kind, author.ID, oldFollows, newFollows)
if delta.Size() == 0 {
// old and new follows are the same, stop
return nil
}
visiting, err := db.WalksVisiting(ctx, author.ID, -1)
if err != nil {
return err
}
old, new, err := walks.ToUpdate(ctx, cache, delta, visiting)
if err != nil {
return err
}
if err := db.ReplaceWalks(ctx, old, new); err != nil {
return err
}
if err := db.Update(ctx, delta); err != nil {
return err
}
return cache.Update(ctx, delta)
}
const (
followPrefix = "p"
maxFollows = 50000
)
// ParsePubkeys returns the slice of pubkeys that are correctly listed in the nostr.Tags.
// - Badly formatted tags are ignored.
// - Pubkeys will be uniquely added (no repetitions).
// - The author of the event will be removed from the followed pubkeys if present.
// - NO CHECKING the validity of the pubkeys
func parsePubkeys(event *nostr.Event) []string {
pubkeys := make([]string, 0, min(len(event.Tags), maxFollows))
for _, tag := range event.Tags {
if len(pubkeys) > maxFollows {
// stop processing, list is too big
break
}
if len(tag) < 2 {
continue
}
prefix, pubkey := tag[0], tag[1]
if prefix != followPrefix {
continue
}
if pubkey == event.PubKey {
// remove self-follows
continue
}
pubkeys = append(pubkeys, pubkey)
}
return unique(pubkeys)
}
// Unique returns a slice of unique elements of the input slice.
func unique[E cmp.Ordered](slice []E) []E {
if len(slice) == 0 {
return nil
}
slices.Sort(slice)
unique := make([]E, 0, len(slice))
unique = append(unique, slice[0])
for i := 1; i < len(slice); i++ {
if slice[i] != slice[i-1] {
unique = append(unique, slice[i])
}
}
return unique
}

View File

@@ -38,16 +38,16 @@ type RedisDB struct {
} }
func New(opt *redis.Options) RedisDB { func New(opt *redis.Options) RedisDB {
r := RedisDB{client: redis.NewClient(opt)} db := RedisDB{client: redis.NewClient(opt)}
if err := r.validateWalks(); err != nil { if err := db.validateWalks(); err != nil {
panic(err) panic(err)
} }
return r return db
} }
// Size returns the DBSize of redis, which is the total number of keys // Size returns the DBSize of redis, which is the total number of keys
func (r RedisDB) Size(ctx context.Context) (int, error) { func (db RedisDB) Size(ctx context.Context) (int, error) {
size, err := r.client.DBSize(ctx).Result() size, err := db.client.DBSize(ctx).Result()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -55,8 +55,8 @@ func (r RedisDB) Size(ctx context.Context) (int, error) {
} }
// NodeCount returns the number of nodes stored in redis (in the keyIndex) // NodeCount returns the number of nodes stored in redis (in the keyIndex)
func (r RedisDB) NodeCount(ctx context.Context) (int, error) { func (db RedisDB) NodeCount(ctx context.Context) (int, error) {
nodes, err := r.client.HLen(ctx, KeyKeyIndex).Result() nodes, err := db.client.HLen(ctx, KeyKeyIndex).Result()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -64,12 +64,12 @@ func (r RedisDB) NodeCount(ctx context.Context) (int, error) {
} }
// Nodes fetches a slice of nodes by their IDs. // Nodes fetches a slice of nodes by their IDs.
func (r RedisDB) Nodes(ctx context.Context, IDs ...graph.ID) ([]*graph.Node, error) { func (db RedisDB) Nodes(ctx context.Context, IDs ...graph.ID) ([]*graph.Node, error) {
if len(IDs) == 0 { if len(IDs) == 0 {
return nil, nil return nil, nil
} }
pipe := r.client.Pipeline() pipe := db.client.Pipeline()
cmds := make([]*redis.MapStringStringCmd, len(IDs)) cmds := make([]*redis.MapStringStringCmd, len(IDs))
for i, ID := range IDs { for i, ID := range IDs {
cmds[i] = pipe.HGetAll(ctx, node(ID)) cmds[i] = pipe.HGetAll(ctx, node(ID))
@@ -97,8 +97,8 @@ func (r RedisDB) Nodes(ctx context.Context, IDs ...graph.ID) ([]*graph.Node, err
} }
// NodeByID fetches a node by its ID // NodeByID fetches a node by its ID
func (r RedisDB) NodeByID(ctx context.Context, ID graph.ID) (*graph.Node, error) { func (db RedisDB) NodeByID(ctx context.Context, ID graph.ID) (*graph.Node, error) {
fields, err := r.client.HGetAll(ctx, node(ID)).Result() fields, err := db.client.HGetAll(ctx, node(ID)).Result()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), err) return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), err)
} }
@@ -111,13 +111,13 @@ func (r RedisDB) NodeByID(ctx context.Context, ID graph.ID) (*graph.Node, error)
} }
// NodeByKey fetches a node by its pubkey // NodeByKey fetches a node by its pubkey
func (r RedisDB) NodeByKey(ctx context.Context, pubkey string) (*graph.Node, error) { func (db RedisDB) NodeByKey(ctx context.Context, pubkey string) (*graph.Node, error) {
ID, err := r.client.HGet(ctx, KeyKeyIndex, pubkey).Result() ID, err := db.client.HGet(ctx, KeyKeyIndex, pubkey).Result()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch ID of node with pubkey %s: %w", pubkey, err) return nil, fmt.Errorf("failed to fetch ID of node with pubkey %s: %w", pubkey, err)
} }
fields, err := r.client.HGetAll(ctx, node(ID)).Result() fields, err := db.client.HGetAll(ctx, node(ID)).Result()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, err) return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, err)
} }
@@ -130,15 +130,15 @@ func (r RedisDB) NodeByKey(ctx context.Context, pubkey string) (*graph.Node, err
} }
// Exists checks for the existance of the pubkey // Exists checks for the existance of the pubkey
func (r RedisDB) Exists(ctx context.Context, pubkey string) (bool, error) { func (db RedisDB) Exists(ctx context.Context, pubkey string) (bool, error) {
exists, err := r.client.HExists(ctx, KeyKeyIndex, pubkey).Result() exists, err := db.client.HExists(ctx, KeyKeyIndex, pubkey).Result()
if err != nil { if err != nil {
return false, fmt.Errorf("failed to check existance of pubkey %s: %w", pubkey, err) return false, fmt.Errorf("failed to check existance of pubkey %s: %w", pubkey, err)
} }
return exists, nil return exists, nil
} }
func (r RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error { func (db RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error {
if len(IDs) == 0 { if len(IDs) == 0 {
return nil return nil
} }
@@ -148,7 +148,7 @@ func (r RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error {
nodes[i] = node(ID) nodes[i] = node(ID)
} }
exists, err := r.client.Exists(ctx, nodes...).Result() exists, err := db.client.Exists(ctx, nodes...).Result()
if err != nil { if err != nil {
return fmt.Errorf("failed to check for the existence of %d nodes: %w", len(IDs), err) return fmt.Errorf("failed to check for the existence of %d nodes: %w", len(IDs), err)
} }
@@ -161,8 +161,8 @@ func (r RedisDB) ensureExists(ctx context.Context, IDs ...graph.ID) error {
} }
// AddNode adds a new inactive node to the database and returns its assigned ID // AddNode adds a new inactive node to the database and returns its assigned ID
func (r RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) { func (db RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) {
exists, err := r.client.HExists(ctx, KeyKeyIndex, pubkey).Result() exists, err := db.client.HExists(ctx, KeyKeyIndex, pubkey).Result()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to check for existence of pubkey %s: %w", pubkey, err) return "", fmt.Errorf("failed to check for existence of pubkey %s: %w", pubkey, err)
} }
@@ -173,13 +173,13 @@ func (r RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) {
// get the ID outside the transaction, which implies there might be "holes", // get the ID outside the transaction, which implies there might be "holes",
// meaning IDs not associated with any node // meaning IDs not associated with any node
next, err := r.client.HIncrBy(ctx, KeyDatabase, KeyLastNodeID, 1).Result() next, err := db.client.HIncrBy(ctx, KeyDatabase, KeyLastNodeID, 1).Result()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to add node with pubkey %s: failed to increment ID", pubkey) return "", fmt.Errorf("failed to add node with pubkey %s: failed to increment ID", pubkey)
} }
ID := strconv.FormatInt(next-1, 10) ID := strconv.FormatInt(next-1, 10)
pipe := r.client.TxPipeline() pipe := db.client.TxPipeline()
pipe.HSetNX(ctx, KeyKeyIndex, pubkey, ID) pipe.HSetNX(ctx, KeyKeyIndex, pubkey, ID)
pipe.HSet(ctx, node(ID), NodeID, ID, NodePubkey, pubkey, NodeStatus, graph.StatusInactive, NodeAddedTS, time.Now().Unix()) pipe.HSet(ctx, node(ID), NodeID, ID, NodePubkey, pubkey, NodeStatus, graph.StatusInactive, NodeAddedTS, time.Now().Unix())
if _, err := pipe.Exec(ctx); err != nil { if _, err := pipe.Exec(ctx); err != nil {
@@ -190,8 +190,8 @@ func (r RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) {
} }
// Promote changes the node status to active // Promote changes the node status to active
func (r RedisDB) Promote(ctx context.Context, ID graph.ID) error { func (db RedisDB) Promote(ctx context.Context, ID graph.ID) error {
err := r.client.HSet(ctx, node(ID), NodeStatus, graph.StatusActive, NodePromotionTS, time.Now().Unix()).Err() err := db.client.HSet(ctx, node(ID), NodeStatus, graph.StatusActive, NodePromotionTS, time.Now().Unix()).Err()
if err != nil { if err != nil {
return fmt.Errorf("failed to promote %s: %w", node(ID), err) return fmt.Errorf("failed to promote %s: %w", node(ID), err)
} }
@@ -199,8 +199,8 @@ func (r RedisDB) Promote(ctx context.Context, ID graph.ID) error {
} }
// Demote changes the node status to inactive // Demote changes the node status to inactive
func (r RedisDB) Demote(ctx context.Context, ID graph.ID) error { func (db RedisDB) Demote(ctx context.Context, ID graph.ID) error {
err := r.client.HSet(ctx, node(ID), NodeStatus, graph.StatusInactive, NodeDemotionTS, time.Now().Unix()).Err() err := db.client.HSet(ctx, node(ID), NodeStatus, graph.StatusInactive, NodeDemotionTS, time.Now().Unix()).Err()
if err != nil { if err != nil {
return fmt.Errorf("failed to demote %s: %w", node(ID), err) return fmt.Errorf("failed to demote %s: %w", node(ID), err)
} }
@@ -208,24 +208,24 @@ func (r RedisDB) Demote(ctx context.Context, ID graph.ID) error {
} }
// Follows returns the follow list of node. If node is not found, it returns [graph.ErrNodeNotFound]. // Follows returns the follow list of node. If node is not found, it returns [graph.ErrNodeNotFound].
func (r RedisDB) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) { func (db RedisDB) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) {
return r.members(ctx, follows, node) return db.members(ctx, follows, node)
} }
// Followers returns the list of followers of node. If node is not found, it returns [graph.ErrNodeNotFound]. // Followers returns the list of followers of node. If node is not found, it returns [graph.ErrNodeNotFound].
func (r RedisDB) Followers(ctx context.Context, node graph.ID) ([]graph.ID, error) { func (db RedisDB) Followers(ctx context.Context, node graph.ID) ([]graph.ID, error) {
return r.members(ctx, followers, node) return db.members(ctx, followers, node)
} }
func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node graph.ID) ([]graph.ID, error) { func (db RedisDB) members(ctx context.Context, key func(graph.ID) string, node graph.ID) ([]graph.ID, error) {
members, err := r.client.SMembers(ctx, key(node)).Result() members, err := db.client.SMembers(ctx, key(node)).Result()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch %s: %w", key(node), err) return nil, fmt.Errorf("failed to fetch %s: %w", key(node), err)
} }
if len(members) == 0 { if len(members) == 0 {
// check if there are no members because node doesn't exists // check if there are no members because node doesn't exists
if err := r.ensureExists(ctx, node); err != nil { if err := db.ensureExists(ctx, node); err != nil {
return nil, fmt.Errorf("failed to fetch %s: %w", key(node), err) return nil, fmt.Errorf("failed to fetch %s: %w", key(node), err)
} }
} }
@@ -235,17 +235,17 @@ func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node gr
// BulkFollows returns the follow-lists of all the provided nodes. // BulkFollows returns the follow-lists of all the provided nodes.
// Do not call on too many nodes (e.g. +100k) to avoid too many recursions. // Do not call on too many nodes (e.g. +100k) to avoid too many recursions.
func (r RedisDB) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) { func (db RedisDB) BulkFollows(ctx context.Context, nodes []graph.ID) ([][]graph.ID, error) {
return r.bulkMembers(ctx, follows, nodes) return db.bulkMembers(ctx, follows, nodes)
} }
func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nodes []graph.ID) ([][]graph.ID, error) { func (db RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nodes []graph.ID) ([][]graph.ID, error) {
switch { switch {
case len(nodes) == 0: case len(nodes) == 0:
return nil, nil return nil, nil
case len(nodes) < 10000: case len(nodes) < 10000:
pipe := r.client.Pipeline() pipe := db.client.Pipeline()
cmds := make([]*redis.StringSliceCmd, len(nodes)) cmds := make([]*redis.StringSliceCmd, len(nodes))
for i, node := range nodes { for i, node := range nodes {
@@ -270,7 +270,7 @@ func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nod
} }
if len(empty) > 0 { if len(empty) > 0 {
err := r.ensureExists(ctx, empty...) err := db.ensureExists(ctx, empty...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch the %s of these nodes %v: %w", key(""), empty, err) return nil, fmt.Errorf("failed to fetch the %s of these nodes %v: %w", key(""), empty, err)
} }
@@ -281,12 +281,12 @@ func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nod
default: default:
// too many nodes, split them in two batches // too many nodes, split them in two batches
mid := len(nodes) / 2 mid := len(nodes) / 2
batch1, err := r.bulkMembers(ctx, key, nodes[:mid]) batch1, err := db.bulkMembers(ctx, key, nodes[:mid])
if err != nil { if err != nil {
return nil, err return nil, err
} }
batch2, err := r.bulkMembers(ctx, key, nodes[mid:]) batch2, err := db.bulkMembers(ctx, key, nodes[mid:])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -296,21 +296,21 @@ func (r RedisDB) bulkMembers(ctx context.Context, key func(graph.ID) string, nod
} }
// FollowCounts returns the number of follows each node has. If a node is not found, it returns 0. // FollowCounts returns the number of follows each node has. If a node is not found, it returns 0.
func (r RedisDB) FollowCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) { func (db RedisDB) FollowCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
return r.counts(ctx, follows, nodes...) return db.counts(ctx, follows, nodes...)
} }
// FollowerCounts returns the number of followers each node has. If a node is not found, it returns 0. // FollowerCounts returns the number of followers each node has. If a node is not found, it returns 0.
func (r RedisDB) FollowerCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) { func (db RedisDB) FollowerCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
return r.counts(ctx, followers, nodes...) return db.counts(ctx, followers, nodes...)
} }
func (r RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes ...graph.ID) ([]int, error) { func (db RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes ...graph.ID) ([]int, error) {
if len(nodes) == 0 { if len(nodes) == 0 {
return nil, nil return nil, nil
} }
pipe := r.client.Pipeline() pipe := db.client.Pipeline()
cmds := make([]*redis.IntCmd, len(nodes)) cmds := make([]*redis.IntCmd, len(nodes))
for i, node := range nodes { for i, node := range nodes {
@@ -330,19 +330,19 @@ func (r RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes ..
} }
// Update applies the delta to the graph. // Update applies the delta to the graph.
func (r RedisDB) Update(ctx context.Context, delta *graph.Delta) error { func (db RedisDB) Update(ctx context.Context, delta graph.Delta) error {
if delta.Size() == 0 { if delta.Size() == 0 {
return nil return nil
} }
err := r.ensureExists(ctx, delta.Node) err := db.ensureExists(ctx, delta.Node)
if err != nil { if err != nil {
return fmt.Errorf("failed to update with delta %v: %w", delta, err) return fmt.Errorf("failed to update with delta %v: %w", delta, err)
} }
switch delta.Kind { switch delta.Kind {
case nostr.KindFollowList: case nostr.KindFollowList:
err = r.updateFollows(ctx, delta) err = db.updateFollows(ctx, delta)
default: default:
err = fmt.Errorf("unsupported kind %d", delta.Kind) err = fmt.Errorf("unsupported kind %d", delta.Kind)
@@ -355,8 +355,8 @@ func (r RedisDB) Update(ctx context.Context, delta *graph.Delta) error {
return nil return nil
} }
func (r RedisDB) updateFollows(ctx context.Context, delta *graph.Delta) error { func (db RedisDB) updateFollows(ctx context.Context, delta graph.Delta) error {
pipe := r.client.TxPipeline() pipe := db.client.TxPipeline()
if len(delta.Add) > 0 { if len(delta.Add) > 0 {
// add all node --> added // add all node --> added
pipe.SAdd(ctx, follows(delta.Node), toStrings(delta.Add)) pipe.SAdd(ctx, follows(delta.Node), toStrings(delta.Add))
@@ -370,8 +370,8 @@ func (r RedisDB) updateFollows(ctx context.Context, delta *graph.Delta) error {
// remove all node --> removed // remove all node --> removed
pipe.SRem(ctx, follows(delta.Node), toStrings(delta.Remove)) pipe.SRem(ctx, follows(delta.Node), toStrings(delta.Remove))
for _, r := range delta.Remove { for _, db := range delta.Remove {
pipe.SRem(ctx, followers(r), delta.Node) pipe.SRem(ctx, followers(db), delta.Node)
} }
} }
@@ -384,12 +384,12 @@ func (r RedisDB) updateFollows(ctx context.Context, delta *graph.Delta) error {
// NodeIDs returns a slice of node IDs assosiated with the pubkeys. // NodeIDs returns a slice of node IDs assosiated with the pubkeys.
// If a pubkey is not found, an empty ID "" is returned // If a pubkey is not found, an empty ID "" is returned
func (r RedisDB) NodeIDs(ctx context.Context, pubkeys ...string) ([]graph.ID, error) { func (db RedisDB) NodeIDs(ctx context.Context, pubkeys ...string) ([]graph.ID, error) {
if len(pubkeys) == 0 { if len(pubkeys) == 0 {
return nil, nil return nil, nil
} }
IDs, err := r.client.HMGet(ctx, KeyKeyIndex, pubkeys...).Result() IDs, err := db.client.HMGet(ctx, KeyKeyIndex, pubkeys...).Result()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch the node IDs of %v: %w", pubkeys, err) return nil, fmt.Errorf("failed to fetch the node IDs of %v: %w", pubkeys, err)
} }
@@ -398,10 +398,12 @@ func (r RedisDB) NodeIDs(ctx context.Context, pubkeys ...string) ([]graph.ID, er
for i, ID := range IDs { for i, ID := range IDs {
switch ID { switch ID {
case nil: case nil:
nodes[i] = "" // empty ID means missing pubkey // empty ID means missing pubkey
nodes[i] = ""
default: default:
nodes[i] = graph.ID(ID.(string)) // no need to type-assert because everything in redis is a string // direct type convertion because everything in redis is a string
nodes[i] = graph.ID(ID.(string))
} }
} }
@@ -410,12 +412,12 @@ func (r RedisDB) NodeIDs(ctx context.Context, pubkeys ...string) ([]graph.ID, er
// Pubkeys returns a slice of pubkeys assosiated with the node IDs. // Pubkeys returns a slice of pubkeys assosiated with the node IDs.
// If a node ID is not found, an empty pubkey "" is returned // If a node ID is not found, an empty pubkey "" is returned
func (r RedisDB) Pubkeys(ctx context.Context, nodes ...graph.ID) ([]string, error) { func (db RedisDB) Pubkeys(ctx context.Context, nodes ...graph.ID) ([]string, error) {
if len(nodes) == 0 { if len(nodes) == 0 {
return nil, nil return nil, nil
} }
pipe := r.client.Pipeline() pipe := db.client.Pipeline()
cmds := make([]*redis.StringCmd, len(nodes)) cmds := make([]*redis.StringCmd, len(nodes))
for i, ID := range nodes { for i, ID := range nodes {
cmds[i] = pipe.HGet(ctx, node(ID), NodePubkey) cmds[i] = pipe.HGet(ctx, node(ID), NodePubkey)
@@ -443,12 +445,56 @@ func (r RedisDB) Pubkeys(ctx context.Context, nodes ...graph.ID) ([]string, erro
return pubkeys, nil return pubkeys, nil
} }
type MissingHandler func(ctx context.Context, db RedisDB, pubkey string) (graph.ID, error)
func Ignore(context.Context, RedisDB, string) (graph.ID, error) { return "", nil }
func Sentinel(context.Context, RedisDB, string) (graph.ID, error) { return "-1", nil }
func AddValid(ctx context.Context, db RedisDB, pubkey string) (graph.ID, error) {
if !nostr.IsValidPublicKey(pubkey) {
return "", nil
}
return db.AddNode(ctx, pubkey)
}
// Resolve pubkeys into node IDs. If a pubkey is missing (ID = ""), it applies the onMissing handler.
func (db RedisDB) Resolve(ctx context.Context, pubkeys []string, onMissing MissingHandler) ([]graph.ID, error) {
IDs, err := db.NodeIDs(ctx, pubkeys...)
if err != nil {
return nil, fmt.Errorf("failed to resolve pubkeys: %w", err)
}
j := 0 // write index
for i, ID := range IDs {
switch ID {
case "":
ID, err = onMissing(ctx, db, pubkeys[i])
if err != nil {
return nil, fmt.Errorf("failed to resolve pubkey %q: %w", pubkeys[i], err)
}
if ID != "" {
IDs[j] = ID
j++
}
default:
if j != i {
// write only if necessary
IDs[j] = ID
}
j++
}
}
return IDs[:j], nil
}
// ScanNodes to return a batch of node IDs of size roughly proportional to limit. // ScanNodes to return a batch of node IDs of size roughly proportional to limit.
// Limit controls how much "work" is invested in fetching the batch, hence it's not precise. // Limit controls how much "work" is invested in fetching the batch, hence it's not precise.
// Learn more about scan: https://redis.io/docs/latest/commands/scan/ // Learn more about scan: https://redis.io/docs/latest/commands/scan/
func (r RedisDB) ScanNodes(ctx context.Context, cursor uint64, limit int) ([]graph.ID, uint64, error) { func (db RedisDB) ScanNodes(ctx context.Context, cursor uint64, limit int) ([]graph.ID, uint64, error) {
match := KeyNodePrefix + "*" match := KeyNodePrefix + "*"
keys, cursor, err := r.client.Scan(ctx, cursor, match, int64(limit)).Result() keys, cursor, err := db.client.Scan(ctx, cursor, match, int64(limit)).Result()
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to scan for keys matching %s: %w", match, err) return nil, 0, fmt.Errorf("failed to scan for keys matching %s: %w", match, err)
} }

View File

@@ -238,7 +238,7 @@ func TestUpdateFollows(t *testing.T) {
} }
defer db.flushAll() defer db.flushAll()
delta := &graph.Delta{ delta := graph.Delta{
Kind: nostr.KindFollowList, Kind: nostr.KindFollowList,
Node: "0", Node: "0",
Remove: []graph.ID{"1"}, Remove: []graph.ID{"1"},
@@ -324,6 +324,57 @@ func TestNodeIDs(t *testing.T) {
} }
} }
func TestResolve(t *testing.T) {
tests := []struct {
name string
setup func() (RedisDB, error)
pubkeys []string
onMissing MissingHandler
expected []graph.ID
}{
{
name: "empty database",
setup: Empty,
pubkeys: []string{"0"},
onMissing: Sentinel,
expected: []graph.ID{"-1"},
},
{
name: "node not found, ignore",
setup: OneNode,
pubkeys: []string{"1"},
onMissing: Ignore,
expected: []graph.ID{},
},
{
name: "valid",
setup: Simple,
pubkeys: []string{"0", "69", "1"},
onMissing: Ignore,
expected: []graph.ID{"0", "1"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db, err := test.setup()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
nodes, err := db.Resolve(ctx, test.pubkeys, test.onMissing)
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if !reflect.DeepEqual(nodes, test.expected) {
t.Fatalf("expected nodes %v, got %v", test.expected, nodes)
}
})
}
}
func TestPubkeys(t *testing.T) { func TestPubkeys(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -77,7 +77,7 @@ func (c *CachedWalker) Add(node graph.ID, follows []graph.ID) error {
return nil return nil
} }
// Add node and follows as edges. It evicts the LRU element if the capacity has been exeeded. // Add node and follows as [edges]. It evicts the LRU element if the capacity has been exeeded.
func (c *CachedWalker) add(node uint32, follows []uint32) { func (c *CachedWalker) add(node uint32, follows []uint32) {
if e, ok := c.lookup[node]; ok { if e, ok := c.lookup[node]; ok {
// node already present, update value // node already present, update value

View File

@@ -34,9 +34,23 @@ func (w Walk) Len() int {
return len(w.Path) return len(w.Path)
} }
// Visits returns whether the walk visited node // Visits returns whether the walk visits any of the nodes
func (w Walk) Visits(node graph.ID) bool { func (w Walk) Visits(nodes ...graph.ID) bool {
return slices.Contains(w.Path, node) for _, node := range nodes {
if slices.Contains(w.Path, node) {
return true
}
}
return false
}
// VisitsAt returns whether the walk visits any of the nodes at the specified step.
// If the step is outside the bouds of the walk, it returns false.
func (w Walk) VisitsAt(step int, nodes ...graph.ID) bool {
if step < 0 || step >= w.Len() {
return false
}
return slices.Contains(nodes, w.Path[step])
} }
// Index returns the index of node in the walk, or -1 if not present // Index returns the index of node in the walk, or -1 if not present
@@ -89,7 +103,6 @@ func Divergence(w1, w2 Walk) int {
// they are all equal, so no divergence // they are all equal, so no divergence
return -1 return -1
} }
return min return min
} }
@@ -176,9 +189,12 @@ func ToRemove(node graph.ID, walks []Walk) ([]Walk, error) {
return toRemove, nil return toRemove, nil
} }
func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Walk) ([]Walk, error) { // ToUpdate returns how the old walks need to be updated to reflect the changes in the graph.
toUpdate := make([]Walk, 0, expectedUpdates(walks, delta)) // In particular, it corrects invalid steps and resamples in order to maintain the correct distribution.
func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Walk) (old, new []Walk, err error) {
resampleProbability := resampleProbability(delta) resampleProbability := resampleProbability(delta)
old = make([]Walk, 0, expectedUpdates(walks, delta))
new = make([]Walk, 0, expectedUpdates(walks, delta))
for _, walk := range walks { for _, walk := range walks {
pos := walk.Index(delta.Node) pos := walk.Index(delta.Node)
@@ -187,43 +203,44 @@ func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Wal
continue continue
} }
shouldResample := rand.Float64() < resampleProbability resample := rand.Float64() < resampleProbability
isInvalid := (pos < walk.Len()-1) && slices.Contains(delta.Remove, walk.Path[pos+1]) invalid := walk.VisitsAt(pos+1, delta.Remove...)
switch { switch {
case shouldResample: case resample:
// prune and graft with the added nodes to avoid oversampling of common nodes // prune and graft with the added nodes to avoid oversampling of common nodes
updated := walk.Copy() updated := walk.Copy()
updated.Prune(pos + 1) updated.Prune(pos + 1)
if rand.Float64() < Alpha { if rand.Float64() < Alpha {
new, err := generate(ctx, walker, delta.Add...) path, err := generate(ctx, walker, delta.Add...)
if err != nil { if err != nil {
return nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err) return nil, nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err)
} }
updated.Graft(new) updated.Graft(path)
} }
toUpdate = append(toUpdate, updated) old = append(old, walk)
new = append(new, updated)
case isInvalid: case invalid:
// prune and graft invalid steps with the common nodes // prune and graft invalid steps with the common nodes
updated := walk.Copy() updated := walk.Copy()
updated.Prune(pos + 1) updated.Prune(pos + 1)
new, err := generate(ctx, walker, delta.Keep...) path, err := generate(ctx, walker, delta.Keep...)
if err != nil { if err != nil {
return nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err) return nil, nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err)
} }
updated.Graft(new) updated.Graft(path)
toUpdate = append(toUpdate, updated) old = append(old, walk)
new = append(new, updated)
}
} }
} return old, new, nil
return toUpdate, nil
} }
// The resample probability that a walk needs to be changed to avoid an oversampling of common nodes. // The resample probability that a walk needs to be changed to avoid an oversampling of common nodes.
@@ -236,9 +253,9 @@ func resampleProbability(delta graph.Delta) float64 {
return 0 return 0
} }
c := float64(len(delta.Keep)) k := float64(len(delta.Keep))
a := float64(len(delta.Add)) a := float64(len(delta.Add))
return a / (a + c) return a / (a + k)
} }
func expectedUpdates(walks []Walk, delta graph.Delta) int { func expectedUpdates(walks []Walk, delta graph.Delta) int {
@@ -248,14 +265,14 @@ func expectedUpdates(walks []Walk, delta graph.Delta) int {
} }
r := float64(len(delta.Remove)) r := float64(len(delta.Remove))
c := float64(len(delta.Keep)) k := float64(len(delta.Keep))
a := float64(len(delta.Add)) a := float64(len(delta.Add))
invalidProbability := Alpha * r / (r + c) invalidP := Alpha * r / (r + k)
resampleProbability := a / (a + c) resampleP := a / (a + k)
updateProbability := invalidProbability + resampleProbability - invalidProbability*resampleProbability updateP := invalidP + resampleP - invalidP*resampleP
expectedUpdates := float64(len(walks)) * updateProbability expected := float64(len(walks)) * updateP
return int(expectedUpdates + 0.5) return int(expected + 0.5)
} }
// returns a random element of a slice. It panics if the slice is empty or nil. // returns a random element of a slice. It panics if the slice is empty or nil.

View File

@@ -124,13 +124,17 @@ func TestUpdateRemove(t *testing.T) {
Alpha = 1 // avoid early stopping, which makes the test deterministic Alpha = 1 // avoid early stopping, which makes the test deterministic
expected := []Walk{{ID: "0", Path: []graph.ID{"0", "3", "2"}}} expected := []Walk{{ID: "0", Path: []graph.ID{"0", "3", "2"}}}
toUpdate, err := ToUpdate(context.Background(), walker, delta, walks) old, new, err := ToUpdate(context.Background(), walker, delta, walks)
if err != nil { if err != nil {
t.Fatalf("expected nil, got %v", err) t.Fatalf("expected nil, got %v", err)
} }
if !reflect.DeepEqual(toUpdate, expected) { if !reflect.DeepEqual(old, walks[:1]) {
t.Errorf("expected %v, got %v", expected, toUpdate) t.Errorf("expected old %v, got %v", walks[:1], old)
}
if !reflect.DeepEqual(new, expected) {
t.Errorf("expected new %v, got %v", expected, new)
} }
} }

View File

@@ -102,11 +102,11 @@ func TestPagerankDynamic(t *testing.T) {
inv := delta.Inverse() inv := delta.Inverse()
test.walker.Update(ctx, inv) test.walker.Update(ctx, inv)
toUpdate, err := walks.ToUpdate(ctx, test.walker, inv, rwalks) _, new, err := walks.ToUpdate(ctx, test.walker, inv, rwalks)
if err != nil { if err != nil {
t.Fatalf("failed to update the walks: %v", err) t.Fatalf("failed to update the walks: %v", err)
} }
store.ReplaceWalks(toUpdate) store.ReplaceWalks(new)
global, err := pagerank.Global(ctx, store, test.nodes...) global, err := pagerank.Global(ctx, store, test.nodes...)
if err != nil { if err != nil {