mirror of
https://github.com/aljazceru/lightning.git
synced 2025-12-19 15:14:23 +01:00
bolt-gen: for wire messages, print out optional fields (if present)
optional fields should be printed, if they exist. so let's print them!
This commit is contained in:
committed by
Rusty Russell
parent
40154b35f0
commit
316edb39a4
@@ -33,6 +33,11 @@ void print${options.enum_name}_message(const u8 *msg)
|
||||
## definition for printing field sets
|
||||
<%def name="print_fieldset(fields, nested, cursor, plen)">
|
||||
% for f in fields:
|
||||
% if f.is_extension():
|
||||
if (plen <= 0)
|
||||
return;
|
||||
printf("(${','.join(f.extension_names)}):");
|
||||
% endif
|
||||
% if f.len_field_of:
|
||||
${f.type_obj.type_name()} ${f.name} = fromwire_${f.type_obj.name}(${cursor}, ${plen});${truncate_check(nested)} <% continue %> \
|
||||
% endif
|
||||
|
||||
@@ -38,7 +38,7 @@ def next_line(args, lines):
|
||||
|
||||
# Class definitions, to keep things classy
|
||||
class Field(object):
|
||||
def __init__(self, name, type_obj, extension=False,
|
||||
def __init__(self, name, type_obj, extensions=[],
|
||||
field_comments=[], optional=False):
|
||||
self.name = name
|
||||
self.type_obj = type_obj
|
||||
@@ -46,7 +46,7 @@ class Field(object):
|
||||
self.len_field_of = None
|
||||
self.len_field = None
|
||||
|
||||
self.is_extension = extension
|
||||
self.extension_names = extensions
|
||||
self.is_optional = optional
|
||||
self.field_comments = field_comments
|
||||
|
||||
@@ -79,7 +79,7 @@ class Field(object):
|
||||
return self.is_optional
|
||||
|
||||
def is_extension(self):
|
||||
return self.is_extension
|
||||
return bool(self.extension_names)
|
||||
|
||||
def size(self):
|
||||
if self.count:
|
||||
@@ -119,15 +119,11 @@ class Field(object):
|
||||
class FieldSet(object):
|
||||
def __init__(self):
|
||||
self.fields = OrderedDict()
|
||||
self.extension_fields = False
|
||||
self.len_fields = {}
|
||||
|
||||
def add_data_field(self, field_name, type_obj, count=1,
|
||||
is_extension=[], comments=[], optional=False):
|
||||
if is_extension:
|
||||
self.extension_fields = True
|
||||
|
||||
field = Field(field_name, type_obj, extension=bool(is_extension),
|
||||
extensions=[], comments=[], optional=False):
|
||||
field = Field(field_name, type_obj, extensions=extensions,
|
||||
field_comments=comments, optional=optional)
|
||||
if bool(count):
|
||||
try:
|
||||
@@ -248,9 +244,10 @@ class Type(FieldSet):
|
||||
return name, False
|
||||
|
||||
def add_data_field(self, field_name, type_obj, count=1,
|
||||
is_extension=[], comments=[], optional=False):
|
||||
extensions=[], comments=[], optional=False):
|
||||
FieldSet.add_data_field(self, field_name, type_obj, count,
|
||||
is_extension, comments=comments, optional=optional)
|
||||
extensions=extensions,
|
||||
comments=comments, optional=optional)
|
||||
if type_obj.name not in self.depends_on:
|
||||
self.depends_on[type_obj.name] = type_obj
|
||||
|
||||
@@ -459,7 +456,13 @@ class Master(object):
|
||||
subtypes = self.get_ordered_subtypes()
|
||||
stuff['structs'] = subtypes + self.tlv_messages()
|
||||
stuff['tlvs'] = self.tlvs
|
||||
stuff['messages'] = list(self.messages.values()) + list(self.extension_msgs.values())
|
||||
|
||||
# We leave out extension messages in the printing pages. Any extension
|
||||
# fields will get printed under the 'original' message, if present
|
||||
if options.print_wire:
|
||||
stuff['messages'] = list(self.messages.values())
|
||||
else:
|
||||
stuff['messages'] = list(self.messages.values()) + list(self.extension_msgs.values())
|
||||
stuff['subtypes'] = subtypes
|
||||
|
||||
print(template.render(**stuff), file=output)
|
||||
@@ -550,6 +553,9 @@ def main(options, args=None, output=sys.stdout, lines=None):
|
||||
# we'll refer to 'optional' message fields as 'extensions')
|
||||
#
|
||||
if bool(tokens[5:]): # is an extension field
|
||||
if optional:
|
||||
raise ValueError("Extension fields cannot be optional. {}:{}"
|
||||
.format(ln, line))
|
||||
extension_name = "{}_{}".format(tokens[1], tokens[5])
|
||||
orig_msg = msg
|
||||
msg = master.find_message(extension_name)
|
||||
@@ -558,6 +564,13 @@ def main(options, args=None, output=sys.stdout, lines=None):
|
||||
msg.enumname = msg.name
|
||||
msg.name = extension_name
|
||||
master.add_extension_msg(msg.name, msg)
|
||||
# If this is a print_wire page, add the extension fields to the
|
||||
# original message, so we can print them if present.
|
||||
if options.print_wire:
|
||||
orig_msg.add_data_field(tokens[2], type_obj, count=count,
|
||||
extensions=tokens[5:],
|
||||
comments=list(comment_set),
|
||||
optional=optional)
|
||||
|
||||
if collapse:
|
||||
count = 1
|
||||
|
||||
@@ -43,6 +43,7 @@ msgdata,test_msg,test_sbt_varlen_varsize,subtype_varlen_varsize,
|
||||
msgdata,test_msg,test_sbt_arrays,subtype_arrays,
|
||||
# test extension fields
|
||||
msgdata,test_msg,extension_1,test_features,,option_short_id
|
||||
msgdata,test_msg,extension_2,test_short_id,,option_one,option_two
|
||||
|
||||
msgtype,test_tlv1,2
|
||||
msgdata,test_tlv1,test_struct,test_short_id,
|
||||
|
||||
Reference in New Issue
Block a user