Refactor add_appointment cli code

This commit is contained in:
Turtle
2019-11-30 00:42:36 -05:00
parent 49657ccbfc
commit 40d7ca1912
5 changed files with 162 additions and 106 deletions

View File

@@ -9,8 +9,8 @@ from getopt import getopt, GetoptError
from requests import ConnectTimeout, ConnectionError from requests import ConnectTimeout, ConnectionError
from uuid import uuid4 from uuid import uuid4
from apps.cli.blob import Blob
from apps.cli.help import help_add_appointment, help_get_appointment from apps.cli.help import help_add_appointment, help_get_appointment
from apps.cli.blob import Blob
from apps.cli import ( from apps.cli import (
DEFAULT_PISA_API_SERVER, DEFAULT_PISA_API_SERVER,
DEFAULT_PISA_API_PORT, DEFAULT_PISA_API_PORT,
@@ -22,9 +22,8 @@ from apps.cli import (
from common.logger import Logger from common.logger import Logger
from common.appointment import Appointment from common.appointment import Appointment
from common.constants import LOCATOR_LEN_HEX
from common.cryptographer import Cryptographer from common.cryptographer import Cryptographer
from common.tools import check_sha256_hex_format from common.tools import check_sha256_hex_format, compute_locator
HTTP_OK = 200 HTTP_OK = 200
@@ -46,11 +45,13 @@ def generate_dummy_appointment():
"to_self_delay": 20, "to_self_delay": 20,
} }
print("Generating dummy appointment data:" "\n\n" + json.dumps(dummy_appointment_data, indent=4, sort_keys=True)) logger.info(
"Generating dummy appointment data:" "\n\n" + json.dumps(dummy_appointment_data, indent=4, sort_keys=True)
)
json.dump(dummy_appointment_data, open("dummy_appointment_data.json", "w")) json.dump(dummy_appointment_data, open("dummy_appointment_data.json", "w"))
print("\nData stored in dummy_appointment_data.json") logger.info("\nData stored in dummy_appointment_data.json")
# Loads and returns Pisa keys from disk # Loads and returns Pisa keys from disk
@@ -61,11 +62,12 @@ def load_key_file_data(file_name):
return key return key
except FileNotFoundError: except FileNotFoundError:
raise FileNotFoundError("File not found.") logger.error("Client's key file not found. Please check your settings.")
return False
except IOError as e:
def compute_locator(tx_id): logger.error("I/O error({}): {}".format(e.errno, e.strerror))
return tx_id[:LOCATOR_LEN_HEX] return False
# Makes sure that the folder APPOINTMENTS_FOLDER_NAME exists, then saves the appointment and signature in it. # Makes sure that the folder APPOINTMENTS_FOLDER_NAME exists, then saves the appointment and signature in it.
@@ -85,12 +87,81 @@ def save_signed_appointment(appointment, signature):
def add_appointment(args): def add_appointment(args):
appointment_data = None # Get appointment data from user.
appointment_data = parse_add_appointment_args(args)
if appointment_data is None:
logger.error("The provided appointment JSON is empty")
return False
valid_txid = check_sha256_hex_format(appointment_data.get("tx_id"))
if not valid_txid:
logger.error("The provided txid is not valid")
return False
tx_id = appointment_data.get("tx_id")
tx = appointment_data.get("tx")
if None not in [tx_id, tx]:
appointment_data["locator"] = compute_locator(tx_id)
appointment_data["encrypted_blob"] = Cryptographer.encrypt(Blob(tx), tx_id)
else:
logger.error("Appointment data is missing some fields.")
return False
appointment = Appointment.from_dict(appointment_data)
signature = get_appointment_signature(appointment)
hex_pk_der = get_pk()
if not (appointment and signature and hex_pk_der):
return False
data = {"appointment": appointment.to_dict(), "signature": signature, "public_key": hex_pk_der.decode("utf-8")}
appointment_json = json.dumps(data, sort_keys=True, separators=(",", ":"))
# Send appointment to the server.
add_appointment_endpoint = "http://{}:{}".format(pisa_api_server, pisa_api_port)
response_json = post_data_to_add_appointment_endpoint(add_appointment_endpoint, appointment_json)
if response_json is None:
return False
signature = response_json.get("signature")
# Check that the server signed the appointment as it should.
if signature is None:
logger.error("The response does not contain the signature of the appointment.")
return False
valid = check_signature(signature, appointment)
if not valid:
logger.error("The returned appointment's signature is invalid")
return False
logger.info("Appointment accepted and signed by Pisa")
# all good, store appointment and signature
try:
save_signed_appointment(appointment.to_dict(), signature)
except OSError as e:
logger.error("There was an error while saving the appointment", error=e)
return False
return True
# Parse arguments passed to add_appointment and handle them accordingly.
# Returns appointment data.
def parse_add_appointment_args(args):
use_help = "Use 'help add_appointment' for help of how to use the command" use_help = "Use 'help add_appointment' for help of how to use the command"
if not args: if not args:
logger.error("No appointment data provided. " + use_help) logger.error("No appointment data provided. " + use_help)
return False return None
arg_opt = args.pop(0) arg_opt = args.pop(0)
@@ -102,7 +173,7 @@ def add_appointment(args):
fin = args.pop(0) fin = args.pop(0)
if not os.path.isfile(fin): if not os.path.isfile(fin):
logger.error("Can't find file", filename=fin) logger.error("Can't find file", filename=fin)
return False return None
try: try:
with open(fin) as f: with open(fin) as f:
@@ -110,63 +181,19 @@ def add_appointment(args):
except IOError as e: except IOError as e:
logger.error("I/O error", errno=e.errno, error=e.strerror) logger.error("I/O error", errno=e.errno, error=e.strerror)
return False return None
else: else:
appointment_data = json.loads(arg_opt) appointment_data = json.loads(arg_opt)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error("Non-JSON encoded data provided as appointment. " + use_help) logger.error("Non-JSON encoded data provided as appointment. " + use_help)
return False return None
if not appointment_data: return appointment_data
logger.error("The provided JSON is empty")
return False
valid_locator = check_sha256_hex_format(appointment_data.get("tx_id"))
if not valid_locator:
logger.error("The provided locator is not valid")
return False
add_appointment_endpoint = "http://{}:{}".format(pisa_api_server, pisa_api_port)
appointment = Appointment.from_dict(appointment_data)
try:
sk_der = load_key_file_data(CLI_PRIVATE_KEY)
cli_sk = Cryptographer.load_private_key_der(sk_der)
except ValueError:
logger.error("Failed to deserialize the public key. It might be in an unsupported format")
return False
except FileNotFoundError:
logger.error("Client's private key file not found. Please check your settings")
return False
except IOError as e:
logger.error("I/O error", errno=e.errno, error=e.strerror)
return False
signature = Cryptographer.sign(appointment.serialize(), cli_sk)
try:
cli_pk_der = load_key_file_data(CLI_PUBLIC_KEY)
hex_pk_der = binascii.hexlify(cli_pk_der)
except FileNotFoundError:
logger.error("Client's public key file not found. Please check your settings")
return False
except IOError as e:
logger.error("I/O error", errno=e.errno, error=e.strerror)
return False
# FIXME: Exceptions for hexlify need to be covered
data = {"appointment": appointment, "signature": signature, "public_key": hex_pk_der.decode("utf-8")}
appointment_json = json.dumps(data, sort_keys=True, separators=(",", ":"))
# Sends appointment data to add_appointment endpoint to be processed by the server.
def post_data_to_add_appointment_endpoint(add_appointment_endpoint, appointment_json):
logger.info("Sending appointment to PISA") logger.info("Sending appointment to PISA")
try: try:
@@ -176,15 +203,15 @@ def add_appointment(args):
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error("The response was not valid JSON") logger.error("The response was not valid JSON")
return False return None
except ConnectTimeout: except ConnectTimeout:
logger.error("Can't connect to pisa API. Connection timeout") logger.error("Can't connect to pisa API. Connection timeout")
return False return None
except ConnectionError: except ConnectionError:
logger.error("Can't connect to pisa API. Server cannot be reached") logger.error("Can't connect to pisa API. Server cannot be reached")
return False return None
if r.status_code != HTTP_OK: if r.status_code != HTTP_OK:
if "error" not in response_json: if "error" not in response_json:
@@ -196,14 +223,17 @@ def add_appointment(args):
status_code=r.status_code, status_code=r.status_code,
description=error, description=error,
) )
return False return None
if "signature" not in response_json: if "signature" not in response_json:
logger.error("The response does not contain the signature of the appointment") logger.error("The response does not contain the signature of the appointment")
return False return None
signature = response_json["signature"] return response_json
# verify that the returned signature is valid
# Verify that the signature returned from the watchtower is valid.
def check_signature(signature, appointment):
try: try:
pisa_pk_der = load_key_file_data(PISA_PUBLIC_KEY) pisa_pk_der = load_key_file_data(PISA_PUBLIC_KEY)
pisa_pk = Cryptographer.load_public_key_der(pisa_pk_der) pisa_pk = Cryptographer.load_public_key_der(pisa_pk_der)
@@ -212,7 +242,7 @@ def add_appointment(args):
logger.error("Failed to deserialize the public key. It might be in an unsupported format") logger.error("Failed to deserialize the public key. It might be in an unsupported format")
return False return False
is_sig_valid = Cryptographer.verify(appointment.serialize(), signature, pisa_pk) return Cryptographer.verify(appointment.serialize(), signature, pisa_pk)
except FileNotFoundError: except FileNotFoundError:
logger.error("Pisa's public key file not found. Please check your settings") logger.error("Pisa's public key file not found. Please check your settings")
@@ -222,21 +252,6 @@ def add_appointment(args):
logger.error("I/O error", errno=e.errno, error=e.strerror) logger.error("I/O error", errno=e.errno, error=e.strerror)
return False return False
if not is_sig_valid:
logger.error("The returned appointment's signature is invalid")
return False
logger.info("Appointment accepted and signed by Pisa")
# all good, store appointment and signature
try:
save_signed_appointment(appointment, signature)
except OSError as e:
logger.error("There was an error while saving the appointment", error=e)
return False
return True
def get_appointment(args): def get_appointment(args):
if not args: if not args:
@@ -260,8 +275,9 @@ def get_appointment(args):
try: try:
r = requests.get(url=get_appointment_endpoint + parameters, timeout=5) r = requests.get(url=get_appointment_endpoint + parameters, timeout=5)
logger.info("Appointment response returned from server: " + str(r))
return True
print(json.dumps(r.json(), indent=4, sort_keys=True))
except ConnectTimeout: except ConnectTimeout:
logger.error("Can't connect to pisa API. Connection timeout") logger.error("Can't connect to pisa API. Connection timeout")
return False return False
@@ -270,7 +286,47 @@ def get_appointment(args):
logger.error("Can't connect to pisa API. Server cannot be reached") logger.error("Can't connect to pisa API. Server cannot be reached")
return False return False
return True
def get_appointment_signature(appointment):
try:
sk_der = load_key_file_data(CLI_PRIVATE_KEY)
cli_sk = Cryptographer.load_private_key_der(sk_der)
signature = Cryptographer.sign(appointment.serialize(), cli_sk)
return signature
except ValueError:
logger.error("Failed to deserialize the public key. It might be in an unsupported format")
return False
except FileNotFoundError:
logger.error("Client's private key file not found. Please check your settings")
return False
except IOError as e:
logger.error("I/O error", errno=e.errno, error=e.strerror)
return False
def get_pk():
try:
cli_pk_der = load_key_file_data(CLI_PUBLIC_KEY)
hex_pk_der = binascii.hexlify(cli_pk_der)
return hex_pk_der
except FileNotFoundError:
logger.error("Client's public key file not found. Please check your settings")
return False
except IOError as e:
logger.error("I/O error", errno=e.errno, error=e.strerror)
return False
except binascii.Error as e:
logger.error("Could not successfully encode public key as hex: ", e)
return False
def show_usage(): def show_usage():

