diff --git a/common/cryptographer.py b/common/cryptographer.py index 67ff6e0..fee5001 100644 --- a/common/cryptographer.py +++ b/common/cryptographer.py @@ -146,6 +146,35 @@ class Cryptographer: return blob + @staticmethod + def load_key_file(file_path): + """ + Loads a key from a key file. + + Args: + file_path (:obj:`str`): the path to the key file to be loaded. + + Returns: + :obj:`bytes` or :obj:`None`: the key file data if the file can be found and read. ``None`` otherwise. + """ + + if not isinstance(file_path, str): + logger.error("Key file path was expected, {} received".format(type(file_path))) + return None + + try: + with open(file_path, "rb") as key_file: + key = key_file.read() + return key + + except FileNotFoundError: + logger.error("Key file not found. Please check your settings") + return None + + except IOError as e: + logger.error("I/O error({}): {}".format(e.errno, e.strerror)) + return None + @staticmethod def load_public_key_der(pk_der): """ @@ -199,7 +228,7 @@ class Cryptographer: return sk except UnsupportedAlgorithm: - raise ValueError("Could not deserialize the private key (unsupported algorithm).") + logger.error("Could not deserialize the private key (unsupported algorithm)") except ValueError: logger.error("The provided data cannot be deserialized (wrong size or format)") @@ -207,6 +236,8 @@ class Cryptographer: except TypeError: logger.error("The provided data cannot be deserialized (wrong type)") + return None + @staticmethod def sign(data, sk, rtype="str"): """ diff --git a/pisa/pisad.py b/pisa/pisad.py index e7d771e..89602b5 100644 --- a/pisa/pisad.py +++ b/pisa/pisad.py @@ -3,6 +3,7 @@ from sys import argv, exit from signal import signal, SIGINT, SIGQUIT, SIGTERM from common.logger import Logger +from common.cryptographer import Cryptographer from pisa import config, LOG_PREFIX from pisa.api import API @@ -44,8 +45,9 @@ def main(): else: try: - with open(config.get("PISA_SECRET_KEY"), "rb") as key_file: - secret_key_der = key_file.read() + secret_key_der = Cryptographer.load_key_file(config.get("PISA_SECRET_KEY")) + if not secret_key_der: + raise IOError("PISA private key can't be loaded") watcher = Watcher(db_manager, Responder(db_manager), secret_key_der, config) diff --git a/test/common/unit/test_cryptographer.py b/test/common/unit/test_cryptographer.py index 44a1b77..875cea4 100644 --- a/test/common/unit/test_cryptographer.py +++ b/test/common/unit/test_cryptographer.py @@ -1,3 +1,5 @@ +import os +import pytest import binascii from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec @@ -181,6 +183,30 @@ def test_decrypt_wrong_return(): assert True +def test_load_key_file(): + dummy_sk = ec.generate_private_key(ec.SECP256K1, default_backend()) + dummy_sk_der = dummy_sk.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + # If file exists and has data in it, function should work. + with open("key_test_file", "wb") as f: + f.write(dummy_sk_der) + + appt_data = Cryptographer.load_key_file("key_test_file") + assert appt_data + + os.remove("key_test_file") + + # If file doesn't exist, function should return None + assert Cryptographer.load_key_file("nonexistent_file") is None + + # If something that's not a file_path is passed as parameter the method should also return None + assert Cryptographer.load_key_file(0) is None and Cryptographer.load_key_file(None) is None + + def test_load_public_key_der(): # load_public_key_der expects a byte encoded data. Any other should fail and return None for wtype in WRONG_TYPES: