mirror of
https://github.com/aljazceru/lightning.git
synced 2025-12-19 15:14:23 +01:00
common/random_select: central place for reservoir sampling.
Turns out we can make quite a simple API out of it. Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
@@ -57,6 +57,7 @@ COMMON_SRC_NOGEN := \
|
|||||||
common/ping.c \
|
common/ping.c \
|
||||||
common/psbt_open.c \
|
common/psbt_open.c \
|
||||||
common/pseudorand.c \
|
common/pseudorand.c \
|
||||||
|
common/random_select.c \
|
||||||
common/read_peer_msg.c \
|
common/read_peer_msg.c \
|
||||||
common/setup.c \
|
common/setup.c \
|
||||||
common/socket_close.c \
|
common/socket_close.c \
|
||||||
|
|||||||
11
common/random_select.c
Normal file
11
common/random_select.c
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#include <common/pseudorand.h>
|
||||||
|
#include <common/random_select.h>
|
||||||
|
|
||||||
|
bool random_select(double weight, double *tot_weight)
|
||||||
|
{
|
||||||
|
*tot_weight += weight;
|
||||||
|
if (weight == 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return pseudorand_double() <= weight / *tot_weight;
|
||||||
|
}
|
||||||
20
common/random_select.h
Normal file
20
common/random_select.h
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#ifndef LIGHTNING_COMMON_RANDOM_SELECT_H
|
||||||
|
#define LIGHTNING_COMMON_RANDOM_SELECT_H
|
||||||
|
#include "config.h"
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
/* Use weighted reservoir sampling, see:
|
||||||
|
* https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
|
||||||
|
* But (currently) the result will consist of only one sample (k=1)
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* random_select: return true if we should select this one.
|
||||||
|
* @weight: weight for this option (use 1.0 if all the same)
|
||||||
|
* @tot_wieght: returns with sum of weights (must be initialized to zero)
|
||||||
|
*
|
||||||
|
* This always returns true on the first non-zero weight, and weighted
|
||||||
|
* randomly from then on.
|
||||||
|
*/
|
||||||
|
bool random_select(double weight, double *tot_weight);
|
||||||
|
#endif /* LIGHTNING_COMMON_RANDOM_SELECT_H */
|
||||||
@@ -65,6 +65,7 @@ GOSSIPD_COMMON_OBJS := \
|
|||||||
common/per_peer_state.o \
|
common/per_peer_state.o \
|
||||||
common/ping.o \
|
common/ping.o \
|
||||||
common/pseudorand.o \
|
common/pseudorand.o \
|
||||||
|
common/random_select.o \
|
||||||
common/setup.o \
|
common/setup.o \
|
||||||
common/status.o \
|
common/status.o \
|
||||||
common/status_wire.o \
|
common/status_wire.o \
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <ccan/tal/tal.h>
|
#include <ccan/tal/tal.h>
|
||||||
#include <common/decode_array.h>
|
#include <common/decode_array.h>
|
||||||
#include <common/pseudorand.h>
|
#include <common/pseudorand.h>
|
||||||
|
#include <common/random_select.h>
|
||||||
#include <common/status.h>
|
#include <common/status.h>
|
||||||
#include <common/timeout.h>
|
#include <common/timeout.h>
|
||||||
#include <common/type_to_string.h>
|
#include <common/type_to_string.h>
|
||||||
@@ -454,7 +455,7 @@ static bool get_unannounced_nodes(const tal_t *ctx,
|
|||||||
{
|
{
|
||||||
size_t num = 0;
|
size_t num = 0;
|
||||||
u64 offset;
|
u64 offset;
|
||||||
u64 threshold = pseudorand_u64();
|
double total_weight = 0.0;
|
||||||
|
|
||||||
/* Pick an example short_channel_id at random to query. As a
|
/* Pick an example short_channel_id at random to query. As a
|
||||||
* side-effect this gets the node. */
|
* side-effect this gets the node. */
|
||||||
@@ -475,11 +476,8 @@ static bool get_unannounced_nodes(const tal_t *ctx,
|
|||||||
(*scids)[num++] = c->scid;
|
(*scids)[num++] = c->scid;
|
||||||
} else {
|
} else {
|
||||||
/* Maybe replace one: approx. reservoir sampling */
|
/* Maybe replace one: approx. reservoir sampling */
|
||||||
u64 p = pseudorand_u64();
|
if (random_select(1.0, &total_weight))
|
||||||
if (p > threshold) {
|
|
||||||
(*scids)[pseudorand(max)] = c->scid;
|
(*scids)[pseudorand(max)] = c->scid;
|
||||||
threshold = p;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,9 @@ void queue_peer_msg(struct peer *peer UNNEEDED, const u8 *msg TAKES UNNEEDED)
|
|||||||
struct peer *random_peer(struct daemon *daemon UNNEEDED,
|
struct peer *random_peer(struct daemon *daemon UNNEEDED,
|
||||||
bool (*check_peer)(const struct peer *peer))
|
bool (*check_peer)(const struct peer *peer))
|
||||||
{ fprintf(stderr, "random_peer called!\n"); abort(); }
|
{ fprintf(stderr, "random_peer called!\n"); abort(); }
|
||||||
|
/* Generated stub for random_select */
|
||||||
|
bool random_select(double weight UNNEEDED, double *tot_weight UNNEEDED)
|
||||||
|
{ fprintf(stderr, "random_select called!\n"); abort(); }
|
||||||
/* Generated stub for status_failed */
|
/* Generated stub for status_failed */
|
||||||
void status_failed(enum status_failreason code UNNEEDED,
|
void status_failed(enum status_failreason code UNNEEDED,
|
||||||
const char *fmt UNNEEDED, ...)
|
const char *fmt UNNEEDED, ...)
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ LIGHTNINGD_COMMON_OBJS := \
|
|||||||
common/per_peer_state.o \
|
common/per_peer_state.o \
|
||||||
common/permute_tx.o \
|
common/permute_tx.o \
|
||||||
common/pseudorand.o \
|
common/pseudorand.o \
|
||||||
|
common/random_select.o \
|
||||||
common/setup.o \
|
common/setup.o \
|
||||||
common/sphinx.o \
|
common/sphinx.o \
|
||||||
common/status_wire.o \
|
common/status_wire.o \
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
#include <common/jsonrpc_errors.h>
|
#include <common/jsonrpc_errors.h>
|
||||||
#include <common/overflows.h>
|
#include <common/overflows.h>
|
||||||
#include <common/param.h>
|
#include <common/param.h>
|
||||||
#include <common/pseudorand.h>
|
#include <common/random_select.h>
|
||||||
#include <common/timeout.h>
|
#include <common/timeout.h>
|
||||||
#include <common/utils.h>
|
#include <common/utils.h>
|
||||||
#include <errno.h>
|
#include <errno.h>
|
||||||
@@ -489,15 +489,8 @@ static struct route_info **select_inchan(const tal_t *ctx,
|
|||||||
bool *any_offline)
|
bool *any_offline)
|
||||||
{
|
{
|
||||||
/* BOLT11 struct wants an array of arrays (can provide multiple routes) */
|
/* BOLT11 struct wants an array of arrays (can provide multiple routes) */
|
||||||
struct route_info **R;
|
struct route_info **r = NULL;
|
||||||
double wsum, p;
|
double total_weight = 0.0;
|
||||||
|
|
||||||
struct sample {
|
|
||||||
const struct route_info *route;
|
|
||||||
double weight;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct sample *S = tal_arr(tmpctx, struct sample, 0);
|
|
||||||
|
|
||||||
*any_offline = false;
|
*any_offline = false;
|
||||||
|
|
||||||
@@ -505,7 +498,6 @@ static struct route_info **select_inchan(const tal_t *ctx,
|
|||||||
for (size_t i = 0; i < tal_count(inchans); i++) {
|
for (size_t i = 0; i < tal_count(inchans); i++) {
|
||||||
struct peer *peer;
|
struct peer *peer;
|
||||||
struct channel *c;
|
struct channel *c;
|
||||||
struct sample sample;
|
|
||||||
struct amount_msat capacity_to_pay_us, excess, capacity;
|
struct amount_msat capacity_to_pay_us, excess, capacity;
|
||||||
struct amount_sat cumulative_reserve;
|
struct amount_sat cumulative_reserve;
|
||||||
double excess_frac;
|
double excess_frac;
|
||||||
@@ -564,33 +556,23 @@ static struct route_info **select_inchan(const tal_t *ctx,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* We don't want a 0 probability if 0 excess; it might be the
|
||||||
|
* only one! So bump it by 1 msat */
|
||||||
|
if (!amount_msat_add(&excess, excess, AMOUNT_MSAT(1))) {
|
||||||
|
log_broken(ld->log, "Channel %s excess overflow!",
|
||||||
|
type_to_string(tmpctx, struct short_channel_id, c->scid));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
excess_frac = amount_msat_ratio(excess, capacity);
|
excess_frac = amount_msat_ratio(excess, capacity);
|
||||||
|
|
||||||
sample.route = &inchans[i];
|
if (random_select(excess_frac, &total_weight)) {
|
||||||
sample.weight = excess_frac;
|
tal_free(r);
|
||||||
tal_arr_expand(&S, sample);
|
r = tal_arr(ctx, struct route_info *, 1);
|
||||||
|
r[0] = tal_dup(r, struct route_info, &inchans[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!tal_count(S))
|
return r;
|
||||||
return NULL;
|
|
||||||
|
|
||||||
/* Use weighted reservoir sampling, see:
|
|
||||||
* https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
|
|
||||||
* But (currently) the result will consist of only one sample (k=1) */
|
|
||||||
R = tal_arr(ctx, struct route_info *, 1);
|
|
||||||
R[0] = tal_dup(R, struct route_info, S[0].route);
|
|
||||||
wsum = S[0].weight;
|
|
||||||
|
|
||||||
for (size_t i = 1; i < tal_count(S); i++) {
|
|
||||||
wsum += S[i].weight;
|
|
||||||
p = S[i].weight / wsum;
|
|
||||||
double random_1 = pseudorand_double(); /* range [0,1) */
|
|
||||||
|
|
||||||
if (random_1 <= p)
|
|
||||||
R[0] = tal_dup(R, struct route_info, S[i].route);
|
|
||||||
}
|
|
||||||
|
|
||||||
return R;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** select_inchan_mpp
|
/** select_inchan_mpp
|
||||||
@@ -1414,6 +1396,7 @@ static struct command_result *json_waitanyinvoice(struct command *cmd,
|
|||||||
" is non-trivial.");
|
" is non-trivial.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static const struct json_command waitanyinvoice_command = {
|
static const struct json_command waitanyinvoice_command = {
|
||||||
"waitanyinvoice",
|
"waitanyinvoice",
|
||||||
"payment",
|
"payment",
|
||||||
@@ -1423,7 +1406,6 @@ static const struct json_command waitanyinvoice_command = {
|
|||||||
};
|
};
|
||||||
AUTODATA(json_command, &waitanyinvoice_command);
|
AUTODATA(json_command, &waitanyinvoice_command);
|
||||||
|
|
||||||
|
|
||||||
/* Wait for an incoming payment matching the `label` in the JSON
|
/* Wait for an incoming payment matching the `label` in the JSON
|
||||||
* command. This will either return immediately if the payment has
|
* command. This will either return immediately if the payment has
|
||||||
* already been received or it may add the `cmd` to the list of
|
* already been received or it may add the `cmd` to the list of
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ LIGHTNINGD_TEST_COMMON_OBJS := \
|
|||||||
common/json.o \
|
common/json.o \
|
||||||
common/key_derive.o \
|
common/key_derive.o \
|
||||||
common/pseudorand.o \
|
common/pseudorand.o \
|
||||||
|
common/random_select.o \
|
||||||
common/memleak.o \
|
common/memleak.o \
|
||||||
common/msg_queue.o \
|
common/msg_queue.o \
|
||||||
common/utils.o \
|
common/utils.o \
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ PLUGIN_COMMON_OBJS := \
|
|||||||
common/node_id.o \
|
common/node_id.o \
|
||||||
common/param.o \
|
common/param.o \
|
||||||
common/pseudorand.o \
|
common/pseudorand.o \
|
||||||
|
common/random_select.o \
|
||||||
common/setup.o \
|
common/setup.o \
|
||||||
common/type_to_string.o \
|
common/type_to_string.o \
|
||||||
common/utils.o \
|
common/utils.o \
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#include <ccan/tal/str/str.h>
|
#include <ccan/tal/str/str.h>
|
||||||
#include <common/json_stream.h>
|
#include <common/json_stream.h>
|
||||||
#include <common/pseudorand.h>
|
#include <common/pseudorand.h>
|
||||||
|
#include <common/random_select.h>
|
||||||
#include <common/type_to_string.h>
|
#include <common/type_to_string.h>
|
||||||
#include <plugins/libplugin-pay.h>
|
#include <plugins/libplugin-pay.h>
|
||||||
|
|
||||||
@@ -2421,12 +2422,11 @@ static struct command_result *shadow_route_listchannels(struct command *cmd,
|
|||||||
const jsmntok_t *result,
|
const jsmntok_t *result,
|
||||||
struct payment *p)
|
struct payment *p)
|
||||||
{
|
{
|
||||||
/* Use reservoir sampling across the capable channels. */
|
|
||||||
struct shadow_route_data *d = payment_mod_shadowroute_get_data(p);
|
struct shadow_route_data *d = payment_mod_shadowroute_get_data(p);
|
||||||
struct payment_constraints *cons = &d->constraints;
|
struct payment_constraints *cons = &d->constraints;
|
||||||
struct route_info *best = NULL;
|
struct route_info *best = NULL;
|
||||||
|
double total_weight = 0.0;
|
||||||
size_t i;
|
size_t i;
|
||||||
u64 sample = 0;
|
|
||||||
struct amount_msat best_fee;
|
struct amount_msat best_fee;
|
||||||
const jsmntok_t *sattok, *delaytok, *basefeetok, *propfeetok, *desttok,
|
const jsmntok_t *sattok, *delaytok, *basefeetok, *propfeetok, *desttok,
|
||||||
*channelstok, *chan, *scidtok;
|
*channelstok, *chan, *scidtok;
|
||||||
@@ -2438,7 +2438,6 @@ static struct command_result *shadow_route_listchannels(struct command *cmd,
|
|||||||
|
|
||||||
channelstok = json_get_member(buf, result, "channels");
|
channelstok = json_get_member(buf, result, "channels");
|
||||||
json_for_each_arr(i, chan, channelstok) {
|
json_for_each_arr(i, chan, channelstok) {
|
||||||
u64 v = pseudorand(UINT64_MAX);
|
|
||||||
struct route_info curr;
|
struct route_info curr;
|
||||||
struct amount_sat capacity;
|
struct amount_sat capacity;
|
||||||
struct amount_msat fee;
|
struct amount_msat fee;
|
||||||
@@ -2465,28 +2464,27 @@ static struct command_result *shadow_route_listchannels(struct command *cmd,
|
|||||||
json_to_sat(buf, sattok, &capacity);
|
json_to_sat(buf, sattok, &capacity);
|
||||||
json_to_node_id(buf, desttok, &curr.pubkey);
|
json_to_node_id(buf, desttok, &curr.pubkey);
|
||||||
|
|
||||||
if (!best || v > sample) {
|
/* If the capacity is insufficient to pass the amount
|
||||||
/* If the capacity is insufficient to pass the amount
|
* it's not a plausible extension. */
|
||||||
* it's not a plausible extension. */
|
if (amount_msat_greater_sat(p->amount, capacity))
|
||||||
if (amount_msat_greater_sat(p->amount, capacity))
|
continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
if (curr.cltv_expiry_delta > cons->cltv_budget)
|
if (curr.cltv_expiry_delta > cons->cltv_budget)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (!amount_msat_fee(
|
if (!amount_msat_fee(
|
||||||
&fee, p->amount, curr.fee_base_msat,
|
&fee, p->amount, curr.fee_base_msat,
|
||||||
curr.fee_proportional_millionths)) {
|
curr.fee_proportional_millionths)) {
|
||||||
/* Fee computation failed... */
|
/* Fee computation failed... */
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (amount_msat_greater_eq(fee, cons->fee_budget))
|
if (amount_msat_greater_eq(fee, cons->fee_budget))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
if (random_select(1.0, &total_weight)) {
|
||||||
best = tal_dup(tmpctx, struct route_info, &curr);
|
best = tal_dup(tmpctx, struct route_info, &curr);
|
||||||
best_fee = fee;
|
best_fee = fee;
|
||||||
sample = v;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user