pygossmap: adds get_neighbors and get_neighbors_hc flodding method

This commit is contained in:
Michael Schmoock
2023-02-18 00:01:53 +01:00
committed by Rusty Russell
parent 5a9a3d83c9
commit 9409f2f1ea
2 changed files with 206 additions and 1 deletions

View File

@@ -159,6 +159,8 @@ def test_mesh(tmp_path):
scids = [scid12, scid14, scid23, scid25, scid36, scid45, scid47, scid56,
scid58, scid69, scid78, scid89]
nodes = [g.get_node(nid) for nid in nodeids]
# check all nodes are there
for nodeid in nodeids:
node = g.get_node(nodeid)
@@ -174,3 +176,121 @@ def test_mesh(tmp_path):
assert str(channel.scid) == scid
assert channel.half_channels[0]
assert channel.half_channels[1]
# check basic relations
# get_neighbors l5 in the middle depth=0 returns just that node
result = g.get_neighbors(source=nodeids[4])
assert len(result) == 1
assert str(next(iter(result)).node_id) == nodeids[4]
result = g.get_neighbors(source=nodeids[4], depth=1)
assert len(result) == 5
# on depth=1 the cross l2, l4, l5, l6, l8 must be returned
assert nodes[1] in result
assert nodes[3] in result
assert nodes[4] in result
assert nodes[5] in result
assert nodes[7] in result
# on depth>=2 all nodes must be returned as we visited the whole graph
for d in range(2, 4):
result = g.get_neighbors(source=nodeids[4], depth=d)
assert len(result) == 9
for node in nodes:
assert node in result
# get_neighbors on l9 with depth=3 must return all but l1
result = g.get_neighbors(nodeids[8], depth=3)
assert len(result) == 8
assert nodes[0] not in result
# get_neighbors on l9 with depth=4 and excludes l5 must return all but l5
result = g.get_neighbors(nodeids[8], depth=4, excludes=[nodes[4]])
assert len(result) == 8
assert nodes[4] not in result
# get_neighbors_hc l5 in the middle expect: 25, 45, 65 and 85
result = g.get_neighbors_hc(source=nodeids[4])
exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]]
exp_scidds = [scid25 + '/1', scid45 + '/0', scid56 + '/1', scid58 + '/0']
assert len(result) == len(exp_ids)
for halfchan in result:
assert str(halfchan.source.node_id) == nodeids[4]
assert str(halfchan.destination.node_id) in exp_ids
assert str(halfchan) in exp_scidds
# same but other direction
result = g.get_neighbors_hc(destination=nodeids[4])
exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]]
exp_scidds = [scid25 + '/0', scid45 + '/1', scid56 + '/0', scid58 + '/1']
assert len(result) == len(exp_ids)
for halfchan in result:
assert str(halfchan.destination.node_id) == nodeids[4]
assert str(halfchan.source.node_id) in exp_ids
assert str(halfchan) in exp_scidds
# get all channels which have l1 as destination
result = g.get_neighbors_hc(destination=nodeids[0])
exp_ids = [nodeids[1], nodeids[3]]
exp_scidds = [scid12 + '/0', scid14 + '/1']
assert len(result) == len(exp_ids)
for halfchan in result:
assert str(halfchan.destination.node_id) == nodeids[0]
assert str(halfchan.source.node_id) in exp_ids
assert str(halfchan) in exp_scidds
# l5 as destination in the middle but depth=1, so the outer ring
# epxect: 12, 14, 32, 36, 74, 78, 98, 96
result = g.get_neighbors_hc(destination=nodeids[4], depth=1)
exp_scidds = [scid12 + '/1', scid14 + '/0', scid23 + '/1', scid36 + '/1',
scid47 + '/0', scid69 + '/1', scid78 + '/0', scid89 + '/0']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# same but other direction
result = g.get_neighbors_hc(source=nodeids[4], depth=1)
exp_scidds = [scid12 + '/0', scid14 + '/1', scid23 + '/0', scid36 + '/0',
scid47 + '/1', scid69 + '/0', scid78 + '/1', scid89 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# l9 as destination and depth=2 expect: 23 25 45 47
result = g.get_neighbors_hc(destination=nodeids[8], depth=2)
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1', scid47 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# l9 as destination depth=2 exclude=[l7] expect: 23 25 45
result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=[nodes[6]])
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# same as above, but excludes halfchannels of l7 expect: 23 25 45
hcs = [c.half_channels[0] for c in nodes[6].channels]
hcs += [c.half_channels[1] for c in nodes[6].channels]
result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=hcs)
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# again, same as above, but excludes channels of l7 expect: 23 25 45
chs = [c for c in nodes[6].channels]
result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=chs)
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# l9 as destination and depth=3 expect: 12 14
result = g.get_neighbors_hc(destination=nodeids[8], depth=3)
exp_scidds = [scid12 + '/1', scid14 + '/0']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# l9 as destination and depth>=4 expect: empty set
for d in range(4, 6):
result = g.get_neighbors_hc(destination=nodeids[8], depth=d)
assert len(result) == 0