it's alive

This commit is contained in:
pippellia-btc
2025-06-04 17:54:00 +02:00
parent 29ef016392
commit 6cbd20f452
11 changed files with 466 additions and 97 deletions

32
cmd/.env Normal file
View File

@@ -0,0 +1,32 @@
# System
EVENTS_CAPACITY=100000
PUBKEYS_CAPACITY=100000
INIT_PUBKEYS=3bf0c63fcb93463407af97a5e5ee64fa883d107ef9e558472c4eb9aaaefa459d
# Connections
REDIS_ADDRESS=localhost:6380
SQLITE_URL=../../events/relay.sqlite
# Firehose
FIREHOSE_OFFSET=60 # in seconds
# Firehose and Fetcher
RELAYS=wss://relay.primal.net,wss://relay.damus.io,wss://nos.lol,wss://eden.nostr.land,wss://relay.current.fyi,wss://nostr.wine,wss://relay.nostr.band
#wss://nostr.orangepill.dev,wss://brb.io,wss://nostr.bitcoiner.social,wss://nostr-pub.wellorder.net,wss://nostr.oxtr.dev,wss://relay.nostr.mom,wss://nostr.fmt.wiz.biz,wss://puravida.nostr.land,wss://nostr.wine,wss://nostr.milou.lol,wss://atlas.nostr.land,wss://offchain.pub,wss://nostr-pub.semisol.dev,wss://nostr.onsats.org,wss://relay.nostr.com.au,wss://relay.nostrati.com,wss://nostr.inosta.cc,wss://relay.nostr.info
# Fetcher
FETCHER_BATCH=100
FETCHER_INTERVAL=30 # in seconds
# Arbiter
# these multipliers must satisfy: (1 + promotion) > demotion > 1;
# if the first inequality is not satisfied, there will be cyclical promotion --> demotion --> promotion...
# if the second inequality is not satisfied, it's impossible for an active node to be demoted
ARBITER_PROMOTION=0.0
ARBITER_DEMOTION=0.0
ARBITER_ACTIVATION=0.0
ARBITER_PROMOTION_WAIT=0
ARBITER_PING_WAIT=10
# Processor
PROCESSOR_PRINT_EVERY=5000

186
cmd/config.go Normal file
View File

