diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 7a6ff8fc2..ead9ffcc5 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -504,8 +504,9 @@ class Message(object): elif f.is_tlv: if not f.is_variable_size(): raise TypeError('TLV {} not variable size'.format(f.name)) - subcalls.append('if (!fromwire__{tlv_name}(ctx, &cursor, &plen, &{tlv_len}, {tlv_name}))' + subcalls.append('{tlv_name} = fromwire__{tlv_name}(ctx, &cursor, &plen, &{tlv_len});' .format(tlv_name=f.name, tlv_len=f.lenvar)) + subcalls.append('if (!{tlv_name})'.format(tlv_name=f.name)) subcalls.append('return false;') elif f.is_variable_size(): subcalls.append("//2nd case {name}".format(name=f.name)) @@ -846,10 +847,12 @@ tlv__type_impl_towire_template = """static void towire__{tlv_name}(const tal_t * {fields}}} """ -tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, size_t *plen, const u16 *len, struct {tlv_name} *{tlv_name}) {{ +tlv__type_impl_fromwire_template = """static struct {tlv_name} *fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, size_t *plen, const u16 *len) {{ \tu8 msg_type, msg_len; \tif (*plen < *len) -\t\treturn false; +\t\treturn NULL; + +\tstruct {tlv_name} *{tlv_name} = talz(ctx, struct {tlv_name}); \twhile (*plen) {{ \t\tmsg_type = fromwire_u8(p, plen); @@ -865,14 +868,25 @@ tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal \t\t\tplen -= msg_len; \t\t}} \t}} -\treturn *p != NULL; +\tif (!*p) {{ +\t\ttal_free({tlv_name}); +\t\treturn NULL; +\t}} +\treturn {tlv_name}; }} """ case_tmpl = """\t\tcase {tlv_msg_enum}: -\t\t\t{tlv_name}->{tlv_msg_name} = tal(ctx, struct tlv_msg_{tlv_msg_name}); -\t\t\tif (!fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}*p, plen, msg_len, {tlv_name}->{tlv_msg_name})) -\t\t\t\treturn false; +\t\t\tif ({tlv_name}->{tlv_msg_name} != NULL) {{ +\t\t\t\tfromwire_fail(p, plen); +\t\t\t\ttal_free({tlv_name}); +\t\t\t\treturn NULL; +\t\t\t}} +\t\t\t{tlv_name}->{tlv_msg_name} = tal({tlv_name}, struct tlv_msg_{tlv_msg_name}); +\t\t\tif (!fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}*p, plen, msg_len, {tlv_name}->{tlv_msg_name})) {{ +\t\t\t\ttal_free({tlv_name}); +\t\t\t\treturn NULL; +\t\t\t}} \t\t\tbreak; """ @@ -906,7 +920,7 @@ def print_tlv_towire(tlv_field_name, messages): def print_tlv_fromwire(tlv_field_name, messages): cases = "" for m in messages: - ctx_arg = 'ctx, ' if m.has_variable_fields else '' + ctx_arg = tlv_field_name + ', ' if m.has_variable_fields else '' cases += case_tmpl.format(ctx_arg=ctx_arg, tlv_msg_enum=m.enum.name, tlv_name=tlv_field_name,