database code dump

This commit is contained in:
pippellia-btc
2025-05-27 11:40:44 +02:00
parent 11c5afd4f7
commit 85a2eebc95
10 changed files with 1112 additions and 134 deletions

View File

@@ -1,33 +1,70 @@
package graph
import (
"time"
)
const (
// types of status
StatusActive string = "active" // meaning, we generate random walks for this node
StatusInactive string = "inactive"
// internal record kinds
Addition int = -3
Promotion int = -2
Demotion int = -1
)
type ID string
// Delta represent the changes a Node made to its follow list.
// It Removed some nodes, and Added some others.
// This means the old follow list is Removed + Common, while the new is Common + Added
func (id ID) MarshalBinary() ([]byte, error) { return []byte(id), nil }
// Node contains the metadata about a node, including a collection of Records.
type Node struct {
ID ID
Pubkey string
Status string // either [StatusActive] or [StatusInactive]
Records []Record
}
// Record contains the timestamp of a node update.
type Record struct {
Kind int // either [Addition], [Promotion] or [Demotion]
Timestamp time.Time
}
// Delta represents updates to apply to a Node.
// Add and Remove contain node IDs to add to or remove from the nodes relationships (e.g., follow list).
type Delta struct {
Node ID
Removed []ID
Common []ID
Added []ID
Kind int
Node ID
Remove []ID
Keep []ID
Add []ID
}
// Size returns the number of relationships changed by delta
func (d Delta) Size() int {
return len(d.Remove) + len(d.Add)
}
// Old returns the old state of the delta
func (d Delta) Old() []ID {
return append(d.Common, d.Removed...)
return append(d.Keep, d.Remove...)
}
// New returns the new state of the delta
func (d Delta) New() []ID {
return append(d.Common, d.Added...)
return append(d.Keep, d.Add...)
}
// Inverse of the delta. If a delta and it's inverse are applied, the graph returns to its original state.
func (d Delta) Inverse() Delta {
return Delta{
Node: d.Node,
Common: d.Common,
Removed: d.Added,
Added: d.Removed,
Kind: d.Kind,
Node: d.Node,
Keep: d.Keep,
Remove: d.Add,
Add: d.Remove,
}
}

359
pkg/redb/graph.go Normal file
View File

