pyln.proto.message: more mypy fixes.

This includes some real bugfixes, since it noticed some places we were
being loose with different types!

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2020-06-18 14:24:18 +09:30
committed by Christian Decker
parent 3882e8bdf7
commit 11a0de877e
4 changed files with 110 additions and 73 deletions

View File

@@ -1,10 +1,10 @@
import struct
from io import BufferedIOBase, BytesIO
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType, IntegerType
from .array_types import (
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
)
from typing import Dict, List, Optional, Tuple, Any, Union, cast
from typing import Dict, List, Optional, Tuple, Any, Union, Callable, cast
class MessageNamespace(object):
@@ -12,7 +12,7 @@ class MessageNamespace(object):
domain, such as within a given BOLT"""
def __init__(self, csv_lines: List[str] = []):
self.subtypes: Dict[str, SubtypeType] = {}
self.fundamentaltypes: Dict[str, SubtypeType] = {}
self.fundamentaltypes: Dict[str, FieldType] = {}
self.tlvtypes: Dict[str, TlvStreamType] = {}
self.messagetypes: Dict[str, MessageType] = {}
@@ -28,27 +28,35 @@ domain, such as within a given BOLT"""
for v in other.subtypes.values():
ret.add_subtype(v)
ret.tlvtypes = self.tlvtypes.copy()
for v in other.tlvtypes.values():
ret.add_tlvtype(v)
for tlv in other.tlvtypes.values():
ret.add_tlvtype(tlv)
ret.messagetypes = self.messagetypes.copy()
for v in other.messagetypes.values():
ret.add_messagetype(v)
return ret
def _check_unique(self, name: str) -> None:
"""Raise an exception if name already used"""
funtype = self.get_fundamentaltype(name)
if funtype:
raise ValueError('Already have {}'.format(funtype))
subtype = self.get_subtype(name)
if subtype:
raise ValueError('Already have {}'.format(subtype))
tlvtype = self.get_tlvtype(name)
if tlvtype:
raise ValueError('Already have {}'.format(tlvtype))
def add_subtype(self, t: 'SubtypeType') -> None:
prev = self.get_type(t.name)
if prev:
raise ValueError('Already have {}'.format(prev))
self._check_unique(t.name)
self.subtypes[t.name] = t
def add_fundamentaltype(self, t: 'SubtypeType') -> None:
assert not self.get_type(t.name)
def add_fundamentaltype(self, t: FieldType) -> None:
self._check_unique(t.name)
self.fundamentaltypes[t.name] = t
def add_tlvtype(self, t: 'TlvStreamType') -> None:
prev = self.get_type(t.name)
if prev:
raise ValueError('Already have {}'.format(prev))
self._check_unique(t.name)
self.tlvtypes[t.name] = t
def add_messagetype(self, m: 'MessageType') -> None:
@@ -70,7 +78,7 @@ domain, such as within a given BOLT"""
return m
return None
def get_fundamentaltype(self, name: str) -> Optional['SubtypeType']:
def get_fundamentaltype(self, name: str) -> Optional[FieldType]:
if name in self.fundamentaltypes:
return self.fundamentaltypes[name]
return None
@@ -85,14 +93,6 @@ domain, such as within a given BOLT"""
return self.tlvtypes[name]
return None
def get_type(self, name: str) -> Optional['SubtypeType']:
t = self.get_fundamentaltype(name)
if t is None:
t = self.get_subtype(name)
if t is None:
t = self.get_tlvtype(name)
return t
def load_csv(self, lines: List[str]) -> None:
"""Load a series of comma-separate-value lines into the namespace"""
vals: Dict[str, List[List[str]]] = {'msgtype': [],
@@ -152,23 +152,22 @@ class MessageTypeField(object):
return self.full_name
class SubtypeType(object):
class SubtypeType(FieldType):
"""This defines a 'subtype' in BOLT-speak. It consists of fields of
other types. Since 'msgtype' and 'tlvtype' are almost identical, they
inherit from this too.
other types. Since 'msgtype' is almost identical, it inherits from this too.
"""
def __init__(self, name: str):
self.name = name
self.fields: List[FieldType] = []
super().__init__(name)
self.fields: List[MessageTypeField] = []
def find_field(self, fieldname: str):
def find_field(self, fieldname: str) -> Optional[MessageTypeField]:
for f in self.fields:
if f.name == fieldname:
return f
return None
def add_field(self, field: FieldType):
def add_field(self, field: MessageTypeField) -> None:
if self.find_field(field.name):
raise ValueError("{}: duplicate field {}".format(self, field))
self.fields.append(field)
@@ -192,12 +191,16 @@ inherit from this too.
.format(parts))
return SubtypeType(parts[0])
def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], ellipsisok=False, option: str = None) -> MessageTypeField:
def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], option: str = None) -> MessageTypeField:
"""Takes msgdata/subtypedata after first two fields
e.g. [...]timestamp_node_id_1,u32,
"""
basetype = namespace.get_type(parts[1])
basetype = namespace.get_fundamentaltype(parts[1])
if basetype is None:
basetype = namespace.get_subtype(parts[1])
if basetype is None:
basetype = namespace.get_tlvtype(parts[1])
if basetype is None:
raise ValueError('Unknown type {}'.format(parts[1]))
@@ -206,7 +209,8 @@ inherit from this too.
lenfield = self.find_field(parts[2])
if lenfield is not None:
# If we didn't know that field was a length, we do now!
if type(lenfield.fieldtype) is not LengthFieldType:
if not isinstance(lenfield.fieldtype, LengthFieldType):
assert isinstance(lenfield.fieldtype, IntegerType)
lenfield.fieldtype = LengthFieldType(lenfield.fieldtype)
field = MessageTypeField(self.name, parts[0],
DynamicArrayType(self,
@@ -215,7 +219,9 @@ inherit from this too.
lenfield),
option)
lenfield.fieldtype.add_length_for(field)
elif ellipsisok and parts[2] == '...':
elif parts[2] == '...':
# ... is only valid for a TLV.
assert isinstance(self, TlvMessageType)
field = MessageTypeField(self.name, parts[0],
EllipsisArrayType(self,
parts[0], basetype),
@@ -264,8 +270,10 @@ inherit from this too.
raise ValueError("Unknown fields specified: {}".format(unknown))
for f in defined.difference(have):
if not f.fieldtype.is_optional():
raise ValueError("Missing value for {}".format(f))
field = self.find_field(f)
assert field
if not field.fieldtype.is_optional():
raise ValueError("Missing value for {}".format(field))
def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str:
self._raise_if_badvals(v)
@@ -273,6 +281,7 @@ inherit from this too.
sep = ''
for fname, val in v.items():
field = self.find_field(fname)
assert field
s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields)
sep = ','
@@ -281,16 +290,19 @@ inherit from this too.
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
ret: Dict[str, Any] = {}
for k, v in val.items():
ret[k] = self.find_field(k).fieldtype.val_to_py(v, val)
field = self.find_field(k)
assert field
ret[k] = field.fieldtype.val_to_py(v, val)
return ret
def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
self._raise_if_badvals(v)
for fname, val in v.items():
field = self.find_field(fname)
assert field
field.fieldtype.write(io_out, val, otherfields)
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
vals = {}
for field in self.fields:
val = field.fieldtype.read(io_in, otherfields)
@@ -383,25 +395,46 @@ class MessageType(SubtypeType):
messagetype.add_field(field)
class TlvStreamType(SubtypeType):
"""A TlvStreamType is just a Subtype, but its fields are
TlvMessageTypes. In the CSV format these are created implicitly, when
a tlvtype line (which defines a TlvMessageType within the TlvType,
confusingly) refers to them.
class TlvMessageType(MessageType):
"""A 'tlvtype' in BOLT-speak"""
def __init__(self, name: str, value: str):
super().__init__(name, value)
def __str__(self):
return "tlvmsgtype-{}".format(self.name)
class TlvStreamType(FieldType):
"""A TlvStreamType's fields are TlvMessageTypes. In the CSV format
these are created implicitly, when a tlvtype line (which defines a
TlvMessageType within the TlvType, confusingly) refers to them.
"""
def __init__(self, name):
super().__init__(name)
self.fields: List[TlvMessageType] = []
def __str__(self):
return "tlvstreamtype-{}".format(self.name)
def find_field_by_number(self, num: int) -> Optional['TlvMessageType']:
def find_field(self, fieldname: str) -> Optional[TlvMessageType]:
for f in self.fields:
if f.name == fieldname:
return f
return None
def find_field_by_number(self, num: int) -> Optional[TlvMessageType]:
for f in self.fields:
if f.number == num:
return f
return None
def add_field(self, field: TlvMessageType) -> None:
if self.find_field(field.name):
raise ValueError("{}: duplicate field {}".format(self, field))
self.fields.append(field)
def is_optional(self) -> bool:
"""You can omit a tlvstream= altogether"""
return True
@@ -438,7 +471,7 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
raise ValueError("Unknown tlv field {}.{}"
.format(tlvstream, parts[1]))
subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True)
subfield = field._field_from_csv(namespace, parts[2:])
field.add_field(subfield)
def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]:
@@ -480,7 +513,9 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
ret: Dict[str, Any] = {}
for k, v in val.items():
ret[k] = self.find_field(k).val_to_py(v, val)
field = self.find_field(k)
assert field
ret[k] = field.val_to_py(v, val)
return ret
def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None:
@@ -490,14 +525,16 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
# ascending order as TLV spec requires.
def write_raw_val(iobuf, val, otherfields: Dict[str, Any]):
def write_raw_val(iobuf: BufferedIOBase, val: Any, otherfields: Dict[str, Any]) -> None:
iobuf.write(val)
def get_value(tup):
"""Get value from num, fun, val tuple"""
return tup[0]
ordered = []
ordered: List[Tuple[int,
Callable[[BufferedIOBase, Any, Dict[str, Any]], None],
Any]] = []
for fieldname in v:
f = self.find_field(fieldname)
if f is None:
@@ -510,13 +547,13 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
for typenum, writefunc, val in ordered:
buf = BytesIO()
writefunc(buf, val, otherfields)
writefunc(cast(BufferedIOBase, buf), val, otherfields)
BigSizeType.write(io_out, typenum)
BigSizeType.write(io_out, len(buf.getvalue()))
io_out.write(buf.getvalue())
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
vals: Dict[str, Any] = {}
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[Union[str, int], Any]:
vals: Dict[Union[str, int], Any] = {}
while True:
tlv_type = BigSizeType.read(io_in)
@@ -543,16 +580,6 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
return " {}={}".format(name, self.val_to_str(v, {}))
class TlvMessageType(MessageType):
"""A 'tlvtype' in BOLT-speak"""
def __init__(self, name: str, value: str):
super().__init__(name, value)
def __str__(self):
return "tlvmsgtype-{}".format(self.name)
class Message(object):
"""A particular message instance"""
def __init__(self, messagetype: MessageType, **kwargs):
@@ -679,7 +706,8 @@ Must not have missing fields.
"""Convert to a Python native object: dicts, lists, strings, ints"""
ret: Dict[str, Union[Dict[str, Any], List[Any], str, int]] = {}
for f, v in self.fields.items():
fieldtype = self.messagetype.find_field(f).fieldtype
ret[f] = fieldtype.val_to_py(v, self.fields)
field = self.messagetype.find_field(f)
assert field
ret[f] = field.fieldtype.val_to_py(v, self.fields)
return ret