From 5a7d038e6eecaec41516a27b6d3eb702ed57cd98 Mon Sep 17 00:00:00 2001 From: Rusty Russell Date: Mon, 25 Feb 2019 14:45:56 +1030 Subject: [PATCH] pylightning: provide a class for Lightning JSONDecoder. Some JSON functions want a *class*, not just a hook, so provide one. To make it clear that we want an encoding *class* and a decoding *object*, rename the UnixDomainSocketRpc encode parameter to encode_cls. Signed-off-by: Rusty Russell --- contrib/pylightning/lightning/lightning.py | 57 ++++++++++++---------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/contrib/pylightning/lightning/lightning.py b/contrib/pylightning/lightning/lightning.py index e4c9ad270..e2e46f121 100644 --- a/contrib/pylightning/lightning/lightning.py +++ b/contrib/pylightning/lightning/lightning.py @@ -121,9 +121,9 @@ class Millisatoshi: class UnixDomainSocketRpc(object): - def __init__(self, socket_path, executor=None, logger=logging, encoder=json.JSONEncoder, decoder=json.JSONDecoder): + def __init__(self, socket_path, executor=None, logger=logging, encoder_cls=json.JSONEncoder, decoder=json.JSONDecoder()): self.socket_path = socket_path - self.encoder = encoder + self.encoder_cls = encoder_cls self.decoder = decoder self.executor = executor self.logger = logger @@ -133,7 +133,7 @@ class UnixDomainSocketRpc(object): self.next_id = 0 def _writeobj(self, sock, obj): - s = json.dumps(obj, cls=self.encoder) + s = json.dumps(obj, cls=self.encoder_cls) sock.sendall(bytearray(s, 'UTF-8')) def _readobj_compat(self, sock, buff=b''): @@ -245,32 +245,39 @@ class LightningRpc(UnixDomainSocketRpc): pass return json.JSONEncoder.default(self, o) - @staticmethod - def lightning_json_hook(json_object): - return json_object + class LightningJSONDecoder(json.JSONDecoder): + def __init__(self, *, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, strict=True, object_pairs_hook=None): + self.object_hook_next = object_hook + super().__init__(object_hook=self.millisatoshi_hook, parse_float=parse_float, parse_int=parse_int, parse_constant=parse_constant, strict=strict, object_pairs_hook=object_pairs_hook) - @staticmethod - def replace_amounts(obj): - """ - Recursively replace _msat fields with appropriate values with Millisatoshi. - """ - if isinstance(obj, dict): - for k, v in obj.items(): - if k.endswith('msat'): - if isinstance(v, str) and v.endswith('msat'): - obj[k] = Millisatoshi(v) - # Special case for array of msat values - elif isinstance(v, list) and all(isinstance(e, str) and e.endswith('msat') for e in v): - obj[k] = [Millisatoshi(e) for e in v] - else: - obj[k] = LightningRpc.replace_amounts(v) - elif isinstance(obj, list): - obj = [LightningRpc.replace_amounts(e) for e in obj] + @staticmethod + def replace_amounts(obj): + """ + Recursively replace _msat fields with appropriate values with Millisatoshi. + """ + if isinstance(obj, dict): + for k, v in obj.items(): + if k.endswith('msat'): + if isinstance(v, str) and v.endswith('msat'): + obj[k] = Millisatoshi(v) + # Special case for array of msat values + elif isinstance(v, list) and all(isinstance(e, str) and e.endswith('msat') for e in v): + obj[k] = [Millisatoshi(e) for e in v] + else: + obj[k] = LightningRpc.LightningJSONDecoder.replace_amounts(v) + elif isinstance(obj, list): + obj = [LightningRpc.LightningJSONDecoder.replace_amounts(e) for e in obj] - return obj + return obj + + def millisatoshi_hook(self, obj): + obj = LightningRpc.LightningJSONDecoder.replace_amounts(obj) + if self.object_hook_next: + obj = self.object_hook_next(obj) + return obj def __init__(self, socket_path, executor=None, logger=logging): - super().__init__(socket_path, executor, logging, self.LightningJSONEncoder, json.JSONDecoder(object_hook=self.replace_amounts)) + super().__init__(socket_path, executor, logging, self.LightningJSONEncoder, self.LightningJSONDecoder()) def getpeer(self, peer_id, level=None): """