@@ -0,0 +1,359 @@
package redb
import (
"context"
"errors"
"fmt"
"github/pippellia-btc/crawler/pkg/graph"
"strconv"
"strings"
"time"
"github.com/nbd-wtf/go-nostr"
"github.com/redis/go-redis/v9"
)
const (
// redis variable names
KeyDatabase string = "database" // TODO: this can be removed
KeyLastNodeID string = "lastNodeID" // TODO: change it to "next" inside "node" hash
KeyKeyIndex string = "keyIndex" // TODO: change to key_index
KeyNodePrefix string = "node:"
KeyFollowsPrefix string = "follows:"
KeyFollowersPrefix string = "followers:"
// redis node HASH fields
NodeID string = "id"
NodePubkey string = "pubkey"
NodeStatus string = "status"
NodePromotionTS string = "promotion_TS" // TODO: change to promotion
NodeDemotionTS string = "demotion_TS" // TODO: change to demotion
NodeAddedTS string = "added_TS" // TODO: change to addition
)
var (
ErrNodeNotFound = errors.New("node not found")
ErrNodeAlreadyExists = errors.New("node already exists")
)
type RedisDB struct {
client *redis.Client
}
func New(opt *redis.Options) RedisDB {
return RedisDB{client: redis.NewClient(opt)}
}
// Size returns the DBSize of redis, which is the total number of keys
func (r RedisDB) Size(ctx context.Context) (int, error) {
size, err := r.client.DBSize(ctx).Result()
if err != nil {
return 0, err
}
return int(size), nil
}
// NodeCount returns the number of nodes stored in redis (in the keyIndex)
func (r RedisDB) NodeCount(ctx context.Context) (int, error) {
nodes, err := r.client.HLen(ctx, KeyKeyIndex).Result()
if err != nil {
return 0, err
}
return int(nodes), nil
}
// NodeByID fetches a node by its ID
func (r RedisDB) NodeByID(ctx context.Context, ID graph.ID) (*graph.Node, error) {
fields, err := r.client.HGetAll(ctx, node(ID)).Result()
if err != nil {
return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), err)
}
if len(fields) == 0 {
return nil, fmt.Errorf("failed to fetch %s: %w", node(ID), ErrNodeNotFound)
}
return parseNode(fields)
}
// NodeByKey fetches a node by its pubkey
func (r RedisDB) NodeByKey(ctx context.Context, pubkey string) (*graph.Node, error) {
ID, err := r.client.HGet(ctx, KeyKeyIndex, pubkey).Result()
if err != nil {
return nil, fmt.Errorf("failed to fetch ID of node with pubkey %s: %w", pubkey, err)
}
fields, err := r.client.HGetAll(ctx, node(ID)).Result()
if err != nil {
return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, err)
}
if len(fields) == 0 {
return nil, fmt.Errorf("failed to fetch node with pubkey %s: %w", pubkey, ErrNodeNotFound)
}
return parseNode(fields)
}
func (r RedisDB) containsNode(ctx context.Context, ID graph.ID) (bool, error) {
exists, err := r.client.Exists(ctx, node(ID)).Result()
if err != nil {
return false, fmt.Errorf("failed to check for the existence of %v: %w", node(ID), err)
}
return exists == 1, nil
}
// AddNode adds a new inactive node to the database and returns its assigned ID
func (r RedisDB) AddNode(ctx context.Context, pubkey string) (graph.ID, error) {
exists, err := r.client.HExists(ctx, KeyKeyIndex, pubkey).Result()
if err != nil {
return "", fmt.Errorf("failed to check for existence of pubkey %s: %w", pubkey, err)
}
if exists {
return "", fmt.Errorf("failed to add node with pubkey %s: %w", pubkey, ErrNodeAlreadyExists)
}
// get the ID outside the transaction, which implies there might be "holes",
// meaning IDs not associated with any node
next, err := r.client.HIncrBy(ctx, KeyDatabase, KeyLastNodeID, 1).Result()
if err != nil {
return "", fmt.Errorf("failed to add node with pubkey %s: failed to increment ID", pubkey)
}
ID := strconv.FormatInt(next-1, 10)
pipe := r.client.TxPipeline()
pipe.HSetNX(ctx, KeyKeyIndex, pubkey, ID)
pipe.HSet(ctx, node(ID), NodeID, ID, NodePubkey, pubkey, NodeStatus, graph.StatusInactive, NodeAddedTS, time.Now().Unix())
if _, err := pipe.Exec(ctx); err != nil {
return "", fmt.Errorf("failed to add node with pubkey %s: pipeline failed: %w", pubkey, err)
}
return graph.ID(ID), nil
}
// Promote changes the node status to active
func (r RedisDB) Promote(ctx context.Context, ID graph.ID) error {
err := r.client.HSet(ctx, node(ID), NodeStatus, graph.StatusActive, NodePromotionTS, time.Now().Unix()).Err()
if err != nil {
return fmt.Errorf("failed to promote %s: %w", node(ID), err)
}
return nil
}
// Demote changes the node status to inactive
func (r RedisDB) Demote(ctx context.Context, ID graph.ID) error {
err := r.client.HSet(ctx, node(ID), NodeStatus, graph.StatusInactive, NodeDemotionTS, time.Now().Unix()).Err()
if err != nil {
return fmt.Errorf("failed to promote %s: %w", node(ID), err)
}
return nil
}
// Follows returns the follow list of node. If node is not found, it returns [ErrNodeNotFound].
func (r RedisDB) Follows(ctx context.Context, node graph.ID) ([]graph.ID, error) {
return r.members(ctx, follows, node)
}
// Followers returns the list of followers of node. If node is not found, it returns [ErrNodeNotFound].
func (r RedisDB) Followers(ctx context.Context, node graph.ID) ([]graph.ID, error) {
return r.members(ctx, followers, node)
}
func (r RedisDB) members(ctx context.Context, key func(graph.ID) string, node graph.ID) ([]graph.ID, error) {
members, err := r.client.SMembers(ctx, key(node)).Result()
if err != nil {
return nil, fmt.Errorf("failed to fetch %s: %w", key(node), err)
}
if len(members) == 0 {
// check if there are no members because node doesn't exists
ok, err := r.containsNode(ctx, node)
if err != nil {
return nil, err
}
if !ok {
return nil, fmt.Errorf("failed to fetch %s: %w", key(node), ErrNodeNotFound)
}
}
return toIDs(members), nil
}
// FollowCounts returns the number of follows each node has. If a node is not found, it returns 0.
func (r RedisDB) FollowCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
return r.counts(ctx, follows, nodes...)
}
// FollowerCounts returns the number of followers each node has. If a node is not found, it returns 0.
func (r RedisDB) FollowerCounts(ctx context.Context, nodes ...graph.ID) ([]int, error) {
return r.counts(ctx, followers, nodes...)
}
func (r RedisDB) counts(ctx context.Context, key func(graph.ID) string, nodes ...graph.ID) ([]int, error) {
if len(nodes) == 0 {
return nil, nil
}
pipe := r.client.Pipeline()
cmds := make([]*redis.IntCmd, len(nodes))
keys := make([]string, len(nodes))
for i, node := range nodes {
keys[i] = key(node)
cmds[i] = pipe.SCard(ctx, key(node))
}
if _, err := pipe.Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to count the elements of %v: %w", keys, err)
}
counts := make([]int, len(nodes))
for i, cmd := range cmds {
counts[i] = int(cmd.Val())
}
return counts, nil
}
// Update applies the delta to the graph.
func (r RedisDB) Update(ctx context.Context, delta *graph.Delta) error {
if delta.Size() == 0 {
return nil
}
ok, err := r.containsNode(ctx, delta.Node)
if err != nil {
return fmt.Errorf("failed to update with delta %v: %w", delta, err)
}
if !ok {
return fmt.Errorf("failed to update with delta %v: %w", delta, ErrNodeNotFound)
}
switch delta.Kind {
case nostr.KindFollowList:
err = r.updateFollows(ctx, delta)
default:
err = fmt.Errorf("unsupported kind %d", delta.Kind)
}
if err != nil {
return fmt.Errorf("failed to update with delta %v: %w", delta, err)
}
return nil
}
func (r RedisDB) updateFollows(ctx context.Context, delta *graph.Delta) error {
pipe := r.client.TxPipeline()
if len(delta.Add) > 0 {
// add all node --> added
pipe.SAdd(ctx, follows(delta.Node), toStrings(delta.Add))
for _, a := range delta.Add {
pipe.SAdd(ctx, followers(a), delta.Node)
}
}
if len(delta.Remove) > 0 {
// remove all node --> removed
pipe.SRem(ctx, follows(delta.Node), toStrings(delta.Remove))
for _, r := range delta.Remove {
pipe.SRem(ctx, followers(r), delta.Node)
}
}
if _, err := pipe.Exec(ctx); err != nil {
return fmt.Errorf("pipeline failed: %w", err)
}
return nil
}
// NodeIDs returns a slice of node IDs assosiated with the pubkeys.
// If a pubkey is not found, an empty ID "" is returned
func (r RedisDB) NodeIDs(ctx context.Context, pubkeys ...string) ([]graph.ID, error) {
if len(pubkeys) == 0 {
return nil, nil
}
IDs, err := r.client.HMGet(ctx, KeyKeyIndex, pubkeys...).Result()
if err != nil {
return nil, fmt.Errorf("failed to fetch the node IDs of %v: %w", pubkeys, err)
}
nodes := make([]graph.ID, len(IDs))
for i, ID := range IDs {
switch ID {
case nil:
nodes[i] = "" // empty ID means missing pubkey
default:
nodes[i] = graph.ID(ID.(string)) // no need to type-assert because everything in redis is a string
}
}
return nodes, nil
}
// Pubkeys returns a slice of pubkeys assosiated with the node IDs.
// If a node ID is not found, an empty pubkey "" is returned
func (r RedisDB) Pubkeys(ctx context.Context, nodes ...graph.ID) ([]string, error) {
if len(nodes) == 0 {
return nil, nil
}
pipe := r.client.Pipeline()
cmds := make([]*redis.StringCmd, len(nodes))
for i, ID := range nodes {
cmds[i] = pipe.HGet(ctx, node(ID), NodePubkey)
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
// deal later with redis.Nil, which means node(s) not found
return nil, fmt.Errorf("failed to fetch the pubkeys of %v: pipeline failed: %w", nodes, err)
}
pubkeys := make([]string, len(nodes))
for i, cmd := range cmds {
switch {
case errors.Is(cmd.Err(), redis.Nil):
pubkeys[i] = "" // empty pubkey means missing node
case cmd.Err() != nil:
return nil, fmt.Errorf("failed to fetch the pubkeys of %v: %w", nodes, cmd.Err())
default:
pubkeys[i] = cmd.Val()
}
}
return pubkeys, nil
}
// ScanNodes to return a batch of node IDs of size roughly proportional to limit.
// Limit controls how much "work" is invested in fetching the batch, hence it's not precise.
// Learn more about scan: https://redis.io/docs/latest/commands/scan/
func (r RedisDB) ScanNodes(ctx context.Context, cursor uint64, limit int) ([]graph.ID, uint64, error) {
match := KeyNodePrefix + "*"
keys, cursor, err := r.client.Scan(ctx, cursor, match, int64(limit)).Result()
if err != nil {
return nil, 0, fmt.Errorf("failed to scan for keys matching %s: %w", match, err)
}
nodes := make([]graph.ID, len(keys))
for i, key := range keys {
node, found := strings.CutPrefix(key, KeyNodePrefix)
if !found {
return nil, 0, fmt.Errorf("failed to scan for keys matching %s: bad match %s", match, node)
}
nodes[i] = graph.ID(node)
}
return nodes, cursor, nil
}