@@ -0,0 +1,186 @@
package main
import (
"fmt"
"os"
"strconv"
"strings"
"time"
"github/pippellia-btc/crawler/pkg/pipe"
_ "github.com/joho/godotenv/autoload" // autoloading .env
"github.com/nbd-wtf/go-nostr"
)
type SystemConfig struct {
RedisAddress string
SQLiteURL string
EventsCapacity int
PubkeysCapacity int
InitPubkeys []string // only used during initialization
}
func NewSystemConfig() SystemConfig {
return SystemConfig{
RedisAddress: "localhost:6379",
SQLiteURL: "events.sqlite",
EventsCapacity: 1000,
PubkeysCapacity: 1000,
}
}
func (c SystemConfig) Print() {
fmt.Println("System:")
fmt.Printf(" RedisAddress: %s\n", c.RedisAddress)
fmt.Printf(" SQLiteURL: %s\n", c.SQLiteURL)
fmt.Printf(" EventsCapacity: %d\n", c.EventsCapacity)
fmt.Printf(" PubkeysCapacity: %d\n", c.PubkeysCapacity)
fmt.Printf(" InitPubkeys: %v\n", c.InitPubkeys)
}
// The configuration parameters for the system and the main processes
type Config struct {
SystemConfig
Firehose pipe.FirehoseConfig
Fetcher pipe.FetcherConfig
Arbiter pipe.ArbiterConfig
Processor pipe.ProcessorConfig
}
// NewConfig returns a config with default parameters
func NewConfig() *Config {
return &Config{
SystemConfig: NewSystemConfig(),
Firehose: pipe.NewFirehoseConfig(),
Fetcher: pipe.NewFetcherConfig(),
Arbiter: pipe.NewArbiterConfig(),
Processor: pipe.NewProcessorConfig(),
}
}
func (c *Config) Print() {
c.SystemConfig.Print()
c.Firehose.Print()
c.Fetcher.Print()
c.Arbiter.Print()
c.Processor.Print()
}
// LoadConfig reads the enviroment variables and parses them into a [Config] struct
func LoadConfig() (*Config, error) {
var config = NewConfig()
var err error
for _, item := range os.Environ() {
keyVal := strings.SplitN(item, "=", 2)
key, val := keyVal[0], keyVal[1]
switch key {
case "REDIS_ADDRESS":
config.RedisAddress = val
case "SQLITE_URL":
config.SQLiteURL = val
case "EVENTS_CAPACITY":
config.EventsCapacity, err = strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
case "PUBKEYS_CAPACITY":
config.PubkeysCapacity, err = strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
case "INIT_PUBKEYS":
pubkeys := strings.Split(val, ",")
for _, pk := range pubkeys {
if !nostr.IsValidPublicKey(pk) {
return nil, fmt.Errorf("pubkey %s is not valid", pk)
}
}
config.InitPubkeys = pubkeys
case "FIREHOSE_OFFSET":
offset, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
config.Fetcher.Interval = time.Duration(offset) * time.Second
case "RELAYS":
relays := strings.Split(val, ",")
if len(relays) == 0 {
return nil, fmt.Errorf("relay list is empty")
}
for _, relay := range relays {
if !nostr.IsValidRelayURL(relay) {
return nil, fmt.Errorf("relay \"%s\" is not a valid url", relay)
}
}
config.Firehose.Relays = relays
config.Fetcher.Relays = relays
case "FETCHER_BATCH":
config.Fetcher.Batch, err = strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
case "FETCHER_INTERVAL":
interval, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
config.Fetcher.Interval = time.Duration(interval) * time.Second
case "ARBITER_ACTIVATION":
config.Arbiter.Activation, err = strconv.ParseFloat(val, 64)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
case "ARBITER_PROMOTION":
config.Arbiter.Promotion, err = strconv.ParseFloat(val, 64)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
case "ARBITER_DEMOTION":
config.Arbiter.Demotion, err = strconv.ParseFloat(val, 64)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
case "ARBITER_PING_WAIT":
wait, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
config.Arbiter.PingWait = time.Duration(wait) * time.Second
case "ARBITER_PROMOTION_WAIT":
wait, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
config.Arbiter.PromotionWait = time.Duration(wait) * time.Second
case "PROCESSOR_PRINT_EVERY":
config.Processor.PrintEvery, err = strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("error parsing %v: %v", keyVal, err)
}
}
}
return config, nil
}

122
cmd/crawler.go Normal file
View File

@@ -0,0 +1,122 @@
package main
import (
"context"
"github/pippellia-btc/crawler/pkg/graph"
"github/pippellia-btc/crawler/pkg/pipe"
"github/pippellia-btc/crawler/pkg/redb"
"github/pippellia-btc/crawler/pkg/walks"
"log"
"os"
"os/signal"
"sync"
"syscall"
"github.com/nbd-wtf/go-nostr"
"github.com/redis/go-redis/v9"
)
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go handleSignals(cancel)
config, err := LoadConfig()
if err != nil {
panic(err)
}
events := make(chan *nostr.Event, config.EventsCapacity)
pubkeys := make(chan string, config.PubkeysCapacity)
db := redb.New(&redis.Options{Addr: config.RedisAddress})
count, err := db.NodeCount(ctx)
if err != nil {
panic(err)
}
if count == 0 {
log.Println("initializing crawler from empty database")
nodes := make([]graph.ID, len(config.InitPubkeys))
for i, pk := range config.InitPubkeys {
nodes[i], err = db.AddNode(ctx, pk)
if err != nil {
panic(err)
}
pubkeys <- pk // add to queue
}
walks, err := walks.Generate(ctx, db, nodes...)
if err != nil {
panic(err)
}
if err := db.AddWalks(ctx, walks...); err != nil {
panic(err)
}
log.Printf("correctly added %d init pubkeys", len(config.InitPubkeys))
}
_ = events
// eventStore, err := eventstore.New(config.SQLiteURL)
// if err != nil {
// panic("failed to connect to the sqlite eventstore: " + err.Error())
// }
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer wg.Done()
pipe.Firehose(ctx, config.Firehose, db, func(event *nostr.Event) error {
select {
case events <- event:
default:
log.Printf("Firehose: channel is full, dropping event ID %s by %s", event.ID, event.PubKey)
}
return nil
})
}()
go func() {
defer wg.Done()
pipe.Fetcher(ctx, config.Fetcher, pubkeys, func(event *nostr.Event) error {
select {
case events <- event:
default:
log.Printf("Fetcher: channel is full, dropping event ID %s by %s", event.ID, event.PubKey)
}
return nil
})
}()
go func() {
defer wg.Done()
pipe.Arbiter(ctx, config.Arbiter, db, func(pubkey string) error {
select {
case pubkeys <- pubkey:
default:
log.Printf("Arbiter: channel is full, dropping pubkey %s", pubkey)
}
return nil
})
}()
log.Println("ready to process events")
pipe.Processor(ctx, config.Processor, db, events)
wg.Wait()
}
// handleSignals listens for OS signals and triggers context cancellation.
func handleSignals(cancel context.CancelFunc) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
<-signals
log.Println(" Signal received. Shutting down...")
cancel()
}

