common/route: route_from_dijkstra returns route_hop array.

This is what (most) callers actually want, so unify it into one place.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2021-05-22 16:40:01 +09:30
parent e531a38963
commit 2bb365a931
12 changed files with 225 additions and 222 deletions

View File

@@ -4,8 +4,8 @@
#include <ccan/tal/str/str.h>
#include <ccan/time/time.h>
#include <common/dijkstra.h>
#include <common/features.h>
#include <common/gossmap.h>
#include <common/node_id.h>
#include <common/pseudorand.h>
#include <common/random_select.h>
#include <common/route.h>
@@ -64,31 +64,82 @@ u64 route_score_cheaper(u32 distance,
return ((u64)costs_to_score(cost, risk) << 32) + distance;
}
struct route **route_from_dijkstra(const tal_t *ctx,
const struct gossmap *map,
const struct dijkstra *dij,
const struct gossmap_node *cur)
/* Recursive version: return false if we can't get there.
*
* amount and cltv are updated, and reflect the amount we
* and delay would have to put into the first channel (usually
* ignored, since we don't pay for our own channels!).
*/
static bool dijkstra_to_hops(struct route_hop **hops,
const struct gossmap *gossmap,
const struct dijkstra *dij,
const struct gossmap_node *cur,
struct amount_msat *amount,
u32 *cltv)
{
struct route **path = tal_arr(ctx, struct route *, 0);
u32 curidx = gossmap_node_idx(map, cur);
u32 curidx = gossmap_node_idx(gossmap, cur);
u32 dist = dijkstra_distance(dij, curidx);
struct gossmap_chan *c;
const struct gossmap_node *next;
size_t num_hops = tal_count(*hops);
const struct half_chan *h;
while (dijkstra_distance(dij, curidx) != 0) {
struct route *r;
if (dist == 0)
return true;
if (dijkstra_distance(dij, curidx) == UINT_MAX)
return tal_free(path);
if (dist == UINT_MAX)
return false;
r = tal(path, struct route);
r->c = dijkstra_best_chan(dij, curidx);
if (r->c->half[0].nodeidx == curidx) {
r->dir = 0;
} else {
assert(r->c->half[1].nodeidx == curidx);
r->dir = 1;
}
tal_arr_expand(&path, r);
cur = gossmap_nth_node(map, r->c, !r->dir);
curidx = gossmap_node_idx(map, cur);
tal_resize(hops, num_hops + 1);
/* OK, populate other fields. */
c = dijkstra_best_chan(dij, curidx);
if (c->half[0].nodeidx == curidx) {
(*hops)[num_hops].direction = 0;
} else {
assert(c->half[1].nodeidx == curidx);
(*hops)[num_hops].direction = 1;
}
return path;
(*hops)[num_hops].scid = gossmap_chan_scid(gossmap, c);
/* Find other end of channel. */
next = gossmap_nth_node(gossmap, c, !(*hops)[num_hops].direction);
gossmap_node_get_id(gossmap, next, &(*hops)[num_hops].node_id);
if (gossmap_node_get_feature(gossmap, next, OPT_VAR_ONION) != -1)
(*hops)[num_hops].style = ROUTE_HOP_TLV;
else
(*hops)[num_hops].style = ROUTE_HOP_LEGACY;
/* These are (ab)used by others. */
(*hops)[num_hops].blinding = NULL;
(*hops)[num_hops].enctlv = NULL;
if (!dijkstra_to_hops(hops, gossmap, dij, next, amount, cltv))
return false;
(*hops)[num_hops].amount = *amount;
(*hops)[num_hops].delay = *cltv;
h = &c->half[(*hops)[num_hops].direction];
if (!amount_msat_add_fee(amount, h->base_fee, h->proportional_fee))
/* Shouldn't happen, since we said it would route,
* amounts must be sane. */
abort();
*cltv += h->delay;
return true;
}
struct route_hop *route_from_dijkstra(const tal_t *ctx,
const struct gossmap *map,
const struct dijkstra *dij,
const struct gossmap_node *src,
struct amount_msat final_amount,
u32 final_cltv)
{
struct route_hop *hops = tal_arr(ctx, struct route_hop, 0);
if (!dijkstra_to_hops(&hops, map, dij, src, &final_amount, &final_cltv))
return tal_free(hops);
return hops;
}

View File

@@ -2,14 +2,41 @@
#ifndef LIGHTNING_COMMON_ROUTE_H
#define LIGHTNING_COMMON_ROUTE_H
#include "config.h"
#include <bitcoin/short_channel_id.h>
#include <common/amount.h>
#include <common/node_id.h>
struct dijkstra;
struct gossmap;
struct gossmap_chan;
struct gossmap_node;
struct route {
int dir;
struct gossmap_chan *c;
enum route_hop_style {
ROUTE_HOP_LEGACY = 1,
ROUTE_HOP_TLV = 2,
};
/**
* struct route_hop: a hop in a route.
*
* @scid: the short_channel_id.
* @direction: 0 (dest node_id < src node_id), 1 (dest node_id > src).
* @node_id: the node_id of the destination of this hop.
* @amount: amount to send through this hop.
* @delay: total cltv delay at this hop.
* @blinding: blinding key for this hop (if any)
* @enctlv: encrypted TLV for this hop (if any)
* @style: onion encoding style for this hop.
*/
struct route_hop {
struct short_channel_id scid;
int direction;
struct node_id node_id;
struct amount_msat amount;
u32 delay;
struct pubkey *blinding;
u8 *enctlv;
enum route_hop_style style;
};
/* Can c carry amount in dir? */
@@ -37,8 +64,10 @@ u64 route_score_cheaper(u32 distance,
struct amount_msat risk);
/* Extract route tal_arr from completed dijkstra: NULL if none. */
struct route **route_from_dijkstra(const tal_t *ctx,
const struct gossmap *map,
const struct dijkstra *dij,
const struct gossmap_node *cur);
struct route_hop *route_from_dijkstra(const tal_t *ctx,
const struct gossmap *map,
const struct dijkstra *dij,
const struct gossmap_node *src,
struct amount_msat final_amount,
u32 final_cltv);
#endif /* LIGHTNING_COMMON_ROUTE_H */

View File

@@ -143,14 +143,16 @@ static void add_connection(int store_fd,
}
static bool channel_is_between(const struct gossmap *gossmap,
const struct route *route,
const struct route_hop *route,
const struct gossmap_node *a,
const struct gossmap_node *b)
{
if (route->c->half[route->dir].nodeidx
const struct gossmap_chan *c = gossmap_find_chan(gossmap, &route->scid);
if (c->half[route->direction].nodeidx
!= gossmap_node_idx(gossmap, a))
return false;
if (route->c->half[!route->dir].nodeidx
if (c->half[!route->direction].nodeidx
!= gossmap_node_idx(gossmap, b))
return false;
@@ -177,7 +179,7 @@ int main(void)
struct node_id a, b, c, d;
struct gossmap_node *a_node, *b_node, *c_node, *d_node;
const struct dijkstra *dij;
struct route **route;
struct route_hop *route;
int store_fd;
struct gossmap *gossmap;
const double riskfactor = 1.0;
@@ -238,17 +240,19 @@ int main(void)
dij = dijkstra(tmpctx, gossmap, c_node, AMOUNT_MSAT(1000), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(1000), 0);
assert(route);
assert(tal_count(route) == 2);
assert(channel_is_between(gossmap, route[0], a_node, b_node));
assert(channel_is_between(gossmap, route[1], b_node, c_node));
assert(channel_is_between(gossmap, &route[0], a_node, b_node));
assert(channel_is_between(gossmap, &route[1], b_node, c_node));
/* We should not be able to find a route that exceeds our own capacity */
dij = dijkstra(tmpctx, gossmap, c_node, AMOUNT_MSAT(1000001), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(1000), 0);
assert(!route);
/* Now test with a query that exceeds the channel capacity after adding
@@ -256,7 +260,8 @@ int main(void)
dij = dijkstra(tmpctx, gossmap, c_node, AMOUNT_MSAT(999999), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(999999), 0);
assert(!route);
/* This should fail to return a route because it is smaller than these
@@ -264,7 +269,8 @@ int main(void)
dij = dijkstra(tmpctx, gossmap, c_node, AMOUNT_MSAT(1), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(1), 0);
assert(!route);
/* {'active': True, 'short_id': '6990:2:1/0', 'fee_per_kw': 10, 'delay': 5, 'message_flags': 1, 'htlc_maximum_msat': 500000, 'htlc_minimum_msat': 100, 'channel_flags': 0, 'destination': '02cca6c5c966fcf61d121e3a70e03a1cd9eeeea024b26ea666ce974d43b242e636', 'source': '03c173897878996287a8100469f954dd820fcd8941daed91c327f168f3329be0bf', 'last_update': 1504064344}, */
@@ -282,7 +288,8 @@ int main(void)
dij = dijkstra(tmpctx, gossmap, d_node, AMOUNT_MSAT(499968), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(499968), 0);
assert(route);
/* This should fail to return a route because it's larger than the
@@ -290,7 +297,8 @@ int main(void)
dij = dijkstra(tmpctx, gossmap, d_node, AMOUNT_MSAT(499968+1), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(499968+1), 0);
assert(!route);
tal_free(tmpctx);

View File

@@ -131,14 +131,15 @@ static void add_connection(int store_fd,
}
static bool channel_is_between(const struct gossmap *gossmap,
const struct route *route,
const struct route_hop *route,
const struct gossmap_node *a,
const struct gossmap_node *b)
{
if (route->c->half[route->dir].nodeidx
const struct gossmap_chan *c = gossmap_find_chan(gossmap, &route->scid);
if (c->half[route->direction].nodeidx
!= gossmap_node_idx(gossmap, a))
return false;
if (route->c->half[!route->dir].nodeidx
if (c->half[!route->direction].nodeidx
!= gossmap_node_idx(gossmap, b))
return false;
@@ -173,7 +174,7 @@ int main(void)
struct gossmap_node *a_node, *b_node, *c_node, *d_node;
struct privkey tmp;
const struct dijkstra *dij;
struct route **route;
struct route_hop *route;
int store_fd;
struct gossmap *gossmap;
const double riskfactor = 1.0;
@@ -207,9 +208,11 @@ int main(void)
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node, AMOUNT_MSAT(1000), 10);
assert(route);
assert(tal_count(route) == 1);
assert(amount_msat_eq(route[0].amount, AMOUNT_MSAT(1000)));
assert(route[0].delay == 10);
/* A<->B<->C */
memset(&tmp, 'c', sizeof(tmp));
@@ -224,9 +227,14 @@ int main(void)
dij = dijkstra(tmpctx, gossmap, c_node, AMOUNT_MSAT(1000), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(1000), 11);
assert(route);
assert(tal_count(route) == 2);
assert(amount_msat_eq(route[1].amount, AMOUNT_MSAT(1000)));
assert(route[1].delay == 11);
assert(amount_msat_eq(route[0].amount, AMOUNT_MSAT(1001)));
assert(route[0].delay == 12);
/* A<->D<->C: Lower base, higher percentage. */
memset(&tmp, 'd', sizeof(tmp));
@@ -246,22 +254,32 @@ int main(void)
dij = dijkstra(tmpctx, gossmap, c_node, AMOUNT_MSAT(1000), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(1000), 12);
assert(route);
assert(tal_count(route) == 2);
assert(channel_is_between(gossmap, route[0], a_node, d_node));
assert(channel_is_between(gossmap, route[1], d_node, c_node));
assert(channel_is_between(gossmap, &route[0], a_node, d_node));
assert(channel_is_between(gossmap, &route[1], d_node, c_node));
assert(amount_msat_eq(route[1].amount, AMOUNT_MSAT(1000)));
assert(route[1].delay == 12);
assert(amount_msat_eq(route[0].amount, AMOUNT_MSAT(1000)));
assert(route[0].delay == 13);
/* Will go via B for large amounts. */
dij = dijkstra(tmpctx, gossmap, c_node, AMOUNT_MSAT(3000000), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(3000000), 13);
assert(route);
assert(tal_count(route) == 2);
assert(channel_is_between(gossmap, route[0], a_node, b_node));
assert(channel_is_between(gossmap, route[1], b_node, c_node));
assert(channel_is_between(gossmap, &route[0], a_node, b_node));
assert(channel_is_between(gossmap, &route[1], b_node, c_node));
assert(amount_msat_eq(route[1].amount, AMOUNT_MSAT(3000000)));
assert(route[1].delay == 13);
assert(amount_msat_eq(route[0].amount, AMOUNT_MSAT(3000000 + 3 + 1)));
assert(route[0].delay == 14);
/* Make B->C inactive, force it back via D */
update_connection(store_fd, &b, &c, 1, 1, 1, true);
@@ -276,11 +294,16 @@ int main(void)
dij = dijkstra(tmpctx, gossmap, c_node, AMOUNT_MSAT(3000000), riskfactor,
route_can_carry_unless_disabled,
route_score_cheaper, NULL);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node);
route = route_from_dijkstra(tmpctx, gossmap, dij, a_node,
AMOUNT_MSAT(3000000), 14);
assert(route);
assert(tal_count(route) == 2);
assert(channel_is_between(gossmap, route[0], a_node, d_node));
assert(channel_is_between(gossmap, route[1], d_node, c_node));
assert(channel_is_between(gossmap, &route[0], a_node, d_node));
assert(channel_is_between(gossmap, &route[1], d_node, c_node));
assert(amount_msat_eq(route[1].amount, AMOUNT_MSAT(3000000)));
assert(route[1].delay == 14);
assert(amount_msat_eq(route[0].amount, AMOUNT_MSAT(3000000 + 6)));
assert(route[0].delay == 15);
tal_free(tmpctx);
secp256k1_context_destroy(secp256k1_ctx);