357
pkg/redb/graph_test.go Normal file
View File

@@ -0,0 +1,357 @@
package redb
import (
"context"
"errors"
"github/pippellia-btc/crawler/pkg/graph"
"reflect"
"testing"
"time"
"github.com/nbd-wtf/go-nostr"
"github.com/redis/go-redis/v9"
)
var ctx = context.Background()
func TestParseNode(t *testing.T) {
tests := []struct {
name string
fields map[string]string
expected *graph.Node
err error
}{
{
name: "nil map",
},
{
name: "empty map",
fields: map[string]string{},
},
{
name: "valid no records",
fields: map[string]string{
NodeID: "19",
NodePubkey: "nineteen",
NodeStatus: graph.StatusActive,
},
expected: &graph.Node{
ID: "19",
Pubkey: "nineteen",
Status: graph.StatusActive,
},
},
{
name: "valid with record",
fields: map[string]string{
NodeID: "19",
NodePubkey: "nineteen",
NodeStatus: graph.StatusActive,
NodeAddedTS: "1",
},
expected: &graph.Node{
ID: "19",
Pubkey: "nineteen",
Status: graph.StatusActive,
Records: []graph.Record{
{Kind: graph.Addition, Timestamp: time.Unix(1, 0)},
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
node, err := parseNode(test.fields)
if !errors.Is(err, test.err) {
t.Fatalf("expected %v got %v", test.err, err)
}
if !reflect.DeepEqual(node, test.expected) {
t.Fatalf("ParseNode(): expected node %v got %v", test.expected, node)
}
})
}
}
func TestAddNode(t *testing.T) {
t.Run("node already exists", func(t *testing.T) {
db, err := OneNode()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
if _, err = db.AddNode(ctx, "0"); !errors.Is(err, ErrNodeAlreadyExists) {
t.Fatalf("expected error %v, got %v", ErrNodeAlreadyExists, err)
}
})
t.Run("valid", func(t *testing.T) {
db, err := OneNode()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
ID, err := db.AddNode(ctx, "xxx")
if err != nil {
t.Fatalf("expected nil, got %v", err)
}
expected := &graph.Node{
ID: "1",
Pubkey: "xxx",
Status: graph.StatusInactive,
Records: []graph.Record{{Kind: graph.Addition, Timestamp: time.Unix(time.Now().Unix(), 0)}},
}
if ID != expected.ID {
t.Fatalf("expected ID %s, got %s", expected.ID, ID)
}
node, err := db.NodeByKey(ctx, "xxx")
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(node, expected) {
t.Fatalf("expected node %v, got %v", expected, node)
}
})
}
func TestMembers(t *testing.T) {
tests := []struct {
name string
setup func() (RedisDB, error)
node graph.ID
expected []graph.ID
err error
}{
{
name: "empty database",
setup: Empty,
node: "0",
err: ErrNodeNotFound,
},
{
name: "node not found",
setup: OneNode,
node: "1",
err: ErrNodeNotFound,
},
{
name: "dandling node",
setup: OneNode,
node: "0",
expected: []graph.ID{},
},
{
name: "valid",
setup: Simple,
node: "0",
expected: []graph.ID{"1"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db, err := test.setup()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
follows, err := db.members(ctx, follows, test.node)
if !errors.Is(err, test.err) {
t.Fatalf("expected error %v, got %v", test.err, err)
}
if !reflect.DeepEqual(follows, test.expected) {
t.Errorf("expected follows %v, got %v", test.expected, follows)
}
})
}
}
func TestUpdateFollows(t *testing.T) {
db, err := Simple()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
delta := &graph.Delta{
Kind: nostr.KindFollowList,
Node: "0",
Remove: []graph.ID{"1"},
Add: []graph.ID{"2"},
}
if err := db.Update(ctx, delta); err != nil {
t.Fatalf("expected error nil, got %v", err)
}
follows, err := db.Follows(ctx, "0")
if err != nil {
t.Fatalf("expected nil got %v", err)
}
if !reflect.DeepEqual(follows, []graph.ID{"2"}) {
t.Fatalf("expected follows(0) %v, got %v", []graph.ID{"2"}, follows)
}
followers, err := db.Followers(ctx, "1")
if err != nil {
t.Fatalf("expected nil got %v", err)
}
if !reflect.DeepEqual(followers, []graph.ID{}) {
t.Fatalf("expected followers(1) %v, got %v", []graph.ID{}, followers)
}
followers, err = db.Followers(ctx, "2")
if err != nil {
t.Fatalf("expected nil got %v", err)
}
if !reflect.DeepEqual(followers, []graph.ID{"0"}) {
t.Fatalf("expected followers(2) %v, got %v", []graph.ID{"0"}, followers)
}
}
func TestNodeIDs(t *testing.T) {
tests := []struct {
name string
setup func() (RedisDB, error)
pubkeys []string
expected []graph.ID
}{
{
name: "empty database",
setup: Empty,
pubkeys: []string{"0"},
expected: []graph.ID{""},
},
{
name: "node not found",
setup: OneNode,
pubkeys: []string{"1"},
expected: []graph.ID{""},
},
{
name: "valid",
setup: Simple,
pubkeys: []string{"0", "1", "69"},
expected: []graph.ID{"0", "1", ""}, // last is not found
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db, err := test.setup()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
nodes, err := db.NodeIDs(ctx, test.pubkeys...)
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if !reflect.DeepEqual(nodes, test.expected) {
t.Fatalf("expected nodes %v, got %v", test.expected, nodes)
}
})
}
}
func TestPubkeys(t *testing.T) {
tests := []struct {
name string
setup func() (RedisDB, error)
nodes []graph.ID
expected []string
}{
{
name: "empty database",
setup: Empty,
nodes: []graph.ID{"0"},
expected: []string{""},
},
{
name: "node not found",
setup: OneNode,
nodes: []graph.ID{"1"},
expected: []string{""},
},
{
name: "valid",
setup: Simple,
nodes: []graph.ID{"0", "1", "69"},
expected: []string{"0", "1", ""}, // last is not found
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db, err := test.setup()
if err != nil {
t.Fatalf("setup failed: %v", err)
}
defer db.flushAll()
pubkeys, err := db.Pubkeys(ctx, test.nodes...)
if err != nil {
t.Fatalf("expected error nil, got %v", err)
}
if !reflect.DeepEqual(pubkeys, test.expected) {
t.Fatalf("expected pubkeys %v, got %v", test.expected, pubkeys)
}
})
}
}
// ------------------------------------- HELPERS -------------------------------
func Empty() (RedisDB, error) {
return New(&redis.Options{Addr: testAddress}), nil
}
func OneNode() (RedisDB, error) {
db := New(&redis.Options{Addr: testAddress})
if _, err := db.AddNode(context.Background(), "0"); err != nil {
db.flushAll()
return RedisDB{}, err
}
return db, nil
}
func Simple() (RedisDB, error) {
ctx := context.Background()
db := New(&redis.Options{Addr: testAddress})
for _, pk := range []string{"0", "1", "2"} {
if _, err := db.AddNode(ctx, pk); err != nil {
db.flushAll()
return RedisDB{}, err
}
}
// 0 ---> 1
if err := db.client.SAdd(ctx, follows("0"), "1").Err(); err != nil {
db.flushAll()
return RedisDB{}, err
}
if err := db.client.SAdd(ctx, followers("1"), "0").Err(); err != nil {
db.flushAll()
return RedisDB{}, err
}
return db, nil
}