1
go.mod
View File

@@ -5,6 +5,7 @@ go 1.24.1
toolchain go1.24.3 toolchain go1.24.3
require ( require (
github.com/joho/godotenv v1.5.1
github.com/nbd-wtf/go-nostr v0.51.12 github.com/nbd-wtf/go-nostr v0.51.12
github.com/redis/go-redis/v9 v9.8.0 github.com/redis/go-redis/v9 v9.8.0
) )

2
go.sum
View File

@@ -31,6 +31,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw= github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=

View File

@@ -14,25 +14,24 @@ import (
// walksTracker tracks the number of walks that have been updated by [Processor]. // walksTracker tracks the number of walks that have been updated by [Processor].
// It's used to wake-up the [Arbiter], which performs work and then resets it to 0. // It's used to wake-up the [Arbiter], which performs work and then resets it to 0.
var walksTracker *atomic.Int32 var walksTracker atomic.Int32
type ArbiterConfig struct { type ArbiterConfig struct {
Activation float64 Activation float64
Promotion float64 Promotion float64
Demotion float64 Demotion float64
PingPeriod time.Duration PromotionWait time.Duration
WaitPeriod time.Duration PingWait time.Duration
} }
func NewArbiterConfig() ArbiterConfig { func NewArbiterConfig() ArbiterConfig {
return ArbiterConfig{ return ArbiterConfig{
Activation: 0.01, Activation: 0.01,
Promotion: 0.1, Promotion: 0.1,
Demotion: 1.05, Demotion: 1.05,
PromotionWait: time.Hour,
PingPeriod: time.Minute, PingWait: time.Minute,
WaitPeriod: time.Hour,
} }
} }
@@ -41,16 +40,19 @@ func (c ArbiterConfig) Print() {
fmt.Printf(" Activation: %f\n", c.Activation) fmt.Printf(" Activation: %f\n", c.Activation)
fmt.Printf(" Promotion: %f\n", c.Promotion) fmt.Printf(" Promotion: %f\n", c.Promotion)
fmt.Printf(" Demotion: %f\n", c.Demotion) fmt.Printf(" Demotion: %f\n", c.Demotion)
fmt.Printf(" WaitPeriod: %v\n", c.WaitPeriod) fmt.Printf(" PromotionWait: %v\n", c.PromotionWait)
fmt.Printf(" PingWait: %v\n", c.PingWait)
} }
// Arbiter activates when the % of walks changed is greater than a threshold. Then it: // Arbiter activates when the % of walks changed is greater than a threshold. Then it:
// - scans through all the nodes in the database // - scans through all the nodes in the database
// - promotes or demotes nodes // - promotes or demotes nodes
func Arbiter(ctx context.Context, config ArbiterConfig, db redb.RedisDB, send func(pk string) error) { func Arbiter(ctx context.Context, config ArbiterConfig, db redb.RedisDB, send func(pk string) error) {
ticker := time.NewTicker(config.PingPeriod) ticker := time.NewTicker(config.PingWait)
defer ticker.Stop() defer ticker.Stop()
walksTracker.Add(1000_000_000) // trigger a scan at startup
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -140,7 +142,7 @@ func arbiterScan(ctx context.Context, config ArbiterConfig, db redb.RedisDB, sen
return promoted, demoted, fmt.Errorf("node %s doesn't have an addition record", node.ID) return promoted, demoted, fmt.Errorf("node %s doesn't have an addition record", node.ID)
} }
if ranks[i] >= promotionThreshold && time.Since(added) > config.WaitPeriod { if ranks[i] >= promotionThreshold && time.Since(added) > config.PromotionWait {
if err := promote(db, node.ID); err != nil { if err := promote(db, node.ID); err != nil {
return promoted, demoted, err return promoted, demoted, err
} }

View File

@@ -12,7 +12,7 @@ import (
var ( var (
relevantKinds = []int{ relevantKinds = []int{
nostr.KindProfileMetadata, //nostr.KindProfileMetadata,
nostr.KindFollowList, nostr.KindFollowList,
} }
@@ -156,8 +156,8 @@ func (c FetcherConfig) Print() {
fmt.Printf(" Interval: %v\n", c.Interval) fmt.Printf(" Interval: %v\n", c.Interval)
} }
// Fetcher extracts pubkeys from the channel and queries for their events when either: // Fetcher extracts pubkeys from the channel and queries for their events:
// - the batch is bigger than config.Batch // - when the batch is bigger than config.Batch
// - after config.Interval since the last query. // - after config.Interval since the last query.
func Fetcher(ctx context.Context, config FetcherConfig, pubkeys <-chan string, send func(*nostr.Event) error) { func Fetcher(ctx context.Context, config FetcherConfig, pubkeys <-chan string, send func(*nostr.Event) error) {
batch := make([]string, 0, config.Batch) batch := make([]string, 0, config.Batch)

View File

@@ -21,7 +21,7 @@ type ProcessorConfig struct {
PrintEvery int PrintEvery int
} }
func NewProcessEventsConfig() ProcessorConfig { func NewProcessorConfig() ProcessorConfig {
return ProcessorConfig{PrintEvery: 5000} return ProcessorConfig{PrintEvery: 5000}
} }
@@ -126,6 +126,7 @@ func processFollowList(cache *walks.CachedWalker, db redb.RedisDB, event *nostr.
return err return err
} }
walksTracker.Add(int32(len(new)))
return cache.Update(ctx, delta) return cache.Update(ctx, delta)
} }

View File

@@ -39,7 +39,7 @@ type RedisDB struct {
func New(opt *redis.Options) RedisDB { func New(opt *redis.Options) RedisDB {
db := RedisDB{client: redis.NewClient(opt)} db := RedisDB{client: redis.NewClient(opt)}
if err := db.validateWalks(); err != nil { if err := db.init(); err != nil {
panic(err) panic(err)
} }
return db return db
@@ -49,7 +49,7 @@ func New(opt *redis.Options) RedisDB {
func (db RedisDB) Size(ctx context.Context) (int, error) { func (db RedisDB) Size(ctx context.Context) (int, error) {
size, err := db.client.DBSize(ctx).Result() size, err := db.client.DBSize(ctx).Result()
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("failed to fetch the db size: %w", err)
} }
return int(size), nil return int(size), nil
} }
@@ -58,7 +58,7 @@ func (db RedisDB) Size(ctx context.Context) (int, error) {
func (db RedisDB) NodeCount(ctx context.Context) (int, error) { func (db RedisDB) NodeCount(ctx context.Context) (int, error) {
nodes, err := db.client.HLen(ctx, KeyKeyIndex).Result() nodes, err := db.client.HLen(ctx, KeyKeyIndex).Result()
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("failed to fetch the node count: %w", err)
} }
return int(nodes), nil return int(nodes), nil
} }

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"github/pippellia-btc/crawler/pkg/graph" "github/pippellia-btc/crawler/pkg/graph"
"github/pippellia-btc/crawler/pkg/walks" "github/pippellia-btc/crawler/pkg/walks"
"math"
"slices" "slices"
"strconv" "strconv"
@@ -30,14 +29,64 @@ var (
ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks") ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks")
) )
// init the walk store checking the existence of [KeyRWS].
// If it exists, check its fields for consistency
// If it doesn't, store [walks.Alpha] and [walks.N]
func (db RedisDB) init() error {
ctx := context.Background()
exists, err := db.client.Exists(ctx, KeyRWS).Result()
if err != nil {
return fmt.Errorf("failed to check for existence of %s %w", KeyRWS, err)
}
switch exists {
case 1:
// exists, check the values
vals, err := db.client.HMGet(ctx, KeyRWS, KeyAlpha, KeyWalksPerNode).Result()
if err != nil {
return fmt.Errorf("failed to fetch alpha and walksPerNode %w", err)
}
alpha, err := parseFloat(vals[0])
if err != nil {
return fmt.Errorf("failed to parse alpha: %w", err)
}
N, err := parseInt(vals[1])
if err != nil {
return fmt.Errorf("failed to parse walksPerNode: %w", err)
}
if alpha != walks.Alpha {
return errors.New("alpha and walks.Alpha are different")
}
if N != walks.N {
return errors.New("N and walks.N are different")
}
case 0:
// doesn't exists, seed the values
err := db.client.HSet(ctx, KeyRWS, KeyAlpha, walks.Alpha, KeyWalksPerNode, walks.N).Err()
if err != nil {
return fmt.Errorf("failed to set alpha and walksPerNode %w", err)
}
default:
return fmt.Errorf("unexpected exists: %d", exists)
}
return nil
}
// Walks returns the walks associated with the IDs. // Walks returns the walks associated with the IDs.
func (r RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, error) { func (db RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, error) {
switch { switch {
case len(IDs) == 0: case len(IDs) == 0:
return nil, nil return nil, nil
case len(IDs) <= 100000: case len(IDs) <= 100000:
vals, err := r.client.HMGet(ctx, KeyWalks, toStrings(IDs)...).Result() vals, err := db.client.HMGet(ctx, KeyWalks, toStrings(IDs)...).Result()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch walks: %w", err) return nil, fmt.Errorf("failed to fetch walks: %w", err)
} }
@@ -58,12 +107,12 @@ func (r RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, erro
default: default:
// too many walks for a single call, so we split them in two batches // too many walks for a single call, so we split them in two batches
mid := len(IDs) / 2 mid := len(IDs) / 2
batch1, err := r.Walks(ctx, IDs[:mid]...) batch1, err := db.Walks(ctx, IDs[:mid]...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
batch2, err := r.Walks(ctx, IDs[mid:]...) batch2, err := db.Walks(ctx, IDs[mid:]...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -74,24 +123,24 @@ func (r RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, erro
// WalksVisiting returns up-to limit walks that visit node. // WalksVisiting returns up-to limit walks that visit node.
// Use limit = -1 to fetch all the walks visiting node. // Use limit = -1 to fetch all the walks visiting node.
func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([]walks.Walk, error) { func (db RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([]walks.Walk, error) {
switch { switch {
case limit == -1: case limit == -1:
// return all walks visiting node // return all walks visiting node
IDs, err := r.client.SMembers(ctx, walksVisiting(node)).Result() IDs, err := db.client.SMembers(ctx, walksVisiting(node)).Result()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err) return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err)
} }
return r.Walks(ctx, toWalks(IDs)...) return db.Walks(ctx, toWalks(IDs)...)
case limit > 0: case limit > 0:
IDs, err := r.client.SRandMemberN(ctx, walksVisiting(node), int64(limit)).Result() IDs, err := db.client.SRandMemberN(ctx, walksVisiting(node), int64(limit)).Result()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err) return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err)
} }
return r.Walks(ctx, toWalks(IDs)...) return db.Walks(ctx, toWalks(IDs)...)
default: default:
return nil, ErrInvalidLimit return nil, ErrInvalidLimit
@@ -102,11 +151,11 @@ func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([
// The walks are distributed evenly among the nodes: // The walks are distributed evenly among the nodes:
// - if limit == -1, all walks are returned (use with few nodes) // - if limit == -1, all walks are returned (use with few nodes)
// - if limit < len(nodes), no walks are returned // - if limit < len(nodes), no walks are returned
func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit int) ([]walks.Walk, error) { func (db RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit int) ([]walks.Walk, error) {
switch { switch {
case limit == -1: case limit == -1:
// return all walks visiting all nodes // return all walks visiting all nodes
pipe := r.client.Pipeline() pipe := db.client.Pipeline()
cmds := make([]*redis.StringSliceCmd, len(nodes)) cmds := make([]*redis.StringSliceCmd, len(nodes))
for i, node := range nodes { for i, node := range nodes {
@@ -123,7 +172,7 @@ func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit i
} }
unique := unique(IDs) unique := unique(IDs)
return r.Walks(ctx, toWalks(unique)...) return db.Walks(ctx, toWalks(unique)...)
case limit > 0: case limit > 0:
// return limit walks uniformely distributed across all nodes // return limit walks uniformely distributed across all nodes
@@ -132,7 +181,7 @@ func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit i
return nil, nil return nil, nil
} }
pipe := r.client.Pipeline() pipe := db.client.Pipeline()
cmds := make([]*redis.StringSliceCmd, len(nodes)) cmds := make([]*redis.StringSliceCmd, len(nodes))
for i, node := range nodes { for i, node := range nodes {
@@ -149,7 +198,7 @@ func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit i
} }
unique := unique(IDs) unique := unique(IDs)
return r.Walks(ctx, toWalks(unique)...) return db.Walks(ctx, toWalks(unique)...)
default: default:
// invalid limit // invalid limit
@@ -158,20 +207,20 @@ func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit i
} }
// AddWalks adds all the walks to the database assigning them progressive IDs. // AddWalks adds all the walks to the database assigning them progressive IDs.
func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error { func (db RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error {
if len(walks) == 0 { if len(walks) == 0 {
return nil return nil
} }
// get the IDs outside the transaction, which implies there might be "holes", // get the IDs outside the transaction, which implies there might be "holes",
// meaning IDs not associated with any walk // meaning IDs not associated with any walk
next, err := r.client.HIncrBy(ctx, KeyRWS, KeyLastWalkID, int64(len(walks))).Result() next, err := db.client.HIncrBy(ctx, KeyRWS, KeyLastWalkID, int64(len(walks))).Result()
if err != nil { if err != nil {
return fmt.Errorf("failed to add walks: failed to increment ID: %w", err) return fmt.Errorf("failed to add walks: failed to increment ID: %w", err)
} }
var visits, ID int var visits, ID int
pipe := r.client.TxPipeline() pipe := db.client.TxPipeline()
for i, walk := range walks { for i, walk := range walks {
visits += walk.Len() visits += walk.Len()
@@ -193,13 +242,13 @@ func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error {
} }
// RemoveWalks removes all the walks from the database. // RemoveWalks removes all the walks from the database.
func (r RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error { func (db RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error {
if len(walks) == 0 { if len(walks) == 0 {
return nil return nil
} }
var visits int var visits int
pipe := r.client.TxPipeline() pipe := db.client.TxPipeline()
for _, walk := range walks { for _, walk := range walks {
pipe.HDel(ctx, KeyWalks, string(walk.ID)) pipe.HDel(ctx, KeyWalks, string(walk.ID))
@@ -219,13 +268,14 @@ func (r RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error {
return nil return nil
} }
func (r RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) error { // ReplaceWalks replaces the old walks with the new ones.
func (db RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) error {
if err := validateReplacement(before, after); err != nil { if err := validateReplacement(before, after); err != nil {
return err return err
} }
var visits int64 var visits int64
pipe := r.client.TxPipeline() pipe := db.client.TxPipeline()
for i := range before { for i := range before {
div := walks.Divergence(before[i], after[i]) div := walks.Divergence(before[i], after[i])
@@ -258,7 +308,7 @@ func (r RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) e
return fmt.Errorf("failed to replace walks: pipeline failed %w", err) return fmt.Errorf("failed to replace walks: pipeline failed %w", err)
} }
pipe = r.client.TxPipeline() pipe = db.client.TxPipeline()
visits = 0 visits = 0
} }
} }
@@ -294,8 +344,8 @@ func validateReplacement(old, new []walks.Walk) error {
} }
// TotalVisits returns the total number of visits, which is the sum of the lengths of all walks. // TotalVisits returns the total number of visits, which is the sum of the lengths of all walks.
func (r RedisDB) TotalVisits(ctx context.Context) (int, error) { func (db RedisDB) TotalVisits(ctx context.Context) (int, error) {
total, err := r.client.HGet(ctx, KeyRWS, KeyTotalVisits).Result() total, err := db.client.HGet(ctx, KeyRWS, KeyTotalVisits).Result()
if err != nil { if err != nil {
return -1, fmt.Errorf("failed to get the total number of visits: %w", err) return -1, fmt.Errorf("failed to get the total number of visits: %w", err)
} }
@@ -309,8 +359,8 @@ func (r RedisDB) TotalVisits(ctx context.Context) (int, error) {
} }
// TotalWalks returns the total number of walks. // TotalWalks returns the total number of walks.
func (r RedisDB) TotalWalks(ctx context.Context) (int, error) { func (db RedisDB) TotalWalks(ctx context.Context) (int, error) {
total, err := r.client.HLen(ctx, KeyWalks).Result() total, err := db.client.HLen(ctx, KeyWalks).Result()
if err != nil { if err != nil {
return -1, fmt.Errorf("failed to get the total number of walks: %w", err) return -1, fmt.Errorf("failed to get the total number of walks: %w", err)
} }
@@ -320,35 +370,8 @@ func (r RedisDB) TotalWalks(ctx context.Context) (int, error) {
// Visits returns the number of times each specified node was visited during the walks. // Visits returns the number of times each specified node was visited during the walks.
// The returned slice contains counts in the same order as the input nodes. // The returned slice contains counts in the same order as the input nodes.
// If a node is not found, it returns 0 visits. // If a node is not found, it returns 0 visits.
func (r RedisDB) Visits(ctx context.Context, nodes ...graph.ID) ([]int, error) { func (db RedisDB) Visits(ctx context.Context, nodes ...graph.ID) ([]int, error) {
return r.counts(ctx, walksVisiting, nodes...) return db.counts(ctx, walksVisiting, nodes...)
}
func (r RedisDB) validateWalks() error {
vals, err := r.client.HMGet(context.Background(), KeyRWS, KeyAlpha, KeyWalksPerNode).Result()
if err != nil {
return fmt.Errorf("failed to fetch alpha and walksPerNode %w", err)
}
alpha, err := parseFloat(vals[0])
if err != nil {
return fmt.Errorf("failed to parse alpha: %w", err)
}
N, err := parseInt(vals[1])
if err != nil {
return fmt.Errorf("failed to parse walksPerNode: %w", err)
}
if math.Abs(alpha-walks.Alpha) > 1e-10 {
return errors.New("alpha and walks.Alpha are different")
}
if N != walks.N {
return errors.New("N and walks.N are different")
}
return nil
} }
// unique returns a slice of unique elements of the input slice. // unique returns a slice of unique elements of the input slice.

View File

@@ -11,30 +11,30 @@ import (
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
func TestValidate(t *testing.T) { // func TestValidate(t *testing.T) {
tests := []struct { // tests := []struct {
name string // name string
setup func() (RedisDB, error) // setup func() (RedisDB, error)
err error // err error
}{ // }{
{name: "empty", setup: Empty, err: ErrValueIsNil}, // {name: "empty", setup: Empty, err: ErrValueIsNil},
{name: "valid", setup: SomeWalks(0)}, // {name: "valid", setup: SomeWalks(0)},
} // }
for _, test := range tests { // for _, test := range tests {
t.Run(test.name, func(t *testing.T) { // t.Run(test.name, func(t *testing.T) {
db, err := test.setup() // db, err := test.setup()
if err != nil { // if err != nil {
t.Fatalf("setup failed: %v", err) // t.Fatalf("setup failed: %v", err)
} // }
defer db.flushAll() // defer db.flushAll()
if err = db.validateWalks(); !errors.Is(err, test.err) { // if err = db.validateWalks(); !errors.Is(err, test.err) {
t.Fatalf("expected error %v, got %v", test.err, err) // t.Fatalf("expected error %v, got %v", test.err, err)
} // }
}) // })
} // }
} // }
func TestWalksVisiting(t *testing.T) { func TestWalksVisiting(t *testing.T) {
tests := []struct { tests := []struct {