mirror of
https://github.com/aljazceru/lightning.git
synced 2025-12-19 15:14:23 +01:00
tlv: add fromwire_ methods for TLV structs
This commit is contained in:
committed by
Rusty Russell
parent
ef610dcab3
commit
6f2e70a6ac
@@ -232,6 +232,21 @@ fromwire_impl_templ = """bool fromwire_{name}({ctx}const void *p{args})
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
fromwire_tlv_impl_templ = """static bool _fromwire_{tlv_name}_{name}({ctx}{args})
|
||||||
|
{{
|
||||||
|
|
||||||
|
\tsize_t start_len, plen;
|
||||||
|
{fields}
|
||||||
|
\tconst u8 *cursor = p;
|
||||||
|
\tplen = tal_count(p);
|
||||||
|
\tif (plen < len)
|
||||||
|
\t\treturn false;
|
||||||
|
\tstart_len = plen;
|
||||||
|
{subcalls}
|
||||||
|
\treturn cursor != NULL && (start_len - plen == len);
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
fromwire_header_templ = """bool fromwire_{name}({ctx}const void *p{args});
|
fromwire_header_templ = """bool fromwire_{name}({ctx}const void *p{args});
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -384,7 +399,72 @@ class Message(object):
|
|||||||
subcalls.append('fromwire_{}(&cursor, &plen, {} + i);'
|
subcalls.append('fromwire_{}(&cursor, &plen, {} + i);'
|
||||||
.format(basetype, name))
|
.format(basetype, name))
|
||||||
|
|
||||||
def print_fromwire(self, is_header):
|
def print_tlv_fromwire(self, tlv_name):
|
||||||
|
""" prints fromwire function definition for a TLV message.
|
||||||
|
these are significantly different in that they take in a struct
|
||||||
|
to populate, instead of fields, as well as a length to read in
|
||||||
|
"""
|
||||||
|
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
|
||||||
|
args = 'const void *p, const u16 len, struct _tlv_msg_{name} *{name}'.format(name=self.name)
|
||||||
|
fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var]
|
||||||
|
subcalls = CCode()
|
||||||
|
for f in self.fields:
|
||||||
|
basetype = f.basetype()
|
||||||
|
if f.is_tlv:
|
||||||
|
raise TypeError('Nested TLVs arent allowed!!')
|
||||||
|
elif f.optional:
|
||||||
|
raise TypeError('Optional fields on TLV messages not currently supported')
|
||||||
|
|
||||||
|
for c in f.comments:
|
||||||
|
subcalls.append('/*{} */'.format(c))
|
||||||
|
|
||||||
|
if f.is_padding():
|
||||||
|
subcalls.append('fromwire_pad(&cursor, &plen, {});'
|
||||||
|
.format(f.num_elems))
|
||||||
|
elif f.is_array():
|
||||||
|
name = '*{}->{}'.format(self.name, f.name)
|
||||||
|
self.print_fromwire_array('ctx', subcalls, basetype, f, name,
|
||||||
|
f.num_elems)
|
||||||
|
elif f.is_variable_size():
|
||||||
|
subcalls.append("// 2nd case {name}".format(name=f.name))
|
||||||
|
typename = f.fieldtype.name
|
||||||
|
# If structs are varlen, need array of ptrs to them.
|
||||||
|
if basetype in varlen_structs:
|
||||||
|
typename += ' *'
|
||||||
|
subcalls.append('{}->{} = {} ? tal_arr(ctx, {}, {}) : NULL;'
|
||||||
|
.format(self.name, f.name, f.lenvar, typename, f.lenvar))
|
||||||
|
|
||||||
|
name = '{}->{}'.format(self.name, f.name)
|
||||||
|
# Allocate these off the array itself, if they need alloc.
|
||||||
|
self.print_fromwire_array('*' + f.name, subcalls, basetype, f,
|
||||||
|
name, f.lenvar)
|
||||||
|
else:
|
||||||
|
if f.is_assignable():
|
||||||
|
if f.is_len_var:
|
||||||
|
s = '{} = fromwire_{}(&cursor, &plen);'.format(f.name, basetype)
|
||||||
|
else:
|
||||||
|
s = '{}->{} = fromwire_{}(&cursor, &plen);'.format(
|
||||||
|
self.name, f.name, basetype)
|
||||||
|
else:
|
||||||
|
s = 'fromwire_{}(&cursor, &plen, *{}->{});'.format(
|
||||||
|
basetype, self.name, f.name)
|
||||||
|
subcalls.append(s)
|
||||||
|
|
||||||
|
return fromwire_tlv_impl_templ.format(
|
||||||
|
tlv_name=tlv_name,
|
||||||
|
name=self.name,
|
||||||
|
ctx=ctx_arg,
|
||||||
|
args=''.join(args),
|
||||||
|
fields=''.join(fields),
|
||||||
|
subcalls=str(subcalls)
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_fromwire(self, is_header, tlv_name):
|
||||||
|
if self.is_tlv:
|
||||||
|
if is_header:
|
||||||
|
return ''
|
||||||
|
return self.print_tlv_fromwire(tlv_name)
|
||||||
|
|
||||||
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
|
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
|
||||||
|
|
||||||
args = []
|
args = []
|
||||||
@@ -394,6 +474,8 @@ class Message(object):
|
|||||||
continue
|
continue
|
||||||
elif f.is_array():
|
elif f.is_array():
|
||||||
args.append(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems))
|
args.append(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems))
|
||||||
|
elif f.is_tlv:
|
||||||
|
args.append(', struct _{} *{}'.format(f.name, f.name))
|
||||||
else:
|
else:
|
||||||
ptrs = '*'
|
ptrs = '*'
|
||||||
# If we're handing a variable array, we need a ptr-to-ptr.
|
# If we're handing a variable array, we need a ptr-to-ptr.
|
||||||
@@ -421,6 +503,12 @@ class Message(object):
|
|||||||
elif f.is_array():
|
elif f.is_array():
|
||||||
self.print_fromwire_array('ctx', subcalls, basetype, f, f.name,
|
self.print_fromwire_array('ctx', subcalls, basetype, f, f.name,
|
||||||
f.num_elems)
|
f.num_elems)
|
||||||
|
elif f.is_tlv:
|
||||||
|
if not f.is_variable_size():
|
||||||
|
raise TypeError('TLV {} not variable size'.format(f.name))
|
||||||
|
subcalls.append('if (!fromwire__{tlv_name}(ctx, &cursor, &{tlv_len}, {tlv_name}))'
|
||||||
|
.format(tlv_name=f.name, tlv_len=f.lenvar))
|
||||||
|
subcalls.append('return false;')
|
||||||
elif f.is_variable_size():
|
elif f.is_variable_size():
|
||||||
subcalls.append("//2nd case {name}".format(name=f.name))
|
subcalls.append("//2nd case {name}".format(name=f.name))
|
||||||
typename = f.fieldtype.name
|
typename = f.fieldtype.name
|
||||||
@@ -472,21 +560,67 @@ class Message(object):
|
|||||||
subcalls=str(subcalls)
|
subcalls=str(subcalls)
|
||||||
)
|
)
|
||||||
|
|
||||||
def print_towire_array(self, subcalls, basetype, f, num_elems):
|
def print_towire_array(self, subcalls, basetype, f, num_elems, is_tlv=False):
|
||||||
|
p_ref = '' if is_tlv else '&'
|
||||||
|
msg_name = self.name + '->' if is_tlv else ''
|
||||||
if f.has_array_helper():
|
if f.has_array_helper():
|
||||||
subcalls.append('towire_{}_array(&p, {}, {});'
|
subcalls.append('towire_{}_array({}p, {}{}, {});'
|
||||||
.format(basetype, f.name, num_elems))
|
.format(basetype, p_ref, msg_name, f.name, num_elems))
|
||||||
else:
|
else:
|
||||||
subcalls.append('for (size_t i = 0; i < {}; i++)'
|
subcalls.append('for (size_t i = 0; i < {}; i++)'
|
||||||
.format(num_elems))
|
.format(num_elems))
|
||||||
if f.fieldtype.is_assignable() or basetype in varlen_structs:
|
if f.fieldtype.is_assignable() or basetype in varlen_structs:
|
||||||
subcalls.append('towire_{}(&p, {}[i]);'
|
subcalls.append('towire_{}({}p, {}{}[i]);'
|
||||||
.format(basetype, f.name))
|
.format(basetype, p_ref, msg_name, f.name))
|
||||||
else:
|
else:
|
||||||
subcalls.append('towire_{}(&p, {} + i);'
|
subcalls.append('towire_{}({}p, {}{} + i);'
|
||||||
.format(basetype, f.name))
|
.format(basetype, p_ref, msg_name, f.name))
|
||||||
|
|
||||||
def print_towire(self, is_header):
|
def print_tlv_towire(self, tlv_name):
|
||||||
|
""" prints towire function definition for a TLV message."""
|
||||||
|
field_decls = []
|
||||||
|
for f in self.fields:
|
||||||
|
if f.is_tlv:
|
||||||
|
raise TypeError("Nested TLVs aren't allowed!! {}->{}".format(tlv_name, f.name))
|
||||||
|
elif f.optional:
|
||||||
|
raise TypeError("Optional fields on TLV messages not currently supported. {}->{}".format(tlv_name, f.name))
|
||||||
|
if f.is_len_var:
|
||||||
|
field_decls.append('\t{0} {1} = tal_count(&{2}->{3});'.format(
|
||||||
|
f.fieldtype.name, f.name, self.name, f.lenvar_for.name
|
||||||
|
))
|
||||||
|
|
||||||
|
subcalls = CCode()
|
||||||
|
for f in self.fields:
|
||||||
|
basetype = f.fieldtype.name
|
||||||
|
if basetype.startswith('struct '):
|
||||||
|
basetype = basetype[7:]
|
||||||
|
elif basetype.startswith('enum '):
|
||||||
|
basetype = basetype[5:]
|
||||||
|
|
||||||
|
for c in f.comments:
|
||||||
|
subcalls.append('/*{} */'.format(c))
|
||||||
|
|
||||||
|
if f.is_padding():
|
||||||
|
subcalls.append('towire_pad(p, {});'.format(f.num_elems))
|
||||||
|
elif f.is_array():
|
||||||
|
self.print_towire_array(subcalls, basetype, f, f.num_elems, is_tlv=True)
|
||||||
|
elif f.is_variable_size():
|
||||||
|
self.print_towire_array(subcalls, basetype, f, f.lenvar, is_tlv=True)
|
||||||
|
elif f.is_len_var:
|
||||||
|
subcalls.append('towire_{}(p, {});'.format(basetype, f.name))
|
||||||
|
else:
|
||||||
|
subcalls.append('towire_{}(p, {}->{});'.format(basetype, self.name, f.name))
|
||||||
|
return tlv_message_towire_stub.format(
|
||||||
|
tlv_name=tlv_name,
|
||||||
|
name=self.name,
|
||||||
|
field_decls='\n'.join(field_decls),
|
||||||
|
subcalls=str(subcalls))
|
||||||
|
|
||||||
|
def print_towire(self, is_header, tlv_name):
|
||||||
|
if self.is_tlv:
|
||||||
|
if is_header:
|
||||||
|
return ''
|
||||||
|
return self.print_tlv_towire(tlv_name)
|
||||||
template = towire_header_templ if is_header else towire_impl_templ
|
template = towire_header_templ if is_header else towire_impl_templ
|
||||||
args = []
|
args = []
|
||||||
for f in self.fields:
|
for f in self.fields:
|
||||||
@@ -494,6 +628,8 @@ class Message(object):
|
|||||||
continue
|
continue
|
||||||
if f.is_array():
|
if f.is_array():
|
||||||
args.append(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems))
|
args.append(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems))
|
||||||
|
elif f.is_tlv:
|
||||||
|
args.append(', const struct _{} *{}'.format(f.name, f.name))
|
||||||
elif f.is_assignable():
|
elif f.is_assignable():
|
||||||
args.append(', {} {}'.format(f.fieldtype.name, f.name))
|
args.append(', {} {}'.format(f.fieldtype.name, f.name))
|
||||||
elif f.is_variable_size() and f.basetype() in varlen_structs:
|
elif f.is_variable_size() and f.basetype() in varlen_structs:
|
||||||
@@ -504,9 +640,14 @@ class Message(object):
|
|||||||
field_decls = []
|
field_decls = []
|
||||||
for f in self.fields:
|
for f in self.fields:
|
||||||
if f.is_len_var:
|
if f.is_len_var:
|
||||||
field_decls.append('\t{0} {1} = tal_count({2});'.format(
|
if f.lenvar_for.is_tlv:
|
||||||
f.fieldtype.name, f.name, f.lenvar_for.name
|
field_decls.append('\t{0} {1} = sizeof({2});'.format(
|
||||||
))
|
f.fieldtype.name, f.name, f.lenvar_for.name
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
field_decls.append('\t{0} {1} = tal_count({2});'.format(
|
||||||
|
f.fieldtype.name, f.name, f.lenvar_for.name
|
||||||
|
))
|
||||||
|
|
||||||
subcalls = CCode()
|
subcalls = CCode()
|
||||||
for f in self.fields:
|
for f in self.fields:
|
||||||
@@ -524,6 +665,11 @@ class Message(object):
|
|||||||
.format(f.num_elems))
|
.format(f.num_elems))
|
||||||
elif f.is_array():
|
elif f.is_array():
|
||||||
self.print_towire_array(subcalls, basetype, f, f.num_elems)
|
self.print_towire_array(subcalls, basetype, f, f.num_elems)
|
||||||
|
elif f.is_tlv:
|
||||||
|
if not f.is_variable_size():
|
||||||
|
raise TypeError('TLV {} not variable size'.format(f.name))
|
||||||
|
subcalls.append('towire__{tlv_name}(&p, {tlv_name});'.format(
|
||||||
|
tlv_name=f.name))
|
||||||
elif f.is_variable_size():
|
elif f.is_variable_size():
|
||||||
self.print_towire_array(subcalls, basetype, f, f.lenvar)
|
self.print_towire_array(subcalls, basetype, f, f.lenvar)
|
||||||
else:
|
else:
|
||||||
@@ -664,6 +810,13 @@ class Message(object):
|
|||||||
fields=str(fmt_fields))
|
fields=str(fmt_fields))
|
||||||
|
|
||||||
|
|
||||||
|
tlv_message_towire_stub = """static void _towire_{tlv_name}_{name}(u8 **p, struct _tlv_msg_{name} *{name}) {{
|
||||||
|
{field_decls}
|
||||||
|
{subcalls}
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
tlv_msg_struct_template = """
|
tlv_msg_struct_template = """
|
||||||
struct _tlv_msg_{msg_name} {{
|
struct _tlv_msg_{msg_name} {{
|
||||||
{fields}
|
{fields}
|
||||||
@@ -676,6 +829,87 @@ struct _{tlv_name} {{
|
|||||||
}};
|
}};
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
tlv__type_impl_towire_fields = """\tif ({tlv_name}->{name}) {{
|
||||||
|
\t\ttowire_u16(p, {enum});
|
||||||
|
\t\ttowire_u16(p, sizeof(*{tlv_name}->{name}));
|
||||||
|
\t\t_towire_{tlv_name}_{name}(p, {tlv_name}->{name});
|
||||||
|
\t}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
tlv__type_impl_towire_template = """static void towire__{tlv_name}(u8 **p, const struct _{tlv_name} *{tlv_name}) {{
|
||||||
|
{fields}}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, const u16 *len, struct _{tlv_name} *{tlv_name}) {{
|
||||||
|
\tu16 msg_type, msg_len;
|
||||||
|
\tconst u8 *cursor = *p;
|
||||||
|
\tsize_t plen = tal_count(p);
|
||||||
|
\tif (plen != *len)
|
||||||
|
\t\treturn false;
|
||||||
|
|
||||||
|
\twhile (cursor && plen) {{
|
||||||
|
\t\tmsg_type = fromwire_u16(&cursor, &plen);
|
||||||
|
\t\tmsg_len = fromwire_u16(&cursor, &plen);
|
||||||
|
\t\tif (plen < msg_len) {{
|
||||||
|
\t\t\tfromwire_fail(&cursor, &plen);
|
||||||
|
\t\t\tbreak;
|
||||||
|
\t\t}}
|
||||||
|
\t\tswitch((enum {tlv_name}_type)msg_type) {{
|
||||||
|
{cases}\t\tdefault:
|
||||||
|
\t\t\t// FIXME: print a warning / message?
|
||||||
|
\t\t\tcursor += msg_len;
|
||||||
|
\t\t\tplen -= msg_len;
|
||||||
|
\t\t}}
|
||||||
|
\t}}
|
||||||
|
\treturn cursor != NULL;
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
case_tmpl = """\t\tcase {tlv_msg_enum}:
|
||||||
|
\t\t\tif (!_fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}cursor, msg_len, {tlv_name}->{tlv_msg_name}))
|
||||||
|
\t\t\t\treturn false;
|
||||||
|
\t\t\tbreak;
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_tlv_fromwires(tlv_fields):
|
||||||
|
fromwires = []
|
||||||
|
for field_name, messages in tlv_fields.items():
|
||||||
|
fromwires.append(print_tlv_fromwire(field_name, messages))
|
||||||
|
return fromwires
|
||||||
|
|
||||||
|
|
||||||
|
def build_tlv_towires(tlv_fields):
|
||||||
|
towires = []
|
||||||
|
for field_name, messages in tlv_fields.items():
|
||||||
|
towires.append(print_tlv_towire(field_name, messages))
|
||||||
|
return towires
|
||||||
|
|
||||||
|
|
||||||
|
def print_tlv_towire(tlv_field_name, messages):
|
||||||
|
fields = ""
|
||||||
|
for m in messages:
|
||||||
|
fields += tlv__type_impl_towire_fields.format(
|
||||||
|
tlv_name=tlv_field_name,
|
||||||
|
enum=m.enum.name,
|
||||||
|
name=m.name)
|
||||||
|
return tlv__type_impl_towire_template.format(
|
||||||
|
tlv_name=tlv_field_name,
|
||||||
|
fields=fields)
|
||||||
|
|
||||||
|
|
||||||
|
def print_tlv_fromwire(tlv_field_name, messages):
|
||||||
|
cases = ""
|
||||||
|
for m in messages:
|
||||||
|
ctx_arg = 'ctx, ' if m.has_variable_fields else ''
|
||||||
|
cases += case_tmpl.format(ctx_arg=ctx_arg,
|
||||||
|
tlv_msg_enum=m.enum.name,
|
||||||
|
tlv_name=tlv_field_name,
|
||||||
|
tlv_msg_name=m.name)
|
||||||
|
return tlv__type_impl_fromwire_template.format(
|
||||||
|
tlv_name=tlv_field_name,
|
||||||
|
cases=cases)
|
||||||
|
|
||||||
|
|
||||||
def build_tlv_type_struct(name, messages):
|
def build_tlv_type_struct(name, messages):
|
||||||
inner_structs = CCode()
|
inner_structs = CCode()
|
||||||
@@ -752,7 +986,6 @@ def parse_tlv_file(tlv_field_name):
|
|||||||
# eg commit_sig,132
|
# eg commit_sig,132
|
||||||
tlv_msg = Message(parts[0], Enumtype("TLV_" + parts[0].upper(), parts[1]), tlv_comments, True)
|
tlv_msg = Message(parts[0], Enumtype("TLV_" + parts[0].upper(), parts[1]), tlv_comments, True)
|
||||||
tlv_messages.append(tlv_msg)
|
tlv_messages.append(tlv_msg)
|
||||||
messages.append(tlv_msg)
|
|
||||||
|
|
||||||
tlv_comments = []
|
tlv_comments = []
|
||||||
tlv_prevfield = None
|
tlv_prevfield = None
|
||||||
@@ -994,8 +1227,22 @@ printcases = ['case {enum.name}: printf("{enum.name}:\\n"); printwire_{name}("{n
|
|||||||
if options.printwire:
|
if options.printwire:
|
||||||
decls = [m.print_printwire(options.header) for m in messages + messages_with_option]
|
decls = [m.print_printwire(options.header) for m in messages + messages_with_option]
|
||||||
else:
|
else:
|
||||||
fromwire_decls = [m.print_fromwire(options.header) for m in messages + messages_with_option]
|
towire_decls = []
|
||||||
towire_decls = towire_decls = [m.print_towire(options.header) for m in messages + messages_with_option]
|
fromwire_decls = []
|
||||||
|
|
||||||
|
for tlv_field, tlv_messages in tlv_fields.items():
|
||||||
|
for m in tlv_messages:
|
||||||
|
towire_decls.append(m.print_towire(options.header, tlv_field))
|
||||||
|
fromwire_decls.append(m.print_fromwire(options.header, tlv_field))
|
||||||
|
|
||||||
|
if not options.header:
|
||||||
|
tlv_towires = build_tlv_towires(tlv_fields)
|
||||||
|
tlv_fromwires = build_tlv_fromwires(tlv_fields)
|
||||||
|
towire_decls += tlv_towires
|
||||||
|
fromwire_decls += tlv_fromwires
|
||||||
|
|
||||||
|
towire_decls += [m.print_towire(options.header, '') for m in messages + messages_with_option]
|
||||||
|
fromwire_decls += [m.print_fromwire(options.header, '') for m in messages + messages_with_option]
|
||||||
decls = fromwire_decls + towire_decls
|
decls = fromwire_decls + towire_decls
|
||||||
|
|
||||||
print(template.format(
|
print(template.format(
|
||||||
|
|||||||
Reference in New Issue
Block a user