100
pkg/redb/utils.go Normal file
View File

@@ -0,0 +1,100 @@
package redb
import (
"context"
"github/pippellia-btc/crawler/pkg/graph"
"strconv"
"time"
)
var (
testAddress = "localhost:6380"
)
// flushAll deletes all the keys of all existing databases. This command never fails.
func (r RedisDB) flushAll() {
r.client.FlushAll(context.Background())
}
func node[ID string | graph.ID](id ID) string {
return KeyNodePrefix + string(id)
}
func follows[ID string | graph.ID](id ID) string {
return KeyFollowsPrefix + string(id)
}
func followers[ID string | graph.ID](id ID) string {
return KeyFollowersPrefix + string(id)
}
// ids converts a slice of strings to IDs
func toIDs(s []string) []graph.ID {
IDs := make([]graph.ID, len(s))
for i, e := range s {
IDs[i] = graph.ID(e)
}
return IDs
}
// strings converts graph IDs to a slice of strings
func toStrings(ids []graph.ID) []string {
s := make([]string, len(ids))
for i, id := range ids {
s[i] = string(id)
}
return s
}
// parseNode() parses the map into a node structure
func parseNode(fields map[string]string) (*graph.Node, error) {
if len(fields) == 0 {
return nil, nil
}
var node graph.Node
for key, val := range fields {
switch key {
case NodeID:
node.ID = graph.ID(val)
case NodePubkey:
node.Pubkey = val
case NodeStatus:
node.Status = val
case NodeAddedTS:
ts, err := parseTimestamp(val)
if err != nil {
return nil, err
}
node.Records = append(node.Records, graph.Record{Kind: graph.Addition, Timestamp: ts})
case NodePromotionTS:
ts, err := parseTimestamp(val)
if err != nil {
return nil, err
}
node.Records = append(node.Records, graph.Record{Kind: graph.Promotion, Timestamp: ts})
case NodeDemotionTS:
ts, err := parseTimestamp(val)
if err != nil {
return nil, err
}
node.Records = append(node.Records, graph.Record{Kind: graph.Demotion, Timestamp: ts})
}
}
return &node, nil
}
// parseTimestamp() parses a unix timestamp string into a time.Time
func parseTimestamp(unix string) (time.Time, error) {
ts, err := strconv.ParseInt(unix, 10, 64)
if err != nil {
return time.Time{}, err
}
return time.Unix(ts, 0), nil
}