View File

@@ -1,4 +1,5 @@
import re import re
from common.constants import LOCATOR_LEN_HEX
def check_sha256_hex_format(value): def check_sha256_hex_format(value):
@@ -12,3 +13,15 @@ def check_sha256_hex_format(value):
:mod:`bool`: Whether or not the value matches the format. :mod:`bool`: Whether or not the value matches the format.
""" """
return isinstance(value, str) and re.match(r"^[0-9A-Fa-f]{64}$", value) is not None return isinstance(value, str) and re.match(r"^[0-9A-Fa-f]{64}$", value) is not None
def compute_locator(tx_id):
"""
Computes an appointment locator given a transaction id.
Args:
tx_id (:obj:`str`): the transaction id used to compute the locator.
Returns:
(:obj:`str`): The computed locator.
"""
return tx_id[:LOCATOR_LEN_HEX]

View File

@@ -3,7 +3,7 @@ from queue import Queue
from threading import Thread from threading import Thread
from common.cryptographer import Cryptographer from common.cryptographer import Cryptographer
from common.constants import LOCATOR_LEN_HEX from common.tools import compute_locator
from common.logger import Logger from common.logger import Logger
from pisa.cleaner import Cleaner from pisa.cleaner import Cleaner
@@ -71,20 +71,6 @@ class Watcher:
if not isinstance(responder, Responder): if not isinstance(responder, Responder):
self.responder = Responder(db_manager) self.responder = Responder(db_manager)
@staticmethod
def compute_locator(tx_id):
"""
Computes an appointment locator given a transaction id.
Args:
tx_id (:obj:`str`): the transaction id used to compute the locator.
Returns:
(:obj:`str`): The computed locator.
"""
return tx_id[:LOCATOR_LEN_HEX]
def add_appointment(self, appointment): def add_appointment(self, appointment):
""" """
Adds a new appointment to the ``appointments`` dictionary if ``max_appointments`` has not been reached. Adds a new appointment to the ``appointments`` dictionary if ``max_appointments`` has not been reached.
@@ -238,7 +224,7 @@ class Watcher:
found. found.
""" """
potential_locators = {Watcher.compute_locator(txid): txid for txid in txids} potential_locators = {compute_locator(txid): txid for txid in txids}
# Check is any of the tx_ids in the received block is an actual match # Check is any of the tx_ids in the received block is an actual match
intersection = set(self.locator_uuid_map.keys()).intersection(potential_locators.keys()) intersection = set(self.locator_uuid_map.keys()).intersection(potential_locators.keys())

