From ff5dfb1cc4c7c96ddded2aa57eb5be9066f60f6c Mon Sep 17 00:00:00 2001 From: Christian Decker Date: Wed, 8 May 2019 13:17:14 +0200 Subject: [PATCH] pylightning: Clean up the argument binding We had a bit of a hand-woven mess in there, trying to inject the extra arguments in the correct places. We now instead treat positional and keyword calls separately and can go back to using the builtin argument binding again. Signed-off-by: Christian Decker --- contrib/pylightning/lightning/plugin.py | 132 +++++++++++------------ contrib/pylightning/tests/test_plugin.py | 127 +++++++++++++++++++++- 2 files changed, 189 insertions(+), 70 deletions(-) diff --git a/contrib/pylightning/lightning/plugin.py b/contrib/pylightning/lightning/plugin.py index bd8be1c3d..4440f3185 100644 --- a/contrib/pylightning/lightning/plugin.py +++ b/contrib/pylightning/lightning/plugin.py @@ -272,79 +272,73 @@ class Plugin(object): return f return decorator - def _exec_func(self, func, request): - params = request.params + @staticmethod + def _coerce_arguments(func, ba): + args = OrderedDict() + for key, val in ba.arguments.items(): + annotation = func.__annotations__.get(key) + if annotation == Millisatoshi: + args[key] = Millisatoshi(val) + else: + args[key] = val + ba.arguments = args + return ba + + def _bind_pos(self, func, params, request): + """Positional binding of parameters + """ + assert(isinstance(params, list)) sig = inspect.signature(func) - arguments = OrderedDict() - for name, value in sig.parameters.items(): - arguments[name] = inspect._empty + # Collect injections so we can sort them and insert them in the right + # order later. If we don't apply inject them in increasing order we + # might shift away an earlier injection. + injections = [] + if 'plugin' in sig.parameters: + pos = list(sig.parameters.keys()).index('plugin') + injections.append((pos, self)) + if 'request' in sig.parameters: + pos = list(sig.parameters.keys()).index('request') + injections.append((pos, request)) + injections = sorted(injections) + for pos, val in injections: + params = params[:pos] + [val] + params[pos:] - # Fill in any injected parameters - if 'plugin' in arguments: - arguments['plugin'] = self - - if 'request' in arguments: - arguments['request'] = request - - args = [] - kwargs = {} - # Now zip the provided arguments and the prefilled a together - if isinstance(params, dict): - for k, v in params.items(): - if k in arguments: - # Explicitly (try to) interpret as Millisatoshi if annotated - if func.__annotations__.get(k) == Millisatoshi: - arguments[k] = Millisatoshi(v) - else: - arguments[k] = v - else: - kwargs[k] = v - else: - pos = 0 - for k, v in arguments.items(): - # Skip already assigned args and special catch-all args - if v is not inspect._empty or k in ['args', 'kwargs']: - continue - - if pos < len(params): - # Apply positional args if we have them - if func.__annotations__.get(k) == Millisatoshi: - arguments[k] = Millisatoshi(params[pos]) - else: - arguments[k] = params[pos] - elif sig.parameters[k].default is inspect.Signature.empty: - # This is a positional arg with no value passed - raise TypeError("Missing required parameter: %s" % sig.parameters[k]) - else: - # For the remainder apply default args - arguments[k] = sig.parameters[k].default - pos += 1 - if len(arguments) < len(params): - args = params[len(arguments):] - - if 'kwargs' in arguments: - arguments['kwargs'] = kwargs - elif len(kwargs) > 0: - raise TypeError("Extra arguments given: {kwargs}".format(kwargs=kwargs)) - - if 'args' in arguments: - arguments['args'] = args - elif len(args) > 0: - raise TypeError("Extra arguments given: {args}".format(args=args)) - - missing = [k for k, v in arguments.items() if v is inspect._empty] - if missing: - raise TypeError("Missing positional arguments ({given} given, " - "expected {expected}): {missing}".format( - missing=", ".join(missing), - given=len(arguments) - len(missing), - expected=len(arguments) - )) - - ba = sig.bind(**arguments) + ba = sig.bind(*params) + self._coerce_arguments(func, ba) ba.apply_defaults() - return func(*ba.args, **ba.kwargs) + return ba + + def _bind_kwargs(self, func, params, request): + """Keyword based binding of parameters + """ + assert(isinstance(params, dict)) + sig = inspect.signature(func) + + # Inject additional parameters if they are in the signature. + if 'plugin' in sig.parameters: + params['plugin'] = self + elif 'plugin' in params: + del params['plugin'] + if 'request' in sig.parameters: + params['request'] = request + elif 'request' in params: + del params['request'] + + ba = sig.bind(**params) + self._coerce_arguments(func, ba) + return ba + + def _exec_func(self, func, request): + params = request.params + if isinstance(params, list): + ba = self._bind_pos(func, params, request) + return func(*ba.args, **ba.kwargs) + elif isinstance(params, dict): + ba = self._bind_kwargs(func, params, request) + return func(*ba.args, **ba.kwargs) + else: + raise TypeError("Parameters to function call must be either a dict or a list.") def _dispatch_request(self, request): name = request.method diff --git a/contrib/pylightning/tests/test_plugin.py b/contrib/pylightning/tests/test_plugin.py index 60f642874..ff7d02a25 100644 --- a/contrib/pylightning/tests/test_plugin.py +++ b/contrib/pylightning/tests/test_plugin.py @@ -1,5 +1,5 @@ from lightning import Plugin -from lightning.plugin import Request +from lightning.plugin import Request, Millisatoshi import itertools import pytest @@ -237,3 +237,128 @@ def test_positional_inject(): method='func', params=[]) ) + + +def test_bind_pos(): + p = Plugin(autopatch=False) + + req = object() + params = ['World'] + + def test1(name): + assert name == 'World' + bound = p._bind_pos(test1, params, req) + test1(*bound.args, **bound.kwargs) + + def test2(name, plugin): + assert name == 'World' + assert plugin == p + bound = p._bind_pos(test2, params, req) + test2(*bound.args, **bound.kwargs) + + def test3(plugin, name): + assert name == 'World' + assert plugin == p + bound = p._bind_pos(test3, params, req) + test3(*bound.args, **bound.kwargs) + + def test4(plugin, name, request): + assert name == 'World' + assert plugin == p + assert request == req + bound = p._bind_pos(test4, params, req) + test4(*bound.args, **bound.kwargs) + + def test5(request, name, plugin): + assert name == 'World' + assert plugin == p + assert request == req + bound = p._bind_pos(test5, params, req) + test5(*bound.args, **bound.kwargs) + + def test6(request, name, plugin, answer=42): + assert name == 'World' + assert plugin == p + assert request == req + assert answer == 42 + bound = p._bind_pos(test6, params, req) + test6(*bound.args, **bound.kwargs) + + # Now mix in a catch-all parameter that needs to be assigned + def test6(request, name, plugin, *args, **kwargs): + assert name == 'World' + assert plugin == p + assert request == req + assert args == (42,) + assert kwargs == {} + bound = p._bind_pos(test6, params + [42], req) + test6(*bound.args, **bound.kwargs) + + +def test_bind_kwargs(): + p = Plugin(autopatch=False) + + req = object() + params = {'name': 'World'} + + def test1(name): + assert name == 'World' + bound = p._bind_kwargs(test1, params, req) + test1(*bound.args, **bound.kwargs) + + def test2(name, plugin): + assert name == 'World' + assert plugin == p + bound = p._bind_kwargs(test2, params, req) + test2(*bound.args, **bound.kwargs) + + def test3(plugin, name): + assert name == 'World' + assert plugin == p + bound = p._bind_kwargs(test3, params, req) + test3(*bound.args, **bound.kwargs) + + def test4(plugin, name, request): + assert name == 'World' + assert plugin == p + assert request == req + bound = p._bind_kwargs(test4, params, req) + test4(*bound.args, **bound.kwargs) + + def test5(request, name, plugin): + assert name == 'World' + assert plugin == p + assert request == req + bound = p._bind_kwargs(test5, params, req) + test5(*bound.args, **bound.kwargs) + + def test6(request, name, plugin, answer=42): + assert name == 'World' + assert plugin == p + assert request == req + assert answer == 42 + bound = p._bind_kwargs(test6, params, req) + test6(*bound.args, **bound.kwargs) + + # Now mix in a catch-all parameter that needs to be assigned + def test6(request, name, plugin, *args, **kwargs): + assert name == 'World' + assert plugin == p + assert request == req + assert args == () + assert kwargs == {'answer': 42} + bound = p._bind_kwargs(test6, {'name': 'World', 'answer': 42}, req) + test6(*bound.args, **bound.kwargs) + + +def test_argument_coercion(): + p = Plugin(autopatch=False) + + def test1(msat: Millisatoshi): + assert isinstance(msat, Millisatoshi) + + ba = p._bind_kwargs(test1, {"msat": "100msat"}, None) + test1(*ba.args) + + ba = p._bind_pos(test1, ["100msat"], None) + test1(*ba.args, **ba.kwargs)