renepay: switch from arc_t to struct arc.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2023-08-02 07:00:09 +09:30
parent b793dc9224
commit 15c8a6f6fe
2 changed files with 47 additions and 76 deletions

View File

@@ -305,17 +305,6 @@ static inline struct arc arc_from_parts(u32 chanidx, int chandir, u32 part, bool
return arc; return arc;
} }
typedef union
{
struct{
u32 dual: 1;
u32 part: PARTS_BITS;
u32 chandir: 1;
u32 chanidx: (32-1-PARTS_BITS-1);
};
u32 idx;
} arc_t;
#define MAX(x, y) (((x) > (y)) ? (x) : (y)) #define MAX(x, y) (((x) > (y)) ? (x) : (y))
#define MIN(x, y) (((x) < (y)) ? (x) : (y)) #define MIN(x, y) (((x) < (y)) ? (x) : (y))
@@ -356,8 +345,8 @@ struct linear_network
// notice that a tail node is not needed, // notice that a tail node is not needed,
// because the tail of arc is the head of dual(arc) // because the tail of arc is the head of dual(arc)
arc_t *node_adjacency_next_arc; struct arc *node_adjacency_next_arc;
arc_t *node_adjacency_first_arc; struct arc *node_adjacency_first_arc;
// probability and fee cost associated to an arc // probability and fee cost associated to an arc
s64 *arc_prob_cost, *arc_fee_cost; s64 *arc_prob_cost, *arc_fee_cost;
@@ -381,22 +370,24 @@ struct residual_network {
/* Helper function. /* Helper function.
* Given an arc idx, return the dual's idx in the residual network. */ * Given an arc idx, return the dual's idx in the residual network. */
static arc_t arc_dual(arc_t arc) static struct arc arc_dual(struct arc arc)
{ {
arc.dual ^= 1; arc.idx ^= (1U << ARC_DUAL_BITOFF);
return arc; return arc;
} }
/* Helper function. */ /* Helper function. */
static bool arc_is_dual(const arc_t arc) static bool arc_is_dual(struct arc arc)
{ {
return arc.dual == 1; bool dual;
arc_to_parts(arc, NULL, NULL, NULL, &dual);
return dual;
} }
/* Helper function. /* Helper function.
* Given an arc of the network (not residual) give me the flow. */ * Given an arc of the network (not residual) give me the flow. */
static s64 get_arc_flow( static s64 get_arc_flow(
const struct residual_network *network, const struct residual_network *network,
const arc_t arc) const struct arc arc)
{ {
assert(!arc_is_dual(arc)); assert(!arc_is_dual(arc));
assert(arc_dual(arc).idx < tal_count(network->cap)); assert(arc_dual(arc).idx < tal_count(network->cap));
@@ -406,7 +397,7 @@ static s64 get_arc_flow(
/* Helper function. /* Helper function.
* Given an arc idx, return the node from which this arc emanates in the residual network. */ * Given an arc idx, return the node from which this arc emanates in the residual network. */
static u32 arc_tail(const struct linear_network *linear_network, static u32 arc_tail(const struct linear_network *linear_network,
const arc_t arc) const struct arc arc)
{ {
assert(arc.idx < tal_count(linear_network->arc_tail_node)); assert(arc.idx < tal_count(linear_network->arc_tail_node));
return linear_network->arc_tail_node[ arc.idx ]; return linear_network->arc_tail_node[ arc.idx ];
@@ -414,9 +405,9 @@ static u32 arc_tail(const struct linear_network *linear_network,
/* Helper function. /* Helper function.
* Given an arc idx, return the node that this arc is pointing to in the residual network. */ * Given an arc idx, return the node that this arc is pointing to in the residual network. */
static u32 arc_head(const struct linear_network *linear_network, static u32 arc_head(const struct linear_network *linear_network,
const arc_t arc) const struct arc arc)
{ {
const arc_t dual = arc_dual(arc); const struct arc dual = arc_dual(arc);
assert(dual.idx < tal_count(linear_network->arc_tail_node)); assert(dual.idx < tal_count(linear_network->arc_tail_node));
return linear_network->arc_tail_node[dual.idx]; return linear_network->arc_tail_node[dual.idx];
} }
@@ -424,7 +415,7 @@ static u32 arc_head(const struct linear_network *linear_network,
/* Helper function. /* Helper function.
* Given node idx `node`, return the idx of the first arc whose tail is `node`. * Given node idx `node`, return the idx of the first arc whose tail is `node`.
* */ * */
static arc_t node_adjacency_begin( static struct arc node_adjacency_begin(
const struct linear_network * linear_network, const struct linear_network * linear_network,
const u32 node) const u32 node)
{ {
@@ -434,39 +425,21 @@ static arc_t node_adjacency_begin(
/* Helper function. /* Helper function.
* Is this the end of the adjacency list. */ * Is this the end of the adjacency list. */
static bool node_adjacency_end(const arc_t arc) static bool node_adjacency_end(const struct arc arc)
{ {
return arc.idx == INVALID_INDEX; return arc.idx == INVALID_INDEX;
} }
/* Helper function. /* Helper function.
* Given node idx `node` and `arc`, returns the idx of the next arc whose tail is `node`. */ * Given node idx `node` and `arc`, returns the idx of the next arc whose tail is `node`. */
static arc_t node_adjacency_next( static struct arc node_adjacency_next(
const struct linear_network *linear_network, const struct linear_network *linear_network,
const arc_t arc) const struct arc arc)
{ {
assert(arc.idx < tal_count(linear_network->node_adjacency_next_arc)); assert(arc.idx < tal_count(linear_network->node_adjacency_next_arc));
return linear_network->node_adjacency_next_arc[arc.idx]; return linear_network->node_adjacency_next_arc[arc.idx];
} }
/* Helper function.
* Given a channel index, we should be able to deduce the arc id. */
static arc_t channel_idx_to_arc(
const u32 chan_idx,
int half,
int part,
int dual)
{
arc_t arc;
arc.dual=dual;
arc.part=part;
arc.chandir=half;
arc.chanidx = chan_idx;
/* check that it doesn't overflow */
assert(arc.chanidx == chan_idx);
return arc;
}
// TODO(eduardo): unit test this // TODO(eduardo): unit test this
/* Split a directed channel into parts with linear cost function. */ /* Split a directed channel into parts with linear cost function. */
static void linearize_channel( static void linearize_channel(
@@ -538,14 +511,13 @@ static void init_residual_network(
{ {
const size_t max_num_arcs = linear_network->max_num_arcs; const size_t max_num_arcs = linear_network->max_num_arcs;
const size_t max_num_nodes = linear_network->max_num_nodes; const size_t max_num_nodes = linear_network->max_num_nodes;
for(u32 idx=0;idx<max_num_arcs;++idx)
{
arc_t arc = (arc_t){.idx=idx};
for(struct arc arc = {0};arc.idx < max_num_arcs; ++arc.idx)
{
if(arc_is_dual(arc)) if(arc_is_dual(arc))
continue; continue;
arc_t dual = arc_dual(arc); struct arc dual = arc_dual(arc);
residual_network->cap[arc.idx]=linear_network->capacity[arc.idx]; residual_network->cap[arc.idx]=linear_network->capacity[arc.idx];
residual_network->cap[dual.idx]=0; residual_network->cap[dual.idx]=0;
@@ -562,19 +534,18 @@ static void combine_cost_function(
struct residual_network *residual_network, struct residual_network *residual_network,
s64 mu) s64 mu)
{ {
for(u32 arc_idx=0;arc_idx<linear_network->max_num_arcs;++arc_idx) for(struct arc arc = {0};arc.idx < linear_network->max_num_arcs; ++arc.idx)
{ {
arc_t arc = (arc_t){.idx=arc_idx};
if(arc_tail(linear_network,arc)==INVALID_INDEX) if(arc_tail(linear_network,arc)==INVALID_INDEX)
continue; continue;
const s64 pcost = linear_network->arc_prob_cost[arc_idx], const s64 pcost = linear_network->arc_prob_cost[arc.idx],
fcost = linear_network->arc_fee_cost[arc_idx]; fcost = linear_network->arc_fee_cost[arc.idx];
const s64 combined = pcost==INFINITE || fcost==INFINITE ? INFINITE : const s64 combined = pcost==INFINITE || fcost==INFINITE ? INFINITE :
mu*fcost + (MU_MAX-1-mu)*pcost; mu*fcost + (MU_MAX-1-mu)*pcost;
residual_network->cost[arc_idx] residual_network->cost[arc.idx]
= mu==0 ? pcost : = mu==0 ? pcost :
(mu==(MU_MAX-1) ? fcost : combined); (mu==(MU_MAX-1) ? fcost : combined);
} }
@@ -583,13 +554,13 @@ static void combine_cost_function(
static void linear_network_add_adjacenct_arc( static void linear_network_add_adjacenct_arc(
struct linear_network *linear_network, struct linear_network *linear_network,
const u32 node_idx, const u32 node_idx,
const arc_t arc) const struct arc arc)
{ {
assert(arc.idx < tal_count(linear_network->arc_tail_node)); assert(arc.idx < tal_count(linear_network->arc_tail_node));
linear_network->arc_tail_node[arc.idx] = node_idx; linear_network->arc_tail_node[arc.idx] = node_idx;
assert(node_idx < tal_count(linear_network->node_adjacency_first_arc)); assert(node_idx < tal_count(linear_network->node_adjacency_first_arc));
const arc_t first_arc = linear_network->node_adjacency_first_arc[node_idx]; const struct arc first_arc = linear_network->node_adjacency_first_arc[node_idx];
assert(arc.idx < tal_count(linear_network->node_adjacency_next_arc)); assert(arc.idx < tal_count(linear_network->node_adjacency_next_arc));
linear_network->node_adjacency_next_arc[arc.idx]=first_arc; linear_network->node_adjacency_next_arc[arc.idx]=first_arc;
@@ -614,11 +585,11 @@ static void init_linear_network(
for(size_t i=0;i<tal_count(linear_network->arc_tail_node);++i) for(size_t i=0;i<tal_count(linear_network->arc_tail_node);++i)
linear_network->arc_tail_node[i]=INVALID_INDEX; linear_network->arc_tail_node[i]=INVALID_INDEX;
linear_network->node_adjacency_next_arc = tal_arr(linear_network,arc_t,max_num_arcs); linear_network->node_adjacency_next_arc = tal_arr(linear_network,struct arc,max_num_arcs);
for(size_t i=0;i<tal_count(linear_network->node_adjacency_next_arc);++i) for(size_t i=0;i<tal_count(linear_network->node_adjacency_next_arc);++i)
linear_network->node_adjacency_next_arc[i].idx=INVALID_INDEX; linear_network->node_adjacency_next_arc[i].idx=INVALID_INDEX;
linear_network->node_adjacency_first_arc = tal_arr(linear_network,arc_t,max_num_nodes); linear_network->node_adjacency_first_arc = tal_arr(linear_network,struct arc,max_num_nodes);
for(size_t i=0;i<tal_count(linear_network->node_adjacency_first_arc);++i) for(size_t i=0;i<tal_count(linear_network->node_adjacency_first_arc);++i)
linear_network->node_adjacency_first_arc[i].idx=INVALID_INDEX; linear_network->node_adjacency_first_arc[i].idx=INVALID_INDEX;
@@ -682,7 +653,7 @@ static void init_linear_network(
{ {
// if(capacity[k]==0)continue; // if(capacity[k]==0)continue;
arc_t arc = channel_idx_to_arc(chan_id,half,k,0); struct arc arc = arc_from_parts(chan_id, half, k, false);
linear_network_add_adjacenct_arc(linear_network,node_id,arc); linear_network_add_adjacenct_arc(linear_network,node_id,arc);
@@ -692,7 +663,7 @@ static void init_linear_network(
linear_network->arc_fee_cost[arc.idx] = fee_cost; linear_network->arc_fee_cost[arc.idx] = fee_cost;
// + the respective dual // + the respective dual
arc_t dual = arc_dual(arc); struct arc dual = arc_dual(arc);
linear_network_add_adjacenct_arc(linear_network,next_id,dual); linear_network_add_adjacenct_arc(linear_network,next_id,dual);
@@ -723,7 +694,7 @@ static int find_admissible_path(
const struct residual_network *residual_network, const struct residual_network *residual_network,
const u32 source, const u32 source,
const u32 target, const u32 target,
arc_t *prev) struct arc *prev)
{ {
tal_t *this_ctx = tal(tmpctx,tal_t); tal_t *this_ctx = tal(tmpctx,tal_t);
@@ -753,7 +724,7 @@ static int find_admissible_path(
break; break;
} }
for(arc_t arc = node_adjacency_begin(linear_network,cur); for(struct arc arc = node_adjacency_begin(linear_network,cur);
!node_adjacency_end(arc); !node_adjacency_end(arc);
arc = node_adjacency_next(linear_network,arc)) arc = node_adjacency_next(linear_network,arc))
{ {
@@ -787,7 +758,7 @@ static s64 get_augmenting_flow(
const struct residual_network *residual_network, const struct residual_network *residual_network,
const u32 source, const u32 source,
const u32 target, const u32 target,
const arc_t *prev) const struct arc *prev)
{ {
s64 flow = INFINITE; s64 flow = INFINITE;
@@ -795,7 +766,7 @@ static s64 get_augmenting_flow(
while(cur!=source) while(cur!=source)
{ {
assert(cur<tal_count(prev)); assert(cur<tal_count(prev));
const arc_t arc = prev[cur]; const struct arc arc = prev[cur];
flow = MIN(flow , residual_network->cap[arc.idx]); flow = MIN(flow , residual_network->cap[arc.idx]);
// we are traversing in the opposite direction to the flow, // we are traversing in the opposite direction to the flow,
@@ -813,7 +784,7 @@ static void augment_flow(
struct residual_network *residual_network, struct residual_network *residual_network,
const u32 source, const u32 source,
const u32 target, const u32 target,
const arc_t *prev, const struct arc *prev,
s64 flow) s64 flow)
{ {
u32 cur = target; u32 cur = target;
@@ -821,8 +792,8 @@ static void augment_flow(
while(cur!=source) while(cur!=source)
{ {
assert(cur < tal_count(prev)); assert(cur < tal_count(prev));
const arc_t arc = prev[cur]; const struct arc arc = prev[cur];
const arc_t dual = arc_dual(arc); const struct arc dual = arc_dual(arc);
assert(arc.idx < tal_count(residual_network->cap)); assert(arc.idx < tal_count(residual_network->cap));
assert(dual.idx < tal_count(residual_network->cap)); assert(dual.idx < tal_count(residual_network->cap));
@@ -862,7 +833,7 @@ static int find_feasible_flow(
/* path information /* path information
* prev: is the id of the arc that lead to the node. */ * prev: is the id of the arc that lead to the node. */
arc_t *prev = tal_arr(this_ctx,arc_t,linear_network->max_num_nodes); struct arc *prev = tal_arr(this_ctx,struct arc,linear_network->max_num_nodes);
while(amount>0) while(amount>0)
{ {
@@ -903,7 +874,7 @@ static int find_optimal_path(
const struct residual_network* residual_network, const struct residual_network* residual_network,
const u32 source, const u32 source,
const u32 target, const u32 target,
arc_t *prev) struct arc *prev)
{ {
tal_t *this_ctx = tal(tmpctx,tal_t); tal_t *this_ctx = tal(tmpctx,tal_t);
int ret = RENEPAY_ERR_NOFEASIBLEFLOW; int ret = RENEPAY_ERR_NOFEASIBLEFLOW;
@@ -935,7 +906,7 @@ static int find_optimal_path(
break; break;
} }
for(arc_t arc = node_adjacency_begin(linear_network,cur); for(struct arc arc = node_adjacency_begin(linear_network,cur);
!node_adjacency_end(arc); !node_adjacency_end(arc);
arc = node_adjacency_next(linear_network,arc)) arc = node_adjacency_next(linear_network,arc))
{ {
@@ -971,13 +942,13 @@ static void zero_flow(
for(u32 node=0;node<linear_network->max_num_nodes;++node) for(u32 node=0;node<linear_network->max_num_nodes;++node)
{ {
residual_network->potential[node]=0; residual_network->potential[node]=0;
for(arc_t arc=node_adjacency_begin(linear_network,node); for(struct arc arc=node_adjacency_begin(linear_network,node);
!node_adjacency_end(arc); !node_adjacency_end(arc);
arc = node_adjacency_next(linear_network,arc)) arc = node_adjacency_next(linear_network,arc))
{ {
if(arc_is_dual(arc))continue; if(arc_is_dual(arc))continue;
arc_t dual = arc_dual(arc); struct arc dual = arc_dual(arc);
residual_network->cap[arc.idx] = linear_network->capacity[arc.idx]; residual_network->cap[arc.idx] = linear_network->capacity[arc.idx];
residual_network->cap[dual.idx] = 0; residual_network->cap[dual.idx] = 0;
@@ -1008,7 +979,7 @@ static int optimize_mcf(
int ret = RENEPAY_ERR_OK; int ret = RENEPAY_ERR_OK;
zero_flow(linear_network,residual_network); zero_flow(linear_network,residual_network);
arc_t *prev = tal_arr(this_ctx,arc_t,linear_network->max_num_nodes); struct arc *prev = tal_arr(this_ctx,struct arc,linear_network->max_num_nodes);
const s64 *const distance = dijkstra_distance_data(dijkstra); const s64 *const distance = dijkstra_distance_data(dijkstra);
@@ -1172,7 +1143,7 @@ static struct flow **
// Compute balance on the nodes. // Compute balance on the nodes.
for(u32 n = 0;n<max_num_nodes;++n) for(u32 n = 0;n<max_num_nodes;++n)
{ {
for(arc_t arc = node_adjacency_begin(linear_network,n); for(struct arc arc = node_adjacency_begin(linear_network,n);
!node_adjacency_end(arc); !node_adjacency_end(arc);
arc = node_adjacency_next(linear_network,arc)) arc = node_adjacency_next(linear_network,arc))
{ {
@@ -1180,11 +1151,14 @@ static struct flow **
continue; continue;
u32 m = arc_head(linear_network,arc); u32 m = arc_head(linear_network,arc);
s64 flow = get_arc_flow(residual_network,arc); s64 flow = get_arc_flow(residual_network,arc);
u32 chanidx;
int chandir;
balance[n] -= flow; balance[n] -= flow;
balance[m] += flow; balance[m] += flow;
chan_flow[arc.chanidx].half[arc.chandir] +=flow; arc_to_parts(arc, &chanidx, &chandir, NULL, NULL);
chan_flow[chanidx].half[chandir] +=flow;
} }
} }

View File

@@ -98,15 +98,12 @@ int main(int argc, char *argv[])
arc_to_parts(a, NULL, NULL, NULL, &dual); arc_to_parts(a, NULL, NULL, NULL, &dual);
assert(dual == i); assert(dual == i);
/* This code not converted yet! */
#if 0
assert(arc_is_dual(a) == dual); assert(arc_is_dual(a) == dual);
a = arc_dual(a); a = arc_dual(a);
arc_to_parts(a, NULL, NULL, NULL, &dual); arc_to_parts(a, NULL, NULL, NULL, &dual);
assert(dual == !i); assert(dual == !i);
assert(arc_is_dual(a) == dual); assert(arc_is_dual(a) == dual);
#endif
} }
common_shutdown(); common_shutdown();