mirror of
https://github.com/aljazceru/lightning.git
synced 2025-12-21 16:14:23 +01:00
pyln-client/gossmap: more fixes, make mypy happier.
Mainly fixing type annotations, but some real fixes: 1. GossmapHalfchannel.from_str() should be a classmethod. 2. update_channel had weird, unusable default values (fields can't be NULL, since we use it below). [ There was one more occurence where isinstance should be used above type() == xyz comparison. -- MS ] Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
@@ -3,7 +3,7 @@
|
|||||||
from pyln.spec.bolt7 import (channel_announcement, channel_update,
|
from pyln.spec.bolt7 import (channel_announcement, channel_update,
|
||||||
node_announcement)
|
node_announcement)
|
||||||
from pyln.proto import ShortChannelId, PublicKey
|
from pyln.proto import ShortChannelId, PublicKey
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import struct
|
import struct
|
||||||
@@ -32,7 +32,7 @@ class GossipStoreHeader(object):
|
|||||||
|
|
||||||
class GossmapHalfchannel(object):
|
class GossmapHalfchannel(object):
|
||||||
"""One direction of a GossmapChannel."""
|
"""One direction of a GossmapChannel."""
|
||||||
def __init__(self, channel: GossmapChannel, direction: int,
|
def __init__(self, channel: 'GossmapChannel', direction: int,
|
||||||
timestamp: int, cltv_expiry_delta: int,
|
timestamp: int, cltv_expiry_delta: int,
|
||||||
htlc_minimum_msat: int, htlc_maximum_msat: int,
|
htlc_minimum_msat: int, htlc_maximum_msat: int,
|
||||||
fee_base_msat: int, fee_proportional_millionths: int):
|
fee_base_msat: int, fee_proportional_millionths: int):
|
||||||
@@ -70,12 +70,13 @@ class GossmapNodeId(object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "GossmapNodeId[{}]".format(self.nodeid.hex())
|
return "GossmapNodeId[{}]".format(self.nodeid.hex())
|
||||||
|
|
||||||
def from_str(self, s: str):
|
@classmethod
|
||||||
|
def from_str(cls, s: str):
|
||||||
if s.startswith('0x'):
|
if s.startswith('0x'):
|
||||||
s = s[2:]
|
s = s[2:]
|
||||||
if len(s) != 67:
|
if len(s) != 67:
|
||||||
raise ValueError(f"{s} is not a valid hexstring of a node_id")
|
raise ValueError(f"{s} is not a valid hexstring of a node_id")
|
||||||
return GossmapNodeId(bytes.fromhex(s))
|
return cls(bytes.fromhex(s))
|
||||||
|
|
||||||
|
|
||||||
class GossmapChannel(object):
|
class GossmapChannel(object):
|
||||||
@@ -96,14 +97,14 @@ class GossmapChannel(object):
|
|||||||
self.updates_fields: List[Optional[Dict[str, Any]]] = [None, None]
|
self.updates_fields: List[Optional[Dict[str, Any]]] = [None, None]
|
||||||
self.updates_offset: List[Optional[int]] = [None, None]
|
self.updates_offset: List[Optional[int]] = [None, None]
|
||||||
self.satoshis = None
|
self.satoshis = None
|
||||||
self.half_channels: List[GossmapHalfchannel] = [None, None]
|
self.half_channels: List[Optional[GossmapHalfchannel]] = [None, None]
|
||||||
|
|
||||||
def update_channel(self,
|
def update_channel(self,
|
||||||
direction: int,
|
direction: int,
|
||||||
fields: List[Optional[Dict[str, Any]]] = [None, None],
|
fields: Dict[str, Any],
|
||||||
off: List[Optional[int]] = [None, None]):
|
off: int):
|
||||||
self.updates_fields[direction] = fields
|
self.updates_fields[direction] = fields
|
||||||
self.updates_offset = off
|
self.updates_offset[direction] = off
|
||||||
|
|
||||||
half = GossmapHalfchannel(self, direction,
|
half = GossmapHalfchannel(self, direction,
|
||||||
fields['timestamp'],
|
fields['timestamp'],
|
||||||
@@ -131,8 +132,8 @@ class GossmapNode(object):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, node_id: GossmapNodeId):
|
def __init__(self, node_id: GossmapNodeId):
|
||||||
self.announce_fields: Optional[Dict[str, Any]] = None
|
self.announce_fields: Optional[Dict[str, Any]] = None
|
||||||
self.announce_offset = None
|
self.announce_offset: Optional[int] = None
|
||||||
self.channels = []
|
self.channels: List[GossmapChannel] = []
|
||||||
self.node_id = node_id
|
self.node_id = node_id
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -147,10 +148,10 @@ class Gossmap(object):
|
|||||||
self.store_buf = bytes()
|
self.store_buf = bytes()
|
||||||
self.nodes: Dict[GossmapNodeId, GossmapNode] = {}
|
self.nodes: Dict[GossmapNodeId, GossmapNode] = {}
|
||||||
self.channels: Dict[ShortChannelId, GossmapChannel] = {}
|
self.channels: Dict[ShortChannelId, GossmapChannel] = {}
|
||||||
self._last_scid: str = None
|
self._last_scid: Optional[str] = None
|
||||||
version = self.store_file.read(1)
|
version = self.store_file.read(1)
|
||||||
if version[0] != GOSSIP_STORE_VERSION:
|
if version[0] != GOSSIP_STORE_VERSION:
|
||||||
raise ValueError("Invalid gossip store version {}".format(version))
|
raise ValueError("Invalid gossip store version {}".format(int(version)))
|
||||||
self.bytes_read = 1
|
self.bytes_read = 1
|
||||||
self.refresh()
|
self.refresh()
|
||||||
|
|
||||||
@@ -200,15 +201,15 @@ class Gossmap(object):
|
|||||||
|
|
||||||
def get_channel(self, short_channel_id: ShortChannelId):
|
def get_channel(self, short_channel_id: ShortChannelId):
|
||||||
""" Resolves a channel by its short channel id """
|
""" Resolves a channel by its short channel id """
|
||||||
if type(short_channel_id) == str:
|
if isinstance(short_channel_id, str):
|
||||||
short_channel_id = ShortChannelId.from_str(short_channel_id)
|
short_channel_id = ShortChannelId.from_str(short_channel_id)
|
||||||
return self.channels.get(short_channel_id)
|
return self.channels.get(short_channel_id)
|
||||||
|
|
||||||
def get_node(self, node_id: GossmapNodeId):
|
def get_node(self, node_id: Union[GossmapNodeId, str]):
|
||||||
""" Resolves a node by its public key node_id """
|
""" Resolves a node by its public key node_id """
|
||||||
if type(node_id) == str:
|
if isinstance(node_id, str):
|
||||||
node_id = GossmapNodeId.from_str(node_id)
|
node_id = GossmapNodeId.from_str(node_id)
|
||||||
return self.nodes.get(node_id)
|
return self.nodes.get(cast(GossmapNodeId, node_id))
|
||||||
|
|
||||||
def update_channel(self, rec: bytes, off: int):
|
def update_channel(self, rec: bytes, off: int):
|
||||||
fields = channel_update.read(io.BytesIO(rec[2:]), {})
|
fields = channel_update.read(io.BytesIO(rec[2:]), {})
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ other types. Since 'msgtype' is almost identical, it inherits from this too.
|
|||||||
f.fieldtype.write(io_out, val, otherfields)
|
f.fieldtype.write(io_out, val, otherfields)
|
||||||
|
|
||||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
vals = {}
|
vals: Dict[str, Any] = {}
|
||||||
for field in self.fields:
|
for field in self.fields:
|
||||||
val = field.fieldtype.read(io_in, vals)
|
val = field.fieldtype.read(io_in, vals)
|
||||||
if val is None:
|
if val is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user