diff --git a/contrib/pyln-proto/pyln/proto/message/array_types.py b/contrib/pyln-proto/pyln/proto/message/array_types.py index 8aeebec45..3b4378225 100644 --- a/contrib/pyln-proto/pyln/proto/message/array_types.py +++ b/contrib/pyln-proto/pyln/proto/message/array_types.py @@ -42,28 +42,26 @@ wants an array of some type. return '[' + s + ']' - def val_to_bin(self, v, otherfields): - b = bytes() + def write(self, io_out, v, otherfields): for i in v: - b += self.elemtype.val_to_bin(i, otherfields) - return b + self.elemtype.write(io_out, i, otherfields) - def arr_from_bin(self, bytestream, otherfields, arraysize): - """arraysize None means take rest of bytestream exactly""" - totsize = 0 + def read_arr(self, io_in, otherfields, arraysize): + """arraysize None means take rest of io entirely and exactly""" vals = [] - i = 0 - while True: - if arraysize is None and totsize == len(bytestream): - return vals, totsize - elif i == arraysize: - return vals, totsize - val, size = self.elemtype.val_from_bin(bytestream[totsize:], - otherfields) - totsize += size - i += 1 + while arraysize is None or len(vals) < arraysize: + # Throws an exception on partial read, so None means completely empty. + val = self.elemtype.read(io_in, otherfields) + if val is None: + if arraysize is not None: + raise ValueError('{}: not enough remaining to read' + .format(self)) + break + vals.append(val) + return vals + class SizedArrayType(ArrayType): """A fixed-size array""" @@ -82,13 +80,13 @@ class SizedArrayType(ArrayType): raise ValueError("Length of {} != {}", s, self.arraysize) return a, b - def val_to_bin(self, v, otherfields): + def write(self, io_out, v, otherfields): if len(v) != self.arraysize: raise ValueError("Length of {} != {}", v, self.arraysize) - return super().val_to_bin(v, otherfields) + return super().write(io_out, v, otherfields) - def val_from_bin(self, bytestream, otherfields): - return super().arr_from_bin(bytestream, otherfields, self.arraysize) + def read(self, io_in, otherfields): + return super().read_arr(io_in, otherfields, self.arraysize) class EllipsisArrayType(ArrayType): @@ -97,9 +95,9 @@ when the tlv ends""" def __init__(self, tlv, name, elemtype): super().__init__(tlv, name, elemtype) - def val_from_bin(self, bytestream, otherfields): + def read(self, io_in, otherfields): """Takes rest of bytestream""" - return super().arr_from_bin(bytestream, otherfields, None) + return super().read_arr(io_in, otherfields, None) def only_at_tlv_end(self): """These only make sense at the end of a TLV""" @@ -142,10 +140,6 @@ class LengthFieldType(FieldType): return v return self.calc_value(otherfields) - def val_to_bin(self, _, otherfields): - return self.underlying_type.val_to_bin(self.calc_value(otherfields), - otherfields) - def val_to_str(self, _, otherfields): return self.underlying_type.val_to_str(self.calc_value(otherfields), otherfields) @@ -155,9 +149,13 @@ class LengthFieldType(FieldType): they're implied by the length of other fields""" return '' - def val_from_bin(self, bytestream, otherfields): + def read(self, io_in, otherfields): """We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)""" - return self.underlying_type.val_from_bin(bytestream, otherfields) + return self.underlying_type.read(io_in, otherfields) + + def write(self, io_out, _, otherfields): + self.underlying_type.write(io_out, self.calc_value(otherfields), + otherfields) def val_from_str(self, s): raise ValueError('{} is implied, cannot be specified'.format(self)) @@ -182,6 +180,6 @@ class DynamicArrayType(ArrayType): assert type(lenfield.fieldtype) is LengthFieldType self.lenfield = lenfield - def val_from_bin(self, bytestream, otherfields): - return super().arr_from_bin(bytestream, otherfields, - self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields)) + def read(self, io_in, otherfields): + return super().read_arr(io_in, otherfields, + self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields)) diff --git a/contrib/pyln-proto/pyln/proto/message/fundamental_types.py b/contrib/pyln-proto/pyln/proto/message/fundamental_types.py index e4cd53d40..344a48ad8 100644 --- a/contrib/pyln-proto/pyln/proto/message/fundamental_types.py +++ b/contrib/pyln-proto/pyln/proto/message/fundamental_types.py @@ -1,4 +1,22 @@ import struct +import io +from typing import Optional + + +def try_unpack(name: str, + io_out: io.BufferedIOBase, + structfmt: str, + empty_ok: bool) -> Optional[int]: + """Unpack a single value using struct.unpack. + +If need_all, never return None, otherwise returns None if EOF.""" + b = io_out.read(struct.calcsize(structfmt)) + if len(b) == 0 and empty_ok: + return None + elif len(b) < struct.calcsize(structfmt): + raise ValueError("{}: not enough bytes", name) + + return struct.unpack(structfmt, b)[0] def split_field(s): @@ -57,15 +75,11 @@ class IntegerType(FieldType): a, b = split_field(s) return int(a), b - def val_to_bin(self, v, otherfields): - return struct.pack(self.structfmt, v) + def write(self, io_out, v, otherfields): + io_out.write(struct.pack(self.structfmt, v)) - def val_from_bin(self, bytestream, otherfields): - "Returns value, bytesused" - if self.bytelen > len(bytestream): - raise ValueError('{}: not enough remaining to read'.format(self)) - return struct.unpack_from(self.structfmt, - bytestream)[0], self.bytelen + def read(self, io_in, otherfields): + return try_unpack(self.name, io_in, self.structfmt, empty_ok=True) class ShortChannelIDType(IntegerType): @@ -110,30 +124,24 @@ class TruncatedIntType(FieldType): .format(a, self.name)) return int(a), b - def val_to_bin(self, v, otherfields): + def write(self, io_out, v, otherfields): binval = struct.pack('>Q', v) while len(binval) != 0 and binval[0] == 0: binval = binval[1:] if len(binval) > self.maxbytes: raise ValueError('{} exceeds maximum {} capacity' .format(v, self.name)) - return binval - - def val_from_bin(self, bytestream, otherfields): - "Returns value, bytesused" - binval = bytes() - while len(binval) < len(bytestream): - if len(binval) == 0 and bytestream[len(binval)] == 0: - raise ValueError('{} encoding is not minimal: {}' - .format(self.name, bytestream)) - binval += bytes([bytestream[len(binval)]]) + io_out.write(binval) + def read(self, io_in, otherfields): + binval = io_in.read() if len(binval) > self.maxbytes: raise ValueError('{} is too long for {}'.format(binval, self.name)) - + if len(binval) > 0 and binval[0] == 0: + raise ValueError('{} encoding is not minimal: {}' + .format(self.name, binval)) # Pad with zeroes and convert as u64 - return (struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0], - len(binval)) + return struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0] class FundamentalHexType(FieldType): @@ -154,16 +162,18 @@ class FundamentalHexType(FieldType): raise ValueError("Length of {} != {}", a, self.bytelen) return ret, b - def val_to_bin(self, v, otherfields): + def write(self, io_out, v, otherfields): if len(bytes(v)) != self.bytelen: raise ValueError("Length of {} != {}", v, self.bytelen) - return bytes(v) + io_out.write(v) - def val_from_bin(self, bytestream, otherfields): - "Returns value, size from bytestream" - if self.bytelen > len(bytestream): + def read(self, io_in, otherfields): + val = io_in.read(self.bytelen) + if len(val) == 0: + return None + elif len(val) != self.bytelen: raise ValueError('{}: not enough remaining'.format(self)) - return bytestream[:self.bytelen], self.bytelen + return val class BigSizeType(FieldType): @@ -177,37 +187,34 @@ class BigSizeType(FieldType): # For the convenience of TLV header parsing @staticmethod - def to_bin(v): + def write(io_out, v, otherfields=None): if v < 253: - return bytes([v]) + io_out.write(bytes([v])) elif v < 2**16: - return bytes([253]) + struct.pack('>H', v) + io_out.write(bytes([253]) + struct.pack('>H', v)) elif v < 2**32: - return bytes([254]) + struct.pack('>I', v) + io_out.write(bytes([254]) + struct.pack('>I', v)) else: - return bytes([255]) + struct.pack('>Q', v) + io_out.write(bytes([255]) + struct.pack('>Q', v)) @staticmethod - def from_bin(bytestream): - "Returns value, bytesused" - if bytestream[0] < 253: - return int(bytestream[0]), 1 - elif bytestream[0] == 253: - return struct.unpack_from('>H', bytestream[1:])[0], 3 - elif bytestream[0] == 254: - return struct.unpack_from('>I', bytestream[1:])[0], 5 + def read(io_in, otherfields=None): + "Returns value, or None on EOF" + b = io_in.read(1) + if len(b) == 0: + return None + if b[0] < 253: + return int(b[0]) + elif b[0] == 253: + return try_unpack('BigSize', io_in, '>H', empty_ok=False) + elif b[0] == 254: + return try_unpack('BigSize', io_in, '>I', empty_ok=False) else: - return struct.unpack_from('>Q', bytestream[1:])[0], 9 + return try_unpack('BigSize', io_in, '>Q', empty_ok=False) def val_to_str(self, v, otherfields): return "{}".format(int(v)) - def val_to_bin(self, v, otherfields): - return self.to_bin(v) - - def val_from_bin(self, bytestream, otherfields): - return self.from_bin(bytestream) - def fundamental_types(): # From 01-messaging.md#fundamental-types: diff --git a/contrib/pyln-proto/pyln/proto/message/message.py b/contrib/pyln-proto/pyln/proto/message/message.py index 0f7a00f9a..e43d497a7 100644 --- a/contrib/pyln-proto/pyln/proto/message/message.py +++ b/contrib/pyln-proto/pyln/proto/message/message.py @@ -1,5 +1,6 @@ import struct -from .fundamental_types import fundamental_types, BigSizeType, split_field +import io +from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack from .array_types import ( SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType ) @@ -253,24 +254,21 @@ inherit from this too. return '{' + s + '}' - def val_to_bin(self, v, otherfields): + def write(self, io_out, v, otherfields): self._raise_if_badvals(v) - b = bytes() for fname, val in v.items(): field = self.find_field(fname) - b += field.fieldtype.val_to_bin(val, otherfields) - return b + field.fieldtype.write(io_out, val, otherfields) - def val_from_bin(self, bytestream, otherfields): - totsize = 0 + def read(self, io_in, otherfields): vals = {} for field in self.fields: - val, size = field.fieldtype.val_from_bin(bytestream[totsize:], - otherfields) - totsize += size + val = field.fieldtype.read(io_in, otherfields) + if val is None: + raise ValueError("{}.{}: short read".format(self, field)) vals[field.name] = val - return vals, totsize + return vals @staticmethod def field_from_csv(namespace, parts): @@ -433,17 +431,15 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8, return '{' + s + '}' - def val_to_bin(self, v, otherfields): - b = bytes() - + def write(self, iobuf, v, otherfields): # If they didn't specify this tlvstream, it's empty. if v is None: - return b + return # Make a tuple of (fieldnum, val_to_bin, val) so we can sort into # ascending order as TLV spec requires. - def copy_val(val, otherfields): - return val + def write_raw_val(iobuf, val, otherfields): + iobuf.write(val) def get_value(tup): """Get value from num, fun, val tuple""" @@ -454,43 +450,40 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8, f = self.find_field(fieldname) if f is None: # fieldname can be an integer for a raw field. - ordered.append((int(fieldname), copy_val, v[fieldname])) + ordered.append((int(fieldname), write_raw_val, v[fieldname])) else: - ordered.append((f.number, f.val_to_bin, v[fieldname])) + ordered.append((f.number, f.write, v[fieldname])) ordered.sort(key=get_value) - for tup in ordered: - value = tup[1](tup[2], otherfields) - b += (BigSizeType.to_bin(tup[0]) - + BigSizeType.to_bin(len(value)) - + value) + for typenum, writefunc, val in ordered: + buf = io.BytesIO() + writefunc(buf, val, otherfields) + BigSizeType.write(iobuf, typenum) + BigSizeType.write(iobuf, len(buf.getvalue())) + iobuf.write(buf.getvalue()) - return b - - def val_from_bin(self, bytestream, otherfields): - totsize = 0 + def read(self, io_in, otherfields): vals = {} - while totsize < len(bytestream): - tlv_type, size = BigSizeType.from_bin(bytestream[totsize:]) - totsize += size - tlv_len, size = BigSizeType.from_bin(bytestream[totsize:]) - totsize += size + while True: + tlv_type = BigSizeType.read(io_in) + if tlv_type is None: + return vals + + tlv_len = BigSizeType.read(io_in) + if tlv_len is None: + raise ValueError("{}: truncated tlv_len field".format(self)) + binval = io_in.read(tlv_len) + if len(binval) != tlv_len: + raise ValueError("{}: truncated tlv {} value" + .format(tlv_type, self)) f = self.find_field_by_number(tlv_type) if f is None: - vals[tlv_type] = bytestream[totsize:totsize + tlv_len] - size = len(vals[tlv_type]) + # Raw fields are allowed, just index by number. + vals[tlv_type] = binval else: - vals[f.name], size = f.val_from_bin(bytestream - [totsize:totsize - + tlv_len], - otherfields) - if size != tlv_len: - raise ValueError("Truncated tlv field") - totsize += size - - return vals, totsize + vals[f.name] = f.read(io.BytesIO(binval), otherfields) def name_and_val(self, name, v): """This is overridden by LengthFieldType to return nothing""" @@ -541,10 +534,15 @@ class Message(object): return missing @staticmethod - def from_bin(namespace, binmsg): - """Decode a binary wire format to a Message within that namespace""" - typenum = struct.unpack_from(">H", binmsg)[0] - off = 2 + def read(namespace, io_in): + """Read and decode a Message within that namespace. + +Returns None on EOF + + """ + typenum = try_unpack('message_type', io_in, ">H", empty_ok=True) + if typenum is None: + return None mtype = namespace.get_msgtype_by_number(typenum) if not mtype: @@ -552,16 +550,21 @@ class Message(object): fields = {} for f in mtype.fields: - v, size = f.fieldtype.val_from_bin(binmsg[off:], fields) - off += size - fields[f.name] = v + fields[f.name] = f.fieldtype.read(io_in, fields) + if fields[f.name] is None: + # optional fields are OK to be missing at end! + raise ValueError('{}: truncated at field {}' + .format(mtype, f.name)) return Message(mtype, **fields) @staticmethod def from_str(namespace, s, incomplete_ok=False): - """Decode a string to a Message within that namespace, of format -msgname [ field=...]*.""" + """Decode a string to a Message within that namespace. + +Format is msgname [ field=...]*. + + """ parts = s.split() mtype = namespace.get_msgtype(parts[0]) @@ -582,14 +585,17 @@ msgname [ field=...]*.""" return m - def to_bin(self): - """Encode a Message into its wire format (must not have missing -fields)""" + def write(self, io_out): + """Write a Message into its wire format. + +Must not have missing fields. + + """ if self.missing_fields(): raise ValueError('Missing fields: {}' .format(self.missing_fields())) - ret = struct.pack(">H", self.messagetype.number) + io_out.write(struct.pack(">H", self.messagetype.number)) for f in self.messagetype.fields: # Optional fields get val == None. Usually this means they don't # write anything, but length fields are an exception: they intuit @@ -598,8 +604,7 @@ fields)""" val = self.fields[f.name] else: val = None - ret += f.fieldtype.val_to_bin(val, self.fields) - return ret + f.fieldtype.write(io_out, val, self.fields) def to_str(self): """Encode a Message into a string""" diff --git a/contrib/pyln-proto/requirements.txt b/contrib/pyln-proto/requirements.txt index 98a17156b..4c579bfd1 100644 --- a/contrib/pyln-proto/requirements.txt +++ b/contrib/pyln-proto/requirements.txt @@ -2,3 +2,4 @@ bitstring==3.1.6 cryptography==2.8 coincurve==13.0.0 base58==1.0.2 +mypy diff --git a/contrib/pyln-proto/tests/test_array_types.py b/contrib/pyln-proto/tests/test_array_types.py index 6ec80c4f2..caf1a4bf3 100644 --- a/contrib/pyln-proto/tests/test_array_types.py +++ b/contrib/pyln-proto/tests/test_array_types.py @@ -1,6 +1,7 @@ #! /usr/bin/python3 from pyln.proto.message.fundamental_types import fundamental_types from pyln.proto.message.array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType, LengthFieldType +import io def test_sized_array(): @@ -32,9 +33,11 @@ def test_sized_array(): + [0, 0, 10, 0, 0, 11, 0, 12])]]: v, _ = arrtype.val_from_str(s) assert arrtype.val_to_str(v, None) == s - v2, _ = arrtype.val_from_bin(b, None) + v2 = arrtype.read(io.BytesIO(b), None) assert v2 == v - assert arrtype.val_to_bin(v, None) == b + buf = io.BytesIO() + arrtype.write(buf, v, None) + assert buf.getvalue() == b def test_ellipsis_array(): @@ -52,23 +55,25 @@ def test_ellipsis_array(): def __init__(self, name): self.name = name - for test in [[EllipsisArrayType(dummy("test1"), "test_arr", byte), - "00010203", - bytes([0, 1, 2, 3])], - [EllipsisArrayType(dummy("test2"), "test_arr", u16), - "[0,1,2,256]", - bytes([0, 0, 0, 1, 0, 2, 1, 0])], - [EllipsisArrayType(dummy("test3"), "test_arr", scid), - "[1x2x3,4x5x6,7x8x9,10x11x12]", - bytes([0, 0, 1, 0, 0, 2, 0, 3] - + [0, 0, 4, 0, 0, 5, 0, 6] - + [0, 0, 7, 0, 0, 8, 0, 9] - + [0, 0, 10, 0, 0, 11, 0, 12])]]: - v, _ = test[0].val_from_str(test[1]) - assert test[0].val_to_str(v, None) == test[1] - v2, _ = test[0].val_from_bin(test[2], None) + for arrtype, s, b in [[EllipsisArrayType(dummy("test1"), "test_arr", byte), + "00010203", + bytes([0, 1, 2, 3])], + [EllipsisArrayType(dummy("test2"), "test_arr", u16), + "[0,1,2,256]", + bytes([0, 0, 0, 1, 0, 2, 1, 0])], + [EllipsisArrayType(dummy("test3"), "test_arr", scid), + "[1x2x3,4x5x6,7x8x9,10x11x12]", + bytes([0, 0, 1, 0, 0, 2, 0, 3] + + [0, 0, 4, 0, 0, 5, 0, 6] + + [0, 0, 7, 0, 0, 8, 0, 9] + + [0, 0, 10, 0, 0, 11, 0, 12])]]: + v, _ = arrtype.val_from_str(s) + assert arrtype.val_to_str(v, None) == s + v2 = arrtype.read(io.BytesIO(b), None) assert v2 == v - assert test[0].val_to_bin(v, None) == test[2] + buf = io.BytesIO() + arrtype.write(buf, v, None) + assert buf.getvalue() == b def test_dynamic_array(): @@ -93,27 +98,29 @@ def test_dynamic_array(): lenfield = field_dummy('lenfield', LengthFieldType(u16)) - for test in [[DynamicArrayType(dummy("test1"), "test_arr", byte, - lenfield), - "00010203", - bytes([0, 1, 2, 3])], - [DynamicArrayType(dummy("test2"), "test_arr", u16, - lenfield), - "[0,1,2,256]", - bytes([0, 0, 0, 1, 0, 2, 1, 0])], - [DynamicArrayType(dummy("test3"), "test_arr", scid, - lenfield), - "[1x2x3,4x5x6,7x8x9,10x11x12]", - bytes([0, 0, 1, 0, 0, 2, 0, 3] - + [0, 0, 4, 0, 0, 5, 0, 6] - + [0, 0, 7, 0, 0, 8, 0, 9] - + [0, 0, 10, 0, 0, 11, 0, 12])]]: + for arrtype, s, b in [[DynamicArrayType(dummy("test1"), "test_arr", byte, + lenfield), + "00010203", + bytes([0, 1, 2, 3])], + [DynamicArrayType(dummy("test2"), "test_arr", u16, + lenfield), + "[0,1,2,256]", + bytes([0, 0, 0, 1, 0, 2, 1, 0])], + [DynamicArrayType(dummy("test3"), "test_arr", scid, + lenfield), + "[1x2x3,4x5x6,7x8x9,10x11x12]", + bytes([0, 0, 1, 0, 0, 2, 0, 3] + + [0, 0, 4, 0, 0, 5, 0, 6] + + [0, 0, 7, 0, 0, 8, 0, 9] + + [0, 0, 10, 0, 0, 11, 0, 12])]]: - lenfield.fieldtype.add_length_for(field_dummy(test[1], test[0])) - v, _ = test[0].val_from_str(test[1]) - otherfields = {test[1]: v} - assert test[0].val_to_str(v, otherfields) == test[1] - v2, _ = test[0].val_from_bin(test[2], otherfields) + lenfield.fieldtype.add_length_for(field_dummy(s, arrtype)) + v, _ = arrtype.val_from_str(s) + otherfields = {s: v} + assert arrtype.val_to_str(v, otherfields) == s + v2 = arrtype.read(io.BytesIO(b), otherfields) assert v2 == v - assert test[0].val_to_bin(v, otherfields) == test[2] + buf = io.BytesIO() + arrtype.write(buf, v, None) + assert buf.getvalue() == b lenfield.fieldtype.len_for = [] diff --git a/contrib/pyln-proto/tests/test_fundamental_types.py b/contrib/pyln-proto/tests/test_fundamental_types.py index a26e0aca9..5dd5ac345 100644 --- a/contrib/pyln-proto/tests/test_fundamental_types.py +++ b/contrib/pyln-proto/tests/test_fundamental_types.py @@ -1,5 +1,6 @@ #! /usr/bin/python3 from pyln.proto.message.fundamental_types import fundamental_types +import io def test_fundamental_types(): @@ -67,8 +68,10 @@ def test_fundamental_types(): for test in expect[t.name]: v, _ = t.val_from_str(test[0]) assert t.val_to_str(v, None) == test[0] - v2, _ = t.val_from_bin(test[1], None) + v2 = t.read(io.BytesIO(test[1]), None) assert v2 == v - assert t.val_to_bin(v, None) == test[1] + buf = io.BytesIO() + t.write(buf, v, None) + assert buf.getvalue() == test[1] assert untested == set(['varint']) diff --git a/contrib/pyln-proto/tests/test_message.py b/contrib/pyln-proto/tests/test_message.py index b2ab2fb6e..186880a79 100644 --- a/contrib/pyln-proto/tests/test_message.py +++ b/contrib/pyln-proto/tests/test_message.py @@ -1,6 +1,7 @@ #! /usr/bin/python3 from pyln.proto.message import MessageNamespace, Message import pytest +import io def test_fundamental(): @@ -51,9 +52,10 @@ def test_static_array(): + [0, 0, 10, 0, 0, 11, 0, 12])]]: m = Message.from_str(ns, test[0]) assert m.to_str() == test[0] - v = m.to_bin() - assert v == test[1] - assert Message.from_bin(ns, test[1]).to_str() == test[0] + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == test[1] + assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] def test_subtype(): @@ -78,9 +80,10 @@ def test_subtype(): + [0, 0, 0, 7, 0, 0, 0, 8])]]: m = Message.from_str(ns, test[0]) assert m.to_str() == test[0] - v = m.to_bin() - assert v == test[1] - assert Message.from_bin(ns, test[1]).to_str() == test[0] + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == test[1] + assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] # Test missing field logic. m = Message.from_str(ns, "test1", incomplete_ok=True) @@ -111,16 +114,19 @@ def test_tlv(): + [253, 0, 255, 4, 1, 2, 3, 4])]]: m = Message.from_str(ns, test[0]) assert m.to_str() == test[0] - v = m.to_bin() - assert v == test[1] - assert Message.from_bin(ns, test[1]).to_str() == test[0] + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == test[1] + assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] # Ordering test (turns into canonical ordering) m = Message.from_str(ns, 'test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304},4=010203}') - assert m.to_bin() == bytes([0, 1] - + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] - + [4, 3, 1, 2, 3] - + [253, 0, 255, 4, 1, 2, 3, 4]) + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == bytes([0, 1] + + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + + [4, 3, 1, 2, 3] + + [253, 0, 255, 4, 1, 2, 3, 4]) def test_message_constructor(): @@ -135,10 +141,12 @@ def test_message_constructor(): m = Message(ns.get_msgtype('test1'), tlvs='{tlv1={field1=01020304,field2=5}' ',tlv2={field3=01020304},4=010203}') - assert m.to_bin() == bytes([0, 1] - + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] - + [4, 3, 1, 2, 3] - + [253, 0, 255, 4, 1, 2, 3, 4]) + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == bytes([0, 1] + + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + + [4, 3, 1, 2, 3] + + [253, 0, 255, 4, 1, 2, 3, 4]) def test_dynamic_array(): @@ -151,13 +159,15 @@ def test_dynamic_array(): # This one is fine. m = Message(ns.get_msgtype('test1'), arr1='01020304', arr2='[1,2,3,4]') - assert m.to_bin() == bytes([0, 1] - + [0, 4] - + [1, 2, 3, 4] - + [0, 0, 0, 1, - 0, 0, 0, 2, - 0, 0, 0, 3, - 0, 0, 0, 4]) + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == bytes([0, 1] + + [0, 4] + + [1, 2, 3, 4] + + [0, 0, 0, 1, + 0, 0, 0, 2, + 0, 0, 0, 3, + 0, 0, 0, 4]) # These ones are not with pytest.raises(ValueError, match='Inconsistent length.*count'):