mirror of
https://github.com/aljazceru/lightning.git
synced 2025-12-19 23:24:27 +01:00
pyln: Add type-annotations to plugin.py
This should help users that have type-checking enabled.
This commit is contained in:
committed by
Rusty Russell
parent
d27da4d152
commit
49ec800a07
@@ -1,10 +1,12 @@
|
|||||||
|
from .lightning import LightningRpc, Millisatoshi
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from .lightning import LightningRpc, Millisatoshi
|
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -12,6 +14,16 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
# Notice that this definition is incomplete as it only checks the
|
||||||
|
# top-level. Arrays and Dicts could contain types that aren't encodeable. This
|
||||||
|
# limitation stems from the fact that recursive types are not really supported
|
||||||
|
# yet.
|
||||||
|
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
|
||||||
|
|
||||||
|
# Yes, decorators are weird...
|
||||||
|
NoneDecoratorType = Callable[..., Callable[..., None]]
|
||||||
|
JsonDecoratorType = Callable[..., Callable[..., JSONType]]
|
||||||
|
|
||||||
|
|
||||||
class MethodType(Enum):
|
class MethodType(Enum):
|
||||||
RPCMETHOD = 0
|
RPCMETHOD = 0
|
||||||
@@ -32,8 +44,10 @@ class Method(object):
|
|||||||
- RPC exposed by RPC passthrough
|
- RPC exposed by RPC passthrough
|
||||||
- HOOK registered to be called synchronously by lightningd
|
- HOOK registered to be called synchronously by lightningd
|
||||||
"""
|
"""
|
||||||
def __init__(self, name, func, mtype=MethodType.RPCMETHOD, category=None,
|
def __init__(self, name: str, func: Callable[..., JSONType],
|
||||||
desc=None, long_desc=None, deprecated=False):
|
mtype: MethodType = MethodType.RPCMETHOD,
|
||||||
|
category: str = None, desc: str = None,
|
||||||
|
long_desc: str = None, deprecated: bool = False):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.func = func
|
self.func = func
|
||||||
self.mtype = mtype
|
self.mtype = mtype
|
||||||
@@ -47,7 +61,8 @@ class Method(object):
|
|||||||
class Request(dict):
|
class Request(dict):
|
||||||
"""A request object that wraps params and allows async return
|
"""A request object that wraps params and allows async return
|
||||||
"""
|
"""
|
||||||
def __init__(self, plugin, req_id, method, params, background=False):
|
def __init__(self, plugin: 'Plugin', req_id: Optional[int], method: str,
|
||||||
|
params: Any, background: bool = False):
|
||||||
self.method = method
|
self.method = method
|
||||||
self.params = params
|
self.params = params
|
||||||
self.background = background
|
self.background = background
|
||||||
@@ -55,15 +70,19 @@ class Request(dict):
|
|||||||
self.state = RequestState.PENDING
|
self.state = RequestState.PENDING
|
||||||
self.id = req_id
|
self.id = req_id
|
||||||
|
|
||||||
def getattr(self, key):
|
def getattr(self, key: str) -> Union[Method, Any, int]:
|
||||||
if key == "params":
|
if key == "params":
|
||||||
return self.params
|
return self.params
|
||||||
elif key == "id":
|
elif key == "id":
|
||||||
return self.id
|
return self.id
|
||||||
elif key == "method":
|
elif key == "method":
|
||||||
return self.method
|
return self.method
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot get attribute "{key}" on Request'.format(key=key)
|
||||||
|
)
|
||||||
|
|
||||||
def set_result(self, result):
|
def set_result(self, result: Any) -> None:
|
||||||
if self.state != RequestState.PENDING:
|
if self.state != RequestState.PENDING:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot set the result of a request that is not pending, "
|
"Cannot set the result of a request that is not pending, "
|
||||||
@@ -75,7 +94,7 @@ class Request(dict):
|
|||||||
'result': self.result
|
'result': self.result
|
||||||
})
|
})
|
||||||
|
|
||||||
def set_exception(self, exc):
|
def set_exception(self, exc: Exception) -> None:
|
||||||
if self.state != RequestState.PENDING:
|
if self.state != RequestState.PENDING:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot set the exception of a request that is not pending, "
|
"Cannot set the exception of a request that is not pending, "
|
||||||
@@ -93,7 +112,7 @@ class Request(dict):
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
def _write_result(self, result):
|
def _write_result(self, result: dict) -> None:
|
||||||
self.plugin._write_locked(result)
|
self.plugin._write_locked(result)
|
||||||
|
|
||||||
|
|
||||||
@@ -126,12 +145,20 @@ class Plugin(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, stdout=None, stdin=None, autopatch=True, dynamic=True,
|
def __init__(self, stdout: Optional[io.TextIOBase] = None,
|
||||||
init_features=None, node_features=None, invoice_features=None):
|
stdin: Optional[io.TextIOBase] = None, autopatch: bool = True,
|
||||||
self.methods = {'init': Method('init', self._init, MethodType.RPCMETHOD)}
|
dynamic: bool = True,
|
||||||
self.options = {}
|
init_features: Optional[Union[int, str, bytes]] = None,
|
||||||
|
node_features: Optional[Union[int, str, bytes]] = None,
|
||||||
|
invoice_features: Optional[Union[int, str, bytes]] = None):
|
||||||
|
self.methods = {
|
||||||
|
'init': Method('init', self._init, MethodType.RPCMETHOD)
|
||||||
|
}
|
||||||
|
|
||||||
def convert_featurebits(bits):
|
self.options: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
def convert_featurebits(
|
||||||
|
bits: Optional[Union[int, str, bytes]]) -> Optional[str]:
|
||||||
"""Convert the featurebits into the bytes required to hexencode.
|
"""Convert the featurebits into the bytes required to hexencode.
|
||||||
"""
|
"""
|
||||||
if bits is None:
|
if bits is None:
|
||||||
@@ -149,7 +176,9 @@ class Plugin(object):
|
|||||||
return hexlify(bits).decode('ASCII')
|
return hexlify(bits).decode('ASCII')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Could not convert featurebits to hex-encoded string")
|
raise ValueError(
|
||||||
|
"Could not convert featurebits to hex-encoded string"
|
||||||
|
)
|
||||||
|
|
||||||
self.featurebits = {
|
self.featurebits = {
|
||||||
'init': convert_featurebits(init_features),
|
'init': convert_featurebits(init_features),
|
||||||
@@ -158,7 +187,7 @@ class Plugin(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# A dict from topics to handler functions
|
# A dict from topics to handler functions
|
||||||
self.subscriptions = {}
|
self.subscriptions: Dict[str, Callable[..., None]] = {}
|
||||||
|
|
||||||
if not stdout:
|
if not stdout:
|
||||||
self.stdout = sys.stdout
|
self.stdout = sys.stdout
|
||||||
@@ -172,17 +201,21 @@ class Plugin(object):
|
|||||||
monkey_patch(self, stdout=True, stderr=True)
|
monkey_patch(self, stdout=True, stderr=True)
|
||||||
|
|
||||||
self.add_method("getmanifest", self._getmanifest, background=False)
|
self.add_method("getmanifest", self._getmanifest, background=False)
|
||||||
self.rpc_filename = None
|
self.rpc_filename: Optional[str] = None
|
||||||
self.lightning_dir = None
|
self.lightning_dir: Optional[str] = None
|
||||||
self.rpc = None
|
self.rpc: Optional[LightningRpc] = None
|
||||||
self.startup = True
|
self.startup = True
|
||||||
self.dynamic = dynamic
|
self.dynamic = dynamic
|
||||||
self.child_init = None
|
self.child_init: Optional[Callable[..., None]] = None
|
||||||
|
|
||||||
self.write_lock = RLock()
|
self.write_lock = RLock()
|
||||||
|
|
||||||
def add_method(self, name, func, background=False, category=None, desc=None,
|
def add_method(self, name: str, func: Callable[..., Any],
|
||||||
long_desc=None, deprecated=False):
|
background: bool = False,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
desc: Optional[str] = None,
|
||||||
|
long_desc: Optional[str] = None,
|
||||||
|
deprecated: bool = False) -> None:
|
||||||
"""Add a plugin method to the dispatch table.
|
"""Add a plugin method to the dispatch table.
|
||||||
|
|
||||||
The function will be expected at call time (see `_dispatch`)
|
The function will be expected at call time (see `_dispatch`)
|
||||||
@@ -221,11 +254,15 @@ class Plugin(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Register the function with the name
|
# Register the function with the name
|
||||||
method = Method(name, func, MethodType.RPCMETHOD, category, desc, long_desc, deprecated)
|
method = Method(
|
||||||
|
name, func, MethodType.RPCMETHOD, category, desc, long_desc,
|
||||||
|
deprecated
|
||||||
|
)
|
||||||
|
|
||||||
method.background = background
|
method.background = background
|
||||||
self.methods[name] = method
|
self.methods[name] = method
|
||||||
|
|
||||||
def add_subscription(self, topic, func):
|
def add_subscription(self, topic: str, func: Callable[..., None]) -> None:
|
||||||
"""Add a subscription to our list of subscriptions.
|
"""Add a subscription to our list of subscriptions.
|
||||||
|
|
||||||
A subscription is an association between a topic and a handler
|
A subscription is an association between a topic and a handler
|
||||||
@@ -243,9 +280,9 @@ class Plugin(object):
|
|||||||
"Topic {} already has a handler".format(topic)
|
"Topic {} already has a handler".format(topic)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure the notification callback has a **kwargs argument so that it
|
# Make sure the notification callback has a **kwargs argument so that
|
||||||
# doesn't break if we add more arguments to the call later on. Issue a
|
# it doesn't break if we add more arguments to the call later
|
||||||
# warning if it does not.
|
# on. Issue a warning if it does not.
|
||||||
s = inspect.signature(func)
|
s = inspect.signature(func)
|
||||||
kinds = [p.kind for p in s.parameters.values()]
|
kinds = [p.kind for p in s.parameters.values()]
|
||||||
if inspect.Parameter.VAR_KEYWORD not in kinds:
|
if inspect.Parameter.VAR_KEYWORD not in kinds:
|
||||||
@@ -257,16 +294,20 @@ class Plugin(object):
|
|||||||
|
|
||||||
self.subscriptions[topic] = func
|
self.subscriptions[topic] = func
|
||||||
|
|
||||||
def subscribe(self, topic):
|
def subscribe(self, topic: str) -> NoneDecoratorType:
|
||||||
"""Function decorator to register a notification handler.
|
"""Function decorator to register a notification handler.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def decorator(f):
|
# Yes, decorator type annotations are just weird, don't think too much
|
||||||
|
# about it...
|
||||||
|
def decorator(f: Callable[..., None]) -> Callable[..., None]:
|
||||||
self.add_subscription(topic, f)
|
self.add_subscription(topic, f)
|
||||||
return f
|
return f
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def add_option(self, name, default, description, opt_type="string",
|
def add_option(self, name: str, default: Optional[str],
|
||||||
deprecated=False):
|
description: Optional[str],
|
||||||
|
opt_type: str = "string", deprecated: bool = False) -> None:
|
||||||
"""Add an option that we'd like to register with lightningd.
|
"""Add an option that we'd like to register with lightningd.
|
||||||
|
|
||||||
Needs to be called before `Plugin.run`, otherwise we might not
|
Needs to be called before `Plugin.run`, otherwise we might not
|
||||||
@@ -279,7 +320,9 @@ class Plugin(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if opt_type not in ["string", "int", "bool", "flag"]:
|
if opt_type not in ["string", "int", "bool", "flag"]:
|
||||||
raise ValueError('{} not in supported type set (string, int, bool, flag)')
|
raise ValueError(
|
||||||
|
'{} not in supported type set (string, int, bool, flag)'
|
||||||
|
)
|
||||||
|
|
||||||
self.options[name] = {
|
self.options[name] = {
|
||||||
'name': name,
|
'name': name,
|
||||||
@@ -290,7 +333,8 @@ class Plugin(object):
|
|||||||
'deprecated': deprecated,
|
'deprecated': deprecated,
|
||||||
}
|
}
|
||||||
|
|
||||||
def add_flag_option(self, name, description, deprecated=False):
|
def add_flag_option(self, name: str, description: str,
|
||||||
|
deprecated: bool = False) -> None:
|
||||||
"""Add a flag option that we'd like to register with lightningd.
|
"""Add a flag option that we'd like to register with lightningd.
|
||||||
|
|
||||||
Needs to be called before `Plugin.run`, otherwise we might not
|
Needs to be called before `Plugin.run`, otherwise we might not
|
||||||
@@ -300,7 +344,7 @@ class Plugin(object):
|
|||||||
self.add_option(name, None, description, opt_type="flag",
|
self.add_option(name, None, description, opt_type="flag",
|
||||||
deprecated=deprecated)
|
deprecated=deprecated)
|
||||||
|
|
||||||
def get_option(self, name):
|
def get_option(self, name: str) -> str:
|
||||||
if name not in self.options:
|
if name not in self.options:
|
||||||
raise ValueError("No option with name {} registered".format(name))
|
raise ValueError("No option with name {} registered".format(name))
|
||||||
|
|
||||||
@@ -309,31 +353,42 @@ class Plugin(object):
|
|||||||
else:
|
else:
|
||||||
return self.options[name]['default']
|
return self.options[name]['default']
|
||||||
|
|
||||||
def async_method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False):
|
def async_method(self, method_name: str, category: Optional[str] = None,
|
||||||
|
desc: Optional[str] = None,
|
||||||
|
long_desc: Optional[str] = None,
|
||||||
|
deprecated: bool = False) -> NoneDecoratorType:
|
||||||
"""Decorator to add an async plugin method to the dispatch table.
|
"""Decorator to add an async plugin method to the dispatch table.
|
||||||
|
|
||||||
Internally uses add_method.
|
Internally uses add_method.
|
||||||
"""
|
"""
|
||||||
def decorator(f):
|
def decorator(f: Callable[..., None]) -> Callable[..., None]:
|
||||||
self.add_method(method_name, f, background=True, category=category,
|
self.add_method(method_name, f, background=True, category=category,
|
||||||
desc=desc, long_desc=long_desc,
|
desc=desc, long_desc=long_desc,
|
||||||
deprecated=deprecated)
|
deprecated=deprecated)
|
||||||
return f
|
return f
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False):
|
def method(self, method_name: str, category: Optional[str] = None,
|
||||||
|
desc: Optional[str] = None,
|
||||||
|
long_desc: Optional[str] = None,
|
||||||
|
deprecated: bool = False) -> JsonDecoratorType:
|
||||||
"""Decorator to add a plugin method to the dispatch table.
|
"""Decorator to add a plugin method to the dispatch table.
|
||||||
|
|
||||||
Internally uses add_method.
|
Internally uses add_method.
|
||||||
"""
|
"""
|
||||||
def decorator(f):
|
def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]:
|
||||||
self.add_method(method_name, f, background=False, category=category,
|
self.add_method(method_name,
|
||||||
desc=desc, long_desc=long_desc,
|
f,
|
||||||
|
background=False,
|
||||||
|
category=category,
|
||||||
|
desc=desc,
|
||||||
|
long_desc=long_desc,
|
||||||
deprecated=deprecated)
|
deprecated=deprecated)
|
||||||
return f
|
return f
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def add_hook(self, name, func, background=False):
|
def add_hook(self, name: str, func: Callable[..., JSONType],
|
||||||
|
background: bool = False) -> None:
|
||||||
"""Register a hook that is called synchronously by lightningd on events
|
"""Register a hook that is called synchronously by lightningd on events
|
||||||
"""
|
"""
|
||||||
if name in self.methods:
|
if name in self.methods:
|
||||||
@@ -357,40 +412,47 @@ class Plugin(object):
|
|||||||
method.background = background
|
method.background = background
|
||||||
self.methods[name] = method
|
self.methods[name] = method
|
||||||
|
|
||||||
def hook(self, method_name):
|
def hook(self, method_name: str) -> JsonDecoratorType:
|
||||||
"""Decorator to add a plugin hook to the dispatch table.
|
"""Decorator to add a plugin hook to the dispatch table.
|
||||||
|
|
||||||
Internally uses add_hook.
|
Internally uses add_hook.
|
||||||
"""
|
"""
|
||||||
def decorator(f):
|
def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]:
|
||||||
self.add_hook(method_name, f, background=False)
|
self.add_hook(method_name, f, background=False)
|
||||||
return f
|
return f
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def async_hook(self, method_name):
|
def async_hook(self, method_name: str) -> NoneDecoratorType:
|
||||||
"""Decorator to add an async plugin hook to the dispatch table.
|
"""Decorator to add an async plugin hook to the dispatch table.
|
||||||
|
|
||||||
Internally uses add_hook.
|
Internally uses add_hook.
|
||||||
"""
|
"""
|
||||||
def decorator(f):
|
def decorator(f: Callable[..., None]) -> Callable[..., None]:
|
||||||
self.add_hook(method_name, f, background=True)
|
self.add_hook(method_name, f, background=True)
|
||||||
return f
|
return f
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def init(self, *args, **kwargs):
|
def init(self) -> NoneDecoratorType:
|
||||||
"""Decorator to add a function called after plugin initialization
|
"""Decorator to add a function called after plugin initialization
|
||||||
"""
|
"""
|
||||||
def decorator(f):
|
def decorator(f: Callable[..., None]) -> Callable[..., None]:
|
||||||
if self.child_init is not None:
|
if self.child_init is not None:
|
||||||
raise ValueError('The @plugin.init decorator should only be used once')
|
raise ValueError(
|
||||||
|
'The @plugin.init decorator should only be used once'
|
||||||
|
)
|
||||||
self.child_init = f
|
self.child_init = f
|
||||||
return f
|
return f
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _coerce_arguments(func, ba):
|
def _coerce_arguments(
|
||||||
|
func: Callable[..., Any],
|
||||||
|
ba: inspect.BoundArguments) -> inspect.BoundArguments:
|
||||||
args = OrderedDict()
|
args = OrderedDict()
|
||||||
annotations = func.__annotations__ if hasattr(func, "__annotations__") else {}
|
annotations = {}
|
||||||
|
if hasattr(func, "__annotations__"):
|
||||||
|
annotations = func.__annotations__
|
||||||
|
|
||||||
for key, val in ba.arguments.items():
|
for key, val in ba.arguments.items():
|
||||||
annotation = annotations.get(key, None)
|
annotation = annotations.get(key, None)
|
||||||
if annotation is not None and annotation == Millisatoshi:
|
if annotation is not None and annotation == Millisatoshi:
|
||||||
@@ -400,7 +462,8 @@ class Plugin(object):
|
|||||||
ba.arguments = args
|
ba.arguments = args
|
||||||
return ba
|
return ba
|
||||||
|
|
||||||
def _bind_pos(self, func, params, request):
|
def _bind_pos(self, func: Callable[..., Any], params: List[str],
|
||||||
|
request: Request) -> inspect.BoundArguments:
|
||||||
"""Positional binding of parameters
|
"""Positional binding of parameters
|
||||||
"""
|
"""
|
||||||
assert(isinstance(params, list))
|
assert(isinstance(params, list))
|
||||||
@@ -409,7 +472,7 @@ class Plugin(object):
|
|||||||
# Collect injections so we can sort them and insert them in the right
|
# Collect injections so we can sort them and insert them in the right
|
||||||
# order later. If we don't apply inject them in increasing order we
|
# order later. If we don't apply inject them in increasing order we
|
||||||
# might shift away an earlier injection.
|
# might shift away an earlier injection.
|
||||||
injections = []
|
injections: List[Tuple[int, Any]] = []
|
||||||
if 'plugin' in sig.parameters:
|
if 'plugin' in sig.parameters:
|
||||||
pos = list(sig.parameters.keys()).index('plugin')
|
pos = list(sig.parameters.keys()).index('plugin')
|
||||||
injections.append((pos, self))
|
injections.append((pos, self))
|
||||||
@@ -425,7 +488,8 @@ class Plugin(object):
|
|||||||
ba.apply_defaults()
|
ba.apply_defaults()
|
||||||
return ba
|
return ba
|
||||||
|
|
||||||
def _bind_kwargs(self, func, params, request):
|
def _bind_kwargs(self, func: Callable[..., Any], params: Dict[str, Any],
|
||||||
|
request: Request) -> inspect.BoundArguments:
|
||||||
"""Keyword based binding of parameters
|
"""Keyword based binding of parameters
|
||||||
"""
|
"""
|
||||||
assert(isinstance(params, dict))
|
assert(isinstance(params, dict))
|
||||||
@@ -445,7 +509,8 @@ class Plugin(object):
|
|||||||
self._coerce_arguments(func, ba)
|
self._coerce_arguments(func, ba)
|
||||||
return ba
|
return ba
|
||||||
|
|
||||||
def _exec_func(self, func, request):
|
def _exec_func(self, func: Callable[..., Any],
|
||||||
|
request: Request) -> JSONType:
|
||||||
params = request.params
|
params = request.params
|
||||||
if isinstance(params, list):
|
if isinstance(params, list):
|
||||||
ba = self._bind_pos(func, params, request)
|
ba = self._bind_pos(func, params, request)
|
||||||
@@ -454,9 +519,11 @@ class Plugin(object):
|
|||||||
ba = self._bind_kwargs(func, params, request)
|
ba = self._bind_kwargs(func, params, request)
|
||||||
return func(*ba.args, **ba.kwargs)
|
return func(*ba.args, **ba.kwargs)
|
||||||
else:
|
else:
|
||||||
raise TypeError("Parameters to function call must be either a dict or a list.")
|
raise TypeError(
|
||||||
|
"Parameters to function call must be either a dict or a list."
|
||||||
|
)
|
||||||
|
|
||||||
def _dispatch_request(self, request):
|
def _dispatch_request(self, request: Request) -> None:
|
||||||
name = request.method
|
name = request.method
|
||||||
|
|
||||||
if name not in self.methods:
|
if name not in self.methods:
|
||||||
@@ -487,7 +554,7 @@ class Plugin(object):
|
|||||||
request.set_exception(e)
|
request.set_exception(e)
|
||||||
self.log(traceback.format_exc())
|
self.log(traceback.format_exc())
|
||||||
|
|
||||||
def _dispatch_notification(self, request):
|
def _dispatch_notification(self, request: Request) -> None:
|
||||||
if request.method not in self.subscriptions:
|
if request.method not in self.subscriptions:
|
||||||
raise ValueError("No subscription for {name} found.".format(
|
raise ValueError("No subscription for {name} found.".format(
|
||||||
name=request.method))
|
name=request.method))
|
||||||
@@ -498,15 +565,19 @@ class Plugin(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.log(traceback.format_exc())
|
self.log(traceback.format_exc())
|
||||||
|
|
||||||
def _write_locked(self, obj):
|
def _write_locked(self, obj: JSONType) -> None:
|
||||||
# ensure_ascii turns UTF-8 into \uXXXX so we need to suppress that,
|
# ensure_ascii turns UTF-8 into \uXXXX so we need to suppress that,
|
||||||
# then utf8 ourselves.
|
# then utf8 ourselves.
|
||||||
s = bytes(json.dumps(obj, cls=LightningRpc.LightningJSONEncoder, ensure_ascii=False) + "\n\n", encoding='utf-8')
|
s = bytes(json.dumps(
|
||||||
|
obj,
|
||||||
|
cls=LightningRpc.LightningJSONEncoder,
|
||||||
|
ensure_ascii=False
|
||||||
|
) + "\n\n", encoding='utf-8')
|
||||||
with self.write_lock:
|
with self.write_lock:
|
||||||
self.stdout.buffer.write(s)
|
self.stdout.buffer.write(s)
|
||||||
self.stdout.flush()
|
self.stdout.flush()
|
||||||
|
|
||||||
def notify(self, method, params):
|
def notify(self, method: str, params: JSONType) -> None:
|
||||||
payload = {
|
payload = {
|
||||||
'jsonrpc': '2.0',
|
'jsonrpc': '2.0',
|
||||||
'method': method,
|
'method': method,
|
||||||
@@ -514,30 +585,35 @@ class Plugin(object):
|
|||||||
}
|
}
|
||||||
self._write_locked(payload)
|
self._write_locked(payload)
|
||||||
|
|
||||||
def log(self, message, level='info'):
|
def log(self, message: str, level: str = 'info') -> None:
|
||||||
# Split the log into multiple lines and print them
|
# Split the log into multiple lines and print them
|
||||||
# individually. Makes tracebacks much easier to read.
|
# individually. Makes tracebacks much easier to read.
|
||||||
for line in message.split('\n'):
|
for line in message.split('\n'):
|
||||||
self.notify('log', {'level': level, 'message': line})
|
self.notify('log', {'level': level, 'message': line})
|
||||||
|
|
||||||
def _parse_request(self, jsrequest):
|
def _parse_request(self, jsrequest: Dict[str, JSONType]) -> Request:
|
||||||
|
i = jsrequest.get('id', None)
|
||||||
|
if not isinstance(i, int) and i is not None:
|
||||||
|
raise ValueError('Non-integer request id "{i}"'.format(i=i))
|
||||||
|
|
||||||
request = Request(
|
request = Request(
|
||||||
plugin=self,
|
plugin=self,
|
||||||
req_id=jsrequest.get('id', None),
|
req_id=i,
|
||||||
method=jsrequest['method'],
|
method=str(jsrequest['method']),
|
||||||
params=jsrequest['params'],
|
params=jsrequest['params'],
|
||||||
background=False,
|
background=False,
|
||||||
)
|
)
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def _multi_dispatch(self, msgs):
|
def _multi_dispatch(self, msgs: List[bytes]) -> bytes:
|
||||||
"""We received a couple of messages, now try to dispatch them all.
|
"""We received a couple of messages, now try to dispatch them all.
|
||||||
|
|
||||||
Returns the last partial message that was not complete yet.
|
Returns the last partial message that was not complete yet.
|
||||||
"""
|
"""
|
||||||
for payload in msgs[:-1]:
|
for payload in msgs[:-1]:
|
||||||
# Note that we use function annotations to do Millisatoshi conversions
|
# Note that we use function annotations to do Millisatoshi
|
||||||
# in _exec_func, so we don't use LightningJSONDecoder here.
|
# conversions in _exec_func, so we don't use LightningJSONDecoder
|
||||||
|
# here.
|
||||||
request = self._parse_request(json.loads(payload.decode('utf8')))
|
request = self._parse_request(json.loads(payload.decode('utf8')))
|
||||||
|
|
||||||
# If this has an 'id'-field, it's a request and returns a
|
# If this has an 'id'-field, it's a request and returns a
|
||||||
@@ -550,7 +626,7 @@ class Plugin(object):
|
|||||||
|
|
||||||
return msgs[-1]
|
return msgs[-1]
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> None:
|
||||||
partial = b""
|
partial = b""
|
||||||
for l in self.stdin.buffer:
|
for l in self.stdin.buffer:
|
||||||
partial += l
|
partial += l
|
||||||
@@ -561,7 +637,7 @@ class Plugin(object):
|
|||||||
|
|
||||||
partial = self._multi_dispatch(msgs)
|
partial = self._multi_dispatch(msgs)
|
||||||
|
|
||||||
def _getmanifest(self, **kwargs):
|
def _getmanifest(self, **kwargs) -> JSONType:
|
||||||
if 'allow-deprecated-apis' in kwargs:
|
if 'allow-deprecated-apis' in kwargs:
|
||||||
self.deprecated_apis = kwargs['allow-deprecated-apis']
|
self.deprecated_apis = kwargs['allow-deprecated-apis']
|
||||||
else:
|
else:
|
||||||
@@ -582,13 +658,21 @@ class Plugin(object):
|
|||||||
doc = inspect.getdoc(method.func)
|
doc = inspect.getdoc(method.func)
|
||||||
if not doc:
|
if not doc:
|
||||||
self.log(
|
self.log(
|
||||||
'RPC method \'{}\' does not have a docstring.'.format(method.name)
|
'RPC method \'{}\' does not have a docstring.'.format(
|
||||||
|
method.name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
doc = "Undocumented RPC method from a plugin."
|
doc = "Undocumented RPC method from a plugin."
|
||||||
doc = re.sub('\n+', ' ', doc)
|
doc = re.sub('\n+', ' ', doc)
|
||||||
|
|
||||||
# Handles out-of-order use of parameters like:
|
# Handles out-of-order use of parameters like:
|
||||||
# def hello_obfus(arg1, arg2, plugin, thing3, request=None, thing5='at', thing6=21)
|
#
|
||||||
|
# ```python3
|
||||||
|
#
|
||||||
|
# def hello_obfus(arg1, arg2, plugin, thing3, request=None,
|
||||||
|
# thing5='at', thing6=21)
|
||||||
|
#
|
||||||
|
# ```
|
||||||
argspec = inspect.getfullargspec(method.func)
|
argspec = inspect.getfullargspec(method.func)
|
||||||
defaults = argspec.defaults
|
defaults = argspec.defaults
|
||||||
num_defaults = len(defaults) if defaults else 0
|
num_defaults = len(defaults) if defaults else 0
|
||||||
@@ -611,7 +695,8 @@ class Plugin(object):
|
|||||||
'description': doc if not method.desc else method.desc
|
'description': doc if not method.desc else method.desc
|
||||||
})
|
})
|
||||||
if method.long_desc:
|
if method.long_desc:
|
||||||
methods[len(methods) - 1]["long_description"] = method.long_desc
|
m = methods[len(methods) - 1]
|
||||||
|
m["long_description"] = method.long_desc
|
||||||
|
|
||||||
manifest = {
|
manifest = {
|
||||||
'options': list(self.options.values()),
|
'options': list(self.options.values()),
|
||||||
@@ -628,12 +713,30 @@ class Plugin(object):
|
|||||||
|
|
||||||
return manifest
|
return manifest
|
||||||
|
|
||||||
def _init(self, options, configuration, request):
|
def _init(self, options: Dict[str, JSONType],
|
||||||
self.rpc_filename = configuration['rpc-file']
|
configuration: Dict[str, JSONType],
|
||||||
self.lightning_dir = configuration['lightning-dir']
|
request: Request) -> JSONType:
|
||||||
|
|
||||||
|
def verify_str(d: Dict[str, JSONType], key: str) -> str:
|
||||||
|
v = d.get(key)
|
||||||
|
if not isinstance(v, str):
|
||||||
|
raise ValueError("Wrong argument to init: expected {key} to be"
|
||||||
|
" a string, got {v}".format(key=key, v=v))
|
||||||
|
return v
|
||||||
|
|
||||||
|
def verify_bool(d: Dict[str, JSONType], key: str) -> bool:
|
||||||
|
v = d.get(key)
|
||||||
|
if not isinstance(v, bool):
|
||||||
|
raise ValueError("Wrong argument to init: expected {key} to be"
|
||||||
|
" a bool, got {v}".format(key=key, v=v))
|
||||||
|
return v
|
||||||
|
|
||||||
|
self.rpc_filename = verify_str(configuration, 'rpc-file')
|
||||||
|
self.lightning_dir = verify_str(configuration, 'lightning-dir')
|
||||||
|
|
||||||
path = os.path.join(self.lightning_dir, self.rpc_filename)
|
path = os.path.join(self.lightning_dir, self.rpc_filename)
|
||||||
self.rpc = LightningRpc(path)
|
self.rpc = LightningRpc(path)
|
||||||
self.startup = configuration['startup']
|
self.startup = verify_bool(configuration, 'startup')
|
||||||
for name, value in options.items():
|
for name, value in options.items():
|
||||||
self.options[name]['value'] = value
|
self.options[name]['value'] = value
|
||||||
|
|
||||||
@@ -647,18 +750,18 @@ class PluginStream(object):
|
|||||||
"""Sink that turns everything that is written to it into a notification.
|
"""Sink that turns everything that is written to it into a notification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, plugin, level="info"):
|
def __init__(self, plugin: Plugin, level: str = "info"):
|
||||||
self.plugin = plugin
|
self.plugin = plugin
|
||||||
self.level = level
|
self.level = level
|
||||||
self.buff = ''
|
self.buff = ''
|
||||||
|
|
||||||
def write(self, payload):
|
def write(self, payload: str) -> None:
|
||||||
self.buff += payload
|
self.buff += payload
|
||||||
|
|
||||||
if len(payload) > 0 and payload[-1] == '\n':
|
if len(payload) > 0 and payload[-1] == '\n':
|
||||||
self.flush()
|
self.flush()
|
||||||
|
|
||||||
def flush(self):
|
def flush(self) -> None:
|
||||||
lines = self.buff.split('\n')
|
lines = self.buff.split('\n')
|
||||||
if len(lines) < 2:
|
if len(lines) < 2:
|
||||||
return
|
return
|
||||||
@@ -670,7 +773,8 @@ class PluginStream(object):
|
|||||||
self.buff = lines[-1]
|
self.buff = lines[-1]
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch(plugin, stdout=True, stderr=False):
|
def monkey_patch(plugin: Plugin, stdout: bool = True,
|
||||||
|
stderr: bool = False) -> None:
|
||||||
"""Monkey patch stderr and stdout so we use notifications instead.
|
"""Monkey patch stderr and stdout so we use notifications instead.
|
||||||
|
|
||||||
A plugin commonly communicates with lightningd over its stdout and
|
A plugin commonly communicates with lightningd over its stdout and
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import coincurve
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
|
||||||
@@ -66,5 +67,79 @@ class ShortChannelId(object):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "{self.block}x{self.txnum}x{self.outnum}".format(self=self)
|
return "{self.block}x{self.txnum}x{self.outnum}".format(self=self)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: object) -> bool:
|
||||||
return self.block == other.block and self.txnum == other.txnum and self.outnum == other.outnum
|
if not isinstance(other, ShortChannelId):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.block == other.block
|
||||||
|
and self.txnum == other.txnum
|
||||||
|
and self.outnum == other.outnum
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Secret(object):
|
||||||
|
def __init__(self, data: bytes) -> None:
|
||||||
|
assert(len(data) == 32)
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def to_bytes(self) -> bytes:
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
return isinstance(other, Secret) and self.data == other.data
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Secret[0x{}]".format(self.data.hex())
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateKey(object):
|
||||||
|
def __init__(self, rawkey) -> None:
|
||||||
|
if not isinstance(rawkey, bytes):
|
||||||
|
raise TypeError(f"rawkey must be bytes, {type(rawkey)} received")
|
||||||
|
elif len(rawkey) != 32:
|
||||||
|
raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received")
|
||||||
|
|
||||||
|
self.rawkey = rawkey
|
||||||
|
self.key = coincurve.PrivateKey(rawkey)
|
||||||
|
|
||||||
|
def serializeCompressed(self):
|
||||||
|
return self.key.secret
|
||||||
|
|
||||||
|
def public_key(self):
|
||||||
|
return PublicKey(self.key.public_key)
|
||||||
|
|
||||||
|
|
||||||
|
class PublicKey(object):
|
||||||
|
def __init__(self, innerkey):
|
||||||
|
# We accept either 33-bytes raw keys, or an EC PublicKey as returned
|
||||||
|
# by coincurve
|
||||||
|
if isinstance(innerkey, bytes):
|
||||||
|
if innerkey[0] in [2, 3] and len(innerkey) == 33:
|
||||||
|
innerkey = coincurve.PublicKey(innerkey)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Byte keys must be 33-byte long starting from either 02 or 03"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif not isinstance(innerkey, coincurve.keys.PublicKey):
|
||||||
|
raise ValueError(
|
||||||
|
"Key must either be bytes or coincurve.keys.PublicKey"
|
||||||
|
)
|
||||||
|
self.key = innerkey
|
||||||
|
|
||||||
|
def serializeCompressed(self):
|
||||||
|
return self.key.format(compressed=True)
|
||||||
|
|
||||||
|
def to_bytes(self) -> bytes:
|
||||||
|
return self.serializeCompressed()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "PublicKey[0x{}]".format(
|
||||||
|
self.serializeCompressed().hex()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def Keypair(object):
|
||||||
|
def __init__(self, priv, pub):
|
||||||
|
self.priv, self.pub = priv, pub
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from cryptography.hazmat.primitives import hashes
|
|||||||
from cryptography.hazmat.primitives.asymmetric import ec
|
from cryptography.hazmat.primitives.asymmetric import ec
|
||||||
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
|
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
|
||||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||||
|
from .primitives import Secret, PrivateKey, PublicKey
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
import coincurve
|
import coincurve
|
||||||
import os
|
import os
|
||||||
@@ -55,64 +56,6 @@ def decryptWithAD(k, n, ad, ciphertext):
|
|||||||
return chacha.decrypt(n, ciphertext, ad)
|
return chacha.decrypt(n, ciphertext, ad)
|
||||||
|
|
||||||
|
|
||||||
class PrivateKey(object):
|
|
||||||
def __init__(self, rawkey):
|
|
||||||
if not isinstance(rawkey, bytes):
|
|
||||||
raise TypeError(f"rawkey must be bytes, {type(rawkey)} received")
|
|
||||||
elif len(rawkey) != 32:
|
|
||||||
raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received")
|
|
||||||
|
|
||||||
self.rawkey = rawkey
|
|
||||||
self.key = coincurve.PrivateKey(rawkey)
|
|
||||||
|
|
||||||
def serializeCompressed(self):
|
|
||||||
return self.key.secret
|
|
||||||
|
|
||||||
def public_key(self):
|
|
||||||
return PublicKey(self.key.public_key)
|
|
||||||
|
|
||||||
|
|
||||||
class Secret(object):
|
|
||||||
def __init__(self, raw):
|
|
||||||
assert(len(raw) == 32)
|
|
||||||
self.raw = raw
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "Secret[0x{}]".format(self.raw.hex())
|
|
||||||
|
|
||||||
|
|
||||||
class PublicKey(object):
|
|
||||||
def __init__(self, innerkey):
|
|
||||||
# We accept either 33-bytes raw keys, or an EC PublicKey as returned
|
|
||||||
# by coincurve
|
|
||||||
if isinstance(innerkey, bytes):
|
|
||||||
if innerkey[0] in [2, 3] and len(innerkey) == 33:
|
|
||||||
innerkey = coincurve.PublicKey(innerkey)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Byte keys must be 33-byte long starting from either 02 or 03"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif not isinstance(innerkey, coincurve.keys.PublicKey):
|
|
||||||
raise ValueError(
|
|
||||||
"Key must either be bytes or coincurve.keys.PublicKey"
|
|
||||||
)
|
|
||||||
self.key = innerkey
|
|
||||||
|
|
||||||
def serializeCompressed(self):
|
|
||||||
return self.key.format(compressed=True)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "PublicKey[0x{}]".format(
|
|
||||||
self.serializeCompressed().hex()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def Keypair(object):
|
|
||||||
def __init__(self, priv, pub):
|
|
||||||
self.priv, self.pub = priv, pub
|
|
||||||
|
|
||||||
|
|
||||||
class Sha256Mixer(object):
|
class Sha256Mixer(object):
|
||||||
def __init__(self, base):
|
def __init__(self, base):
|
||||||
self.hash = sha256(base).digest()
|
self.hash = sha256(base).digest()
|
||||||
@@ -174,7 +117,7 @@ class LightningConnection(object):
|
|||||||
h.hash = self.handshake['h']
|
h.hash = self.handshake['h']
|
||||||
h.update(self.handshake['e'].public_key().serializeCompressed())
|
h.update(self.handshake['e'].public_key().serializeCompressed())
|
||||||
es = ecdh(self.handshake['e'], self.remote_pubkey)
|
es = ecdh(self.handshake['e'], self.remote_pubkey)
|
||||||
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'')
|
t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'')
|
||||||
assert(len(t) == 64)
|
assert(len(t) == 64)
|
||||||
self.chaining_key, temp_k1 = t[:32], t[32:]
|
self.chaining_key, temp_k1 = t[:32], t[32:]
|
||||||
c = encryptWithAD(temp_k1, self.nonce(0), h.digest(), b'')
|
c = encryptWithAD(temp_k1, self.nonce(0), h.digest(), b'')
|
||||||
@@ -194,7 +137,7 @@ class LightningConnection(object):
|
|||||||
h.update(re.serializeCompressed())
|
h.update(re.serializeCompressed())
|
||||||
es = ecdh(self.local_privkey, re)
|
es = ecdh(self.local_privkey, re)
|
||||||
self.handshake['re'] = re
|
self.handshake['re'] = re
|
||||||
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'')
|
t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'')
|
||||||
self.chaining_key, temp_k1 = t[:32], t[32:]
|
self.chaining_key, temp_k1 = t[:32], t[32:]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -210,7 +153,7 @@ class LightningConnection(object):
|
|||||||
h.hash = self.handshake['h']
|
h.hash = self.handshake['h']
|
||||||
h.update(self.handshake['e'].public_key().serializeCompressed())
|
h.update(self.handshake['e'].public_key().serializeCompressed())
|
||||||
ee = ecdh(self.handshake['e'], self.handshake['re'])
|
ee = ecdh(self.handshake['e'], self.handshake['re'])
|
||||||
t = hkdf(salt=self.chaining_key, ikm=ee.raw, info=b'')
|
t = hkdf(salt=self.chaining_key, ikm=ee.data, info=b'')
|
||||||
assert(len(t) == 64)
|
assert(len(t) == 64)
|
||||||
self.chaining_key, self.temp_k2 = t[:32], t[32:]
|
self.chaining_key, self.temp_k2 = t[:32], t[32:]
|
||||||
c = encryptWithAD(self.temp_k2, self.nonce(0), h.digest(), b'')
|
c = encryptWithAD(self.temp_k2, self.nonce(0), h.digest(), b'')
|
||||||
@@ -231,7 +174,7 @@ class LightningConnection(object):
|
|||||||
h.update(re.serializeCompressed())
|
h.update(re.serializeCompressed())
|
||||||
ee = ecdh(self.handshake['e'], re)
|
ee = ecdh(self.handshake['e'], re)
|
||||||
self.chaining_key, self.temp_k2 = hkdf_two_keys(
|
self.chaining_key, self.temp_k2 = hkdf_two_keys(
|
||||||
salt=self.chaining_key, ikm=ee.raw
|
salt=self.chaining_key, ikm=ee.data
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
decryptWithAD(self.temp_k2, self.nonce(0), h.digest(), c)
|
decryptWithAD(self.temp_k2, self.nonce(0), h.digest(), c)
|
||||||
@@ -249,7 +192,7 @@ class LightningConnection(object):
|
|||||||
se = ecdh(self.local_privkey, self.re)
|
se = ecdh(self.local_privkey, self.re)
|
||||||
|
|
||||||
self.chaining_key, self.temp_k3 = hkdf_two_keys(
|
self.chaining_key, self.temp_k3 = hkdf_two_keys(
|
||||||
salt=self.chaining_key, ikm=se.raw
|
salt=self.chaining_key, ikm=se.data
|
||||||
)
|
)
|
||||||
t = encryptWithAD(self.temp_k3, self.nonce(0), h.digest(), b'')
|
t = encryptWithAD(self.temp_k3, self.nonce(0), h.digest(), b'')
|
||||||
m = b'\x00' + c + t
|
m = b'\x00' + c + t
|
||||||
@@ -272,7 +215,7 @@ class LightningConnection(object):
|
|||||||
se = ecdh(self.handshake['e'], self.remote_pubkey)
|
se = ecdh(self.handshake['e'], self.remote_pubkey)
|
||||||
|
|
||||||
self.chaining_key, self.temp_k3 = hkdf_two_keys(
|
self.chaining_key, self.temp_k3 = hkdf_two_keys(
|
||||||
se.raw, self.chaining_key
|
se.data, self.chaining_key
|
||||||
)
|
)
|
||||||
decryptWithAD(self.temp_k3, self.nonce(0), h.digest(), t)
|
decryptWithAD(self.temp_k3, self.nonce(0), h.digest(), t)
|
||||||
self.rn, self.sn = 0, 0
|
self.rn, self.sn = 0, 0
|
||||||
|
|||||||
Reference in New Issue
Block a user