pyln-client/gossmap: Don't mix bytes and GossmapNodeId

Do not mix bytes and GossmapNodeId when accessing Gossmap.nodes dicts.

Therefore the definion got GossmapNodeId also needed to be pulled to the
beginning of the file.
This commit is contained in:
Michael Schmoock
2021-09-08 06:25:21 +09:30
committed by Rusty Russell
parent ac27217114
commit ba2bcac530

View File

@@ -49,14 +49,43 @@ class GossmapHalfchannel(object):
return "GossmapHalfchannel[{}x{}]".format(str(self.channel.scid), self.direction) return "GossmapHalfchannel[{}x{}]".format(str(self.channel.scid), self.direction)
class GossmapNodeId(object):
def __init__(self, buf: bytes):
if len(buf) != 33 or (buf[0] != 2 and buf[0] != 3):
raise ValueError("{} is not a valid node_id".format(buf.hex()))
self.nodeid = buf
def to_pubkey(self) -> PublicKey:
return PublicKey(self.nodeid)
def __eq__(self, other):
if not isinstance(other, GossmapNodeId):
return False
return self.nodeid == other.nodeid
def __hash__(self):
return self.nodeid.__hash__()
def __repr__(self):
return "GossmapNodeId[{}]".format(self.nodeid.hex())
def from_str(self, s: str):
if s.startswith('0x'):
s = s[2:]
if len(s) != 67:
raise ValueError(f"{s} is not a valid hexstring of a node_id")
return GossmapNodeId(bytes.fromhex(s))
class GossmapChannel(object): class GossmapChannel(object):
"""A channel: fields of channel_announcement are in .fields, optional updates are in .updates_fields, which can be None if there has been no channel update.""" """A channel: fields of channel_announcement are in .fields, optional updates are in .updates_fields, which can be None if there has been no channel update."""
def __init__(self, def __init__(self,
fields: Dict[str, Any], fields: Dict[str, Any],
announce_offset: int, announce_offset: int,
scid, scid,
node1_id: bytes, node1_id: GossmapNodeId,
node2_id: bytes, node2_id: GossmapNodeId,
is_private: bool): is_private: bool):
self.fields = fields self.fields = fields
self.announce_offset = announce_offset self.announce_offset = announce_offset
@@ -95,35 +124,6 @@ class GossmapChannel(object):
return "GossmapChannel[{}]".format(str(self.scid)) return "GossmapChannel[{}]".format(str(self.scid))
class GossmapNodeId(object):
def __init__(self, buf: bytes):
if len(buf) != 33 or (buf[0] != 2 and buf[0] != 3):
raise ValueError("{} is not a valid node_id".format(buf.hex()))
self.nodeid = buf
def to_pubkey(self) -> PublicKey:
return PublicKey(self.nodeid)
def __eq__(self, other):
if not isinstance(other, GossmapNodeId):
return False
return self.nodeid == other.nodeid
def __hash__(self):
return self.nodeid.__hash__()
def __repr__(self):
return "GossmapNodeId[{}]".format(self.nodeid.hex())
def from_str(self, s: str):
if s.startswith('0x'):
s = s[2:]
if len(s) != 67:
raise ValueError(f"{s} is not a valid hexstring of a node_id")
return GossmapNodeId(bytes.fromhex(s))
class GossmapNode(object): class GossmapNode(object):
"""A node: fields of node_announcement are in .announce_fields, which can be None of there has been no node announcement. """A node: fields of node_announcement are in .announce_fields, which can be None of there has been no node announcement.
@@ -145,7 +145,7 @@ class Gossmap(object):
self.store_filename = store_filename self.store_filename = store_filename
self.store_file = open(store_filename, "rb") self.store_file = open(store_filename, "rb")
self.store_buf = bytes() self.store_buf = bytes()
self.nodes: Dict[bytes, GossmapNode] = {} self.nodes: Dict[GossmapNodeId, GossmapNode] = {}
self.channels: Dict[ShortChannelId, GossmapChannel] = {} self.channels: Dict[ShortChannelId, GossmapChannel] = {}
self._last_scid: str = None self._last_scid: str = None
version = self.store_file.read(1) version = self.store_file.read(1)