mirror of
https://github.com/aljazceru/crawler_v2.git
synced 2025-12-17 07:24:21 +01:00
395 lines
10 KiB
Go
395 lines
10 KiB
Go
package redb
|
|
|
|
import (
|
|
"cmp"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github/pippellia-btc/crawler/pkg/graph"
|
|
"github/pippellia-btc/crawler/pkg/walks"
|
|
"slices"
|
|
"strconv"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
const (
|
|
KeyRWS = "RWS" // TODO: this can be removed
|
|
KeyAlpha = "alpha" // TODO: walks:alpha
|
|
KeyWalksPerNode = "walksPerNode" // TODO: walks:N or another
|
|
KeyLastWalkID = "lastWalkID" // TODO: walks:next
|
|
KeyTotalVisits = "totalVisits" // TODO: walks:total_visits
|
|
KeyWalks = "walks"
|
|
KeyWalksVisitingPrefix = "walksVisiting:" // TODO: walks_visiting:
|
|
)
|
|
|
|
var (
|
|
ErrWalkNotFound = errors.New("walk not found")
|
|
ErrInvalidReplacement = errors.New("invalid walk replacement")
|
|
ErrInvalidLimit = errors.New("limit must be a positive integer, or -1 to fetch all walks")
|
|
)
|
|
|
|
// init the walk store checking the existence of [KeyRWS].
|
|
// If it exists, check its fields for consistency
|
|
// If it doesn't, store [walks.Alpha] and [walks.N]
|
|
func (db RedisDB) init() error {
|
|
ctx := context.Background()
|
|
exists, err := db.client.Exists(ctx, KeyRWS).Result()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check for existence of %s %w", KeyRWS, err)
|
|
}
|
|
|
|
switch exists {
|
|
case 1:
|
|
// exists, check the values
|
|
vals, err := db.client.HMGet(ctx, KeyRWS, KeyAlpha, KeyWalksPerNode).Result()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch alpha and walksPerNode %w", err)
|
|
}
|
|
|
|
alpha, err := parseFloat(vals[0])
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse alpha: %w", err)
|
|
}
|
|
|
|
N, err := parseInt(vals[1])
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse walksPerNode: %w", err)
|
|
}
|
|
|
|
if alpha != walks.Alpha {
|
|
return errors.New("alpha and walks.Alpha are different")
|
|
}
|
|
|
|
if N != walks.N {
|
|
return errors.New("N and walks.N are different")
|
|
}
|
|
|
|
case 0:
|
|
// doesn't exists, seed the values
|
|
err := db.client.HSet(ctx, KeyRWS, KeyAlpha, walks.Alpha, KeyWalksPerNode, walks.N).Err()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to set alpha and walksPerNode %w", err)
|
|
}
|
|
|
|
default:
|
|
return fmt.Errorf("unexpected exists: %d", exists)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Walks returns the walks associated with the IDs.
|
|
func (db RedisDB) Walks(ctx context.Context, IDs ...walks.ID) ([]walks.Walk, error) {
|
|
switch {
|
|
case len(IDs) == 0:
|
|
return nil, nil
|
|
|
|
case len(IDs) <= 100000:
|
|
vals, err := db.client.HMGet(ctx, KeyWalks, toStrings(IDs)...).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch walks: %w", err)
|
|
}
|
|
|
|
walks := make([]walks.Walk, len(vals))
|
|
for i, val := range vals {
|
|
if val == nil {
|
|
// walk was not found, so return an error
|
|
return nil, fmt.Errorf("failed to fetch walk with ID %s: %w", IDs[i], ErrWalkNotFound)
|
|
}
|
|
|
|
walks[i] = parseWalk(val.(string))
|
|
walks[i].ID = IDs[i]
|
|
}
|
|
|
|
return walks, nil
|
|
|
|
default:
|
|
// too many walks for a single call, so we split them in two batches
|
|
mid := len(IDs) / 2
|
|
batch1, err := db.Walks(ctx, IDs[:mid]...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
batch2, err := db.Walks(ctx, IDs[mid:]...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return append(batch1, batch2...), nil
|
|
}
|
|
}
|
|
|
|
// WalksVisiting returns up-to limit walks that visit node.
|
|
// Use limit = -1 to fetch all the walks visiting node.
|
|
func (db RedisDB) WalksVisiting(ctx context.Context, node graph.ID, limit int) ([]walks.Walk, error) {
|
|
switch {
|
|
case limit == -1:
|
|
// return all walks visiting node
|
|
IDs, err := db.client.SMembers(ctx, walksVisiting(node)).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err)
|
|
}
|
|
|
|
return db.Walks(ctx, toWalks(IDs)...)
|
|
|
|
case limit > 0:
|
|
IDs, err := db.client.SRandMemberN(ctx, walksVisiting(node), int64(limit)).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch %s: %w", walksVisiting(node), err)
|
|
}
|
|
|
|
return db.Walks(ctx, toWalks(IDs)...)
|
|
|
|
default:
|
|
return nil, ErrInvalidLimit
|
|
}
|
|
}
|
|
|
|
// WalksVisitingAny returns up to limit walks that visit the specified nodes.
|
|
// The walks are distributed evenly among the nodes:
|
|
// - if limit == -1, all walks are returned (use with few nodes)
|
|
// - if limit < len(nodes), no walks are returned
|
|
func (db RedisDB) WalksVisitingAny(ctx context.Context, nodes []graph.ID, limit int) ([]walks.Walk, error) {
|
|
switch {
|
|
case limit == -1:
|
|
// return all walks visiting all nodes
|
|
pipe := db.client.Pipeline()
|
|
cmds := make([]*redis.StringSliceCmd, len(nodes))
|
|
|
|
for i, node := range nodes {
|
|
cmds[i] = pipe.SMembers(ctx, walksVisiting(node))
|
|
}
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to fetch all walks visiting %d nodes: %w", len(nodes), err)
|
|
}
|
|
|
|
IDs := make([]string, 0, walks.N*len(nodes))
|
|
for _, cmd := range cmds {
|
|
IDs = append(IDs, cmd.Val()...)
|
|
}
|
|
|
|
unique := unique(IDs)
|
|
return db.Walks(ctx, toWalks(unique)...)
|
|
|
|
case limit > 0:
|
|
// return limit walks uniformely distributed across all nodes
|
|
nodeLimit := int64(limit / len(nodes))
|
|
if nodeLimit == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
pipe := db.client.Pipeline()
|
|
cmds := make([]*redis.StringSliceCmd, len(nodes))
|
|
|
|
for i, node := range nodes {
|
|
cmds[i] = pipe.SRandMemberN(ctx, walksVisiting(node), nodeLimit)
|
|
}
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to fetch %d walks visiting %d nodes: %w", limit, len(nodes), err)
|
|
}
|
|
|
|
IDs := make([]string, 0, limit)
|
|
for _, cmd := range cmds {
|
|
IDs = append(IDs, cmd.Val()...)
|
|
}
|
|
|
|
unique := unique(IDs)
|
|
return db.Walks(ctx, toWalks(unique)...)
|
|
|
|
default:
|
|
// invalid limit
|
|
return nil, fmt.Errorf("failed to fetch walks visiting any: %w", ErrInvalidLimit)
|
|
}
|
|
}
|
|
|
|
// AddWalks adds all the walks to the database assigning them progressive IDs.
|
|
func (db RedisDB) AddWalks(ctx context.Context, walks ...walks.Walk) error {
|
|
if len(walks) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// get the IDs outside the transaction, which implies there might be "holes",
|
|
// meaning IDs not associated with any walk
|
|
next, err := db.client.HIncrBy(ctx, KeyRWS, KeyLastWalkID, int64(len(walks))).Result()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to add walks: failed to increment ID: %w", err)
|
|
}
|
|
|
|
var visits, ID int
|
|
pipe := db.client.TxPipeline()
|
|
|
|
for i, walk := range walks {
|
|
visits += walk.Len()
|
|
ID = int(next) - len(walks) + i // assigning IDs in the same order
|
|
|
|
pipe.HSet(ctx, KeyWalks, ID, formatWalk(walk))
|
|
for _, node := range walk.Path {
|
|
pipe.SAdd(ctx, walksVisiting(node), ID)
|
|
}
|
|
}
|
|
|
|
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, int64(visits))
|
|
|
|
if _, err = pipe.Exec(ctx); err != nil {
|
|
return fmt.Errorf("failed to add walks: pipeline failed %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveWalks removes all the walks from the database.
|
|
func (db RedisDB) RemoveWalks(ctx context.Context, walks ...walks.Walk) error {
|
|
if len(walks) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var visits int
|
|
pipe := db.client.TxPipeline()
|
|
|
|
for _, walk := range walks {
|
|
pipe.HDel(ctx, KeyWalks, string(walk.ID))
|
|
for _, node := range walk.Path {
|
|
pipe.SRem(ctx, walksVisiting(node), string(walk.ID))
|
|
}
|
|
|
|
visits += walk.Len()
|
|
}
|
|
|
|
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, -int64(visits))
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return fmt.Errorf("failed to remove walks: pipeline failed %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ReplaceWalks replaces the old walks with the new ones.
|
|
func (db RedisDB) ReplaceWalks(ctx context.Context, before, after []walks.Walk) error {
|
|
if err := validateReplacement(before, after); err != nil {
|
|
return err
|
|
}
|
|
|
|
var visits int64
|
|
pipe := db.client.TxPipeline()
|
|
|
|
for i := range before {
|
|
div := walks.Divergence(before[i], after[i])
|
|
if div == -1 {
|
|
// the two walks are equal, skip
|
|
continue
|
|
}
|
|
|
|
prev := before[i]
|
|
next := after[i]
|
|
ID := after[i].ID
|
|
|
|
pipe.HSet(ctx, KeyWalks, ID, formatWalk(next))
|
|
|
|
for _, node := range prev.Path[div:] {
|
|
pipe.SRem(ctx, walksVisiting(node), ID)
|
|
visits--
|
|
}
|
|
|
|
for _, node := range next.Path[div:] {
|
|
pipe.SAdd(ctx, walksVisiting(node), ID)
|
|
visits++
|
|
}
|
|
|
|
if pipe.Len() > 5000 {
|
|
// execute a partial update when it's too big
|
|
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, visits)
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return fmt.Errorf("failed to replace walks: pipeline failed %w", err)
|
|
}
|
|
|
|
pipe = db.client.TxPipeline()
|
|
visits = 0
|
|
}
|
|
}
|
|
|
|
pipe.HIncrBy(ctx, KeyRWS, KeyTotalVisits, visits)
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return fmt.Errorf("failed to replace walks: pipeline failed %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateReplacement(old, new []walks.Walk) error {
|
|
if len(old) != len(new) {
|
|
return fmt.Errorf("%w: old and new walks must have the same lenght", ErrInvalidReplacement)
|
|
}
|
|
|
|
seen := make(map[walks.ID]struct{})
|
|
for i := range old {
|
|
if old[i].ID != new[i].ID {
|
|
return fmt.Errorf("%w: IDs don't match at index %d: old=%s, new=%s", ErrInvalidReplacement, i, old[i].ID, new[i].ID)
|
|
}
|
|
|
|
if _, ok := seen[old[i].ID]; ok {
|
|
return fmt.Errorf("%w: repeated walk ID %s", ErrInvalidReplacement, old[i].ID)
|
|
}
|
|
|
|
seen[old[i].ID] = struct{}{}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// TotalVisits returns the total number of visits, which is the sum of the lengths of all walks.
|
|
func (db RedisDB) TotalVisits(ctx context.Context) (int, error) {
|
|
total, err := db.client.HGet(ctx, KeyRWS, KeyTotalVisits).Result()
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to get the total number of visits: %w", err)
|
|
}
|
|
|
|
tot, err := strconv.Atoi(total)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to parse the total number of visits: %w", err)
|
|
}
|
|
|
|
return tot, nil
|
|
}
|
|
|
|
// TotalWalks returns the total number of walks.
|
|
func (db RedisDB) TotalWalks(ctx context.Context) (int, error) {
|
|
total, err := db.client.HLen(ctx, KeyWalks).Result()
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to get the total number of walks: %w", err)
|
|
}
|
|
return int(total), nil
|
|
}
|
|
|
|
// Visits returns the number of times each specified node was visited during the walks.
|
|
// The returned slice contains counts in the same order as the input nodes.
|
|
// If a node is not found, it returns 0 visits.
|
|
func (db RedisDB) Visits(ctx context.Context, nodes ...graph.ID) ([]int, error) {
|
|
return db.counts(ctx, walksVisiting, nodes...)
|
|
}
|
|
|
|
// unique returns a slice of unique elements of the input slice.
|
|
func unique[E cmp.Ordered](slice []E) []E {
|
|
if len(slice) == 0 {
|
|
return nil
|
|
}
|
|
|
|
slices.Sort(slice)
|
|
unique := make([]E, 0, len(slice))
|
|
unique = append(unique, slice[0])
|
|
|
|
for i := 1; i < len(slice); i++ {
|
|
if slice[i] != slice[i-1] {
|
|
unique = append(unique, slice[i])
|
|
}
|
|
}
|
|
|
|
return unique
|
|
}
|