From 5142dc81f6902462cb3d892506107cce4adc9964 Mon Sep 17 00:00:00 2001 From: niftynei Date: Wed, 24 Mar 2021 16:36:32 -0500 Subject: [PATCH] pyln-proto: write out length of arrays of subtypes to wire We weren't writing out the length of a nested subtype's dynamicarraylenght, now we do. The trick is to iterate through the fields on a subtype (since the length field is added separately) and to also iterate down through the otherfield values as we 'descend' --- .../pyln/proto/message/array_types.py | 15 ++++++++---- .../pyln-proto/pyln/proto/message/message.py | 15 ++++++++---- contrib/pyln-proto/tests/test_message.py | 23 +++++++++++++++++++ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/contrib/pyln-proto/pyln/proto/message/array_types.py b/contrib/pyln-proto/pyln/proto/message/array_types.py index 60c7011da..8ab0c8ef3 100644 --- a/contrib/pyln-proto/pyln/proto/message/array_types.py +++ b/contrib/pyln-proto/pyln/proto/message/array_types.py @@ -48,9 +48,16 @@ wants an array of some type. return [self.elemtype.val_to_py(i, otherfields) for i in v] - def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None: - for i in v: - self.elemtype.write(io_out, i, otherfields) + def write(self, io_out: BufferedIOBase, vals: List[Any], otherfields: Dict[str, Any]) -> None: + name = self.name.split('.')[1] + if otherfields and name in otherfields: + otherfields = otherfields[name] + for i, val in enumerate(vals): + if isinstance(otherfields, list) and len(otherfields) > i: + fields = otherfields[i] + else: + fields = otherfields + self.elemtype.write(io_out, val, fields) def read_arr(self, io_in: BufferedIOBase, otherfields: Dict[str, Any], arraysize: Optional[int]) -> List[Any]: """arraysize None means take rest of io entirely and exactly""" @@ -179,7 +186,7 @@ they're implied by the length of other fields""" if mylen != len(otherfields[lens.name]): return [fieldname] # Field might be missing! - if lens.name in otherfields: + if otherfields and lens.name in otherfields: mylen = len(otherfields[lens.name]) return [] diff --git a/contrib/pyln-proto/pyln/proto/message/message.py b/contrib/pyln-proto/pyln/proto/message/message.py index d7c709fb9..69474e462 100644 --- a/contrib/pyln-proto/pyln/proto/message/message.py +++ b/contrib/pyln-proto/pyln/proto/message/message.py @@ -297,10 +297,17 @@ other types. Since 'msgtype' is almost identical, it inherits from this too. def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None: self._raise_if_badvals(v) - for fname, val in v.items(): - field = self.find_field(fname) - assert field - field.fieldtype.write(io_out, val, otherfields) + for f in self.fields: + if f.name in v: + val = v[f.name] + else: + if f.option is not None: + raise ValueError("Missing field {} {}".format(f.name, otherfields)) + val = None + + if self.name in otherfields: + otherfields = otherfields[self.name] + f.fieldtype.write(io_out, val, otherfields) def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]: vals = {} diff --git a/contrib/pyln-proto/tests/test_message.py b/contrib/pyln-proto/tests/test_message.py index 3e79faf9d..c88297ab0 100644 --- a/contrib/pyln-proto/tests/test_message.py +++ b/contrib/pyln-proto/tests/test_message.py @@ -90,6 +90,29 @@ def test_subtype(): assert m.missing_fields() +def test_subtype_array(): + ns = MessageNamespace() + ns.load_csv(['msgtype,tx_signatures,1', + 'msgdata,tx_signatures,num_witnesses,u16,', + 'msgdata,tx_signatures,witness_stack,witness_stack,num_witnesses', + 'subtype,witness_stack', + 'subtypedata,witness_stack,num_input_witness,u16,', + 'subtypedata,witness_stack,witness_element,witness_element,num_input_witness', + 'subtype,witness_element', + 'subtypedata,witness_element,len,u16,', + 'subtypedata,witness_element,witness,byte,len']) + + for test in [["tx_signatures witness_stack=" + "[{witness_element=[{witness=3045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01},{witness=02d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b}]}]", + bytes.fromhex('00010001000200483045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01002102d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b')]]: + m = Message.from_str(ns, test[0]) + assert m.to_str() == test[0] + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue().hex() == test[1].hex() + assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] + + def test_tlv(): ns = MessageNamespace() ns.load_csv(['msgtype,test1,1',