From 168bc547001f4abc7a6f2cc96fa9517e53b2669f Mon Sep 17 00:00:00 2001 From: Christian Decker Date: Fri, 31 Mar 2023 17:52:52 +0200 Subject: [PATCH] msggen: Add VersioningCheck This is a visitor that ensures every new field has at least an `added` field, and that we don't change the `added` or `deprecated` annotation after the fact. --- contrib/msggen/msggen/__main__.py | 3 +++ contrib/msggen/msggen/checks.py | 36 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 contrib/msggen/msggen/checks.py diff --git a/contrib/msggen/msggen/__main__.py b/contrib/msggen/msggen/__main__.py index 98ba0c526..e8a5ee49a 100644 --- a/contrib/msggen/msggen/__main__.py +++ b/contrib/msggen/msggen/__main__.py @@ -73,6 +73,9 @@ def run(rootdir: Path): p.apply(service) OptionalPatch().apply(service) + # Run the checks here, we should eventually split that out to a + # separate subcommand + VersioningCheck().check(service) generator_chain = GeneratorChain() add_handler_gen_grpc(generator_chain, meta) diff --git a/contrib/msggen/msggen/checks.py b/contrib/msggen/msggen/checks.py new file mode 100644 index 000000000..1f97d87a2 --- /dev/null +++ b/contrib/msggen/msggen/checks.py @@ -0,0 +1,36 @@ +from abc import ABC +from msggen import model + + +class Check(ABC): + """A check is a visitor that throws exceptions on inconsistencies. + + """ + def visit(self, field: model.Field) -> None: + pass + + def check(self, service: model.Service) -> None: + def recurse(f: model.Field): + # First recurse if we have further type definitions + if isinstance(f, model.ArrayField): + self.visit(f.itemtype) + recurse(f.itemtype) + elif isinstance(f, model.CompositeField): + for c in f.fields: + self.visit(c) + recurse(c) + # Now visit ourselves + self.visit(f) + for m in service.methods: + recurse(m.request) + recurse(m.response) + + +class VersioningCheck(Check): + """Check that all schemas have the `added` and `deprecated` annotations. + """ + def visit(self, f: model.Field) -> None: + if not hasattr(f, "added"): + raise ValueError(f"Field {f.path} is missing the `added` annotation") + if not hasattr(f, "deprecated"): + raise ValueError(f"Field {f.path} is missing the `deprecated` annotation")