diff --git a/contrib/pyln-client/pyln/client/gossmap.py b/contrib/pyln-client/pyln/client/gossmap.py index 25d92942d..0cf5523d6 100755 --- a/contrib/pyln-client/pyln/client/gossmap.py +++ b/contrib/pyln-client/pyln/client/gossmap.py @@ -3,7 +3,7 @@ from pyln.spec.bolt7 import (channel_announcement, channel_update, node_announcement) from pyln.proto import ShortChannelId, PublicKey -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Set, Optional, Union import io import struct @@ -273,12 +273,97 @@ class Gossmap(object): channel = self.get_channel(short_channel_id) return channel.half_channels[direction] + def get_neighbors_hc(self, + source: Union[GossmapNodeId, str, None] = None, + destination: Union[GossmapNodeId, str, None] = None, + depth: int = 0, + excludes: Union[Set[Any], List[Any]] = set()): + """ Returns a set[GossmapHalfchannel]` from `source` or towards + `destination` node ID. Using the optional `depth` greater than `0` + will result in a second, third, ... order list of connected + channels towards or from that node. + Note: only one of `source` or `destination` can be given. """ + assert (source is None) ^ (destination is None), "Only one of source or destination must be given" + assert depth >= 0, "Depth cannot be smaller than 0" + node = self.get_node(source if source else destination) + assert node is not None, "source or destination unknown" + if isinstance(excludes, List): + excludes = set(excludes) + + # first get set of reachable nodes ... + reachable = self.get_neighbors(source, destination, depth, excludes) + # and iterate and check any each source/dest channel from here + result = set() + for node in reachable: + for channel in node.channels: + if channel in excludes: + continue + other = channel.node1 if node != channel.node1 else channel.node2 + if other in reachable or other in excludes: + continue + direction = 0 + if source is not None and node > other: + direction = 1 + if destination is not None and node < other: + direction = 1 + hc = channel.half_channels[direction] + # skip excluded or non existent halfchannels + if hc is None or hc in excludes: + continue + result.add(hc) + return result + def get_node(self, node_id: Union[GossmapNodeId, str]): """ Resolves a node by its public key node_id """ if isinstance(node_id, str): node_id = GossmapNodeId.from_str(node_id) return self.nodes.get(node_id) + def get_neighbors(self, + source: Union[GossmapNodeId, str, None] = None, + destination: Union[GossmapNodeId, str, None] = None, + depth: int = 0, + excludes: Union[Set[Any], List[Any]] = set()): + """ Returns a set of nodes within a given depth from a source node """ + assert (source is None) ^ (destination is None), "Only one of source or destination must be given" + assert depth >= 0, "Depth cannot be smaller than 0" + node = self.get_node(source if source else destination) + assert node is not None, "source or destination unknown" + if isinstance(excludes, List): + excludes = set(excludes) + + result = set() + result.add(node) + inner = set() + inner.add(node) + while depth > 0: + shell = set() + for node in inner: + for channel in node.channels: + if channel in excludes: # skip excluded channels + continue + other = channel.node1 if channel.node1 != node else channel.node2 + direction = 0 + if source is not None and node > other: + direction = 1 + if destination is not None and node < other: + direction = 1 + if channel.half_channels[direction] is None: + continue # one way channel in the wrong direction + halfchannel = channel.half_channels[direction] + if halfchannel in excludes: # skip excluded halfchannels + continue + # skip excluded or already seen nodes + if other in excludes or other in inner or other in result: + continue + shell.add(other) + if len(shell) == 0: + break + depth -= 1 + result.update(shell) + inner = shell + return result + def _update_channel(self, rec: bytes, hdr: GossipStoreHeader): fields = channel_update.read(io.BytesIO(rec[2:]), {}) direction = fields['channel_flags'] & 1 diff --git a/contrib/pyln-client/tests/test_gossmap.py b/contrib/pyln-client/tests/test_gossmap.py index 0daa1b22a..3a6e87bc1 100644 --- a/contrib/pyln-client/tests/test_gossmap.py +++ b/contrib/pyln-client/tests/test_gossmap.py @@ -159,6 +159,8 @@ def test_mesh(tmp_path): scids = [scid12, scid14, scid23, scid25, scid36, scid45, scid47, scid56, scid58, scid69, scid78, scid89] + nodes = [g.get_node(nid) for nid in nodeids] + # check all nodes are there for nodeid in nodeids: node = g.get_node(nodeid) @@ -174,3 +176,121 @@ def test_mesh(tmp_path): assert str(channel.scid) == scid assert channel.half_channels[0] assert channel.half_channels[1] + + # check basic relations + # get_neighbors l5 in the middle depth=0 returns just that node + result = g.get_neighbors(source=nodeids[4]) + assert len(result) == 1 + assert str(next(iter(result)).node_id) == nodeids[4] + result = g.get_neighbors(source=nodeids[4], depth=1) + assert len(result) == 5 + # on depth=1 the cross l2, l4, l5, l6, l8 must be returned + assert nodes[1] in result + assert nodes[3] in result + assert nodes[4] in result + assert nodes[5] in result + assert nodes[7] in result + # on depth>=2 all nodes must be returned as we visited the whole graph + for d in range(2, 4): + result = g.get_neighbors(source=nodeids[4], depth=d) + assert len(result) == 9 + for node in nodes: + assert node in result + # get_neighbors on l9 with depth=3 must return all but l1 + result = g.get_neighbors(nodeids[8], depth=3) + assert len(result) == 8 + assert nodes[0] not in result + # get_neighbors on l9 with depth=4 and excludes l5 must return all but l5 + result = g.get_neighbors(nodeids[8], depth=4, excludes=[nodes[4]]) + assert len(result) == 8 + assert nodes[4] not in result + + # get_neighbors_hc l5 in the middle expect: 25, 45, 65 and 85 + result = g.get_neighbors_hc(source=nodeids[4]) + exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]] + exp_scidds = [scid25 + '/1', scid45 + '/0', scid56 + '/1', scid58 + '/0'] + assert len(result) == len(exp_ids) + for halfchan in result: + assert str(halfchan.source.node_id) == nodeids[4] + assert str(halfchan.destination.node_id) in exp_ids + assert str(halfchan) in exp_scidds + + # same but other direction + result = g.get_neighbors_hc(destination=nodeids[4]) + exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]] + exp_scidds = [scid25 + '/0', scid45 + '/1', scid56 + '/0', scid58 + '/1'] + assert len(result) == len(exp_ids) + for halfchan in result: + assert str(halfchan.destination.node_id) == nodeids[4] + assert str(halfchan.source.node_id) in exp_ids + assert str(halfchan) in exp_scidds + + # get all channels which have l1 as destination + result = g.get_neighbors_hc(destination=nodeids[0]) + exp_ids = [nodeids[1], nodeids[3]] + exp_scidds = [scid12 + '/0', scid14 + '/1'] + assert len(result) == len(exp_ids) + for halfchan in result: + assert str(halfchan.destination.node_id) == nodeids[0] + assert str(halfchan.source.node_id) in exp_ids + assert str(halfchan) in exp_scidds + + # l5 as destination in the middle but depth=1, so the outer ring + # epxect: 12, 14, 32, 36, 74, 78, 98, 96 + result = g.get_neighbors_hc(destination=nodeids[4], depth=1) + exp_scidds = [scid12 + '/1', scid14 + '/0', scid23 + '/1', scid36 + '/1', + scid47 + '/0', scid69 + '/1', scid78 + '/0', scid89 + '/0'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds + + # same but other direction + result = g.get_neighbors_hc(source=nodeids[4], depth=1) + exp_scidds = [scid12 + '/0', scid14 + '/1', scid23 + '/0', scid36 + '/0', + scid47 + '/1', scid69 + '/0', scid78 + '/1', scid89 + '/1'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds + + # l9 as destination and depth=2 expect: 23 25 45 47 + result = g.get_neighbors_hc(destination=nodeids[8], depth=2) + exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1', scid47 + '/1'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds + + # l9 as destination depth=2 exclude=[l7] expect: 23 25 45 + result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=[nodes[6]]) + exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds + + # same as above, but excludes halfchannels of l7 expect: 23 25 45 + hcs = [c.half_channels[0] for c in nodes[6].channels] + hcs += [c.half_channels[1] for c in nodes[6].channels] + result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=hcs) + exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds + + # again, same as above, but excludes channels of l7 expect: 23 25 45 + chs = [c for c in nodes[6].channels] + result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=chs) + exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds + + # l9 as destination and depth=3 expect: 12 14 + result = g.get_neighbors_hc(destination=nodeids[8], depth=3) + exp_scidds = [scid12 + '/1', scid14 + '/0'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds + + # l9 as destination and depth>=4 expect: empty set + for d in range(4, 6): + result = g.get_neighbors_hc(destination=nodeids[8], depth=d) + assert len(result) == 0