pyln-client/gossmap: more fixes, make mypy happier.

Mainly fixing type annotations, but some real fixes:

1. GossmapHalfchannel.from_str() should be a classmethod.
2. update_channel had weird, unusable default values (fields can't be NULL,
   since we use it below).

[ There was one more occurence where isinstance should be used above
type() == xyz comparison. -- MS ]

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2021-09-08 06:25:21 +09:30
parent ba2bcac530
commit 487facf1f0
2 changed files with 18 additions and 17 deletions

View File

@@ -3,7 +3,7 @@
from pyln.spec.bolt7 import (channel_announcement, channel_update, from pyln.spec.bolt7 import (channel_announcement, channel_update,
node_announcement) node_announcement)
from pyln.proto import ShortChannelId, PublicKey from pyln.proto import ShortChannelId, PublicKey
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union, cast
import io import io
import struct import struct
@@ -32,7 +32,7 @@ class GossipStoreHeader(object):
class GossmapHalfchannel(object): class GossmapHalfchannel(object):
"""One direction of a GossmapChannel.""" """One direction of a GossmapChannel."""
def __init__(self, channel: GossmapChannel, direction: int, def __init__(self, channel: 'GossmapChannel', direction: int,
timestamp: int, cltv_expiry_delta: int, timestamp: int, cltv_expiry_delta: int,
htlc_minimum_msat: int, htlc_maximum_msat: int, htlc_minimum_msat: int, htlc_maximum_msat: int,
fee_base_msat: int, fee_proportional_millionths: int): fee_base_msat: int, fee_proportional_millionths: int):
@@ -70,12 +70,13 @@ class GossmapNodeId(object):
def __repr__(self): def __repr__(self):
return "GossmapNodeId[{}]".format(self.nodeid.hex()) return "GossmapNodeId[{}]".format(self.nodeid.hex())
def from_str(self, s: str): @classmethod
def from_str(cls, s: str):
if s.startswith('0x'): if s.startswith('0x'):
s = s[2:] s = s[2:]
if len(s) != 67: if len(s) != 67:
raise ValueError(f"{s} is not a valid hexstring of a node_id") raise ValueError(f"{s} is not a valid hexstring of a node_id")
return GossmapNodeId(bytes.fromhex(s)) return cls(bytes.fromhex(s))
class GossmapChannel(object): class GossmapChannel(object):
@@ -96,14 +97,14 @@ class GossmapChannel(object):
self.updates_fields: List[Optional[Dict[str, Any]]] = [None, None] self.updates_fields: List[Optional[Dict[str, Any]]] = [None, None]
self.updates_offset: List[Optional[int]] = [None, None] self.updates_offset: List[Optional[int]] = [None, None]
self.satoshis = None self.satoshis = None
self.half_channels: List[GossmapHalfchannel] = [None, None] self.half_channels: List[Optional[GossmapHalfchannel]] = [None, None]
def update_channel(self, def update_channel(self,
direction: int, direction: int,
fields: List[Optional[Dict[str, Any]]] = [None, None], fields: Dict[str, Any],
off: List[Optional[int]] = [None, None]): off: int):
self.updates_fields[direction] = fields self.updates_fields[direction] = fields
self.updates_offset = off self.updates_offset[direction] = off
half = GossmapHalfchannel(self, direction, half = GossmapHalfchannel(self, direction,
fields['timestamp'], fields['timestamp'],
@@ -131,8 +132,8 @@ class GossmapNode(object):
""" """
def __init__(self, node_id: GossmapNodeId): def __init__(self, node_id: GossmapNodeId):
self.announce_fields: Optional[Dict[str, Any]] = None self.announce_fields: Optional[Dict[str, Any]] = None
self.announce_offset = None self.announce_offset: Optional[int] = None
self.channels = [] self.channels: List[GossmapChannel] = []
self.node_id = node_id self.node_id = node_id
def __repr__(self): def __repr__(self):
@@ -147,10 +148,10 @@ class Gossmap(object):
self.store_buf = bytes() self.store_buf = bytes()
self.nodes: Dict[GossmapNodeId, GossmapNode] = {} self.nodes: Dict[GossmapNodeId, GossmapNode] = {}
self.channels: Dict[ShortChannelId, GossmapChannel] = {} self.channels: Dict[ShortChannelId, GossmapChannel] = {}
self._last_scid: str = None self._last_scid: Optional[str] = None
version = self.store_file.read(1) version = self.store_file.read(1)
if version[0] != GOSSIP_STORE_VERSION: if version[0] != GOSSIP_STORE_VERSION:
raise ValueError("Invalid gossip store version {}".format(version)) raise ValueError("Invalid gossip store version {}".format(int(version)))
self.bytes_read = 1 self.bytes_read = 1
self.refresh() self.refresh()
@@ -200,15 +201,15 @@ class Gossmap(object):
def get_channel(self, short_channel_id: ShortChannelId): def get_channel(self, short_channel_id: ShortChannelId):
""" Resolves a channel by its short channel id """ """ Resolves a channel by its short channel id """
if type(short_channel_id) == str: if isinstance(short_channel_id, str):
short_channel_id = ShortChannelId.from_str(short_channel_id) short_channel_id = ShortChannelId.from_str(short_channel_id)
return self.channels.get(short_channel_id) return self.channels.get(short_channel_id)
def get_node(self, node_id: GossmapNodeId): def get_node(self, node_id: Union[GossmapNodeId, str]):
""" Resolves a node by its public key node_id """ """ Resolves a node by its public key node_id """
if type(node_id) == str: if isinstance(node_id, str):
node_id = GossmapNodeId.from_str(node_id) node_id = GossmapNodeId.from_str(node_id)
return self.nodes.get(node_id) return self.nodes.get(cast(GossmapNodeId, node_id))
def update_channel(self, rec: bytes, off: int): def update_channel(self, rec: bytes, off: int):
fields = channel_update.read(io.BytesIO(rec[2:]), {}) fields = channel_update.read(io.BytesIO(rec[2:]), {})

View File

@@ -310,7 +310,7 @@ other types. Since 'msgtype' is almost identical, it inherits from this too.
f.fieldtype.write(io_out, val, otherfields) f.fieldtype.write(io_out, val, otherfields)
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]: def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
vals = {} vals: Dict[str, Any] = {}
for field in self.fields: for field in self.fields:
val = field.fieldtype.read(io_in, vals) val = field.fieldtype.read(io_in, vals)
if val is None: if val is None: