diff --git a/cmd/.env b/cmd/.env new file mode 100644 index 0000000..46634de --- /dev/null +++ b/cmd/.env @@ -0,0 +1,32 @@ +# System +EVENTS_CAPACITY=100000 +PUBKEYS_CAPACITY=100000 +INIT_PUBKEYS=3bf0c63fcb93463407af97a5e5ee64fa883d107ef9e558472c4eb9aaaefa459d + +# Connections +REDIS_ADDRESS=localhost:6380 +SQLITE_URL=../../events/relay.sqlite + +# Firehose +FIREHOSE_OFFSET=60 # in seconds + +# Firehose and Fetcher +RELAYS=wss://relay.primal.net,wss://relay.damus.io,wss://nos.lol,wss://eden.nostr.land,wss://relay.current.fyi,wss://nostr.wine,wss://relay.nostr.band +#wss://nostr.orangepill.dev,wss://brb.io,wss://nostr.bitcoiner.social,wss://nostr-pub.wellorder.net,wss://nostr.oxtr.dev,wss://relay.nostr.mom,wss://nostr.fmt.wiz.biz,wss://puravida.nostr.land,wss://nostr.wine,wss://nostr.milou.lol,wss://atlas.nostr.land,wss://offchain.pub,wss://nostr-pub.semisol.dev,wss://nostr.onsats.org,wss://relay.nostr.com.au,wss://relay.nostrati.com,wss://nostr.inosta.cc,wss://relay.nostr.info + +# Fetcher +FETCHER_BATCH=100 +FETCHER_INTERVAL=30 # in seconds + +# Arbiter +# these multipliers must satisfy: (1 + promotion) > demotion > 1; +# if the first inequality is not satisfied, there will be cyclical promotion --> demotion --> promotion... +# if the second inequality is not satisfied, it's impossible for an active node to be demoted +ARBITER_PROMOTION=0.0 +ARBITER_DEMOTION=0.0 +ARBITER_ACTIVATION=0.0 +ARBITER_PROMOTION_WAIT=0 +ARBITER_PING_WAIT=10 + +# Processor +PROCESSOR_PRINT_EVERY=5000 diff --git a/cmd/config.go b/cmd/config.go new file mode 100644 index 0000000..a4ecbcd --- /dev/null +++ b/cmd/config.go @@ -0,0 +1,186 @@ +package main + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" + + "github/pippellia-btc/crawler/pkg/pipe" + + _ "github.com/joho/godotenv/autoload" // autoloading .env + "github.com/nbd-wtf/go-nostr" +) + +type SystemConfig struct { + RedisAddress string + SQLiteURL string + + EventsCapacity int + PubkeysCapacity int + + InitPubkeys []string // only used during initialization +} + +func NewSystemConfig() SystemConfig { + return SystemConfig{ + RedisAddress: "localhost:6379", + SQLiteURL: "events.sqlite", + EventsCapacity: 1000, + PubkeysCapacity: 1000, + } +} + +func (c SystemConfig) Print() { + fmt.Println("System:") + fmt.Printf(" RedisAddress: %s\n", c.RedisAddress) + fmt.Printf(" SQLiteURL: %s\n", c.SQLiteURL) + fmt.Printf(" EventsCapacity: %d\n", c.EventsCapacity) + fmt.Printf(" PubkeysCapacity: %d\n", c.PubkeysCapacity) + fmt.Printf(" InitPubkeys: %v\n", c.InitPubkeys) +} + +// The configuration parameters for the system and the main processes +type Config struct { + SystemConfig + Firehose pipe.FirehoseConfig + Fetcher pipe.FetcherConfig + Arbiter pipe.ArbiterConfig + Processor pipe.ProcessorConfig +} + +// NewConfig returns a config with default parameters +func NewConfig() *Config { + return &Config{ + SystemConfig: NewSystemConfig(), + Firehose: pipe.NewFirehoseConfig(), + Fetcher: pipe.NewFetcherConfig(), + Arbiter: pipe.NewArbiterConfig(), + Processor: pipe.NewProcessorConfig(), + } +} + +func (c *Config) Print() { + c.SystemConfig.Print() + c.Firehose.Print() + c.Fetcher.Print() + c.Arbiter.Print() + c.Processor.Print() +} + +// LoadConfig reads the enviroment variables and parses them into a [Config] struct +func LoadConfig() (*Config, error) { + var config = NewConfig() + var err error + + for _, item := range os.Environ() { + keyVal := strings.SplitN(item, "=", 2) + key, val := keyVal[0], keyVal[1] + + switch key { + case "REDIS_ADDRESS": + config.RedisAddress = val + + case "SQLITE_URL": + config.SQLiteURL = val + + case "EVENTS_CAPACITY": + config.EventsCapacity, err = strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + + case "PUBKEYS_CAPACITY": + config.PubkeysCapacity, err = strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + + case "INIT_PUBKEYS": + pubkeys := strings.Split(val, ",") + for _, pk := range pubkeys { + if !nostr.IsValidPublicKey(pk) { + return nil, fmt.Errorf("pubkey %s is not valid", pk) + } + } + + config.InitPubkeys = pubkeys + + case "FIREHOSE_OFFSET": + offset, err := strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + config.Fetcher.Interval = time.Duration(offset) * time.Second + + case "RELAYS": + relays := strings.Split(val, ",") + if len(relays) == 0 { + return nil, fmt.Errorf("relay list is empty") + } + + for _, relay := range relays { + if !nostr.IsValidRelayURL(relay) { + return nil, fmt.Errorf("relay \"%s\" is not a valid url", relay) + } + } + + config.Firehose.Relays = relays + config.Fetcher.Relays = relays + + case "FETCHER_BATCH": + config.Fetcher.Batch, err = strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + + case "FETCHER_INTERVAL": + interval, err := strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + config.Fetcher.Interval = time.Duration(interval) * time.Second + + case "ARBITER_ACTIVATION": + config.Arbiter.Activation, err = strconv.ParseFloat(val, 64) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + + case "ARBITER_PROMOTION": + config.Arbiter.Promotion, err = strconv.ParseFloat(val, 64) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + + case "ARBITER_DEMOTION": + config.Arbiter.Demotion, err = strconv.ParseFloat(val, 64) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + + case "ARBITER_PING_WAIT": + wait, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + config.Arbiter.PingWait = time.Duration(wait) * time.Second + + case "ARBITER_PROMOTION_WAIT": + wait, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + config.Arbiter.PromotionWait = time.Duration(wait) * time.Second + + case "PROCESSOR_PRINT_EVERY": + config.Processor.PrintEvery, err = strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("error parsing %v: %v", keyVal, err) + } + } + } + + return config, nil +} diff --git a/cmd/crawler.go b/cmd/crawler.go new file mode 100644 index 0000000..0272110 --- /dev/null +++ b/cmd/crawler.go @@ -0,0 +1,122 @@ +package main + +import ( + "context" + "github/pippellia-btc/crawler/pkg/graph" + "github/pippellia-btc/crawler/pkg/pipe" + "github/pippellia-btc/crawler/pkg/redb" + "github/pippellia-btc/crawler/pkg/walks" + "log" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/nbd-wtf/go-nostr" + "github.com/redis/go-redis/v9" +) + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go handleSignals(cancel) + + config, err := LoadConfig() + if err != nil { + panic(err) + } + + events := make(chan *nostr.Event, config.EventsCapacity) + pubkeys := make(chan string, config.PubkeysCapacity) + + db := redb.New(&redis.Options{Addr: config.RedisAddress}) + count, err := db.NodeCount(ctx) + if err != nil { + panic(err) + } + + if count == 0 { + log.Println("initializing crawler from empty database") + + nodes := make([]graph.ID, len(config.InitPubkeys)) + for i, pk := range config.InitPubkeys { + nodes[i], err = db.AddNode(ctx, pk) + if err != nil { + panic(err) + } + + pubkeys <- pk // add to queue + } + + walks, err := walks.Generate(ctx, db, nodes...) + if err != nil { + panic(err) + } + + if err := db.AddWalks(ctx, walks...); err != nil { + panic(err) + } + + log.Printf("correctly added %d init pubkeys", len(config.InitPubkeys)) + } + + _ = events + + // eventStore, err := eventstore.New(config.SQLiteURL) + // if err != nil { + // panic("failed to connect to the sqlite eventstore: " + err.Error()) + // } + + var wg sync.WaitGroup + wg.Add(3) + + go func() { + defer wg.Done() + pipe.Firehose(ctx, config.Firehose, db, func(event *nostr.Event) error { + select { + case events <- event: + default: + log.Printf("Firehose: channel is full, dropping event ID %s by %s", event.ID, event.PubKey) + } + return nil + }) + }() + + go func() { + defer wg.Done() + pipe.Fetcher(ctx, config.Fetcher, pubkeys, func(event *nostr.Event) error { + select { + case events <- event: + default: + log.Printf("Fetcher: channel is full, dropping event ID %s by %s", event.ID, event.PubKey) + } + return nil + }) + }() + + go func() { + defer wg.Done() + pipe.Arbiter(ctx, config.Arbiter, db, func(pubkey string) error { + select { + case pubkeys <- pubkey: + default: + log.Printf("Arbiter: channel is full, dropping pubkey %s", pubkey) + } + return nil + }) + }() + + log.Println("ready to process events") + pipe.Processor(ctx, config.Processor, db, events) + wg.Wait() +} + +// handleSignals listens for OS signals and triggers context cancellation. +func handleSignals(cancel context.CancelFunc) { + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + <-signals + + log.Println(" Signal received. Shutting down...") + cancel() +} diff --git a/go.mod b/go.mod index e8e30d8..feb2417 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.1 toolchain go1.24.3 require ( + github.com/joho/godotenv v1.5.1 github.com/nbd-wtf/go-nostr v0.51.12 github.com/redis/go-redis/v9 v9.8.0 ) diff --git a/go.sum b/go.sum index 337dfb1..64026b9 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= diff --git a/pkg/pipe/arbiter.go b/pkg/pipe/arbiter.go index 5a64faa..1570c2b 100644 --- a/pkg/pipe/arbiter.go +++ b/pkg/pipe/arbiter.go @@ -14,25 +14,24 @@ import ( // walksTracker tracks the number of walks that have been updated by [Processor]. // It's used to wake-up the [Arbiter], which performs work and then resets it to 0. -var walksTracker *atomic.Int32 +var walksTracker atomic.Int32 type ArbiterConfig struct { Activation float64 Promotion float64 Demotion float64 - PingPeriod time.Duration - WaitPeriod time.Duration + PromotionWait time.Duration + PingWait time.Duration } func NewArbiterConfig() ArbiterConfig { return ArbiterConfig{ - Activation: 0.01, - Promotion: 0.1, - Demotion: 1.05, - - PingPeriod: time.Minute, - WaitPeriod: time.Hour, + Activation: 0.01, + Promotion: 0.1, + Demotion: 1.05, + PromotionWait: time.Hour, + PingWait: time.Minute, } } @@ -41,16 +40,19 @@ func (c ArbiterConfig) Print() { fmt.Printf(" Activation: %f\n", c.Activation) fmt.Printf(" Promotion: %f\n", c.Promotion) fmt.Printf(" Demotion: %f\n", c.Demotion) - fmt.Printf(" WaitPeriod: %v\n", c.WaitPeriod) + fmt.Printf(" PromotionWait: %v\n", c.PromotionWait) + fmt.Printf(" PingWait: %v\n", c.PingWait) } // Arbiter activates when the % of walks changed is greater than a threshold. Then it: // - scans through all the nodes in the database // - promotes or demotes nodes func Arbiter(ctx context.Context, config ArbiterConfig, db redb.RedisDB, send func(pk string) error) { - ticker := time.NewTicker(config.PingPeriod) + ticker := time.NewTicker(config.PingWait) defer ticker.Stop() + walksTracker.Add(1000_000_000) // trigger a scan at startup + for { select { case <-ctx.Done(): @@ -140,7 +142,7 @@ func arbiterScan(ctx context.Context, config ArbiterConfig, db redb.RedisDB, sen return promoted, demoted, fmt.Errorf("node %s doesn't have an addition record", node.ID) } - if ranks[i] >= promotionThreshold && time.Since(added) > config.WaitPeriod { + if ranks[i] >= promotionThreshold && time.Since(added) > config.PromotionWait { if err := promote(db, node.ID); err != nil { return promoted, demoted, err } diff --git a/pkg/pipe/intake.go b/pkg/pipe/intake.go index bcf6398..92c0754 100644 --- a/pkg/pipe/intake.go +++ b/pkg/pipe/intake.go @@ -12,7 +12,7 @@ import ( var ( relevantKinds = []int{ - nostr.KindProfileMetadata, + //nostr.KindProfileMetadata, nostr.KindFollowList, } @@ -156,8 +156,8 @@ func (c FetcherConfig) Print() { fmt.Printf(" Interval: %v\n", c.Interval) } -// Fetcher extracts pubkeys from the channel and queries for their events when either: -// - the batch is bigger than config.Batch +// Fetcher extracts pubkeys from the channel and queries for their events: +// - when the batch is bigger than config.Batch // - after config.Interval since the last query. func Fetcher(ctx context.Context, config FetcherConfig, pubkeys <-chan string, send func(*nostr.Event) error) { batch := make([]string, 0, config.Batch) diff --git a/pkg/pipe/processor.go b/pkg/pipe/processor.go index f99f2bb..0866e55 100644 --- a/pkg/pipe/processor.go +++ b/pkg/pipe/processor.go @@ -21,7 +21,7 @@ type ProcessorConfig struct { PrintEvery int } -func NewProcessEventsConfig() ProcessorConfig { +func NewProcessorConfig() ProcessorConfig { return ProcessorConfig{PrintEvery: 5000} } @@ -126,6 +126,7 @@ func processFollowList(cache *walks.CachedWalker, db redb.RedisDB, event *nostr. return err } + walksTracker.Add(int32(len(new))) return cache.Update(ctx, delta) } diff --git a/pkg/redb/graph.go b/pkg/redb/graph.go index 158d8c8..329c8b3 100644 --- a/pkg/redb/graph.go +++ b/pkg/redb/graph.go @@ -39,7 +39,7 @@ type RedisDB struct { func New(opt *redis.Options) RedisDB { db := RedisDB{client: redis.NewClient(opt)} - if err := db.validateWalks(); err != nil { + if err := db.init(); err != nil { panic(err) } return db @@ -49,7 +49,7 @@ func New(opt *redis.Options) RedisDB { func (db RedisDB) Size(ctx context.Context) (int, error) { size, err := db.client.DBSize(ctx).Result() if err != nil { - return 0, err + return 0, fmt.Errorf("failed to fetch the db size: %w", err) } return int(size), nil } @@ -58,7 +58,7 @@ func (db RedisDB) Size(ctx context.Context) (int, error) { func (db RedisDB) NodeCount(ctx context.Context) (int, error) { nodes, err := db.client.HLen(ctx, KeyKeyIndex).Result() if err != nil { - return 0, err + return 0, fmt.Errorf("failed to fetch the node count: %w", err) } return int(nodes), nil } diff --git a/pkg/redb/walks.go b/pkg/redb/walks.go index 34b7c42..208a43e 100644 --- a/pkg/redb/walks.go +++ b/pkg/redb/walks.go @@ -7,7 +7,6 @@ import ( "fmt" "github/pippellia-btc/crawler/pkg/graph" "github/pippellia-btc/crawler/pkg/walks" - "math" "slices" "strconv" @@ -30,14 +29,64 @@ var ( ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks") ) +// init the walk store checking the existence of [KeyRWS]. +// If it exists, check its fields for consistency +// If it doesn't, store [walks.Alpha] and [walks.N] +func (db RedisDB) init() error { + ctx := context.Background() + exists, err := db.client.Exists(ctx, KeyRWS).Result() + if err != nil { + return fmt.Errorf("failed to check for existence of %s %w", KeyRWS, err) + } + + switch exists { + case 1: + // exists, check the values + vals, err := db.client.HMGet(ctx, KeyRWS, KeyAlpha, KeyWalksPerNode).Result() + if err != nil { + return fmt.Errorf("failed to fetch alpha and walksPerNode %w", err) + } + + alpha, err := parseFloat(vals[0]) + if err != nil { + return fmt.Errorf("failed to parse alpha: %w", err) + } + + N, err := parseInt(vals[1]) + if err != nil { + return fmt.Errorf("failed to parse walksPerNode: %w", err) + } + + if alpha != walks.Alpha { + return errors.New("alpha and walks.Alpha are different") + } + + if N != walks.N { + return errors.New("N and walks.N are different") + } + + case 0: + // doesn't exists, seed the values + err := db.client.HSet(ctx, KeyRWS, KeyAlpha, walks.Alpha, KeyWalksPerNode, walks.N).Err() + if err != nil { + return fmt.Errorf("failed to set alpha and walksPerNode %w", err) + } + + default: + return fmt.Errorf("unexpected exists: %d", exists) + } + + return nil +} + // Walks returns the walks associated with the IDs. -func (r RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, error) { +func (db RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, error) { switch { case len(IDs) == 0: return nil, nil case len(IDs) <= 100000: - vals, err := r.client.HMGet(ctx, KeyWalks, toStrings(IDs)...).Result() + vals, err := db.client.HMGet(ctx, KeyWalks, toStrings(IDs)...).Result() if err != nil { return nil, fmt.Errorf("failed to fetch walks: %w", err) } @@ -58,12 +107,12 @@ func (r RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, erro default: // too many walks for a single call, so we split them in two batches mid := len(IDs) / 2 - batch1, err := r.Walks(ctx, IDs[:mid]...) + batch1, err := db.Walks(ctx, IDs[:mid]...) if err != nil { return nil, err } - batch2, err := r.Walks(ctx, IDs[mid:]...) + batch2, err := db.Walks(ctx, IDs[mid:]...) if err != nil { return nil, err } @@ -74,24 +123,24 @@ func (r RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, erro // WalksVisiting returns up-to limit walks that visit node. // Use limit = -1 to fetch all the walks visiting node. -func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([]walks.Walk, error) { +func (db RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([]walks.Walk, error) { switch { case limit == -1: // return all walks visiting node - IDs, err := r.client.SMembers(ctx, walksVisiting(node)).Result() + IDs, err := db.client.SMembers(ctx, walksVisiting(node)).Result() if err != nil { return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err) } - return r.Walks(ctx, toWalks(IDs)...) + return db.Walks(ctx, toWalks(IDs)...) case limit > 0: - IDs, err := r.client.SRandMemberN(ctx, walksVisiting(node), int64(limit)).Result() + IDs, err := db.client.SRandMemberN(ctx, walksVisiting(node), int64(limit)).Result() if err != nil { return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err) } - return r.Walks(ctx, toWalks(IDs)...) + return db.Walks(ctx, toWalks(IDs)...) default: return nil, ErrInvalidLimit @@ -102,11 +151,11 @@ func (r RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([ // The walks are distributed evenly among the nodes: // - if limit == -1, all walks are returned (use with few nodes) // - if limit < len(nodes), no walks are returned -func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit int) ([]walks.Walk, error) { +func (db RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit int) ([]walks.Walk, error) { switch { case limit == -1: // return all walks visiting all nodes - pipe := r.client.Pipeline() + pipe := db.client.Pipeline() cmds := make([]*redis.StringSliceCmd, len(nodes)) for i, node := range nodes { @@ -123,7 +172,7 @@ func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit i } unique := unique(IDs) - return r.Walks(ctx, toWalks(unique)...) + return db.Walks(ctx, toWalks(unique)...) case limit > 0: // return limit walks uniformely distributed across all nodes @@ -132,7 +181,7 @@ func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit i return nil, nil } - pipe := r.client.Pipeline() + pipe := db.client.Pipeline() cmds := make([]*redis.StringSliceCmd, len(nodes)) for i, node := range nodes { @@ -149,7 +198,7 @@ func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit i } unique := unique(IDs) - return r.Walks(ctx, toWalks(unique)...) + return db.Walks(ctx, toWalks(unique)...) default: // invalid limit @@ -158,20 +207,20 @@ func (r RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit i } // AddWalks adds all the walks to the database assigning them progressive IDs. -func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error { +func (db RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error { if len(walks) == 0 { return nil } // get the IDs outside the transaction, which implies there might be "holes", // meaning IDs not associated with any walk - next, err := r.client.HIncrBy(ctx, KeyRWS, KeyLastWalkID, int64(len(walks))).Result() + next, err := db.client.HIncrBy(ctx, KeyRWS, KeyLastWalkID, int64(len(walks))).Result() if err != nil { return fmt.Errorf("failed to add walks: failed to increment ID: %w", err) } var visits, ID int - pipe := r.client.TxPipeline() + pipe := db.client.TxPipeline() for i, walk := range walks { visits += walk.Len() @@ -193,13 +242,13 @@ func (r RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error { } // RemoveWalks removes all the walks from the database. -func (r RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error { +func (db RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error { if len(walks) == 0 { return nil } var visits int - pipe := r.client.TxPipeline() + pipe := db.client.TxPipeline() for _, walk := range walks { pipe.HDel(ctx, KeyWalks, string(walk.ID)) @@ -219,13 +268,14 @@ func (r RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error { return nil } -func (r RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) error { +// ReplaceWalks replaces the old walks with the new ones. +func (db RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) error { if err := validateReplacement(before, after); err != nil { return err } var visits int64 - pipe := r.client.TxPipeline() + pipe := db.client.TxPipeline() for i := range before { div := walks.Divergence(before[i], after[i]) @@ -258,7 +308,7 @@ func (r RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) e return fmt.Errorf("failed to replace walks: pipeline failed %w", err) } - pipe = r.client.TxPipeline() + pipe = db.client.TxPipeline() visits = 0 } } @@ -294,8 +344,8 @@ func validateReplacement(old, new []walks.Walk) error { } // TotalVisits returns the total number of visits, which is the sum of the lengths of all walks. -func (r RedisDB) TotalVisits(ctx context.Context) (int, error) { - total, err := r.client.HGet(ctx, KeyRWS, KeyTotalVisits).Result() +func (db RedisDB) TotalVisits(ctx context.Context) (int, error) { + total, err := db.client.HGet(ctx, KeyRWS, KeyTotalVisits).Result() if err != nil { return -1, fmt.Errorf("failed to get the total number of visits: %w", err) } @@ -309,8 +359,8 @@ func (r RedisDB) TotalVisits(ctx context.Context) (int, error) { } // TotalWalks returns the total number of walks. -func (r RedisDB) TotalWalks(ctx context.Context) (int, error) { - total, err := r.client.HLen(ctx, KeyWalks).Result() +func (db RedisDB) TotalWalks(ctx context.Context) (int, error) { + total, err := db.client.HLen(ctx, KeyWalks).Result() if err != nil { return -1, fmt.Errorf("failed to get the total number of walks: %w", err) } @@ -320,35 +370,8 @@ func (r RedisDB) TotalWalks(ctx context.Context) (int, error) { // Visits returns the number of times each specified node was visited during the walks. // The returned slice contains counts in the same order as the input nodes. // If a node is not found, it returns 0 visits. -func (r RedisDB) Visits(ctx context.Context, nodes ...graph.ID) ([]int, error) { - return r.counts(ctx, walksVisiting, nodes...) -} - -func (r RedisDB) validateWalks() error { - vals, err := r.client.HMGet(context.Background(), KeyRWS, KeyAlpha, KeyWalksPerNode).Result() - if err != nil { - return fmt.Errorf("failed to fetch alpha and walksPerNode %w", err) - } - - alpha, err := parseFloat(vals[0]) - if err != nil { - return fmt.Errorf("failed to parse alpha: %w", err) - } - - N, err := parseInt(vals[1]) - if err != nil { - return fmt.Errorf("failed to parse walksPerNode: %w", err) - } - - if math.Abs(alpha-walks.Alpha) > 1e-10 { - return errors.New("alpha and walks.Alpha are different") - } - - if N != walks.N { - return errors.New("N and walks.N are different") - } - - return nil +func (db RedisDB) Visits(ctx context.Context, nodes ...graph.ID) ([]int, error) { + return db.counts(ctx, walksVisiting, nodes...) } // unique returns a slice of unique elements of the input slice. diff --git a/pkg/redb/walks_test.go b/pkg/redb/walks_test.go index 06ebe02..5a00476 100644 --- a/pkg/redb/walks_test.go +++ b/pkg/redb/walks_test.go @@ -11,30 +11,30 @@ import ( "github.com/redis/go-redis/v9" ) -func TestValidate(t *testing.T) { - tests := []struct { - name string - setup func() (RedisDB, error) - err error - }{ - {name: "empty", setup: Empty, err: ErrValueIsNil}, - {name: "valid", setup: SomeWalks(0)}, - } +// func TestValidate(t *testing.T) { +// tests := []struct { +// name string +// setup func() (RedisDB, error) +// err error +// }{ +// {name: "empty", setup: Empty, err: ErrValueIsNil}, +// {name: "valid", setup: SomeWalks(0)}, +// } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - db, err := test.setup() - if err != nil { - t.Fatalf("setup failed: %v", err) - } - defer db.flushAll() +// for _, test := range tests { +// t.Run(test.name, func(t *testing.T) { +// db, err := test.setup() +// if err != nil { +// t.Fatalf("setup failed: %v", err) +// } +// defer db.flushAll() - if err = db.validateWalks(); !errors.Is(err, test.err) { - t.Fatalf("expected error %v, got %v", test.err, err) - } - }) - } -} +// if err = db.validateWalks(); !errors.Is(err, test.err) { +// t.Fatalf("expected error %v, got %v", test.err, err) +// } +// }) +// } +// } func TestWalksVisiting(t *testing.T) { tests := []struct {