Updates unit tests to use the new key formats

This commit is contained in:
Sergi Delgado Segura
2020-02-21 13:18:56 +01:00
parent 3db5012145
commit 1837baed2a
3 changed files with 19 additions and 31 deletions

View File

@@ -5,11 +5,8 @@ import requests
from time import sleep from time import sleep
from shutil import rmtree from shutil import rmtree
from threading import Thread from threading import Thread
from binascii import hexlify
from cryptography.hazmat.backends import default_backend from coincurve import PrivateKey
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from common.blob import Blob from common.blob import Blob
from pisa.responder import TransactionTracker from pisa.responder import TransactionTracker
@@ -58,10 +55,10 @@ def db_manager():
def generate_keypair(): def generate_keypair():
client_sk = ec.generate_private_key(ec.SECP256K1, default_backend()) sk = PrivateKey()
client_pk = client_sk.public_key() pk = sk.public_key
return client_sk, client_pk return sk, pk
def get_random_value_hex(nbytes): def get_random_value_hex(nbytes):
@@ -106,9 +103,7 @@ def generate_dummy_appointment_data(real_height=True, start_time_offset=5, end_t
# dummy keys for this test # dummy keys for this test
client_sk, client_pk = generate_keypair() client_sk, client_pk = generate_keypair()
client_pk_der = client_pk.public_bytes( client_pk_hex = client_pk.format().hex()
encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo
)
locator = compute_locator(dispute_txid) locator = compute_locator(dispute_txid)
blob = Blob(dummy_appointment_data.get("tx")) blob = Blob(dummy_appointment_data.get("tx"))
@@ -124,9 +119,8 @@ def generate_dummy_appointment_data(real_height=True, start_time_offset=5, end_t
} }
signature = Cryptographer.sign(Appointment.from_dict(appointment_data).serialize(), client_sk) signature = Cryptographer.sign(Appointment.from_dict(appointment_data).serialize(), client_sk)
pk_hex = hexlify(client_pk_der).decode("utf-8")
data = {"appointment": appointment_data, "signature": signature, "public_key": pk_hex} data = {"appointment": appointment_data, "signature": signature, "public_key": client_pk_hex}
return data, dispute_tx.hex() return data, dispute_tx.hex()

View File

@@ -3,7 +3,6 @@ import pytest
import requests import requests
from time import sleep from time import sleep
from threading import Thread from threading import Thread
from cryptography.hazmat.primitives import serialization
from pisa.api import API from pisa.api import API
from pisa.watcher import Watcher from pisa.watcher import Watcher
@@ -36,13 +35,8 @@ config = get_config()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def run_api(db_manager): def run_api(db_manager):
sk, pk = generate_keypair() sk, pk = generate_keypair()
sk_der = sk.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
watcher = Watcher(db_manager, Responder(db_manager), sk_der, get_config()) watcher = Watcher(db_manager, Responder(db_manager), sk.to_der(), get_config())
chain_monitor = ChainMonitor(watcher.block_queue, watcher.responder.block_queue) chain_monitor = ChainMonitor(watcher.block_queue, watcher.responder.block_queue)
watcher.awake() watcher.awake()
chain_monitor.monitor_chain() chain_monitor.monitor_chain()

View File

@@ -2,8 +2,7 @@ import pytest
from uuid import uuid4 from uuid import uuid4
from shutil import rmtree from shutil import rmtree
from threading import Thread from threading import Thread
from cryptography.hazmat.primitives.asymmetric import ec from coincurve import PrivateKey
from cryptography.hazmat.primitives import serialization
from pisa.watcher import Watcher from pisa.watcher import Watcher
from pisa.responder import Responder from pisa.responder import Responder
@@ -36,11 +35,6 @@ TEST_SET_SIZE = 200
signing_key, public_key = generate_keypair() signing_key, public_key = generate_keypair()
sk_der = signing_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@@ -56,7 +50,7 @@ def temp_db_manager():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def watcher(db_manager): def watcher(db_manager):
watcher = Watcher(db_manager, Responder(db_manager), sk_der, get_config()) watcher = Watcher(db_manager, Responder(db_manager), signing_key.to_der(), get_config())
chain_monitor = ChainMonitor(watcher.block_queue, watcher.responder.block_queue) chain_monitor = ChainMonitor(watcher.block_queue, watcher.responder.block_queue)
chain_monitor.monitor_chain() chain_monitor.monitor_chain()
@@ -96,7 +90,7 @@ def test_init(run_bitcoind, watcher):
assert isinstance(watcher.locator_uuid_map, dict) and len(watcher.locator_uuid_map) == 0 assert isinstance(watcher.locator_uuid_map, dict) and len(watcher.locator_uuid_map) == 0
assert watcher.block_queue.empty() assert watcher.block_queue.empty()
assert isinstance(watcher.config, dict) assert isinstance(watcher.config, dict)
assert isinstance(watcher.signing_key, ec.EllipticCurvePrivateKey) assert isinstance(watcher.signing_key, PrivateKey)
assert isinstance(watcher.responder, Responder) assert isinstance(watcher.responder, Responder)
@@ -109,13 +103,17 @@ def test_add_appointment(watcher):
added_appointment, sig = watcher.add_appointment(appointment) added_appointment, sig = watcher.add_appointment(appointment)
assert added_appointment is True assert added_appointment is True
assert Cryptographer.verify(appointment.serialize(), sig, public_key) assert Cryptographer.verify_rpk(
watcher.signing_key.public_key, Cryptographer.recover_pk(appointment.serialize(), sig)
)
# Check that we can also add an already added appointment (same locator) # Check that we can also add an already added appointment (same locator)
added_appointment, sig = watcher.add_appointment(appointment) added_appointment, sig = watcher.add_appointment(appointment)
assert added_appointment is True assert added_appointment is True
assert Cryptographer.verify(appointment.serialize(), sig, public_key) assert Cryptographer.verify_rpk(
watcher.signing_key.public_key, Cryptographer.recover_pk(appointment.serialize(), sig)
)
def test_add_too_many_appointments(watcher): def test_add_too_many_appointments(watcher):
@@ -129,7 +127,9 @@ def test_add_too_many_appointments(watcher):
added_appointment, sig = watcher.add_appointment(appointment) added_appointment, sig = watcher.add_appointment(appointment)
assert added_appointment is True assert added_appointment is True
assert Cryptographer.verify(appointment.serialize(), sig, public_key) assert Cryptographer.verify_rpk(
watcher.signing_key.public_key, Cryptographer.recover_pk(appointment.serialize(), sig)
)
appointment, dispute_tx = generate_dummy_appointment( appointment, dispute_tx = generate_dummy_appointment(
start_time_offset=START_TIME_OFFSET, end_time_offset=END_TIME_OFFSET start_time_offset=START_TIME_OFFSET, end_time_offset=END_TIME_OFFSET