mirror of
https://github.com/aljazceru/lightning.git
synced 2025-12-19 23:24:27 +01:00
pyln.proto.message: more mypy fixes.
This includes some real bugfixes, since it noticed some places we were being loose with different types! Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
committed by
Christian Decker
parent
3882e8bdf7
commit
11a0de877e
@@ -17,7 +17,7 @@ check-flake8:
|
|||||||
|
|
||||||
# mypy . does not recurse. I have no idea why...
|
# mypy . does not recurse. I have no idea why...
|
||||||
check-mypy:
|
check-mypy:
|
||||||
mypy --ignore-missing-imports `find * -name '*.py'`
|
mypy --ignore-missing-imports `find pyln/proto/message/ -name '*.py'`
|
||||||
|
|
||||||
$(SDIST_FILE):
|
$(SDIST_FILE):
|
||||||
python3 setup.py sdist
|
python3 setup.py sdist
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from .fundamental_types import FieldType, IntegerType, split_field
|
from .fundamental_types import FieldType, IntegerType, split_field
|
||||||
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union
|
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union, cast
|
||||||
from io import BufferedIOBase
|
from io import BufferedIOBase
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .message import SubtypeType, TlvStreamType
|
from .message import SubtypeType, TlvMessageType, MessageTypeField
|
||||||
|
|
||||||
|
|
||||||
class ArrayType(FieldType):
|
class ArrayType(FieldType):
|
||||||
@@ -98,7 +98,7 @@ class SizedArrayType(ArrayType):
|
|||||||
class EllipsisArrayType(ArrayType):
|
class EllipsisArrayType(ArrayType):
|
||||||
"""This is used for ... fields at the end of a tlv: the array ends
|
"""This is used for ... fields at the end of a tlv: the array ends
|
||||||
when the tlv ends"""
|
when the tlv ends"""
|
||||||
def __init__(self, tlv: 'TlvStreamType', name: str, elemtype: FieldType):
|
def __init__(self, tlv: 'TlvMessageType', name: str, elemtype: FieldType):
|
||||||
super().__init__(tlv, name, elemtype)
|
super().__init__(tlv, name, elemtype)
|
||||||
|
|
||||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
|
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
|
||||||
@@ -119,13 +119,13 @@ class LengthFieldType(FieldType):
|
|||||||
super().__init__(inttype.name)
|
super().__init__(inttype.name)
|
||||||
self.underlying_type = inttype
|
self.underlying_type = inttype
|
||||||
# You can be length for more than one field!
|
# You can be length for more than one field!
|
||||||
self.len_for: List[DynamicArrayType] = []
|
self.len_for: List['MessageTypeField'] = []
|
||||||
|
|
||||||
def is_optional(self) -> bool:
|
def is_optional(self) -> bool:
|
||||||
"""This field value is always implies, never specified directly"""
|
"""This field value is always implies, never specified directly"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_length_for(self, field: 'DynamicArrayType') -> None:
|
def add_length_for(self, field: 'MessageTypeField') -> None:
|
||||||
assert isinstance(field.fieldtype, DynamicArrayType)
|
assert isinstance(field.fieldtype, DynamicArrayType)
|
||||||
self.len_for.append(field)
|
self.len_for.append(field)
|
||||||
|
|
||||||
@@ -160,7 +160,7 @@ class LengthFieldType(FieldType):
|
|||||||
they're implied by the length of other fields"""
|
they're implied by the length of other fields"""
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None:
|
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[int]:
|
||||||
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
|
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
|
||||||
return self.underlying_type.read(io_in, otherfields)
|
return self.underlying_type.read(io_in, otherfields)
|
||||||
|
|
||||||
@@ -186,11 +186,11 @@ they're implied by the length of other fields"""
|
|||||||
|
|
||||||
class DynamicArrayType(ArrayType):
|
class DynamicArrayType(ArrayType):
|
||||||
"""This is used for arrays where another field controls the size"""
|
"""This is used for arrays where another field controls the size"""
|
||||||
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: LengthFieldType):
|
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: 'MessageTypeField'):
|
||||||
super().__init__(outer, name, elemtype)
|
super().__init__(outer, name, elemtype)
|
||||||
assert type(lenfield.fieldtype) is LengthFieldType
|
assert type(lenfield.fieldtype) is LengthFieldType
|
||||||
self.lenfield = lenfield
|
self.lenfield = lenfield
|
||||||
|
|
||||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
|
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
|
||||||
return super().read_arr(io_in, otherfields,
|
return super().read_arr(io_in, otherfields,
|
||||||
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))
|
cast(LengthFieldType, self.lenfield.fieldtype)._maybe_calc_value(self.lenfield.name, otherfields))
|
||||||
|
|||||||
@@ -59,6 +59,15 @@ These are further specialized.
|
|||||||
def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str:
|
def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def val_from_str(self, s: str) -> Tuple[Any, str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def write(self, io_out: BufferedIOBase, v: Any, otherfields: Dict[str, Any]) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Any:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any:
|
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any:
|
||||||
"""Convert to a python object: for simple fields, this means a string"""
|
"""Convert to a python object: for simple fields, this means a string"""
|
||||||
return self.val_to_str(v, otherfields)
|
return self.val_to_str(v, otherfields)
|
||||||
@@ -83,7 +92,7 @@ class IntegerType(FieldType):
|
|||||||
a, b = split_field(s)
|
a, b = split_field(s)
|
||||||
return int(a), b
|
return int(a), b
|
||||||
|
|
||||||
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> int:
|
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any:
|
||||||
"""Convert to a python object: for integer fields, this means an int"""
|
"""Convert to a python object: for integer fields, this means an int"""
|
||||||
return int(v)
|
return int(v)
|
||||||
|
|
||||||
@@ -240,7 +249,7 @@ class BigSizeType(FieldType):
|
|||||||
return int(v)
|
return int(v)
|
||||||
|
|
||||||
|
|
||||||
def fundamental_types():
|
def fundamental_types() -> List[FieldType]:
|
||||||
# From 01-messaging.md#fundamental-types:
|
# From 01-messaging.md#fundamental-types:
|
||||||
return [IntegerType('byte', 1, 'B'),
|
return [IntegerType('byte', 1, 'B'),
|
||||||
IntegerType('u16', 2, '>H'),
|
IntegerType('u16', 2, '>H'),
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import struct
|
import struct
|
||||||
from io import BufferedIOBase, BytesIO
|
from io import BufferedIOBase, BytesIO
|
||||||
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType
|
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType, IntegerType
|
||||||
from .array_types import (
|
from .array_types import (
|
||||||
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
|
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
|
||||||
)
|
)
|
||||||
from typing import Dict, List, Optional, Tuple, Any, Union, cast
|
from typing import Dict, List, Optional, Tuple, Any, Union, Callable, cast
|
||||||
|
|
||||||
|
|
||||||
class MessageNamespace(object):
|
class MessageNamespace(object):
|
||||||
@@ -12,7 +12,7 @@ class MessageNamespace(object):
|
|||||||
domain, such as within a given BOLT"""
|
domain, such as within a given BOLT"""
|
||||||
def __init__(self, csv_lines: List[str] = []):
|
def __init__(self, csv_lines: List[str] = []):
|
||||||
self.subtypes: Dict[str, SubtypeType] = {}
|
self.subtypes: Dict[str, SubtypeType] = {}
|
||||||
self.fundamentaltypes: Dict[str, SubtypeType] = {}
|
self.fundamentaltypes: Dict[str, FieldType] = {}
|
||||||
self.tlvtypes: Dict[str, TlvStreamType] = {}
|
self.tlvtypes: Dict[str, TlvStreamType] = {}
|
||||||
self.messagetypes: Dict[str, MessageType] = {}
|
self.messagetypes: Dict[str, MessageType] = {}
|
||||||
|
|
||||||
@@ -28,27 +28,35 @@ domain, such as within a given BOLT"""
|
|||||||
for v in other.subtypes.values():
|
for v in other.subtypes.values():
|
||||||
ret.add_subtype(v)
|
ret.add_subtype(v)
|
||||||
ret.tlvtypes = self.tlvtypes.copy()
|
ret.tlvtypes = self.tlvtypes.copy()
|
||||||
for v in other.tlvtypes.values():
|
for tlv in other.tlvtypes.values():
|
||||||
ret.add_tlvtype(v)
|
ret.add_tlvtype(tlv)
|
||||||
ret.messagetypes = self.messagetypes.copy()
|
ret.messagetypes = self.messagetypes.copy()
|
||||||
for v in other.messagetypes.values():
|
for v in other.messagetypes.values():
|
||||||
ret.add_messagetype(v)
|
ret.add_messagetype(v)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def _check_unique(self, name: str) -> None:
|
||||||
|
"""Raise an exception if name already used"""
|
||||||
|
funtype = self.get_fundamentaltype(name)
|
||||||
|
if funtype:
|
||||||
|
raise ValueError('Already have {}'.format(funtype))
|
||||||
|
subtype = self.get_subtype(name)
|
||||||
|
if subtype:
|
||||||
|
raise ValueError('Already have {}'.format(subtype))
|
||||||
|
tlvtype = self.get_tlvtype(name)
|
||||||
|
if tlvtype:
|
||||||
|
raise ValueError('Already have {}'.format(tlvtype))
|
||||||
|
|
||||||
def add_subtype(self, t: 'SubtypeType') -> None:
|
def add_subtype(self, t: 'SubtypeType') -> None:
|
||||||
prev = self.get_type(t.name)
|
self._check_unique(t.name)
|
||||||
if prev:
|
|
||||||
raise ValueError('Already have {}'.format(prev))
|
|
||||||
self.subtypes[t.name] = t
|
self.subtypes[t.name] = t
|
||||||
|
|
||||||
def add_fundamentaltype(self, t: 'SubtypeType') -> None:
|
def add_fundamentaltype(self, t: FieldType) -> None:
|
||||||
assert not self.get_type(t.name)
|
self._check_unique(t.name)
|
||||||
self.fundamentaltypes[t.name] = t
|
self.fundamentaltypes[t.name] = t
|
||||||
|
|
||||||
def add_tlvtype(self, t: 'TlvStreamType') -> None:
|
def add_tlvtype(self, t: 'TlvStreamType') -> None:
|
||||||
prev = self.get_type(t.name)
|
self._check_unique(t.name)
|
||||||
if prev:
|
|
||||||
raise ValueError('Already have {}'.format(prev))
|
|
||||||
self.tlvtypes[t.name] = t
|
self.tlvtypes[t.name] = t
|
||||||
|
|
||||||
def add_messagetype(self, m: 'MessageType') -> None:
|
def add_messagetype(self, m: 'MessageType') -> None:
|
||||||
@@ -70,7 +78,7 @@ domain, such as within a given BOLT"""
|
|||||||
return m
|
return m
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_fundamentaltype(self, name: str) -> Optional['SubtypeType']:
|
def get_fundamentaltype(self, name: str) -> Optional[FieldType]:
|
||||||
if name in self.fundamentaltypes:
|
if name in self.fundamentaltypes:
|
||||||
return self.fundamentaltypes[name]
|
return self.fundamentaltypes[name]
|
||||||
return None
|
return None
|
||||||
@@ -85,14 +93,6 @@ domain, such as within a given BOLT"""
|
|||||||
return self.tlvtypes[name]
|
return self.tlvtypes[name]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_type(self, name: str) -> Optional['SubtypeType']:
|
|
||||||
t = self.get_fundamentaltype(name)
|
|
||||||
if t is None:
|
|
||||||
t = self.get_subtype(name)
|
|
||||||
if t is None:
|
|
||||||
t = self.get_tlvtype(name)
|
|
||||||
return t
|
|
||||||
|
|
||||||
def load_csv(self, lines: List[str]) -> None:
|
def load_csv(self, lines: List[str]) -> None:
|
||||||
"""Load a series of comma-separate-value lines into the namespace"""
|
"""Load a series of comma-separate-value lines into the namespace"""
|
||||||
vals: Dict[str, List[List[str]]] = {'msgtype': [],
|
vals: Dict[str, List[List[str]]] = {'msgtype': [],
|
||||||
@@ -152,23 +152,22 @@ class MessageTypeField(object):
|
|||||||
return self.full_name
|
return self.full_name
|
||||||
|
|
||||||
|
|
||||||
class SubtypeType(object):
|
class SubtypeType(FieldType):
|
||||||
"""This defines a 'subtype' in BOLT-speak. It consists of fields of
|
"""This defines a 'subtype' in BOLT-speak. It consists of fields of
|
||||||
other types. Since 'msgtype' and 'tlvtype' are almost identical, they
|
other types. Since 'msgtype' is almost identical, it inherits from this too.
|
||||||
inherit from this too.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
self.name = name
|
super().__init__(name)
|
||||||
self.fields: List[FieldType] = []
|
self.fields: List[MessageTypeField] = []
|
||||||
|
|
||||||
def find_field(self, fieldname: str):
|
def find_field(self, fieldname: str) -> Optional[MessageTypeField]:
|
||||||
for f in self.fields:
|
for f in self.fields:
|
||||||
if f.name == fieldname:
|
if f.name == fieldname:
|
||||||
return f
|
return f
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def add_field(self, field: FieldType):
|
def add_field(self, field: MessageTypeField) -> None:
|
||||||
if self.find_field(field.name):
|
if self.find_field(field.name):
|
||||||
raise ValueError("{}: duplicate field {}".format(self, field))
|
raise ValueError("{}: duplicate field {}".format(self, field))
|
||||||
self.fields.append(field)
|
self.fields.append(field)
|
||||||
@@ -192,12 +191,16 @@ inherit from this too.
|
|||||||
.format(parts))
|
.format(parts))
|
||||||
return SubtypeType(parts[0])
|
return SubtypeType(parts[0])
|
||||||
|
|
||||||
def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], ellipsisok=False, option: str = None) -> MessageTypeField:
|
def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], option: str = None) -> MessageTypeField:
|
||||||
"""Takes msgdata/subtypedata after first two fields
|
"""Takes msgdata/subtypedata after first two fields
|
||||||
e.g. [...]timestamp_node_id_1,u32,
|
e.g. [...]timestamp_node_id_1,u32,
|
||||||
|
|
||||||
"""
|
"""
|
||||||
basetype = namespace.get_type(parts[1])
|
basetype = namespace.get_fundamentaltype(parts[1])
|
||||||
|
if basetype is None:
|
||||||
|
basetype = namespace.get_subtype(parts[1])
|
||||||
|
if basetype is None:
|
||||||
|
basetype = namespace.get_tlvtype(parts[1])
|
||||||
if basetype is None:
|
if basetype is None:
|
||||||
raise ValueError('Unknown type {}'.format(parts[1]))
|
raise ValueError('Unknown type {}'.format(parts[1]))
|
||||||
|
|
||||||
@@ -206,7 +209,8 @@ inherit from this too.
|
|||||||
lenfield = self.find_field(parts[2])
|
lenfield = self.find_field(parts[2])
|
||||||
if lenfield is not None:
|
if lenfield is not None:
|
||||||
# If we didn't know that field was a length, we do now!
|
# If we didn't know that field was a length, we do now!
|
||||||
if type(lenfield.fieldtype) is not LengthFieldType:
|
if not isinstance(lenfield.fieldtype, LengthFieldType):
|
||||||
|
assert isinstance(lenfield.fieldtype, IntegerType)
|
||||||
lenfield.fieldtype = LengthFieldType(lenfield.fieldtype)
|
lenfield.fieldtype = LengthFieldType(lenfield.fieldtype)
|
||||||
field = MessageTypeField(self.name, parts[0],
|
field = MessageTypeField(self.name, parts[0],
|
||||||
DynamicArrayType(self,
|
DynamicArrayType(self,
|
||||||
@@ -215,7 +219,9 @@ inherit from this too.
|
|||||||
lenfield),
|
lenfield),
|
||||||
option)
|
option)
|
||||||
lenfield.fieldtype.add_length_for(field)
|
lenfield.fieldtype.add_length_for(field)
|
||||||
elif ellipsisok and parts[2] == '...':
|
elif parts[2] == '...':
|
||||||
|
# ... is only valid for a TLV.
|
||||||
|
assert isinstance(self, TlvMessageType)
|
||||||
field = MessageTypeField(self.name, parts[0],
|
field = MessageTypeField(self.name, parts[0],
|
||||||
EllipsisArrayType(self,
|
EllipsisArrayType(self,
|
||||||
parts[0], basetype),
|
parts[0], basetype),
|
||||||
@@ -264,8 +270,10 @@ inherit from this too.
|
|||||||
raise ValueError("Unknown fields specified: {}".format(unknown))
|
raise ValueError("Unknown fields specified: {}".format(unknown))
|
||||||
|
|
||||||
for f in defined.difference(have):
|
for f in defined.difference(have):
|
||||||
if not f.fieldtype.is_optional():
|
field = self.find_field(f)
|
||||||
raise ValueError("Missing value for {}".format(f))
|
assert field
|
||||||
|
if not field.fieldtype.is_optional():
|
||||||
|
raise ValueError("Missing value for {}".format(field))
|
||||||
|
|
||||||
def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str:
|
def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str:
|
||||||
self._raise_if_badvals(v)
|
self._raise_if_badvals(v)
|
||||||
@@ -273,6 +281,7 @@ inherit from this too.
|
|||||||
sep = ''
|
sep = ''
|
||||||
for fname, val in v.items():
|
for fname, val in v.items():
|
||||||
field = self.find_field(fname)
|
field = self.find_field(fname)
|
||||||
|
assert field
|
||||||
s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields)
|
s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields)
|
||||||
sep = ','
|
sep = ','
|
||||||
|
|
||||||
@@ -281,16 +290,19 @@ inherit from this too.
|
|||||||
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
ret: Dict[str, Any] = {}
|
ret: Dict[str, Any] = {}
|
||||||
for k, v in val.items():
|
for k, v in val.items():
|
||||||
ret[k] = self.find_field(k).fieldtype.val_to_py(v, val)
|
field = self.find_field(k)
|
||||||
|
assert field
|
||||||
|
ret[k] = field.fieldtype.val_to_py(v, val)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
|
def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
|
||||||
self._raise_if_badvals(v)
|
self._raise_if_badvals(v)
|
||||||
for fname, val in v.items():
|
for fname, val in v.items():
|
||||||
field = self.find_field(fname)
|
field = self.find_field(fname)
|
||||||
|
assert field
|
||||||
field.fieldtype.write(io_out, val, otherfields)
|
field.fieldtype.write(io_out, val, otherfields)
|
||||||
|
|
||||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
vals = {}
|
vals = {}
|
||||||
for field in self.fields:
|
for field in self.fields:
|
||||||
val = field.fieldtype.read(io_in, otherfields)
|
val = field.fieldtype.read(io_in, otherfields)
|
||||||
@@ -383,25 +395,46 @@ class MessageType(SubtypeType):
|
|||||||
messagetype.add_field(field)
|
messagetype.add_field(field)
|
||||||
|
|
||||||
|
|
||||||
class TlvStreamType(SubtypeType):
|
class TlvMessageType(MessageType):
|
||||||
"""A TlvStreamType is just a Subtype, but its fields are
|
"""A 'tlvtype' in BOLT-speak"""
|
||||||
TlvMessageTypes. In the CSV format these are created implicitly, when
|
|
||||||
a tlvtype line (which defines a TlvMessageType within the TlvType,
|
def __init__(self, name: str, value: str):
|
||||||
confusingly) refers to them.
|
super().__init__(name, value)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "tlvmsgtype-{}".format(self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class TlvStreamType(FieldType):
|
||||||
|
"""A TlvStreamType's fields are TlvMessageTypes. In the CSV format
|
||||||
|
these are created implicitly, when a tlvtype line (which defines a
|
||||||
|
TlvMessageType within the TlvType, confusingly) refers to them.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
|
self.fields: List[TlvMessageType] = []
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "tlvstreamtype-{}".format(self.name)
|
return "tlvstreamtype-{}".format(self.name)
|
||||||
|
|
||||||
def find_field_by_number(self, num: int) -> Optional['TlvMessageType']:
|
def find_field(self, fieldname: str) -> Optional[TlvMessageType]:
|
||||||
|
for f in self.fields:
|
||||||
|
if f.name == fieldname:
|
||||||
|
return f
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_field_by_number(self, num: int) -> Optional[TlvMessageType]:
|
||||||
for f in self.fields:
|
for f in self.fields:
|
||||||
if f.number == num:
|
if f.number == num:
|
||||||
return f
|
return f
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def add_field(self, field: TlvMessageType) -> None:
|
||||||
|
if self.find_field(field.name):
|
||||||
|
raise ValueError("{}: duplicate field {}".format(self, field))
|
||||||
|
self.fields.append(field)
|
||||||
|
|
||||||
def is_optional(self) -> bool:
|
def is_optional(self) -> bool:
|
||||||
"""You can omit a tlvstream= altogether"""
|
"""You can omit a tlvstream= altogether"""
|
||||||
return True
|
return True
|
||||||
@@ -438,7 +471,7 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
|||||||
raise ValueError("Unknown tlv field {}.{}"
|
raise ValueError("Unknown tlv field {}.{}"
|
||||||
.format(tlvstream, parts[1]))
|
.format(tlvstream, parts[1]))
|
||||||
|
|
||||||
subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True)
|
subfield = field._field_from_csv(namespace, parts[2:])
|
||||||
field.add_field(subfield)
|
field.add_field(subfield)
|
||||||
|
|
||||||
def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]:
|
def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]:
|
||||||
@@ -480,7 +513,9 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
|||||||
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
ret: Dict[str, Any] = {}
|
ret: Dict[str, Any] = {}
|
||||||
for k, v in val.items():
|
for k, v in val.items():
|
||||||
ret[k] = self.find_field(k).val_to_py(v, val)
|
field = self.find_field(k)
|
||||||
|
assert field
|
||||||
|
ret[k] = field.val_to_py(v, val)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None:
|
def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None:
|
||||||
@@ -490,14 +525,16 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
|||||||
|
|
||||||
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
|
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
|
||||||
# ascending order as TLV spec requires.
|
# ascending order as TLV spec requires.
|
||||||
def write_raw_val(iobuf, val, otherfields: Dict[str, Any]):
|
def write_raw_val(iobuf: BufferedIOBase, val: Any, otherfields: Dict[str, Any]) -> None:
|
||||||
iobuf.write(val)
|
iobuf.write(val)
|
||||||
|
|
||||||
def get_value(tup):
|
def get_value(tup):
|
||||||
"""Get value from num, fun, val tuple"""
|
"""Get value from num, fun, val tuple"""
|
||||||
return tup[0]
|
return tup[0]
|
||||||
|
|
||||||
ordered = []
|
ordered: List[Tuple[int,
|
||||||
|
Callable[[BufferedIOBase, Any, Dict[str, Any]], None],
|
||||||
|
Any]] = []
|
||||||
for fieldname in v:
|
for fieldname in v:
|
||||||
f = self.find_field(fieldname)
|
f = self.find_field(fieldname)
|
||||||
if f is None:
|
if f is None:
|
||||||
@@ -510,13 +547,13 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
|||||||
|
|
||||||
for typenum, writefunc, val in ordered:
|
for typenum, writefunc, val in ordered:
|
||||||
buf = BytesIO()
|
buf = BytesIO()
|
||||||
writefunc(buf, val, otherfields)
|
writefunc(cast(BufferedIOBase, buf), val, otherfields)
|
||||||
BigSizeType.write(io_out, typenum)
|
BigSizeType.write(io_out, typenum)
|
||||||
BigSizeType.write(io_out, len(buf.getvalue()))
|
BigSizeType.write(io_out, len(buf.getvalue()))
|
||||||
io_out.write(buf.getvalue())
|
io_out.write(buf.getvalue())
|
||||||
|
|
||||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[Union[str, int], Any]:
|
||||||
vals: Dict[str, Any] = {}
|
vals: Dict[Union[str, int], Any] = {}
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
tlv_type = BigSizeType.read(io_in)
|
tlv_type = BigSizeType.read(io_in)
|
||||||
@@ -543,16 +580,6 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
|||||||
return " {}={}".format(name, self.val_to_str(v, {}))
|
return " {}={}".format(name, self.val_to_str(v, {}))
|
||||||
|
|
||||||
|
|
||||||
class TlvMessageType(MessageType):
|
|
||||||
"""A 'tlvtype' in BOLT-speak"""
|
|
||||||
|
|
||||||
def __init__(self, name: str, value: str):
|
|
||||||
super().__init__(name, value)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "tlvmsgtype-{}".format(self.name)
|
|
||||||
|
|
||||||
|
|
||||||
class Message(object):
|
class Message(object):
|
||||||
"""A particular message instance"""
|
"""A particular message instance"""
|
||||||
def __init__(self, messagetype: MessageType, **kwargs):
|
def __init__(self, messagetype: MessageType, **kwargs):
|
||||||
@@ -679,7 +706,8 @@ Must not have missing fields.
|
|||||||
"""Convert to a Python native object: dicts, lists, strings, ints"""
|
"""Convert to a Python native object: dicts, lists, strings, ints"""
|
||||||
ret: Dict[str, Union[Dict[str, Any], List[Any], str, int]] = {}
|
ret: Dict[str, Union[Dict[str, Any], List[Any], str, int]] = {}
|
||||||
for f, v in self.fields.items():
|
for f, v in self.fields.items():
|
||||||
fieldtype = self.messagetype.find_field(f).fieldtype
|
field = self.messagetype.find_field(f)
|
||||||
ret[f] = fieldtype.val_to_py(v, self.fields)
|
assert field
|
||||||
|
ret[f] = field.fieldtype.val_to_py(v, self.fields)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
Reference in New Issue
Block a user