pygossmap: adds missing __str__, __eq__ and __hash__

Also caches certain __hash__ and __str__ operations,
This way graph operations can be done quicker.
This commit is contained in:
Michael Schmoock
2023-02-14 16:28:03 +01:00
committed by Rusty Russell
parent be60f2ac33
commit eb9cb5ef31

View File

@@ -52,9 +52,25 @@ class GossmapHalfchannel(object):
self.fee_base_msat: int = fields['fee_base_msat']
self.fee_proportional_millionths: int = fields['fee_proportional_millionths']
# Cache the _scidd and hash to have faster operation later
# Unfortunately the @final decorator only comes for python3.8
self._scidd = f"{self.channel.scid}/{self.direction}"
self._numscidd = direction << 63 | self.channel.scid.to_int()
def __repr__(self):
return f"GossmapHalfchannel[{self._scidd}]"
def __eq__(self, other):
if not isinstance(other, GossmapHalfchannel):
return False
return self._numscidd == other._numscidd
def __str__(self):
return self._scidd
def __hash__(self):
return self._numscidd
class GossmapNodeId(object):
def __init__(self, buf: Union[bytes, str]):
@@ -64,6 +80,9 @@ class GossmapNodeId(object):
raise ValueError("{} is not a valid node_id".format(buf.hex()))
self.nodeid = buf
self._hash = self.nodeid.__hash__()
self._str = self.nodeid.hex()
def to_pubkey(self) -> PublicKey:
return PublicKey(self.nodeid)
@@ -78,11 +97,14 @@ class GossmapNodeId(object):
return self.nodeid.__lt__(other.nodeid) # yes, that works
def __hash__(self):
return self.nodeid.__hash__()
return self._hash
def __repr__(self):
return "GossmapNodeId[{}]".format(self.nodeid.hex())
def __str__(self):
return self._str
@classmethod
def from_str(cls, s: str):
if s.startswith('0x'):
@@ -128,6 +150,17 @@ class GossmapChannel(object):
def __repr__(self):
return "GossmapChannel[{}]".format(str(self.scid))
def __str__(self):
return str(self.scid)
def __eq__(self, other):
if not isinstance(other, GossmapChannel):
return False
return self.scid.__eq__(other.scid)
def __hash__(self):
return self.scid.__hash__()
class GossmapNode(object):
"""A node: fields of node_announcement are in .fields,
@@ -141,6 +174,8 @@ class GossmapNode(object):
self.channels: List[GossmapChannel] = []
self.node_id = node_id
self._hash = self.node_id.__hash__()
def __repr__(self):
return f"GossmapNode[{self.node_id.nodeid.hex()}]"
@@ -154,6 +189,12 @@ class GossmapNode(object):
raise ValueError(f"Cannot compare GossmapNode with {type(other)}")
return self.node_id.__lt__(other.node_id)
def __hash__(self):
return self._hash
def __str__(self):
return str(self.node_id)
class Gossmap(object):
"""Class to represent the gossip map of the network"""