pytest: add schema support for JSON responses.

This adds our first (basic) schema, and sews support into pyln-testing
so it will load schemas for any method for doc/schemas/{method}.schema.json.

All JSON responses in a test run are checked against the schema (if any).

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2021-05-26 15:17:01 +09:30
parent 8a67c4a1ba
commit ea99a05249
6 changed files with 237 additions and 4 deletions

View File

@@ -349,6 +349,7 @@ class UnixDomainSocketRpc(object):
"enable": True "enable": True
}, },
}) })
# FIXME: Notification schema support?
_, buf = self._readobj(sock, buf) _, buf = self._readobj(sock, buf)
request = { request = {

View File

@@ -1,13 +1,17 @@
from concurrent import futures from concurrent import futures
from pyln.testing.db import SqliteDbProvider, PostgresDbProvider from pyln.testing.db import SqliteDbProvider, PostgresDbProvider
from pyln.testing.utils import NodeFactory, BitcoinD, ElementsD, env, DEVELOPER, LightningNode, TEST_DEBUG, Throttler from pyln.testing.utils import NodeFactory, BitcoinD, ElementsD, env, DEVELOPER, LightningNode, TEST_DEBUG, Throttler
from pyln.client import Millisatoshi
from typing import Dict from typing import Dict
import json
import jsonschema # type: ignore
import logging import logging
import os import os
import pytest # type: ignore import pytest # type: ignore
import re import re
import shutil import shutil
import string
import sys import sys
import tempfile import tempfile
@@ -202,8 +206,145 @@ def throttler(test_base_dir):
yield Throttler(test_base_dir) yield Throttler(test_base_dir)
def _extra_validator():
"""JSON Schema validator with additions for our specialized types"""
def is_hex(checker, instance):
"""Hex string"""
if not checker.is_type(instance, "string"):
return False
return all(c in string.hexdigits for c in instance)
def is_u64(checker, instance):
"""64-bit integer"""
if not checker.is_type(instance, "integer"):
return False
return instance >= 0 and instance < 2**64
def is_u32(checker, instance):
"""32-bit integer"""
if not checker.is_type(instance, "integer"):
return False
return instance >= 0 and instance < 2**32
def is_u16(checker, instance):
"""16-bit integer"""
if not checker.is_type(instance, "integer"):
return False
return instance >= 0 and instance < 2**16
def is_short_channel_id(checker, instance):
"""Short channel id"""
if not checker.is_type(instance, "string"):
return False
parts = instance.split("x")
if len(parts) != 3:
return False
# May not be integers
try:
blocknum = int(parts[0])
txnum = int(parts[1])
outnum = int(parts[2])
except ValueError:
return False
# BOLT #7:
# ## Definition of `short_channel_id`
#
# The `short_channel_id` is the unique description of the funding transaction.
# It is constructed as follows:
# 1. the most significant 3 bytes: indicating the block height
# 2. the next 3 bytes: indicating the transaction index within the block
# 3. the least significant 2 bytes: indicating the output index that pays to the
# channel.
return (blocknum >= 0 and blocknum < 2**24
and txnum >= 0 and txnum < 2**24
and outnum >= 0 and outnum < 2**16)
def is_pubkey(checker, instance):
"""SEC1 encoded compressed pubkey"""
if not checker.is_type(instance, "hex"):
return False
if len(instance) != 66:
return False
return instance[0:2] == "02" or instance[0:2] == "03"
def is_pubkey32(checker, instance):
"""x-only BIP-340 public key"""
if not checker.is_type(instance, "hex"):
return False
if len(instance) != 64:
return False
return True
def is_signature(checker, instance):
"""DER encoded secp256k1 ECDSA signature"""
if not checker.is_type(instance, "hex"):
return False
if len(instance) > 72 * 2:
return False
return True
def is_bip340sig(checker, instance):
"""Hex encoded secp256k1 Schnorr signature"""
if not checker.is_type(instance, "hex"):
return False
if len(instance) != 64 * 2:
return False
return True
def is_msat(checker, instance):
"""String number ending in msat"""
return type(instance) is Millisatoshi
def is_txid(checker, instance):
"""Bitcoin transaction ID"""
if not checker.is_type(instance, "hex"):
return False
return len(instance) == 64
type_checker = jsonschema.Draft7Validator.TYPE_CHECKER.redefine_many({
"hex": is_hex,
"u64": is_u64,
"u32": is_u32,
"u16": is_u16,
"pubkey": is_pubkey,
"msat": is_msat,
"txid": is_txid,
"signature": is_signature,
"bip340sig": is_bip340sig,
"pubkey32": is_pubkey32,
"short_channel_id": is_short_channel_id,
})
return jsonschema.validators.extend(jsonschema.Draft7Validator,
type_checker=type_checker)
def _load_schema(filename):
"""Load the schema from @filename and create a validator for it"""
with open(filename, 'r') as f:
return _extra_validator()(json.load(f))
@pytest.fixture(autouse=True)
def jsonschemas():
"""Load schema files if they exist"""
try:
schemafiles = os.listdir('doc/schemas')
except FileNotFoundError:
schemafiles = []
schemas = {}
for fname in schemafiles:
if not fname.endswith('.schema.json'):
continue
schemas[fname.rpartition('.schema')[0]] = _load_schema(os.path.join('doc/schemas',
fname))
return schemas
@pytest.fixture @pytest.fixture
def node_factory(request, directory, test_name, bitcoind, executor, db_provider, teardown_checks, node_cls, throttler): def node_factory(request, directory, test_name, bitcoind, executor, db_provider, teardown_checks, node_cls, throttler, jsonschemas):
nf = NodeFactory( nf = NodeFactory(
request, request,
test_name, test_name,
@@ -213,6 +354,7 @@ def node_factory(request, directory, test_name, bitcoind, executor, db_provider,
db_provider=db_provider, db_provider=db_provider,
node_cls=node_cls, node_cls=node_cls,
throttler=throttler, throttler=throttler,
jsonschemas=jsonschemas,
) )
yield nf yield nf

View File

@@ -601,7 +601,17 @@ class PrettyPrintingLightningRpc(LightningRpc):
eyes. It has some overhead since we re-serialize the request and eyes. It has some overhead since we re-serialize the request and
result to json in order to pretty print it. result to json in order to pretty print it.
Also validates (optional) schemas for us.
""" """
def __init__(self, socket_path, executor=None, logger=logging,
patch_json=True, jsonschemas={}):
super().__init__(
socket_path,
executor,
logger,
patch_json,
)
self.jsonschemas = jsonschemas
def call(self, method, payload=None): def call(self, method, payload=None):
id = self.next_id id = self.next_id
@@ -615,6 +625,10 @@ class PrettyPrintingLightningRpc(LightningRpc):
"id": id, "id": id,
"result": res "result": res
}, indent=2)) }, indent=2))
if method in self.jsonschemas:
self.jsonschemas[method].validate(res)
return res return res
@@ -625,6 +639,7 @@ class LightningNode(object):
allow_warning=False, allow_warning=False,
allow_bad_gossip=False, allow_bad_gossip=False,
db=None, port=None, disconnect=None, random_hsm=None, options=None, db=None, port=None, disconnect=None, random_hsm=None, options=None,
jsonschemas={},
**kwargs): **kwargs):
self.bitcoin = bitcoind self.bitcoin = bitcoind
self.executor = executor self.executor = executor
@@ -639,7 +654,7 @@ class LightningNode(object):
self.rc = 0 self.rc = 0
socket_path = os.path.join(lightning_dir, TEST_NETWORK, "lightning-rpc").format(node_id) socket_path = os.path.join(lightning_dir, TEST_NETWORK, "lightning-rpc").format(node_id)
self.rpc = PrettyPrintingLightningRpc(socket_path, self.executor) self.rpc = PrettyPrintingLightningRpc(socket_path, self.executor, jsonschemas=jsonschemas)
self.daemon = LightningD( self.daemon = LightningD(
lightning_dir, bitcoindproxy=bitcoind.get_proxy(), lightning_dir, bitcoindproxy=bitcoind.get_proxy(),
@@ -1196,7 +1211,7 @@ class NodeFactory(object):
"""A factory to setup and start `lightningd` daemons. """A factory to setup and start `lightningd` daemons.
""" """
def __init__(self, request, testname, bitcoind, executor, directory, def __init__(self, request, testname, bitcoind, executor, directory,
db_provider, node_cls, throttler): db_provider, node_cls, throttler, jsonschemas):
if request.node.get_closest_marker("slow_test") and SLOW_MACHINE: if request.node.get_closest_marker("slow_test") and SLOW_MACHINE:
self.valgrind = False self.valgrind = False
else: else:
@@ -1211,6 +1226,7 @@ class NodeFactory(object):
self.db_provider = db_provider self.db_provider = db_provider
self.node_cls = node_cls self.node_cls = node_cls
self.throttler = throttler self.throttler = throttler
self.jsonschemas = jsonschemas
def split_options(self, opts): def split_options(self, opts):
"""Split node options from cli options """Split node options from cli options
@@ -1289,6 +1305,7 @@ class NodeFactory(object):
node = self.node_cls( node = self.node_cls(
node_id, lightning_dir, self.bitcoind, self.executor, self.valgrind, db=db, node_id, lightning_dir, self.bitcoind, self.executor, self.valgrind, db=db,
port=port, options=options, may_fail=may_fail or expect_fail, port=port, options=options, may_fail=may_fail or expect_fail,
jsonschemas=self.jsonschemas,
**kwargs **kwargs
) )

View File

@@ -9,3 +9,4 @@ pytest-timeout ~= 1.4.2
pytest-xdist ~= 2.2.0 pytest-xdist ~= 2.2.0
pytest==6.1.* pytest==6.1.*
python-bitcoinlib==0.11.* python-bitcoinlib==0.11.*
jsonschema==3.2.*

View File

@@ -0,0 +1,72 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"additionalProperties": false,
"properties": {
"pays": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": false,
"properties": {
"payment_hash": {
"type": "hex",
"description": "the hash of the *payment_preimage* which will prove payment",
"maxLength": 64,
"minLength": 64
},
"status": {
"type": "string",
"enum": [ "pending", "failed", "complete" ],
"description": "status of the payment"
},
"destination": {
"type": "pubkey",
"description": "the final destination of the payment if known"
},
"amount_msat": {
"type": "msat",
"description": "the amount the destination received, if known (**status** *complete* or *pending*)"
},
"amount_sent_msat": {
"type": "msat",
"description": "the amount we actually sent, including fees (**status** *complete* or *pending*)"
},
"created_at": {
"type": "u64",
"description": "the UNIX timestamp showing when this payment was initiated"
},
"preimage": {
"type": "hex",
"description": "proof of payment, only if (and always if) **status** is *complete*",
"FIXME": "we should enforce the status/payment_preimage relation in the schema!",
"maxLength": 64,
"minLength": 64
},
"label": {
"type": "string",
"description": "the label, if given to sendpay"
},
"bolt11": {
"type": "string",
"description": "the bolt11 string (if pay supplied one)"
},
"bolt12": {
"type": "string",
"description": "the bolt12 string (if supplied for pay: **experimental-offers** only)."
},
"erroronion": {
"type": "hex",
"description": "the error onion returned on failure, if any."
},
"number_of_parts": {
"type": "u64",
"description": "the number of parts for a successful payment (only if more than one, and **status** is *complete*)."
}
},
"required": [ "payment_hash", "status", "created_at" ]
}
}
},
"required": [ "pays" ]
}

View File

@@ -1,5 +1,5 @@
from utils import DEVELOPER, TEST_NETWORK # noqa: F401,F403 from utils import DEVELOPER, TEST_NETWORK # noqa: F401,F403
from pyln.testing.fixtures import directory, test_base_dir, test_name, chainparams, node_factory, bitcoind, teardown_checks, throttler, db_provider, executor, setup_logging # noqa: F401,F403 from pyln.testing.fixtures import directory, test_base_dir, test_name, chainparams, node_factory, bitcoind, teardown_checks, throttler, db_provider, executor, setup_logging, jsonschemas # noqa: F401,F403
from pyln.testing import utils from pyln.testing import utils
from utils import COMPAT from utils import COMPAT