1
pkg/redb/walks.go Normal file
View File

@@ -0,0 +1 @@
package redb

View File

@@ -179,7 +179,7 @@ func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Wal
}
shouldResample = rand.Float64() < resampleProbability
isInvalid = (pos < walk.Len()-1) && slices.Contains(delta.Removed, walk.Path[pos+1])
isInvalid = (pos < walk.Len()-1) && slices.Contains(delta.Remove, walk.Path[pos+1])
switch {
case shouldResample:
@@ -188,7 +188,7 @@ func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Wal
updated.Prune(pos + 1)
if rand.Float64() < Alpha {
new, err := generate(ctx, walker, delta.Added...)
new, err := generate(ctx, walker, delta.Add...)
if err != nil {
return nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err)
}
@@ -203,7 +203,7 @@ func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Wal
updated := walk.Copy()
updated.Prune(pos + 1)
new, err := generate(ctx, walker, delta.Common...)
new, err := generate(ctx, walker, delta.Keep...)
if err != nil {
return nil, fmt.Errorf("ToUpdate: failed to generate new segment: %w", err)
}
@@ -223,24 +223,24 @@ func ToUpdate(ctx context.Context, walker Walker, delta graph.Delta, walks []Wal
// Our goal is to have 1/3 of the walks that continue go to each of 1, 2 and 3.
// This means we have to re-do 2/3 of the walks and make them continue towards 2 or 3.
func resampleProbability(delta graph.Delta) float64 {
if len(delta.Added) == 0 {
if len(delta.Add) == 0 {
return 0
}
c := float64(len(delta.Common))
a := float64(len(delta.Added))
c := float64(len(delta.Keep))
a := float64(len(delta.Add))
return a / (a + c)
}
func expectedUpdates(walks []Walk, delta graph.Delta) int {
if len(delta.Common) == 0 {
if len(delta.Keep) == 0 {
// no nodes have remained, all walks must be re-computed
return len(walks)
}
r := float64(len(delta.Removed))
c := float64(len(delta.Common))
a := float64(len(delta.Added))
r := float64(len(delta.Remove))
c := float64(len(delta.Keep))
a := float64(len(delta.Add))
invalidProbability := Alpha * r / (r + c)
resampleProbability := a / (a + c)

View File

@@ -62,9 +62,9 @@ func TestUpdateRemove(t *testing.T) {
})
delta := graph.Delta{
Node: "0",
Removed: []graph.ID{"1"}, // the old follows were "1" and "3"
Common: []graph.ID{"3"},
Node: "0",
Remove: []graph.ID{"1"}, // the old follows were "1" and "3"
Keep: []graph.ID{"3"},
}
walks := []Walk{