View File

@@ -16,6 +16,7 @@ from pisa.watcher import Watcher
from pisa.tools import bitcoin_cli from pisa.tools import bitcoin_cli
from pisa.db_manager import DBManager from pisa.db_manager import DBManager
from common.appointment import Appointment from common.appointment import Appointment
from common.tools import compute_locator
from bitcoind_mock.utils import sha256d from bitcoind_mock.utils import sha256d
from bitcoind_mock.transaction import TX from bitcoind_mock.transaction import TX
@@ -103,7 +104,7 @@ def generate_dummy_appointment_data(real_height=True, start_time_offset=5, end_t
encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo
) )
locator = Watcher.compute_locator(dispute_txid) locator = compute_locator(dispute_txid)
blob = Blob(dummy_appointment_data.get("tx")) blob = Blob(dummy_appointment_data.get("tx"))
encrypted_blob = Cryptographer.encrypt(blob, dummy_appointment_data.get("tx_id")) encrypted_blob = Cryptographer.encrypt(blob, dummy_appointment_data.get("tx_id"))

View File

@@ -16,7 +16,7 @@ from test.pisa.unit.conftest import (
) )
from pisa.conf import EXPIRY_DELTA, MAX_APPOINTMENTS from pisa.conf import EXPIRY_DELTA, MAX_APPOINTMENTS
from common.tools import check_sha256_hex_format from common.tools import check_sha256_hex_format, compute_locator
from common.cryptographer import Cryptographer from common.cryptographer import Cryptographer
@@ -46,7 +46,7 @@ def txids():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def locator_uuid_map(txids): def locator_uuid_map(txids):
return {Watcher.compute_locator(txid): uuid4().hex for txid in txids} return {compute_locator(txid): uuid4().hex for txid in txids}
def create_appointments(n): def create_appointments(n):
@@ -219,7 +219,7 @@ def test_filter_valid_breaches(watcher):
dummy_appointment, _ = generate_dummy_appointment() dummy_appointment, _ = generate_dummy_appointment()
dummy_appointment.encrypted_blob.data = encrypted_blob dummy_appointment.encrypted_blob.data = encrypted_blob
dummy_appointment.locator = Watcher.compute_locator(dispute_txid) dummy_appointment.locator = compute_locator(dispute_txid)
uuid = uuid4().hex uuid = uuid4().hex
appointments = {uuid: dummy_appointment} appointments = {uuid: dummy_appointment}